Compare commits
12 Commits
9671d0d363
...
1f35258a66
Author | SHA1 | Date | |
---|---|---|---|
![]() |
1f35258a66 | ||
![]() |
2e1b3be2cd | ||
![]() |
34780d62c7 | ||
![]() |
47d63fddfb | ||
![]() |
2f51966fa8 | ||
![]() |
55fb865e7f | ||
![]() |
a3029d09c2 | ||
![]() |
17e18ae86c | ||
![]() |
8c17082643 | ||
![]() |
729e0bccb1 | ||
![]() |
317c703ea0 | ||
![]() |
0e886527c8 |
5
.cursor/rules/no-duplicate-implementations.mdc
Normal file
5
.cursor/rules/no-duplicate-implementations.mdc
Normal file
@@ -0,0 +1,5 @@
|
||||
---
|
||||
description: Before implementing new idea look if we have existing partial or full implementation that we can work with instead of branching off. if you spot duplicate implementations suggest to merge and streamline them.
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
@@ -6,8 +6,6 @@ Much larger and more sophisticated architecture for better learning
|
||||
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from datetime import datetime
|
||||
import math
|
||||
|
||||
@@ -15,10 +13,30 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
# Try to import optional dependencies
|
||||
try:
|
||||
import numpy as np
|
||||
HAS_NUMPY = True
|
||||
except ImportError:
|
||||
np = None
|
||||
HAS_NUMPY = False
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
HAS_MATPLOTLIB = True
|
||||
except ImportError:
|
||||
plt = None
|
||||
HAS_MATPLOTLIB = False
|
||||
|
||||
try:
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
||||
HAS_SKLEARN = True
|
||||
except ImportError:
|
||||
HAS_SKLEARN = False
|
||||
|
||||
# Import checkpoint management
|
||||
from NN.training.model_manager import save_checkpoint, load_best_checkpoint
|
||||
from NN.training.model_manager import create_model_manager
|
||||
@@ -125,11 +143,12 @@ class EnhancedCNNModel(nn.Module):
|
||||
def __init__(self,
|
||||
input_size: int = 60,
|
||||
feature_dim: int = 50,
|
||||
output_size: int = 2, # BUY/SELL for 2-action system
|
||||
output_size: int = 5, # OHLCV prediction (Open, High, Low, Close, Volume)
|
||||
base_channels: int = 256, # Increased from 128 to 256
|
||||
num_blocks: int = 12, # Increased from 6 to 12
|
||||
num_attention_heads: int = 16, # Increased from 8 to 16
|
||||
dropout_rate: float = 0.2):
|
||||
dropout_rate: float = 0.2,
|
||||
prediction_horizon: int = 1): # New: Prediction horizon in minutes
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
@@ -397,48 +416,51 @@ class EnhancedCNNModel(nn.Module):
|
||||
volatility_pred = self._memory_barrier(self.volatility_predictor(processed_features))
|
||||
confidence = self._memory_barrier(self.confidence_head(processed_features))
|
||||
|
||||
# Combine all features for final decision (8 regime classes + 1 volatility)
|
||||
# Combine all features for OHLCV prediction
|
||||
# Create completely independent tensors for concatenation
|
||||
vol_pred_flat = self._memory_barrier(volatility_pred.reshape(volatility_pred.shape[0], -1)) # Flatten instead of squeeze
|
||||
combined_features = torch.cat([processed_features, regime_probs, vol_pred_flat], dim=1)
|
||||
combined_features = self._memory_barrier(combined_features)
|
||||
|
||||
trading_logits = self._memory_barrier(self.decision_head(combined_features))
|
||||
# OHLCV prediction (Open, High, Low, Close, Volume)
|
||||
ohlcv_pred = self._memory_barrier(self.decision_head(combined_features))
|
||||
|
||||
# Apply temperature scaling for better calibration - create new tensor
|
||||
temperature = 1.5
|
||||
scaled_logits = trading_logits / temperature
|
||||
trading_probs = self._memory_barrier(F.softmax(scaled_logits, dim=1))
|
||||
|
||||
# Flatten confidence to ensure consistent shape
|
||||
# Generate confidence based on prediction stability
|
||||
confidence_flat = self._memory_barrier(confidence.reshape(confidence.shape[0], -1))
|
||||
volatility_flat = self._memory_barrier(volatility_pred.reshape(volatility_pred.shape[0], -1))
|
||||
|
||||
# Calculate prediction confidence based on volatility and regime stability
|
||||
regime_stability = torch.std(regime_probs, dim=1, keepdim=True)
|
||||
prediction_confidence = 1.0 / (1.0 + regime_stability + volatility_flat * 0.1)
|
||||
prediction_confidence = self._memory_barrier(prediction_confidence.squeeze(-1))
|
||||
|
||||
return {
|
||||
'logits': self._memory_barrier(trading_logits),
|
||||
'probabilities': self._memory_barrier(trading_probs),
|
||||
'confidence': confidence_flat[:, 0] if confidence_flat.shape[1] > 0 else confidence_flat.reshape(-1)[0],
|
||||
'ohlcv': self._memory_barrier(ohlcv_pred), # [batch_size, 5] - OHLCV predictions
|
||||
'confidence': prediction_confidence,
|
||||
'regime': self._memory_barrier(regime_probs),
|
||||
'volatility': volatility_flat[:, 0] if volatility_flat.shape[1] > 0 else volatility_flat.reshape(-1)[0],
|
||||
'features': self._memory_barrier(processed_features)
|
||||
'features': self._memory_barrier(processed_features),
|
||||
'regime_stability': self._memory_barrier(regime_stability.squeeze(-1))
|
||||
}
|
||||
|
||||
def predict(self, feature_matrix: np.ndarray) -> Dict[str, Any]:
|
||||
def predict(self, feature_matrix) -> Dict[str, Any]:
|
||||
"""
|
||||
Make predictions on feature matrix
|
||||
Make OHLCV predictions on feature matrix
|
||||
Args:
|
||||
feature_matrix: numpy array of shape [sequence_length, features]
|
||||
feature_matrix: tensor or numpy array of shape [sequence_length, features]
|
||||
Returns:
|
||||
Dictionary with prediction results
|
||||
Dictionary with OHLCV prediction results and trading signals
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
# Convert to tensor and add batch dimension
|
||||
if isinstance(feature_matrix, np.ndarray):
|
||||
if HAS_NUMPY and isinstance(feature_matrix, np.ndarray):
|
||||
x = torch.FloatTensor(feature_matrix).unsqueeze(0) # Add batch dim
|
||||
else:
|
||||
elif isinstance(feature_matrix, torch.Tensor):
|
||||
x = feature_matrix.unsqueeze(0)
|
||||
else:
|
||||
x = torch.FloatTensor(feature_matrix).unsqueeze(0)
|
||||
|
||||
# Move to device
|
||||
device = next(self.parameters()).device
|
||||
@@ -447,14 +469,16 @@ class EnhancedCNNModel(nn.Module):
|
||||
# Forward pass
|
||||
outputs = self.forward(x)
|
||||
|
||||
# Extract results with proper shape handling
|
||||
probs = outputs['probabilities'].cpu().numpy()[0]
|
||||
confidence_tensor = outputs['confidence'].cpu().numpy()
|
||||
regime = outputs['regime'].cpu().numpy()[0]
|
||||
volatility = outputs['volatility'].cpu().numpy()
|
||||
# Extract OHLCV predictions
|
||||
ohlcv_pred = outputs['ohlcv'].cpu().numpy()[0] if HAS_NUMPY else outputs['ohlcv'].cpu().tolist()[0]
|
||||
|
||||
# Extract other outputs
|
||||
confidence_tensor = outputs['confidence'].cpu().numpy() if HAS_NUMPY else outputs['confidence'].cpu().tolist()
|
||||
regime = outputs['regime'].cpu().numpy()[0] if HAS_NUMPY else outputs['regime'].cpu().tolist()[0]
|
||||
volatility = outputs['volatility'].cpu().numpy() if HAS_NUMPY else outputs['volatility'].cpu().tolist()
|
||||
|
||||
# Handle confidence shape properly
|
||||
if isinstance(confidence_tensor, np.ndarray):
|
||||
if HAS_NUMPY and isinstance(confidence_tensor, np.ndarray):
|
||||
if confidence_tensor.ndim == 0:
|
||||
confidence = float(confidence_tensor.item())
|
||||
elif confidence_tensor.size == 1:
|
||||
@@ -465,7 +489,7 @@ class EnhancedCNNModel(nn.Module):
|
||||
confidence = float(confidence_tensor)
|
||||
|
||||
# Handle volatility shape properly
|
||||
if isinstance(volatility, np.ndarray):
|
||||
if HAS_NUMPY and isinstance(volatility, np.ndarray):
|
||||
if volatility.ndim == 0:
|
||||
volatility = float(volatility.item())
|
||||
elif volatility.size == 1:
|
||||
@@ -475,19 +499,68 @@ class EnhancedCNNModel(nn.Module):
|
||||
else:
|
||||
volatility = float(volatility)
|
||||
|
||||
# Determine action (0=BUY, 1=SELL for 2-action system)
|
||||
action = int(np.argmax(probs))
|
||||
action_confidence = float(probs[action])
|
||||
# Extract OHLCV values
|
||||
open_price, high_price, low_price, close_price, volume = ohlcv_pred
|
||||
|
||||
# Calculate price movement and direction
|
||||
price_change = close_price - open_price
|
||||
price_change_pct = (price_change / open_price) * 100 if open_price != 0 else 0
|
||||
|
||||
# Calculate candle characteristics
|
||||
body_size = abs(close_price - open_price)
|
||||
upper_wick = high_price - max(open_price, close_price)
|
||||
lower_wick = min(open_price, close_price) - low_price
|
||||
total_range = high_price - low_price
|
||||
|
||||
# Determine trading action based on predicted candle
|
||||
if price_change_pct > 0.1: # Bullish candle (>0.1% gain)
|
||||
action = 0 # BUY
|
||||
action_name = 'BUY'
|
||||
action_confidence = min(0.95, confidence * (1 + abs(price_change_pct) * 10))
|
||||
elif price_change_pct < -0.1: # Bearish candle (<-0.1% loss)
|
||||
action = 1 # SELL
|
||||
action_name = 'SELL'
|
||||
action_confidence = min(0.95, confidence * (1 + abs(price_change_pct) * 10))
|
||||
else: # Sideways/neutral candle
|
||||
# Use body vs wick analysis for weak signals
|
||||
if body_size / total_range > 0.7: # Strong directional body
|
||||
action = 0 if price_change > 0 else 1
|
||||
action_name = 'BUY' if action == 0 else 'SELL'
|
||||
action_confidence = confidence * 0.6 # Reduce confidence for weak signals
|
||||
else:
|
||||
action = 2 # HOLD
|
||||
action_name = 'HOLD'
|
||||
action_confidence = confidence * 0.3 # Very low confidence
|
||||
|
||||
# Adjust confidence based on volatility
|
||||
if volatility > 0.5: # High volatility
|
||||
action_confidence *= 0.8 # Reduce confidence in volatile conditions
|
||||
elif volatility < 0.2: # Low volatility
|
||||
action_confidence *= 1.2 # Increase confidence in stable conditions
|
||||
action_confidence = min(0.95, action_confidence) # Cap at 95%
|
||||
|
||||
return {
|
||||
'action': action,
|
||||
'action_name': 'BUY' if action == 0 else 'SELL',
|
||||
'action_name': action_name,
|
||||
'confidence': float(confidence),
|
||||
'action_confidence': action_confidence,
|
||||
'probabilities': probs.tolist(),
|
||||
'regime_probabilities': regime.tolist(),
|
||||
'ohlcv_prediction': {
|
||||
'open': float(open_price),
|
||||
'high': float(high_price),
|
||||
'low': float(low_price),
|
||||
'close': float(close_price),
|
||||
'volume': float(volume)
|
||||
},
|
||||
'price_change_pct': price_change_pct,
|
||||
'candle_characteristics': {
|
||||
'body_size': body_size,
|
||||
'upper_wick': upper_wick,
|
||||
'lower_wick': lower_wick,
|
||||
'total_range': total_range
|
||||
},
|
||||
'regime_probabilities': regime if isinstance(regime, list) else regime.tolist(),
|
||||
'volatility_prediction': float(volatility),
|
||||
'raw_logits': outputs['logits'].cpu().numpy()[0].tolist()
|
||||
'prediction_quality': 'high' if action_confidence > 0.8 else 'medium' if action_confidence > 0.6 else 'low'
|
||||
}
|
||||
|
||||
def get_memory_usage(self) -> Dict[str, Any]:
|
||||
|
@@ -15,11 +15,19 @@ Architecture:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
# Try to import numpy, but provide fallback if not available
|
||||
try:
|
||||
import numpy as np
|
||||
HAS_NUMPY = True
|
||||
except ImportError:
|
||||
np = None
|
||||
HAS_NUMPY = False
|
||||
logging.warning("NumPy not available - COB RL model will have limited functionality")
|
||||
|
||||
from .model_interfaces import ModelInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -164,12 +172,12 @@ class MassiveRLNetwork(nn.Module):
|
||||
'features': x # Hidden features for analysis
|
||||
}
|
||||
|
||||
def predict(self, cob_features: np.ndarray) -> Dict[str, Any]:
|
||||
def predict(self, cob_features) -> Dict[str, Any]:
|
||||
"""
|
||||
High-level prediction method for COB features
|
||||
|
||||
Args:
|
||||
cob_features: COB features as numpy array [input_size]
|
||||
cob_features: COB features as tensor or numpy array [input_size]
|
||||
|
||||
Returns:
|
||||
Dict containing prediction results
|
||||
@@ -177,10 +185,13 @@ class MassiveRLNetwork(nn.Module):
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
# Convert to tensor and add batch dimension
|
||||
if isinstance(cob_features, np.ndarray):
|
||||
if HAS_NUMPY and isinstance(cob_features, np.ndarray):
|
||||
x = torch.from_numpy(cob_features).float()
|
||||
else:
|
||||
elif isinstance(cob_features, torch.Tensor):
|
||||
x = cob_features.float()
|
||||
else:
|
||||
# Try to convert from list or other format
|
||||
x = torch.tensor(cob_features, dtype=torch.float32)
|
||||
|
||||
if x.dim() == 1:
|
||||
x = x.unsqueeze(0) # Add batch dimension
|
||||
@@ -198,11 +209,17 @@ class MassiveRLNetwork(nn.Module):
|
||||
confidence = outputs['confidence'].item()
|
||||
value = outputs['value'].item()
|
||||
|
||||
# Convert probabilities to list (works with or without numpy)
|
||||
if HAS_NUMPY:
|
||||
probabilities = price_probs.cpu().numpy()[0].tolist()
|
||||
else:
|
||||
probabilities = price_probs.cpu().tolist()[0]
|
||||
|
||||
return {
|
||||
'predicted_direction': predicted_direction, # 0=DOWN, 1=SIDEWAYS, 2=UP
|
||||
'confidence': confidence,
|
||||
'value': value,
|
||||
'probabilities': price_probs.cpu().numpy()[0],
|
||||
'probabilities': probabilities,
|
||||
'direction_text': ['DOWN', 'SIDEWAYS', 'UP'][predicted_direction]
|
||||
}
|
||||
|
||||
@@ -250,15 +267,18 @@ class COBRLModelInterface(ModelInterface):
|
||||
|
||||
logger.info(f"COB RL Model Interface initialized on {self.device}")
|
||||
|
||||
def predict(self, cob_features: np.ndarray) -> Dict[str, Any]:
|
||||
def predict(self, cob_features) -> Dict[str, Any]:
|
||||
"""Make prediction using the model"""
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
# Convert to tensor and add batch dimension
|
||||
if isinstance(cob_features, np.ndarray):
|
||||
if HAS_NUMPY and isinstance(cob_features, np.ndarray):
|
||||
x = torch.from_numpy(cob_features).float()
|
||||
else:
|
||||
elif isinstance(cob_features, torch.Tensor):
|
||||
x = cob_features.float()
|
||||
else:
|
||||
# Try to convert from list or other format
|
||||
x = torch.tensor(cob_features, dtype=torch.float32)
|
||||
|
||||
if x.dim() == 1:
|
||||
x = x.unsqueeze(0) # Add batch dimension
|
||||
@@ -275,11 +295,17 @@ class COBRLModelInterface(ModelInterface):
|
||||
confidence = outputs['confidence'].item()
|
||||
value = outputs['value'].item()
|
||||
|
||||
# Convert probabilities to list (works with or without numpy)
|
||||
if HAS_NUMPY:
|
||||
probabilities = price_probs.cpu().numpy()[0].tolist()
|
||||
else:
|
||||
probabilities = price_probs.cpu().tolist()[0]
|
||||
|
||||
return {
|
||||
'predicted_direction': predicted_direction, # 0=DOWN, 1=SIDEWAYS, 2=UP
|
||||
'confidence': confidence,
|
||||
'value': value,
|
||||
'probabilities': price_probs.cpu().numpy()[0],
|
||||
'probabilities': probabilities,
|
||||
'direction_text': ['DOWN', 'SIDEWAYS', 'UP'][predicted_direction]
|
||||
}
|
||||
|
||||
|
780
NN/models/multi_timeframe_predictor.py
Normal file
780
NN/models/multi_timeframe_predictor.py
Normal file
@@ -0,0 +1,780 @@
|
||||
"""
|
||||
Multi-Timeframe Prediction System for Enhanced Trading
|
||||
|
||||
This module implements a sophisticated multi-timeframe prediction system that allows
|
||||
models to make predictions for different time horizons (1, 5, 10 minutes) with
|
||||
appropriate confidence thresholds and position holding strategies.
|
||||
|
||||
Key Features:
|
||||
- Dynamic sequence length adaptation for different timeframes
|
||||
- Confidence calibration based on prediction horizon
|
||||
- Position holding logic for longer-term trades
|
||||
- Risk-adjusted trading strategies
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PredictionHorizon(Enum):
|
||||
"""Prediction time horizons"""
|
||||
ONE_MINUTE = 1
|
||||
FIVE_MINUTES = 5
|
||||
TEN_MINUTES = 10
|
||||
|
||||
class ConfidenceThreshold(Enum):
|
||||
"""Confidence thresholds for different horizons"""
|
||||
ONE_MINUTE = 0.35 # Lower threshold for quick trades
|
||||
FIVE_MINUTES = 0.65 # Higher threshold for 5-minute holds
|
||||
TEN_MINUTES = 0.80 # Very high threshold for 10-minute holds
|
||||
|
||||
@dataclass
|
||||
class MultiTimeframePrediction:
|
||||
"""Container for multi-timeframe predictions"""
|
||||
symbol: str
|
||||
current_price: float
|
||||
predictions: Dict[PredictionHorizon, Dict[str, Any]]
|
||||
timestamp: datetime
|
||||
market_conditions: Dict[str, Any]
|
||||
|
||||
class MultiTimeframePredictor:
|
||||
"""
|
||||
Advanced multi-timeframe prediction system that adapts model behavior
|
||||
based on desired prediction horizon and market conditions.
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator):
|
||||
self.orchestrator = orchestrator
|
||||
self.horizons = {
|
||||
PredictionHorizon.ONE_MINUTE: {
|
||||
'sequence_length': 60, # 60 minutes for 1-minute predictions
|
||||
'confidence_threshold': ConfidenceThreshold.ONE_MINUTE.value,
|
||||
'max_hold_time': 60, # 1 minute max hold
|
||||
'risk_multiplier': 1.0
|
||||
},
|
||||
PredictionHorizon.FIVE_MINUTES: {
|
||||
'sequence_length': 300, # 300 minutes for 5-minute predictions
|
||||
'confidence_threshold': ConfidenceThreshold.FIVE_MINUTES.value,
|
||||
'max_hold_time': 300, # 5 minutes max hold
|
||||
'risk_multiplier': 1.5 # Higher risk for longer holds
|
||||
},
|
||||
PredictionHorizon.TEN_MINUTES: {
|
||||
'sequence_length': 600, # 600 minutes for 10-minute predictions
|
||||
'confidence_threshold': ConfidenceThreshold.TEN_MINUTES.value,
|
||||
'max_hold_time': 600, # 10 minutes max hold
|
||||
'risk_multiplier': 2.0 # Highest risk for longest holds
|
||||
}
|
||||
}
|
||||
|
||||
# Initialize models for different horizons
|
||||
self.models = {}
|
||||
self._initialize_multi_horizon_models()
|
||||
|
||||
def _initialize_multi_horizon_models(self):
|
||||
"""Initialize separate model instances for different horizons"""
|
||||
try:
|
||||
for horizon, config in self.horizons.items():
|
||||
# CNN Model for this horizon
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
# Create horizon-specific model configuration
|
||||
horizon_model = self._create_horizon_specific_model(
|
||||
self.orchestrator.cnn_model,
|
||||
config['sequence_length'],
|
||||
horizon
|
||||
)
|
||||
self.models[f'cnn_{horizon.value}min'] = horizon_model
|
||||
|
||||
# COB RL Model for this horizon
|
||||
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
||||
self.models[f'cob_rl_{horizon.value}min'] = self.orchestrator.cob_rl_agent
|
||||
|
||||
logger.info(f"Initialized {horizon.value}-minute prediction model")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing multi-horizon models: {e}")
|
||||
|
||||
def _create_horizon_specific_model(self, base_model, sequence_length: int, horizon: PredictionHorizon):
|
||||
"""Create a model instance optimized for specific prediction horizon"""
|
||||
try:
|
||||
# For CNN models, we need to adjust input size and potentially architecture
|
||||
if hasattr(base_model, '__class__'):
|
||||
model_class = base_model.__class__
|
||||
|
||||
# Calculate appropriate input size for horizon
|
||||
# More data for longer predictions
|
||||
adjusted_input_size = min(sequence_length, 300) # Cap at 300 to avoid memory issues
|
||||
|
||||
# Create new model instance with horizon-specific parameters
|
||||
# Use only the parameters that the model actually accepts
|
||||
try:
|
||||
horizon_model = model_class(
|
||||
input_size=adjusted_input_size,
|
||||
feature_dim=getattr(base_model, 'feature_dim', 50),
|
||||
output_size=5, # Always use 5 for OHLCV predictions
|
||||
prediction_horizon=horizon.value
|
||||
)
|
||||
except TypeError:
|
||||
# If the model doesn't accept these parameters, just create with defaults
|
||||
logger.warning(f"Model {model_class.__name__} doesn't accept expected parameters, using defaults")
|
||||
horizon_model = model_class()
|
||||
|
||||
# Try to load pre-trained weights if available
|
||||
try:
|
||||
if hasattr(base_model, 'state_dict'):
|
||||
# Load base model weights and adapt if necessary
|
||||
base_state = base_model.state_dict()
|
||||
horizon_model.load_state_dict(base_state, strict=False)
|
||||
logger.info(f"Loaded base model weights for {horizon.value}-minute horizon")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load base weights for {horizon.value}-minute model: {e}")
|
||||
|
||||
return horizon_model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating horizon-specific model: {e}")
|
||||
return base_model # Fallback to base model
|
||||
|
||||
def generate_multi_timeframe_prediction(self, symbol: str) -> Optional[MultiTimeframePrediction]:
|
||||
"""
|
||||
Generate predictions for all timeframes with appropriate confidence thresholds
|
||||
"""
|
||||
try:
|
||||
# Get current market data
|
||||
current_price = self._get_current_price(symbol)
|
||||
if not current_price:
|
||||
return None
|
||||
|
||||
# Get market conditions for confidence adjustment
|
||||
market_conditions = self._assess_market_conditions(symbol)
|
||||
|
||||
predictions = {}
|
||||
|
||||
# Generate predictions for each horizon
|
||||
for horizon, config in self.horizons.items():
|
||||
prediction = self._generate_single_horizon_prediction(
|
||||
symbol, current_price, horizon, config, market_conditions
|
||||
)
|
||||
if prediction:
|
||||
predictions[horizon] = prediction
|
||||
|
||||
if not predictions:
|
||||
return None
|
||||
|
||||
return MultiTimeframePrediction(
|
||||
symbol=symbol,
|
||||
current_price=current_price,
|
||||
predictions=predictions,
|
||||
timestamp=datetime.now(),
|
||||
market_conditions=market_conditions
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating multi-timeframe prediction: {e}")
|
||||
return None
|
||||
|
||||
def _generate_single_horizon_prediction(self, symbol: str, current_price: float,
|
||||
horizon: PredictionHorizon, config: Dict,
|
||||
market_conditions: Dict) -> Optional[Dict[str, Any]]:
|
||||
"""Generate prediction for single timeframe using iterative candle prediction"""
|
||||
try:
|
||||
# Get base historical data (use shorter sequence for iterative prediction)
|
||||
base_sequence_length = min(60, config['sequence_length'] // 2) # Use half for base data
|
||||
base_data = self._get_sequence_data_for_horizon(symbol, base_sequence_length)
|
||||
|
||||
if not base_data:
|
||||
return None
|
||||
|
||||
# Generate iterative predictions for this horizon
|
||||
iterative_predictions = self._generate_iterative_predictions(
|
||||
symbol, base_data, horizon.value, market_conditions
|
||||
)
|
||||
|
||||
if not iterative_predictions:
|
||||
return None
|
||||
|
||||
# Analyze the predicted price movement over the horizon
|
||||
horizon_prediction = self._analyze_horizon_prediction(
|
||||
iterative_predictions, config, market_conditions
|
||||
)
|
||||
|
||||
# Apply confidence threshold
|
||||
if horizon_prediction['confidence'] < config['confidence_threshold']:
|
||||
return None # Not confident enough for this horizon
|
||||
|
||||
return horizon_prediction
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating {horizon.value}-minute prediction: {e}")
|
||||
return None
|
||||
|
||||
def _get_sequence_data_for_horizon(self, symbol: str, sequence_length: int) -> Optional[torch.Tensor]:
|
||||
"""Get appropriate sequence data for prediction horizon"""
|
||||
try:
|
||||
# This would need to be implemented based on your data provider
|
||||
# For now, return a placeholder
|
||||
if hasattr(self.orchestrator, 'data_provider'):
|
||||
# Get historical data for the required sequence length
|
||||
data = self.orchestrator.data_provider.get_historical_data(
|
||||
symbol, '1m', limit=sequence_length
|
||||
)
|
||||
|
||||
if data is not None and len(data) >= sequence_length // 10: # At least 10% of required data
|
||||
# Convert to tensor format expected by models
|
||||
tensor_data = self._convert_data_to_tensor(data)
|
||||
if tensor_data is not None:
|
||||
logger.debug(f"✅ Converted {len(data)} data points to tensor shape: {tensor_data.shape}")
|
||||
return tensor_data
|
||||
else:
|
||||
logger.warning("Failed to convert data to tensor")
|
||||
return None
|
||||
else:
|
||||
logger.warning(f"Insufficient data for {sequence_length}-point prediction: {len(data) if data is not None else 'None'}")
|
||||
return None
|
||||
|
||||
# Fallback: create mock data if no data provider available
|
||||
logger.warning("No data provider available - creating mock sequence data")
|
||||
return self._create_mock_sequence_data(sequence_length)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting sequence data: {e}")
|
||||
# Fallback: create mock data on error
|
||||
logger.warning("Creating mock sequence data due to error")
|
||||
return self._create_mock_sequence_data(sequence_length)
|
||||
|
||||
def _convert_data_to_tensor(self, data) -> torch.Tensor:
|
||||
"""Convert market data to tensor format"""
|
||||
try:
|
||||
# This is a placeholder - implement based on your data format
|
||||
if hasattr(data, 'values'):
|
||||
# Assume pandas DataFrame
|
||||
features = ['open', 'high', 'low', 'close', 'volume']
|
||||
feature_data = []
|
||||
|
||||
for feature in features:
|
||||
if feature in data.columns:
|
||||
values = data[feature].ffill().fillna(0).values
|
||||
feature_data.append(values)
|
||||
|
||||
if feature_data:
|
||||
# Ensure all feature arrays have the same length
|
||||
min_length = min(len(arr) for arr in feature_data)
|
||||
feature_data = [arr[:min_length] for arr in feature_data]
|
||||
|
||||
# Stack features
|
||||
tensor_data = torch.tensor(feature_data, dtype=torch.float32).transpose(0, 1)
|
||||
|
||||
# Validate tensor data
|
||||
if torch.any(torch.isnan(tensor_data)) or torch.any(torch.isinf(tensor_data)):
|
||||
logger.warning("Found NaN or Inf values in tensor data, replacing with zeros")
|
||||
tensor_data = torch.nan_to_num(tensor_data, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
return tensor_data.unsqueeze(0) # Add batch dimension
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting data to tensor: {e}")
|
||||
return None
|
||||
|
||||
def _get_cnn_prediction(self, model, sequence_data: torch.Tensor, config: Dict) -> Optional[Dict]:
|
||||
"""Get CNN model prediction using OHLCV prediction"""
|
||||
try:
|
||||
# Use the predict method which now handles OHLCV predictions
|
||||
if hasattr(model, 'predict'):
|
||||
if sequence_data.dim() == 3: # [batch, seq, features]
|
||||
sequence_data_flat = sequence_data.squeeze(0) # Remove batch dim
|
||||
else:
|
||||
sequence_data_flat = sequence_data
|
||||
|
||||
prediction = model.predict(sequence_data_flat)
|
||||
|
||||
if prediction and 'action_name' in prediction:
|
||||
return {
|
||||
'action': prediction['action_name'],
|
||||
'confidence': prediction.get('action_confidence', 0.5),
|
||||
'model': 'cnn',
|
||||
'horizon': config.get('max_hold_time', 60),
|
||||
'ohlcv_prediction': prediction.get('ohlcv_prediction'),
|
||||
'price_change_pct': prediction.get('price_change_pct', 0)
|
||||
}
|
||||
|
||||
# Fallback to direct forward pass if predict method not available
|
||||
with torch.no_grad():
|
||||
outputs = model(sequence_data)
|
||||
if isinstance(outputs, dict) and 'ohlcv' in outputs:
|
||||
ohlcv = outputs['ohlcv'].cpu().numpy()[0]
|
||||
confidence = outputs['confidence'].cpu().numpy()[0] if hasattr(outputs['confidence'], 'cpu') else outputs['confidence']
|
||||
|
||||
# Determine action from OHLCV
|
||||
price_change_pct = ((ohlcv[3] - ohlcv[0]) / ohlcv[0]) * 100 if ohlcv[0] != 0 else 0
|
||||
|
||||
if price_change_pct > 0.1:
|
||||
action = 'BUY'
|
||||
elif price_change_pct < -0.1:
|
||||
action = 'SELL'
|
||||
else:
|
||||
action = 'HOLD'
|
||||
|
||||
return {
|
||||
'action': action,
|
||||
'confidence': float(confidence),
|
||||
'model': 'cnn',
|
||||
'horizon': config.get('max_hold_time', 60),
|
||||
'ohlcv_prediction': {
|
||||
'open': float(ohlcv[0]),
|
||||
'high': float(ohlcv[1]),
|
||||
'low': float(ohlcv[2]),
|
||||
'close': float(ohlcv[3]),
|
||||
'volume': float(ohlcv[4])
|
||||
},
|
||||
'price_change_pct': price_change_pct
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting CNN prediction: {e}")
|
||||
return None
|
||||
|
||||
def _get_cob_rl_prediction(self, model, sequence_data: torch.Tensor, config: Dict) -> Optional[Dict]:
|
||||
"""Get COB RL model prediction"""
|
||||
try:
|
||||
# This would need to be implemented based on your COB RL model interface
|
||||
if hasattr(model, 'predict'):
|
||||
result = model.predict(sequence_data)
|
||||
return {
|
||||
'action': result.get('action', 'HOLD'),
|
||||
'confidence': result.get('confidence', 0.5),
|
||||
'model': 'cob_rl',
|
||||
'horizon': config.get('max_hold_time', 60)
|
||||
}
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting COB RL prediction: {e}")
|
||||
return None
|
||||
|
||||
def _ensemble_predictions(self, predictions: List[Dict], config: Dict,
|
||||
market_conditions: Dict) -> Dict[str, Any]:
|
||||
"""Ensemble multiple model predictions using OHLCV data"""
|
||||
try:
|
||||
if not predictions:
|
||||
return None
|
||||
|
||||
# Enhanced ensemble considering both action and price movement
|
||||
action_votes = {}
|
||||
confidence_sum = 0
|
||||
price_change_indicators = []
|
||||
|
||||
for pred in predictions:
|
||||
action = pred['action']
|
||||
confidence = pred['confidence']
|
||||
|
||||
# Weight by confidence
|
||||
if action not in action_votes:
|
||||
action_votes[action] = 0
|
||||
action_votes[action] += confidence
|
||||
confidence_sum += confidence
|
||||
|
||||
# Collect price change indicators for ensemble analysis
|
||||
if 'price_change_pct' in pred:
|
||||
price_change_indicators.append(pred['price_change_pct'])
|
||||
|
||||
# Get winning action
|
||||
if action_votes:
|
||||
best_action = max(action_votes, key=action_votes.get)
|
||||
ensemble_confidence = action_votes[best_action] / len(predictions)
|
||||
else:
|
||||
best_action = 'HOLD'
|
||||
ensemble_confidence = 0.1
|
||||
|
||||
# Analyze price movement consensus
|
||||
if price_change_indicators:
|
||||
avg_price_change = sum(price_change_indicators) / len(price_change_indicators)
|
||||
price_consensus = abs(avg_price_change) / 0.1 # Normalize around 0.1% threshold
|
||||
|
||||
# Boost confidence if price movements are consistent
|
||||
if len(price_change_indicators) > 1:
|
||||
price_std = torch.std(torch.tensor(price_change_indicators)).item()
|
||||
if price_std < 0.05: # Low variability in predictions
|
||||
ensemble_confidence *= 1.2
|
||||
elif price_std > 0.15: # High variability
|
||||
ensemble_confidence *= 0.8
|
||||
|
||||
# Override action based on strong price consensus
|
||||
if abs(avg_price_change) > 0.2: # Strong price movement
|
||||
if avg_price_change > 0:
|
||||
best_action = 'BUY'
|
||||
else:
|
||||
best_action = 'SELL'
|
||||
ensemble_confidence = min(ensemble_confidence * 1.3, 0.9)
|
||||
|
||||
# Adjust confidence based on market conditions
|
||||
market_confidence_multiplier = market_conditions.get('confidence_multiplier', 1.0)
|
||||
final_confidence = min(ensemble_confidence * market_confidence_multiplier, 1.0)
|
||||
|
||||
return {
|
||||
'action': best_action,
|
||||
'confidence': final_confidence,
|
||||
'horizon_minutes': config['max_hold_time'] // 60,
|
||||
'risk_multiplier': config['risk_multiplier'],
|
||||
'models_used': len(predictions),
|
||||
'market_conditions': market_conditions,
|
||||
'price_change_indicators': price_change_indicators,
|
||||
'avg_price_change_pct': sum(price_change_indicators) / len(price_change_indicators) if price_change_indicators else 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prediction ensemble: {e}")
|
||||
return None
|
||||
|
||||
def _assess_market_conditions(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Assess current market conditions for confidence adjustment"""
|
||||
try:
|
||||
conditions = {
|
||||
'volatility': 'medium',
|
||||
'trend': 'sideways',
|
||||
'confidence_multiplier': 1.0,
|
||||
'risk_level': 'normal'
|
||||
}
|
||||
|
||||
# This could be enhanced with actual market analysis
|
||||
# For now, return default conditions
|
||||
return conditions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error assessing market conditions: {e}")
|
||||
return {'confidence_multiplier': 1.0}
|
||||
|
||||
def _get_current_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get current price for symbol"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'data_provider'):
|
||||
ticker = self.orchestrator.data_provider.get_current_price(symbol)
|
||||
return ticker
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current price for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def should_execute_trade(self, prediction: MultiTimeframePrediction) -> Tuple[bool, str]:
|
||||
"""
|
||||
Determine if a trade should be executed based on multi-timeframe analysis
|
||||
"""
|
||||
try:
|
||||
if not prediction or not prediction.predictions:
|
||||
return False, "No predictions available"
|
||||
|
||||
# Find the best prediction across all horizons
|
||||
best_prediction = None
|
||||
best_confidence = 0
|
||||
|
||||
for horizon, pred in prediction.predictions.items():
|
||||
if pred['confidence'] > best_confidence:
|
||||
best_confidence = pred['confidence']
|
||||
best_prediction = (horizon, pred)
|
||||
|
||||
if not best_prediction:
|
||||
return False, "No valid predictions"
|
||||
|
||||
horizon, pred = best_prediction
|
||||
config = self.horizons[horizon]
|
||||
|
||||
# Check if confidence meets threshold
|
||||
if pred['confidence'] < config['confidence_threshold']:
|
||||
return False, ".2f"
|
||||
|
||||
# Check market conditions
|
||||
market_risk = prediction.market_conditions.get('risk_level', 'normal')
|
||||
if market_risk == 'high' and horizon.value >= 5:
|
||||
return False, "High market risk - avoiding longer-term predictions"
|
||||
|
||||
return True, f"Valid {horizon.value}-minute prediction with {pred['confidence']:.2f} confidence"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in trade execution decision: {e}")
|
||||
return False, f"Decision error: {e}"
|
||||
|
||||
def get_position_hold_time(self, prediction: MultiTimeframePrediction) -> int:
|
||||
"""Determine how long to hold a position based on prediction horizon"""
|
||||
try:
|
||||
if not prediction or not prediction.predictions:
|
||||
return 60 # Default 1 minute
|
||||
|
||||
# Use the longest horizon prediction that's available and confident
|
||||
max_horizon = 1
|
||||
for horizon, pred in prediction.predictions.items():
|
||||
config = self.horizons[horizon]
|
||||
if pred['confidence'] >= config['confidence_threshold']:
|
||||
max_horizon = max(max_horizon, horizon.value)
|
||||
|
||||
return max_horizon * 60 # Convert minutes to seconds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error determining hold time: {e}")
|
||||
return 60
|
||||
|
||||
def _generate_iterative_predictions(self, symbol: str, base_data: torch.Tensor,
|
||||
num_steps: int, market_conditions: Dict) -> Optional[List[Dict]]:
|
||||
"""Generate iterative candle predictions for the specified number of steps"""
|
||||
try:
|
||||
predictions = []
|
||||
current_data = base_data.clone() # Start with base historical data
|
||||
|
||||
# Get the CNN model for iterative prediction
|
||||
cnn_model = None
|
||||
for model_key, model in self.models.items():
|
||||
if model_key.startswith('cnn_'):
|
||||
cnn_model = model
|
||||
break
|
||||
|
||||
if not cnn_model:
|
||||
logger.warning("No CNN model available for iterative prediction")
|
||||
return None
|
||||
|
||||
# Check if CNN model has predict method
|
||||
if not hasattr(cnn_model, 'predict'):
|
||||
logger.warning("CNN model does not have predict method - trying alternative approach")
|
||||
# Try to use the orchestrator's CNN model directly
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
cnn_model = self.orchestrator.cnn_model
|
||||
logger.info("Using orchestrator's CNN model for predictions")
|
||||
|
||||
# Check if orchestrator's CNN model also lacks predict method
|
||||
if not hasattr(cnn_model, 'predict'):
|
||||
logger.error("Orchestrator's CNN model also lacks predict method - creating mock predictions")
|
||||
return self._create_mock_predictions(num_steps)
|
||||
else:
|
||||
logger.error("No CNN model with predict method available - creating mock predictions")
|
||||
# Create mock predictions for testing
|
||||
return self._create_mock_predictions(num_steps)
|
||||
|
||||
for step in range(num_steps):
|
||||
# Use CNN model to predict next candle
|
||||
try:
|
||||
with torch.no_grad():
|
||||
# Prepare data for CNN prediction
|
||||
# Convert tensor to format expected by predict method
|
||||
if current_data.dim() == 3: # [batch, seq, features]
|
||||
current_data_flat = current_data.squeeze(0) # Remove batch dim
|
||||
else:
|
||||
current_data_flat = current_data
|
||||
|
||||
prediction = cnn_model.predict(current_data_flat)
|
||||
|
||||
if prediction and 'ohlcv_prediction' in prediction:
|
||||
# Add timestamp to the prediction
|
||||
prediction_time = datetime.now() + timedelta(minutes=step + 1)
|
||||
prediction['timestamp'] = prediction_time
|
||||
predictions.append(prediction)
|
||||
logger.debug(f"📊 Step {step}: Added prediction for {prediction_time}, close: {prediction['ohlcv_prediction']['close']:.2f}")
|
||||
|
||||
# Extract predicted OHLCV values
|
||||
ohlcv = prediction['ohlcv_prediction']
|
||||
new_candle = torch.tensor([
|
||||
ohlcv['open'],
|
||||
ohlcv['high'],
|
||||
ohlcv['low'],
|
||||
ohlcv['close'],
|
||||
ohlcv['volume']
|
||||
], dtype=current_data.dtype)
|
||||
|
||||
# Add the predicted candle to our data sequence
|
||||
# Remove oldest candle and add new prediction
|
||||
if current_data.dim() == 3:
|
||||
current_data = torch.cat([
|
||||
current_data[:, 1:, :], # Remove oldest candle
|
||||
new_candle.unsqueeze(0).unsqueeze(0) # Add new prediction
|
||||
], dim=1)
|
||||
else:
|
||||
current_data = torch.cat([
|
||||
current_data[1:, :], # Remove oldest candle
|
||||
new_candle.unsqueeze(0) # Add new prediction
|
||||
], dim=0)
|
||||
else:
|
||||
logger.warning(f"❌ Step {step}: Invalid prediction format")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in iterative prediction step {step}: {e}")
|
||||
break
|
||||
|
||||
return predictions if predictions else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in iterative predictions: {e}")
|
||||
return None
|
||||
|
||||
def _create_mock_predictions(self, num_steps: int) -> List[Dict]:
|
||||
"""Create mock predictions for testing when CNN model is not available"""
|
||||
try:
|
||||
logger.info(f"Creating {num_steps} mock predictions for testing")
|
||||
predictions = []
|
||||
current_time = datetime.now()
|
||||
base_price = 4300.0 # Mock base price
|
||||
|
||||
for step in range(num_steps):
|
||||
prediction_time = current_time + timedelta(minutes=step + 1)
|
||||
price_change = (step - num_steps // 2) * 2.0 # Mock price movement
|
||||
predicted_price = base_price + price_change
|
||||
|
||||
mock_prediction = {
|
||||
'timestamp': prediction_time,
|
||||
'ohlcv_prediction': {
|
||||
'open': predicted_price,
|
||||
'high': predicted_price + 1.0,
|
||||
'low': predicted_price - 1.0,
|
||||
'close': predicted_price + 0.5,
|
||||
'volume': 1000
|
||||
},
|
||||
'confidence': max(0.3, 0.8 - step * 0.05), # Decreasing confidence
|
||||
'action': 0 if price_change > 0 else 1,
|
||||
'action_name': 'BUY' if price_change > 0 else 'SELL'
|
||||
}
|
||||
predictions.append(mock_prediction)
|
||||
|
||||
logger.info(f"✅ Created {len(predictions)} mock predictions")
|
||||
return predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating mock predictions: {e}")
|
||||
return []
|
||||
|
||||
def _create_mock_sequence_data(self, sequence_length: int) -> torch.Tensor:
|
||||
"""Create mock sequence data for testing when real data is not available"""
|
||||
try:
|
||||
logger.info(f"Creating mock sequence data with {sequence_length} points")
|
||||
|
||||
# Create mock OHLCV data
|
||||
base_price = 4300.0
|
||||
mock_data = []
|
||||
|
||||
for i in range(sequence_length):
|
||||
# Simulate price movement
|
||||
price_change = (i - sequence_length // 2) * 0.5
|
||||
price = base_price + price_change
|
||||
|
||||
# Create OHLCV candle
|
||||
candle = [
|
||||
price, # open
|
||||
price + 1.0, # high
|
||||
price - 1.0, # low
|
||||
price + 0.5, # close
|
||||
1000.0 # volume
|
||||
]
|
||||
mock_data.append(candle)
|
||||
|
||||
# Convert to tensor
|
||||
tensor_data = torch.tensor(mock_data, dtype=torch.float32)
|
||||
tensor_data = tensor_data.unsqueeze(0) # Add batch dimension
|
||||
|
||||
logger.debug(f"✅ Created mock sequence data shape: {tensor_data.shape}")
|
||||
return tensor_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating mock sequence data: {e}")
|
||||
# Return minimal valid tensor
|
||||
return torch.zeros((1, 10, 5), dtype=torch.float32)
|
||||
|
||||
def _analyze_horizon_prediction(self, iterative_predictions: List[Dict],
|
||||
config: Dict, market_conditions: Dict) -> Optional[Dict[str, Any]]:
|
||||
"""Analyze the series of iterative predictions to determine overall horizon movement"""
|
||||
try:
|
||||
if not iterative_predictions:
|
||||
return None
|
||||
|
||||
# Extract price data from predictions
|
||||
predicted_prices = []
|
||||
confidences = []
|
||||
actions = []
|
||||
|
||||
for pred in iterative_predictions:
|
||||
if 'ohlcv_prediction' in pred:
|
||||
close_price = pred['ohlcv_prediction']['close']
|
||||
predicted_prices.append(close_price)
|
||||
|
||||
confidence = pred.get('action_confidence', 0.5)
|
||||
confidences.append(confidence)
|
||||
|
||||
action = pred.get('action', 2) # Default to HOLD
|
||||
actions.append(action)
|
||||
|
||||
if not predicted_prices:
|
||||
return None
|
||||
|
||||
# Calculate overall price movement
|
||||
start_price = predicted_prices[0]
|
||||
end_price = predicted_prices[-1]
|
||||
total_change = end_price - start_price
|
||||
total_change_pct = (total_change / start_price) * 100 if start_price != 0 else 0
|
||||
|
||||
# Calculate volatility and trend strength
|
||||
price_volatility = torch.std(torch.tensor(predicted_prices)).item()
|
||||
avg_confidence = sum(confidences) / len(confidences)
|
||||
|
||||
# Determine overall action based on price movement and confidence
|
||||
if total_change_pct > 0.5: # Overall bullish movement
|
||||
action = 0 # BUY
|
||||
action_name = 'BUY'
|
||||
confidence_multiplier = 1.2
|
||||
elif total_change_pct < -0.5: # Overall bearish movement
|
||||
action = 1 # SELL
|
||||
action_name = 'SELL'
|
||||
confidence_multiplier = 1.2
|
||||
else: # Sideways movement
|
||||
# Use majority vote from individual predictions
|
||||
buy_count = sum(1 for a in actions if a == 0)
|
||||
sell_count = sum(1 for a in actions if a == 1)
|
||||
|
||||
if buy_count > sell_count:
|
||||
action = 0
|
||||
action_name = 'BUY'
|
||||
confidence_multiplier = 0.8 # Reduce confidence for mixed signals
|
||||
elif sell_count > buy_count:
|
||||
action = 1
|
||||
action_name = 'SELL'
|
||||
confidence_multiplier = 0.8
|
||||
else:
|
||||
action = 2 # HOLD
|
||||
action_name = 'HOLD'
|
||||
confidence_multiplier = 0.5
|
||||
|
||||
# Calculate final confidence
|
||||
final_confidence = avg_confidence * confidence_multiplier
|
||||
|
||||
# Adjust for market conditions
|
||||
market_multiplier = market_conditions.get('confidence_multiplier', 1.0)
|
||||
final_confidence *= market_multiplier
|
||||
|
||||
# Cap confidence at reasonable levels
|
||||
final_confidence = min(0.95, max(0.1, final_confidence))
|
||||
|
||||
# Adjust for volatility
|
||||
if price_volatility > 0.02: # High volatility in predictions
|
||||
final_confidence *= 0.9
|
||||
|
||||
return {
|
||||
'action': action,
|
||||
'action_name': action_name,
|
||||
'confidence': final_confidence,
|
||||
'horizon_minutes': config['max_hold_time'] // 60,
|
||||
'total_price_change_pct': total_change_pct,
|
||||
'price_volatility': price_volatility,
|
||||
'avg_prediction_confidence': avg_confidence,
|
||||
'num_predictions': len(iterative_predictions),
|
||||
'risk_multiplier': config['risk_multiplier'],
|
||||
'market_conditions': market_conditions,
|
||||
'prediction_series': {
|
||||
'prices': predicted_prices,
|
||||
'confidences': confidences,
|
||||
'actions': actions
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing horizon prediction: {e}")
|
||||
return None
|
@@ -197,6 +197,10 @@ class ModelManager:
|
||||
self.nn_models_dir = self.base_dir / "NN" / "models"
|
||||
self.legacy_models_dir = self.base_dir / "models"
|
||||
|
||||
# Legacy checkpoint directories (where existing checkpoints are stored)
|
||||
self.legacy_checkpoints_dir = self.nn_models_dir / "checkpoints"
|
||||
self.legacy_registry_file = self.legacy_checkpoints_dir / "registry_metadata.json"
|
||||
|
||||
# Metadata and checkpoint management
|
||||
self.metadata_file = self.checkpoints_dir / "model_metadata.json"
|
||||
self.checkpoint_metadata_file = self.checkpoints_dir / "checkpoint_metadata.json"
|
||||
@@ -232,14 +236,72 @@ class ModelManager:
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _load_metadata(self) -> Dict[str, Any]:
|
||||
"""Load model metadata"""
|
||||
"""Load model metadata with legacy support"""
|
||||
metadata = {'models': {}, 'last_updated': datetime.now().isoformat()}
|
||||
|
||||
# First try to load from new unified metadata
|
||||
if self.metadata_file.exists():
|
||||
try:
|
||||
with open(self.metadata_file, 'r') as f:
|
||||
return json.load(f)
|
||||
metadata = json.load(f)
|
||||
logger.info(f"Loaded unified metadata from {self.metadata_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading metadata: {e}")
|
||||
return {'models': {}, 'last_updated': datetime.now().isoformat()}
|
||||
logger.error(f"Error loading unified metadata: {e}")
|
||||
|
||||
# Also load legacy metadata for backward compatibility
|
||||
if self.legacy_registry_file.exists():
|
||||
try:
|
||||
with open(self.legacy_registry_file, 'r') as f:
|
||||
legacy_data = json.load(f)
|
||||
|
||||
# Merge legacy data into unified metadata
|
||||
if 'models' in legacy_data:
|
||||
for model_name, model_info in legacy_data['models'].items():
|
||||
if model_name not in metadata['models']:
|
||||
# Convert legacy path format to absolute path
|
||||
if 'latest_path' in model_info:
|
||||
legacy_path = model_info['latest_path']
|
||||
|
||||
# Handle different legacy path formats
|
||||
if not legacy_path.startswith('/'):
|
||||
# Try multiple path resolution strategies
|
||||
possible_paths = [
|
||||
self.legacy_checkpoints_dir / legacy_path, # NN/models/checkpoints/models/cnn/...
|
||||
self.legacy_checkpoints_dir.parent / legacy_path, # NN/models/models/cnn/...
|
||||
self.base_dir / legacy_path, # /project/models/cnn/...
|
||||
]
|
||||
|
||||
resolved_path = None
|
||||
for path in possible_paths:
|
||||
if path.exists():
|
||||
resolved_path = path
|
||||
break
|
||||
|
||||
if resolved_path:
|
||||
legacy_path = str(resolved_path)
|
||||
else:
|
||||
# If no resolved path found, try to find the file by pattern
|
||||
filename = Path(legacy_path).name
|
||||
for search_path in [self.legacy_checkpoints_dir]:
|
||||
for file_path in search_path.rglob(filename):
|
||||
legacy_path = str(file_path)
|
||||
break
|
||||
|
||||
metadata['models'][model_name] = {
|
||||
'type': model_info.get('type', 'unknown'),
|
||||
'latest_path': legacy_path,
|
||||
'last_saved': model_info.get('last_saved', 'legacy'),
|
||||
'save_count': model_info.get('save_count', 1),
|
||||
'checkpoints': model_info.get('checkpoints', [])
|
||||
}
|
||||
logger.info(f"Migrated legacy metadata for {model_name}: {legacy_path}")
|
||||
|
||||
logger.info(f"Loaded legacy metadata from {self.legacy_registry_file}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading legacy metadata: {e}")
|
||||
|
||||
return metadata
|
||||
|
||||
def _load_checkpoint_metadata(self) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Load checkpoint metadata"""
|
||||
@@ -407,34 +469,125 @@ class ModelManager:
|
||||
return None
|
||||
|
||||
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
||||
"""Load the best checkpoint for a model"""
|
||||
"""Load the best checkpoint for a model with legacy support"""
|
||||
try:
|
||||
# First, try the unified registry
|
||||
model_info = self.metadata['models'].get(model_name)
|
||||
if model_info and Path(model_info['latest_path']).exists():
|
||||
# Load from unified registry
|
||||
load_dict = torch.load(model_info['latest_path'], map_location='cpu')
|
||||
return model_info['latest_path'], None
|
||||
logger.info(f"Loading checkpoint from unified registry: {model_info['latest_path']}")
|
||||
# Create metadata from model info for compatibility
|
||||
registry_metadata = CheckpointMetadata(
|
||||
checkpoint_id=f"{model_name}_registry",
|
||||
model_name=model_name,
|
||||
model_type=model_info.get('type', model_name),
|
||||
file_path=model_info['latest_path'],
|
||||
created_at=datetime.fromisoformat(model_info.get('last_saved', datetime.now().isoformat())),
|
||||
file_size_mb=0.0, # Will be calculated if needed
|
||||
performance_score=0.0, # Unknown from registry
|
||||
accuracy=None,
|
||||
loss=None, # Orchestrator will handle this
|
||||
val_accuracy=None,
|
||||
val_loss=None
|
||||
)
|
||||
return model_info['latest_path'], registry_metadata
|
||||
|
||||
# Fallback to checkpoint metadata
|
||||
checkpoints = self.checkpoint_metadata.get(model_name, [])
|
||||
if not checkpoints:
|
||||
logger.warning(f"No checkpoints found for {model_name}")
|
||||
return None
|
||||
if checkpoints:
|
||||
# Get best checkpoint
|
||||
best_checkpoint = max(checkpoints, key=lambda x: x.performance_score)
|
||||
|
||||
# Get best checkpoint
|
||||
best_checkpoint = max(checkpoints, key=lambda x: x.performance_score)
|
||||
if Path(best_checkpoint.file_path).exists():
|
||||
logger.info(f"Loading checkpoint from unified metadata: {best_checkpoint.file_path}")
|
||||
return best_checkpoint.file_path, best_checkpoint
|
||||
|
||||
if not Path(best_checkpoint.file_path).exists():
|
||||
logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}")
|
||||
return None
|
||||
# Legacy fallback: Look for checkpoints in legacy directories
|
||||
logger.info(f"No checkpoint found in unified structure, checking legacy directories for {model_name}")
|
||||
legacy_path = self._find_legacy_checkpoint(model_name)
|
||||
if legacy_path:
|
||||
logger.info(f"Found legacy checkpoint: {legacy_path}")
|
||||
# Create a basic CheckpointMetadata for the legacy checkpoint
|
||||
legacy_metadata = CheckpointMetadata(
|
||||
checkpoint_id=f"legacy_{model_name}",
|
||||
model_name=model_name,
|
||||
model_type=model_name, # Will be inferred from model type
|
||||
file_path=str(legacy_path),
|
||||
created_at=datetime.fromtimestamp(legacy_path.stat().st_mtime),
|
||||
file_size_mb=legacy_path.stat().st_size / (1024 * 1024),
|
||||
performance_score=0.0, # Unknown for legacy
|
||||
accuracy=None,
|
||||
loss=None
|
||||
)
|
||||
return str(legacy_path), legacy_metadata
|
||||
|
||||
return best_checkpoint.file_path, best_checkpoint
|
||||
logger.warning(f"No checkpoints found for {model_name} in any location")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def _find_legacy_checkpoint(self, model_name: str) -> Optional[Path]:
|
||||
"""Find checkpoint in legacy directories"""
|
||||
if not self.legacy_checkpoints_dir.exists():
|
||||
return None
|
||||
|
||||
# Use unified model naming throughout the project
|
||||
# All model references use consistent short names: dqn, cnn, cob_rl, transformer, decision
|
||||
# This eliminates complex mapping and ensures consistency across the entire codebase
|
||||
patterns = [model_name]
|
||||
|
||||
# Add minimal backward compatibility patterns
|
||||
if model_name == 'dqn':
|
||||
patterns.extend(['dqn_agent', 'agent'])
|
||||
elif model_name == 'cnn':
|
||||
patterns.extend(['cnn_model', 'enhanced_cnn'])
|
||||
elif model_name == 'cob_rl':
|
||||
patterns.extend(['rl', 'rl_agent', 'trading_agent'])
|
||||
|
||||
# Search in legacy saved directory first
|
||||
legacy_saved_dir = self.legacy_checkpoints_dir / "saved"
|
||||
if legacy_saved_dir.exists():
|
||||
for file_path in legacy_saved_dir.rglob("*.pt"):
|
||||
filename = file_path.name.lower()
|
||||
if any(pattern in filename for pattern in patterns):
|
||||
return file_path
|
||||
|
||||
# Search in model-specific directories
|
||||
for model_type in ['cnn', 'dqn', 'rl', 'transformer', 'decision']:
|
||||
model_dir = self.legacy_checkpoints_dir / model_type
|
||||
if model_dir.exists():
|
||||
saved_dir = model_dir / "saved"
|
||||
if saved_dir.exists():
|
||||
for file_path in saved_dir.rglob("*.pt"):
|
||||
filename = file_path.name.lower()
|
||||
if any(pattern in filename for pattern in patterns):
|
||||
return file_path
|
||||
|
||||
# Search in archive directory
|
||||
archive_dir = self.legacy_checkpoints_dir / "archive"
|
||||
if archive_dir.exists():
|
||||
for file_path in archive_dir.rglob("*.pt"):
|
||||
filename = file_path.name.lower()
|
||||
if any(pattern in filename for pattern in patterns):
|
||||
return file_path
|
||||
|
||||
# Search in backtest directory (might contain RL or other models)
|
||||
backtest_dir = self.legacy_checkpoints_dir / "backtest"
|
||||
if backtest_dir.exists():
|
||||
for file_path in backtest_dir.rglob("*.pt"):
|
||||
filename = file_path.name.lower()
|
||||
if any(pattern in filename for pattern in patterns):
|
||||
return file_path
|
||||
|
||||
# Last resort: search entire legacy directory
|
||||
for file_path in self.legacy_checkpoints_dir.rglob("*.pt"):
|
||||
filename = file_path.name.lower()
|
||||
if any(pattern in filename for pattern in patterns):
|
||||
return file_path
|
||||
|
||||
return None
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
"""Get storage statistics"""
|
||||
try:
|
||||
@@ -467,7 +620,7 @@ class ModelManager:
|
||||
'models': {}
|
||||
}
|
||||
|
||||
# Count files in different directories as "checkpoints"
|
||||
# Count files in new unified directories
|
||||
checkpoint_dirs = [
|
||||
self.checkpoints_dir / "cnn",
|
||||
self.checkpoints_dir / "dqn",
|
||||
@@ -511,6 +664,34 @@ class ModelManager:
|
||||
saved_size = sum(f.stat().st_size for f in saved_files)
|
||||
stats['total_size_mb'] += saved_size / (1024 * 1024)
|
||||
|
||||
# Add legacy checkpoint statistics
|
||||
if self.legacy_checkpoints_dir.exists():
|
||||
legacy_files = list(self.legacy_checkpoints_dir.rglob('*.pt'))
|
||||
if legacy_files:
|
||||
legacy_size = sum(f.stat().st_size for f in legacy_files)
|
||||
stats['total_checkpoints'] += len(legacy_files)
|
||||
stats['total_size_mb'] += legacy_size / (1024 * 1024)
|
||||
|
||||
# Add legacy models to stats
|
||||
legacy_model_dirs = ['cnn', 'dqn', 'rl', 'transformer', 'decision']
|
||||
for model_dir_name in legacy_model_dirs:
|
||||
model_dir = self.legacy_checkpoints_dir / model_dir_name
|
||||
if model_dir.exists():
|
||||
model_files = list(model_dir.rglob('*.pt'))
|
||||
if model_files and model_dir_name not in stats['models']:
|
||||
stats['total_models'] += 1
|
||||
model_size = sum(f.stat().st_size for f in model_files)
|
||||
latest_file = max(model_files, key=lambda f: f.stat().st_mtime)
|
||||
|
||||
stats['models'][model_dir_name] = {
|
||||
'checkpoint_count': len(model_files),
|
||||
'total_size_mb': model_size / (1024 * 1024),
|
||||
'best_performance': 0.0,
|
||||
'best_checkpoint_id': latest_file.name,
|
||||
'latest_checkpoint': latest_file.name,
|
||||
'location': 'legacy'
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
|
@@ -331,8 +331,39 @@ class ExtremaTrainer:
|
||||
|
||||
# Get all available price data for better extrema detection
|
||||
all_candles = list(self.context_data[symbol].candles)
|
||||
prices = [candle['close'] for candle in all_candles]
|
||||
timestamps = [candle['timestamp'] for candle in all_candles]
|
||||
prices = []
|
||||
timestamps = []
|
||||
|
||||
for i, candle in enumerate(all_candles):
|
||||
# Handle different candle formats
|
||||
if isinstance(candle, dict):
|
||||
if 'close' in candle:
|
||||
prices.append(candle['close'])
|
||||
else:
|
||||
# Fallback to other price fields
|
||||
price = candle.get('price') or candle.get('high') or candle.get('low') or candle.get('open') or 0
|
||||
prices.append(price)
|
||||
|
||||
# Handle timestamp with fallbacks
|
||||
if 'timestamp' in candle:
|
||||
timestamps.append(candle['timestamp'])
|
||||
elif 'time' in candle:
|
||||
timestamps.append(candle['time'])
|
||||
else:
|
||||
# Generate timestamp based on index if none available
|
||||
timestamps.append(datetime.now() - timedelta(minutes=len(all_candles) - i))
|
||||
else:
|
||||
# Handle non-dict candle formats (e.g., tuples, lists)
|
||||
if hasattr(candle, '__getitem__'):
|
||||
prices.append(float(candle[3])) # Assume OHLC format: [O, H, L, C]
|
||||
timestamps.append(datetime.now() - timedelta(minutes=len(all_candles) - i))
|
||||
else:
|
||||
# Skip invalid candle data
|
||||
continue
|
||||
|
||||
# Ensure we have enough data
|
||||
if len(prices) < self.window_size * 3:
|
||||
return detected
|
||||
|
||||
# Use a more sophisticated extrema detection algorithm
|
||||
window = self.window_size
|
||||
|
@@ -15,19 +15,41 @@ import asyncio
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Tuple, Union
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque
|
||||
import json
|
||||
|
||||
# Try to import optional dependencies
|
||||
try:
|
||||
import numpy as np
|
||||
HAS_NUMPY = True
|
||||
except ImportError:
|
||||
np = None
|
||||
HAS_NUMPY = False
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
HAS_PANDAS = True
|
||||
except ImportError:
|
||||
pd = None
|
||||
HAS_PANDAS = False
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
# Try to import PyTorch
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
HAS_TORCH = True
|
||||
except ImportError:
|
||||
torch = None
|
||||
nn = None
|
||||
optim = None
|
||||
HAS_TORCH = False
|
||||
|
||||
from .config import get_config
|
||||
from .data_provider import DataProvider
|
||||
@@ -198,6 +220,7 @@ class TradingOrchestrator:
|
||||
# Load historical data for models and RL training
|
||||
self._load_historical_data_for_models()
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _initialize_ml_models(self):
|
||||
"""Initialize ML models for enhanced trading"""
|
||||
try:
|
||||
@@ -227,7 +250,7 @@ class TradingOrchestrator:
|
||||
self.rl_agent.load_best_checkpoint() # This loads the state into the model
|
||||
# Check if we have checkpoints available
|
||||
from NN.training.model_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("dqn_agent")
|
||||
result = load_best_checkpoint("dqn")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
self.model_states['dqn']['initial_loss'] = getattr(metadata, 'initial_loss', None)
|
||||
@@ -267,17 +290,37 @@ class TradingOrchestrator:
|
||||
checkpoint_loaded = False
|
||||
try:
|
||||
from NN.training.model_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("enhanced_cnn")
|
||||
result = load_best_checkpoint("cnn")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
self.model_states['cnn']['initial_loss'] = 0.412
|
||||
self.model_states['cnn']['current_loss'] = metadata.loss or 0.0187
|
||||
self.model_states['cnn']['best_loss'] = metadata.loss or 0.0134
|
||||
self.model_states['cnn']['checkpoint_loaded'] = True
|
||||
self.model_states['cnn']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "N/A"
|
||||
logger.info(f"CNN checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||
# Actually load the model weights from the checkpoint
|
||||
try:
|
||||
checkpoint_data = torch.load(file_path, map_location=self.device)
|
||||
if 'model_state_dict' in checkpoint_data:
|
||||
self.cnn_model.load_state_dict(checkpoint_data['model_state_dict'])
|
||||
logger.info(f"CNN model weights loaded from: {file_path}")
|
||||
elif 'state_dict' in checkpoint_data:
|
||||
self.cnn_model.load_state_dict(checkpoint_data['state_dict'])
|
||||
logger.info(f"CNN model weights loaded from: {file_path}")
|
||||
else:
|
||||
# Try loading directly as state dict
|
||||
self.cnn_model.load_state_dict(checkpoint_data)
|
||||
logger.info(f"CNN model weights loaded directly from: {file_path}")
|
||||
|
||||
# Update model states
|
||||
self.model_states['cnn']['initial_loss'] = checkpoint_data.get('initial_loss', 0.412)
|
||||
self.model_states['cnn']['current_loss'] = metadata.loss or checkpoint_data.get('loss', 0.0187)
|
||||
self.model_states['cnn']['best_loss'] = metadata.loss or checkpoint_data.get('best_loss', 0.0134)
|
||||
self.model_states['cnn']['checkpoint_loaded'] = True
|
||||
self.model_states['cnn']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "N/A"
|
||||
logger.info(f"CNN checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||
except Exception as load_error:
|
||||
logger.warning(f"Failed to load CNN model weights: {load_error}")
|
||||
# Continue with fresh model but mark as loaded for metadata purposes
|
||||
self.model_states['cnn']['checkpoint_loaded'] = True
|
||||
checkpoint_loaded = True
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading CNN checkpoint: {e}")
|
||||
|
||||
@@ -347,58 +390,97 @@ class TradingOrchestrator:
|
||||
self.extrema_trainer = None
|
||||
|
||||
# Initialize COB RL Model - UNIFIED with ModelManager
|
||||
cob_rl_available = False
|
||||
try:
|
||||
from NN.models.cob_rl_model import COBRLModelInterface
|
||||
cob_rl_available = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"COB RL dependencies not available: {e}")
|
||||
cob_rl_available = False
|
||||
|
||||
# Initialize COB RL model using unified approach
|
||||
self.cob_rl_agent = COBRLModelInterface(
|
||||
model_checkpoint_dir="@checkpoints/cob_rl",
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
if cob_rl_available:
|
||||
try:
|
||||
# Initialize COB RL model using unified approach
|
||||
self.cob_rl_agent = COBRLModelInterface(
|
||||
model_checkpoint_dir="@checkpoints/cob_rl",
|
||||
device='cuda' if (HAS_TORCH and torch.cuda.is_available()) else 'cpu'
|
||||
)
|
||||
|
||||
# Add COB RL to model states tracking
|
||||
self.model_states['cob_rl'] = {
|
||||
'initial_loss': None,
|
||||
'current_loss': None,
|
||||
'best_loss': None,
|
||||
'checkpoint_loaded': False
|
||||
}
|
||||
|
||||
# Load best checkpoint using unified ModelManager
|
||||
checkpoint_loaded = False
|
||||
try:
|
||||
from NN.training.model_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("cob_rl")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
self.model_states['cob_rl']['initial_loss'] = getattr(metadata, 'loss', None)
|
||||
self.model_states['cob_rl']['current_loss'] = getattr(metadata, 'loss', None)
|
||||
self.model_states['cob_rl']['best_loss'] = getattr(metadata, 'loss', None)
|
||||
self.model_states['cob_rl']['checkpoint_loaded'] = True
|
||||
self.model_states['cob_rl']['checkpoint_filename'] = getattr(metadata, 'checkpoint_id', 'unknown')
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{getattr(metadata, 'loss', 'N/A'):.4f}" if getattr(metadata, 'loss', None) is not None else "N/A"
|
||||
logger.info(f"COB RL checkpoint loaded: {getattr(metadata, 'checkpoint_id', 'unknown')} (loss={loss_str})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading COB RL checkpoint: {e}")
|
||||
|
||||
if not checkpoint_loaded:
|
||||
# New model - no synthetic data, start fresh
|
||||
self.model_states['cob_rl']['initial_loss'] = None
|
||||
self.model_states['cob_rl']['current_loss'] = None
|
||||
self.model_states['cob_rl']['best_loss'] = None
|
||||
self.model_states['cob_rl']['checkpoint_filename'] = 'none (fresh start)'
|
||||
logger.info("COB RL starting fresh - no checkpoint found")
|
||||
|
||||
logger.info("COB RL Agent initialized and integrated with unified ModelManager")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing COB RL: {e}")
|
||||
self.cob_rl_agent = None
|
||||
cob_rl_available = False
|
||||
|
||||
if not cob_rl_available:
|
||||
# COB RL not available due to missing dependencies
|
||||
# Still try to load checkpoint metadata for display purposes
|
||||
logger.info("COB RL dependencies missing - attempting checkpoint metadata load only")
|
||||
|
||||
# Add COB RL to model states tracking
|
||||
self.model_states['cob_rl'] = {
|
||||
'initial_loss': None,
|
||||
'current_loss': None,
|
||||
'best_loss': None,
|
||||
'checkpoint_loaded': False
|
||||
'checkpoint_loaded': False,
|
||||
'checkpoint_filename': 'dependencies missing'
|
||||
}
|
||||
|
||||
# Load best checkpoint using unified ModelManager
|
||||
checkpoint_loaded = False
|
||||
# Try to load checkpoint metadata even without the model
|
||||
try:
|
||||
from NN.training.model_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("cob_rl_agent")
|
||||
result = load_best_checkpoint("cob_rl")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
self.model_states['cob_rl']['initial_loss'] = metadata.loss
|
||||
self.model_states['cob_rl']['current_loss'] = metadata.loss
|
||||
self.model_states['cob_rl']['best_loss'] = metadata.loss
|
||||
self.model_states['cob_rl']['checkpoint_loaded'] = True
|
||||
self.model_states['cob_rl']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "N/A"
|
||||
logger.info(f"COB RL checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||
self.model_states['cob_rl']['checkpoint_filename'] = getattr(metadata, 'checkpoint_id', 'found')
|
||||
logger.info(f"COB RL checkpoint metadata loaded (model unavailable): {getattr(metadata, 'checkpoint_id', 'unknown')}")
|
||||
else:
|
||||
logger.info("No COB RL checkpoint found")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading COB RL checkpoint: {e}")
|
||||
logger.debug(f"Could not load COB RL checkpoint metadata: {e}")
|
||||
|
||||
if not checkpoint_loaded:
|
||||
# New model - no synthetic data, start fresh
|
||||
self.model_states['cob_rl']['initial_loss'] = None
|
||||
self.model_states['cob_rl']['current_loss'] = None
|
||||
self.model_states['cob_rl']['best_loss'] = None
|
||||
self.model_states['cob_rl']['checkpoint_filename'] = 'none (fresh start)'
|
||||
logger.info("COB RL starting fresh - no checkpoint found")
|
||||
|
||||
logger.info("COB RL Agent initialized and integrated with unified ModelManager")
|
||||
logger.info(" - Uses @checkpoints/ directory structure")
|
||||
logger.info(" - Follows same load/save/checkpoint flow as other models")
|
||||
logger.info(" - Integrated with enhanced real-time training system")
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"COB RL Model not available: {e}")
|
||||
self.cob_rl_agent = None
|
||||
|
||||
logger.info("COB RL initialization completed")
|
||||
logger.info(" - Uses @checkpoints/ directory structure")
|
||||
logger.info(" - Follows same load/save/checkpoint flow as other models")
|
||||
logger.info(" - Gracefully handles missing dependencies")
|
||||
|
||||
# Initialize TRANSFORMER Model
|
||||
try:
|
||||
from NN.models.advanced_transformer_trading import create_trading_transformer, TradingTransformerConfig
|
||||
@@ -531,6 +613,7 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error in extrema trainer prediction: {e}")
|
||||
return None
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def get_memory_usage(self) -> float:
|
||||
return 30.0 # MB
|
||||
|
||||
@@ -562,6 +645,7 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error in transformer prediction: {e}")
|
||||
return None
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def get_memory_usage(self) -> float:
|
||||
return 60.0 # MB estimate for transformer
|
||||
|
||||
@@ -588,6 +672,7 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error in decision model prediction: {e}")
|
||||
return None
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def get_memory_usage(self) -> float:
|
||||
return 40.0 # MB estimate for decision model
|
||||
|
||||
@@ -605,6 +690,7 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing ML models: {e}")
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def update_model_loss(self, model_name: str, current_loss: float, best_loss: float = None):
|
||||
"""Update model loss and potentially best loss"""
|
||||
if model_name in self.model_states:
|
||||
@@ -615,6 +701,7 @@ class TradingOrchestrator:
|
||||
self.model_states[model_name]['best_loss'] = current_loss
|
||||
logger.debug(f"Updated {model_name} loss: current={current_loss:.4f}, best={self.model_states[model_name]['best_loss']:.4f}")
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def checkpoint_saved(self, model_name: str, checkpoint_data: Dict[str, Any]):
|
||||
"""Callback when a model checkpoint is saved"""
|
||||
if model_name in self.model_states:
|
||||
@@ -628,6 +715,7 @@ class TradingOrchestrator:
|
||||
self.model_states[model_name]['best_loss'] = saved_loss
|
||||
logger.info(f"New best loss for {model_name}: {saved_loss:.4f}")
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def get_recent_predictions(self, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""Get recent predictions from all models for data streaming"""
|
||||
try:
|
||||
@@ -667,6 +755,7 @@ class TradingOrchestrator:
|
||||
logger.debug(f"Error getting recent predictions: {e}")
|
||||
return []
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def _save_orchestrator_state(self):
|
||||
"""Save the current state of the orchestrator, including model states."""
|
||||
state = {
|
||||
@@ -681,6 +770,7 @@ class TradingOrchestrator:
|
||||
json.dump(state, f, indent=4)
|
||||
logger.info(f"Orchestrator state saved to {save_path}")
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def _load_orchestrator_state(self):
|
||||
"""Load the orchestrator state from a saved file."""
|
||||
save_path = os.path.join(self.config.paths.get('checkpoint_dir', './models/saved'), 'orchestrator_state.json')
|
||||
@@ -716,6 +806,7 @@ class TradingOrchestrator:
|
||||
self.trade_loop_task = asyncio.create_task(self._trading_decision_loop())
|
||||
logger.info("Continuous trading loop initiated.")
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def _initialize_cob_integration(self):
|
||||
"""Initialize COB integration for real-time market microstructure data"""
|
||||
if COB_INTEGRATION_AVAILABLE:
|
||||
@@ -746,12 +837,14 @@ class TradingOrchestrator:
|
||||
else:
|
||||
logger.warning("COB Integration not initialized. Cannot start streaming.")
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def _start_cob_matrix_worker(self):
|
||||
"""Start a background worker to continuously update COB matrices for models"""
|
||||
if not self.cob_integration:
|
||||
logger.warning("COB Integration not available, cannot start COB matrix worker.")
|
||||
return
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def matrix_worker():
|
||||
logger.info("COB Matrix Worker started.")
|
||||
while self.realtime_processing:
|
||||
@@ -790,6 +883,7 @@ class TradingOrchestrator:
|
||||
matrix_thread = threading.Thread(target=matrix_worker, daemon=True)
|
||||
matrix_thread.start()
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def _update_cob_matrix_for_symbol(self, symbol: str):
|
||||
"""Updates the COB matrix and features for a specific symbol."""
|
||||
if not self.cob_integration:
|
||||
@@ -906,6 +1000,7 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error generating COB DQN features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def _on_cob_cnn_features(self, symbol: str, cob_data: Dict):
|
||||
"""Callback for when new COB CNN features are available"""
|
||||
if not self.realtime_processing:
|
||||
@@ -923,6 +1018,7 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _on_cob_cnn_features for {symbol}: {e}")
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def _on_cob_dqn_features(self, symbol: str, cob_data: Dict):
|
||||
"""Callback for when new COB DQN features are available"""
|
||||
if not self.realtime_processing:
|
||||
@@ -940,6 +1036,7 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _on_cob_dqn_features for {symbol}: {e}")
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def _on_cob_dashboard_data(self, symbol: str, cob_data: Dict):
|
||||
"""Callback for when new COB data is available for the dashboard"""
|
||||
if not self.realtime_processing:
|
||||
@@ -952,20 +1049,24 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _on_cob_dashboard_data for {symbol}: {e}")
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def get_cob_features(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Get the latest COB features for CNN model"""
|
||||
return self.latest_cob_features.get(symbol)
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def get_cob_state(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Get the latest COB state for DQN model"""
|
||||
return self.latest_cob_state.get(symbol)
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def get_cob_snapshot(self, symbol: str) -> Optional[COBSnapshot]:
|
||||
"""Get the latest raw COB snapshot for a symbol"""
|
||||
if self.cob_integration:
|
||||
return self.cob_integration.get_latest_cob_snapshot(symbol)
|
||||
return None
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def get_cob_feature_matrix(self, symbol: str, sequence_length: int = 60) -> Optional[np.ndarray]:
|
||||
"""Get a sequence of COB CNN features for sequence models"""
|
||||
if symbol not in self.cob_feature_history or not self.cob_feature_history[symbol]:
|
||||
@@ -998,6 +1099,7 @@ class TradingOrchestrator:
|
||||
|
||||
# Weight normalization removed - handled by ModelManager
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def add_decision_callback(self, callback):
|
||||
"""Add a callback function to be called when decisions are made"""
|
||||
self.decision_callbacks.append(callback)
|
||||
@@ -1261,6 +1363,7 @@ class TradingOrchestrator:
|
||||
logger.debug(f"Error building RL state for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _get_cob_state(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Build COB state vector for COB RL agent"""
|
||||
try:
|
||||
@@ -1417,6 +1520,7 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error creating RL state for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _combine_predictions(self, symbol: str, price: float,
|
||||
predictions: List[Prediction],
|
||||
timestamp: datetime) -> TradingDecision:
|
||||
@@ -1532,6 +1636,7 @@ class TradingOrchestrator:
|
||||
current_position_pnl=0.0
|
||||
)
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _get_timeframe_weight(self, timeframe: str) -> float:
|
||||
"""Get importance weight for a timeframe"""
|
||||
# Higher timeframes get more weight in decision making
|
||||
@@ -1544,12 +1649,14 @@ class TradingOrchestrator:
|
||||
# Model performance and weight adaptation removed - handled by ModelManager
|
||||
# Use self.model_manager for all model performance tracking
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def get_recent_decisions(self, symbol: str, limit: int = 10) -> List[TradingDecision]:
|
||||
"""Get recent decisions for a symbol"""
|
||||
if symbol in self.recent_decisions:
|
||||
return self.recent_decisions[symbol][-limit:]
|
||||
return []
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def get_performance_metrics(self) -> Dict[str, Any]:
|
||||
"""Get performance metrics for the orchestrator"""
|
||||
return {
|
||||
@@ -1564,6 +1671,7 @@ class TradingOrchestrator:
|
||||
}
|
||||
}
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def get_model_states(self) -> Dict[str, Dict]:
|
||||
"""Get current model states with REAL checkpoint data - SSOT for dashboard"""
|
||||
try:
|
||||
@@ -1688,6 +1796,7 @@ class TradingOrchestrator:
|
||||
'extrema_trainer': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False}
|
||||
}
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _initialize_decision_fusion(self):
|
||||
"""Initialize the decision fusion neural network for learning model effectiveness"""
|
||||
try:
|
||||
@@ -1706,6 +1815,7 @@ class TradingOrchestrator:
|
||||
self.fc3 = nn.Linear(hidden_size, 3) # BUY, SELL, HOLD
|
||||
self.dropout = nn.Dropout(0.2)
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def forward(self, x):
|
||||
x = torch.relu(self.fc1(x))
|
||||
x = self.dropout(x)
|
||||
@@ -1720,6 +1830,7 @@ class TradingOrchestrator:
|
||||
logger.warning(f"Decision fusion initialization failed: {e}")
|
||||
self.decision_fusion_enabled = False
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _initialize_enhanced_training_system(self):
|
||||
"""Initialize the enhanced real-time training system"""
|
||||
try:
|
||||
@@ -1764,6 +1875,7 @@ class TradingOrchestrator:
|
||||
self.training_enabled = False
|
||||
self.enhanced_training_system = None
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def start_enhanced_training(self):
|
||||
"""Start the enhanced real-time training system"""
|
||||
try:
|
||||
@@ -1784,6 +1896,7 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error starting enhanced training: {e}")
|
||||
return False
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def stop_enhanced_training(self):
|
||||
"""Stop the enhanced real-time training system"""
|
||||
try:
|
||||
@@ -1797,6 +1910,7 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error stopping enhanced training: {e}")
|
||||
return False
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def get_enhanced_training_stats(self) -> Dict[str, Any]:
|
||||
"""Get enhanced training system statistics with orchestrator integration"""
|
||||
try:
|
||||
@@ -1893,6 +2007,7 @@ class TradingOrchestrator:
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def set_training_dashboard(self, dashboard):
|
||||
"""Set the dashboard reference for the training system"""
|
||||
try:
|
||||
@@ -1911,6 +2026,7 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error getting universal data stream: {e}")
|
||||
return None
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def get_universal_data_for_model(self, model_type: str = 'cnn') -> Optional[Dict[str, Any]]:
|
||||
"""Get formatted universal data for specific model types"""
|
||||
try:
|
||||
@@ -1953,6 +2069,7 @@ class TradingOrchestrator:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _calculate_aggressiveness_thresholds(self, current_pnl: float, symbol: str) -> tuple:
|
||||
"""Calculate confidence thresholds based on aggressiveness settings"""
|
||||
# Base thresholds
|
||||
@@ -1975,6 +2092,7 @@ class TradingOrchestrator:
|
||||
|
||||
return entry_threshold, exit_threshold
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _apply_pnl_feedback(self, action: str, confidence: float, current_pnl: float,
|
||||
symbol: str, reasoning: dict) -> tuple:
|
||||
"""Apply P&L-based feedback to decision making"""
|
||||
@@ -2008,6 +2126,7 @@ class TradingOrchestrator:
|
||||
logger.debug(f"Error applying P&L feedback: {e}")
|
||||
return action, confidence
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _calculate_dynamic_entry_aggressiveness(self, symbol: str) -> float:
|
||||
"""Calculate dynamic entry aggressiveness based on recent performance"""
|
||||
try:
|
||||
@@ -2036,6 +2155,7 @@ class TradingOrchestrator:
|
||||
logger.debug(f"Error calculating dynamic entry aggressiveness: {e}")
|
||||
return 0.5
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _calculate_dynamic_exit_aggressiveness(self, symbol: str, current_pnl: float) -> float:
|
||||
"""Calculate dynamic exit aggressiveness based on P&L and market conditions"""
|
||||
try:
|
||||
@@ -2058,11 +2178,13 @@ class TradingOrchestrator:
|
||||
logger.debug(f"Error calculating dynamic exit aggressiveness: {e}")
|
||||
return 0.5
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def set_trading_executor(self, trading_executor):
|
||||
"""Set the trading executor for position tracking"""
|
||||
self.trading_executor = trading_executor
|
||||
logger.info("Trading executor set for position tracking and P&L feedback")
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price for symbol"""
|
||||
try:
|
||||
@@ -2108,6 +2230,7 @@ class TradingOrchestrator:
|
||||
else:
|
||||
return 1000.0
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _generate_fallback_prediction(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Generate fallback prediction when models fail"""
|
||||
try:
|
||||
@@ -2128,6 +2251,7 @@ class TradingOrchestrator:
|
||||
'model': 'fallback'
|
||||
}
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def capture_dqn_prediction(self, symbol: str, action_idx: int, confidence: float, price: float, q_values: List[float] = None):
|
||||
"""Capture DQN prediction for dashboard visualization"""
|
||||
try:
|
||||
@@ -2144,6 +2268,7 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.debug(f"Error capturing DQN prediction: {e}")
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def capture_cnn_prediction(self, symbol: str, direction: int, confidence: float, current_price: float, predicted_price: float):
|
||||
"""Capture CNN prediction for dashboard visualization"""
|
||||
try:
|
||||
@@ -2209,6 +2334,7 @@ class TradingOrchestrator:
|
||||
logger.warning(f"Data stream monitor initialization failed: {e}")
|
||||
self.data_stream_monitor = None
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def start_data_stream(self) -> bool:
|
||||
"""Start data streaming if not already active."""
|
||||
try:
|
||||
@@ -2221,6 +2347,7 @@ class TradingOrchestrator:
|
||||
logger.error(f"Failed to start data stream: {e}")
|
||||
return False
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def stop_data_stream(self) -> bool:
|
||||
"""Stop data streaming if active."""
|
||||
try:
|
||||
@@ -2231,6 +2358,7 @@ class TradingOrchestrator:
|
||||
logger.error(f"Failed to stop data stream: {e}")
|
||||
return False
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def get_data_stream_status(self) -> Dict[str, any]:
|
||||
"""Return current data stream status and buffer sizes."""
|
||||
status = {
|
||||
@@ -2249,6 +2377,7 @@ class TradingOrchestrator:
|
||||
pass
|
||||
return status
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def save_data_snapshot(self, filepath: str = None) -> str:
|
||||
"""Save a snapshot of current data stream buffers to a file.
|
||||
|
||||
@@ -2276,6 +2405,7 @@ class TradingOrchestrator:
|
||||
logger.error(f"Failed to save data snapshot: {e}")
|
||||
raise
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def get_stream_summary(self) -> Dict[str, any]:
|
||||
"""Get a summary of current data stream activity."""
|
||||
status = self.get_data_stream_status()
|
||||
@@ -2299,6 +2429,7 @@ class TradingOrchestrator:
|
||||
|
||||
return summary
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def get_cob_data(self, symbol: str, limit: int = 300) -> List:
|
||||
"""Get COB data for a symbol with specified limit."""
|
||||
try:
|
||||
@@ -2309,6 +2440,7 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error getting COB data: {e}")
|
||||
return []
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _load_historical_data_for_models(self):
|
||||
"""Load 300 historical candles for all required timeframes and symbols for model training"""
|
||||
logger.info("Loading 300 historical candles for model training and RL context...")
|
||||
@@ -2364,6 +2496,7 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in historical data loading: {e}")
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _initialize_models_with_historical_data(self, symbols_timeframes: List[Tuple[str, str]]):
|
||||
"""Initialize all NN models with historical data using data provider's normalized methods"""
|
||||
try:
|
||||
@@ -2397,6 +2530,7 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing models with historical data: {e}")
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _initialize_cnn_with_provider_data(self):
|
||||
"""Initialize CNN using data provider's normalized feature extraction"""
|
||||
try:
|
||||
@@ -2427,6 +2561,7 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing CNN with provider data: {e}")
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _initialize_dqn_with_provider_data(self, symbols_timeframes: List[Tuple[str, str]]):
|
||||
"""Initialize DQN using data provider's normalized state vector creation"""
|
||||
try:
|
||||
@@ -2444,6 +2579,7 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing DQN with provider data: {e}")
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _initialize_transformer_with_provider_data(self, symbols_timeframes: List[Tuple[str, str]]):
|
||||
"""Initialize Transformer using data provider's normalized sequence creation"""
|
||||
try:
|
||||
@@ -2461,6 +2597,7 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing Transformer with provider data: {e}")
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _initialize_decision_with_provider_data(self, symbol_features: Dict[str, Dict[str, pd.DataFrame]]):
|
||||
"""Initialize Decision Fusion using data provider's feature aggregation"""
|
||||
try:
|
||||
@@ -2490,6 +2627,7 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing Decision Fusion with provider data: {e}")
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def get_ohlcv_data(self, symbol: str, timeframe: str, limit: int = 300) -> List:
|
||||
"""Get OHLCV data for a symbol with specified timeframe and limit."""
|
||||
try:
|
||||
|
@@ -850,6 +850,115 @@ class TradingExecutor:
|
||||
"""Get trade history"""
|
||||
return self.trade_history.copy()
|
||||
|
||||
def export_trades_to_csv(self, filename: Optional[str] = None) -> str:
|
||||
"""Export trade history to CSV file with comprehensive analysis"""
|
||||
import csv
|
||||
from pathlib import Path
|
||||
|
||||
if not self.trade_history:
|
||||
logger.warning("No trades to export")
|
||||
return ""
|
||||
|
||||
# Generate filename if not provided
|
||||
if filename is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"trade_history_{timestamp}.csv"
|
||||
|
||||
# Ensure .csv extension
|
||||
if not filename.endswith('.csv'):
|
||||
filename += '.csv'
|
||||
|
||||
# Create trades directory if it doesn't exist
|
||||
trades_dir = Path("trades")
|
||||
trades_dir.mkdir(exist_ok=True)
|
||||
filepath = trades_dir / filename
|
||||
|
||||
try:
|
||||
with open(filepath, 'w', newline='', encoding='utf-8') as csvfile:
|
||||
fieldnames = [
|
||||
'symbol', 'side', 'quantity', 'entry_price', 'exit_price',
|
||||
'entry_time', 'exit_time', 'pnl', 'fees', 'confidence',
|
||||
'hold_time_seconds', 'hold_time_minutes', 'leverage',
|
||||
'pnl_percentage', 'net_pnl', 'profit_loss', 'trade_duration',
|
||||
'entry_hour', 'exit_hour', 'day_of_week'
|
||||
]
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
|
||||
total_pnl = 0
|
||||
winning_trades = 0
|
||||
losing_trades = 0
|
||||
|
||||
for trade in self.trade_history:
|
||||
# Calculate additional metrics
|
||||
pnl_percentage = (trade.pnl / trade.entry_price) * 100 if trade.entry_price != 0 else 0
|
||||
net_pnl = trade.pnl - trade.fees
|
||||
profit_loss = "PROFIT" if net_pnl > 0 else "LOSS"
|
||||
trade_duration = trade.exit_time - trade.entry_time
|
||||
hold_time_minutes = trade.hold_time_seconds / 60
|
||||
|
||||
# Track statistics
|
||||
total_pnl += net_pnl
|
||||
if net_pnl > 0:
|
||||
winning_trades += 1
|
||||
else:
|
||||
losing_trades += 1
|
||||
|
||||
writer.writerow({
|
||||
'symbol': trade.symbol,
|
||||
'side': trade.side,
|
||||
'quantity': trade.quantity,
|
||||
'entry_price': trade.entry_price,
|
||||
'exit_price': trade.exit_price,
|
||||
'entry_time': trade.entry_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'exit_time': trade.exit_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'pnl': trade.pnl,
|
||||
'fees': trade.fees,
|
||||
'confidence': trade.confidence,
|
||||
'hold_time_seconds': trade.hold_time_seconds,
|
||||
'hold_time_minutes': hold_time_minutes,
|
||||
'leverage': trade.leverage,
|
||||
'pnl_percentage': pnl_percentage,
|
||||
'net_pnl': net_pnl,
|
||||
'profit_loss': profit_loss,
|
||||
'trade_duration': str(trade_duration),
|
||||
'entry_hour': trade.entry_time.hour,
|
||||
'exit_hour': trade.exit_time.hour,
|
||||
'day_of_week': trade.entry_time.strftime('%A')
|
||||
})
|
||||
|
||||
# Create summary statistics file
|
||||
summary_filename = filename.replace('.csv', '_summary.txt')
|
||||
summary_filepath = trades_dir / summary_filename
|
||||
|
||||
total_trades = len(self.trade_history)
|
||||
win_rate = (winning_trades / total_trades * 100) if total_trades > 0 else 0
|
||||
avg_pnl = total_pnl / total_trades if total_trades > 0 else 0
|
||||
avg_hold_time = sum(t.hold_time_seconds for t in self.trade_history) / total_trades if total_trades > 0 else 0
|
||||
|
||||
with open(summary_filepath, 'w', encoding='utf-8') as f:
|
||||
f.write("TRADE ANALYSIS SUMMARY\n")
|
||||
f.write("=" * 50 + "\n")
|
||||
f.write(f"Total Trades: {total_trades}\n")
|
||||
f.write(f"Winning Trades: {winning_trades}\n")
|
||||
f.write(f"Losing Trades: {losing_trades}\n")
|
||||
f.write(f"Win Rate: {win_rate:.1f}%\n")
|
||||
f.write(f"Total P&L: ${total_pnl:.2f}\n")
|
||||
f.write(f"Average P&L per Trade: ${avg_pnl:.2f}\n")
|
||||
f.write(f"Average Hold Time: {avg_hold_time:.1f} seconds ({avg_hold_time/60:.1f} minutes)\n")
|
||||
f.write(f"Export Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"Data File: {filename}\n")
|
||||
|
||||
logger.info(f"📊 Trade history exported to: {filepath}")
|
||||
logger.info(f"📈 Trade summary saved to: {summary_filepath}")
|
||||
logger.info(f"📊 Total Trades: {total_trades} | Win Rate: {win_rate:.1f}% | Total P&L: ${total_pnl:.2f}")
|
||||
|
||||
return str(filepath)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting trades to CSV: {e}")
|
||||
return ""
|
||||
|
||||
def get_daily_stats(self) -> Dict[str, Any]:
|
||||
"""Get daily trading statistics with enhanced fee analysis"""
|
||||
total_pnl = sum(trade.pnl for trade in self.trade_history)
|
||||
|
@@ -37,16 +37,23 @@ import traceback
|
||||
import gc
|
||||
import time
|
||||
import psutil
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
# Try to import torch
|
||||
try:
|
||||
import torch
|
||||
HAS_TORCH = True
|
||||
except ImportError:
|
||||
torch = None
|
||||
HAS_TORCH = False
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def clear_gpu_memory():
|
||||
"""Clear GPU memory cache"""
|
||||
if torch.cuda.is_available():
|
||||
if HAS_TORCH and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -272,7 +272,7 @@ class DashboardComponentManager:
|
||||
logger.error(f"Error formatting system status: {e}")
|
||||
return [html.P(f"Error: {str(e)}", className="text-danger small")]
|
||||
|
||||
def format_cob_data(self, cob_snapshot, symbol, cumulative_imbalance_stats=None, cob_mode="Unknown"):
|
||||
def format_cob_data(self, cob_snapshot, symbol, cumulative_imbalance_stats=None, cob_mode="Unknown", imbalance_ma_data=None):
|
||||
"""Format COB data into a split view with summary, imbalance stats, and a compact ladder."""
|
||||
try:
|
||||
if not cob_snapshot:
|
||||
@@ -317,7 +317,7 @@ class DashboardComponentManager:
|
||||
}
|
||||
|
||||
# --- Left Panel: Overview and Stats ---
|
||||
overview_panel = self._create_cob_overview_panel(symbol, stats, cumulative_imbalance_stats, cob_mode)
|
||||
overview_panel = self._create_cob_overview_panel(symbol, stats, cumulative_imbalance_stats, cob_mode, imbalance_ma_data)
|
||||
|
||||
# --- Right Panel: Compact Ladder ---
|
||||
ladder_panel = self._create_cob_ladder_panel(bids, asks, mid_price, symbol)
|
||||
@@ -331,7 +331,7 @@ class DashboardComponentManager:
|
||||
logger.error(f"Error formatting split COB data: {e}")
|
||||
return html.P(f"Error: {str(e)}", className="text-danger small")
|
||||
|
||||
def _create_cob_overview_panel(self, symbol, stats, cumulative_imbalance_stats, cob_mode="Unknown"):
|
||||
def _create_cob_overview_panel(self, symbol, stats, cumulative_imbalance_stats, cob_mode="Unknown", imbalance_ma_data=None):
|
||||
"""Creates the left panel with summary and imbalance stats."""
|
||||
mid_price = stats.get('mid_price', 0)
|
||||
spread_bps = stats.get('spread_bps', 0)
|
||||
@@ -373,6 +373,18 @@ class DashboardComponentManager:
|
||||
|
||||
html.Div(imbalance_stats_display),
|
||||
|
||||
# COB Imbalance Moving Averages
|
||||
html.Div([
|
||||
html.H6("Imbalance MAs", className="mt-3 mb-2 small text-muted text-uppercase"),
|
||||
*[
|
||||
html.Div([
|
||||
html.Strong(f"{timeframe}: ", className="small"),
|
||||
html.Span(f"MA {timeframe}: {ma_value:.3f}", className=f"small {'text-success' if ma_value > 0 else 'text-danger'}")
|
||||
], className="mb-1")
|
||||
for timeframe, ma_value in (imbalance_ma_data or {}).items()
|
||||
]
|
||||
]) if imbalance_ma_data else html.Div(),
|
||||
|
||||
html.Hr(className="my-2"),
|
||||
|
||||
html.Table([
|
||||
|
@@ -37,33 +37,37 @@ class DashboardLayoutManager:
|
||||
"🧠 Model Predictions & Performance Tracking"
|
||||
], className="text-light mb-3"),
|
||||
|
||||
# Summary cards row
|
||||
# Summary cards row - Enhanced with real metrics
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6("0", id="total-predictions-count", className="mb-0 text-primary"),
|
||||
html.Small("Total Predictions", className="text-light")
|
||||
], className="card-body text-center p-2 bg-dark")
|
||||
], className="card col-md-3 mx-1 bg-dark border-secondary"),
|
||||
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6("0", id="pending-predictions-count", className="mb-0 text-warning"),
|
||||
html.Small("Pending Resolution", className="text-light")
|
||||
html.Small("Recent Signals", className="text-light"),
|
||||
html.Small("", id="predictions-trend", className="d-block text-xs text-muted")
|
||||
], className="card-body text-center p-2 bg-dark")
|
||||
], className="card col-md-3 mx-1 bg-dark border-secondary"),
|
||||
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6("0", id="active-models-count", className="mb-0 text-info"),
|
||||
html.Small("Active Models", className="text-light")
|
||||
html.Small("Loaded Models", className="text-light"),
|
||||
html.Small("", id="models-status", className="d-block text-xs text-success")
|
||||
], className="card-body text-center p-2 bg-dark")
|
||||
], className="card col-md-3 mx-1 bg-dark border-secondary"),
|
||||
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6("0.0", id="total-rewards-sum", className="mb-0 text-success"),
|
||||
html.Small("Total Rewards", className="text-light")
|
||||
html.H6("0.00", id="avg-confidence", className="mb-0 text-warning"),
|
||||
html.Small("Avg Confidence", className="text-light"),
|
||||
html.Small("", id="confidence-trend", className="d-block text-xs text-muted")
|
||||
], className="card-body text-center p-2 bg-dark")
|
||||
], className="card col-md-3 mx-1 bg-dark border-secondary"),
|
||||
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6("+0.00", id="total-rewards-sum", className="mb-0 text-success"),
|
||||
html.Small("Total Rewards", className="text-light"),
|
||||
html.Small("", id="rewards-trend", className="d-block text-xs text-muted")
|
||||
], className="card-body text-center p-2 bg-dark")
|
||||
], className="card col-md-3 mx-1 bg-dark border-secondary")
|
||||
], className="row mb-3"),
|
||||
@@ -453,3 +457,4 @@ class DashboardLayoutManager:
|
||||
], className="d-flex")
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user