checkpoint manager
This commit is contained in:
276
core/cnn_dashboard_integration.py
Normal file
276
core/cnn_dashboard_integration.py
Normal file
@ -0,0 +1,276 @@
|
||||
"""
|
||||
CNN Dashboard Integration
|
||||
|
||||
This module integrates the EnhancedCNN model with the dashboard, providing real-time
|
||||
training and visualization of model predictions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
import os
|
||||
import json
|
||||
|
||||
from .enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
from .data_models import BaseDataInput, ModelOutput, create_model_output
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CNNDashboardIntegration:
|
||||
"""
|
||||
Integrates the EnhancedCNN model with the dashboard
|
||||
|
||||
This class:
|
||||
1. Loads and initializes the CNN model
|
||||
2. Processes real-time data for model inference
|
||||
3. Manages continuous training of the model
|
||||
4. Provides visualization data for the dashboard
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider=None, checkpoint_dir: str = "models/enhanced_cnn"):
|
||||
"""
|
||||
Initialize the CNN dashboard integration
|
||||
|
||||
Args:
|
||||
data_provider: Data provider instance
|
||||
checkpoint_dir: Directory to save checkpoints to
|
||||
"""
|
||||
self.data_provider = data_provider
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.cnn_adapter = None
|
||||
self.training_thread = None
|
||||
self.training_active = False
|
||||
self.training_interval = 60 # Train every 60 seconds
|
||||
self.training_samples = []
|
||||
self.max_training_samples = 1000
|
||||
self.last_training_time = 0
|
||||
self.last_predictions = {}
|
||||
self.performance_metrics = {}
|
||||
self.model_name = "enhanced_cnn_v1"
|
||||
|
||||
# Create checkpoint directory if it doesn't exist
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Initialize CNN adapter
|
||||
self._initialize_cnn_adapter()
|
||||
|
||||
logger.info(f"CNNDashboardIntegration initialized with checkpoint_dir: {checkpoint_dir}")
|
||||
|
||||
def _initialize_cnn_adapter(self):
|
||||
"""Initialize the CNN adapter"""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from .enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
|
||||
# Create CNN adapter
|
||||
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=self.checkpoint_dir)
|
||||
|
||||
# Load best checkpoint if available
|
||||
self.cnn_adapter.load_best_checkpoint()
|
||||
|
||||
logger.info("CNN adapter initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing CNN adapter: {e}")
|
||||
self.cnn_adapter = None
|
||||
|
||||
def start_training_thread(self):
|
||||
"""Start the training thread"""
|
||||
if self.training_thread is not None and self.training_thread.is_alive():
|
||||
logger.info("Training thread already running")
|
||||
return
|
||||
|
||||
self.training_active = True
|
||||
self.training_thread = threading.Thread(target=self._training_loop, daemon=True)
|
||||
self.training_thread.start()
|
||||
|
||||
logger.info("CNN training thread started")
|
||||
|
||||
def stop_training_thread(self):
|
||||
"""Stop the training thread"""
|
||||
self.training_active = False
|
||||
if self.training_thread is not None:
|
||||
self.training_thread.join(timeout=5)
|
||||
self.training_thread = None
|
||||
|
||||
logger.info("CNN training thread stopped")
|
||||
|
||||
def _training_loop(self):
|
||||
"""Training loop for continuous model training"""
|
||||
while self.training_active:
|
||||
try:
|
||||
# Check if it's time to train
|
||||
current_time = time.time()
|
||||
if current_time - self.last_training_time >= self.training_interval and len(self.training_samples) >= 10:
|
||||
logger.info(f"Training CNN model with {len(self.training_samples)} samples")
|
||||
|
||||
# Train model
|
||||
if self.cnn_adapter is not None:
|
||||
metrics = self.cnn_adapter.train(epochs=1)
|
||||
|
||||
# Update performance metrics
|
||||
self.performance_metrics = {
|
||||
'loss': metrics.get('loss', 0.0),
|
||||
'accuracy': metrics.get('accuracy', 0.0),
|
||||
'samples': metrics.get('samples', 0),
|
||||
'last_training': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Log training metrics
|
||||
logger.info(f"CNN training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}")
|
||||
|
||||
# Update last training time
|
||||
self.last_training_time = current_time
|
||||
|
||||
# Sleep to avoid high CPU usage
|
||||
time.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN training loop: {e}")
|
||||
time.sleep(5) # Sleep longer on error
|
||||
|
||||
def process_data(self, symbol: str, base_data: BaseDataInput) -> Optional[ModelOutput]:
|
||||
"""
|
||||
Process data for model inference and training
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
base_data: Standardized input data
|
||||
|
||||
Returns:
|
||||
Optional[ModelOutput]: Model output, or None if processing failed
|
||||
"""
|
||||
try:
|
||||
if self.cnn_adapter is None:
|
||||
logger.warning("CNN adapter not initialized")
|
||||
return None
|
||||
|
||||
# Make prediction
|
||||
model_output = self.cnn_adapter.predict(base_data)
|
||||
|
||||
# Store prediction
|
||||
self.last_predictions[symbol] = model_output
|
||||
|
||||
# Store model output in data provider
|
||||
if self.data_provider is not None:
|
||||
self.data_provider.store_model_output(model_output)
|
||||
|
||||
return model_output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing data for CNN model: {e}")
|
||||
return None
|
||||
|
||||
def add_training_sample(self, base_data: BaseDataInput, actual_action: str, reward: float):
|
||||
"""
|
||||
Add a training sample
|
||||
|
||||
Args:
|
||||
base_data: Standardized input data
|
||||
actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
|
||||
reward: Reward received for the action
|
||||
"""
|
||||
try:
|
||||
if self.cnn_adapter is None:
|
||||
logger.warning("CNN adapter not initialized")
|
||||
return
|
||||
|
||||
# Add training sample to CNN adapter
|
||||
self.cnn_adapter.add_training_sample(base_data, actual_action, reward)
|
||||
|
||||
# Add to local training samples
|
||||
self.training_samples.append((base_data.symbol, actual_action, reward))
|
||||
|
||||
# Limit training samples
|
||||
if len(self.training_samples) > self.max_training_samples:
|
||||
self.training_samples = self.training_samples[-self.max_training_samples:]
|
||||
|
||||
logger.debug(f"Added training sample for {base_data.symbol}, action: {actual_action}, reward: {reward:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding training sample: {e}")
|
||||
|
||||
def get_performance_metrics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get performance metrics
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Performance metrics
|
||||
"""
|
||||
metrics = self.performance_metrics.copy()
|
||||
|
||||
# Add additional metrics
|
||||
metrics['training_samples'] = len(self.training_samples)
|
||||
metrics['model_name'] = self.model_name
|
||||
|
||||
# Add last prediction metrics
|
||||
if self.last_predictions:
|
||||
for symbol, prediction in self.last_predictions.items():
|
||||
metrics[f'{symbol}_last_action'] = prediction.predictions.get('action', 'UNKNOWN')
|
||||
metrics[f'{symbol}_last_confidence'] = prediction.confidence
|
||||
|
||||
return metrics
|
||||
|
||||
def get_visualization_data(self, symbol: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get visualization data for the dashboard
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Visualization data
|
||||
"""
|
||||
data = {
|
||||
'model_name': self.model_name,
|
||||
'symbol': symbol,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'performance_metrics': self.get_performance_metrics()
|
||||
}
|
||||
|
||||
# Add last prediction
|
||||
if symbol in self.last_predictions:
|
||||
prediction = self.last_predictions[symbol]
|
||||
data['last_prediction'] = {
|
||||
'action': prediction.predictions.get('action', 'UNKNOWN'),
|
||||
'confidence': prediction.confidence,
|
||||
'timestamp': prediction.timestamp.isoformat(),
|
||||
'buy_probability': prediction.predictions.get('buy_probability', 0.0),
|
||||
'sell_probability': prediction.predictions.get('sell_probability', 0.0),
|
||||
'hold_probability': prediction.predictions.get('hold_probability', 0.0)
|
||||
}
|
||||
|
||||
# Add training samples summary
|
||||
symbol_samples = [s for s in self.training_samples if s[0] == symbol]
|
||||
data['training_samples'] = {
|
||||
'total': len(symbol_samples),
|
||||
'buy': len([s for s in symbol_samples if s[1] == 'BUY']),
|
||||
'sell': len([s for s in symbol_samples if s[1] == 'SELL']),
|
||||
'hold': len([s for s in symbol_samples if s[1] == 'HOLD']),
|
||||
'avg_reward': sum(s[2] for s in symbol_samples) / len(symbol_samples) if symbol_samples else 0.0
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
# Global CNN dashboard integration instance
|
||||
_cnn_dashboard_integration = None
|
||||
|
||||
def get_cnn_dashboard_integration(data_provider=None) -> CNNDashboardIntegration:
|
||||
"""
|
||||
Get the global CNN dashboard integration instance
|
||||
|
||||
Args:
|
||||
data_provider: Data provider instance
|
||||
|
||||
Returns:
|
||||
CNNDashboardIntegration: Global CNN dashboard integration instance
|
||||
"""
|
||||
global _cnn_dashboard_integration
|
||||
|
||||
if _cnn_dashboard_integration is None:
|
||||
_cnn_dashboard_integration = CNNDashboardIntegration(data_provider=data_provider)
|
||||
|
||||
return _cnn_dashboard_integration
|
@ -1467,12 +1467,10 @@ class DataProvider:
|
||||
# Update COB data cache for distribution
|
||||
binance_symbol = symbol.replace('/', '').upper()
|
||||
if binance_symbol not in self.cob_data_cache or self.cob_data_cache[binance_symbol] is None:
|
||||
from collections import deque
|
||||
self.cob_data_cache[binance_symbol] = deque(maxlen=300)
|
||||
|
||||
# Ensure the deque is properly initialized
|
||||
if not isinstance(self.cob_data_cache[binance_symbol], deque):
|
||||
from collections import deque
|
||||
self.cob_data_cache[binance_symbol] = deque(maxlen=300)
|
||||
|
||||
self.cob_data_cache[binance_symbol].append({
|
||||
|
430
core/enhanced_cnn_adapter.py
Normal file
430
core/enhanced_cnn_adapter.py
Normal file
@ -0,0 +1,430 @@
|
||||
"""
|
||||
Enhanced CNN Adapter for Standardized Input Format
|
||||
|
||||
This module provides an adapter for the EnhancedCNN model to work with the standardized
|
||||
BaseDataInput format, enabling seamless integration with the multi-modal trading system.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||
from threading import Lock
|
||||
|
||||
from .data_models import BaseDataInput, ModelOutput, create_model_output
|
||||
from NN.models.enhanced_cnn import EnhancedCNN
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EnhancedCNNAdapter:
|
||||
"""
|
||||
Adapter for EnhancedCNN model to work with standardized BaseDataInput format
|
||||
|
||||
This adapter:
|
||||
1. Converts BaseDataInput to the format expected by EnhancedCNN
|
||||
2. Processes model outputs to create standardized ModelOutput
|
||||
3. Manages model training with collected data
|
||||
4. Handles checkpoint management
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str = None, checkpoint_dir: str = "models/enhanced_cnn"):
|
||||
"""
|
||||
Initialize the EnhancedCNN adapter
|
||||
|
||||
Args:
|
||||
model_path: Path to load model from, if None a new model is created
|
||||
checkpoint_dir: Directory to save checkpoints to
|
||||
"""
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.model = None
|
||||
self.model_path = model_path
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.training_lock = Lock()
|
||||
self.training_data = []
|
||||
self.max_training_samples = 10000
|
||||
self.batch_size = 32
|
||||
self.learning_rate = 0.0001
|
||||
self.model_name = "enhanced_cnn_v1"
|
||||
|
||||
# Create checkpoint directory if it doesn't exist
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Initialize model
|
||||
self._initialize_model()
|
||||
|
||||
logger.info(f"EnhancedCNNAdapter initialized with device: {self.device}")
|
||||
|
||||
def _initialize_model(self):
|
||||
"""Initialize the EnhancedCNN model"""
|
||||
try:
|
||||
# Calculate input shape based on BaseDataInput structure
|
||||
# OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features
|
||||
# BTC OHLCV: 300 frames x 5 features = 1500 features
|
||||
# COB: ±20 buckets x 4 metrics = 160 features
|
||||
# MA: 4 timeframes x 10 buckets = 40 features
|
||||
# Technical indicators: 100 features
|
||||
# Last predictions: 50 features
|
||||
# Total: 7850 features
|
||||
input_shape = 7850
|
||||
n_actions = 3 # BUY, SELL, HOLD
|
||||
|
||||
# Create model
|
||||
self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
|
||||
self.model.to(self.device)
|
||||
|
||||
# Load model if path is provided
|
||||
if self.model_path:
|
||||
success = self.model.load(self.model_path)
|
||||
if success:
|
||||
logger.info(f"Model loaded from {self.model_path}")
|
||||
else:
|
||||
logger.warning(f"Failed to load model from {self.model_path}, using new model")
|
||||
else:
|
||||
logger.info("No model path provided, using new model")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing EnhancedCNN model: {e}")
|
||||
raise
|
||||
|
||||
def _convert_base_data_to_features(self, base_data: BaseDataInput) -> torch.Tensor:
|
||||
"""
|
||||
Convert BaseDataInput to feature vector for EnhancedCNN
|
||||
|
||||
Args:
|
||||
base_data: Standardized input data
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Feature vector for EnhancedCNN
|
||||
"""
|
||||
try:
|
||||
# Use the get_feature_vector method from BaseDataInput
|
||||
features = base_data.get_feature_vector()
|
||||
|
||||
# Convert to torch tensor
|
||||
features_tensor = torch.tensor(features, dtype=torch.float32, device=self.device)
|
||||
|
||||
return features_tensor
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting BaseDataInput to features: {e}")
|
||||
# Return empty tensor with correct shape
|
||||
return torch.zeros(7850, dtype=torch.float32, device=self.device)
|
||||
|
||||
def predict(self, base_data: BaseDataInput) -> ModelOutput:
|
||||
"""
|
||||
Make a prediction using the EnhancedCNN model
|
||||
|
||||
Args:
|
||||
base_data: Standardized input data
|
||||
|
||||
Returns:
|
||||
ModelOutput: Standardized model output
|
||||
"""
|
||||
try:
|
||||
# Convert BaseDataInput to features
|
||||
features = self._convert_base_data_to_features(base_data)
|
||||
|
||||
# Ensure features has batch dimension
|
||||
if features.dim() == 1:
|
||||
features = features.unsqueeze(0)
|
||||
|
||||
# Set model to evaluation mode
|
||||
self.model.eval()
|
||||
|
||||
# Make prediction
|
||||
with torch.no_grad():
|
||||
q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.model(features)
|
||||
|
||||
# Get action and confidence
|
||||
action_probs = torch.softmax(q_values, dim=1)
|
||||
action_idx = torch.argmax(action_probs, dim=1).item()
|
||||
confidence = float(action_probs[0, action_idx].item())
|
||||
|
||||
# Map action index to action string
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
action = actions[action_idx]
|
||||
|
||||
# Create predictions dictionary
|
||||
predictions = {
|
||||
'action': action,
|
||||
'buy_probability': float(action_probs[0, 0].item()),
|
||||
'sell_probability': float(action_probs[0, 1].item()),
|
||||
'hold_probability': float(action_probs[0, 2].item()),
|
||||
'extrema': extrema_pred.squeeze(0).cpu().numpy().tolist(),
|
||||
'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist()
|
||||
}
|
||||
|
||||
# Create hidden states dictionary
|
||||
hidden_states = {
|
||||
'features': features_refined.squeeze(0).cpu().numpy().tolist()
|
||||
}
|
||||
|
||||
# Create metadata dictionary
|
||||
metadata = {
|
||||
'model_version': '1.0',
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'input_shape': features.shape
|
||||
}
|
||||
|
||||
# Create ModelOutput
|
||||
model_output = ModelOutput(
|
||||
model_type='cnn',
|
||||
model_name=self.model_name,
|
||||
symbol=base_data.symbol,
|
||||
timestamp=datetime.now(),
|
||||
confidence=confidence,
|
||||
predictions=predictions,
|
||||
hidden_states=hidden_states,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
return model_output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error making prediction with EnhancedCNN: {e}")
|
||||
# Return default ModelOutput
|
||||
return create_model_output(
|
||||
model_type='cnn',
|
||||
model_name=self.model_name,
|
||||
symbol=base_data.symbol,
|
||||
action='HOLD',
|
||||
confidence=0.0
|
||||
)
|
||||
|
||||
def add_training_sample(self, base_data: BaseDataInput, actual_action: str, reward: float):
|
||||
"""
|
||||
Add a training sample to the training data
|
||||
|
||||
Args:
|
||||
base_data: Standardized input data
|
||||
actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
|
||||
reward: Reward received for the action
|
||||
"""
|
||||
try:
|
||||
# Convert BaseDataInput to features
|
||||
features = self._convert_base_data_to_features(base_data)
|
||||
|
||||
# Convert action to index
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
action_idx = actions.index(actual_action)
|
||||
|
||||
# Add to training data
|
||||
with self.training_lock:
|
||||
self.training_data.append((features, action_idx, reward))
|
||||
|
||||
# Limit training data size
|
||||
if len(self.training_data) > self.max_training_samples:
|
||||
# Sort by reward (highest first) and keep top samples
|
||||
self.training_data.sort(key=lambda x: x[2], reverse=True)
|
||||
self.training_data = self.training_data[:self.max_training_samples]
|
||||
|
||||
logger.debug(f"Added training sample for {base_data.symbol}, action: {actual_action}, reward: {reward:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding training sample: {e}")
|
||||
|
||||
def train(self, epochs: int = 1) -> Dict[str, float]:
|
||||
"""
|
||||
Train the model with collected data
|
||||
|
||||
Args:
|
||||
epochs: Number of epochs to train for
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: Training metrics
|
||||
"""
|
||||
try:
|
||||
with self.training_lock:
|
||||
# Check if we have enough data
|
||||
if len(self.training_data) < self.batch_size:
|
||||
logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
|
||||
return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
|
||||
|
||||
# Set model to training mode
|
||||
self.model.train()
|
||||
|
||||
# Create optimizer
|
||||
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
|
||||
|
||||
# Training metrics
|
||||
total_loss = 0.0
|
||||
correct_predictions = 0
|
||||
total_predictions = 0
|
||||
|
||||
# Train for specified number of epochs
|
||||
for epoch in range(epochs):
|
||||
# Shuffle training data
|
||||
np.random.shuffle(self.training_data)
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(self.training_data), self.batch_size):
|
||||
batch = self.training_data[i:i+self.batch_size]
|
||||
|
||||
# Skip if batch is too small
|
||||
if len(batch) < 2:
|
||||
continue
|
||||
|
||||
# Prepare batch
|
||||
features = torch.stack([sample[0] for sample in batch])
|
||||
actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
|
||||
rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)
|
||||
|
||||
# Zero gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
q_values, _, _, _, _ = self.model(features)
|
||||
|
||||
# Calculate loss (CrossEntropyLoss with reward weighting)
|
||||
# First, apply softmax to get probabilities
|
||||
probs = torch.softmax(q_values, dim=1)
|
||||
|
||||
# Get probability of chosen action
|
||||
chosen_probs = probs[torch.arange(len(actions)), actions]
|
||||
|
||||
# Calculate negative log likelihood loss
|
||||
nll_loss = -torch.log(chosen_probs + 1e-10)
|
||||
|
||||
# Weight by reward (higher reward = higher weight)
|
||||
# Normalize rewards to [0, 1] range
|
||||
min_reward = rewards.min()
|
||||
max_reward = rewards.max()
|
||||
if max_reward > min_reward:
|
||||
normalized_rewards = (rewards - min_reward) / (max_reward - min_reward)
|
||||
else:
|
||||
normalized_rewards = torch.ones_like(rewards)
|
||||
|
||||
# Apply reward weighting (higher reward = higher weight)
|
||||
weighted_loss = nll_loss * (normalized_rewards + 0.1) # Add small constant to avoid zero weights
|
||||
|
||||
# Mean loss
|
||||
loss = weighted_loss.mean()
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Update weights
|
||||
optimizer.step()
|
||||
|
||||
# Update metrics
|
||||
total_loss += loss.item()
|
||||
|
||||
# Calculate accuracy
|
||||
predicted_actions = torch.argmax(q_values, dim=1)
|
||||
correct_predictions += (predicted_actions == actions).sum().item()
|
||||
total_predictions += len(actions)
|
||||
|
||||
# Calculate final metrics
|
||||
avg_loss = total_loss / (len(self.training_data) / self.batch_size)
|
||||
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
||||
|
||||
# Save checkpoint
|
||||
self._save_checkpoint(avg_loss, accuracy)
|
||||
|
||||
logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}")
|
||||
|
||||
return {
|
||||
'loss': avg_loss,
|
||||
'accuracy': accuracy,
|
||||
'samples': len(self.training_data)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training model: {e}")
|
||||
return {'loss': 0.0, 'accuracy': 0.0, 'samples': 0, 'error': str(e)}
|
||||
|
||||
def _save_checkpoint(self, loss: float, accuracy: float):
|
||||
"""
|
||||
Save model checkpoint
|
||||
|
||||
Args:
|
||||
loss: Training loss
|
||||
accuracy: Training accuracy
|
||||
"""
|
||||
try:
|
||||
# Import checkpoint manager
|
||||
from utils.checkpoint_manager import CheckpointManager
|
||||
|
||||
# Create checkpoint manager
|
||||
checkpoint_manager = CheckpointManager(
|
||||
checkpoint_dir=self.checkpoint_dir,
|
||||
max_checkpoints=10,
|
||||
metric_name="accuracy"
|
||||
)
|
||||
|
||||
# Create temporary model file
|
||||
temp_path = os.path.join(self.checkpoint_dir, f"{self.model_name}_temp")
|
||||
self.model.save(temp_path)
|
||||
|
||||
# Create metrics
|
||||
metrics = {
|
||||
'loss': loss,
|
||||
'accuracy': accuracy,
|
||||
'samples': len(self.training_data)
|
||||
}
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'model_name': self.model_name,
|
||||
'input_shape': self.model.input_shape,
|
||||
'n_actions': self.model.n_actions
|
||||
}
|
||||
|
||||
# Save checkpoint
|
||||
checkpoint_path = checkpoint_manager.save_checkpoint(
|
||||
model_name=self.model_name,
|
||||
model_path=f"{temp_path}.pt",
|
||||
metrics=metrics,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Delete temporary model file
|
||||
if os.path.exists(f"{temp_path}.pt"):
|
||||
os.remove(f"{temp_path}.pt")
|
||||
|
||||
logger.info(f"Model checkpoint saved to {checkpoint_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint: {e}")
|
||||
|
||||
def load_best_checkpoint(self):
|
||||
"""Load the best checkpoint based on accuracy"""
|
||||
try:
|
||||
# Import checkpoint manager
|
||||
from utils.checkpoint_manager import CheckpointManager
|
||||
|
||||
# Create checkpoint manager
|
||||
checkpoint_manager = CheckpointManager(
|
||||
checkpoint_dir=self.checkpoint_dir,
|
||||
max_checkpoints=10,
|
||||
metric_name="accuracy"
|
||||
)
|
||||
|
||||
# Load best checkpoint
|
||||
best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name)
|
||||
|
||||
if not best_checkpoint_path:
|
||||
logger.info("No checkpoints found")
|
||||
return False
|
||||
|
||||
# Load model
|
||||
success = self.model.load(best_checkpoint_path)
|
||||
|
||||
if success:
|
||||
logger.info(f"Loaded best checkpoint from {best_checkpoint_path}")
|
||||
|
||||
# Log metrics
|
||||
metrics = best_checkpoint_metadata.get('metrics', {})
|
||||
logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}")
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint: {e}")
|
||||
return False
|
@ -1,34 +1,31 @@
|
||||
"""
|
||||
Model Output Manager
|
||||
|
||||
This module provides extensible model output storage and management for the multi-modal trading system.
|
||||
Supports CNN, RL, LSTM, Transformer, and future model types with cross-model feeding capabilities.
|
||||
This module provides a centralized storage and management system for model outputs,
|
||||
enabling cross-model feeding and evaluation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import pickle
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from collections import deque, defaultdict
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from threading import Lock
|
||||
from pathlib import Path
|
||||
|
||||
from .data_models import ModelOutput, create_model_output
|
||||
from .data_models import ModelOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelOutputManager:
|
||||
"""
|
||||
Extensible model output storage and management system
|
||||
Centralized storage and management system for model outputs
|
||||
|
||||
Features:
|
||||
- Standardized ModelOutput storage for all model types
|
||||
- Cross-model feeding with hidden states
|
||||
- Historical output tracking
|
||||
- Metadata management
|
||||
- Persistence and recovery
|
||||
- Performance analytics
|
||||
This class:
|
||||
1. Stores model outputs for all models
|
||||
2. Provides access to current and historical outputs
|
||||
3. Handles persistence of outputs to disk
|
||||
4. Supports evaluation of model performance
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: str = "cache/model_outputs", max_history: int = 1000):
|
||||
@ -36,279 +33,226 @@ class ModelOutputManager:
|
||||
Initialize the model output manager
|
||||
|
||||
Args:
|
||||
cache_dir: Directory for persistent storage
|
||||
max_history: Maximum number of outputs to keep in memory per model
|
||||
cache_dir: Directory to store model outputs
|
||||
max_history: Maximum number of historical outputs to keep per model
|
||||
"""
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.cache_dir = cache_dir
|
||||
self.max_history = max_history
|
||||
self.outputs_lock = Lock()
|
||||
|
||||
# In-memory storage
|
||||
self.current_outputs: Dict[str, Dict[str, ModelOutput]] = defaultdict(dict) # {symbol: {model_name: ModelOutput}}
|
||||
self.output_history: Dict[str, Dict[str, deque]] = defaultdict(lambda: defaultdict(lambda: deque(maxlen=max_history))) # {symbol: {model_name: deque}}
|
||||
self.cross_model_states: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: hidden_states}}
|
||||
# Current outputs for each model and symbol
|
||||
# {symbol: {model_name: ModelOutput}}
|
||||
self.current_outputs: Dict[str, Dict[str, ModelOutput]] = {}
|
||||
|
||||
# Metadata tracking
|
||||
self.model_metadata: Dict[str, Dict[str, Any]] = defaultdict(dict) # {model_name: metadata}
|
||||
self.performance_stats: Dict[str, Dict[str, Any]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: stats}}
|
||||
# Historical outputs for each model and symbol
|
||||
# {symbol: {model_name: List[ModelOutput]}}
|
||||
self.historical_outputs: Dict[str, Dict[str, List[ModelOutput]]] = {}
|
||||
|
||||
# Thread safety
|
||||
self.storage_lock = Lock()
|
||||
# Performance metrics for each model and symbol
|
||||
# {symbol: {model_name: Dict[str, float]}}
|
||||
self.performance_metrics: Dict[str, Dict[str, Dict[str, float]]] = {}
|
||||
|
||||
# Supported model types
|
||||
self.supported_model_types = {
|
||||
'cnn', 'rl', 'lstm', 'transformer', 'orchestrator',
|
||||
'ensemble', 'hybrid', 'custom' # Extensible for future types
|
||||
}
|
||||
# Create cache directory if it doesn't exist
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
logger.info(f"ModelOutputManager initialized with cache dir: {self.cache_dir}")
|
||||
logger.info(f"Supported model types: {self.supported_model_types}")
|
||||
logger.info(f"ModelOutputManager initialized with cache_dir: {cache_dir}")
|
||||
|
||||
def store_output(self, model_output: ModelOutput) -> bool:
|
||||
"""
|
||||
Store model output with full extensibility support
|
||||
Store a model output
|
||||
|
||||
Args:
|
||||
model_output: ModelOutput from any model type
|
||||
model_output: Model output to store
|
||||
|
||||
Returns:
|
||||
bool: True if stored successfully, False otherwise
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
with self.storage_lock:
|
||||
symbol = model_output.symbol
|
||||
model_name = model_output.model_name
|
||||
model_type = model_output.model_type
|
||||
|
||||
# Validate model type (extensible)
|
||||
if model_type not in self.supported_model_types:
|
||||
logger.warning(f"Unknown model type '{model_type}' - adding to supported types")
|
||||
self.supported_model_types.add(model_type)
|
||||
symbol = model_output.symbol
|
||||
model_name = model_output.model_name
|
||||
|
||||
with self.outputs_lock:
|
||||
# Initialize dictionaries if they don't exist
|
||||
if symbol not in self.current_outputs:
|
||||
self.current_outputs[symbol] = {}
|
||||
if symbol not in self.historical_outputs:
|
||||
self.historical_outputs[symbol] = {}
|
||||
if model_name not in self.historical_outputs[symbol]:
|
||||
self.historical_outputs[symbol][model_name] = []
|
||||
|
||||
# Store current output
|
||||
self.current_outputs[symbol][model_name] = model_output
|
||||
|
||||
# Add to history
|
||||
self.output_history[symbol][model_name].append(model_output)
|
||||
|
||||
# Store cross-model states if available
|
||||
if model_output.hidden_states:
|
||||
self.cross_model_states[symbol][model_name] = model_output.hidden_states
|
||||
|
||||
# Update model metadata
|
||||
self._update_model_metadata(model_name, model_type, model_output.metadata)
|
||||
|
||||
# Update performance statistics
|
||||
self._update_performance_stats(symbol, model_name, model_output)
|
||||
|
||||
# Persist to disk (async to avoid blocking)
|
||||
self._persist_output_async(model_output)
|
||||
|
||||
logger.debug(f"Stored output from {model_name} ({model_type}) for {symbol}")
|
||||
return True
|
||||
# Add to historical outputs
|
||||
self.historical_outputs[symbol][model_name].append(model_output)
|
||||
|
||||
# Limit historical outputs
|
||||
if len(self.historical_outputs[symbol][model_name]) > self.max_history:
|
||||
self.historical_outputs[symbol][model_name] = self.historical_outputs[symbol][model_name][-self.max_history:]
|
||||
|
||||
# Persist output to disk
|
||||
self._persist_output(model_output)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing model output: {e}")
|
||||
return False
|
||||
|
||||
def get_current_output(self, symbol: str, model_name: str) -> Optional[ModelOutput]:
|
||||
"""
|
||||
Get the current (latest) output from a specific model
|
||||
Get the current output for a model and symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
model_name: Name of the model
|
||||
symbol: Symbol to get output for
|
||||
model_name: Model name to get output for
|
||||
|
||||
Returns:
|
||||
ModelOutput: Latest output from the model, or None if not available
|
||||
ModelOutput: Current output, or None if not available
|
||||
"""
|
||||
try:
|
||||
return self.current_outputs.get(symbol, {}).get(model_name)
|
||||
with self.outputs_lock:
|
||||
if symbol in self.current_outputs and model_name in self.current_outputs[symbol]:
|
||||
return self.current_outputs[symbol][model_name]
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current output for {model_name}: {e}")
|
||||
logger.error(f"Error getting current output: {e}")
|
||||
return None
|
||||
|
||||
def get_all_current_outputs(self, symbol: str) -> Dict[str, ModelOutput]:
|
||||
"""
|
||||
Get all current outputs for a symbol (for cross-model feeding)
|
||||
Get all current outputs for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
symbol: Symbol to get outputs for
|
||||
|
||||
Returns:
|
||||
Dict[str, ModelOutput]: Dictionary of current outputs by model name
|
||||
Dict[str, ModelOutput]: Dictionary of model name to output
|
||||
"""
|
||||
try:
|
||||
return dict(self.current_outputs.get(symbol, {}))
|
||||
with self.outputs_lock:
|
||||
if symbol in self.current_outputs:
|
||||
return self.current_outputs[symbol].copy()
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all current outputs for {symbol}: {e}")
|
||||
logger.error(f"Error getting all current outputs: {e}")
|
||||
return {}
|
||||
|
||||
def get_output_history(self, symbol: str, model_name: str, count: int = 10) -> List[ModelOutput]:
|
||||
def get_historical_outputs(self, symbol: str, model_name: str, limit: int = None) -> List[ModelOutput]:
|
||||
"""
|
||||
Get historical outputs from a model
|
||||
Get historical outputs for a model and symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
model_name: Name of the model
|
||||
count: Number of historical outputs to retrieve
|
||||
symbol: Symbol to get outputs for
|
||||
model_name: Model name to get outputs for
|
||||
limit: Maximum number of outputs to return, None for all
|
||||
|
||||
Returns:
|
||||
List[ModelOutput]: List of historical outputs (most recent first)
|
||||
List[ModelOutput]: List of historical outputs
|
||||
"""
|
||||
try:
|
||||
history = self.output_history.get(symbol, {}).get(model_name, deque())
|
||||
return list(history)[-count:][::-1] # Most recent first
|
||||
with self.outputs_lock:
|
||||
if symbol in self.historical_outputs and model_name in self.historical_outputs[symbol]:
|
||||
outputs = self.historical_outputs[symbol][model_name]
|
||||
if limit is not None:
|
||||
outputs = outputs[-limit:]
|
||||
return outputs.copy()
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting output history for {model_name}: {e}")
|
||||
logger.error(f"Error getting historical outputs: {e}")
|
||||
return []
|
||||
|
||||
def get_cross_model_states(self, symbol: str, requesting_model: str) -> Dict[str, Dict[str, Any]]:
|
||||
def evaluate_model_performance(self, symbol: str, model_name: str) -> Dict[str, float]:
|
||||
"""
|
||||
Get hidden states from other models for cross-model feeding
|
||||
Evaluate model performance based on historical outputs
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
requesting_model: Name of the model requesting the states
|
||||
symbol: Symbol to evaluate
|
||||
model_name: Model name to evaluate
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, Any]]: Hidden states from other models
|
||||
Dict[str, float]: Performance metrics
|
||||
"""
|
||||
try:
|
||||
all_states = self.cross_model_states.get(symbol, {})
|
||||
# Return states from all models except the requesting one
|
||||
return {model_name: states for model_name, states in all_states.items()
|
||||
if model_name != requesting_model}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting cross-model states for {requesting_model}: {e}")
|
||||
return {}
|
||||
|
||||
def get_model_types_active(self, symbol: str) -> List[str]:
|
||||
"""
|
||||
Get list of active model types for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
List[str]: List of active model types
|
||||
"""
|
||||
try:
|
||||
current_outputs = self.current_outputs.get(symbol, {})
|
||||
return [output.model_type for output in current_outputs.values()]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting active model types for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def get_consensus_prediction(self, symbol: str, confidence_threshold: float = 0.5) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get consensus prediction from all active models
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
confidence_threshold: Minimum confidence threshold for inclusion
|
||||
|
||||
Returns:
|
||||
Dict containing consensus prediction or None
|
||||
"""
|
||||
try:
|
||||
current_outputs = self.current_outputs.get(symbol, {})
|
||||
if not current_outputs:
|
||||
return None
|
||||
# Get historical outputs
|
||||
outputs = self.get_historical_outputs(symbol, model_name)
|
||||
|
||||
# Filter by confidence threshold
|
||||
high_confidence_outputs = [
|
||||
output for output in current_outputs.values()
|
||||
if output.confidence >= confidence_threshold
|
||||
]
|
||||
if not outputs:
|
||||
return {'accuracy': 0.0, 'confidence': 0.0, 'samples': 0}
|
||||
|
||||
if not high_confidence_outputs:
|
||||
return None
|
||||
# Calculate metrics
|
||||
total_outputs = len(outputs)
|
||||
total_confidence = sum(output.confidence for output in outputs)
|
||||
avg_confidence = total_confidence / total_outputs if total_outputs > 0 else 0.0
|
||||
|
||||
# Calculate consensus
|
||||
buy_votes = sum(1 for output in high_confidence_outputs
|
||||
if output.predictions.get('action') == 'BUY')
|
||||
sell_votes = sum(1 for output in high_confidence_outputs
|
||||
if output.predictions.get('action') == 'SELL')
|
||||
hold_votes = sum(1 for output in high_confidence_outputs
|
||||
if output.predictions.get('action') == 'HOLD')
|
||||
# For now, we don't have ground truth to calculate accuracy
|
||||
# In the future, we can add this by comparing predictions to actual market movements
|
||||
|
||||
total_votes = len(high_confidence_outputs)
|
||||
avg_confidence = sum(output.confidence for output in high_confidence_outputs) / total_votes
|
||||
|
||||
# Determine consensus action
|
||||
if buy_votes > sell_votes and buy_votes > hold_votes:
|
||||
consensus_action = 'BUY'
|
||||
elif sell_votes > buy_votes and sell_votes > hold_votes:
|
||||
consensus_action = 'SELL'
|
||||
else:
|
||||
consensus_action = 'HOLD'
|
||||
|
||||
return {
|
||||
'action': consensus_action,
|
||||
metrics = {
|
||||
'confidence': avg_confidence,
|
||||
'votes': {'BUY': buy_votes, 'SELL': sell_votes, 'HOLD': hold_votes},
|
||||
'total_models': total_votes,
|
||||
'model_types': [output.model_type for output in high_confidence_outputs]
|
||||
'samples': total_outputs,
|
||||
'last_update': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating consensus prediction for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _update_model_metadata(self, model_name: str, model_type: str, metadata: Dict[str, Any]):
|
||||
"""Update metadata for a model"""
|
||||
try:
|
||||
if model_name not in self.model_metadata:
|
||||
self.model_metadata[model_name] = {
|
||||
'model_type': model_type,
|
||||
'first_seen': datetime.now(),
|
||||
'total_predictions': 0,
|
||||
'custom_metadata': {}
|
||||
}
|
||||
# Store metrics
|
||||
with self.outputs_lock:
|
||||
if symbol not in self.performance_metrics:
|
||||
self.performance_metrics[symbol] = {}
|
||||
self.performance_metrics[symbol][model_name] = metrics
|
||||
|
||||
self.model_metadata[model_name]['total_predictions'] += 1
|
||||
self.model_metadata[model_name]['last_seen'] = datetime.now()
|
||||
|
||||
# Merge custom metadata
|
||||
if metadata:
|
||||
self.model_metadata[model_name]['custom_metadata'].update(metadata)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating model metadata: {e}")
|
||||
|
||||
def _update_performance_stats(self, symbol: str, model_name: str, model_output: ModelOutput):
|
||||
"""Update performance statistics for a model"""
|
||||
try:
|
||||
stats = self.performance_stats[symbol][model_name]
|
||||
|
||||
if 'prediction_count' not in stats:
|
||||
stats['prediction_count'] = 0
|
||||
stats['confidence_sum'] = 0.0
|
||||
stats['action_counts'] = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
||||
stats['first_prediction'] = model_output.timestamp
|
||||
|
||||
stats['prediction_count'] += 1
|
||||
stats['confidence_sum'] += model_output.confidence
|
||||
stats['avg_confidence'] = stats['confidence_sum'] / stats['prediction_count']
|
||||
stats['last_prediction'] = model_output.timestamp
|
||||
|
||||
action = model_output.predictions.get('action', 'HOLD')
|
||||
if action in stats['action_counts']:
|
||||
stats['action_counts'][action] += 1
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating performance stats: {e}")
|
||||
logger.error(f"Error evaluating model performance: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def _persist_output_async(self, model_output: ModelOutput):
|
||||
"""Persist model output to disk (simplified version)"""
|
||||
def get_performance_metrics(self, symbol: str, model_name: str) -> Dict[str, float]:
|
||||
"""
|
||||
Get performance metrics for a model and symbol
|
||||
|
||||
Args:
|
||||
symbol: Symbol to get metrics for
|
||||
model_name: Model name to get metrics for
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: Performance metrics
|
||||
"""
|
||||
try:
|
||||
# Create filename based on model and timestamp
|
||||
timestamp_str = model_output.timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp_str}.json"
|
||||
filepath = self.cache_dir / filename
|
||||
with self.outputs_lock:
|
||||
if symbol in self.performance_metrics and model_name in self.performance_metrics[symbol]:
|
||||
return self.performance_metrics[symbol][model_name].copy()
|
||||
|
||||
# Convert to JSON-serializable format
|
||||
# If no metrics are available, calculate them
|
||||
return self.evaluate_model_performance(symbol, model_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting performance metrics: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def _persist_output(self, model_output: ModelOutput) -> bool:
|
||||
"""
|
||||
Persist a model output to disk
|
||||
|
||||
Args:
|
||||
model_output: Model output to persist
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create directory if it doesn't exist
|
||||
symbol_dir = os.path.join(self.cache_dir, model_output.symbol.replace('/', '_'))
|
||||
os.makedirs(symbol_dir, exist_ok=True)
|
||||
|
||||
# Create filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp}.json"
|
||||
filepath = os.path.join(self.cache_dir, filename)
|
||||
|
||||
# Convert ModelOutput to dictionary
|
||||
output_dict = {
|
||||
'model_type': model_output.model_type,
|
||||
'model_name': model_output.model_name,
|
||||
@ -319,77 +263,120 @@ class ModelOutputManager:
|
||||
'metadata': model_output.metadata
|
||||
}
|
||||
|
||||
# Save to file (in a real implementation, this would be async)
|
||||
# Don't store hidden states in file (too large)
|
||||
|
||||
# Write to file
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(output_dict, f, indent=2)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error persisting model output: {e}")
|
||||
return False
|
||||
|
||||
def get_performance_summary(self, symbol: str) -> Dict[str, Any]:
|
||||
def load_outputs_from_disk(self, symbol: str = None, model_name: str = None) -> int:
|
||||
"""
|
||||
Get performance summary for all models for a symbol
|
||||
Load model outputs from disk
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
symbol: Symbol to load outputs for, None for all
|
||||
model_name: Model name to load outputs for, None for all
|
||||
|
||||
Returns:
|
||||
Dict containing performance summary
|
||||
int: Number of outputs loaded
|
||||
"""
|
||||
try:
|
||||
summary = {
|
||||
'symbol': symbol,
|
||||
'active_models': len(self.current_outputs.get(symbol, {})),
|
||||
'model_stats': {}
|
||||
}
|
||||
# Find all output files
|
||||
import glob
|
||||
|
||||
for model_name, stats in self.performance_stats.get(symbol, {}).items():
|
||||
summary['model_stats'][model_name] = {
|
||||
'predictions': stats.get('prediction_count', 0),
|
||||
'avg_confidence': round(stats.get('avg_confidence', 0.0), 3),
|
||||
'action_distribution': stats.get('action_counts', {}),
|
||||
'model_type': self.model_metadata.get(model_name, {}).get('model_type', 'unknown')
|
||||
}
|
||||
if symbol and model_name:
|
||||
pattern = os.path.join(self.cache_dir, f"{model_name}_{symbol.replace('/', '_')}*.json")
|
||||
elif symbol:
|
||||
pattern = os.path.join(self.cache_dir, f"*_{symbol.replace('/', '_')}*.json")
|
||||
elif model_name:
|
||||
pattern = os.path.join(self.cache_dir, f"{model_name}_*.json")
|
||||
else:
|
||||
pattern = os.path.join(self.cache_dir, "*.json")
|
||||
|
||||
return summary
|
||||
output_files = glob.glob(pattern)
|
||||
|
||||
if not output_files:
|
||||
logger.info(f"No output files found for pattern: {pattern}")
|
||||
return 0
|
||||
|
||||
# Load each file
|
||||
loaded_count = 0
|
||||
for filepath in output_files:
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
output_dict = json.load(f)
|
||||
|
||||
# Create ModelOutput
|
||||
model_output = ModelOutput(
|
||||
model_type=output_dict['model_type'],
|
||||
model_name=output_dict['model_name'],
|
||||
symbol=output_dict['symbol'],
|
||||
timestamp=datetime.fromisoformat(output_dict['timestamp']),
|
||||
confidence=output_dict['confidence'],
|
||||
predictions=output_dict['predictions'],
|
||||
hidden_states={}, # Don't load hidden states from disk
|
||||
metadata=output_dict.get('metadata', {})
|
||||
)
|
||||
|
||||
# Store output
|
||||
self.store_output(model_output)
|
||||
loaded_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading output file {filepath}: {e}")
|
||||
|
||||
logger.info(f"Loaded {loaded_count} model outputs from disk")
|
||||
return loaded_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting performance summary: {e}")
|
||||
return {'symbol': symbol, 'error': str(e)}
|
||||
logger.error(f"Error loading outputs from disk: {e}")
|
||||
return 0
|
||||
|
||||
def cleanup_old_outputs(self, max_age_hours: int = 24):
|
||||
def cleanup_old_outputs(self, max_age_days: int = 30) -> int:
|
||||
"""
|
||||
Clean up old outputs to manage memory usage
|
||||
Clean up old output files
|
||||
|
||||
Args:
|
||||
max_age_hours: Maximum age of outputs to keep in hours
|
||||
max_age_days: Maximum age of files to keep in days
|
||||
|
||||
Returns:
|
||||
int: Number of files deleted
|
||||
"""
|
||||
try:
|
||||
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
|
||||
# Find all output files
|
||||
import glob
|
||||
output_files = glob.glob(os.path.join(self.cache_dir, "*.json"))
|
||||
|
||||
with self.storage_lock:
|
||||
for symbol in self.output_history:
|
||||
for model_name in self.output_history[symbol]:
|
||||
history = self.output_history[symbol][model_name]
|
||||
# Remove old outputs
|
||||
while history and history[0].timestamp < cutoff_time:
|
||||
history.popleft()
|
||||
if not output_files:
|
||||
return 0
|
||||
|
||||
logger.info(f"Cleaned up outputs older than {max_age_hours} hours")
|
||||
# Calculate cutoff time
|
||||
cutoff_time = time.time() - (max_age_days * 24 * 60 * 60)
|
||||
|
||||
# Delete old files
|
||||
deleted_count = 0
|
||||
for filepath in output_files:
|
||||
try:
|
||||
# Get file modification time
|
||||
mtime = os.path.getmtime(filepath)
|
||||
|
||||
# Delete if older than cutoff
|
||||
if mtime < cutoff_time:
|
||||
os.remove(filepath)
|
||||
deleted_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting file {filepath}: {e}")
|
||||
|
||||
logger.info(f"Deleted {deleted_count} old model output files")
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up old outputs: {e}")
|
||||
|
||||
def add_custom_model_type(self, model_type: str):
|
||||
"""
|
||||
Add support for a new custom model type
|
||||
|
||||
Args:
|
||||
model_type: Name of the new model type
|
||||
"""
|
||||
self.supported_model_types.add(model_type)
|
||||
logger.info(f"Added support for custom model type: {model_type}")
|
||||
|
||||
def get_supported_model_types(self) -> List[str]:
|
||||
"""Get list of all supported model types"""
|
||||
return list(self.supported_model_types)
|
||||
return 0
|
Reference in New Issue
Block a user