inference_enabled, cleanup
This commit is contained in:
@ -1,365 +0,0 @@
|
||||
"""
|
||||
Dashboard CNN Integration
|
||||
|
||||
This module integrates the EnhancedCNNAdapter with the dashboard system,
|
||||
providing real-time training, predictions, and performance metrics display.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
|
||||
from .enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
from .standardized_data_provider import StandardizedDataProvider
|
||||
from .data_models import BaseDataInput, ModelOutput, create_model_output
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DashboardCNNIntegration:
|
||||
"""
|
||||
CNN integration for the dashboard system
|
||||
|
||||
This class:
|
||||
1. Manages CNN model lifecycle in the dashboard
|
||||
2. Provides real-time training and inference
|
||||
3. Tracks performance metrics for dashboard display
|
||||
4. Handles model predictions for chart overlay
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider: StandardizedDataProvider, symbols: List[str] = None):
|
||||
"""
|
||||
Initialize the dashboard CNN integration
|
||||
|
||||
Args:
|
||||
data_provider: Standardized data provider
|
||||
symbols: List of symbols to process
|
||||
"""
|
||||
self.data_provider = data_provider
|
||||
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
||||
|
||||
# Initialize CNN adapter
|
||||
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||
|
||||
# Load best checkpoint if available
|
||||
self.cnn_adapter.load_best_checkpoint()
|
||||
|
||||
# Performance tracking
|
||||
self.performance_metrics = {
|
||||
'total_predictions': 0,
|
||||
'total_training_samples': 0,
|
||||
'last_training_time': None,
|
||||
'last_inference_time': None,
|
||||
'training_loss_history': deque(maxlen=100),
|
||||
'accuracy_history': deque(maxlen=100),
|
||||
'inference_times': deque(maxlen=100),
|
||||
'training_times': deque(maxlen=100),
|
||||
'predictions_per_second': 0.0,
|
||||
'training_per_second': 0.0,
|
||||
'model_status': 'FRESH',
|
||||
'confidence_history': deque(maxlen=100),
|
||||
'action_distribution': {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
||||
}
|
||||
|
||||
# Prediction cache for dashboard display
|
||||
self.prediction_cache = {}
|
||||
self.prediction_history = {symbol: deque(maxlen=1000) for symbol in self.symbols}
|
||||
|
||||
# Training control
|
||||
self.training_enabled = True
|
||||
self.inference_enabled = True
|
||||
self.training_lock = threading.Lock()
|
||||
|
||||
# Real-time processing
|
||||
self.is_running = False
|
||||
self.processing_thread = None
|
||||
|
||||
logger.info(f"DashboardCNNIntegration initialized for symbols: {self.symbols}")
|
||||
|
||||
def start_real_time_processing(self):
|
||||
"""Start real-time CNN processing"""
|
||||
if self.is_running:
|
||||
logger.warning("Real-time processing already running")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.processing_thread = threading.Thread(target=self._real_time_processing_loop, daemon=True)
|
||||
self.processing_thread.start()
|
||||
|
||||
logger.info("Started real-time CNN processing")
|
||||
|
||||
def stop_real_time_processing(self):
|
||||
"""Stop real-time CNN processing"""
|
||||
self.is_running = False
|
||||
if self.processing_thread:
|
||||
self.processing_thread.join(timeout=5)
|
||||
|
||||
logger.info("Stopped real-time CNN processing")
|
||||
|
||||
def _real_time_processing_loop(self):
|
||||
"""Main real-time processing loop"""
|
||||
last_prediction_time = {}
|
||||
prediction_interval = 1.0 # Make prediction every 1 second
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
for symbol in self.symbols:
|
||||
# Check if it's time to make a prediction for this symbol
|
||||
if (symbol not in last_prediction_time or
|
||||
current_time - last_prediction_time[symbol] >= prediction_interval):
|
||||
|
||||
# Make prediction if inference is enabled
|
||||
if self.inference_enabled:
|
||||
self._make_prediction(symbol)
|
||||
last_prediction_time[symbol] = current_time
|
||||
|
||||
# Update performance metrics
|
||||
self._update_performance_metrics()
|
||||
|
||||
# Sleep briefly to prevent overwhelming the system
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in real-time processing loop: {e}")
|
||||
time.sleep(1)
|
||||
|
||||
def _make_prediction(self, symbol: str):
|
||||
"""Make a prediction for a symbol"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Get standardized input data
|
||||
base_data = self.data_provider.get_base_data_input(symbol)
|
||||
|
||||
if base_data is None:
|
||||
logger.debug(f"No base data available for {symbol}")
|
||||
return
|
||||
|
||||
# Make prediction
|
||||
model_output = self.cnn_adapter.predict(base_data)
|
||||
|
||||
# Record inference time
|
||||
inference_time = time.time() - start_time
|
||||
self.performance_metrics['inference_times'].append(inference_time)
|
||||
|
||||
# Update performance metrics
|
||||
self.performance_metrics['total_predictions'] += 1
|
||||
self.performance_metrics['last_inference_time'] = datetime.now()
|
||||
self.performance_metrics['confidence_history'].append(model_output.confidence)
|
||||
|
||||
# Update action distribution
|
||||
action = model_output.predictions['action']
|
||||
self.performance_metrics['action_distribution'][action] += 1
|
||||
|
||||
# Cache prediction for dashboard
|
||||
self.prediction_cache[symbol] = model_output
|
||||
self.prediction_history[symbol].append(model_output)
|
||||
|
||||
# Store model output in data provider
|
||||
self.data_provider.store_model_output(model_output)
|
||||
|
||||
logger.debug(f"CNN prediction for {symbol}: {action} ({model_output.confidence:.3f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error making prediction for {symbol}: {e}")
|
||||
|
||||
def add_training_sample(self, symbol: str, actual_action: str, reward: float):
|
||||
"""Add a training sample and trigger training if enabled"""
|
||||
try:
|
||||
if not self.training_enabled:
|
||||
return
|
||||
|
||||
# Get base data for the symbol
|
||||
base_data = self.data_provider.get_base_data_input(symbol)
|
||||
|
||||
if base_data is None:
|
||||
logger.debug(f"No base data available for training sample: {symbol}")
|
||||
return
|
||||
|
||||
# Add training sample
|
||||
self.cnn_adapter.add_training_sample(base_data, actual_action, reward)
|
||||
|
||||
# Update metrics
|
||||
self.performance_metrics['total_training_samples'] += 1
|
||||
|
||||
# Train model periodically (every 10 samples)
|
||||
if self.performance_metrics['total_training_samples'] % 10 == 0:
|
||||
self._train_model()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding training sample: {e}")
|
||||
|
||||
def _train_model(self):
|
||||
"""Train the CNN model"""
|
||||
try:
|
||||
with self.training_lock:
|
||||
start_time = time.time()
|
||||
|
||||
# Train model
|
||||
metrics = self.cnn_adapter.train(epochs=1)
|
||||
|
||||
# Record training time
|
||||
training_time = time.time() - start_time
|
||||
self.performance_metrics['training_times'].append(training_time)
|
||||
|
||||
# Update performance metrics
|
||||
self.performance_metrics['last_training_time'] = datetime.now()
|
||||
|
||||
if 'loss' in metrics:
|
||||
self.performance_metrics['training_loss_history'].append(metrics['loss'])
|
||||
|
||||
if 'accuracy' in metrics:
|
||||
self.performance_metrics['accuracy_history'].append(metrics['accuracy'])
|
||||
|
||||
# Update model status
|
||||
if metrics.get('accuracy', 0) > 0.5:
|
||||
self.performance_metrics['model_status'] = 'TRAINED'
|
||||
else:
|
||||
self.performance_metrics['model_status'] = 'TRAINING'
|
||||
|
||||
logger.info(f"CNN training completed: loss={metrics.get('loss', 0):.4f}, accuracy={metrics.get('accuracy', 0):.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training CNN model: {e}")
|
||||
|
||||
def _update_performance_metrics(self):
|
||||
"""Update performance metrics for dashboard display"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
# Calculate predictions per second (last 60 seconds)
|
||||
recent_inferences = [t for t in self.performance_metrics['inference_times']
|
||||
if current_time - t <= 60]
|
||||
self.performance_metrics['predictions_per_second'] = len(recent_inferences) / 60.0
|
||||
|
||||
# Calculate training per second (last 60 seconds)
|
||||
recent_trainings = [t for t in self.performance_metrics['training_times']
|
||||
if current_time - t <= 60]
|
||||
self.performance_metrics['training_per_second'] = len(recent_trainings) / 60.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating performance metrics: {e}")
|
||||
|
||||
def get_dashboard_metrics(self) -> Dict[str, Any]:
|
||||
"""Get metrics for dashboard display"""
|
||||
try:
|
||||
# Calculate current loss
|
||||
current_loss = (self.performance_metrics['training_loss_history'][-1]
|
||||
if self.performance_metrics['training_loss_history'] else 0.0)
|
||||
|
||||
# Calculate current accuracy
|
||||
current_accuracy = (self.performance_metrics['accuracy_history'][-1]
|
||||
if self.performance_metrics['accuracy_history'] else 0.0)
|
||||
|
||||
# Calculate average confidence
|
||||
avg_confidence = (np.mean(list(self.performance_metrics['confidence_history']))
|
||||
if self.performance_metrics['confidence_history'] else 0.0)
|
||||
|
||||
# Get latest prediction
|
||||
latest_prediction = None
|
||||
latest_symbol = None
|
||||
for symbol, prediction in self.prediction_cache.items():
|
||||
if latest_prediction is None or prediction.timestamp > latest_prediction.timestamp:
|
||||
latest_prediction = prediction
|
||||
latest_symbol = symbol
|
||||
|
||||
# Format timing information
|
||||
last_inference_str = "None"
|
||||
last_training_str = "None"
|
||||
|
||||
if self.performance_metrics['last_inference_time']:
|
||||
last_inference_str = self.performance_metrics['last_inference_time'].strftime("%H:%M:%S")
|
||||
|
||||
if self.performance_metrics['last_training_time']:
|
||||
last_training_str = self.performance_metrics['last_training_time'].strftime("%H:%M:%S")
|
||||
|
||||
return {
|
||||
'model_name': 'CNN',
|
||||
'model_type': 'cnn',
|
||||
'parameters': '50.0M',
|
||||
'status': self.performance_metrics['model_status'],
|
||||
'current_loss': current_loss,
|
||||
'accuracy': current_accuracy,
|
||||
'confidence': avg_confidence,
|
||||
'total_predictions': self.performance_metrics['total_predictions'],
|
||||
'total_training_samples': self.performance_metrics['total_training_samples'],
|
||||
'predictions_per_second': self.performance_metrics['predictions_per_second'],
|
||||
'training_per_second': self.performance_metrics['training_per_second'],
|
||||
'last_inference': last_inference_str,
|
||||
'last_training': last_training_str,
|
||||
'latest_prediction': {
|
||||
'action': latest_prediction.predictions['action'] if latest_prediction else 'HOLD',
|
||||
'confidence': latest_prediction.confidence if latest_prediction else 0.0,
|
||||
'symbol': latest_symbol or 'ETH/USDT',
|
||||
'timestamp': latest_prediction.timestamp.strftime("%H:%M:%S") if latest_prediction else "None"
|
||||
},
|
||||
'action_distribution': self.performance_metrics['action_distribution'].copy(),
|
||||
'training_enabled': self.training_enabled,
|
||||
'inference_enabled': self.inference_enabled
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting dashboard metrics: {e}")
|
||||
return {
|
||||
'model_name': 'CNN',
|
||||
'model_type': 'cnn',
|
||||
'parameters': '50.0M',
|
||||
'status': 'ERROR',
|
||||
'current_loss': 0.0,
|
||||
'accuracy': 0.0,
|
||||
'confidence': 0.0,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
def get_predictions_for_chart(self, symbol: str, timeframe: str = '1s', limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get predictions for chart overlay"""
|
||||
try:
|
||||
if symbol not in self.prediction_history:
|
||||
return []
|
||||
|
||||
predictions = list(self.prediction_history[symbol])[-limit:]
|
||||
|
||||
chart_data = []
|
||||
for prediction in predictions:
|
||||
chart_data.append({
|
||||
'timestamp': prediction.timestamp,
|
||||
'action': prediction.predictions['action'],
|
||||
'confidence': prediction.confidence,
|
||||
'buy_probability': prediction.predictions.get('buy_probability', 0.0),
|
||||
'sell_probability': prediction.predictions.get('sell_probability', 0.0),
|
||||
'hold_probability': prediction.predictions.get('hold_probability', 0.0)
|
||||
})
|
||||
|
||||
return chart_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting predictions for chart: {e}")
|
||||
return []
|
||||
|
||||
def set_training_enabled(self, enabled: bool):
|
||||
"""Enable or disable training"""
|
||||
self.training_enabled = enabled
|
||||
logger.info(f"CNN training {'enabled' if enabled else 'disabled'}")
|
||||
|
||||
def set_inference_enabled(self, enabled: bool):
|
||||
"""Enable or disable inference"""
|
||||
self.inference_enabled = enabled
|
||||
logger.info(f"CNN inference {'enabled' if enabled else 'disabled'}")
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""Get model information for dashboard"""
|
||||
return {
|
||||
'name': 'Enhanced CNN',
|
||||
'version': '1.0',
|
||||
'parameters': '50.0M',
|
||||
'input_shape': self.cnn_adapter.model.input_shape if self.cnn_adapter.model else 'Unknown',
|
||||
'device': str(self.cnn_adapter.device),
|
||||
'checkpoint_dir': self.cnn_adapter.checkpoint_dir,
|
||||
'training_samples': len(self.cnn_adapter.training_data),
|
||||
'max_training_samples': self.cnn_adapter.max_training_samples
|
||||
}
|
@ -1,403 +0,0 @@
|
||||
"""
|
||||
Enhanced CNN Integration for Dashboard
|
||||
|
||||
This module integrates the EnhancedCNNAdapter with the dashboard, providing real-time
|
||||
training and inference capabilities.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
import os
|
||||
|
||||
from .enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
from .standardized_data_provider import StandardizedDataProvider
|
||||
from .data_models import BaseDataInput, ModelOutput, create_model_output
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EnhancedCNNIntegration:
|
||||
"""
|
||||
Integration of EnhancedCNNAdapter with the dashboard
|
||||
|
||||
This class:
|
||||
1. Manages the EnhancedCNNAdapter lifecycle
|
||||
2. Provides real-time training and inference
|
||||
3. Collects and reports performance metrics
|
||||
4. Integrates with the dashboard's model visualization
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider: StandardizedDataProvider, checkpoint_dir: str = "models/enhanced_cnn"):
|
||||
"""
|
||||
Initialize the EnhancedCNNIntegration
|
||||
|
||||
Args:
|
||||
data_provider: StandardizedDataProvider instance
|
||||
checkpoint_dir: Directory to store checkpoints
|
||||
"""
|
||||
self.data_provider = data_provider
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.model_name = "enhanced_cnn_v1"
|
||||
|
||||
# Create checkpoint directory if it doesn't exist
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Initialize CNN adapter
|
||||
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=checkpoint_dir)
|
||||
|
||||
# Load best checkpoint if available
|
||||
self.cnn_adapter.load_best_checkpoint()
|
||||
|
||||
# Performance tracking
|
||||
self.inference_times = []
|
||||
self.training_times = []
|
||||
self.total_inferences = 0
|
||||
self.total_training_runs = 0
|
||||
self.last_inference_time = None
|
||||
self.last_training_time = None
|
||||
self.inference_rate = 0.0
|
||||
self.training_rate = 0.0
|
||||
self.daily_inferences = 0
|
||||
self.daily_training_runs = 0
|
||||
|
||||
# Training settings
|
||||
self.training_enabled = True
|
||||
self.inference_enabled = True
|
||||
self.training_frequency = 10 # Train every N inferences
|
||||
self.training_batch_size = 32
|
||||
self.training_epochs = 1
|
||||
|
||||
# Latest prediction
|
||||
self.latest_prediction = None
|
||||
self.latest_prediction_time = None
|
||||
|
||||
# Training metrics
|
||||
self.current_loss = 0.0
|
||||
self.initial_loss = None
|
||||
self.best_loss = None
|
||||
self.current_accuracy = 0.0
|
||||
self.improvement_percentage = 0.0
|
||||
|
||||
# Training thread
|
||||
self.training_thread = None
|
||||
self.training_active = False
|
||||
self.stop_training = False
|
||||
|
||||
logger.info(f"EnhancedCNNIntegration initialized with model: {self.model_name}")
|
||||
|
||||
def start_continuous_training(self):
|
||||
"""Start continuous training in a background thread"""
|
||||
if self.training_thread is not None and self.training_thread.is_alive():
|
||||
logger.info("Continuous training already running")
|
||||
return
|
||||
|
||||
self.stop_training = False
|
||||
self.training_thread = threading.Thread(target=self._continuous_training_loop, daemon=True)
|
||||
self.training_thread.start()
|
||||
logger.info("Started continuous training thread")
|
||||
|
||||
def stop_continuous_training(self):
|
||||
"""Stop continuous training"""
|
||||
self.stop_training = True
|
||||
logger.info("Stopping continuous training thread")
|
||||
|
||||
def _continuous_training_loop(self):
|
||||
"""Continuous training loop"""
|
||||
try:
|
||||
self.training_active = True
|
||||
logger.info("Starting continuous training loop")
|
||||
|
||||
while not self.stop_training:
|
||||
# Check if training is enabled
|
||||
if not self.training_enabled:
|
||||
time.sleep(5)
|
||||
continue
|
||||
|
||||
# Check if we have enough training samples
|
||||
if len(self.cnn_adapter.training_data) < self.training_batch_size:
|
||||
logger.debug(f"Not enough training samples: {len(self.cnn_adapter.training_data)}/{self.training_batch_size}")
|
||||
time.sleep(5)
|
||||
continue
|
||||
|
||||
# Train model
|
||||
start_time = time.time()
|
||||
metrics = self.cnn_adapter.train(epochs=self.training_epochs)
|
||||
training_time = time.time() - start_time
|
||||
|
||||
# Update metrics
|
||||
self.training_times.append(training_time)
|
||||
if len(self.training_times) > 100:
|
||||
self.training_times.pop(0)
|
||||
|
||||
self.total_training_runs += 1
|
||||
self.daily_training_runs += 1
|
||||
self.last_training_time = datetime.now()
|
||||
|
||||
# Calculate training rate
|
||||
if self.training_times:
|
||||
avg_training_time = sum(self.training_times) / len(self.training_times)
|
||||
self.training_rate = 1.0 / avg_training_time if avg_training_time > 0 else 0.0
|
||||
|
||||
# Update loss and accuracy
|
||||
self.current_loss = metrics.get('loss', 0.0)
|
||||
self.current_accuracy = metrics.get('accuracy', 0.0)
|
||||
|
||||
# Update initial loss if not set
|
||||
if self.initial_loss is None:
|
||||
self.initial_loss = self.current_loss
|
||||
|
||||
# Update best loss
|
||||
if self.best_loss is None or self.current_loss < self.best_loss:
|
||||
self.best_loss = self.current_loss
|
||||
|
||||
# Calculate improvement percentage
|
||||
if self.initial_loss is not None and self.initial_loss > 0:
|
||||
self.improvement_percentage = ((self.initial_loss - self.current_loss) / self.initial_loss) * 100
|
||||
|
||||
logger.info(f"Training completed: loss={self.current_loss:.4f}, accuracy={self.current_accuracy:.4f}, samples={metrics.get('samples', 0)}")
|
||||
|
||||
# Sleep before next training
|
||||
time.sleep(10)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in continuous training loop: {e}")
|
||||
finally:
|
||||
self.training_active = False
|
||||
|
||||
def predict(self, symbol: str) -> Optional[ModelOutput]:
|
||||
"""
|
||||
Make a prediction using the EnhancedCNN model
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
ModelOutput: Standardized model output
|
||||
"""
|
||||
try:
|
||||
# Check if inference is enabled
|
||||
if not self.inference_enabled:
|
||||
return None
|
||||
|
||||
# Get standardized input data
|
||||
base_data = self.data_provider.get_base_data_input(symbol)
|
||||
|
||||
if base_data is None:
|
||||
logger.warning(f"Failed to get base data input for {symbol}")
|
||||
return None
|
||||
|
||||
# Make prediction
|
||||
start_time = time.time()
|
||||
model_output = self.cnn_adapter.predict(base_data)
|
||||
inference_time = time.time() - start_time
|
||||
|
||||
# Update metrics
|
||||
self.inference_times.append(inference_time)
|
||||
if len(self.inference_times) > 100:
|
||||
self.inference_times.pop(0)
|
||||
|
||||
self.total_inferences += 1
|
||||
self.daily_inferences += 1
|
||||
self.last_inference_time = datetime.now()
|
||||
|
||||
# Calculate inference rate
|
||||
if self.inference_times:
|
||||
avg_inference_time = sum(self.inference_times) / len(self.inference_times)
|
||||
self.inference_rate = 1.0 / avg_inference_time if avg_inference_time > 0 else 0.0
|
||||
|
||||
# Store latest prediction
|
||||
self.latest_prediction = model_output
|
||||
self.latest_prediction_time = datetime.now()
|
||||
|
||||
# Store model output in data provider
|
||||
self.data_provider.store_model_output(model_output)
|
||||
|
||||
# Add training sample if we have a price
|
||||
current_price = self._get_current_price(symbol)
|
||||
if current_price and current_price > 0:
|
||||
# Simulate market feedback based on price movement
|
||||
# In a real system, this would be replaced with actual market performance data
|
||||
action = model_output.predictions['action']
|
||||
|
||||
# For demonstration, we'll use a simple heuristic:
|
||||
# - If price is above 3000, BUY is good
|
||||
# - If price is below 3000, SELL is good
|
||||
# - Otherwise, HOLD is good
|
||||
if current_price > 3000:
|
||||
best_action = 'BUY'
|
||||
elif current_price < 3000:
|
||||
best_action = 'SELL'
|
||||
else:
|
||||
best_action = 'HOLD'
|
||||
|
||||
# Calculate reward based on whether the action matched the best action
|
||||
if action == best_action:
|
||||
reward = 0.05 # Positive reward for correct action
|
||||
else:
|
||||
reward = -0.05 # Negative reward for incorrect action
|
||||
|
||||
# Add training sample
|
||||
self.cnn_adapter.add_training_sample(base_data, best_action, reward)
|
||||
|
||||
logger.debug(f"Added training sample for {symbol}, action: {action}, best_action: {best_action}, reward: {reward:.4f}")
|
||||
|
||||
return model_output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error making prediction: {e}")
|
||||
return None
|
||||
|
||||
def _get_current_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get current price for a symbol"""
|
||||
try:
|
||||
# Try to get price from data provider
|
||||
if hasattr(self.data_provider, 'current_prices'):
|
||||
binance_symbol = symbol.replace('/', '').upper()
|
||||
if binance_symbol in self.data_provider.current_prices:
|
||||
return self.data_provider.current_prices[binance_symbol]
|
||||
|
||||
# Try to get price from latest OHLCV data
|
||||
df = self.data_provider.get_historical_data(symbol, '1s', 1)
|
||||
if df is not None and not df.empty:
|
||||
return float(df.iloc[-1]['close'])
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current price: {e}")
|
||||
return None
|
||||
|
||||
def get_model_state(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get model state for dashboard display
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Model state
|
||||
"""
|
||||
try:
|
||||
# Format prediction for display
|
||||
prediction_info = "FRESH"
|
||||
confidence = 0.0
|
||||
|
||||
if self.latest_prediction:
|
||||
action = self.latest_prediction.predictions.get('action', 'UNKNOWN')
|
||||
confidence = self.latest_prediction.confidence
|
||||
|
||||
# Map action to display text
|
||||
if action == 'BUY':
|
||||
prediction_info = "BUY_SIGNAL"
|
||||
elif action == 'SELL':
|
||||
prediction_info = "SELL_SIGNAL"
|
||||
elif action == 'HOLD':
|
||||
prediction_info = "HOLD_SIGNAL"
|
||||
else:
|
||||
prediction_info = "PATTERN_ANALYSIS"
|
||||
|
||||
# Format timing information
|
||||
inference_timing = "None"
|
||||
training_timing = "None"
|
||||
|
||||
if self.last_inference_time:
|
||||
inference_timing = self.last_inference_time.strftime('%H:%M:%S')
|
||||
|
||||
if self.last_training_time:
|
||||
training_timing = self.last_training_time.strftime('%H:%M:%S')
|
||||
|
||||
# Calculate improvement percentage
|
||||
improvement = 0.0
|
||||
if self.initial_loss is not None and self.initial_loss > 0 and self.current_loss > 0:
|
||||
improvement = ((self.initial_loss - self.current_loss) / self.initial_loss) * 100
|
||||
|
||||
return {
|
||||
'model_name': self.model_name,
|
||||
'model_type': 'cnn',
|
||||
'parameters': 50000000, # 50M parameters
|
||||
'status': 'ACTIVE' if self.inference_enabled else 'DISABLED',
|
||||
'checkpoint_loaded': True, # Assume checkpoint is loaded
|
||||
'last_prediction': prediction_info,
|
||||
'confidence': confidence * 100, # Convert to percentage
|
||||
'last_inference_time': inference_timing,
|
||||
'last_training_time': training_timing,
|
||||
'inference_rate': self.inference_rate,
|
||||
'training_rate': self.training_rate,
|
||||
'daily_inferences': self.daily_inferences,
|
||||
'daily_training_runs': self.daily_training_runs,
|
||||
'initial_loss': self.initial_loss,
|
||||
'current_loss': self.current_loss,
|
||||
'best_loss': self.best_loss,
|
||||
'current_accuracy': self.current_accuracy,
|
||||
'improvement_percentage': improvement,
|
||||
'training_active': self.training_active,
|
||||
'training_enabled': self.training_enabled,
|
||||
'inference_enabled': self.inference_enabled,
|
||||
'training_samples': len(self.cnn_adapter.training_data)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model state: {e}")
|
||||
return {
|
||||
'model_name': self.model_name,
|
||||
'model_type': 'cnn',
|
||||
'parameters': 50000000, # 50M parameters
|
||||
'status': 'ERROR',
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
def get_pivot_prediction(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get pivot prediction for dashboard display
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Pivot prediction
|
||||
"""
|
||||
try:
|
||||
if not self.latest_prediction:
|
||||
return {
|
||||
'next_pivot': 0.0,
|
||||
'pivot_type': 'UNKNOWN',
|
||||
'confidence': 0.0,
|
||||
'time_to_pivot': 0
|
||||
}
|
||||
|
||||
# Extract pivot prediction from model output
|
||||
extrema_pred = self.latest_prediction.predictions.get('extrema', [0, 0, 0])
|
||||
|
||||
# Determine pivot type (0=bottom, 1=top, 2=neither)
|
||||
pivot_type_idx = extrema_pred.index(max(extrema_pred))
|
||||
pivot_types = ['BOTTOM', 'TOP', 'RANGE_CONTINUATION']
|
||||
pivot_type = pivot_types[pivot_type_idx]
|
||||
|
||||
# Get current price
|
||||
current_price = self._get_current_price('ETH/USDT') or 0.0
|
||||
|
||||
# Calculate next pivot price (simple heuristic for demonstration)
|
||||
if pivot_type == 'BOTTOM':
|
||||
next_pivot = current_price * 0.95 # 5% below current price
|
||||
elif pivot_type == 'TOP':
|
||||
next_pivot = current_price * 1.05 # 5% above current price
|
||||
else:
|
||||
next_pivot = current_price # Same as current price
|
||||
|
||||
# Calculate confidence
|
||||
confidence = max(extrema_pred) * 100 # Convert to percentage
|
||||
|
||||
# Calculate time to pivot (simple heuristic for demonstration)
|
||||
time_to_pivot = 5 # 5 minutes
|
||||
|
||||
return {
|
||||
'next_pivot': next_pivot,
|
||||
'pivot_type': pivot_type,
|
||||
'confidence': confidence,
|
||||
'time_to_pivot': time_to_pivot
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pivot prediction: {e}")
|
||||
return {
|
||||
'next_pivot': 0.0,
|
||||
'pivot_type': 'ERROR',
|
||||
'confidence': 0.0,
|
||||
'time_to_pivot': 0
|
||||
}
|
@ -3369,12 +3369,17 @@ class TradingOrchestrator:
|
||||
)
|
||||
logger.info(f" Outcome: {outcome_status}")
|
||||
|
||||
# Add performance summary
|
||||
# Add comprehensive performance summary
|
||||
if model_name in self.model_performance:
|
||||
perf = self.model_performance[model_name]
|
||||
logger.info(
|
||||
f" Performance: {perf['accuracy']:.1%} ({perf['correct']}/{perf['total']})"
|
||||
f" Performance: {perf['directional_accuracy']:.1%} directional ({perf['directional_correct']}/{perf['total']}) | "
|
||||
f"{perf['accuracy']:.1%} profitable ({perf['correct']}/{perf['total']})"
|
||||
)
|
||||
if perf["pivot_attempted"] > 0:
|
||||
logger.info(
|
||||
f" Pivot Detection: {perf['pivot_accuracy']:.1%} ({perf['pivot_detected']}/{perf['pivot_attempted']})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in immediate training for {model_name}: {e}")
|
||||
@ -3453,32 +3458,62 @@ class TradingOrchestrator:
|
||||
predicted_price_vector=predicted_price_vector,
|
||||
)
|
||||
|
||||
# Update model performance tracking
|
||||
# Initialize enhanced model performance tracking
|
||||
if model_name not in self.model_performance:
|
||||
self.model_performance[model_name] = {
|
||||
"correct": 0,
|
||||
"correct": 0, # Profitability accuracy (backwards compatible)
|
||||
"total": 0,
|
||||
"accuracy": 0.0,
|
||||
"accuracy": 0.0, # Profitability accuracy (backwards compatible)
|
||||
"directional_correct": 0, # NEW: Directional accuracy
|
||||
"directional_accuracy": 0.0, # NEW: Directional accuracy %
|
||||
"pivot_detected": 0, # NEW: Successful pivot detections
|
||||
"pivot_attempted": 0, # NEW: Total pivot attempts
|
||||
"pivot_accuracy": 0.0, # NEW: Pivot detection accuracy
|
||||
"price_predictions": {"total": 0, "accurate": 0, "avg_error": 0.0},
|
||||
}
|
||||
|
||||
# Ensure all new keys exist (for existing models)
|
||||
perf = self.model_performance[model_name]
|
||||
if "directional_correct" not in perf:
|
||||
perf["directional_correct"] = 0
|
||||
perf["directional_accuracy"] = 0.0
|
||||
perf["pivot_detected"] = 0
|
||||
perf["pivot_attempted"] = 0
|
||||
perf["pivot_accuracy"] = 0.0
|
||||
|
||||
# Ensure price_predictions key exists
|
||||
if "price_predictions" not in self.model_performance[model_name]:
|
||||
self.model_performance[model_name]["price_predictions"] = {
|
||||
"total": 0,
|
||||
"accurate": 0,
|
||||
"avg_error": 0.0,
|
||||
}
|
||||
if "price_predictions" not in perf:
|
||||
perf["price_predictions"] = {"total": 0, "accurate": 0, "avg_error": 0.0}
|
||||
|
||||
self.model_performance[model_name]["total"] += 1
|
||||
if was_correct:
|
||||
self.model_performance[model_name]["correct"] += 1
|
||||
|
||||
self.model_performance[model_name]["accuracy"] = (
|
||||
self.model_performance[model_name]["correct"]
|
||||
/ self.model_performance[model_name]["total"]
|
||||
# Calculate directional accuracy separately
|
||||
directional_correct = (
|
||||
(predicted_action == "BUY" and price_change_pct > 0) or
|
||||
(predicted_action == "SELL" and price_change_pct < 0) or
|
||||
(predicted_action == "HOLD" and abs(price_change_pct) < 0.05)
|
||||
)
|
||||
|
||||
# Update all accuracy metrics
|
||||
perf["total"] += 1
|
||||
if was_correct: # Profitability accuracy
|
||||
perf["correct"] += 1
|
||||
if directional_correct:
|
||||
perf["directional_correct"] += 1
|
||||
|
||||
# Update pivot detection tracking
|
||||
is_significant_move = abs(price_change_pct) > 0.08 # 0.08% threshold for "significant"
|
||||
if predicted_action in ["BUY", "SELL"] and is_significant_move:
|
||||
perf["pivot_attempted"] += 1
|
||||
if directional_correct:
|
||||
perf["pivot_detected"] += 1
|
||||
|
||||
# Calculate all accuracy percentages
|
||||
perf["accuracy"] = perf["correct"] / perf["total"] # Profitability accuracy
|
||||
perf["directional_accuracy"] = perf["directional_correct"] / perf["total"] # Directional accuracy
|
||||
if perf["pivot_attempted"] > 0:
|
||||
perf["pivot_accuracy"] = perf["pivot_detected"] / perf["pivot_attempted"] # Pivot accuracy
|
||||
else:
|
||||
perf["pivot_accuracy"] = 0.0
|
||||
|
||||
# Track price prediction accuracy if available
|
||||
if inference_price is not None:
|
||||
price_prediction_stats = self.model_performance[model_name][
|
||||
@ -3504,7 +3539,8 @@ class TradingOrchestrator:
|
||||
f"({price_prediction_stats['avg_error']:.2f}% avg error)"
|
||||
)
|
||||
|
||||
# Enhanced logging for training evaluation
|
||||
# Enhanced logging with new accuracy metrics
|
||||
perf = self.model_performance[model_name]
|
||||
logger.info(f"Training evaluation for {model_name}:")
|
||||
logger.info(
|
||||
f" Action: {predicted_action} | Confidence: {prediction_confidence:.3f}"
|
||||
@ -3512,10 +3548,15 @@ class TradingOrchestrator:
|
||||
logger.info(
|
||||
f" Price change: {price_change_pct:+.3f}% | Time: {time_diff_seconds:.1f}s"
|
||||
)
|
||||
logger.info(f" Reward: {reward:.4f} | Correct: {was_correct}")
|
||||
logger.info(f" Reward: {reward:.4f} | Profitable: {was_correct} | Directional: {directional_correct}")
|
||||
logger.info(
|
||||
f" Accuracy: {self.model_performance[model_name]['accuracy']:.1%} ({self.model_performance[model_name]['correct']}/{self.model_performance[model_name]['total']})"
|
||||
f" Profitability: {perf['accuracy']:.1%} ({perf['correct']}/{perf['total']}) | "
|
||||
f"Directional: {perf['directional_accuracy']:.1%} ({perf['directional_correct']}/{perf['total']})"
|
||||
)
|
||||
if perf["pivot_attempted"] > 0:
|
||||
logger.info(
|
||||
f" Pivot Detection: {perf['pivot_accuracy']:.1%} ({perf['pivot_detected']}/{perf['pivot_attempted']})"
|
||||
)
|
||||
|
||||
# Train the specific model based on sophisticated outcome
|
||||
await self._train_model_on_outcome(
|
||||
@ -3549,6 +3590,45 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating and training on record: {e}")
|
||||
|
||||
def _is_pivot_point(self, price_change_pct: float, prediction_confidence: float, time_diff_minutes: float) -> tuple[bool, str, float]:
|
||||
"""
|
||||
Detect if this is a significant pivot point worth trading.
|
||||
Pivot points are the key moments where markets change direction or momentum.
|
||||
|
||||
Returns:
|
||||
tuple: (is_pivot, pivot_type, pivot_strength)
|
||||
"""
|
||||
abs_change = abs(price_change_pct)
|
||||
|
||||
# Pivot point thresholds (much more realistic for crypto)
|
||||
minor_pivot = 0.08 # 0.08% - small but tradeable pivot
|
||||
medium_pivot = 0.25 # 0.25% - significant pivot
|
||||
major_pivot = 0.6 # 0.6% - major pivot
|
||||
massive_pivot = 1.2 # 1.2% - massive pivot
|
||||
|
||||
# Time-based multipliers (faster pivots are more valuable)
|
||||
time_multiplier = 1.0
|
||||
if time_diff_minutes < 2.0: # Very fast pivot
|
||||
time_multiplier = 2.0
|
||||
elif time_diff_minutes < 5.0: # Fast pivot
|
||||
time_multiplier = 1.5
|
||||
elif time_diff_minutes > 15.0: # Slow pivot - less valuable
|
||||
time_multiplier = 0.7
|
||||
|
||||
# Confidence multiplier (high confidence pivots are more valuable)
|
||||
confidence_multiplier = 0.5 + (prediction_confidence * 1.5) # 0.5 to 2.0
|
||||
|
||||
if abs_change >= massive_pivot:
|
||||
return True, "MASSIVE_PIVOT", 10.0 * time_multiplier * confidence_multiplier
|
||||
elif abs_change >= major_pivot:
|
||||
return True, "MAJOR_PIVOT", 5.0 * time_multiplier * confidence_multiplier
|
||||
elif abs_change >= medium_pivot:
|
||||
return True, "MEDIUM_PIVOT", 2.5 * time_multiplier * confidence_multiplier
|
||||
elif abs_change >= minor_pivot:
|
||||
return True, "MINOR_PIVOT", 1.2 * time_multiplier * confidence_multiplier
|
||||
else:
|
||||
return False, "NO_PIVOT", 0.1 # Very small reward for noise
|
||||
|
||||
def _calculate_sophisticated_reward(
|
||||
self,
|
||||
predicted_action: str,
|
||||
@ -3562,11 +3642,19 @@ class TradingOrchestrator:
|
||||
predicted_price_vector: dict = None,
|
||||
) -> tuple[float, bool]:
|
||||
"""
|
||||
Calculate sophisticated reward based on prediction accuracy, confidence, and price movement magnitude
|
||||
Now considers position status and current P&L when evaluating decisions
|
||||
NOISE REDUCTION: Treats neutral/low-confidence signals as HOLD to reduce training noise
|
||||
PRICE VECTOR BONUS: Rewards accurate price direction and magnitude predictions
|
||||
|
||||
PIVOT-POINT FOCUSED REWARD SYSTEM
|
||||
|
||||
This system heavily rewards models for correctly identifying pivot points -
|
||||
the actual profitable trading opportunities in the market. Small movements
|
||||
are treated as noise and given minimal rewards.
|
||||
|
||||
Key Features:
|
||||
- Separate directional accuracy vs profitability accuracy tracking
|
||||
- Heavy rewards for successful pivot point detection
|
||||
- Minimal penalties for noise (small movements)
|
||||
- Time-weighted rewards (faster detection = better)
|
||||
- Confidence-weighted rewards (higher confidence = better)
|
||||
|
||||
Args:
|
||||
predicted_action: The predicted action ('BUY', 'SELL', 'HOLD')
|
||||
prediction_confidence: Model's confidence in the prediction (0.0 to 1.0)
|
||||
@ -3579,21 +3667,36 @@ class TradingOrchestrator:
|
||||
predicted_price_vector: Dict with 'direction' (-1 to 1) and 'confidence' (0 to 1)
|
||||
|
||||
Returns:
|
||||
tuple: (reward, was_correct)
|
||||
tuple: (reward, directional_correct, profitability_correct, pivot_detected)
|
||||
"""
|
||||
try:
|
||||
# NOISE REDUCTION: Treat low-confidence signals as HOLD
|
||||
confidence_threshold = 0.6 # Only consider BUY/SELL if confidence > 60%
|
||||
if prediction_confidence < confidence_threshold:
|
||||
predicted_action = "HOLD"
|
||||
logger.debug(f"Low confidence ({prediction_confidence:.2f}) - treating as HOLD for noise reduction")
|
||||
# Store original action for directional accuracy tracking
|
||||
original_action = predicted_action
|
||||
|
||||
# FEE-AWARE THRESHOLDS: Account for trading fees (0.05-0.06% per trade, ~0.12% round trip)
|
||||
fee_cost = 0.12 # 0.12% round trip fee cost
|
||||
movement_threshold = 0.15 # Minimum movement to be profitable after fees
|
||||
strong_movement_threshold = 0.5 # Strong movements - good profit potential
|
||||
rapid_movement_threshold = 1.0 # Rapid movements - excellent profit potential
|
||||
massive_movement_threshold = 2.0 # Massive movements - extraordinary profit potential
|
||||
# PIVOT POINT DETECTION
|
||||
is_pivot, pivot_type, pivot_strength = self._is_pivot_point(
|
||||
price_change_pct, prediction_confidence, time_diff_minutes
|
||||
)
|
||||
|
||||
# DIRECTIONAL ACCURACY (simple direction prediction)
|
||||
directional_correct = False
|
||||
if predicted_action == "BUY" and price_change_pct > 0:
|
||||
directional_correct = True
|
||||
elif predicted_action == "SELL" and price_change_pct < 0:
|
||||
directional_correct = True
|
||||
elif predicted_action == "HOLD" and abs(price_change_pct) < 0.05: # Very small movement
|
||||
directional_correct = True
|
||||
|
||||
# PROFITABILITY ACCURACY (fee-aware profitable trades)
|
||||
fee_cost = 0.10 # 0.10% round trip fee cost (realistic for most exchanges)
|
||||
profitability_correct = False
|
||||
|
||||
if predicted_action == "BUY" and price_change_pct > fee_cost:
|
||||
profitability_correct = True
|
||||
elif predicted_action == "SELL" and price_change_pct < -fee_cost:
|
||||
profitability_correct = True
|
||||
elif predicted_action == "HOLD" and abs(price_change_pct) < fee_cost:
|
||||
profitability_correct = True
|
||||
|
||||
# Determine current position status if not provided
|
||||
if has_position is None and symbol:
|
||||
@ -3604,210 +3707,104 @@ class TradingOrchestrator:
|
||||
elif has_position is None:
|
||||
has_position = False
|
||||
|
||||
# Determine if prediction was directionally correct
|
||||
was_correct = False
|
||||
directional_accuracy = 0.0
|
||||
|
||||
if predicted_action == "BUY":
|
||||
# BUY signals need to overcome fee costs for profitability
|
||||
was_correct = price_change_pct > movement_threshold
|
||||
# PIVOT POINT REWARD CALCULATION
|
||||
base_reward = 0.0
|
||||
pivot_bonus = 0.0
|
||||
|
||||
# For backwards compatibility, use profitability_correct as the main "was_correct"
|
||||
was_correct = profitability_correct
|
||||
|
||||
# MASSIVE REWARDS FOR SUCCESSFUL PIVOT POINT DETECTION
|
||||
if is_pivot and directional_correct:
|
||||
# Base pivot reward
|
||||
base_reward = pivot_strength
|
||||
|
||||
# ENHANCED FEE-AWARE REWARD STRUCTURE
|
||||
if price_change_pct > massive_movement_threshold:
|
||||
# Massive movements (2%+) - EXTRAORDINARY rewards for high confidence
|
||||
directional_accuracy = price_change_pct * 5.0 # 5x multiplier for massive moves
|
||||
if prediction_confidence > 0.8:
|
||||
directional_accuracy *= 2.0 # Additional 2x for high confidence (10x total)
|
||||
elif price_change_pct > rapid_movement_threshold:
|
||||
# Rapid movements (1%+) - EXCELLENT rewards for high confidence
|
||||
directional_accuracy = price_change_pct * 3.0 # 3x multiplier for rapid moves
|
||||
if prediction_confidence > 0.7:
|
||||
directional_accuracy *= 1.5 # Additional 1.5x for good confidence (4.5x total)
|
||||
elif price_change_pct > strong_movement_threshold:
|
||||
# Strong movements (0.5%+) - GOOD rewards
|
||||
directional_accuracy = price_change_pct * 2.0 # 2x multiplier for strong moves
|
||||
else:
|
||||
# Small movements - minimal rewards (fees eat most profit)
|
||||
directional_accuracy = max(0, (price_change_pct - fee_cost)) * 0.5 # Penalty for fee cost
|
||||
# EXTRAORDINARY bonuses for successful pivot predictions
|
||||
if pivot_type == "MASSIVE_PIVOT":
|
||||
pivot_bonus = 50.0 * prediction_confidence # Up to 50x reward!
|
||||
logger.info(f"MASSIVE PIVOT SUCCESS: {pivot_type} detected with {prediction_confidence:.2f} confidence = {pivot_bonus:.1f}x bonus!")
|
||||
elif pivot_type == "MAJOR_PIVOT":
|
||||
pivot_bonus = 20.0 * prediction_confidence # Up to 20x reward!
|
||||
logger.info(f"MAJOR PIVOT SUCCESS: {pivot_type} detected with {prediction_confidence:.2f} confidence = {pivot_bonus:.1f}x bonus!")
|
||||
elif pivot_type == "MEDIUM_PIVOT":
|
||||
pivot_bonus = 8.0 * prediction_confidence # Up to 8x reward!
|
||||
logger.info(f"MEDIUM PIVOT SUCCESS: {pivot_type} detected with {prediction_confidence:.2f} confidence = {pivot_bonus:.1f}x bonus!")
|
||||
elif pivot_type == "MINOR_PIVOT":
|
||||
pivot_bonus = 3.0 * prediction_confidence # Up to 3x reward!
|
||||
logger.info(f"MINOR PIVOT SUCCESS: {pivot_type} detected with {prediction_confidence:.2f} confidence = {pivot_bonus:.1f}x bonus!")
|
||||
|
||||
elif predicted_action == "SELL":
|
||||
# SELL signals need to overcome fee costs for profitability
|
||||
was_correct = price_change_pct < -movement_threshold
|
||||
# Additional time-based bonus for early detection
|
||||
if time_diff_minutes < 1.0:
|
||||
time_bonus = pivot_bonus * 0.5 # 50% bonus for very fast detection
|
||||
pivot_bonus += time_bonus
|
||||
logger.info(f"EARLY DETECTION BONUS: Detected {pivot_type} in {time_diff_minutes:.1f}m = +{time_bonus:.1f} bonus")
|
||||
|
||||
base_reward += pivot_bonus
|
||||
|
||||
elif is_pivot and not directional_correct:
|
||||
# MODERATE penalty for missing pivot points (still valuable to learn from)
|
||||
base_reward = -pivot_strength * 0.3 # Small penalty to encourage learning
|
||||
logger.debug(f"MISSED PIVOT: {pivot_type} missed, small penalty = {base_reward:.2f}")
|
||||
|
||||
elif not is_pivot and directional_correct:
|
||||
# Small reward for correct direction on non-pivots (noise)
|
||||
base_reward = 0.2 * prediction_confidence
|
||||
logger.debug(f"NOISE CORRECT: Correct direction on noise movement = {base_reward:.2f}")
|
||||
|
||||
# ENHANCED FEE-AWARE REWARD STRUCTURE (symmetric to BUY)
|
||||
abs_change = abs(price_change_pct)
|
||||
if abs_change > massive_movement_threshold:
|
||||
# Massive movements (2%+) - EXTRAORDINARY rewards for high confidence
|
||||
directional_accuracy = abs_change * 5.0 # 5x multiplier for massive moves
|
||||
if prediction_confidence > 0.8:
|
||||
directional_accuracy *= 2.0 # Additional 2x for high confidence (10x total)
|
||||
elif abs_change > rapid_movement_threshold:
|
||||
# Rapid movements (1%+) - EXCELLENT rewards for high confidence
|
||||
directional_accuracy = abs_change * 3.0 # 3x multiplier for rapid moves
|
||||
if prediction_confidence > 0.7:
|
||||
directional_accuracy *= 1.5 # Additional 1.5x for good confidence (4.5x total)
|
||||
elif abs_change > strong_movement_threshold:
|
||||
# Strong movements (0.5%+) - GOOD rewards
|
||||
directional_accuracy = abs_change * 2.0 # 2x multiplier for strong moves
|
||||
else:
|
||||
# Small movements - minimal rewards (fees eat most profit)
|
||||
directional_accuracy = max(0, (abs_change - fee_cost)) * 0.5 # Penalty for fee cost
|
||||
|
||||
elif predicted_action == "HOLD":
|
||||
# HOLD evaluation with noise reduction - smaller rewards to reduce training noise
|
||||
if has_position:
|
||||
# If we have a position, HOLD evaluation depends on P&L and price movement
|
||||
if current_position_pnl > 0: # Currently profitable position
|
||||
# Holding a profitable position is good if price continues favorably
|
||||
if price_change_pct > 0: # Price went up while holding profitable position - excellent
|
||||
was_correct = True
|
||||
directional_accuracy = price_change_pct * 0.8 # Reduced from 1.5 to reduce noise
|
||||
elif abs(price_change_pct) < movement_threshold: # Price stable - good
|
||||
was_correct = True
|
||||
directional_accuracy = movement_threshold * 0.5 # Reduced reward to reduce noise
|
||||
else: # Price dropped while holding profitable position - still okay but less reward
|
||||
was_correct = True
|
||||
directional_accuracy = max(0, (current_position_pnl / 100.0) - abs(price_change_pct) * 0.3)
|
||||
elif current_position_pnl < 0: # Currently losing position
|
||||
# Holding a losing position is generally bad - should consider closing
|
||||
if price_change_pct > movement_threshold: # Price recovered - good hold
|
||||
was_correct = True
|
||||
directional_accuracy = price_change_pct * 0.6 # Reduced reward
|
||||
else: # Price continued down or stayed flat - bad hold
|
||||
was_correct = False
|
||||
# Penalty proportional to loss magnitude
|
||||
directional_accuracy = abs(current_position_pnl / 100.0) * 0.3 # Reduced penalty
|
||||
else: # Breakeven position
|
||||
# Standard HOLD evaluation for breakeven positions
|
||||
if abs(price_change_pct) < movement_threshold: # Price stable - good
|
||||
was_correct = True
|
||||
directional_accuracy = movement_threshold * 0.4 # Reduced reward
|
||||
else: # Price moved significantly - missed opportunity
|
||||
was_correct = False
|
||||
directional_accuracy = max(0, movement_threshold - abs(price_change_pct)) * 0.5
|
||||
else:
|
||||
# If we don't have a position, HOLD is correct if price stayed relatively stable
|
||||
was_correct = abs(price_change_pct) < movement_threshold
|
||||
directional_accuracy = max(0, movement_threshold - abs(price_change_pct)) * 0.4 # Reduced reward
|
||||
|
||||
# Calculate FEE-AWARE magnitude-based multiplier (aggressive rewards for profitable movements)
|
||||
abs_movement = abs(price_change_pct)
|
||||
if abs_movement > massive_movement_threshold:
|
||||
magnitude_multiplier = min(abs_movement / 1.0, 8.0) # Up to 8x for massive moves (8% = 8x)
|
||||
elif abs_movement > rapid_movement_threshold:
|
||||
magnitude_multiplier = min(abs_movement / 1.5, 4.0) # Up to 4x for rapid moves (6% = 4x)
|
||||
elif abs_movement > strong_movement_threshold:
|
||||
magnitude_multiplier = min(abs_movement / 2.0, 2.0) # Up to 2x for strong moves (4% = 2x)
|
||||
else:
|
||||
# Small movements get minimal multiplier due to fees
|
||||
magnitude_multiplier = max(0.1, (abs_movement - fee_cost) / 2.0) # Penalty for fee cost
|
||||
|
||||
# Calculate confidence-based reward adjustment
|
||||
if was_correct:
|
||||
# Reward confident correct predictions more, penalize unconfident correct predictions less
|
||||
confidence_multiplier = 0.5 + (
|
||||
prediction_confidence * 1.5
|
||||
) # Range: 0.5 to 2.0
|
||||
base_reward = (
|
||||
directional_accuracy * magnitude_multiplier * confidence_multiplier
|
||||
# Very small penalty for wrong direction on noise (don't overtrain on noise)
|
||||
base_reward = -0.1 * prediction_confidence
|
||||
logger.debug(f"NOISE INCORRECT: Wrong direction on noise movement = {base_reward:.2f}")
|
||||
|
||||
# POSITION-AWARE ADJUSTMENTS
|
||||
if has_position:
|
||||
# Adjust rewards based on current position status
|
||||
if current_position_pnl > 0.5: # Profitable position
|
||||
if predicted_action == "HOLD" and price_change_pct > 0:
|
||||
base_reward += 0.5 # Bonus for holding profitable position during uptrend
|
||||
logger.debug(f"POSITION BONUS: Holding profitable position during uptrend = +0.5")
|
||||
elif current_position_pnl < -0.5: # Losing position
|
||||
if predicted_action in ["BUY", "SELL"] and directional_correct:
|
||||
base_reward += 0.3 # Bonus for taking action to exit losing position
|
||||
logger.debug(f"EXIT BONUS: Taking action on losing position = +0.3")
|
||||
|
||||
# PRICE VECTOR BONUS (if available)
|
||||
if predicted_price_vector and isinstance(predicted_price_vector, dict):
|
||||
vector_bonus = self._calculate_price_vector_bonus(
|
||||
predicted_price_vector, price_change_pct, abs(price_change_pct), prediction_confidence
|
||||
)
|
||||
if vector_bonus > 0:
|
||||
base_reward += vector_bonus
|
||||
logger.debug(f"PRICE VECTOR BONUS: +{vector_bonus:.3f}")
|
||||
|
||||
# ENHANCED HIGH-CONFIDENCE BONUSES for profitable movements
|
||||
abs_movement = abs(price_change_pct)
|
||||
|
||||
# Extraordinary confidence bonus for massive movements
|
||||
if prediction_confidence > 0.9 and abs_movement > massive_movement_threshold:
|
||||
base_reward *= 3.0 # 300% bonus for ultra-confident massive moves
|
||||
logger.info(f"ULTRA CONFIDENCE BONUS: {prediction_confidence:.2f} confidence + {abs_movement:.2f}% movement = 3x reward")
|
||||
|
||||
# Excellent confidence bonus for rapid movements
|
||||
elif prediction_confidence > 0.8 and abs_movement > rapid_movement_threshold:
|
||||
base_reward *= 2.0 # 200% bonus for very confident rapid moves
|
||||
logger.info(f"HIGH CONFIDENCE BONUS: {prediction_confidence:.2f} confidence + {abs_movement:.2f}% movement = 2x reward")
|
||||
|
||||
# Good confidence bonus for strong movements
|
||||
elif prediction_confidence > 0.7 and abs_movement > strong_movement_threshold:
|
||||
base_reward *= 1.5 # 150% bonus for confident strong moves
|
||||
logger.info(f"CONFIDENCE BONUS: {prediction_confidence:.2f} confidence + {abs_movement:.2f}% movement = 1.5x reward")
|
||||
|
||||
# Rapid movement detection bonus (speed matters for fees)
|
||||
if time_diff_minutes < 5.0 and abs_movement > rapid_movement_threshold:
|
||||
base_reward *= 1.3 # 30% bonus for rapid detection of big moves
|
||||
logger.info(f"RAPID DETECTION BONUS: {abs_movement:.2f}% movement in {time_diff_minutes:.1f}m = 1.3x reward")
|
||||
|
||||
# PRICE VECTOR ACCURACY BONUS - Reward models for accurate price direction/magnitude predictions
|
||||
if predicted_price_vector and isinstance(predicted_price_vector, dict):
|
||||
vector_bonus = self._calculate_price_vector_bonus(
|
||||
predicted_price_vector, price_change_pct, abs_movement, prediction_confidence
|
||||
)
|
||||
if vector_bonus > 0:
|
||||
base_reward += vector_bonus
|
||||
logger.info(f"PRICE VECTOR BONUS: +{vector_bonus:.3f} for accurate direction/magnitude prediction")
|
||||
|
||||
else:
|
||||
# ENHANCED PENALTY SYSTEM: Discourage fee-losing trades
|
||||
abs_movement = abs(price_change_pct)
|
||||
|
||||
# Penalize incorrect predictions more severely if they were confident
|
||||
confidence_penalty = 0.5 + (prediction_confidence * 1.5) # Higher confidence = higher penalty
|
||||
base_penalty = abs_movement * confidence_penalty
|
||||
|
||||
# SEVERE penalties for confident wrong predictions on big moves
|
||||
if prediction_confidence > 0.8 and abs_movement > rapid_movement_threshold:
|
||||
base_penalty *= 5.0 # 5x penalty for very confident wrong on big moves
|
||||
logger.warning(f"SEVERE PENALTY: {prediction_confidence:.2f} confidence wrong on {abs_movement:.2f}% movement = 5x penalty")
|
||||
elif prediction_confidence > 0.7 and abs_movement > strong_movement_threshold:
|
||||
base_penalty *= 3.0 # 3x penalty for confident wrong on strong moves
|
||||
logger.warning(f"HIGH PENALTY: {prediction_confidence:.2f} confidence wrong on {abs_movement:.2f}% movement = 3x penalty")
|
||||
elif prediction_confidence > 0.8:
|
||||
base_penalty *= 2.0 # 2x penalty for overconfident wrong predictions
|
||||
|
||||
# ADDITIONAL penalty for predictions that would lose money to fees
|
||||
if abs_movement < fee_cost and prediction_confidence > 0.5:
|
||||
fee_loss_penalty = (fee_cost - abs_movement) * 2.0 # Penalty for fee-losing trades
|
||||
base_penalty += fee_loss_penalty
|
||||
logger.warning(f"FEE LOSS PENALTY: {abs_movement:.2f}% movement < {fee_cost:.2f}% fees = +{fee_loss_penalty:.3f} penalty")
|
||||
|
||||
base_reward = -base_penalty
|
||||
|
||||
# Time decay factor (predictions should be evaluated quickly)
|
||||
time_decay = max(
|
||||
0.1, 1.0 - (time_diff_minutes / 60.0)
|
||||
) # Decay over 1 hour, min 10%
|
||||
|
||||
# Final reward calculation
|
||||
# Time decay factor (pivot detection should be fast)
|
||||
time_decay = max(0.3, 1.0 - (time_diff_minutes / 30.0)) # Decay over 30 minutes, min 30%
|
||||
|
||||
# Apply time decay
|
||||
final_reward = base_reward * time_decay
|
||||
|
||||
# Bonus for accurate price predictions
|
||||
if (
|
||||
has_price_prediction and abs(price_change_pct) < 1.0
|
||||
): # Accurate price prediction
|
||||
final_reward *= 1.2 # 20% bonus for accurate price predictions
|
||||
logger.debug(
|
||||
f"Applied price prediction accuracy bonus: {final_reward:.3f}"
|
||||
)
|
||||
|
||||
# Clamp reward to reasonable range
|
||||
final_reward = max(-5.0, min(5.0, final_reward))
|
||||
|
||||
|
||||
# Clamp reward to reasonable range (higher range for pivot bonuses)
|
||||
final_reward = max(-10.0, min(100.0, final_reward))
|
||||
|
||||
# Log detailed accuracy information
|
||||
logger.debug(
|
||||
f"REWARD CALCULATION: action={predicted_action}, confidence={prediction_confidence:.3f}, "
|
||||
f"price_change={price_change_pct:.3f}%, pivot={is_pivot}/{pivot_type}, "
|
||||
f"directional_correct={directional_correct}, profitability_correct={profitability_correct}, "
|
||||
f"reward={final_reward:.3f}"
|
||||
)
|
||||
|
||||
return final_reward, was_correct
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating sophisticated reward: {e}")
|
||||
# Fallback to simple reward with position awareness
|
||||
has_position = self._has_open_position(symbol) if symbol else False
|
||||
|
||||
if predicted_action == "HOLD" and has_position:
|
||||
# If holding a position, HOLD is correct if price didn't drop significantly
|
||||
simple_correct = price_change_pct > -0.2 # Allow small losses while holding
|
||||
else:
|
||||
# Standard evaluation for other cases
|
||||
simple_correct = (
|
||||
(predicted_action == "BUY" and price_change_pct > 0.1)
|
||||
or (predicted_action == "SELL" and price_change_pct < -0.1)
|
||||
or (predicted_action == "HOLD" and abs(price_change_pct) < 0.1)
|
||||
)
|
||||
return (1.0 if simple_correct else -0.5, simple_correct)
|
||||
# Fallback to simple directional accuracy
|
||||
simple_correct = (
|
||||
(predicted_action == "BUY" and price_change_pct > 0) or
|
||||
(predicted_action == "SELL" and price_change_pct < 0) or
|
||||
(predicted_action == "HOLD" and abs(price_change_pct) < 0.05)
|
||||
)
|
||||
return (1.0 if simple_correct else -0.1, simple_correct)
|
||||
|
||||
def _calculate_price_vector_bonus(
|
||||
self,
|
||||
@ -4334,6 +4331,25 @@ class TradingOrchestrator:
|
||||
|
||||
# Create training sample from record
|
||||
model_input = record.get("model_input")
|
||||
|
||||
# If model_input is None, try to generate fresh state for training
|
||||
if model_input is None:
|
||||
logger.debug(f"No stored model input for {model_name}, generating fresh state")
|
||||
try:
|
||||
# Generate fresh input state for training
|
||||
if hasattr(self, 'data_provider') and self.data_provider:
|
||||
# Use data provider to generate current market state
|
||||
fresh_state = self._generate_fresh_state_fallback(model_name)
|
||||
if fresh_state is not None and len(fresh_state) > 0:
|
||||
model_input = fresh_state
|
||||
logger.debug(f"Generated fresh training state for {model_name}: shape={fresh_state.shape if hasattr(fresh_state, 'shape') else len(fresh_state)}")
|
||||
else:
|
||||
logger.warning(f"Failed to generate fresh state for {model_name}")
|
||||
else:
|
||||
logger.warning(f"No data provider available for generating fresh state for {model_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error generating fresh state for {model_name}: {e}")
|
||||
|
||||
if model_input is not None:
|
||||
# Convert to tensor and ensure device placement
|
||||
device = next(self.cnn_model.parameters()).device
|
||||
@ -4432,7 +4448,71 @@ class TradingOrchestrator:
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"No model input available for CNN training")
|
||||
logger.warning(f"No model input available for CNN training for {model_name}. This prevents the model from learning.")
|
||||
|
||||
# Try one more time to generate training data from current market conditions
|
||||
try:
|
||||
if hasattr(self, 'data_provider') and self.data_provider:
|
||||
# Create minimal training sample from current market data
|
||||
symbol = record.get("symbol", "ETH/USDT")
|
||||
current_price = self._get_current_price(symbol)
|
||||
|
||||
# Get variables from function scope
|
||||
actual_action = prediction["action"]
|
||||
pred_confidence = prediction.get("confidence", 0.5)
|
||||
|
||||
# Create a basic feature vector (this is a fallback)
|
||||
basic_features = np.array([
|
||||
current_price / 10000.0, # Normalized price
|
||||
pred_confidence, # Model confidence
|
||||
reward, # Current reward
|
||||
1.0 if actual_action == "BUY" else 0.0,
|
||||
1.0 if actual_action == "SELL" else 0.0,
|
||||
1.0 if actual_action == "HOLD" else 0.0
|
||||
], dtype=np.float32)
|
||||
|
||||
# Pad to expected size if needed
|
||||
expected_size = 512 # Adjust based on your model's expected input size
|
||||
if len(basic_features) < expected_size:
|
||||
padding = np.zeros(expected_size - len(basic_features), dtype=np.float32)
|
||||
basic_features = np.concatenate([basic_features, padding])
|
||||
|
||||
logger.info(f"Created fallback training features for {model_name}: shape={basic_features.shape}")
|
||||
|
||||
# Now perform training with the fallback features
|
||||
device = next(self.cnn_model.parameters()).device
|
||||
features_tensor = torch.tensor(basic_features, dtype=torch.float32, device=device).unsqueeze(0)
|
||||
|
||||
# Convert action to index
|
||||
actions = ["BUY", "SELL", "HOLD"]
|
||||
action_idx = actions.index(actual_action) if actual_action in actions else 2
|
||||
action_tensor = torch.tensor([action_idx], dtype=torch.long, device=device)
|
||||
reward_tensor = torch.tensor([reward], dtype=torch.float32, device=device)
|
||||
|
||||
# Perform minimal training step
|
||||
self.cnn_model.train()
|
||||
self.cnn_optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
q_values, _, _, _, _ = self.cnn_model(features_tensor)
|
||||
|
||||
# Calculate basic loss
|
||||
q_values_selected = q_values.gather(1, action_tensor.unsqueeze(1)).squeeze(1)
|
||||
loss = nn.MSELoss()(q_values_selected, reward_tensor)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.cnn_model.parameters(), max_norm=1.0)
|
||||
self.cnn_optimizer.step()
|
||||
|
||||
logger.info(f"Fallback CNN training completed for {model_name}: loss={loss.item():.4f}")
|
||||
return True
|
||||
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Fallback CNN training failed for {model_name}: {fallback_error}")
|
||||
|
||||
# If we reach here, even fallback training failed
|
||||
logger.error(f"All CNN training methods failed for {model_name}. Model will not learn from this prediction.")
|
||||
return False
|
||||
|
||||
# Try model interface training methods
|
||||
|
Reference in New Issue
Block a user