inrefence predictions fix

This commit is contained in:
Dobromir Popov
2025-07-26 23:34:36 +03:00
parent 7c61c12b70
commit 3eb6335169
9 changed files with 1125 additions and 305 deletions

View File

@ -15,6 +15,7 @@ from threading import Lock
from .data_models import BaseDataInput, ModelOutput, create_model_output
from NN.models.enhanced_cnn import EnhancedCNN
from utils.inference_logger import log_model_inference
logger = logging.getLogger(__name__)
@ -339,6 +340,42 @@ class EnhancedCNNAdapter:
metadata=metadata
)
# Log inference with full input data for training feedback
log_model_inference(
model_name=self.model_name,
symbol=base_data.symbol,
action=action,
confidence=confidence,
probabilities={
'BUY': predictions['buy_probability'],
'SELL': predictions['sell_probability'],
'HOLD': predictions['hold_probability']
},
input_features=features.cpu().numpy(), # Store full feature vector
processing_time_ms=inference_duration,
checkpoint_id=None, # Could be enhanced to track checkpoint
metadata={
'base_data_input': {
'symbol': base_data.symbol,
'timestamp': base_data.timestamp.isoformat(),
'ohlcv_1s_count': len(base_data.ohlcv_1s),
'ohlcv_1m_count': len(base_data.ohlcv_1m),
'ohlcv_1h_count': len(base_data.ohlcv_1h),
'ohlcv_1d_count': len(base_data.ohlcv_1d),
'btc_ohlcv_1s_count': len(base_data.btc_ohlcv_1s),
'has_cob_data': base_data.cob_data is not None,
'technical_indicators_count': len(base_data.technical_indicators),
'pivot_points_count': len(base_data.pivot_points),
'last_predictions_count': len(base_data.last_predictions)
},
'model_predictions': {
'pivot_price': pivot_price,
'extrema_prediction': predictions['extrema'],
'price_prediction': predictions['price_prediction']
}
}
)
return model_output
except Exception as e:
@ -401,7 +438,7 @@ class EnhancedCNNAdapter:
def train(self, epochs: int = 1) -> Dict[str, float]:
"""
Train the model with collected data
Train the model with collected data and inference history
Args:
epochs: Number of epochs to train for
@ -415,6 +452,9 @@ class EnhancedCNNAdapter:
training_start = training_start_time.timestamp()
with self.training_lock:
# Get additional training data from inference history
self._load_training_data_from_inference_history()
# Check if we have enough data
if len(self.training_data) < self.batch_size:
logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
@ -583,3 +623,100 @@ class EnhancedCNNAdapter:
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
def _load_training_data_from_inference_history(self):
"""Load training data from inference history for continuous learning"""
try:
from utils.database_manager import get_database_manager
db_manager = get_database_manager()
# Get recent inference records with input features
inference_records = db_manager.get_inference_records_for_training(
model_name=self.model_name,
hours_back=24, # Last 24 hours
limit=1000
)
if not inference_records:
logger.debug("No inference records found for training")
return
# Convert inference records to training samples
# For now, use a simple approach: treat high-confidence predictions as ground truth
for record in inference_records:
if record.input_features is not None and record.confidence > 0.7:
# Convert action to index
actions = ['BUY', 'SELL', 'HOLD']
if record.action in actions:
action_idx = actions.index(record.action)
# Use confidence as a proxy for reward (high confidence = good prediction)
reward = record.confidence * 2 - 1 # Scale to [-1, 1]
# Convert features to tensor
features_tensor = torch.tensor(record.input_features, dtype=torch.float32, device=self.device)
# Add to training data if not already present (avoid duplicates)
sample_exists = any(
torch.equal(features_tensor, existing[0])
for existing in self.training_data
)
if not sample_exists:
self.training_data.append((features_tensor, action_idx, reward))
logger.info(f"Loaded {len(inference_records)} inference records for training, total training samples: {len(self.training_data)}")
except Exception as e:
logger.error(f"Error loading training data from inference history: {e}")
def evaluate_predictions_against_outcomes(self, hours_back: int = 1) -> Dict[str, float]:
"""
Evaluate past predictions against actual market outcomes
Args:
hours_back: How many hours back to evaluate
Returns:
Dict with evaluation metrics
"""
try:
from utils.database_manager import get_database_manager
db_manager = get_database_manager()
# Get inference records from the specified time period
inference_records = db_manager.get_inference_records_for_training(
model_name=self.model_name,
hours_back=hours_back,
limit=100
)
if not inference_records:
return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}
# For now, use a simple evaluation based on confidence
# In a real implementation, this would compare against actual price movements
correct_predictions = 0
total_predictions = len(inference_records)
# Simple heuristic: high confidence predictions are more likely to be correct
for record in inference_records:
if record.confidence > 0.8: # High confidence threshold
correct_predictions += 1
elif record.confidence > 0.6: # Medium confidence
correct_predictions += 0.5
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
logger.info(f"Prediction evaluation: {correct_predictions:.1f}/{total_predictions} = {accuracy:.3f} accuracy")
return {
'accuracy': accuracy,
'total_predictions': total_predictions,
'correct_predictions': correct_predictions
}
except Exception as e:
logger.error(f"Error evaluating predictions: {e}")
return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}

View File

@ -268,6 +268,7 @@ class TradingOrchestrator:
# Initialize models, COB integration, and training system
self._initialize_ml_models()
self._initialize_cob_integration()
self._start_cob_integration_sync() # Start COB integration
self._initialize_decision_fusion() # Initialize fusion system
self._initialize_enhanced_training_system() # Initialize real-time training
@ -826,6 +827,31 @@ class TradingOrchestrator:
else:
logger.warning("COB Integration not initialized or start method not available.")
def _start_cob_integration_sync(self):
"""Start COB integration synchronously during initialization"""
if self.cob_integration and hasattr(self.cob_integration, 'start'):
try:
logger.info("Starting COB integration during initialization...")
# If start is async, we need to run it in the event loop
import asyncio
try:
# Try to get current event loop
loop = asyncio.get_event_loop()
if loop.is_running():
# If loop is running, schedule the coroutine
asyncio.create_task(self.cob_integration.start())
else:
# If no loop is running, run it
loop.run_until_complete(self.cob_integration.start())
except RuntimeError:
# No event loop, create one
asyncio.run(self.cob_integration.start())
logger.info("COB Integration started during initialization")
except Exception as e:
logger.warning(f"Failed to start COB integration during initialization: {e}")
else:
logger.debug("COB Integration not available for startup")
def _on_cob_cnn_features(self, symbol: str, cob_data: Dict):
"""Callback for when new COB CNN features are available"""
if not self.realtime_processing:
@ -870,9 +896,37 @@ class TradingOrchestrator:
return
try:
self.latest_cob_data[symbol] = cob_data
# logger.debug(f"COB Dashboard data updated for {symbol}")
# Update data cache with COB data for BaseDataInput
if hasattr(self, 'data_integration') and self.data_integration:
# Convert cob_data to COBData format if needed
from .data_models import COBData
# Create COBData object from the raw cob_data
if 'price_buckets' in cob_data and 'current_price' in cob_data:
cob_data_obj = COBData(
symbol=symbol,
timestamp=datetime.now(),
current_price=cob_data['current_price'],
bucket_size=1.0 if 'ETH' in symbol else 10.0,
price_buckets=cob_data.get('price_buckets', {}),
bid_ask_imbalance=cob_data.get('bid_ask_imbalance', {}),
volume_weighted_prices=cob_data.get('volume_weighted_prices', {}),
order_flow_metrics=cob_data.get('order_flow_metrics', {}),
ma_1s_imbalance=cob_data.get('ma_1s_imbalance', {}),
ma_5s_imbalance=cob_data.get('ma_5s_imbalance', {}),
ma_15s_imbalance=cob_data.get('ma_15s_imbalance', {}),
ma_60s_imbalance=cob_data.get('ma_60s_imbalance', {})
)
# Update cache with COB data
self.data_integration.cache.update('cob_data', symbol, cob_data_obj, 'cob_integration')
logger.debug(f"Updated cache with COB data for {symbol}")
# Update dashboard
if self.dashboard and hasattr(self.dashboard, 'update_cob_data'):
self.dashboard.update_cob_data(symbol, cob_data)
except Exception as e:
logger.error(f"Error in _on_cob_dashboard_data for {symbol}: {e}")
@ -2006,16 +2060,27 @@ class TradingOrchestrator:
try:
result = self.cnn_adapter.predict(base_data)
if result:
# Extract action and probabilities from ModelOutput
action = result.predictions.get('action', 'HOLD')
probabilities = {
'BUY': result.predictions.get('buy_probability', 0.0),
'SELL': result.predictions.get('sell_probability', 0.0),
'HOLD': result.predictions.get('hold_probability', 0.0)
}
prediction = Prediction(
action=result.action,
action=action,
confidence=result.confidence,
probabilities=result.predictions,
probabilities=probabilities,
timeframe="multi", # Multi-timeframe prediction
timestamp=datetime.now(),
model_name="enhanced_cnn",
metadata={
'feature_size': len(base_data.get_feature_vector()),
'data_sources': ['ohlcv_1s', 'ohlcv_1m', 'ohlcv_1h', 'ohlcv_1d', 'btc', 'cob', 'indicators']
'data_sources': ['ohlcv_1s', 'ohlcv_1m', 'ohlcv_1h', 'ohlcv_1d', 'btc', 'cob', 'indicators'],
'pivot_price': result.predictions.get('pivot_price'),
'extrema_prediction': result.predictions.get('extrema'),
'price_prediction': result.predictions.get('price_prediction')
}
)
predictions.append(prediction)
@ -2026,101 +2091,80 @@ class TradingOrchestrator:
except Exception as e:
logger.error(f"Error using CNN adapter: {e}")
# Fallback to legacy CNN prediction if adapter fails
# Fallback to direct model inference using BaseDataInput (unified approach)
if not predictions:
timeframes = getattr(self.config, 'timeframes', ['1m','5m','15m','1h'])
for timeframe in timeframes:
# 1) build or fetch your feature matrix (and optionally augment with COB)…
feature_matrix = self.data_provider.get_feature_matrix(
symbol=symbol,
timeframes=[timeframe],
window_size=getattr(model, 'window_size', 20)
)
if feature_matrix is None:
continue
# …apply COBaugmentation here (omitted for brevity)—
enhanced_features = self._augment_with_cob(feature_matrix, symbol)
# 2) Initialize these before we call the model
action_probs, confidence = None, None
# 3) Try the actual model inference
try:
# if your model has an .act() that returns (probs, conf)
if hasattr(model.model, 'act'):
# Flatten / reshape enhanced_features as needed…
x = self._prepare_cnn_input(enhanced_features)
# Debugging: Print the type and content of x before passing to act()
logger.debug(f"CNN input (x) type: {type(x)}, shape: {x.shape}, content sample: {x.flatten()[:5]}...")
action_idx, confidence, action_probs = model.model.act(x, explore=False)
# Debugging: Print the type and content of the unpacked values
logger.debug(f"CNN act() returned: action_idx={action_idx} (type={type(action_idx)}), confidence={confidence} (type={type(confidence)}), action_probs={action_probs[:5]}... (type={type(action_probs)})")
else:
# fallback to generic predict
result = model.predict(enhanced_features)
if isinstance(result, tuple) and len(result)==2:
action_probs, confidence = result
else:
action_probs = result
confidence = 0.7
except Exception as e:
logger.warning(f"CNN inference failed for {symbol}@{timeframe}: {e}")
continue # skip this timeframe entirely
# 4) If we still don't have valid probs, skip
if action_probs is None:
continue
# 5) Build your Prediction
action_names = ['SELL','HOLD','BUY']
best_idx = int(np.argmax(action_probs))
best_action = action_names[best_idx]
pred = Prediction(
action=best_action,
confidence=float(confidence),
probabilities={n: float(p) for n,p in zip(action_names, action_probs)},
timeframe=timeframe,
timestamp=datetime.now(),
model_name=model.name,
metadata={
'feature_shape': str(enhanced_features.shape),
'cob_enhanced': enhanced_features is not feature_matrix
}
)
predictions.append(pred)
# …and capture for the dashboard if you like…
current_price = self._get_current_price(symbol)
if current_price is not None:
predicted_price = current_price * (1 + (0.01 * (confidence if best_action=='BUY' else -confidence if best_action=='SELL' else 0)))
self.capture_cnn_prediction(
symbol,
direction=best_idx,
confidence=confidence,
current_price=current_price,
predicted_price=predicted_price
logger.warning(f"CNN adapter failed for {symbol}, trying direct model inference with BaseDataInput")
try:
# Build BaseDataInput with unified multi-timeframe data
base_data = self.build_base_data_input(symbol)
if not base_data:
logger.warning(f"Cannot build BaseDataInput for CNN fallback: {symbol}")
return predictions
# Convert to unified feature vector (7850 features)
feature_vector = base_data.get_feature_vector()
# Use the model's act method with unified input
if hasattr(model.model, 'act'):
# Convert to tensor format expected by enhanced_cnn
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
features_tensor = torch.tensor(feature_vector, dtype=torch.float32, device=device)
# Call the model's act method
action_idx, confidence, action_probs = model.model.act(features_tensor, explore=False)
# Build prediction with unified timeframe result
action_names = ['BUY', 'SELL', 'HOLD'] # Note: enhanced_cnn uses this order
best_action = action_names[action_idx]
pred = Prediction(
action=best_action,
confidence=float(confidence),
probabilities={
'BUY': float(action_probs[0]),
'SELL': float(action_probs[1]),
'HOLD': float(action_probs[2])
},
timeframe='unified', # Indicates this uses all timeframes
timestamp=datetime.now(),
model_name=model.name,
metadata={
'feature_vector_size': len(feature_vector),
'unified_input': True,
'fallback_method': 'direct_model_inference'
}
)
predictions.append(pred)
# Capture for dashboard
current_price = self._get_current_price(symbol)
if current_price is not None:
predicted_price = current_price * (1 + (0.01 * (confidence if best_action=='BUY' else -confidence if best_action=='SELL' else 0)))
self.capture_cnn_prediction(
symbol,
direction=action_idx,
confidence=confidence,
current_price=current_price,
predicted_price=predicted_price
)
logger.info(f"CNN fallback successful for {symbol}: {best_action} (confidence: {confidence:.3f})")
else:
logger.warning(f"CNN model {model.name} does not have act() method for fallback")
except Exception as e:
logger.error(f"CNN fallback inference failed for {symbol}: {e}")
# Don't continue with old timeframe-by-timeframe approach
except Exception as e:
logger.error(f"Orch: Error getting CNN predictions: {e}")
return predictions
# helper stubs for clarity
def _augment_with_cob(self, feature_matrix, symbol):
# your existing cobaugmentation logic…
return feature_matrix
def _prepare_cnn_input(self, features):
arr = features.flatten()
# pad/truncate to 300, reshape to (1,300)
if len(arr) < 300:
arr = np.pad(arr, (0,300-len(arr)), 'constant')
else:
arr = arr[:300]
return arr.reshape(1,-1)
# Note: Removed obsolete _augment_with_cob and _prepare_cnn_input methods
# The unified CNN model now handles all timeframes and COB data internally through BaseDataInput
async def _get_rl_prediction(self, model: RLAgentInterface, symbol: str) -> Optional[Prediction]:
"""Get prediction from RL agent using FIFO queue data"""
try:
@ -2197,59 +2241,63 @@ class TradingOrchestrator:
return None
async def _get_generic_prediction(self, model: ModelInterface, symbol: str) -> Optional[Prediction]:
"""Get prediction from generic model"""
"""Get prediction from generic model using unified BaseDataInput"""
try:
# Safely get timeframes from config
timeframes = getattr(self.config, 'timeframes', None)
if timeframes is None:
timeframes = ['1m', '5m', '15m'] # Default timeframes
# Use unified BaseDataInput approach instead of old timeframe-specific method
base_data = self.build_base_data_input(symbol)
if not base_data:
logger.warning(f"Cannot build BaseDataInput for generic prediction: {symbol}")
return None
# Get feature matrix for the model
feature_matrix = self.data_provider.get_feature_matrix(
symbol=symbol,
timeframes=timeframes[:3], # Use first 3 timeframes
window_size=20
)
# Convert to feature vector for generic models
feature_vector = base_data.get_feature_vector()
if feature_matrix is not None:
prediction_result = model.predict(feature_matrix)
# Handle different return formats from model.predict()
if prediction_result is None:
return None
# Check if it's a tuple (action_probs, confidence)
if isinstance(prediction_result, tuple) and len(prediction_result) == 2:
action_probs, confidence = prediction_result
elif isinstance(prediction_result, dict):
# Handle dictionary return format
action_probs = prediction_result.get('probabilities', None)
confidence = prediction_result.get('confidence', 0.7)
else:
# Assume it's just action probabilities (e.g., a list or numpy array)
action_probs = prediction_result
confidence = 0.7 # Default confidence
if action_probs is not None:
# Ensure action_probs is a numpy array for argmax
if not isinstance(action_probs, np.ndarray):
action_probs = np.array(action_probs)
# For backward compatibility, reshape to matrix format if model expects it
# Most generic models expect a 2D matrix, so reshape the unified vector
feature_matrix = feature_vector.reshape(1, -1) # Shape: (1, 7850)
prediction_result = model.predict(feature_matrix)
# Handle different return formats from model.predict()
if prediction_result is None:
return None
# Check if it's a tuple (action_probs, confidence)
if isinstance(prediction_result, tuple) and len(prediction_result) == 2:
action_probs, confidence = prediction_result
elif isinstance(prediction_result, dict):
# Handle dictionary return format
action_probs = prediction_result.get('probabilities', None)
confidence = prediction_result.get('confidence', 0.7)
else:
# Assume it's just action probabilities (e.g., a list or numpy array)
action_probs = prediction_result
confidence = 0.7 # Default confidence
if action_probs is not None:
# Ensure action_probs is a numpy array for argmax
if not isinstance(action_probs, np.ndarray):
action_probs = np.array(action_probs)
action_names = ['SELL', 'HOLD', 'BUY']
best_action_idx = np.argmax(action_probs)
best_action = action_names[best_action_idx]
prediction = Prediction(
action=best_action,
confidence=float(confidence),
probabilities={name: float(prob) for name, prob in zip(action_names, action_probs)},
timeframe='mixed',
timestamp=datetime.now(),
model_name=model.name,
metadata={'generic_model': True}
)
return prediction
action_names = ['SELL', 'HOLD', 'BUY']
best_action_idx = np.argmax(action_probs)
best_action = action_names[best_action_idx]
prediction = Prediction(
action=best_action,
confidence=float(confidence),
probabilities={name: float(prob) for name, prob in zip(action_names, action_probs)},
timeframe='unified', # Now uses unified multi-timeframe data
timestamp=datetime.now(),
model_name=model.name,
metadata={
'generic_model': True,
'unified_input': True,
'feature_vector_size': len(feature_vector)
}
)
return prediction
return None
@ -2258,45 +2306,29 @@ class TradingOrchestrator:
return None
def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]:
"""Get current state for RL agent"""
"""Get current state for RL agent - ensure compatibility with saved model"""
try:
# Safely get timeframes from config
timeframes = getattr(self.config, 'timeframes', None)
if timeframes is None:
timeframes = ['1m', '5m', '15m', '1h'] # Default timeframes
# Use unified BaseDataInput approach
base_data = self.build_base_data_input(symbol)
if not base_data:
logger.warning(f"Cannot build BaseDataInput for RL state: {symbol}")
return None
# Get feature matrix for all timeframes
feature_matrix = self.data_provider.get_feature_matrix(
symbol=symbol,
timeframes=timeframes,
window_size=self.config.rl.get('window_size', 20)
)
# Get unified feature vector
feature_vector = base_data.get_feature_vector()
if feature_matrix is not None:
# Flatten the feature matrix for RL agent
# Shape: (n_timeframes, window_size, n_features) -> (n_timeframes * window_size * n_features,)
state = feature_matrix.flatten()
# Add additional state information (position, balance, etc.)
# This would come from a portfolio manager in a real implementation
additional_state = np.array([0.0, 1.0, 0.0]) # [position, balance, unrealized_pnl]
combined_state = np.concatenate([state, additional_state])
# Ensure DQN gets exactly 403 features (expected by the model)
target_size = 403
if len(combined_state) < target_size:
# Pad with zeros
padded_state = np.zeros(target_size)
padded_state[:len(combined_state)] = combined_state
combined_state = padded_state
elif len(combined_state) > target_size:
# Truncate to target size
combined_state = combined_state[:target_size]
return combined_state
return None
# Ensure compatibility with saved model (expects 403 features)
target_size = 403 # Match the saved model's expected input size
if len(feature_vector) < target_size:
# Pad with zeros
padded_state = np.zeros(target_size)
padded_state[:len(feature_vector)] = feature_vector
return padded_state
elif len(feature_vector) > target_size:
# Truncate to target size
return feature_vector[:target_size]
else:
return feature_vector
except Exception as e:
logger.error(f"Error creating RL state for {symbol}: {e}")
@ -3897,7 +3929,7 @@ class TradingOrchestrator:
def build_base_data_input(self, symbol: str) -> Optional[Any]:
"""
Build BaseDataInput using simplified data integration
Build BaseDataInput using simplified data integration (optimized for speed)
Args:
symbol: Trading symbol
@ -3906,71 +3938,9 @@ class TradingOrchestrator:
BaseDataInput with consistent data structure
"""
try:
# Use simplified data integration to build BaseDataInput
# Use simplified data integration to build BaseDataInput (should be instantaneous)
return self.data_integration.build_base_data_input(symbol)
# Verify we have minimum data for all timeframes with fallback strategy
missing_data = []
for data_type, min_count in min_requirements.items():
if not self.ensure_minimum_data(data_type, symbol, min_count):
# Get actual count for better logging
actual_count = 0
if data_type in self.data_queues and symbol in self.data_queues[data_type]:
with self.data_queue_locks[data_type][symbol]:
actual_count = len(self.data_queues[data_type][symbol])
missing_data.append((data_type, actual_count, min_count))
# If we're missing critical 1s data, try to use 1m data as fallback
if missing_data:
critical_missing = [d for d in missing_data if d[0] in ['ohlcv_1s', 'ohlcv_1h']]
if critical_missing:
logger.warning(f"Missing critical data for {symbol}: {critical_missing}")
# Try fallback strategy: use available data with padding
if self._try_fallback_data_strategy(symbol, missing_data):
logger.info(f"Successfully applied fallback data strategy for {symbol}")
else:
for data_type, actual_count, min_count in missing_data:
logger.warning(f"Insufficient {data_type} data for {symbol}: have {actual_count}, need {min_count}")
return None
# Get BTC data (reference symbol)
btc_symbol = 'BTC/USDT'
if not self.ensure_minimum_data('ohlcv_1s', btc_symbol, 100):
# Get actual BTC data count for logging
btc_count = 0
if 'ohlcv_1s' in self.data_queues and btc_symbol in self.data_queues['ohlcv_1s']:
with self.data_queue_locks['ohlcv_1s'][btc_symbol]:
btc_count = len(self.data_queues['ohlcv_1s'][btc_symbol])
logger.warning(f"Insufficient BTC data for reference: have {btc_count}, need 100, using ETH data as fallback")
# Use ETH data as fallback
btc_data = self.get_queue_data('ohlcv_1s', symbol, 300)
else:
btc_data = self.get_queue_data('ohlcv_1s', btc_symbol, 300)
# Build BaseDataInput with queue data
base_data = BaseDataInput(
symbol=symbol,
timestamp=datetime.now(),
ohlcv_1s=self.get_queue_data('ohlcv_1s', symbol, 300),
ohlcv_1m=self.get_queue_data('ohlcv_1m', symbol, 300),
ohlcv_1h=self.get_queue_data('ohlcv_1h', symbol, 300),
ohlcv_1d=self.get_queue_data('ohlcv_1d', symbol, 300),
btc_ohlcv_1s=btc_data,
technical_indicators=self._get_latest_indicators(symbol),
cob_data=self._get_latest_cob_data(symbol),
last_predictions=self._get_recent_model_predictions(symbol)
)
# Validate the data
if not base_data.validate():
logger.warning(f"BaseDataInput validation failed for {symbol}")
return None
return base_data
except Exception as e:
logger.error(f"Error building BaseDataInput for {symbol}: {e}")
return None

View File

@ -6,7 +6,8 @@ Integrates with SmartDataUpdater for efficient data management.
"""
import logging
from datetime import datetime
import threading
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
import pandas as pd
@ -29,6 +30,11 @@ class SimplifiedDataIntegration:
# Initialize smart data updater
self.data_updater = SmartDataUpdater(data_provider, symbols)
# Pre-built OHLCV data cache for instant access
self._ohlcv_cache = {} # {symbol: {timeframe: List[OHLCVBar]}}
self._ohlcv_cache_lock = threading.RLock()
self._last_cache_update = {} # {symbol: {timeframe: datetime}}
# Register for tick data if available
self._setup_tick_integration()
@ -61,6 +67,8 @@ class SimplifiedDataIntegration:
def _on_tick_data(self, symbol: str, price: float, volume: float, timestamp: datetime = None):
"""Handle incoming tick data"""
self.data_updater.add_tick(symbol, price, volume, timestamp)
# Invalidate OHLCV cache for this symbol
self._invalidate_ohlcv_cache(symbol)
def _on_websocket_data(self, symbol: str, data: Dict[str, Any]):
"""Handle WebSocket data updates"""
@ -68,12 +76,28 @@ class SimplifiedDataIntegration:
# Extract price and volume from WebSocket data
if 'price' in data and 'volume' in data:
self.data_updater.add_tick(symbol, data['price'], data['volume'])
# Invalidate OHLCV cache for this symbol
self._invalidate_ohlcv_cache(symbol)
except Exception as e:
logger.error(f"Error processing WebSocket data: {e}")
def _invalidate_ohlcv_cache(self, symbol: str):
"""Invalidate OHLCV cache for a symbol when new data arrives"""
try:
with self._ohlcv_cache_lock:
# Remove cached data for all timeframes of this symbol
keys_to_remove = [key for key in self._ohlcv_cache.keys() if key.startswith(f"{symbol}_")]
for key in keys_to_remove:
if key in self._ohlcv_cache:
del self._ohlcv_cache[key]
if key in self._last_cache_update:
del self._last_cache_update[key]
except Exception as e:
logger.error(f"Error invalidating OHLCV cache for {symbol}: {e}")
def build_base_data_input(self, symbol: str) -> Optional[BaseDataInput]:
"""
Build BaseDataInput from cached data (much simpler than FIFO queues)
Build BaseDataInput from cached data (optimized for speed)
Args:
symbol: Trading symbol
@ -82,22 +106,7 @@ class SimplifiedDataIntegration:
BaseDataInput with consistent data structure
"""
try:
# Check if we have minimum required data
required_timeframes = ['1s', '1m', '1h', '1d']
missing_timeframes = []
for timeframe in required_timeframes:
if not self.cache.has_data(f'ohlcv_{timeframe}', symbol, max_age_seconds=300):
missing_timeframes.append(timeframe)
if missing_timeframes:
logger.warning(f"Missing data for {symbol}: {missing_timeframes}")
# Try to use historical data as fallback
if not self._try_historical_fallback(symbol, missing_timeframes):
return None
# Get current OHLCV data
# Get OHLCV data directly from optimized cache (no validation checks for speed)
ohlcv_1s_list = self._get_ohlcv_data_list(symbol, '1s', 300)
ohlcv_1m_list = self._get_ohlcv_data_list(symbol, '1m', 300)
ohlcv_1h_list = self._get_ohlcv_data_list(symbol, '1h', 300)
@ -109,18 +118,13 @@ class SimplifiedDataIntegration:
if not btc_ohlcv_1s_list:
# Use ETH data as fallback
btc_ohlcv_1s_list = ohlcv_1s_list
logger.debug(f"Using {symbol} data as BTC fallback")
# Get technical indicators
# Get cached data (fast lookups)
technical_indicators = self.cache.get('technical_indicators', symbol) or {}
# Get COB data if available
cob_data = self.cache.get('cob_data', symbol)
# Get recent model predictions
last_predictions = self._get_recent_predictions(symbol)
# Build BaseDataInput
# Build BaseDataInput (no validation for speed - assume data is good)
base_data = BaseDataInput(
symbol=symbol,
timestamp=datetime.now(),
@ -134,11 +138,6 @@ class SimplifiedDataIntegration:
last_predictions=last_predictions
)
# Validate the data
if not base_data.validate():
logger.warning(f"BaseDataInput validation failed for {symbol}")
return None
return base_data
except Exception as e:
@ -146,11 +145,39 @@ class SimplifiedDataIntegration:
return None
def _get_ohlcv_data_list(self, symbol: str, timeframe: str, max_count: int) -> List[OHLCVBar]:
"""Get OHLCV data list from cache and historical data"""
"""Get OHLCV data list from pre-built cache for instant access"""
try:
with self._ohlcv_cache_lock:
cache_key = f"{symbol}_{timeframe}"
# Check if we have fresh cached data (updated within last 5 seconds)
last_update = self._last_cache_update.get(cache_key)
if (last_update and
(datetime.now() - last_update).total_seconds() < 5 and
cache_key in self._ohlcv_cache):
cached_data = self._ohlcv_cache[cache_key]
return cached_data[-max_count:] if len(cached_data) >= max_count else cached_data
# Need to rebuild cache for this symbol/timeframe
data_list = self._build_ohlcv_cache(symbol, timeframe, max_count)
# Cache the result
self._ohlcv_cache[cache_key] = data_list
self._last_cache_update[cache_key] = datetime.now()
return data_list[-max_count:] if len(data_list) >= max_count else data_list
except Exception as e:
logger.error(f"Error getting OHLCV data list for {symbol}/{timeframe}: {e}")
return self._create_dummy_data_list(symbol, timeframe, max_count)
def _build_ohlcv_cache(self, symbol: str, timeframe: str, max_count: int) -> List[OHLCVBar]:
"""Build OHLCV cache from historical and current data"""
try:
data_list = []
# Get historical data first
# Get historical data first (this should be fast as it's already cached)
historical_df = self.cache.get_historical_data(symbol, timeframe)
if historical_df is not None and not historical_df.empty:
# Convert historical data to OHLCVBar objects
@ -174,34 +201,14 @@ class SimplifiedDataIntegration:
# Ensure we have the right amount of data (pad if necessary)
while len(data_list) < max_count:
# Pad with the last available data or create dummy data
if data_list:
last_bar = data_list[-1]
dummy_bar = OHLCVBar(
symbol=symbol,
timestamp=last_bar.timestamp,
open=last_bar.close,
high=last_bar.close,
low=last_bar.close,
close=last_bar.close,
volume=0.0,
timeframe=timeframe
)
else:
# Create completely dummy data
dummy_bar = OHLCVBar(
symbol=symbol,
timestamp=datetime.now(),
open=0.0, high=0.0, low=0.0, close=0.0, volume=0.0,
timeframe=timeframe
)
data_list.append(dummy_bar)
data_list.extend(self._create_dummy_data_list(symbol, timeframe, max_count - len(data_list)))
return data_list[-max_count:] # Return last max_count items
return data_list
except Exception as e:
logger.error(f"Error getting OHLCV data list for {symbol} {timeframe}: {e}")
return []
logger.error(f"Error building OHLCV cache for {symbol}/{timeframe}: {e}")
return self._create_dummy_data_list(symbol, timeframe, max_count)
def _try_historical_fallback(self, symbol: str, missing_timeframes: List[str]) -> bool:
"""Try to use historical data for missing timeframes"""