inference_enabled, cleanup

This commit is contained in:
Dobromir Popov
2025-08-04 14:24:39 +03:00
parent 29382ac0db
commit e223bc90e9
39 changed files with 315 additions and 90858 deletions

View File

@ -1,365 +0,0 @@
"""
Dashboard CNN Integration
This module integrates the EnhancedCNNAdapter with the dashboard system,
providing real-time training, predictions, and performance metrics display.
"""
import logging
import time
import threading
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
from collections import deque
import numpy as np
from .enhanced_cnn_adapter import EnhancedCNNAdapter
from .standardized_data_provider import StandardizedDataProvider
from .data_models import BaseDataInput, ModelOutput, create_model_output
logger = logging.getLogger(__name__)
class DashboardCNNIntegration:
"""
CNN integration for the dashboard system
This class:
1. Manages CNN model lifecycle in the dashboard
2. Provides real-time training and inference
3. Tracks performance metrics for dashboard display
4. Handles model predictions for chart overlay
"""
def __init__(self, data_provider: StandardizedDataProvider, symbols: List[str] = None):
"""
Initialize the dashboard CNN integration
Args:
data_provider: Standardized data provider
symbols: List of symbols to process
"""
self.data_provider = data_provider
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
# Initialize CNN adapter
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
# Load best checkpoint if available
self.cnn_adapter.load_best_checkpoint()
# Performance tracking
self.performance_metrics = {
'total_predictions': 0,
'total_training_samples': 0,
'last_training_time': None,
'last_inference_time': None,
'training_loss_history': deque(maxlen=100),
'accuracy_history': deque(maxlen=100),
'inference_times': deque(maxlen=100),
'training_times': deque(maxlen=100),
'predictions_per_second': 0.0,
'training_per_second': 0.0,
'model_status': 'FRESH',
'confidence_history': deque(maxlen=100),
'action_distribution': {'BUY': 0, 'SELL': 0, 'HOLD': 0}
}
# Prediction cache for dashboard display
self.prediction_cache = {}
self.prediction_history = {symbol: deque(maxlen=1000) for symbol in self.symbols}
# Training control
self.training_enabled = True
self.inference_enabled = True
self.training_lock = threading.Lock()
# Real-time processing
self.is_running = False
self.processing_thread = None
logger.info(f"DashboardCNNIntegration initialized for symbols: {self.symbols}")
def start_real_time_processing(self):
"""Start real-time CNN processing"""
if self.is_running:
logger.warning("Real-time processing already running")
return
self.is_running = True
self.processing_thread = threading.Thread(target=self._real_time_processing_loop, daemon=True)
self.processing_thread.start()
logger.info("Started real-time CNN processing")
def stop_real_time_processing(self):
"""Stop real-time CNN processing"""
self.is_running = False
if self.processing_thread:
self.processing_thread.join(timeout=5)
logger.info("Stopped real-time CNN processing")
def _real_time_processing_loop(self):
"""Main real-time processing loop"""
last_prediction_time = {}
prediction_interval = 1.0 # Make prediction every 1 second
while self.is_running:
try:
current_time = time.time()
for symbol in self.symbols:
# Check if it's time to make a prediction for this symbol
if (symbol not in last_prediction_time or
current_time - last_prediction_time[symbol] >= prediction_interval):
# Make prediction if inference is enabled
if self.inference_enabled:
self._make_prediction(symbol)
last_prediction_time[symbol] = current_time
# Update performance metrics
self._update_performance_metrics()
# Sleep briefly to prevent overwhelming the system
time.sleep(0.1)
except Exception as e:
logger.error(f"Error in real-time processing loop: {e}")
time.sleep(1)
def _make_prediction(self, symbol: str):
"""Make a prediction for a symbol"""
try:
start_time = time.time()
# Get standardized input data
base_data = self.data_provider.get_base_data_input(symbol)
if base_data is None:
logger.debug(f"No base data available for {symbol}")
return
# Make prediction
model_output = self.cnn_adapter.predict(base_data)
# Record inference time
inference_time = time.time() - start_time
self.performance_metrics['inference_times'].append(inference_time)
# Update performance metrics
self.performance_metrics['total_predictions'] += 1
self.performance_metrics['last_inference_time'] = datetime.now()
self.performance_metrics['confidence_history'].append(model_output.confidence)
# Update action distribution
action = model_output.predictions['action']
self.performance_metrics['action_distribution'][action] += 1
# Cache prediction for dashboard
self.prediction_cache[symbol] = model_output
self.prediction_history[symbol].append(model_output)
# Store model output in data provider
self.data_provider.store_model_output(model_output)
logger.debug(f"CNN prediction for {symbol}: {action} ({model_output.confidence:.3f})")
except Exception as e:
logger.error(f"Error making prediction for {symbol}: {e}")
def add_training_sample(self, symbol: str, actual_action: str, reward: float):
"""Add a training sample and trigger training if enabled"""
try:
if not self.training_enabled:
return
# Get base data for the symbol
base_data = self.data_provider.get_base_data_input(symbol)
if base_data is None:
logger.debug(f"No base data available for training sample: {symbol}")
return
# Add training sample
self.cnn_adapter.add_training_sample(base_data, actual_action, reward)
# Update metrics
self.performance_metrics['total_training_samples'] += 1
# Train model periodically (every 10 samples)
if self.performance_metrics['total_training_samples'] % 10 == 0:
self._train_model()
except Exception as e:
logger.error(f"Error adding training sample: {e}")
def _train_model(self):
"""Train the CNN model"""
try:
with self.training_lock:
start_time = time.time()
# Train model
metrics = self.cnn_adapter.train(epochs=1)
# Record training time
training_time = time.time() - start_time
self.performance_metrics['training_times'].append(training_time)
# Update performance metrics
self.performance_metrics['last_training_time'] = datetime.now()
if 'loss' in metrics:
self.performance_metrics['training_loss_history'].append(metrics['loss'])
if 'accuracy' in metrics:
self.performance_metrics['accuracy_history'].append(metrics['accuracy'])
# Update model status
if metrics.get('accuracy', 0) > 0.5:
self.performance_metrics['model_status'] = 'TRAINED'
else:
self.performance_metrics['model_status'] = 'TRAINING'
logger.info(f"CNN training completed: loss={metrics.get('loss', 0):.4f}, accuracy={metrics.get('accuracy', 0):.4f}")
except Exception as e:
logger.error(f"Error training CNN model: {e}")
def _update_performance_metrics(self):
"""Update performance metrics for dashboard display"""
try:
current_time = time.time()
# Calculate predictions per second (last 60 seconds)
recent_inferences = [t for t in self.performance_metrics['inference_times']
if current_time - t <= 60]
self.performance_metrics['predictions_per_second'] = len(recent_inferences) / 60.0
# Calculate training per second (last 60 seconds)
recent_trainings = [t for t in self.performance_metrics['training_times']
if current_time - t <= 60]
self.performance_metrics['training_per_second'] = len(recent_trainings) / 60.0
except Exception as e:
logger.error(f"Error updating performance metrics: {e}")
def get_dashboard_metrics(self) -> Dict[str, Any]:
"""Get metrics for dashboard display"""
try:
# Calculate current loss
current_loss = (self.performance_metrics['training_loss_history'][-1]
if self.performance_metrics['training_loss_history'] else 0.0)
# Calculate current accuracy
current_accuracy = (self.performance_metrics['accuracy_history'][-1]
if self.performance_metrics['accuracy_history'] else 0.0)
# Calculate average confidence
avg_confidence = (np.mean(list(self.performance_metrics['confidence_history']))
if self.performance_metrics['confidence_history'] else 0.0)
# Get latest prediction
latest_prediction = None
latest_symbol = None
for symbol, prediction in self.prediction_cache.items():
if latest_prediction is None or prediction.timestamp > latest_prediction.timestamp:
latest_prediction = prediction
latest_symbol = symbol
# Format timing information
last_inference_str = "None"
last_training_str = "None"
if self.performance_metrics['last_inference_time']:
last_inference_str = self.performance_metrics['last_inference_time'].strftime("%H:%M:%S")
if self.performance_metrics['last_training_time']:
last_training_str = self.performance_metrics['last_training_time'].strftime("%H:%M:%S")
return {
'model_name': 'CNN',
'model_type': 'cnn',
'parameters': '50.0M',
'status': self.performance_metrics['model_status'],
'current_loss': current_loss,
'accuracy': current_accuracy,
'confidence': avg_confidence,
'total_predictions': self.performance_metrics['total_predictions'],
'total_training_samples': self.performance_metrics['total_training_samples'],
'predictions_per_second': self.performance_metrics['predictions_per_second'],
'training_per_second': self.performance_metrics['training_per_second'],
'last_inference': last_inference_str,
'last_training': last_training_str,
'latest_prediction': {
'action': latest_prediction.predictions['action'] if latest_prediction else 'HOLD',
'confidence': latest_prediction.confidence if latest_prediction else 0.0,
'symbol': latest_symbol or 'ETH/USDT',
'timestamp': latest_prediction.timestamp.strftime("%H:%M:%S") if latest_prediction else "None"
},
'action_distribution': self.performance_metrics['action_distribution'].copy(),
'training_enabled': self.training_enabled,
'inference_enabled': self.inference_enabled
}
except Exception as e:
logger.error(f"Error getting dashboard metrics: {e}")
return {
'model_name': 'CNN',
'model_type': 'cnn',
'parameters': '50.0M',
'status': 'ERROR',
'current_loss': 0.0,
'accuracy': 0.0,
'confidence': 0.0,
'error': str(e)
}
def get_predictions_for_chart(self, symbol: str, timeframe: str = '1s', limit: int = 100) -> List[Dict[str, Any]]:
"""Get predictions for chart overlay"""
try:
if symbol not in self.prediction_history:
return []
predictions = list(self.prediction_history[symbol])[-limit:]
chart_data = []
for prediction in predictions:
chart_data.append({
'timestamp': prediction.timestamp,
'action': prediction.predictions['action'],
'confidence': prediction.confidence,
'buy_probability': prediction.predictions.get('buy_probability', 0.0),
'sell_probability': prediction.predictions.get('sell_probability', 0.0),
'hold_probability': prediction.predictions.get('hold_probability', 0.0)
})
return chart_data
except Exception as e:
logger.error(f"Error getting predictions for chart: {e}")
return []
def set_training_enabled(self, enabled: bool):
"""Enable or disable training"""
self.training_enabled = enabled
logger.info(f"CNN training {'enabled' if enabled else 'disabled'}")
def set_inference_enabled(self, enabled: bool):
"""Enable or disable inference"""
self.inference_enabled = enabled
logger.info(f"CNN inference {'enabled' if enabled else 'disabled'}")
def get_model_info(self) -> Dict[str, Any]:
"""Get model information for dashboard"""
return {
'name': 'Enhanced CNN',
'version': '1.0',
'parameters': '50.0M',
'input_shape': self.cnn_adapter.model.input_shape if self.cnn_adapter.model else 'Unknown',
'device': str(self.cnn_adapter.device),
'checkpoint_dir': self.cnn_adapter.checkpoint_dir,
'training_samples': len(self.cnn_adapter.training_data),
'max_training_samples': self.cnn_adapter.max_training_samples
}

