785 lines
30 KiB
Python
785 lines
30 KiB
Python
"""
|
|
CNN Training Pipeline with Comprehensive Data Storage and Replay
|
|
|
|
This module implements a robust CNN training pipeline that:
|
|
1. Integrates with the comprehensive training data collection system
|
|
2. Stores all backpropagation data for gradient replay
|
|
3. Enables retraining on most profitable setups
|
|
4. Maintains training episode profitability tracking
|
|
5. Supports both real-time and batch training modes
|
|
|
|
Key Features:
|
|
- Integration with TrainingDataCollector for data validation
|
|
- Gradient and loss storage for each training step
|
|
- Profitable episode prioritization and replay
|
|
- Comprehensive training metrics and validation
|
|
- Real-time pivot point prediction with outcome tracking
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple, Any, Callable
|
|
from dataclasses import dataclass, field
|
|
import json
|
|
import pickle
|
|
from collections import deque, defaultdict
|
|
import threading
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from .training_data_collector import (
|
|
TrainingDataCollector,
|
|
TrainingEpisode,
|
|
ModelInputPackage,
|
|
get_training_data_collector
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class CNNTrainingStep:
|
|
"""Single CNN training step with complete backpropagation data"""
|
|
step_id: str
|
|
timestamp: datetime
|
|
episode_id: str
|
|
|
|
# Input data
|
|
input_features: torch.Tensor
|
|
target_labels: torch.Tensor
|
|
|
|
# Forward pass results
|
|
model_outputs: Dict[str, torch.Tensor]
|
|
predictions: Dict[str, Any]
|
|
confidence_scores: torch.Tensor
|
|
|
|
# Loss components
|
|
total_loss: float
|
|
pivot_prediction_loss: float
|
|
confidence_loss: float
|
|
regularization_loss: float
|
|
|
|
# Backpropagation data
|
|
gradients: Dict[str, torch.Tensor] # Gradients for each parameter
|
|
gradient_norms: Dict[str, float] # Gradient norms for monitoring
|
|
|
|
# Model state
|
|
model_state_dict: Optional[Dict[str, torch.Tensor]] = None
|
|
optimizer_state: Optional[Dict[str, Any]] = None
|
|
|
|
# Training metadata
|
|
learning_rate: float = 0.001
|
|
batch_size: int = 32
|
|
epoch: int = 0
|
|
|
|
# Profitability tracking
|
|
actual_profitability: Optional[float] = None
|
|
prediction_accuracy: Optional[float] = None
|
|
training_value: float = 0.0 # Value of this training step for replay
|
|
|
|
@dataclass
|
|
class CNNTrainingSession:
|
|
"""Complete CNN training session with multiple steps"""
|
|
session_id: str
|
|
start_timestamp: datetime
|
|
end_timestamp: Optional[datetime] = None
|
|
|
|
# Session configuration
|
|
training_mode: str = 'real_time' # 'real_time', 'batch', 'replay'
|
|
symbol: str = ''
|
|
|
|
# Training steps
|
|
training_steps: List[CNNTrainingStep] = field(default_factory=list)
|
|
|
|
# Session metrics
|
|
total_steps: int = 0
|
|
average_loss: float = 0.0
|
|
best_loss: float = float('inf')
|
|
convergence_achieved: bool = False
|
|
|
|
# Profitability metrics
|
|
profitable_predictions: int = 0
|
|
total_predictions: int = 0
|
|
profitability_rate: float = 0.0
|
|
|
|
# Session value for replay prioritization
|
|
session_value: float = 0.0
|
|
|
|
class CNNPivotPredictor(nn.Module):
|
|
"""CNN model for pivot point prediction with comprehensive output"""
|
|
|
|
def __init__(self,
|
|
input_channels: int = 10, # Multiple timeframes
|
|
sequence_length: int = 300, # 300 bars
|
|
hidden_dim: int = 256,
|
|
num_pivot_classes: int = 3, # high, low, none
|
|
dropout_rate: float = 0.2):
|
|
|
|
super(CNNPivotPredictor, self).__init__()
|
|
|
|
self.input_channels = input_channels
|
|
self.sequence_length = sequence_length
|
|
self.hidden_dim = hidden_dim
|
|
|
|
# Convolutional layers for pattern extraction
|
|
self.conv_layers = nn.Sequential(
|
|
# First conv block
|
|
nn.Conv1d(input_channels, 64, kernel_size=7, padding=3),
|
|
nn.BatchNorm1d(64),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
|
|
# Second conv block
|
|
nn.Conv1d(64, 128, kernel_size=5, padding=2),
|
|
nn.BatchNorm1d(128),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
|
|
# Third conv block
|
|
nn.Conv1d(128, 256, kernel_size=3, padding=1),
|
|
nn.BatchNorm1d(256),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
)
|
|
|
|
# LSTM for temporal dependencies
|
|
self.lstm = nn.LSTM(
|
|
input_size=256,
|
|
hidden_size=hidden_dim,
|
|
num_layers=2,
|
|
batch_first=True,
|
|
dropout=dropout_rate,
|
|
bidirectional=True
|
|
)
|
|
|
|
# Attention mechanism
|
|
self.attention = nn.MultiheadAttention(
|
|
embed_dim=hidden_dim * 2, # Bidirectional LSTM
|
|
num_heads=8,
|
|
dropout=dropout_rate,
|
|
batch_first=True
|
|
)
|
|
|
|
# Output heads
|
|
self.pivot_classifier = nn.Sequential(
|
|
nn.Linear(hidden_dim * 2, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
nn.Linear(hidden_dim, num_pivot_classes)
|
|
)
|
|
|
|
self.pivot_price_regressor = nn.Sequential(
|
|
nn.Linear(hidden_dim * 2, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Dropout(dropout_rate),
|
|
nn.Linear(hidden_dim, 1)
|
|
)
|
|
|
|
self.confidence_head = nn.Sequential(
|
|
nn.Linear(hidden_dim * 2, hidden_dim // 2),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim // 2, 1),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
# Initialize weights
|
|
self.apply(self._init_weights)
|
|
|
|
def _init_weights(self, module):
|
|
"""Initialize weights with proper scaling"""
|
|
if isinstance(module, nn.Linear):
|
|
torch.nn.init.xavier_uniform_(module.weight)
|
|
if module.bias is not None:
|
|
torch.nn.init.zeros_(module.bias)
|
|
elif isinstance(module, nn.Conv1d):
|
|
torch.nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Forward pass through CNN pivot predictor
|
|
|
|
Args:
|
|
x: Input tensor [batch_size, input_channels, sequence_length]
|
|
|
|
Returns:
|
|
Dict containing predictions and hidden states
|
|
"""
|
|
batch_size = x.size(0)
|
|
|
|
# Convolutional feature extraction
|
|
conv_features = self.conv_layers(x) # [batch, 256, sequence_length]
|
|
|
|
# Prepare for LSTM (transpose to [batch, sequence, features])
|
|
lstm_input = conv_features.transpose(1, 2) # [batch, sequence_length, 256]
|
|
|
|
# LSTM processing
|
|
lstm_output, (hidden, cell) = self.lstm(lstm_input) # [batch, sequence_length, hidden_dim*2]
|
|
|
|
# Attention mechanism
|
|
attended_output, attention_weights = self.attention(
|
|
lstm_output, lstm_output, lstm_output
|
|
)
|
|
|
|
# Use the last timestep for predictions
|
|
final_features = attended_output[:, -1, :] # [batch, hidden_dim*2]
|
|
|
|
# Generate predictions
|
|
pivot_logits = self.pivot_classifier(final_features)
|
|
pivot_price = self.pivot_price_regressor(final_features)
|
|
confidence = self.confidence_head(final_features)
|
|
|
|
return {
|
|
'pivot_logits': pivot_logits,
|
|
'pivot_price': pivot_price,
|
|
'confidence': confidence,
|
|
'hidden_states': final_features,
|
|
'attention_weights': attention_weights,
|
|
'conv_features': conv_features,
|
|
'lstm_output': lstm_output
|
|
}
|
|
|
|
class CNNTrainingDataset(Dataset):
|
|
"""Dataset for CNN training with training episodes"""
|
|
|
|
def __init__(self, training_episodes: List[TrainingEpisode]):
|
|
self.episodes = training_episodes
|
|
self.valid_episodes = self._validate_episodes()
|
|
|
|
def _validate_episodes(self) -> List[TrainingEpisode]:
|
|
"""Validate and filter episodes for training"""
|
|
valid = []
|
|
for episode in self.episodes:
|
|
try:
|
|
# Check if episode has required data
|
|
if (episode.input_package.cnn_features is not None and
|
|
episode.actual_outcome.outcome_validated):
|
|
valid.append(episode)
|
|
except Exception as e:
|
|
logger.warning(f"Invalid episode {episode.episode_id}: {e}")
|
|
|
|
logger.info(f"Validated {len(valid)}/{len(self.episodes)} episodes for training")
|
|
return valid
|
|
|
|
def __len__(self):
|
|
return len(self.valid_episodes)
|
|
|
|
def __getitem__(self, idx):
|
|
episode = self.valid_episodes[idx]
|
|
|
|
# Extract features
|
|
features = torch.from_numpy(episode.input_package.cnn_features).float()
|
|
|
|
# Create labels from actual outcomes
|
|
pivot_class = self._determine_pivot_class(episode.actual_outcome)
|
|
pivot_price = episode.actual_outcome.optimal_exit_price
|
|
confidence_target = episode.actual_outcome.profitability_score
|
|
|
|
return {
|
|
'features': features,
|
|
'pivot_class': torch.tensor(pivot_class, dtype=torch.long),
|
|
'pivot_price': torch.tensor(pivot_price, dtype=torch.float),
|
|
'confidence_target': torch.tensor(confidence_target, dtype=torch.float),
|
|
'episode_id': episode.episode_id,
|
|
'profitability': episode.actual_outcome.profitability_score
|
|
}
|
|
|
|
def _determine_pivot_class(self, outcome) -> int:
|
|
"""Determine pivot class from outcome"""
|
|
if outcome.price_change_15m > 0.5: # Significant upward movement
|
|
return 0 # High pivot
|
|
elif outcome.price_change_15m < -0.5: # Significant downward movement
|
|
return 1 # Low pivot
|
|
else:
|
|
return 2 # No significant pivot
|
|
|
|
class CNNTrainer:
|
|
"""CNN trainer with comprehensive data storage and replay capabilities"""
|
|
|
|
def __init__(self,
|
|
model: CNNPivotPredictor,
|
|
device: str = 'cuda',
|
|
learning_rate: float = 0.001,
|
|
storage_dir: str = "cnn_training_storage"):
|
|
|
|
self.model = model.to(device)
|
|
self.device = device
|
|
self.learning_rate = learning_rate
|
|
|
|
# Storage
|
|
self.storage_dir = Path(storage_dir)
|
|
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Optimizer
|
|
self.optimizer = torch.optim.AdamW(
|
|
self.model.parameters(),
|
|
lr=learning_rate,
|
|
weight_decay=1e-5
|
|
)
|
|
|
|
# Learning rate scheduler
|
|
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
self.optimizer, mode='min', patience=10, factor=0.5
|
|
)
|
|
|
|
# Training data collector
|
|
self.data_collector = get_training_data_collector()
|
|
|
|
# Training sessions storage
|
|
self.training_sessions: List[CNNTrainingSession] = []
|
|
self.current_session: Optional[CNNTrainingSession] = None
|
|
|
|
# Training statistics
|
|
self.training_stats = {
|
|
'total_sessions': 0,
|
|
'total_steps': 0,
|
|
'best_validation_loss': float('inf'),
|
|
'profitable_predictions': 0,
|
|
'total_predictions': 0,
|
|
'replay_sessions': 0
|
|
}
|
|
|
|
# Background training
|
|
self.is_training = False
|
|
self.training_thread = None
|
|
|
|
logger.info(f"CNN Trainer initialized")
|
|
logger.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
|
logger.info(f"Storage directory: {self.storage_dir}")
|
|
|
|
def start_real_time_training(self, symbol: str):
|
|
"""Start real-time training for a symbol"""
|
|
if self.is_training:
|
|
logger.warning("CNN training already running")
|
|
return
|
|
|
|
self.is_training = True
|
|
self.training_thread = threading.Thread(
|
|
target=self._real_time_training_worker,
|
|
args=(symbol,),
|
|
daemon=True
|
|
)
|
|
self.training_thread.start()
|
|
|
|
logger.info(f"Started real-time CNN training for {symbol}")
|
|
|
|
def stop_training(self):
|
|
"""Stop training"""
|
|
self.is_training = False
|
|
if self.training_thread:
|
|
self.training_thread.join(timeout=10)
|
|
|
|
if self.current_session:
|
|
self._finalize_training_session()
|
|
|
|
logger.info("CNN training stopped")
|
|
|
|
def _real_time_training_worker(self, symbol: str):
|
|
"""Real-time training worker"""
|
|
logger.info(f"Real-time CNN training worker started for {symbol}")
|
|
|
|
while self.is_training:
|
|
try:
|
|
# Get high-priority episodes for training
|
|
episodes = self.data_collector.get_high_priority_episodes(
|
|
symbol=symbol,
|
|
limit=100,
|
|
min_priority=0.3
|
|
)
|
|
|
|
if len(episodes) >= 32: # Minimum batch size
|
|
self._train_on_episodes(episodes, training_mode='real_time')
|
|
|
|
# Wait before next training cycle
|
|
threading.Event().wait(300) # Train every 5 minutes
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in real-time training worker: {e}")
|
|
threading.Event().wait(60) # Wait before retrying
|
|
|
|
logger.info(f"Real-time CNN training worker stopped for {symbol}")
|
|
|
|
def train_on_profitable_episodes(self,
|
|
symbol: str,
|
|
min_profitability: float = 0.7,
|
|
max_episodes: int = 500) -> Dict[str, Any]:
|
|
"""Train specifically on most profitable episodes"""
|
|
try:
|
|
# Get all episodes for symbol
|
|
all_episodes = self.data_collector.training_episodes.get(symbol, [])
|
|
|
|
# Filter for profitable episodes
|
|
profitable_episodes = [
|
|
ep for ep in all_episodes
|
|
if (ep.actual_outcome.is_profitable and
|
|
ep.actual_outcome.profitability_score >= min_profitability)
|
|
]
|
|
|
|
# Sort by profitability and limit
|
|
profitable_episodes.sort(
|
|
key=lambda x: x.actual_outcome.profitability_score,
|
|
reverse=True
|
|
)
|
|
profitable_episodes = profitable_episodes[:max_episodes]
|
|
|
|
if len(profitable_episodes) < 10:
|
|
logger.warning(f"Insufficient profitable episodes for {symbol}: {len(profitable_episodes)}")
|
|
return {'status': 'insufficient_data', 'episodes_found': len(profitable_episodes)}
|
|
|
|
# Train on profitable episodes
|
|
results = self._train_on_episodes(
|
|
profitable_episodes,
|
|
training_mode='profitable_replay'
|
|
)
|
|
|
|
logger.info(f"Trained on {len(profitable_episodes)} profitable episodes for {symbol}")
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training on profitable episodes: {e}")
|
|
return {'status': 'error', 'error': str(e)}
|
|
|
|
def _train_on_episodes(self,
|
|
episodes: List[TrainingEpisode],
|
|
training_mode: str = 'batch') -> Dict[str, Any]:
|
|
"""Train on a batch of episodes with comprehensive data storage"""
|
|
try:
|
|
# Start new training session
|
|
session = CNNTrainingSession(
|
|
session_id=f"{training_mode}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
|
start_timestamp=datetime.now(),
|
|
training_mode=training_mode,
|
|
symbol=episodes[0].input_package.symbol if episodes else 'unknown'
|
|
)
|
|
self.current_session = session
|
|
|
|
# Create dataset and dataloader
|
|
dataset = CNNTrainingDataset(episodes)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=32,
|
|
shuffle=True,
|
|
num_workers=2
|
|
)
|
|
|
|
# Training loop
|
|
self.model.train()
|
|
total_loss = 0.0
|
|
num_batches = 0
|
|
|
|
for batch_idx, batch in enumerate(dataloader):
|
|
# Move to device
|
|
features = batch['features'].to(self.device)
|
|
pivot_class = batch['pivot_class'].to(self.device)
|
|
pivot_price = batch['pivot_price'].to(self.device)
|
|
confidence_target = batch['confidence_target'].to(self.device)
|
|
|
|
# Forward pass
|
|
self.optimizer.zero_grad()
|
|
outputs = self.model(features)
|
|
|
|
# Calculate losses
|
|
classification_loss = F.cross_entropy(outputs['pivot_logits'], pivot_class)
|
|
regression_loss = F.mse_loss(outputs['pivot_price'].squeeze(), pivot_price)
|
|
confidence_loss = F.binary_cross_entropy(
|
|
outputs['confidence'].squeeze(),
|
|
confidence_target
|
|
)
|
|
|
|
# Combined loss
|
|
total_batch_loss = classification_loss + 0.5 * regression_loss + 0.3 * confidence_loss
|
|
|
|
# Backward pass
|
|
total_batch_loss.backward()
|
|
|
|
# Gradient clipping
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
|
|
# Store gradients before optimizer step
|
|
gradients = {}
|
|
gradient_norms = {}
|
|
for name, param in self.model.named_parameters():
|
|
if param.grad is not None:
|
|
gradients[name] = param.grad.clone().detach()
|
|
gradient_norms[name] = param.grad.norm().item()
|
|
|
|
# Optimizer step
|
|
self.optimizer.step()
|
|
|
|
# Create training step record
|
|
step = CNNTrainingStep(
|
|
step_id=f"{session.session_id}_step_{batch_idx}",
|
|
timestamp=datetime.now(),
|
|
episode_id=f"batch_{batch_idx}",
|
|
input_features=features.detach().cpu(),
|
|
target_labels=pivot_class.detach().cpu(),
|
|
model_outputs={k: v.detach().cpu() for k, v in outputs.items()},
|
|
predictions=self._extract_predictions(outputs),
|
|
confidence_scores=outputs['confidence'].detach().cpu(),
|
|
total_loss=total_batch_loss.item(),
|
|
pivot_prediction_loss=classification_loss.item(),
|
|
confidence_loss=confidence_loss.item(),
|
|
regularization_loss=0.0,
|
|
gradients=gradients,
|
|
gradient_norms=gradient_norms,
|
|
learning_rate=self.optimizer.param_groups[0]['lr'],
|
|
batch_size=features.size(0)
|
|
)
|
|
|
|
# Calculate training value for this step
|
|
step.training_value = self._calculate_step_training_value(step, batch)
|
|
|
|
# Add to session
|
|
session.training_steps.append(step)
|
|
|
|
total_loss += total_batch_loss.item()
|
|
num_batches += 1
|
|
|
|
# Log progress
|
|
if batch_idx % 10 == 0:
|
|
logger.debug(f"Batch {batch_idx}: Loss = {total_batch_loss.item():.4f}")
|
|
|
|
# Finalize session
|
|
session.end_timestamp = datetime.now()
|
|
session.total_steps = num_batches
|
|
session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
|
session.best_loss = min(step.total_loss for step in session.training_steps)
|
|
|
|
# Calculate session value
|
|
session.session_value = self._calculate_session_value(session)
|
|
|
|
# Update scheduler
|
|
self.scheduler.step(session.average_loss)
|
|
|
|
# Save session
|
|
self._save_training_session(session)
|
|
|
|
# Update statistics
|
|
self.training_stats['total_sessions'] += 1
|
|
self.training_stats['total_steps'] += session.total_steps
|
|
if training_mode == 'profitable_replay':
|
|
self.training_stats['replay_sessions'] += 1
|
|
|
|
logger.info(f"Training session completed: {session.session_id}")
|
|
logger.info(f"Average loss: {session.average_loss:.4f}")
|
|
logger.info(f"Session value: {session.session_value:.3f}")
|
|
|
|
return {
|
|
'status': 'success',
|
|
'session_id': session.session_id,
|
|
'average_loss': session.average_loss,
|
|
'total_steps': session.total_steps,
|
|
'session_value': session.session_value
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in training session: {e}")
|
|
return {'status': 'error', 'error': str(e)}
|
|
finally:
|
|
self.current_session = None
|
|
|
|
def _extract_predictions(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
|
|
"""Extract human-readable predictions from model outputs"""
|
|
try:
|
|
pivot_probs = F.softmax(outputs['pivot_logits'], dim=1)
|
|
predicted_class = torch.argmax(pivot_probs, dim=1)
|
|
|
|
return {
|
|
'pivot_class': predicted_class.cpu().numpy().tolist(),
|
|
'pivot_probabilities': pivot_probs.cpu().numpy().tolist(),
|
|
'pivot_price': outputs['pivot_price'].cpu().numpy().tolist(),
|
|
'confidence': outputs['confidence'].cpu().numpy().tolist()
|
|
}
|
|
except Exception as e:
|
|
logger.warning(f"Error extracting predictions: {e}")
|
|
return {}
|
|
|
|
def _calculate_step_training_value(self,
|
|
step: CNNTrainingStep,
|
|
batch: Dict[str, Any]) -> float:
|
|
"""Calculate the training value of a step for replay prioritization"""
|
|
try:
|
|
value = 0.0
|
|
|
|
# Base value from loss (lower loss = higher value)
|
|
if step.total_loss > 0:
|
|
value += 1.0 / (1.0 + step.total_loss)
|
|
|
|
# Bonus for high profitability episodes in batch
|
|
avg_profitability = torch.mean(batch['profitability']).item()
|
|
value += avg_profitability * 0.3
|
|
|
|
# Bonus for gradient magnitude (indicates learning)
|
|
avg_grad_norm = np.mean(list(step.gradient_norms.values()))
|
|
value += min(avg_grad_norm / 10.0, 0.2) # Cap at 0.2
|
|
|
|
return min(value, 1.0)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error calculating step training value: {e}")
|
|
return 0.0
|
|
|
|
def _calculate_session_value(self, session: CNNTrainingSession) -> float:
|
|
"""Calculate overall session value for replay prioritization"""
|
|
try:
|
|
if not session.training_steps:
|
|
return 0.0
|
|
|
|
# Average step values
|
|
avg_step_value = np.mean([step.training_value for step in session.training_steps])
|
|
|
|
# Bonus for convergence
|
|
convergence_bonus = 0.0
|
|
if len(session.training_steps) > 10:
|
|
early_loss = np.mean([s.total_loss for s in session.training_steps[:5]])
|
|
late_loss = np.mean([s.total_loss for s in session.training_steps[-5:]])
|
|
if early_loss > late_loss:
|
|
convergence_bonus = min((early_loss - late_loss) / early_loss, 0.3)
|
|
|
|
# Bonus for profitable replay sessions
|
|
mode_bonus = 0.2 if session.training_mode == 'profitable_replay' else 0.0
|
|
|
|
return min(avg_step_value + convergence_bonus + mode_bonus, 1.0)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error calculating session value: {e}")
|
|
return 0.0
|
|
|
|
def _save_training_session(self, session: CNNTrainingSession):
|
|
"""Save training session to disk"""
|
|
try:
|
|
session_dir = self.storage_dir / session.symbol / 'sessions'
|
|
session_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save full session data
|
|
session_file = session_dir / f"{session.session_id}.pkl"
|
|
with open(session_file, 'wb') as f:
|
|
pickle.dump(session, f)
|
|
|
|
# Save session metadata
|
|
metadata = {
|
|
'session_id': session.session_id,
|
|
'start_timestamp': session.start_timestamp.isoformat(),
|
|
'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None,
|
|
'training_mode': session.training_mode,
|
|
'symbol': session.symbol,
|
|
'total_steps': session.total_steps,
|
|
'average_loss': session.average_loss,
|
|
'best_loss': session.best_loss,
|
|
'session_value': session.session_value
|
|
}
|
|
|
|
metadata_file = session_dir / f"{session.session_id}_metadata.json"
|
|
with open(metadata_file, 'w') as f:
|
|
json.dump(metadata, f, indent=2)
|
|
|
|
logger.debug(f"Saved training session: {session.session_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving training session: {e}")
|
|
|
|
def _finalize_training_session(self):
|
|
"""Finalize current training session"""
|
|
if self.current_session:
|
|
self.current_session.end_timestamp = datetime.now()
|
|
self._save_training_session(self.current_session)
|
|
self.training_sessions.append(self.current_session)
|
|
self.current_session = None
|
|
|
|
def get_training_statistics(self) -> Dict[str, Any]:
|
|
"""Get comprehensive training statistics"""
|
|
stats = self.training_stats.copy()
|
|
|
|
# Add recent session information
|
|
if self.training_sessions:
|
|
recent_sessions = sorted(
|
|
self.training_sessions,
|
|
key=lambda x: x.start_timestamp,
|
|
reverse=True
|
|
)[:10]
|
|
|
|
stats['recent_sessions'] = [
|
|
{
|
|
'session_id': s.session_id,
|
|
'timestamp': s.start_timestamp.isoformat(),
|
|
'mode': s.training_mode,
|
|
'average_loss': s.average_loss,
|
|
'session_value': s.session_value
|
|
}
|
|
for s in recent_sessions
|
|
]
|
|
|
|
# Calculate profitability rate
|
|
if stats['total_predictions'] > 0:
|
|
stats['profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions']
|
|
else:
|
|
stats['profitability_rate'] = 0.0
|
|
|
|
return stats
|
|
|
|
def replay_high_value_sessions(self,
|
|
symbol: str,
|
|
min_session_value: float = 0.7,
|
|
max_sessions: int = 10) -> Dict[str, Any]:
|
|
"""Replay high-value training sessions"""
|
|
try:
|
|
# Find high-value sessions
|
|
high_value_sessions = [
|
|
s for s in self.training_sessions
|
|
if (s.symbol == symbol and
|
|
s.session_value >= min_session_value)
|
|
]
|
|
|
|
# Sort by value and limit
|
|
high_value_sessions.sort(key=lambda x: x.session_value, reverse=True)
|
|
high_value_sessions = high_value_sessions[:max_sessions]
|
|
|
|
if not high_value_sessions:
|
|
return {'status': 'no_high_value_sessions', 'sessions_found': 0}
|
|
|
|
# Replay sessions
|
|
total_replayed = 0
|
|
for session in high_value_sessions:
|
|
# Extract episodes from session steps
|
|
episode_ids = list(set(step.episode_id for step in session.training_steps))
|
|
|
|
# Get corresponding episodes
|
|
episodes = []
|
|
for episode_id in episode_ids:
|
|
# Find episode in data collector
|
|
for ep in self.data_collector.training_episodes.get(symbol, []):
|
|
if ep.episode_id == episode_id:
|
|
episodes.append(ep)
|
|
break
|
|
|
|
if episodes:
|
|
self._train_on_episodes(episodes, training_mode='high_value_replay')
|
|
total_replayed += 1
|
|
|
|
logger.info(f"Replayed {total_replayed} high-value sessions for {symbol}")
|
|
return {
|
|
'status': 'success',
|
|
'sessions_replayed': total_replayed,
|
|
'sessions_found': len(high_value_sessions)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error replaying high-value sessions: {e}")
|
|
return {'status': 'error', 'error': str(e)}
|
|
|
|
# Global instance
|
|
cnn_trainer = None
|
|
|
|
def get_cnn_trainer(model: CNNPivotPredictor = None) -> CNNTrainer:
|
|
"""Get global CNN trainer instance"""
|
|
global cnn_trainer
|
|
if cnn_trainer is None:
|
|
if model is None:
|
|
model = CNNPivotPredictor()
|
|
cnn_trainer = CNNTrainer(model)
|
|
return cnn_trainer |