timezones
This commit is contained in:
@ -1961,9 +1961,28 @@ class TradingOrchestrator:
|
||||
if historical_data is None or historical_data.empty:
|
||||
return
|
||||
|
||||
# Find closest price to prediction timestamp
|
||||
prediction_price = historical_data['close'].iloc[-1] # Simplified
|
||||
price_change_pct = (current_price - prediction_price) / prediction_price * 100
|
||||
# Use predicted price if available, otherwise fall back to historical price
|
||||
predicted_price = None
|
||||
if 'price_prediction' in prediction and prediction['price_prediction']:
|
||||
try:
|
||||
# Extract predicted price change from CNN output
|
||||
price_prediction_data = prediction['price_prediction']
|
||||
if isinstance(price_prediction_data, list) and len(price_prediction_data) > 0:
|
||||
predicted_price_change_pct = float(price_prediction_data[0]) * 0.01 # Convert to percentage
|
||||
predicted_price = current_price * (1 + predicted_price_change_pct)
|
||||
logger.debug(f"Using CNN price prediction: {predicted_price_change_pct:.3f}% -> ${predicted_price:.2f}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing price prediction: {e}")
|
||||
|
||||
# Fall back to historical price if no prediction available
|
||||
if predicted_price is None:
|
||||
prediction_price = historical_data['close'].iloc[-1] # Simplified
|
||||
price_change_pct = (current_price - prediction_price) / prediction_price * 100
|
||||
logger.debug(f"Using historical price comparison: ${prediction_price:.2f} -> ${current_price:.2f}")
|
||||
else:
|
||||
# Use predicted price for reward calculation
|
||||
price_change_pct = (current_price - predicted_price) / predicted_price * 100
|
||||
logger.debug(f"Using predicted price comparison: ${predicted_price:.2f} -> ${current_price:.2f}")
|
||||
|
||||
# Enhanced reward system based on prediction confidence and price movement magnitude
|
||||
predicted_action = prediction['action']
|
||||
@ -1974,12 +1993,16 @@ class TradingOrchestrator:
|
||||
predicted_action,
|
||||
prediction_confidence,
|
||||
price_change_pct,
|
||||
time_diff
|
||||
time_diff,
|
||||
predicted_price is not None # Add price prediction flag
|
||||
)
|
||||
|
||||
# Update model performance tracking
|
||||
if model_name not in self.model_performance:
|
||||
self.model_performance[model_name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
|
||||
self.model_performance[model_name] = {
|
||||
'correct': 0, 'total': 0, 'accuracy': 0.0,
|
||||
'price_predictions': {'total': 0, 'accurate': 0, 'avg_error': 0.0}
|
||||
}
|
||||
|
||||
self.model_performance[model_name]['total'] += 1
|
||||
if was_correct:
|
||||
@ -1990,6 +2013,26 @@ class TradingOrchestrator:
|
||||
self.model_performance[model_name]['total']
|
||||
)
|
||||
|
||||
# Track price prediction accuracy if available
|
||||
if predicted_price is not None:
|
||||
price_prediction_stats = self.model_performance[model_name]['price_predictions']
|
||||
price_prediction_stats['total'] += 1
|
||||
|
||||
# Calculate prediction error
|
||||
prediction_error_pct = abs(price_change_pct)
|
||||
price_prediction_stats['avg_error'] = (
|
||||
(price_prediction_stats['avg_error'] * (price_prediction_stats['total'] - 1) + prediction_error_pct) /
|
||||
price_prediction_stats['total']
|
||||
)
|
||||
|
||||
# Consider prediction accurate if error < 1%
|
||||
if prediction_error_pct < 1.0:
|
||||
price_prediction_stats['accurate'] += 1
|
||||
|
||||
logger.debug(f"Price prediction accuracy for {model_name}: "
|
||||
f"{price_prediction_stats['accurate']}/{price_prediction_stats['total']} "
|
||||
f"({price_prediction_stats['avg_error']:.2f}% avg error)")
|
||||
|
||||
# Train the specific model based on sophisticated outcome
|
||||
await self._train_model_on_outcome(record, was_correct, price_change_pct, reward)
|
||||
|
||||
@ -2003,15 +2046,17 @@ class TradingOrchestrator:
|
||||
'evaluated_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
price_pred_info = f"predicted: ${predicted_price:.2f}" if predicted_price is not None else "no price prediction"
|
||||
logger.debug(f"Evaluated {model_name} prediction: {'✓' if was_correct else '✗'} "
|
||||
f"({prediction['action']}, {price_change_pct:.2f}% change, "
|
||||
f"confidence: {prediction_confidence:.3f}, reward: {reward:.3f})")
|
||||
f"confidence: {prediction_confidence:.3f}, {price_pred_info}, reward: {reward:.3f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating and training on record: {e}")
|
||||
|
||||
def _calculate_sophisticated_reward(self, predicted_action: str, prediction_confidence: float,
|
||||
price_change_pct: float, time_diff_minutes: float) -> tuple[float, bool]:
|
||||
price_change_pct: float, time_diff_minutes: float,
|
||||
has_price_prediction: bool = False) -> tuple[float, bool]:
|
||||
"""
|
||||
Calculate sophisticated reward based on prediction accuracy, confidence, and price movement magnitude
|
||||
|
||||
@ -2072,6 +2117,11 @@ class TradingOrchestrator:
|
||||
# Final reward calculation
|
||||
final_reward = base_reward * time_decay
|
||||
|
||||
# Bonus for accurate price predictions
|
||||
if has_price_prediction and abs(price_change_pct) < 1.0: # Accurate price prediction
|
||||
final_reward *= 1.2 # 20% bonus for accurate price predictions
|
||||
logger.debug(f"Applied price prediction accuracy bonus: {final_reward:.3f}")
|
||||
|
||||
# Clamp reward to reasonable range
|
||||
final_reward = max(-5.0, min(5.0, final_reward))
|
||||
|
||||
@ -2224,6 +2274,11 @@ class TradingOrchestrator:
|
||||
batch_size = getattr(model, 'batch_size', 32)
|
||||
if memory_size >= batch_size:
|
||||
logger.debug(f"Training {model_name} with {memory_size} experiences")
|
||||
|
||||
# Ensure model is in training mode
|
||||
if hasattr(model, 'policy_net'):
|
||||
model.policy_net.train()
|
||||
|
||||
training_start_time = time.time()
|
||||
training_loss = model.replay()
|
||||
training_duration_ms = (time.time() - training_start_time) * 1000
|
||||
@ -2233,6 +2288,11 @@ class TradingOrchestrator:
|
||||
self._update_model_training_statistics(model_name, training_loss, training_duration_ms)
|
||||
logger.debug(f"RL training completed for {model_name}: loss={training_loss:.4f}, time={training_duration_ms:.1f}ms")
|
||||
return True
|
||||
elif training_loss == 0.0:
|
||||
logger.warning(f"RL training returned zero loss for {model_name} - possible gradient issue")
|
||||
# Still update training statistics
|
||||
self._update_model_training_statistics(model_name, training_duration_ms=training_duration_ms)
|
||||
return False # Training failed
|
||||
else:
|
||||
# Still update training statistics even if no loss returned
|
||||
self._update_model_training_statistics(model_name, training_duration_ms=training_duration_ms)
|
||||
@ -2321,7 +2381,7 @@ class TradingOrchestrator:
|
||||
symbol = record.get('symbol', 'ETH/USDT')
|
||||
actual_action = prediction['action']
|
||||
|
||||
# Check if adapter has add_training_sample method
|
||||
# Add training sample to adapter
|
||||
if hasattr(self.cnn_adapter, 'add_training_sample'):
|
||||
self.cnn_adapter.add_training_sample(symbol, actual_action, reward)
|
||||
logger.debug(f"Added training sample to CNN adapter: action={actual_action}, reward={reward:.3f}")
|
||||
@ -2331,14 +2391,25 @@ class TradingOrchestrator:
|
||||
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
|
||||
logger.debug(f"Training CNN with {len(self.cnn_adapter.training_data)} samples")
|
||||
training_start_time = time.time()
|
||||
|
||||
# Add validation to prevent overfitting
|
||||
training_results = self.cnn_adapter.train(epochs=1)
|
||||
training_duration_ms = (time.time() - training_start_time) * 1000
|
||||
|
||||
if training_results and 'loss' in training_results:
|
||||
current_loss = training_results['loss']
|
||||
self.update_model_loss(model_name, current_loss)
|
||||
accuracy = training_results.get('accuracy', 0.0)
|
||||
|
||||
# Validate training results - 100% accuracy is suspicious
|
||||
if accuracy >= 0.99:
|
||||
logger.warning(f"CNN training shows suspiciously high accuracy: {accuracy:.4f} - possible overfitting")
|
||||
# Don't update loss if accuracy is too high (likely overfitting)
|
||||
logger.warning("Skipping loss update due to potential overfitting")
|
||||
else:
|
||||
self.update_model_loss(model_name, current_loss)
|
||||
|
||||
self._update_model_training_statistics(model_name, current_loss, training_duration_ms)
|
||||
logger.debug(f"CNN training completed: loss={current_loss:.4f}, time={training_duration_ms:.1f}ms")
|
||||
logger.debug(f"CNN training completed: loss={current_loss:.4f}, accuracy={accuracy:.4f}, time={training_duration_ms:.1f}ms")
|
||||
return True
|
||||
else:
|
||||
# Still update training statistics even if no loss returned
|
||||
@ -2350,7 +2421,7 @@ class TradingOrchestrator:
|
||||
logger.debug(f"CNN adapter doesn't have add_training_sample method")
|
||||
|
||||
# Try direct model training methods
|
||||
if hasattr(model, 'add_training_sample'):
|
||||
elif hasattr(model, 'add_training_sample'):
|
||||
symbol = record.get('symbol', 'ETH/USDT')
|
||||
actual_action = prediction['action']
|
||||
model.add_training_sample(symbol, actual_action, reward)
|
||||
@ -2365,7 +2436,14 @@ class TradingOrchestrator:
|
||||
|
||||
if training_results and 'loss' in training_results:
|
||||
current_loss = training_results['loss']
|
||||
self.update_model_loss(model_name, current_loss)
|
||||
accuracy = training_results.get('accuracy', 0.0)
|
||||
|
||||
# Validate training results
|
||||
if accuracy >= 0.99:
|
||||
logger.warning(f"CNN training shows suspiciously high accuracy: {accuracy:.4f} - possible overfitting")
|
||||
else:
|
||||
self.update_model_loss(model_name, current_loss)
|
||||
|
||||
self._update_model_training_statistics(model_name, current_loss, training_duration_ms)
|
||||
logger.debug(f"CNN training completed: loss={current_loss:.4f}, time={training_duration_ms:.1f}ms")
|
||||
return True
|
||||
@ -2378,7 +2456,6 @@ class TradingOrchestrator:
|
||||
elif hasattr(model, 'train'):
|
||||
logger.debug(f"Using basic train method for {model_name}")
|
||||
# For now, just acknowledge that training was attempted
|
||||
# The EnhancedCNN model might need specific training data format
|
||||
logger.debug(f"CNN model {model_name} training acknowledged (basic train method available)")
|
||||
return True
|
||||
|
||||
@ -2514,9 +2591,7 @@ class TradingOrchestrator:
|
||||
# Use CNN adapter if available
|
||||
if hasattr(self, 'cnn_adapter') and self.cnn_adapter:
|
||||
try:
|
||||
cnn_start_time = time.time()
|
||||
result = self.cnn_adapter.predict(base_data)
|
||||
cnn_duration_ms = (time.time() - cnn_start_time) * 1000
|
||||
if result:
|
||||
# Extract action and probabilities from ModelOutput
|
||||
action = result.predictions.get('action', 'HOLD')
|
||||
|
Reference in New Issue
Block a user