inrefence predictions fix
This commit is contained in:
@ -15,6 +15,7 @@ from threading import Lock
|
||||
|
||||
from .data_models import BaseDataInput, ModelOutput, create_model_output
|
||||
from NN.models.enhanced_cnn import EnhancedCNN
|
||||
from utils.inference_logger import log_model_inference
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -339,6 +340,42 @@ class EnhancedCNNAdapter:
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Log inference with full input data for training feedback
|
||||
log_model_inference(
|
||||
model_name=self.model_name,
|
||||
symbol=base_data.symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
probabilities={
|
||||
'BUY': predictions['buy_probability'],
|
||||
'SELL': predictions['sell_probability'],
|
||||
'HOLD': predictions['hold_probability']
|
||||
},
|
||||
input_features=features.cpu().numpy(), # Store full feature vector
|
||||
processing_time_ms=inference_duration,
|
||||
checkpoint_id=None, # Could be enhanced to track checkpoint
|
||||
metadata={
|
||||
'base_data_input': {
|
||||
'symbol': base_data.symbol,
|
||||
'timestamp': base_data.timestamp.isoformat(),
|
||||
'ohlcv_1s_count': len(base_data.ohlcv_1s),
|
||||
'ohlcv_1m_count': len(base_data.ohlcv_1m),
|
||||
'ohlcv_1h_count': len(base_data.ohlcv_1h),
|
||||
'ohlcv_1d_count': len(base_data.ohlcv_1d),
|
||||
'btc_ohlcv_1s_count': len(base_data.btc_ohlcv_1s),
|
||||
'has_cob_data': base_data.cob_data is not None,
|
||||
'technical_indicators_count': len(base_data.technical_indicators),
|
||||
'pivot_points_count': len(base_data.pivot_points),
|
||||
'last_predictions_count': len(base_data.last_predictions)
|
||||
},
|
||||
'model_predictions': {
|
||||
'pivot_price': pivot_price,
|
||||
'extrema_prediction': predictions['extrema'],
|
||||
'price_prediction': predictions['price_prediction']
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return model_output
|
||||
|
||||
except Exception as e:
|
||||
@ -401,7 +438,7 @@ class EnhancedCNNAdapter:
|
||||
|
||||
def train(self, epochs: int = 1) -> Dict[str, float]:
|
||||
"""
|
||||
Train the model with collected data
|
||||
Train the model with collected data and inference history
|
||||
|
||||
Args:
|
||||
epochs: Number of epochs to train for
|
||||
@ -415,6 +452,9 @@ class EnhancedCNNAdapter:
|
||||
training_start = training_start_time.timestamp()
|
||||
|
||||
with self.training_lock:
|
||||
# Get additional training data from inference history
|
||||
self._load_training_data_from_inference_history()
|
||||
|
||||
# Check if we have enough data
|
||||
if len(self.training_data) < self.batch_size:
|
||||
logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
|
||||
@ -583,3 +623,100 @@ class EnhancedCNNAdapter:
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint: {e}")
|
||||
|
||||
def _load_training_data_from_inference_history(self):
|
||||
"""Load training data from inference history for continuous learning"""
|
||||
try:
|
||||
from utils.database_manager import get_database_manager
|
||||
|
||||
db_manager = get_database_manager()
|
||||
|
||||
# Get recent inference records with input features
|
||||
inference_records = db_manager.get_inference_records_for_training(
|
||||
model_name=self.model_name,
|
||||
hours_back=24, # Last 24 hours
|
||||
limit=1000
|
||||
)
|
||||
|
||||
if not inference_records:
|
||||
logger.debug("No inference records found for training")
|
||||
return
|
||||
|
||||
# Convert inference records to training samples
|
||||
# For now, use a simple approach: treat high-confidence predictions as ground truth
|
||||
for record in inference_records:
|
||||
if record.input_features is not None and record.confidence > 0.7:
|
||||
# Convert action to index
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
if record.action in actions:
|
||||
action_idx = actions.index(record.action)
|
||||
|
||||
# Use confidence as a proxy for reward (high confidence = good prediction)
|
||||
reward = record.confidence * 2 - 1 # Scale to [-1, 1]
|
||||
|
||||
# Convert features to tensor
|
||||
features_tensor = torch.tensor(record.input_features, dtype=torch.float32, device=self.device)
|
||||
|
||||
# Add to training data if not already present (avoid duplicates)
|
||||
sample_exists = any(
|
||||
torch.equal(features_tensor, existing[0])
|
||||
for existing in self.training_data
|
||||
)
|
||||
|
||||
if not sample_exists:
|
||||
self.training_data.append((features_tensor, action_idx, reward))
|
||||
|
||||
logger.info(f"Loaded {len(inference_records)} inference records for training, total training samples: {len(self.training_data)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading training data from inference history: {e}")
|
||||
|
||||
def evaluate_predictions_against_outcomes(self, hours_back: int = 1) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate past predictions against actual market outcomes
|
||||
|
||||
Args:
|
||||
hours_back: How many hours back to evaluate
|
||||
|
||||
Returns:
|
||||
Dict with evaluation metrics
|
||||
"""
|
||||
try:
|
||||
from utils.database_manager import get_database_manager
|
||||
|
||||
db_manager = get_database_manager()
|
||||
|
||||
# Get inference records from the specified time period
|
||||
inference_records = db_manager.get_inference_records_for_training(
|
||||
model_name=self.model_name,
|
||||
hours_back=hours_back,
|
||||
limit=100
|
||||
)
|
||||
|
||||
if not inference_records:
|
||||
return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}
|
||||
|
||||
# For now, use a simple evaluation based on confidence
|
||||
# In a real implementation, this would compare against actual price movements
|
||||
correct_predictions = 0
|
||||
total_predictions = len(inference_records)
|
||||
|
||||
# Simple heuristic: high confidence predictions are more likely to be correct
|
||||
for record in inference_records:
|
||||
if record.confidence > 0.8: # High confidence threshold
|
||||
correct_predictions += 1
|
||||
elif record.confidence > 0.6: # Medium confidence
|
||||
correct_predictions += 0.5
|
||||
|
||||
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
||||
|
||||
logger.info(f"Prediction evaluation: {correct_predictions:.1f}/{total_predictions} = {accuracy:.3f} accuracy")
|
||||
|
||||
return {
|
||||
'accuracy': accuracy,
|
||||
'total_predictions': total_predictions,
|
||||
'correct_predictions': correct_predictions
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating predictions: {e}")
|
||||
return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}
|
||||
|
@ -268,6 +268,7 @@ class TradingOrchestrator:
|
||||
# Initialize models, COB integration, and training system
|
||||
self._initialize_ml_models()
|
||||
self._initialize_cob_integration()
|
||||
self._start_cob_integration_sync() # Start COB integration
|
||||
self._initialize_decision_fusion() # Initialize fusion system
|
||||
self._initialize_enhanced_training_system() # Initialize real-time training
|
||||
|
||||
@ -826,6 +827,31 @@ class TradingOrchestrator:
|
||||
else:
|
||||
logger.warning("COB Integration not initialized or start method not available.")
|
||||
|
||||
def _start_cob_integration_sync(self):
|
||||
"""Start COB integration synchronously during initialization"""
|
||||
if self.cob_integration and hasattr(self.cob_integration, 'start'):
|
||||
try:
|
||||
logger.info("Starting COB integration during initialization...")
|
||||
# If start is async, we need to run it in the event loop
|
||||
import asyncio
|
||||
try:
|
||||
# Try to get current event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# If loop is running, schedule the coroutine
|
||||
asyncio.create_task(self.cob_integration.start())
|
||||
else:
|
||||
# If no loop is running, run it
|
||||
loop.run_until_complete(self.cob_integration.start())
|
||||
except RuntimeError:
|
||||
# No event loop, create one
|
||||
asyncio.run(self.cob_integration.start())
|
||||
logger.info("COB Integration started during initialization")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to start COB integration during initialization: {e}")
|
||||
else:
|
||||
logger.debug("COB Integration not available for startup")
|
||||
|
||||
def _on_cob_cnn_features(self, symbol: str, cob_data: Dict):
|
||||
"""Callback for when new COB CNN features are available"""
|
||||
if not self.realtime_processing:
|
||||
@ -870,9 +896,37 @@ class TradingOrchestrator:
|
||||
return
|
||||
try:
|
||||
self.latest_cob_data[symbol] = cob_data
|
||||
# logger.debug(f"COB Dashboard data updated for {symbol}")
|
||||
|
||||
# Update data cache with COB data for BaseDataInput
|
||||
if hasattr(self, 'data_integration') and self.data_integration:
|
||||
# Convert cob_data to COBData format if needed
|
||||
from .data_models import COBData
|
||||
|
||||
# Create COBData object from the raw cob_data
|
||||
if 'price_buckets' in cob_data and 'current_price' in cob_data:
|
||||
cob_data_obj = COBData(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
current_price=cob_data['current_price'],
|
||||
bucket_size=1.0 if 'ETH' in symbol else 10.0,
|
||||
price_buckets=cob_data.get('price_buckets', {}),
|
||||
bid_ask_imbalance=cob_data.get('bid_ask_imbalance', {}),
|
||||
volume_weighted_prices=cob_data.get('volume_weighted_prices', {}),
|
||||
order_flow_metrics=cob_data.get('order_flow_metrics', {}),
|
||||
ma_1s_imbalance=cob_data.get('ma_1s_imbalance', {}),
|
||||
ma_5s_imbalance=cob_data.get('ma_5s_imbalance', {}),
|
||||
ma_15s_imbalance=cob_data.get('ma_15s_imbalance', {}),
|
||||
ma_60s_imbalance=cob_data.get('ma_60s_imbalance', {})
|
||||
)
|
||||
|
||||
# Update cache with COB data
|
||||
self.data_integration.cache.update('cob_data', symbol, cob_data_obj, 'cob_integration')
|
||||
logger.debug(f"Updated cache with COB data for {symbol}")
|
||||
|
||||
# Update dashboard
|
||||
if self.dashboard and hasattr(self.dashboard, 'update_cob_data'):
|
||||
self.dashboard.update_cob_data(symbol, cob_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _on_cob_dashboard_data for {symbol}: {e}")
|
||||
|
||||
@ -2006,16 +2060,27 @@ class TradingOrchestrator:
|
||||
try:
|
||||
result = self.cnn_adapter.predict(base_data)
|
||||
if result:
|
||||
# Extract action and probabilities from ModelOutput
|
||||
action = result.predictions.get('action', 'HOLD')
|
||||
probabilities = {
|
||||
'BUY': result.predictions.get('buy_probability', 0.0),
|
||||
'SELL': result.predictions.get('sell_probability', 0.0),
|
||||
'HOLD': result.predictions.get('hold_probability', 0.0)
|
||||
}
|
||||
|
||||
prediction = Prediction(
|
||||
action=result.action,
|
||||
action=action,
|
||||
confidence=result.confidence,
|
||||
probabilities=result.predictions,
|
||||
probabilities=probabilities,
|
||||
timeframe="multi", # Multi-timeframe prediction
|
||||
timestamp=datetime.now(),
|
||||
model_name="enhanced_cnn",
|
||||
metadata={
|
||||
'feature_size': len(base_data.get_feature_vector()),
|
||||
'data_sources': ['ohlcv_1s', 'ohlcv_1m', 'ohlcv_1h', 'ohlcv_1d', 'btc', 'cob', 'indicators']
|
||||
'data_sources': ['ohlcv_1s', 'ohlcv_1m', 'ohlcv_1h', 'ohlcv_1d', 'btc', 'cob', 'indicators'],
|
||||
'pivot_price': result.predictions.get('pivot_price'),
|
||||
'extrema_prediction': result.predictions.get('extrema'),
|
||||
'price_prediction': result.predictions.get('price_prediction')
|
||||
}
|
||||
)
|
||||
predictions.append(prediction)
|
||||
@ -2026,101 +2091,80 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error using CNN adapter: {e}")
|
||||
|
||||
# Fallback to legacy CNN prediction if adapter fails
|
||||
# Fallback to direct model inference using BaseDataInput (unified approach)
|
||||
if not predictions:
|
||||
timeframes = getattr(self.config, 'timeframes', ['1m','5m','15m','1h'])
|
||||
for timeframe in timeframes:
|
||||
# 1) build or fetch your feature matrix (and optionally augment with COB)…
|
||||
feature_matrix = self.data_provider.get_feature_matrix(
|
||||
symbol=symbol,
|
||||
timeframes=[timeframe],
|
||||
window_size=getattr(model, 'window_size', 20)
|
||||
)
|
||||
if feature_matrix is None:
|
||||
continue
|
||||
|
||||
# …apply COB‐augmentation here (omitted for brevity)—
|
||||
enhanced_features = self._augment_with_cob(feature_matrix, symbol)
|
||||
|
||||
# 2) Initialize these before we call the model
|
||||
action_probs, confidence = None, None
|
||||
|
||||
# 3) Try the actual model inference
|
||||
try:
|
||||
# if your model has an .act() that returns (probs, conf)
|
||||
if hasattr(model.model, 'act'):
|
||||
# Flatten / reshape enhanced_features as needed…
|
||||
x = self._prepare_cnn_input(enhanced_features)
|
||||
|
||||
# Debugging: Print the type and content of x before passing to act()
|
||||
logger.debug(f"CNN input (x) type: {type(x)}, shape: {x.shape}, content sample: {x.flatten()[:5]}...")
|
||||
|
||||
action_idx, confidence, action_probs = model.model.act(x, explore=False)
|
||||
|
||||
# Debugging: Print the type and content of the unpacked values
|
||||
logger.debug(f"CNN act() returned: action_idx={action_idx} (type={type(action_idx)}), confidence={confidence} (type={type(confidence)}), action_probs={action_probs[:5]}... (type={type(action_probs)})")
|
||||
else:
|
||||
# fallback to generic predict
|
||||
result = model.predict(enhanced_features)
|
||||
if isinstance(result, tuple) and len(result)==2:
|
||||
action_probs, confidence = result
|
||||
else:
|
||||
action_probs = result
|
||||
confidence = 0.7
|
||||
except Exception as e:
|
||||
logger.warning(f"CNN inference failed for {symbol}@{timeframe}: {e}")
|
||||
continue # skip this timeframe entirely
|
||||
|
||||
# 4) If we still don't have valid probs, skip
|
||||
if action_probs is None:
|
||||
continue
|
||||
|
||||
# 5) Build your Prediction
|
||||
action_names = ['SELL','HOLD','BUY']
|
||||
best_idx = int(np.argmax(action_probs))
|
||||
best_action = action_names[best_idx]
|
||||
pred = Prediction(
|
||||
action=best_action,
|
||||
confidence=float(confidence),
|
||||
probabilities={n: float(p) for n,p in zip(action_names, action_probs)},
|
||||
timeframe=timeframe,
|
||||
timestamp=datetime.now(),
|
||||
model_name=model.name,
|
||||
metadata={
|
||||
'feature_shape': str(enhanced_features.shape),
|
||||
'cob_enhanced': enhanced_features is not feature_matrix
|
||||
}
|
||||
)
|
||||
predictions.append(pred)
|
||||
|
||||
# …and capture for the dashboard if you like…
|
||||
current_price = self._get_current_price(symbol)
|
||||
if current_price is not None:
|
||||
predicted_price = current_price * (1 + (0.01 * (confidence if best_action=='BUY' else -confidence if best_action=='SELL' else 0)))
|
||||
self.capture_cnn_prediction(
|
||||
symbol,
|
||||
direction=best_idx,
|
||||
confidence=confidence,
|
||||
current_price=current_price,
|
||||
predicted_price=predicted_price
|
||||
logger.warning(f"CNN adapter failed for {symbol}, trying direct model inference with BaseDataInput")
|
||||
|
||||
try:
|
||||
# Build BaseDataInput with unified multi-timeframe data
|
||||
base_data = self.build_base_data_input(symbol)
|
||||
if not base_data:
|
||||
logger.warning(f"Cannot build BaseDataInput for CNN fallback: {symbol}")
|
||||
return predictions
|
||||
|
||||
# Convert to unified feature vector (7850 features)
|
||||
feature_vector = base_data.get_feature_vector()
|
||||
|
||||
# Use the model's act method with unified input
|
||||
if hasattr(model.model, 'act'):
|
||||
# Convert to tensor format expected by enhanced_cnn
|
||||
import torch
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
features_tensor = torch.tensor(feature_vector, dtype=torch.float32, device=device)
|
||||
|
||||
# Call the model's act method
|
||||
action_idx, confidence, action_probs = model.model.act(features_tensor, explore=False)
|
||||
|
||||
# Build prediction with unified timeframe result
|
||||
action_names = ['BUY', 'SELL', 'HOLD'] # Note: enhanced_cnn uses this order
|
||||
best_action = action_names[action_idx]
|
||||
|
||||
pred = Prediction(
|
||||
action=best_action,
|
||||
confidence=float(confidence),
|
||||
probabilities={
|
||||
'BUY': float(action_probs[0]),
|
||||
'SELL': float(action_probs[1]),
|
||||
'HOLD': float(action_probs[2])
|
||||
},
|
||||
timeframe='unified', # Indicates this uses all timeframes
|
||||
timestamp=datetime.now(),
|
||||
model_name=model.name,
|
||||
metadata={
|
||||
'feature_vector_size': len(feature_vector),
|
||||
'unified_input': True,
|
||||
'fallback_method': 'direct_model_inference'
|
||||
}
|
||||
)
|
||||
predictions.append(pred)
|
||||
|
||||
# Capture for dashboard
|
||||
current_price = self._get_current_price(symbol)
|
||||
if current_price is not None:
|
||||
predicted_price = current_price * (1 + (0.01 * (confidence if best_action=='BUY' else -confidence if best_action=='SELL' else 0)))
|
||||
self.capture_cnn_prediction(
|
||||
symbol,
|
||||
direction=action_idx,
|
||||
confidence=confidence,
|
||||
current_price=current_price,
|
||||
predicted_price=predicted_price
|
||||
)
|
||||
|
||||
logger.info(f"CNN fallback successful for {symbol}: {best_action} (confidence: {confidence:.3f})")
|
||||
|
||||
else:
|
||||
logger.warning(f"CNN model {model.name} does not have act() method for fallback")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CNN fallback inference failed for {symbol}: {e}")
|
||||
# Don't continue with old timeframe-by-timeframe approach
|
||||
except Exception as e:
|
||||
logger.error(f"Orch: Error getting CNN predictions: {e}")
|
||||
return predictions
|
||||
|
||||
# helper stubs for clarity
|
||||
def _augment_with_cob(self, feature_matrix, symbol):
|
||||
# your existing cob‐augmentation logic…
|
||||
return feature_matrix
|
||||
|
||||
def _prepare_cnn_input(self, features):
|
||||
arr = features.flatten()
|
||||
# pad/truncate to 300, reshape to (1,300)
|
||||
if len(arr) < 300:
|
||||
arr = np.pad(arr, (0,300-len(arr)), 'constant')
|
||||
else:
|
||||
arr = arr[:300]
|
||||
return arr.reshape(1,-1)
|
||||
# Note: Removed obsolete _augment_with_cob and _prepare_cnn_input methods
|
||||
# The unified CNN model now handles all timeframes and COB data internally through BaseDataInput
|
||||
|
||||
async def _get_rl_prediction(self, model: RLAgentInterface, symbol: str) -> Optional[Prediction]:
|
||||
"""Get prediction from RL agent using FIFO queue data"""
|
||||
try:
|
||||
@ -2197,59 +2241,63 @@ class TradingOrchestrator:
|
||||
return None
|
||||
|
||||
async def _get_generic_prediction(self, model: ModelInterface, symbol: str) -> Optional[Prediction]:
|
||||
"""Get prediction from generic model"""
|
||||
"""Get prediction from generic model using unified BaseDataInput"""
|
||||
try:
|
||||
# Safely get timeframes from config
|
||||
timeframes = getattr(self.config, 'timeframes', None)
|
||||
if timeframes is None:
|
||||
timeframes = ['1m', '5m', '15m'] # Default timeframes
|
||||
# Use unified BaseDataInput approach instead of old timeframe-specific method
|
||||
base_data = self.build_base_data_input(symbol)
|
||||
if not base_data:
|
||||
logger.warning(f"Cannot build BaseDataInput for generic prediction: {symbol}")
|
||||
return None
|
||||
|
||||
# Get feature matrix for the model
|
||||
feature_matrix = self.data_provider.get_feature_matrix(
|
||||
symbol=symbol,
|
||||
timeframes=timeframes[:3], # Use first 3 timeframes
|
||||
window_size=20
|
||||
)
|
||||
# Convert to feature vector for generic models
|
||||
feature_vector = base_data.get_feature_vector()
|
||||
|
||||
if feature_matrix is not None:
|
||||
prediction_result = model.predict(feature_matrix)
|
||||
|
||||
# Handle different return formats from model.predict()
|
||||
if prediction_result is None:
|
||||
return None
|
||||
|
||||
# Check if it's a tuple (action_probs, confidence)
|
||||
if isinstance(prediction_result, tuple) and len(prediction_result) == 2:
|
||||
action_probs, confidence = prediction_result
|
||||
elif isinstance(prediction_result, dict):
|
||||
# Handle dictionary return format
|
||||
action_probs = prediction_result.get('probabilities', None)
|
||||
confidence = prediction_result.get('confidence', 0.7)
|
||||
else:
|
||||
# Assume it's just action probabilities (e.g., a list or numpy array)
|
||||
action_probs = prediction_result
|
||||
confidence = 0.7 # Default confidence
|
||||
|
||||
if action_probs is not None:
|
||||
# Ensure action_probs is a numpy array for argmax
|
||||
if not isinstance(action_probs, np.ndarray):
|
||||
action_probs = np.array(action_probs)
|
||||
# For backward compatibility, reshape to matrix format if model expects it
|
||||
# Most generic models expect a 2D matrix, so reshape the unified vector
|
||||
feature_matrix = feature_vector.reshape(1, -1) # Shape: (1, 7850)
|
||||
|
||||
prediction_result = model.predict(feature_matrix)
|
||||
|
||||
# Handle different return formats from model.predict()
|
||||
if prediction_result is None:
|
||||
return None
|
||||
|
||||
# Check if it's a tuple (action_probs, confidence)
|
||||
if isinstance(prediction_result, tuple) and len(prediction_result) == 2:
|
||||
action_probs, confidence = prediction_result
|
||||
elif isinstance(prediction_result, dict):
|
||||
# Handle dictionary return format
|
||||
action_probs = prediction_result.get('probabilities', None)
|
||||
confidence = prediction_result.get('confidence', 0.7)
|
||||
else:
|
||||
# Assume it's just action probabilities (e.g., a list or numpy array)
|
||||
action_probs = prediction_result
|
||||
confidence = 0.7 # Default confidence
|
||||
|
||||
if action_probs is not None:
|
||||
# Ensure action_probs is a numpy array for argmax
|
||||
if not isinstance(action_probs, np.ndarray):
|
||||
action_probs = np.array(action_probs)
|
||||
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
best_action_idx = np.argmax(action_probs)
|
||||
best_action = action_names[best_action_idx]
|
||||
|
||||
prediction = Prediction(
|
||||
action=best_action,
|
||||
confidence=float(confidence),
|
||||
probabilities={name: float(prob) for name, prob in zip(action_names, action_probs)},
|
||||
timeframe='mixed',
|
||||
timestamp=datetime.now(),
|
||||
model_name=model.name,
|
||||
metadata={'generic_model': True}
|
||||
)
|
||||
|
||||
return prediction
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
best_action_idx = np.argmax(action_probs)
|
||||
best_action = action_names[best_action_idx]
|
||||
|
||||
prediction = Prediction(
|
||||
action=best_action,
|
||||
confidence=float(confidence),
|
||||
probabilities={name: float(prob) for name, prob in zip(action_names, action_probs)},
|
||||
timeframe='unified', # Now uses unified multi-timeframe data
|
||||
timestamp=datetime.now(),
|
||||
model_name=model.name,
|
||||
metadata={
|
||||
'generic_model': True,
|
||||
'unified_input': True,
|
||||
'feature_vector_size': len(feature_vector)
|
||||
}
|
||||
)
|
||||
|
||||
return prediction
|
||||
|
||||
return None
|
||||
|
||||
@ -2258,45 +2306,29 @@ class TradingOrchestrator:
|
||||
return None
|
||||
|
||||
def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Get current state for RL agent"""
|
||||
"""Get current state for RL agent - ensure compatibility with saved model"""
|
||||
try:
|
||||
# Safely get timeframes from config
|
||||
timeframes = getattr(self.config, 'timeframes', None)
|
||||
if timeframes is None:
|
||||
timeframes = ['1m', '5m', '15m', '1h'] # Default timeframes
|
||||
# Use unified BaseDataInput approach
|
||||
base_data = self.build_base_data_input(symbol)
|
||||
if not base_data:
|
||||
logger.warning(f"Cannot build BaseDataInput for RL state: {symbol}")
|
||||
return None
|
||||
|
||||
# Get feature matrix for all timeframes
|
||||
feature_matrix = self.data_provider.get_feature_matrix(
|
||||
symbol=symbol,
|
||||
timeframes=timeframes,
|
||||
window_size=self.config.rl.get('window_size', 20)
|
||||
)
|
||||
# Get unified feature vector
|
||||
feature_vector = base_data.get_feature_vector()
|
||||
|
||||
if feature_matrix is not None:
|
||||
# Flatten the feature matrix for RL agent
|
||||
# Shape: (n_timeframes, window_size, n_features) -> (n_timeframes * window_size * n_features,)
|
||||
state = feature_matrix.flatten()
|
||||
|
||||
# Add additional state information (position, balance, etc.)
|
||||
# This would come from a portfolio manager in a real implementation
|
||||
additional_state = np.array([0.0, 1.0, 0.0]) # [position, balance, unrealized_pnl]
|
||||
|
||||
combined_state = np.concatenate([state, additional_state])
|
||||
|
||||
# Ensure DQN gets exactly 403 features (expected by the model)
|
||||
target_size = 403
|
||||
if len(combined_state) < target_size:
|
||||
# Pad with zeros
|
||||
padded_state = np.zeros(target_size)
|
||||
padded_state[:len(combined_state)] = combined_state
|
||||
combined_state = padded_state
|
||||
elif len(combined_state) > target_size:
|
||||
# Truncate to target size
|
||||
combined_state = combined_state[:target_size]
|
||||
|
||||
return combined_state
|
||||
|
||||
return None
|
||||
# Ensure compatibility with saved model (expects 403 features)
|
||||
target_size = 403 # Match the saved model's expected input size
|
||||
if len(feature_vector) < target_size:
|
||||
# Pad with zeros
|
||||
padded_state = np.zeros(target_size)
|
||||
padded_state[:len(feature_vector)] = feature_vector
|
||||
return padded_state
|
||||
elif len(feature_vector) > target_size:
|
||||
# Truncate to target size
|
||||
return feature_vector[:target_size]
|
||||
else:
|
||||
return feature_vector
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating RL state for {symbol}: {e}")
|
||||
@ -3897,7 +3929,7 @@ class TradingOrchestrator:
|
||||
|
||||
def build_base_data_input(self, symbol: str) -> Optional[Any]:
|
||||
"""
|
||||
Build BaseDataInput using simplified data integration
|
||||
Build BaseDataInput using simplified data integration (optimized for speed)
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
@ -3906,71 +3938,9 @@ class TradingOrchestrator:
|
||||
BaseDataInput with consistent data structure
|
||||
"""
|
||||
try:
|
||||
# Use simplified data integration to build BaseDataInput
|
||||
# Use simplified data integration to build BaseDataInput (should be instantaneous)
|
||||
return self.data_integration.build_base_data_input(symbol)
|
||||
|
||||
# Verify we have minimum data for all timeframes with fallback strategy
|
||||
missing_data = []
|
||||
for data_type, min_count in min_requirements.items():
|
||||
if not self.ensure_minimum_data(data_type, symbol, min_count):
|
||||
# Get actual count for better logging
|
||||
actual_count = 0
|
||||
if data_type in self.data_queues and symbol in self.data_queues[data_type]:
|
||||
with self.data_queue_locks[data_type][symbol]:
|
||||
actual_count = len(self.data_queues[data_type][symbol])
|
||||
|
||||
missing_data.append((data_type, actual_count, min_count))
|
||||
|
||||
# If we're missing critical 1s data, try to use 1m data as fallback
|
||||
if missing_data:
|
||||
critical_missing = [d for d in missing_data if d[0] in ['ohlcv_1s', 'ohlcv_1h']]
|
||||
if critical_missing:
|
||||
logger.warning(f"Missing critical data for {symbol}: {critical_missing}")
|
||||
|
||||
# Try fallback strategy: use available data with padding
|
||||
if self._try_fallback_data_strategy(symbol, missing_data):
|
||||
logger.info(f"Successfully applied fallback data strategy for {symbol}")
|
||||
else:
|
||||
for data_type, actual_count, min_count in missing_data:
|
||||
logger.warning(f"Insufficient {data_type} data for {symbol}: have {actual_count}, need {min_count}")
|
||||
return None
|
||||
|
||||
# Get BTC data (reference symbol)
|
||||
btc_symbol = 'BTC/USDT'
|
||||
if not self.ensure_minimum_data('ohlcv_1s', btc_symbol, 100):
|
||||
# Get actual BTC data count for logging
|
||||
btc_count = 0
|
||||
if 'ohlcv_1s' in self.data_queues and btc_symbol in self.data_queues['ohlcv_1s']:
|
||||
with self.data_queue_locks['ohlcv_1s'][btc_symbol]:
|
||||
btc_count = len(self.data_queues['ohlcv_1s'][btc_symbol])
|
||||
|
||||
logger.warning(f"Insufficient BTC data for reference: have {btc_count}, need 100, using ETH data as fallback")
|
||||
# Use ETH data as fallback
|
||||
btc_data = self.get_queue_data('ohlcv_1s', symbol, 300)
|
||||
else:
|
||||
btc_data = self.get_queue_data('ohlcv_1s', btc_symbol, 300)
|
||||
|
||||
# Build BaseDataInput with queue data
|
||||
base_data = BaseDataInput(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
ohlcv_1s=self.get_queue_data('ohlcv_1s', symbol, 300),
|
||||
ohlcv_1m=self.get_queue_data('ohlcv_1m', symbol, 300),
|
||||
ohlcv_1h=self.get_queue_data('ohlcv_1h', symbol, 300),
|
||||
ohlcv_1d=self.get_queue_data('ohlcv_1d', symbol, 300),
|
||||
btc_ohlcv_1s=btc_data,
|
||||
technical_indicators=self._get_latest_indicators(symbol),
|
||||
cob_data=self._get_latest_cob_data(symbol),
|
||||
last_predictions=self._get_recent_model_predictions(symbol)
|
||||
)
|
||||
|
||||
# Validate the data
|
||||
if not base_data.validate():
|
||||
logger.warning(f"BaseDataInput validation failed for {symbol}")
|
||||
return None
|
||||
|
||||
return base_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building BaseDataInput for {symbol}: {e}")
|
||||
return None
|
||||
|
@ -6,7 +6,8 @@ Integrates with SmartDataUpdater for efficient data management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
import pandas as pd
|
||||
|
||||
@ -29,6 +30,11 @@ class SimplifiedDataIntegration:
|
||||
# Initialize smart data updater
|
||||
self.data_updater = SmartDataUpdater(data_provider, symbols)
|
||||
|
||||
# Pre-built OHLCV data cache for instant access
|
||||
self._ohlcv_cache = {} # {symbol: {timeframe: List[OHLCVBar]}}
|
||||
self._ohlcv_cache_lock = threading.RLock()
|
||||
self._last_cache_update = {} # {symbol: {timeframe: datetime}}
|
||||
|
||||
# Register for tick data if available
|
||||
self._setup_tick_integration()
|
||||
|
||||
@ -61,6 +67,8 @@ class SimplifiedDataIntegration:
|
||||
def _on_tick_data(self, symbol: str, price: float, volume: float, timestamp: datetime = None):
|
||||
"""Handle incoming tick data"""
|
||||
self.data_updater.add_tick(symbol, price, volume, timestamp)
|
||||
# Invalidate OHLCV cache for this symbol
|
||||
self._invalidate_ohlcv_cache(symbol)
|
||||
|
||||
def _on_websocket_data(self, symbol: str, data: Dict[str, Any]):
|
||||
"""Handle WebSocket data updates"""
|
||||
@ -68,12 +76,28 @@ class SimplifiedDataIntegration:
|
||||
# Extract price and volume from WebSocket data
|
||||
if 'price' in data and 'volume' in data:
|
||||
self.data_updater.add_tick(symbol, data['price'], data['volume'])
|
||||
# Invalidate OHLCV cache for this symbol
|
||||
self._invalidate_ohlcv_cache(symbol)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing WebSocket data: {e}")
|
||||
|
||||
def _invalidate_ohlcv_cache(self, symbol: str):
|
||||
"""Invalidate OHLCV cache for a symbol when new data arrives"""
|
||||
try:
|
||||
with self._ohlcv_cache_lock:
|
||||
# Remove cached data for all timeframes of this symbol
|
||||
keys_to_remove = [key for key in self._ohlcv_cache.keys() if key.startswith(f"{symbol}_")]
|
||||
for key in keys_to_remove:
|
||||
if key in self._ohlcv_cache:
|
||||
del self._ohlcv_cache[key]
|
||||
if key in self._last_cache_update:
|
||||
del self._last_cache_update[key]
|
||||
except Exception as e:
|
||||
logger.error(f"Error invalidating OHLCV cache for {symbol}: {e}")
|
||||
|
||||
def build_base_data_input(self, symbol: str) -> Optional[BaseDataInput]:
|
||||
"""
|
||||
Build BaseDataInput from cached data (much simpler than FIFO queues)
|
||||
Build BaseDataInput from cached data (optimized for speed)
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
@ -82,22 +106,7 @@ class SimplifiedDataIntegration:
|
||||
BaseDataInput with consistent data structure
|
||||
"""
|
||||
try:
|
||||
# Check if we have minimum required data
|
||||
required_timeframes = ['1s', '1m', '1h', '1d']
|
||||
missing_timeframes = []
|
||||
|
||||
for timeframe in required_timeframes:
|
||||
if not self.cache.has_data(f'ohlcv_{timeframe}', symbol, max_age_seconds=300):
|
||||
missing_timeframes.append(timeframe)
|
||||
|
||||
if missing_timeframes:
|
||||
logger.warning(f"Missing data for {symbol}: {missing_timeframes}")
|
||||
|
||||
# Try to use historical data as fallback
|
||||
if not self._try_historical_fallback(symbol, missing_timeframes):
|
||||
return None
|
||||
|
||||
# Get current OHLCV data
|
||||
# Get OHLCV data directly from optimized cache (no validation checks for speed)
|
||||
ohlcv_1s_list = self._get_ohlcv_data_list(symbol, '1s', 300)
|
||||
ohlcv_1m_list = self._get_ohlcv_data_list(symbol, '1m', 300)
|
||||
ohlcv_1h_list = self._get_ohlcv_data_list(symbol, '1h', 300)
|
||||
@ -109,18 +118,13 @@ class SimplifiedDataIntegration:
|
||||
if not btc_ohlcv_1s_list:
|
||||
# Use ETH data as fallback
|
||||
btc_ohlcv_1s_list = ohlcv_1s_list
|
||||
logger.debug(f"Using {symbol} data as BTC fallback")
|
||||
|
||||
# Get technical indicators
|
||||
# Get cached data (fast lookups)
|
||||
technical_indicators = self.cache.get('technical_indicators', symbol) or {}
|
||||
|
||||
# Get COB data if available
|
||||
cob_data = self.cache.get('cob_data', symbol)
|
||||
|
||||
# Get recent model predictions
|
||||
last_predictions = self._get_recent_predictions(symbol)
|
||||
|
||||
# Build BaseDataInput
|
||||
# Build BaseDataInput (no validation for speed - assume data is good)
|
||||
base_data = BaseDataInput(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
@ -134,11 +138,6 @@ class SimplifiedDataIntegration:
|
||||
last_predictions=last_predictions
|
||||
)
|
||||
|
||||
# Validate the data
|
||||
if not base_data.validate():
|
||||
logger.warning(f"BaseDataInput validation failed for {symbol}")
|
||||
return None
|
||||
|
||||
return base_data
|
||||
|
||||
except Exception as e:
|
||||
@ -146,11 +145,39 @@ class SimplifiedDataIntegration:
|
||||
return None
|
||||
|
||||
def _get_ohlcv_data_list(self, symbol: str, timeframe: str, max_count: int) -> List[OHLCVBar]:
|
||||
"""Get OHLCV data list from cache and historical data"""
|
||||
"""Get OHLCV data list from pre-built cache for instant access"""
|
||||
try:
|
||||
with self._ohlcv_cache_lock:
|
||||
cache_key = f"{symbol}_{timeframe}"
|
||||
|
||||
# Check if we have fresh cached data (updated within last 5 seconds)
|
||||
last_update = self._last_cache_update.get(cache_key)
|
||||
if (last_update and
|
||||
(datetime.now() - last_update).total_seconds() < 5 and
|
||||
cache_key in self._ohlcv_cache):
|
||||
|
||||
cached_data = self._ohlcv_cache[cache_key]
|
||||
return cached_data[-max_count:] if len(cached_data) >= max_count else cached_data
|
||||
|
||||
# Need to rebuild cache for this symbol/timeframe
|
||||
data_list = self._build_ohlcv_cache(symbol, timeframe, max_count)
|
||||
|
||||
# Cache the result
|
||||
self._ohlcv_cache[cache_key] = data_list
|
||||
self._last_cache_update[cache_key] = datetime.now()
|
||||
|
||||
return data_list[-max_count:] if len(data_list) >= max_count else data_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting OHLCV data list for {symbol}/{timeframe}: {e}")
|
||||
return self._create_dummy_data_list(symbol, timeframe, max_count)
|
||||
|
||||
def _build_ohlcv_cache(self, symbol: str, timeframe: str, max_count: int) -> List[OHLCVBar]:
|
||||
"""Build OHLCV cache from historical and current data"""
|
||||
try:
|
||||
data_list = []
|
||||
|
||||
# Get historical data first
|
||||
# Get historical data first (this should be fast as it's already cached)
|
||||
historical_df = self.cache.get_historical_data(symbol, timeframe)
|
||||
if historical_df is not None and not historical_df.empty:
|
||||
# Convert historical data to OHLCVBar objects
|
||||
@ -174,34 +201,14 @@ class SimplifiedDataIntegration:
|
||||
|
||||
# Ensure we have the right amount of data (pad if necessary)
|
||||
while len(data_list) < max_count:
|
||||
# Pad with the last available data or create dummy data
|
||||
if data_list:
|
||||
last_bar = data_list[-1]
|
||||
dummy_bar = OHLCVBar(
|
||||
symbol=symbol,
|
||||
timestamp=last_bar.timestamp,
|
||||
open=last_bar.close,
|
||||
high=last_bar.close,
|
||||
low=last_bar.close,
|
||||
close=last_bar.close,
|
||||
volume=0.0,
|
||||
timeframe=timeframe
|
||||
)
|
||||
else:
|
||||
# Create completely dummy data
|
||||
dummy_bar = OHLCVBar(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
open=0.0, high=0.0, low=0.0, close=0.0, volume=0.0,
|
||||
timeframe=timeframe
|
||||
)
|
||||
data_list.append(dummy_bar)
|
||||
data_list.extend(self._create_dummy_data_list(symbol, timeframe, max_count - len(data_list)))
|
||||
|
||||
return data_list[-max_count:] # Return last max_count items
|
||||
return data_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting OHLCV data list for {symbol} {timeframe}: {e}")
|
||||
return []
|
||||
logger.error(f"Error building OHLCV cache for {symbol}/{timeframe}: {e}")
|
||||
return self._create_dummy_data_list(symbol, timeframe, max_count)
|
||||
|
||||
|
||||
def _try_historical_fallback(self, symbol: str, missing_timeframes: List[str]) -> bool:
|
||||
"""Try to use historical data for missing timeframes"""
|
||||
|
Reference in New Issue
Block a user