replay system
This commit is contained in:
785
core/cnn_training_pipeline.py
Normal file
785
core/cnn_training_pipeline.py
Normal file
@ -0,0 +1,785 @@
|
||||
"""
|
||||
CNN Training Pipeline with Comprehensive Data Storage and Replay
|
||||
|
||||
This module implements a robust CNN training pipeline that:
|
||||
1. Integrates with the comprehensive training data collection system
|
||||
2. Stores all backpropagation data for gradient replay
|
||||
3. Enables retraining on most profitable setups
|
||||
4. Maintains training episode profitability tracking
|
||||
5. Supports both real-time and batch training modes
|
||||
|
||||
Key Features:
|
||||
- Integration with TrainingDataCollector for data validation
|
||||
- Gradient and loss storage for each training step
|
||||
- Profitable episode prioritization and replay
|
||||
- Comprehensive training metrics and validation
|
||||
- Real-time pivot point prediction with outcome tracking
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any, Callable
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import pickle
|
||||
from collections import deque, defaultdict
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from .training_data_collector import (
|
||||
TrainingDataCollector,
|
||||
TrainingEpisode,
|
||||
ModelInputPackage,
|
||||
get_training_data_collector
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class CNNTrainingStep:
|
||||
"""Single CNN training step with complete backpropagation data"""
|
||||
step_id: str
|
||||
timestamp: datetime
|
||||
episode_id: str
|
||||
|
||||
# Input data
|
||||
input_features: torch.Tensor
|
||||
target_labels: torch.Tensor
|
||||
|
||||
# Forward pass results
|
||||
model_outputs: Dict[str, torch.Tensor]
|
||||
predictions: Dict[str, Any]
|
||||
confidence_scores: torch.Tensor
|
||||
|
||||
# Loss components
|
||||
total_loss: float
|
||||
pivot_prediction_loss: float
|
||||
confidence_loss: float
|
||||
regularization_loss: float
|
||||
|
||||
# Backpropagation data
|
||||
gradients: Dict[str, torch.Tensor] # Gradients for each parameter
|
||||
gradient_norms: Dict[str, float] # Gradient norms for monitoring
|
||||
|
||||
# Model state
|
||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None
|
||||
optimizer_state: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Training metadata
|
||||
learning_rate: float = 0.001
|
||||
batch_size: int = 32
|
||||
epoch: int = 0
|
||||
|
||||
# Profitability tracking
|
||||
actual_profitability: Optional[float] = None
|
||||
prediction_accuracy: Optional[float] = None
|
||||
training_value: float = 0.0 # Value of this training step for replay
|
||||
|
||||
@dataclass
|
||||
class CNNTrainingSession:
|
||||
"""Complete CNN training session with multiple steps"""
|
||||
session_id: str
|
||||
start_timestamp: datetime
|
||||
end_timestamp: Optional[datetime] = None
|
||||
|
||||
# Session configuration
|
||||
training_mode: str = 'real_time' # 'real_time', 'batch', 'replay'
|
||||
symbol: str = ''
|
||||
|
||||
# Training steps
|
||||
training_steps: List[CNNTrainingStep] = field(default_factory=list)
|
||||
|
||||
# Session metrics
|
||||
total_steps: int = 0
|
||||
average_loss: float = 0.0
|
||||
best_loss: float = float('inf')
|
||||
convergence_achieved: bool = False
|
||||
|
||||
# Profitability metrics
|
||||
profitable_predictions: int = 0
|
||||
total_predictions: int = 0
|
||||
profitability_rate: float = 0.0
|
||||
|
||||
# Session value for replay prioritization
|
||||
session_value: float = 0.0
|
||||
|
||||
class CNNPivotPredictor(nn.Module):
|
||||
"""CNN model for pivot point prediction with comprehensive output"""
|
||||
|
||||
def __init__(self,
|
||||
input_channels: int = 10, # Multiple timeframes
|
||||
sequence_length: int = 300, # 300 bars
|
||||
hidden_dim: int = 256,
|
||||
num_pivot_classes: int = 3, # high, low, none
|
||||
dropout_rate: float = 0.2):
|
||||
|
||||
super(CNNPivotPredictor, self).__init__()
|
||||
|
||||
self.input_channels = input_channels
|
||||
self.sequence_length = sequence_length
|
||||
self.hidden_dim = hidden_dim
|
||||
|
||||
# Convolutional layers for pattern extraction
|
||||
self.conv_layers = nn.Sequential(
|
||||
# First conv block
|
||||
nn.Conv1d(input_channels, 64, kernel_size=7, padding=3),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
|
||||
# Second conv block
|
||||
nn.Conv1d(64, 128, kernel_size=5, padding=2),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
|
||||
# Third conv block
|
||||
nn.Conv1d(128, 256, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
)
|
||||
|
||||
# LSTM for temporal dependencies
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=256,
|
||||
hidden_size=hidden_dim,
|
||||
num_layers=2,
|
||||
batch_first=True,
|
||||
dropout=dropout_rate,
|
||||
bidirectional=True
|
||||
)
|
||||
|
||||
# Attention mechanism
|
||||
self.attention = nn.MultiheadAttention(
|
||||
embed_dim=hidden_dim * 2, # Bidirectional LSTM
|
||||
num_heads=8,
|
||||
dropout=dropout_rate,
|
||||
batch_first=True
|
||||
)
|
||||
|
||||
# Output heads
|
||||
self.pivot_classifier = nn.Sequential(
|
||||
nn.Linear(hidden_dim * 2, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(hidden_dim, num_pivot_classes)
|
||||
)
|
||||
|
||||
self.pivot_price_regressor = nn.Sequential(
|
||||
nn.Linear(hidden_dim * 2, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(hidden_dim, 1)
|
||||
)
|
||||
|
||||
self.confidence_head = nn.Sequential(
|
||||
nn.Linear(hidden_dim * 2, hidden_dim // 2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim // 2, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize weights with proper scaling"""
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Conv1d):
|
||||
torch.nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through CNN pivot predictor
|
||||
|
||||
Args:
|
||||
x: Input tensor [batch_size, input_channels, sequence_length]
|
||||
|
||||
Returns:
|
||||
Dict containing predictions and hidden states
|
||||
"""
|
||||
batch_size = x.size(0)
|
||||
|
||||
# Convolutional feature extraction
|
||||
conv_features = self.conv_layers(x) # [batch, 256, sequence_length]
|
||||
|
||||
# Prepare for LSTM (transpose to [batch, sequence, features])
|
||||
lstm_input = conv_features.transpose(1, 2) # [batch, sequence_length, 256]
|
||||
|
||||
# LSTM processing
|
||||
lstm_output, (hidden, cell) = self.lstm(lstm_input) # [batch, sequence_length, hidden_dim*2]
|
||||
|
||||
# Attention mechanism
|
||||
attended_output, attention_weights = self.attention(
|
||||
lstm_output, lstm_output, lstm_output
|
||||
)
|
||||
|
||||
# Use the last timestep for predictions
|
||||
final_features = attended_output[:, -1, :] # [batch, hidden_dim*2]
|
||||
|
||||
# Generate predictions
|
||||
pivot_logits = self.pivot_classifier(final_features)
|
||||
pivot_price = self.pivot_price_regressor(final_features)
|
||||
confidence = self.confidence_head(final_features)
|
||||
|
||||
return {
|
||||
'pivot_logits': pivot_logits,
|
||||
'pivot_price': pivot_price,
|
||||
'confidence': confidence,
|
||||
'hidden_states': final_features,
|
||||
'attention_weights': attention_weights,
|
||||
'conv_features': conv_features,
|
||||
'lstm_output': lstm_output
|
||||
}
|
||||
|
||||
class CNNTrainingDataset(Dataset):
|
||||
"""Dataset for CNN training with training episodes"""
|
||||
|
||||
def __init__(self, training_episodes: List[TrainingEpisode]):
|
||||
self.episodes = training_episodes
|
||||
self.valid_episodes = self._validate_episodes()
|
||||
|
||||
def _validate_episodes(self) -> List[TrainingEpisode]:
|
||||
"""Validate and filter episodes for training"""
|
||||
valid = []
|
||||
for episode in self.episodes:
|
||||
try:
|
||||
# Check if episode has required data
|
||||
if (episode.input_package.cnn_features is not None and
|
||||
episode.actual_outcome.outcome_validated):
|
||||
valid.append(episode)
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid episode {episode.episode_id}: {e}")
|
||||
|
||||
logger.info(f"Validated {len(valid)}/{len(self.episodes)} episodes for training")
|
||||
return valid
|
||||
|
||||
def __len__(self):
|
||||
return len(self.valid_episodes)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
episode = self.valid_episodes[idx]
|
||||
|
||||
# Extract features
|
||||
features = torch.from_numpy(episode.input_package.cnn_features).float()
|
||||
|
||||
# Create labels from actual outcomes
|
||||
pivot_class = self._determine_pivot_class(episode.actual_outcome)
|
||||
pivot_price = episode.actual_outcome.optimal_exit_price
|
||||
confidence_target = episode.actual_outcome.profitability_score
|
||||
|
||||
return {
|
||||
'features': features,
|
||||
'pivot_class': torch.tensor(pivot_class, dtype=torch.long),
|
||||
'pivot_price': torch.tensor(pivot_price, dtype=torch.float),
|
||||
'confidence_target': torch.tensor(confidence_target, dtype=torch.float),
|
||||
'episode_id': episode.episode_id,
|
||||
'profitability': episode.actual_outcome.profitability_score
|
||||
}
|
||||
|
||||
def _determine_pivot_class(self, outcome) -> int:
|
||||
"""Determine pivot class from outcome"""
|
||||
if outcome.price_change_15m > 0.5: # Significant upward movement
|
||||
return 0 # High pivot
|
||||
elif outcome.price_change_15m < -0.5: # Significant downward movement
|
||||
return 1 # Low pivot
|
||||
else:
|
||||
return 2 # No significant pivot
|
||||
|
||||
class CNNTrainer:
|
||||
"""CNN trainer with comprehensive data storage and replay capabilities"""
|
||||
|
||||
def __init__(self,
|
||||
model: CNNPivotPredictor,
|
||||
device: str = 'cuda',
|
||||
learning_rate: float = 0.001,
|
||||
storage_dir: str = "cnn_training_storage"):
|
||||
|
||||
self.model = model.to(device)
|
||||
self.device = device
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
# Storage
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Optimizer
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=learning_rate,
|
||||
weight_decay=1e-5
|
||||
)
|
||||
|
||||
# Learning rate scheduler
|
||||
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.optimizer, mode='min', patience=10, factor=0.5
|
||||
)
|
||||
|
||||
# Training data collector
|
||||
self.data_collector = get_training_data_collector()
|
||||
|
||||
# Training sessions storage
|
||||
self.training_sessions: List[CNNTrainingSession] = []
|
||||
self.current_session: Optional[CNNTrainingSession] = None
|
||||
|
||||
# Training statistics
|
||||
self.training_stats = {
|
||||
'total_sessions': 0,
|
||||
'total_steps': 0,
|
||||
'best_validation_loss': float('inf'),
|
||||
'profitable_predictions': 0,
|
||||
'total_predictions': 0,
|
||||
'replay_sessions': 0
|
||||
}
|
||||
|
||||
# Background training
|
||||
self.is_training = False
|
||||
self.training_thread = None
|
||||
|
||||
logger.info(f"CNN Trainer initialized")
|
||||
logger.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
|
||||
logger.info(f"Storage directory: {self.storage_dir}")
|
||||
|
||||
def start_real_time_training(self, symbol: str):
|
||||
"""Start real-time training for a symbol"""
|
||||
if self.is_training:
|
||||
logger.warning("CNN training already running")
|
||||
return
|
||||
|
||||
self.is_training = True
|
||||
self.training_thread = threading.Thread(
|
||||
target=self._real_time_training_worker,
|
||||
args=(symbol,),
|
||||
daemon=True
|
||||
)
|
||||
self.training_thread.start()
|
||||
|
||||
logger.info(f"Started real-time CNN training for {symbol}")
|
||||
|
||||
def stop_training(self):
|
||||
"""Stop training"""
|
||||
self.is_training = False
|
||||
if self.training_thread:
|
||||
self.training_thread.join(timeout=10)
|
||||
|
||||
if self.current_session:
|
||||
self._finalize_training_session()
|
||||
|
||||
logger.info("CNN training stopped")
|
||||
|
||||
def _real_time_training_worker(self, symbol: str):
|
||||
"""Real-time training worker"""
|
||||
logger.info(f"Real-time CNN training worker started for {symbol}")
|
||||
|
||||
while self.is_training:
|
||||
try:
|
||||
# Get high-priority episodes for training
|
||||
episodes = self.data_collector.get_high_priority_episodes(
|
||||
symbol=symbol,
|
||||
limit=100,
|
||||
min_priority=0.3
|
||||
)
|
||||
|
||||
if len(episodes) >= 32: # Minimum batch size
|
||||
self._train_on_episodes(episodes, training_mode='real_time')
|
||||
|
||||
# Wait before next training cycle
|
||||
threading.Event().wait(300) # Train every 5 minutes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in real-time training worker: {e}")
|
||||
threading.Event().wait(60) # Wait before retrying
|
||||
|
||||
logger.info(f"Real-time CNN training worker stopped for {symbol}")
|
||||
|
||||
def train_on_profitable_episodes(self,
|
||||
symbol: str,
|
||||
min_profitability: float = 0.7,
|
||||
max_episodes: int = 500) -> Dict[str, Any]:
|
||||
"""Train specifically on most profitable episodes"""
|
||||
try:
|
||||
# Get all episodes for symbol
|
||||
all_episodes = self.data_collector.training_episodes.get(symbol, [])
|
||||
|
||||
# Filter for profitable episodes
|
||||
profitable_episodes = [
|
||||
ep for ep in all_episodes
|
||||
if (ep.actual_outcome.is_profitable and
|
||||
ep.actual_outcome.profitability_score >= min_profitability)
|
||||
]
|
||||
|
||||
# Sort by profitability and limit
|
||||
profitable_episodes.sort(
|
||||
key=lambda x: x.actual_outcome.profitability_score,
|
||||
reverse=True
|
||||
)
|
||||
profitable_episodes = profitable_episodes[:max_episodes]
|
||||
|
||||
if len(profitable_episodes) < 10:
|
||||
logger.warning(f"Insufficient profitable episodes for {symbol}: {len(profitable_episodes)}")
|
||||
return {'status': 'insufficient_data', 'episodes_found': len(profitable_episodes)}
|
||||
|
||||
# Train on profitable episodes
|
||||
results = self._train_on_episodes(
|
||||
profitable_episodes,
|
||||
training_mode='profitable_replay'
|
||||
)
|
||||
|
||||
logger.info(f"Trained on {len(profitable_episodes)} profitable episodes for {symbol}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training on profitable episodes: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
def _train_on_episodes(self,
|
||||
episodes: List[TrainingEpisode],
|
||||
training_mode: str = 'batch') -> Dict[str, Any]:
|
||||
"""Train on a batch of episodes with comprehensive data storage"""
|
||||
try:
|
||||
# Start new training session
|
||||
session = CNNTrainingSession(
|
||||
session_id=f"{training_mode}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
start_timestamp=datetime.now(),
|
||||
training_mode=training_mode,
|
||||
symbol=episodes[0].input_package.symbol if episodes else 'unknown'
|
||||
)
|
||||
self.current_session = session
|
||||
|
||||
# Create dataset and dataloader
|
||||
dataset = CNNTrainingDataset(episodes)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
num_workers=2
|
||||
)
|
||||
|
||||
# Training loop
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
# Move to device
|
||||
features = batch['features'].to(self.device)
|
||||
pivot_class = batch['pivot_class'].to(self.device)
|
||||
pivot_price = batch['pivot_price'].to(self.device)
|
||||
confidence_target = batch['confidence_target'].to(self.device)
|
||||
|
||||
# Forward pass
|
||||
self.optimizer.zero_grad()
|
||||
outputs = self.model(features)
|
||||
|
||||
# Calculate losses
|
||||
classification_loss = F.cross_entropy(outputs['pivot_logits'], pivot_class)
|
||||
regression_loss = F.mse_loss(outputs['pivot_price'].squeeze(), pivot_price)
|
||||
confidence_loss = F.binary_cross_entropy(
|
||||
outputs['confidence'].squeeze(),
|
||||
confidence_target
|
||||
)
|
||||
|
||||
# Combined loss
|
||||
total_batch_loss = classification_loss + 0.5 * regression_loss + 0.3 * confidence_loss
|
||||
|
||||
# Backward pass
|
||||
total_batch_loss.backward()
|
||||
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
||||
|
||||
# Store gradients before optimizer step
|
||||
gradients = {}
|
||||
gradient_norms = {}
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.grad is not None:
|
||||
gradients[name] = param.grad.clone().detach()
|
||||
gradient_norms[name] = param.grad.norm().item()
|
||||
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
|
||||
# Create training step record
|
||||
step = CNNTrainingStep(
|
||||
step_id=f"{session.session_id}_step_{batch_idx}",
|
||||
timestamp=datetime.now(),
|
||||
episode_id=f"batch_{batch_idx}",
|
||||
input_features=features.detach().cpu(),
|
||||
target_labels=pivot_class.detach().cpu(),
|
||||
model_outputs={k: v.detach().cpu() for k, v in outputs.items()},
|
||||
predictions=self._extract_predictions(outputs),
|
||||
confidence_scores=outputs['confidence'].detach().cpu(),
|
||||
total_loss=total_batch_loss.item(),
|
||||
pivot_prediction_loss=classification_loss.item(),
|
||||
confidence_loss=confidence_loss.item(),
|
||||
regularization_loss=0.0,
|
||||
gradients=gradients,
|
||||
gradient_norms=gradient_norms,
|
||||
learning_rate=self.optimizer.param_groups[0]['lr'],
|
||||
batch_size=features.size(0)
|
||||
)
|
||||
|
||||
# Calculate training value for this step
|
||||
step.training_value = self._calculate_step_training_value(step, batch)
|
||||
|
||||
# Add to session
|
||||
session.training_steps.append(step)
|
||||
|
||||
total_loss += total_batch_loss.item()
|
||||
num_batches += 1
|
||||
|
||||
# Log progress
|
||||
if batch_idx % 10 == 0:
|
||||
logger.debug(f"Batch {batch_idx}: Loss = {total_batch_loss.item():.4f}")
|
||||
|
||||
# Finalize session
|
||||
session.end_timestamp = datetime.now()
|
||||
session.total_steps = num_batches
|
||||
session.average_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
||||
session.best_loss = min(step.total_loss for step in session.training_steps)
|
||||
|
||||
# Calculate session value
|
||||
session.session_value = self._calculate_session_value(session)
|
||||
|
||||
# Update scheduler
|
||||
self.scheduler.step(session.average_loss)
|
||||
|
||||
# Save session
|
||||
self._save_training_session(session)
|
||||
|
||||
# Update statistics
|
||||
self.training_stats['total_sessions'] += 1
|
||||
self.training_stats['total_steps'] += session.total_steps
|
||||
if training_mode == 'profitable_replay':
|
||||
self.training_stats['replay_sessions'] += 1
|
||||
|
||||
logger.info(f"Training session completed: {session.session_id}")
|
||||
logger.info(f"Average loss: {session.average_loss:.4f}")
|
||||
logger.info(f"Session value: {session.session_value:.3f}")
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'session_id': session.session_id,
|
||||
'average_loss': session.average_loss,
|
||||
'total_steps': session.total_steps,
|
||||
'session_value': session.session_value
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training session: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
finally:
|
||||
self.current_session = None
|
||||
|
||||
def _extract_predictions(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
|
||||
"""Extract human-readable predictions from model outputs"""
|
||||
try:
|
||||
pivot_probs = F.softmax(outputs['pivot_logits'], dim=1)
|
||||
predicted_class = torch.argmax(pivot_probs, dim=1)
|
||||
|
||||
return {
|
||||
'pivot_class': predicted_class.cpu().numpy().tolist(),
|
||||
'pivot_probabilities': pivot_probs.cpu().numpy().tolist(),
|
||||
'pivot_price': outputs['pivot_price'].cpu().numpy().tolist(),
|
||||
'confidence': outputs['confidence'].cpu().numpy().tolist()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting predictions: {e}")
|
||||
return {}
|
||||
|
||||
def _calculate_step_training_value(self,
|
||||
step: CNNTrainingStep,
|
||||
batch: Dict[str, Any]) -> float:
|
||||
"""Calculate the training value of a step for replay prioritization"""
|
||||
try:
|
||||
value = 0.0
|
||||
|
||||
# Base value from loss (lower loss = higher value)
|
||||
if step.total_loss > 0:
|
||||
value += 1.0 / (1.0 + step.total_loss)
|
||||
|
||||
# Bonus for high profitability episodes in batch
|
||||
avg_profitability = torch.mean(batch['profitability']).item()
|
||||
value += avg_profitability * 0.3
|
||||
|
||||
# Bonus for gradient magnitude (indicates learning)
|
||||
avg_grad_norm = np.mean(list(step.gradient_norms.values()))
|
||||
value += min(avg_grad_norm / 10.0, 0.2) # Cap at 0.2
|
||||
|
||||
return min(value, 1.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating step training value: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_session_value(self, session: CNNTrainingSession) -> float:
|
||||
"""Calculate overall session value for replay prioritization"""
|
||||
try:
|
||||
if not session.training_steps:
|
||||
return 0.0
|
||||
|
||||
# Average step values
|
||||
avg_step_value = np.mean([step.training_value for step in session.training_steps])
|
||||
|
||||
# Bonus for convergence
|
||||
convergence_bonus = 0.0
|
||||
if len(session.training_steps) > 10:
|
||||
early_loss = np.mean([s.total_loss for s in session.training_steps[:5]])
|
||||
late_loss = np.mean([s.total_loss for s in session.training_steps[-5:]])
|
||||
if early_loss > late_loss:
|
||||
convergence_bonus = min((early_loss - late_loss) / early_loss, 0.3)
|
||||
|
||||
# Bonus for profitable replay sessions
|
||||
mode_bonus = 0.2 if session.training_mode == 'profitable_replay' else 0.0
|
||||
|
||||
return min(avg_step_value + convergence_bonus + mode_bonus, 1.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating session value: {e}")
|
||||
return 0.0
|
||||
|
||||
def _save_training_session(self, session: CNNTrainingSession):
|
||||
"""Save training session to disk"""
|
||||
try:
|
||||
session_dir = self.storage_dir / session.symbol / 'sessions'
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save full session data
|
||||
session_file = session_dir / f"{session.session_id}.pkl"
|
||||
with open(session_file, 'wb') as f:
|
||||
pickle.dump(session, f)
|
||||
|
||||
# Save session metadata
|
||||
metadata = {
|
||||
'session_id': session.session_id,
|
||||
'start_timestamp': session.start_timestamp.isoformat(),
|
||||
'end_timestamp': session.end_timestamp.isoformat() if session.end_timestamp else None,
|
||||
'training_mode': session.training_mode,
|
||||
'symbol': session.symbol,
|
||||
'total_steps': session.total_steps,
|
||||
'average_loss': session.average_loss,
|
||||
'best_loss': session.best_loss,
|
||||
'session_value': session.session_value
|
||||
}
|
||||
|
||||
metadata_file = session_dir / f"{session.session_id}_metadata.json"
|
||||
with open(metadata_file, 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
logger.debug(f"Saved training session: {session.session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving training session: {e}")
|
||||
|
||||
def _finalize_training_session(self):
|
||||
"""Finalize current training session"""
|
||||
if self.current_session:
|
||||
self.current_session.end_timestamp = datetime.now()
|
||||
self._save_training_session(self.current_session)
|
||||
self.training_sessions.append(self.current_session)
|
||||
self.current_session = None
|
||||
|
||||
def get_training_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive training statistics"""
|
||||
stats = self.training_stats.copy()
|
||||
|
||||
# Add recent session information
|
||||
if self.training_sessions:
|
||||
recent_sessions = sorted(
|
||||
self.training_sessions,
|
||||
key=lambda x: x.start_timestamp,
|
||||
reverse=True
|
||||
)[:10]
|
||||
|
||||
stats['recent_sessions'] = [
|
||||
{
|
||||
'session_id': s.session_id,
|
||||
'timestamp': s.start_timestamp.isoformat(),
|
||||
'mode': s.training_mode,
|
||||
'average_loss': s.average_loss,
|
||||
'session_value': s.session_value
|
||||
}
|
||||
for s in recent_sessions
|
||||
]
|
||||
|
||||
# Calculate profitability rate
|
||||
if stats['total_predictions'] > 0:
|
||||
stats['profitability_rate'] = stats['profitable_predictions'] / stats['total_predictions']
|
||||
else:
|
||||
stats['profitability_rate'] = 0.0
|
||||
|
||||
return stats
|
||||
|
||||
def replay_high_value_sessions(self,
|
||||
symbol: str,
|
||||
min_session_value: float = 0.7,
|
||||
max_sessions: int = 10) -> Dict[str, Any]:
|
||||
"""Replay high-value training sessions"""
|
||||
try:
|
||||
# Find high-value sessions
|
||||
high_value_sessions = [
|
||||
s for s in self.training_sessions
|
||||
if (s.symbol == symbol and
|
||||
s.session_value >= min_session_value)
|
||||
]
|
||||
|
||||
# Sort by value and limit
|
||||
high_value_sessions.sort(key=lambda x: x.session_value, reverse=True)
|
||||
high_value_sessions = high_value_sessions[:max_sessions]
|
||||
|
||||
if not high_value_sessions:
|
||||
return {'status': 'no_high_value_sessions', 'sessions_found': 0}
|
||||
|
||||
# Replay sessions
|
||||
total_replayed = 0
|
||||
for session in high_value_sessions:
|
||||
# Extract episodes from session steps
|
||||
episode_ids = list(set(step.episode_id for step in session.training_steps))
|
||||
|
||||
# Get corresponding episodes
|
||||
episodes = []
|
||||
for episode_id in episode_ids:
|
||||
# Find episode in data collector
|
||||
for ep in self.data_collector.training_episodes.get(symbol, []):
|
||||
if ep.episode_id == episode_id:
|
||||
episodes.append(ep)
|
||||
break
|
||||
|
||||
if episodes:
|
||||
self._train_on_episodes(episodes, training_mode='high_value_replay')
|
||||
total_replayed += 1
|
||||
|
||||
logger.info(f"Replayed {total_replayed} high-value sessions for {symbol}")
|
||||
return {
|
||||
'status': 'success',
|
||||
'sessions_replayed': total_replayed,
|
||||
'sessions_found': len(high_value_sessions)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error replaying high-value sessions: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
# Global instance
|
||||
cnn_trainer = None
|
||||
|
||||
def get_cnn_trainer(model: CNNPivotPredictor = None) -> CNNTrainer:
|
||||
"""Get global CNN trainer instance"""
|
||||
global cnn_trainer
|
||||
if cnn_trainer is None:
|
||||
if model is None:
|
||||
model = CNNPivotPredictor()
|
||||
cnn_trainer = CNNTrainer(model)
|
||||
return cnn_trainer
|
Reference in New Issue
Block a user