View File

@ -1,403 +0,0 @@
"""
Enhanced CNN Integration for Dashboard
This module integrates the EnhancedCNNAdapter with the dashboard, providing real-time
training and inference capabilities.
"""
import logging
import threading
import time
from datetime import datetime
from typing import Dict, List, Optional, Any, Union
import os
from .enhanced_cnn_adapter import EnhancedCNNAdapter
from .standardized_data_provider import StandardizedDataProvider
from .data_models import BaseDataInput, ModelOutput, create_model_output
logger = logging.getLogger(__name__)
class EnhancedCNNIntegration:
"""
Integration of EnhancedCNNAdapter with the dashboard
This class:
1. Manages the EnhancedCNNAdapter lifecycle
2. Provides real-time training and inference
3. Collects and reports performance metrics
4. Integrates with the dashboard's model visualization
"""
def __init__(self, data_provider: StandardizedDataProvider, checkpoint_dir: str = "models/enhanced_cnn"):
"""
Initialize the EnhancedCNNIntegration
Args:
data_provider: StandardizedDataProvider instance
checkpoint_dir: Directory to store checkpoints
"""
self.data_provider = data_provider
self.checkpoint_dir = checkpoint_dir
self.model_name = "enhanced_cnn_v1"
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
# Initialize CNN adapter
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=checkpoint_dir)
# Load best checkpoint if available
self.cnn_adapter.load_best_checkpoint()
# Performance tracking
self.inference_times = []
self.training_times = []
self.total_inferences = 0
self.total_training_runs = 0
self.last_inference_time = None
self.last_training_time = None
self.inference_rate = 0.0
self.training_rate = 0.0
self.daily_inferences = 0
self.daily_training_runs = 0
# Training settings
self.training_enabled = True
self.inference_enabled = True
self.training_frequency = 10 # Train every N inferences
self.training_batch_size = 32
self.training_epochs = 1
# Latest prediction
self.latest_prediction = None
self.latest_prediction_time = None
# Training metrics
self.current_loss = 0.0
self.initial_loss = None
self.best_loss = None
self.current_accuracy = 0.0
self.improvement_percentage = 0.0
# Training thread
self.training_thread = None
self.training_active = False
self.stop_training = False
logger.info(f"EnhancedCNNIntegration initialized with model: {self.model_name}")
def start_continuous_training(self):
"""Start continuous training in a background thread"""
if self.training_thread is not None and self.training_thread.is_alive():
logger.info("Continuous training already running")
return
self.stop_training = False
self.training_thread = threading.Thread(target=self._continuous_training_loop, daemon=True)
self.training_thread.start()
logger.info("Started continuous training thread")
def stop_continuous_training(self):
"""Stop continuous training"""
self.stop_training = True
logger.info("Stopping continuous training thread")
def _continuous_training_loop(self):
"""Continuous training loop"""
try:
self.training_active = True
logger.info("Starting continuous training loop")
while not self.stop_training:
# Check if training is enabled
if not self.training_enabled:
time.sleep(5)
continue
# Check if we have enough training samples
if len(self.cnn_adapter.training_data) < self.training_batch_size:
logger.debug(f"Not enough training samples: {len(self.cnn_adapter.training_data)}/{self.training_batch_size}")
time.sleep(5)
continue
# Train model
start_time = time.time()
metrics = self.cnn_adapter.train(epochs=self.training_epochs)
training_time = time.time() - start_time
# Update metrics
self.training_times.append(training_time)
if len(self.training_times) > 100:
self.training_times.pop(0)
self.total_training_runs += 1
self.daily_training_runs += 1
self.last_training_time = datetime.now()
# Calculate training rate
if self.training_times:
avg_training_time = sum(self.training_times) / len(self.training_times)
self.training_rate = 1.0 / avg_training_time if avg_training_time > 0 else 0.0
# Update loss and accuracy
self.current_loss = metrics.get('loss', 0.0)
self.current_accuracy = metrics.get('accuracy', 0.0)
# Update initial loss if not set
if self.initial_loss is None:
self.initial_loss = self.current_loss
# Update best loss
if self.best_loss is None or self.current_loss < self.best_loss:
self.best_loss = self.current_loss
# Calculate improvement percentage
if self.initial_loss is not None and self.initial_loss > 0:
self.improvement_percentage = ((self.initial_loss - self.current_loss) / self.initial_loss) * 100
logger.info(f"Training completed: loss={self.current_loss:.4f}, accuracy={self.current_accuracy:.4f}, samples={metrics.get('samples', 0)}")
# Sleep before next training
time.sleep(10)
except Exception as e:
logger.error(f"Error in continuous training loop: {e}")
finally:
self.training_active = False
def predict(self, symbol: str) -> Optional[ModelOutput]:
"""
Make a prediction using the EnhancedCNN model
Args:
symbol: Trading symbol
Returns:
ModelOutput: Standardized model output
"""
try:
# Check if inference is enabled
if not self.inference_enabled:
return None
# Get standardized input data
base_data = self.data_provider.get_base_data_input(symbol)
if base_data is None:
logger.warning(f"Failed to get base data input for {symbol}")
return None
# Make prediction
start_time = time.time()
model_output = self.cnn_adapter.predict(base_data)
inference_time = time.time() - start_time
# Update metrics
self.inference_times.append(inference_time)
if len(self.inference_times) > 100:
self.inference_times.pop(0)
self.total_inferences += 1
self.daily_inferences += 1
self.last_inference_time = datetime.now()
# Calculate inference rate
if self.inference_times:
avg_inference_time = sum(self.inference_times) / len(self.inference_times)
self.inference_rate = 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0
# Store latest prediction
self.latest_prediction = model_output
self.latest_prediction_time = datetime.now()
# Store model output in data provider
self.data_provider.store_model_output(model_output)
# Add training sample if we have a price
current_price = self._get_current_price(symbol)
if current_price and current_price > 0:
# Simulate market feedback based on price movement
# In a real system, this would be replaced with actual market performance data
action = model_output.predictions['action']
# For demonstration, we'll use a simple heuristic:
# - If price is above 3000, BUY is good
# - If price is below 3000, SELL is good
# - Otherwise, HOLD is good
if current_price > 3000:
best_action = 'BUY'
elif current_price < 3000:
best_action = 'SELL'
else:
best_action = 'HOLD'
# Calculate reward based on whether the action matched the best action
if action == best_action:
reward = 0.05 # Positive reward for correct action
else:
reward = -0.05 # Negative reward for incorrect action
# Add training sample
self.cnn_adapter.add_training_sample(base_data, best_action, reward)
logger.debug(f"Added training sample for {symbol}, action: {action}, best_action: {best_action}, reward: {reward:.4f}")
return model_output
except Exception as e:
logger.error(f"Error making prediction: {e}")
return None
def _get_current_price(self, symbol: str) -> Optional[float]:
"""Get current price for a symbol"""
try:
# Try to get price from data provider
if hasattr(self.data_provider, 'current_prices'):
binance_symbol = symbol.replace('/', '').upper()
if binance_symbol in self.data_provider.current_prices:
return self.data_provider.current_prices[binance_symbol]
# Try to get price from latest OHLCV data
df = self.data_provider.get_historical_data(symbol, '1s', 1)
if df is not None and not df.empty:
return float(df.iloc[-1]['close'])
return None
except Exception as e:
logger.error(f"Error getting current price: {e}")
return None
def get_model_state(self) -> Dict[str, Any]:
"""
Get model state for dashboard display
Returns:
Dict[str, Any]: Model state
"""
try:
# Format prediction for display
prediction_info = "FRESH"
confidence = 0.0
if self.latest_prediction:
action = self.latest_prediction.predictions.get('action', 'UNKNOWN')
confidence = self.latest_prediction.confidence
# Map action to display text
if action == 'BUY':
prediction_info = "BUY_SIGNAL"
elif action == 'SELL':
prediction_info = "SELL_SIGNAL"
elif action == 'HOLD':
prediction_info = "HOLD_SIGNAL"
else:
prediction_info = "PATTERN_ANALYSIS"
# Format timing information
inference_timing = "None"
training_timing = "None"
if self.last_inference_time:
inference_timing = self.last_inference_time.strftime('%H:%M:%S')
if self.last_training_time:
training_timing = self.last_training_time.strftime('%H:%M:%S')
# Calculate improvement percentage
improvement = 0.0
if self.initial_loss is not None and self.initial_loss > 0 and self.current_loss > 0:
improvement = ((self.initial_loss - self.current_loss) / self.initial_loss) * 100
return {
'model_name': self.model_name,
'model_type': 'cnn',
'parameters': 50000000, # 50M parameters
'status': 'ACTIVE' if self.inference_enabled else 'DISABLED',
'checkpoint_loaded': True, # Assume checkpoint is loaded
'last_prediction': prediction_info,
'confidence': confidence * 100, # Convert to percentage
'last_inference_time': inference_timing,
'last_training_time': training_timing,
'inference_rate': self.inference_rate,
'training_rate': self.training_rate,
'daily_inferences': self.daily_inferences,
'daily_training_runs': self.daily_training_runs,
'initial_loss': self.initial_loss,
'current_loss': self.current_loss,
'best_loss': self.best_loss,
'current_accuracy': self.current_accuracy,
'improvement_percentage': improvement,
'training_active': self.training_active,
'training_enabled': self.training_enabled,
'inference_enabled': self.inference_enabled,
'training_samples': len(self.cnn_adapter.training_data)
}
except Exception as e:
logger.error(f"Error getting model state: {e}")
return {
'model_name': self.model_name,
'model_type': 'cnn',
'parameters': 50000000, # 50M parameters
'status': 'ERROR',
'error': str(e)
}
def get_pivot_prediction(self) -> Dict[str, Any]:
"""
Get pivot prediction for dashboard display
Returns:
Dict[str, Any]: Pivot prediction
"""
try:
if not self.latest_prediction:
return {
'next_pivot': 0.0,
'pivot_type': 'UNKNOWN',
'confidence': 0.0,
'time_to_pivot': 0
}
# Extract pivot prediction from model output
extrema_pred = self.latest_prediction.predictions.get('extrema', [0, 0, 0])
# Determine pivot type (0=bottom, 1=top, 2=neither)
pivot_type_idx = extrema_pred.index(max(extrema_pred))
pivot_types = ['BOTTOM', 'TOP', 'RANGE_CONTINUATION']
pivot_type = pivot_types[pivot_type_idx]
# Get current price
current_price = self._get_current_price('ETH/USDT') or 0.0
# Calculate next pivot price (simple heuristic for demonstration)
if pivot_type == 'BOTTOM':
next_pivot = current_price * 0.95 # 5% below current price
elif pivot_type == 'TOP':
next_pivot = current_price * 1.05 # 5% above current price
else:
next_pivot = current_price # Same as current price
# Calculate confidence
confidence = max(extrema_pred) * 100 # Convert to percentage
# Calculate time to pivot (simple heuristic for demonstration)
time_to_pivot = 5 # 5 minutes
return {
'next_pivot': next_pivot,
'pivot_type': pivot_type,
'confidence': confidence,
'time_to_pivot': time_to_pivot
}
except Exception as e:
logger.error(f"Error getting pivot prediction: {e}")
return {
'next_pivot': 0.0,
'pivot_type': 'ERROR',
'confidence': 0.0,
'time_to_pivot': 0
}

View File

@ -3369,12 +3369,17 @@ class TradingOrchestrator:
)
logger.info(f" Outcome: {outcome_status}")
# Add performance summary
# Add comprehensive performance summary
if model_name in self.model_performance:
perf = self.model_performance[model_name]
logger.info(
f" Performance: {perf['accuracy']:.1%} ({perf['correct']}/{perf['total']})"
f" Performance: {perf['directional_accuracy']:.1%} directional ({perf['directional_correct']}/{perf['total']}) | "
f"{perf['accuracy']:.1%} profitable ({perf['correct']}/{perf['total']})"
)
if perf["pivot_attempted"] > 0:
logger.info(
f" Pivot Detection: {perf['pivot_accuracy']:.1%} ({perf['pivot_detected']}/{perf['pivot_attempted']})"
)
except Exception as e:
logger.error(f"Error in immediate training for {model_name}: {e}")
@ -3453,32 +3458,62 @@ class TradingOrchestrator:
predicted_price_vector=predicted_price_vector,
)
# Update model performance tracking
# Initialize enhanced model performance tracking
if model_name not in self.model_performance:
self.model_performance[model_name] = {
"correct": 0,
"correct": 0, # Profitability accuracy (backwards compatible)
"total": 0,
"accuracy": 0.0,
"accuracy": 0.0, # Profitability accuracy (backwards compatible)
"directional_correct": 0, # NEW: Directional accuracy
"directional_accuracy": 0.0, # NEW: Directional accuracy %
"pivot_detected": 0, # NEW: Successful pivot detections
"pivot_attempted": 0, # NEW: Total pivot attempts
"pivot_accuracy": 0.0, # NEW: Pivot detection accuracy
"price_predictions": {"total": 0, "accurate": 0, "avg_error": 0.0},
}
# Ensure all new keys exist (for existing models)
perf = self.model_performance[model_name]
if "directional_correct" not in perf:
perf["directional_correct"] = 0
perf["directional_accuracy"] = 0.0
perf["pivot_detected"] = 0
perf["pivot_attempted"] = 0
perf["pivot_accuracy"] = 0.0
# Ensure price_predictions key exists
if "price_predictions" not in self.model_performance[model_name]:
self.model_performance[model_name]["price_predictions"] = {
"total": 0,
"accurate": 0,
"avg_error": 0.0,
}
if "price_predictions" not in perf:
perf["price_predictions"] = {"total": 0, "accurate": 0, "avg_error": 0.0}
self.model_performance[model_name]["total"] += 1
if was_correct:
self.model_performance[model_name]["correct"] += 1
self.model_performance[model_name]["accuracy"] = (
self.model_performance[model_name]["correct"]
/ self.model_performance[model_name]["total"]
# Calculate directional accuracy separately
directional_correct = (
(predicted_action == "BUY" and price_change_pct > 0) or
(predicted_action == "SELL" and price_change_pct < 0) or
(predicted_action == "HOLD" and abs(price_change_pct) < 0.05)
)
# Update all accuracy metrics
perf["total"] += 1
if was_correct: # Profitability accuracy
perf["correct"] += 1
if directional_correct:
perf["directional_correct"] += 1
# Update pivot detection tracking
is_significant_move = abs(price_change_pct) > 0.08 # 0.08% threshold for "significant"
if predicted_action in ["BUY", "SELL"] and is_significant_move:
perf["pivot_attempted"] += 1
if directional_correct:
perf["pivot_detected"] += 1
# Calculate all accuracy percentages
perf["accuracy"] = perf["correct"] / perf["total"] # Profitability accuracy
perf["directional_accuracy"] = perf["directional_correct"] / perf["total"] # Directional accuracy
if perf["pivot_attempted"] > 0:
perf["pivot_accuracy"] = perf["pivot_detected"] / perf["pivot_attempted"] # Pivot accuracy
else:
perf["pivot_accuracy"] = 0.0
# Track price prediction accuracy if available
if inference_price is not None:
price_prediction_stats = self.model_performance[model_name][
@ -3504,7 +3539,8 @@ class TradingOrchestrator:
f"({price_prediction_stats['avg_error']:.2f}% avg error)"
)
# Enhanced logging for training evaluation
# Enhanced logging with new accuracy metrics
perf = self.model_performance[model_name]
logger.info(f"Training evaluation for {model_name}:")
logger.info(
f" Action: {predicted_action} | Confidence: {prediction_confidence:.3f}"
@ -3512,10 +3548,15 @@ class TradingOrchestrator:
logger.info(
f" Price change: {price_change_pct:+.3f}% | Time: {time_diff_seconds:.1f}s"
)
logger.info(f" Reward: {reward:.4f} | Correct: {was_correct}")
logger.info(f" Reward: {reward:.4f} | Profitable: {was_correct} | Directional: {directional_correct}")
logger.info(
f" Accuracy: {self.model_performance[model_name]['accuracy']:.1%} ({self.model_performance[model_name]['correct']}/{self.model_performance[model_name]['total']})"
f" Profitability: {perf['accuracy']:.1%} ({perf['correct']}/{perf['total']}) | "
f"Directional: {perf['directional_accuracy']:.1%} ({perf['directional_correct']}/{perf['total']})"
)
if perf["pivot_attempted"] > 0:
logger.info(
f" Pivot Detection: {perf['pivot_accuracy']:.1%} ({perf['pivot_detected']}/{perf['pivot_attempted']})"
)
# Train the specific model based on sophisticated outcome
await self._train_model_on_outcome(
@ -3549,6 +3590,45 @@ class TradingOrchestrator:
except Exception as e:
logger.error(f"Error evaluating and training on record: {e}")
def _is_pivot_point(self, price_change_pct: float, prediction_confidence: float, time_diff_minutes: float) -> tuple[bool, str, float]:
"""
Detect if this is a significant pivot point worth trading.
Pivot points are the key moments where markets change direction or momentum.
Returns:
tuple: (is_pivot, pivot_type, pivot_strength)
"""
abs_change = abs(price_change_pct)
# Pivot point thresholds (much more realistic for crypto)
minor_pivot = 0.08 # 0.08% - small but tradeable pivot
medium_pivot = 0.25 # 0.25% - significant pivot
major_pivot = 0.6 # 0.6% - major pivot
massive_pivot = 1.2 # 1.2% - massive pivot
# Time-based multipliers (faster pivots are more valuable)
time_multiplier = 1.0
if time_diff_minutes < 2.0: # Very fast pivot
time_multiplier = 2.0
elif time_diff_minutes < 5.0: # Fast pivot
time_multiplier = 1.5
elif time_diff_minutes > 15.0: # Slow pivot - less valuable
time_multiplier = 0.7
# Confidence multiplier (high confidence pivots are more valuable)
confidence_multiplier = 0.5 + (prediction_confidence * 1.5) # 0.5 to 2.0
if abs_change >= massive_pivot:
return True, "MASSIVE_PIVOT", 10.0 * time_multiplier * confidence_multiplier
elif abs_change >= major_pivot:
return True, "MAJOR_PIVOT", 5.0 * time_multiplier * confidence_multiplier
elif abs_change >= medium_pivot:
return True, "MEDIUM_PIVOT", 2.5 * time_multiplier * confidence_multiplier
elif abs_change >= minor_pivot:
return True, "MINOR_PIVOT", 1.2 * time_multiplier * confidence_multiplier
else:
return False, "NO_PIVOT", 0.1 # Very small reward for noise
def _calculate_sophisticated_reward(
self,
predicted_action: str,
@ -3562,11 +3642,19 @@ class TradingOrchestrator:
predicted_price_vector: dict = None,
) -> tuple[float, bool]:
"""
Calculate sophisticated reward based on prediction accuracy, confidence, and price movement magnitude
Now considers position status and current P&L when evaluating decisions
NOISE REDUCTION: Treats neutral/low-confidence signals as HOLD to reduce training noise
PRICE VECTOR BONUS: Rewards accurate price direction and magnitude predictions
PIVOT-POINT FOCUSED REWARD SYSTEM
This system heavily rewards models for correctly identifying pivot points -
the actual profitable trading opportunities in the market. Small movements
are treated as noise and given minimal rewards.
Key Features:
- Separate directional accuracy vs profitability accuracy tracking
- Heavy rewards for successful pivot point detection
- Minimal penalties for noise (small movements)
- Time-weighted rewards (faster detection = better)
- Confidence-weighted rewards (higher confidence = better)
Args:
predicted_action: The predicted action ('BUY', 'SELL', 'HOLD')
prediction_confidence: Model's confidence in the prediction (0.0 to 1.0)
@ -3579,21 +3667,36 @@ class TradingOrchestrator:
predicted_price_vector: Dict with 'direction' (-1 to 1) and 'confidence' (0 to 1)
Returns:
tuple: (reward, was_correct)
tuple: (reward, directional_correct, profitability_correct, pivot_detected)
"""
try:
# NOISE REDUCTION: Treat low-confidence signals as HOLD
confidence_threshold = 0.6 # Only consider BUY/SELL if confidence > 60%
if prediction_confidence < confidence_threshold:
predicted_action = "HOLD"
logger.debug(f"Low confidence ({prediction_confidence:.2f}) - treating as HOLD for noise reduction")
# Store original action for directional accuracy tracking
original_action = predicted_action
# FEE-AWARE THRESHOLDS: Account for trading fees (0.05-0.06% per trade, ~0.12% round trip)
fee_cost = 0.12 # 0.12% round trip fee cost
movement_threshold = 0.15 # Minimum movement to be profitable after fees
strong_movement_threshold = 0.5 # Strong movements - good profit potential
rapid_movement_threshold = 1.0 # Rapid movements - excellent profit potential
massive_movement_threshold = 2.0 # Massive movements - extraordinary profit potential
# PIVOT POINT DETECTION
is_pivot, pivot_type, pivot_strength = self._is_pivot_point(
price_change_pct, prediction_confidence, time_diff_minutes
)
# DIRECTIONAL ACCURACY (simple direction prediction)
directional_correct = False
if predicted_action == "BUY" and price_change_pct > 0:
directional_correct = True
elif predicted_action == "SELL" and price_change_pct < 0:
directional_correct = True
elif predicted_action == "HOLD" and abs(price_change_pct) < 0.05: # Very small movement
directional_correct = True
# PROFITABILITY ACCURACY (fee-aware profitable trades)
fee_cost = 0.10 # 0.10% round trip fee cost (realistic for most exchanges)
profitability_correct = False
if predicted_action == "BUY" and price_change_pct > fee_cost:
profitability_correct = True
elif predicted_action == "SELL" and price_change_pct < -fee_cost:
profitability_correct = True
elif predicted_action == "HOLD" and abs(price_change_pct) < fee_cost:
profitability_correct = True
# Determine current position status if not provided
if has_position is None and symbol:
@ -3604,210 +3707,104 @@ class TradingOrchestrator:
elif has_position is None:
has_position = False
# Determine if prediction was directionally correct
was_correct = False
directional_accuracy = 0.0
if predicted_action == "BUY":
# BUY signals need to overcome fee costs for profitability
was_correct = price_change_pct > movement_threshold
# PIVOT POINT REWARD CALCULATION
base_reward = 0.0
pivot_bonus = 0.0
# For backwards compatibility, use profitability_correct as the main "was_correct"
was_correct = profitability_correct
# MASSIVE REWARDS FOR SUCCESSFUL PIVOT POINT DETECTION
if is_pivot and directional_correct:
# Base pivot reward
base_reward = pivot_strength
# ENHANCED FEE-AWARE REWARD STRUCTURE
if price_change_pct > massive_movement_threshold:
# Massive movements (2%+) - EXTRAORDINARY rewards for high confidence
directional_accuracy = price_change_pct * 5.0 # 5x multiplier for massive moves
if prediction_confidence > 0.8:
directional_accuracy *= 2.0 # Additional 2x for high confidence (10x total)
elif price_change_pct > rapid_movement_threshold:
# Rapid movements (1%+) - EXCELLENT rewards for high confidence
directional_accuracy = price_change_pct * 3.0 # 3x multiplier for rapid moves
if prediction_confidence > 0.7:
directional_accuracy *= 1.5 # Additional 1.5x for good confidence (4.5x total)
elif price_change_pct > strong_movement_threshold:
# Strong movements (0.5%+) - GOOD rewards
directional_accuracy = price_change_pct * 2.0 # 2x multiplier for strong moves
else:
# Small movements - minimal rewards (fees eat most profit)
directional_accuracy = max(0, (price_change_pct - fee_cost)) * 0.5 # Penalty for fee cost
# EXTRAORDINARY bonuses for successful pivot predictions
if pivot_type == "MASSIVE_PIVOT":
pivot_bonus = 50.0 * prediction_confidence # Up to 50x reward!
logger.info(f"MASSIVE PIVOT SUCCESS: {pivot_type} detected with {prediction_confidence:.2f} confidence = {pivot_bonus:.1f}x bonus!")
elif pivot_type == "MAJOR_PIVOT":
pivot_bonus = 20.0 * prediction_confidence # Up to 20x reward!
logger.info(f"MAJOR PIVOT SUCCESS: {pivot_type} detected with {prediction_confidence:.2f} confidence = {pivot_bonus:.1f}x bonus!")
elif pivot_type == "MEDIUM_PIVOT":
pivot_bonus = 8.0 * prediction_confidence # Up to 8x reward!
logger.info(f"MEDIUM PIVOT SUCCESS: {pivot_type} detected with {prediction_confidence:.2f} confidence = {pivot_bonus:.1f}x bonus!")
elif pivot_type == "MINOR_PIVOT":
pivot_bonus = 3.0 * prediction_confidence # Up to 3x reward!
logger.info(f"MINOR PIVOT SUCCESS: {pivot_type} detected with {prediction_confidence:.2f} confidence = {pivot_bonus:.1f}x bonus!")
elif predicted_action == "SELL":
# SELL signals need to overcome fee costs for profitability
was_correct = price_change_pct < -movement_threshold
# Additional time-based bonus for early detection
if time_diff_minutes < 1.0:
time_bonus = pivot_bonus * 0.5 # 50% bonus for very fast detection
pivot_bonus += time_bonus
logger.info(f"EARLY DETECTION BONUS: Detected {pivot_type} in {time_diff_minutes:.1f}m = +{time_bonus:.1f} bonus")
base_reward += pivot_bonus
elif is_pivot and not directional_correct:
# MODERATE penalty for missing pivot points (still valuable to learn from)
base_reward = -pivot_strength * 0.3 # Small penalty to encourage learning
logger.debug(f"MISSED PIVOT: {pivot_type} missed, small penalty = {base_reward:.2f}")
elif not is_pivot and directional_correct:
# Small reward for correct direction on non-pivots (noise)
base_reward = 0.2 * prediction_confidence
logger.debug(f"NOISE CORRECT: Correct direction on noise movement = {base_reward:.2f}")
# ENHANCED FEE-AWARE REWARD STRUCTURE (symmetric to BUY)
abs_change = abs(price_change_pct)
if abs_change > massive_movement_threshold:
# Massive movements (2%+) - EXTRAORDINARY rewards for high confidence
directional_accuracy = abs_change * 5.0 # 5x multiplier for massive moves
if prediction_confidence > 0.8:
directional_accuracy *= 2.0 # Additional 2x for high confidence (10x total)
elif abs_change > rapid_movement_threshold:
# Rapid movements (1%+) - EXCELLENT rewards for high confidence
directional_accuracy = abs_change * 3.0 # 3x multiplier for rapid moves
if prediction_confidence > 0.7:
directional_accuracy *= 1.5 # Additional 1.5x for good confidence (4.5x total)
elif abs_change > strong_movement_threshold:
# Strong movements (0.5%+) - GOOD rewards
directional_accuracy = abs_change * 2.0 # 2x multiplier for strong moves
else:
# Small movements - minimal rewards (fees eat most profit)
directional_accuracy = max(0, (abs_change - fee_cost)) * 0.5 # Penalty for fee cost
elif predicted_action == "HOLD":
# HOLD evaluation with noise reduction - smaller rewards to reduce training noise
if has_position:
# If we have a position, HOLD evaluation depends on P&L and price movement
if current_position_pnl > 0: # Currently profitable position
# Holding a profitable position is good if price continues favorably
if price_change_pct > 0: # Price went up while holding profitable position - excellent
was_correct = True
directional_accuracy = price_change_pct * 0.8 # Reduced from 1.5 to reduce noise
elif abs(price_change_pct) < movement_threshold: # Price stable - good
was_correct = True
directional_accuracy = movement_threshold * 0.5 # Reduced reward to reduce noise
else: # Price dropped while holding profitable position - still okay but less reward
was_correct = True
directional_accuracy = max(0, (current_position_pnl / 100.0) - abs(price_change_pct) * 0.3)
elif current_position_pnl < 0: # Currently losing position
# Holding a losing position is generally bad - should consider closing
if price_change_pct > movement_threshold: # Price recovered - good hold
was_correct = True
directional_accuracy = price_change_pct * 0.6 # Reduced reward
else: # Price continued down or stayed flat - bad hold
was_correct = False
# Penalty proportional to loss magnitude
directional_accuracy = abs(current_position_pnl / 100.0) * 0.3 # Reduced penalty
else: # Breakeven position
# Standard HOLD evaluation for breakeven positions
if abs(price_change_pct) < movement_threshold: # Price stable - good
was_correct = True
directional_accuracy = movement_threshold * 0.4 # Reduced reward
else: # Price moved significantly - missed opportunity
was_correct = False
directional_accuracy = max(0, movement_threshold - abs(price_change_pct)) * 0.5
else:
# If we don't have a position, HOLD is correct if price stayed relatively stable
was_correct = abs(price_change_pct) < movement_threshold
directional_accuracy = max(0, movement_threshold - abs(price_change_pct)) * 0.4 # Reduced reward
# Calculate FEE-AWARE magnitude-based multiplier (aggressive rewards for profitable movements)
abs_movement = abs(price_change_pct)
if abs_movement > massive_movement_threshold:
magnitude_multiplier = min(abs_movement / 1.0, 8.0) # Up to 8x for massive moves (8% = 8x)
elif abs_movement > rapid_movement_threshold:
magnitude_multiplier = min(abs_movement / 1.5, 4.0) # Up to 4x for rapid moves (6% = 4x)
elif abs_movement > strong_movement_threshold:
magnitude_multiplier = min(abs_movement / 2.0, 2.0) # Up to 2x for strong moves (4% = 2x)
else:
# Small movements get minimal multiplier due to fees
magnitude_multiplier = max(0.1, (abs_movement - fee_cost) / 2.0) # Penalty for fee cost
# Calculate confidence-based reward adjustment
if was_correct:
# Reward confident correct predictions more, penalize unconfident correct predictions less
confidence_multiplier = 0.5 + (
prediction_confidence * 1.5
) # Range: 0.5 to 2.0
base_reward = (
directional_accuracy * magnitude_multiplier * confidence_multiplier
# Very small penalty for wrong direction on noise (don't overtrain on noise)
base_reward = -0.1 * prediction_confidence
logger.debug(f"NOISE INCORRECT: Wrong direction on noise movement = {base_reward:.2f}")
# POSITION-AWARE ADJUSTMENTS
if has_position:
# Adjust rewards based on current position status
if current_position_pnl > 0.5: # Profitable position
if predicted_action == "HOLD" and price_change_pct > 0:
base_reward += 0.5 # Bonus for holding profitable position during uptrend
logger.debug(f"POSITION BONUS: Holding profitable position during uptrend = +0.5")
elif current_position_pnl < -0.5: # Losing position
if predicted_action in ["BUY", "SELL"] and directional_correct:
base_reward += 0.3 # Bonus for taking action to exit losing position
logger.debug(f"EXIT BONUS: Taking action on losing position = +0.3")
# PRICE VECTOR BONUS (if available)
if predicted_price_vector and isinstance(predicted_price_vector, dict):
vector_bonus = self._calculate_price_vector_bonus(
predicted_price_vector, price_change_pct, abs(price_change_pct), prediction_confidence
)
if vector_bonus > 0:
base_reward += vector_bonus
logger.debug(f"PRICE VECTOR BONUS: +{vector_bonus:.3f}")
# ENHANCED HIGH-CONFIDENCE BONUSES for profitable movements
abs_movement = abs(price_change_pct)
# Extraordinary confidence bonus for massive movements
if prediction_confidence > 0.9 and abs_movement > massive_movement_threshold:
base_reward *= 3.0 # 300% bonus for ultra-confident massive moves
logger.info(f"ULTRA CONFIDENCE BONUS: {prediction_confidence:.2f} confidence + {abs_movement:.2f}% movement = 3x reward")
# Excellent confidence bonus for rapid movements
elif prediction_confidence > 0.8 and abs_movement > rapid_movement_threshold:
base_reward *= 2.0 # 200% bonus for very confident rapid moves
logger.info(f"HIGH CONFIDENCE BONUS: {prediction_confidence:.2f} confidence + {abs_movement:.2f}% movement = 2x reward")
# Good confidence bonus for strong movements
elif prediction_confidence > 0.7 and abs_movement > strong_movement_threshold:
base_reward *= 1.5 # 150% bonus for confident strong moves
logger.info(f"CONFIDENCE BONUS: {prediction_confidence:.2f} confidence + {abs_movement:.2f}% movement = 1.5x reward")
# Rapid movement detection bonus (speed matters for fees)
if time_diff_minutes < 5.0 and abs_movement > rapid_movement_threshold:
base_reward *= 1.3 # 30% bonus for rapid detection of big moves
logger.info(f"RAPID DETECTION BONUS: {abs_movement:.2f}% movement in {time_diff_minutes:.1f}m = 1.3x reward")
# PRICE VECTOR ACCURACY BONUS - Reward models for accurate price direction/magnitude predictions
if predicted_price_vector and isinstance(predicted_price_vector, dict):
vector_bonus = self._calculate_price_vector_bonus(
predicted_price_vector, price_change_pct, abs_movement, prediction_confidence
)
if vector_bonus > 0:
base_reward += vector_bonus
logger.info(f"PRICE VECTOR BONUS: +{vector_bonus:.3f} for accurate direction/magnitude prediction")
else:
# ENHANCED PENALTY SYSTEM: Discourage fee-losing trades
abs_movement = abs(price_change_pct)
# Penalize incorrect predictions more severely if they were confident
confidence_penalty = 0.5 + (prediction_confidence * 1.5) # Higher confidence = higher penalty
base_penalty = abs_movement * confidence_penalty
# SEVERE penalties for confident wrong predictions on big moves
if prediction_confidence > 0.8 and abs_movement > rapid_movement_threshold:
base_penalty *= 5.0 # 5x penalty for very confident wrong on big moves
logger.warning(f"SEVERE PENALTY: {prediction_confidence:.2f} confidence wrong on {abs_movement:.2f}% movement = 5x penalty")
elif prediction_confidence > 0.7 and abs_movement > strong_movement_threshold:
base_penalty *= 3.0 # 3x penalty for confident wrong on strong moves
logger.warning(f"HIGH PENALTY: {prediction_confidence:.2f} confidence wrong on {abs_movement:.2f}% movement = 3x penalty")
elif prediction_confidence > 0.8:
base_penalty *= 2.0 # 2x penalty for overconfident wrong predictions
# ADDITIONAL penalty for predictions that would lose money to fees
if abs_movement < fee_cost and prediction_confidence > 0.5:
fee_loss_penalty = (fee_cost - abs_movement) * 2.0 # Penalty for fee-losing trades
base_penalty += fee_loss_penalty
logger.warning(f"FEE LOSS PENALTY: {abs_movement:.2f}% movement < {fee_cost:.2f}% fees = +{fee_loss_penalty:.3f} penalty")
base_reward = -base_penalty
# Time decay factor (predictions should be evaluated quickly)
time_decay = max(
0.1, 1.0 - (time_diff_minutes / 60.0)
) # Decay over 1 hour, min 10%
# Final reward calculation
# Time decay factor (pivot detection should be fast)
time_decay = max(0.3, 1.0 - (time_diff_minutes / 30.0)) # Decay over 30 minutes, min 30%
# Apply time decay
final_reward = base_reward * time_decay
# Bonus for accurate price predictions
if (
has_price_prediction and abs(price_change_pct) < 1.0
): # Accurate price prediction
final_reward *= 1.2 # 20% bonus for accurate price predictions
logger.debug(
f"Applied price prediction accuracy bonus: {final_reward:.3f}"
)
# Clamp reward to reasonable range
final_reward = max(-5.0, min(5.0, final_reward))
# Clamp reward to reasonable range (higher range for pivot bonuses)
final_reward = max(-10.0, min(100.0, final_reward))
# Log detailed accuracy information
logger.debug(
f"REWARD CALCULATION: action={predicted_action}, confidence={prediction_confidence:.3f}, "
f"price_change={price_change_pct:.3f}%, pivot={is_pivot}/{pivot_type}, "
f"directional_correct={directional_correct}, profitability_correct={profitability_correct}, "
f"reward={final_reward:.3f}"
)
return final_reward, was_correct
except Exception as e:
logger.error(f"Error calculating sophisticated reward: {e}")
# Fallback to simple reward with position awareness
has_position = self._has_open_position(symbol) if symbol else False
if predicted_action == "HOLD" and has_position:
# If holding a position, HOLD is correct if price didn't drop significantly
simple_correct = price_change_pct > -0.2 # Allow small losses while holding
else:
# Standard evaluation for other cases
simple_correct = (
(predicted_action == "BUY" and price_change_pct > 0.1)
or (predicted_action == "SELL" and price_change_pct < -0.1)
or (predicted_action == "HOLD" and abs(price_change_pct) < 0.1)
)
return (1.0 if simple_correct else -0.5, simple_correct)
# Fallback to simple directional accuracy
simple_correct = (
(predicted_action == "BUY" and price_change_pct > 0) or
(predicted_action == "SELL" and price_change_pct < 0) or
(predicted_action == "HOLD" and abs(price_change_pct) < 0.05)
)
return (1.0 if simple_correct else -0.1, simple_correct)
def _calculate_price_vector_bonus(
self,
@ -4334,6 +4331,25 @@ class TradingOrchestrator:
# Create training sample from record
model_input = record.get("model_input")
# If model_input is None, try to generate fresh state for training
if model_input is None:
logger.debug(f"No stored model input for {model_name}, generating fresh state")
try:
# Generate fresh input state for training
if hasattr(self, 'data_provider') and self.data_provider:
# Use data provider to generate current market state
fresh_state = self._generate_fresh_state_fallback(model_name)
if fresh_state is not None and len(fresh_state) > 0:
model_input = fresh_state
logger.debug(f"Generated fresh training state for {model_name}: shape={fresh_state.shape if hasattr(fresh_state, 'shape') else len(fresh_state)}")
else:
logger.warning(f"Failed to generate fresh state for {model_name}")
else:
logger.warning(f"No data provider available for generating fresh state for {model_name}")
except Exception as e:
logger.warning(f"Error generating fresh state for {model_name}: {e}")
if model_input is not None:
# Convert to tensor and ensure device placement
device = next(self.cnn_model.parameters()).device
@ -4432,7 +4448,71 @@ class TradingOrchestrator:
)
return True
else:
logger.warning(f"No model input available for CNN training")
logger.warning(f"No model input available for CNN training for {model_name}. This prevents the model from learning.")
# Try one more time to generate training data from current market conditions
try:
if hasattr(self, 'data_provider') and self.data_provider:
# Create minimal training sample from current market data
symbol = record.get("symbol", "ETH/USDT")
current_price = self._get_current_price(symbol)
# Get variables from function scope
actual_action = prediction["action"]
pred_confidence = prediction.get("confidence", 0.5)
# Create a basic feature vector (this is a fallback)
basic_features = np.array([
current_price / 10000.0, # Normalized price
pred_confidence, # Model confidence
reward, # Current reward
1.0 if actual_action == "BUY" else 0.0,
1.0 if actual_action == "SELL" else 0.0,
1.0 if actual_action == "HOLD" else 0.0
], dtype=np.float32)
# Pad to expected size if needed
expected_size = 512 # Adjust based on your model's expected input size
if len(basic_features) < expected_size:
padding = np.zeros(expected_size - len(basic_features), dtype=np.float32)
basic_features = np.concatenate([basic_features, padding])
logger.info(f"Created fallback training features for {model_name}: shape={basic_features.shape}")
# Now perform training with the fallback features
device = next(self.cnn_model.parameters()).device
features_tensor = torch.tensor(basic_features, dtype=torch.float32, device=device).unsqueeze(0)
# Convert action to index
actions = ["BUY", "SELL", "HOLD"]
action_idx = actions.index(actual_action) if actual_action in actions else 2
action_tensor = torch.tensor([action_idx], dtype=torch.long, device=device)
reward_tensor = torch.tensor([reward], dtype=torch.float32, device=device)
# Perform minimal training step
self.cnn_model.train()
self.cnn_optimizer.zero_grad()
# Forward pass
q_values, _, _, _, _ = self.cnn_model(features_tensor)
# Calculate basic loss
q_values_selected = q_values.gather(1, action_tensor.unsqueeze(1)).squeeze(1)
loss = nn.MSELoss()(q_values_selected, reward_tensor)
# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(self.cnn_model.parameters(), max_norm=1.0)
self.cnn_optimizer.step()
logger.info(f"Fallback CNN training completed for {model_name}: loss={loss.item():.4f}")
return True
except Exception as fallback_error:
logger.error(f"Fallback CNN training failed for {model_name}: {fallback_error}")
# If we reach here, even fallback training failed
logger.error(f"All CNN training methods failed for {model_name}. Model will not learn from this prediction.")
return False
# Try model interface training methods