803 lines
33 KiB
Python
803 lines
33 KiB
Python
"""
|
|
Enhanced CNN Trainer with Perfect Move Learning
|
|
|
|
This trainer implements:
|
|
1. Training on marked perfect moves with known outcomes
|
|
2. Multi-timeframe CNN model training with confidence scoring
|
|
3. Backpropagation on optimal moves when future outcomes are known
|
|
4. Progressive learning from real trading experience
|
|
5. Symbol-specific and timeframe-specific model fine-tuning
|
|
"""
|
|
|
|
import logging
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
from pathlib import Path
|
|
import json
|
|
|
|
from core.config import get_config
|
|
from core.data_provider import DataProvider
|
|
from core.enhanced_orchestrator import PerfectMove, EnhancedTradingOrchestrator
|
|
from models import CNNModelInterface
|
|
import models
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class PerfectMoveDataset(Dataset):
|
|
"""Dataset for training on perfect moves with known outcomes"""
|
|
|
|
def __init__(self, perfect_moves: List[PerfectMove], data_provider: DataProvider):
|
|
"""
|
|
Initialize dataset from perfect moves
|
|
|
|
Args:
|
|
perfect_moves: List of perfect moves with known outcomes
|
|
data_provider: Data provider to fetch additional context
|
|
"""
|
|
self.perfect_moves = perfect_moves
|
|
self.data_provider = data_provider
|
|
self.samples = []
|
|
self._prepare_samples()
|
|
|
|
def _prepare_samples(self):
|
|
"""Prepare training samples from perfect moves"""
|
|
logger.info(f"Preparing {len(self.perfect_moves)} perfect move samples")
|
|
|
|
for move in self.perfect_moves:
|
|
try:
|
|
# Get feature matrix at the time of the decision
|
|
feature_matrix = self.data_provider.get_feature_matrix(
|
|
symbol=move.symbol,
|
|
timeframes=[move.timeframe],
|
|
window_size=20,
|
|
end_time=move.timestamp
|
|
)
|
|
|
|
if feature_matrix is not None:
|
|
# Convert optimal action to label
|
|
action_to_label = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
|
|
label = action_to_label.get(move.optimal_action, 1)
|
|
|
|
# Create confidence target (what confidence should have been)
|
|
confidence_target = move.confidence_should_have_been
|
|
|
|
sample = {
|
|
'features': feature_matrix,
|
|
'action_label': label,
|
|
'confidence_target': confidence_target,
|
|
'symbol': move.symbol,
|
|
'timeframe': move.timeframe,
|
|
'outcome': move.actual_outcome,
|
|
'timestamp': move.timestamp
|
|
}
|
|
self.samples.append(sample)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error preparing sample for perfect move: {e}")
|
|
|
|
logger.info(f"Prepared {len(self.samples)} valid training samples")
|
|
|
|
def __len__(self):
|
|
return len(self.samples)
|
|
|
|
def __getitem__(self, idx):
|
|
sample = self.samples[idx]
|
|
|
|
# Convert to tensors
|
|
features = torch.FloatTensor(sample['features'])
|
|
action_label = torch.LongTensor([sample['action_label']])
|
|
confidence_target = torch.FloatTensor([sample['confidence_target']])
|
|
|
|
return {
|
|
'features': features,
|
|
'action_label': action_label,
|
|
'confidence_target': confidence_target,
|
|
'metadata': {
|
|
'symbol': sample['symbol'],
|
|
'timeframe': sample['timeframe'],
|
|
'outcome': sample['outcome'],
|
|
'timestamp': sample['timestamp']
|
|
}
|
|
}
|
|
|
|
class EnhancedCNNModel(nn.Module, CNNModelInterface):
|
|
"""Enhanced CNN model with timeframe-specific predictions and confidence scoring"""
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
nn.Module.__init__(self)
|
|
CNNModelInterface.__init__(self, config)
|
|
|
|
self.timeframes = config.get('timeframes', ['1h', '4h', '1d'])
|
|
self.n_features = len(config.get('features', ['open', 'high', 'low', 'close', 'volume']))
|
|
self.window_size = config.get('window_size', 20)
|
|
|
|
# Build the neural network
|
|
self._build_network()
|
|
|
|
# Initialize device
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.to(self.device)
|
|
|
|
# Training components
|
|
self.optimizer = optim.Adam(self.parameters(), lr=config.get('learning_rate', 0.001))
|
|
self.action_criterion = nn.CrossEntropyLoss()
|
|
self.confidence_criterion = nn.MSELoss()
|
|
|
|
logger.info(f"Enhanced CNN model initialized for {len(self.timeframes)} timeframes")
|
|
|
|
def _build_network(self):
|
|
"""Build the CNN architecture"""
|
|
# Convolutional feature extraction
|
|
self.conv_layers = nn.Sequential(
|
|
# First conv block
|
|
nn.Conv1d(self.n_features, 64, kernel_size=3, padding=1),
|
|
nn.BatchNorm1d(64),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
|
|
# Second conv block
|
|
nn.Conv1d(64, 128, kernel_size=3, padding=1),
|
|
nn.BatchNorm1d(128),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
|
|
# Third conv block
|
|
nn.Conv1d(128, 256, kernel_size=3, padding=1),
|
|
nn.BatchNorm1d(256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
|
|
# Global average pooling
|
|
nn.AdaptiveAvgPool1d(1)
|
|
)
|
|
|
|
# Timeframe-specific heads
|
|
self.timeframe_heads = nn.ModuleDict()
|
|
for timeframe in self.timeframes:
|
|
self.timeframe_heads[timeframe] = nn.Sequential(
|
|
nn.Linear(256, 128),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(128, 64),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3)
|
|
)
|
|
|
|
# Action prediction heads (one per timeframe)
|
|
self.action_heads = nn.ModuleDict()
|
|
for timeframe in self.timeframes:
|
|
self.action_heads[timeframe] = nn.Linear(64, 3) # BUY, HOLD, SELL
|
|
|
|
# Confidence prediction heads (one per timeframe)
|
|
self.confidence_heads = nn.ModuleDict()
|
|
for timeframe in self.timeframes:
|
|
self.confidence_heads[timeframe] = nn.Sequential(
|
|
nn.Linear(64, 32),
|
|
nn.ReLU(),
|
|
nn.Linear(32, 1),
|
|
nn.Sigmoid() # Output between 0 and 1
|
|
)
|
|
|
|
def forward(self, x, timeframe: str = None):
|
|
"""
|
|
Forward pass through the network
|
|
|
|
Args:
|
|
x: Input tensor [batch_size, window_size, features]
|
|
timeframe: Specific timeframe to predict for
|
|
|
|
Returns:
|
|
action_probs: Action probabilities
|
|
confidence: Confidence score
|
|
"""
|
|
# Reshape for conv1d: [batch, features, sequence]
|
|
x = x.transpose(1, 2)
|
|
|
|
# Extract features
|
|
features = self.conv_layers(x) # [batch, 256, 1]
|
|
features = features.squeeze(-1) # [batch, 256]
|
|
|
|
if timeframe and timeframe in self.timeframe_heads:
|
|
# Timeframe-specific prediction
|
|
tf_features = self.timeframe_heads[timeframe](features)
|
|
action_logits = self.action_heads[timeframe](tf_features)
|
|
confidence = self.confidence_heads[timeframe](tf_features)
|
|
|
|
action_probs = torch.softmax(action_logits, dim=1)
|
|
return action_probs, confidence.squeeze(-1)
|
|
else:
|
|
# Multi-timeframe prediction (average across timeframes)
|
|
all_action_probs = []
|
|
all_confidences = []
|
|
|
|
for tf in self.timeframes:
|
|
tf_features = self.timeframe_heads[tf](features)
|
|
action_logits = self.action_heads[tf](tf_features)
|
|
confidence = self.confidence_heads[tf](tf_features)
|
|
|
|
action_probs = torch.softmax(action_logits, dim=1)
|
|
all_action_probs.append(action_probs)
|
|
all_confidences.append(confidence.squeeze(-1))
|
|
|
|
# Average predictions across timeframes
|
|
avg_action_probs = torch.stack(all_action_probs).mean(dim=0)
|
|
avg_confidence = torch.stack(all_confidences).mean(dim=0)
|
|
|
|
return avg_action_probs, avg_confidence
|
|
|
|
def predict(self, features: np.ndarray) -> Tuple[np.ndarray, float]:
|
|
"""Predict action probabilities and confidence"""
|
|
self.eval()
|
|
with torch.no_grad():
|
|
x = torch.FloatTensor(features).to(self.device)
|
|
if len(x.shape) == 2:
|
|
x = x.unsqueeze(0) # Add batch dimension
|
|
|
|
action_probs, confidence = self.forward(x)
|
|
|
|
return action_probs[0].cpu().numpy(), confidence[0].cpu().item()
|
|
|
|
def predict_timeframe(self, features: np.ndarray, timeframe: str) -> Tuple[np.ndarray, float]:
|
|
"""Predict for specific timeframe"""
|
|
self.eval()
|
|
with torch.no_grad():
|
|
x = torch.FloatTensor(features).to(self.device)
|
|
if len(x.shape) == 2:
|
|
x = x.unsqueeze(0) # Add batch dimension
|
|
|
|
action_probs, confidence = self.forward(x, timeframe)
|
|
|
|
return action_probs[0].cpu().numpy(), confidence[0].cpu().item()
|
|
|
|
def get_memory_usage(self) -> int:
|
|
"""Get memory usage in MB"""
|
|
if torch.cuda.is_available():
|
|
return torch.cuda.memory_allocated(self.device) // (1024 * 1024)
|
|
else:
|
|
# Rough estimate for CPU
|
|
param_count = sum(p.numel() for p in self.parameters())
|
|
return (param_count * 4) // (1024 * 1024) # 4 bytes per float32
|
|
|
|
def train(self, training_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Train the model (placeholder for interface compatibility)"""
|
|
return {}
|
|
|
|
class EnhancedCNNTrainer:
|
|
"""Enhanced CNN trainer using perfect moves and real market outcomes"""
|
|
|
|
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
|
|
"""Initialize the enhanced trainer"""
|
|
self.config = config or get_config()
|
|
self.orchestrator = orchestrator
|
|
self.data_provider = DataProvider(self.config)
|
|
|
|
# Training parameters
|
|
self.learning_rate = self.config.training.get('learning_rate', 0.001)
|
|
self.batch_size = self.config.training.get('batch_size', 32)
|
|
self.epochs = self.config.training.get('epochs', 100)
|
|
self.patience = self.config.training.get('early_stopping_patience', 10)
|
|
|
|
# Model
|
|
self.model = EnhancedCNNModel(self.config.cnn)
|
|
|
|
# Training history
|
|
self.training_history = {
|
|
'train_loss': [],
|
|
'val_loss': [],
|
|
'train_accuracy': [],
|
|
'val_accuracy': [],
|
|
'confidence_accuracy': []
|
|
}
|
|
|
|
# Create save directory
|
|
models_path = self.config.cnn.get('model_dir', "models/enhanced_cnn")
|
|
self.save_dir = Path(models_path)
|
|
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
logger.info("Enhanced CNN trainer initialized")
|
|
|
|
def train_on_perfect_moves(self, min_samples: int = 100) -> Dict[str, Any]:
|
|
"""Train the model on perfect moves from the orchestrator"""
|
|
if not self.orchestrator:
|
|
raise ValueError("Orchestrator required for perfect move training")
|
|
|
|
# Get perfect moves from orchestrator
|
|
perfect_moves = []
|
|
for symbol in self.config.symbols:
|
|
symbol_moves = self.orchestrator.get_perfect_moves_for_training(symbol=symbol)
|
|
perfect_moves.extend(symbol_moves)
|
|
|
|
if len(perfect_moves) < min_samples:
|
|
logger.warning(f"Not enough perfect moves for training: {len(perfect_moves)} < {min_samples}")
|
|
return {'error': 'insufficient_data', 'samples': len(perfect_moves)}
|
|
|
|
logger.info(f"Training on {len(perfect_moves)} perfect moves")
|
|
|
|
# Create dataset
|
|
dataset = PerfectMoveDataset(perfect_moves, self.data_provider)
|
|
|
|
# Split into train/validation
|
|
train_size = int(0.8 * len(dataset))
|
|
val_size = len(dataset) - train_size
|
|
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
|
|
|
# Create data loaders
|
|
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
|
|
val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
|
|
|
|
# Training loop
|
|
best_val_loss = float('inf')
|
|
patience_counter = 0
|
|
|
|
for epoch in range(self.epochs):
|
|
# Training phase
|
|
train_loss, train_acc = self._train_epoch(train_loader)
|
|
|
|
# Validation phase
|
|
val_loss, val_acc, conf_acc = self._validate_epoch(val_loader)
|
|
|
|
# Update history
|
|
self.training_history['train_loss'].append(train_loss)
|
|
self.training_history['val_loss'].append(val_loss)
|
|
self.training_history['train_accuracy'].append(train_acc)
|
|
self.training_history['val_accuracy'].append(val_acc)
|
|
self.training_history['confidence_accuracy'].append(conf_acc)
|
|
|
|
# Log progress
|
|
logger.info(f"Epoch {epoch+1}/{self.epochs}: "
|
|
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
|
|
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
|
|
f"Conf Acc: {conf_acc:.4f}")
|
|
|
|
# Early stopping
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
patience_counter = 0
|
|
self._save_model('best_model.pt')
|
|
else:
|
|
patience_counter += 1
|
|
if patience_counter >= self.patience:
|
|
logger.info(f"Early stopping at epoch {epoch+1}")
|
|
break
|
|
|
|
# Save final model
|
|
self._save_model('final_model.pt')
|
|
|
|
# Generate training report
|
|
return self._generate_training_report()
|
|
|
|
def _train_epoch(self, train_loader: DataLoader) -> Tuple[float, float]:
|
|
"""Train for one epoch"""
|
|
self.model.train()
|
|
total_loss = 0.0
|
|
correct_predictions = 0
|
|
total_predictions = 0
|
|
|
|
for batch in train_loader:
|
|
features = batch['features'].to(self.model.device)
|
|
action_labels = batch['action_label'].to(self.model.device).squeeze(-1)
|
|
confidence_targets = batch['confidence_target'].to(self.model.device).squeeze(-1)
|
|
|
|
# Zero gradients
|
|
self.model.optimizer.zero_grad()
|
|
|
|
# Forward pass
|
|
action_probs, confidence_pred = self.model(features)
|
|
|
|
# Calculate losses
|
|
action_loss = self.model.action_criterion(action_probs, action_labels)
|
|
confidence_loss = self.model.confidence_criterion(confidence_pred, confidence_targets)
|
|
|
|
# Combined loss
|
|
total_loss_batch = action_loss + 0.5 * confidence_loss
|
|
|
|
# Backward pass
|
|
total_loss_batch.backward()
|
|
self.model.optimizer.step()
|
|
|
|
# Track metrics
|
|
total_loss += total_loss_batch.item()
|
|
predicted_actions = torch.argmax(action_probs, dim=1)
|
|
correct_predictions += (predicted_actions == action_labels).sum().item()
|
|
total_predictions += action_labels.size(0)
|
|
|
|
avg_loss = total_loss / len(train_loader)
|
|
accuracy = correct_predictions / total_predictions
|
|
|
|
return avg_loss, accuracy
|
|
|
|
def _validate_epoch(self, val_loader: DataLoader) -> Tuple[float, float, float]:
|
|
"""Validate for one epoch"""
|
|
self.model.eval()
|
|
total_loss = 0.0
|
|
correct_predictions = 0
|
|
total_predictions = 0
|
|
confidence_errors = []
|
|
|
|
with torch.no_grad():
|
|
for batch in val_loader:
|
|
features = batch['features'].to(self.model.device)
|
|
action_labels = batch['action_label'].to(self.model.device).squeeze(-1)
|
|
confidence_targets = batch['confidence_target'].to(self.model.device).squeeze(-1)
|
|
|
|
# Forward pass
|
|
action_probs, confidence_pred = self.model(features)
|
|
|
|
# Calculate losses
|
|
action_loss = self.model.action_criterion(action_probs, action_labels)
|
|
confidence_loss = self.model.confidence_criterion(confidence_pred, confidence_targets)
|
|
total_loss_batch = action_loss + 0.5 * confidence_loss
|
|
|
|
# Track metrics
|
|
total_loss += total_loss_batch.item()
|
|
predicted_actions = torch.argmax(action_probs, dim=1)
|
|
correct_predictions += (predicted_actions == action_labels).sum().item()
|
|
total_predictions += action_labels.size(0)
|
|
|
|
# Track confidence accuracy
|
|
conf_errors = torch.abs(confidence_pred - confidence_targets)
|
|
confidence_errors.extend(conf_errors.cpu().numpy())
|
|
|
|
avg_loss = total_loss / len(val_loader)
|
|
accuracy = correct_predictions / total_predictions
|
|
confidence_accuracy = 1.0 - np.mean(confidence_errors) # 1 - mean absolute error
|
|
|
|
return avg_loss, accuracy, confidence_accuracy
|
|
|
|
def _save_model(self, filename: str):
|
|
"""Save the model"""
|
|
save_path = self.save_dir / filename
|
|
torch.save({
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.model.optimizer.state_dict(),
|
|
'config': self.config.cnn,
|
|
'training_history': self.training_history
|
|
}, save_path)
|
|
logger.info(f"Model saved to {save_path}")
|
|
|
|
def load_model(self, filename: str) -> bool:
|
|
"""Load a saved model"""
|
|
load_path = self.save_dir / filename
|
|
if not load_path.exists():
|
|
logger.error(f"Model file not found: {load_path}")
|
|
return False
|
|
|
|
try:
|
|
checkpoint = torch.load(load_path, map_location=self.model.device)
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
self.model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
self.training_history = checkpoint.get('training_history', {})
|
|
logger.info(f"Model loaded from {load_path}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error loading model: {e}")
|
|
return False
|
|
|
|
def _generate_training_report(self) -> Dict[str, Any]:
|
|
"""Generate comprehensive training report"""
|
|
if not self.training_history['train_loss']:
|
|
return {'error': 'no_training_data'}
|
|
|
|
# Calculate final metrics
|
|
final_train_loss = self.training_history['train_loss'][-1]
|
|
final_val_loss = self.training_history['val_loss'][-1]
|
|
final_train_acc = self.training_history['train_accuracy'][-1]
|
|
final_val_acc = self.training_history['val_accuracy'][-1]
|
|
final_conf_acc = self.training_history['confidence_accuracy'][-1]
|
|
|
|
# Best metrics
|
|
best_val_loss = min(self.training_history['val_loss'])
|
|
best_val_acc = max(self.training_history['val_accuracy'])
|
|
best_conf_acc = max(self.training_history['confidence_accuracy'])
|
|
|
|
report = {
|
|
'training_completed': True,
|
|
'epochs_trained': len(self.training_history['train_loss']),
|
|
'final_metrics': {
|
|
'train_loss': final_train_loss,
|
|
'val_loss': final_val_loss,
|
|
'train_accuracy': final_train_acc,
|
|
'val_accuracy': final_val_acc,
|
|
'confidence_accuracy': final_conf_acc
|
|
},
|
|
'best_metrics': {
|
|
'val_loss': best_val_loss,
|
|
'val_accuracy': best_val_acc,
|
|
'confidence_accuracy': best_conf_acc
|
|
},
|
|
'model_info': {
|
|
'timeframes': self.model.timeframes,
|
|
'memory_usage_mb': self.model.get_memory_usage(),
|
|
'device': str(self.model.device)
|
|
}
|
|
}
|
|
|
|
# Generate plots
|
|
self._plot_training_history()
|
|
|
|
logger.info("Training completed successfully")
|
|
logger.info(f"Final validation accuracy: {final_val_acc:.4f}")
|
|
logger.info(f"Final confidence accuracy: {final_conf_acc:.4f}")
|
|
|
|
return report
|
|
|
|
def _plot_training_history(self):
|
|
"""Plot training history"""
|
|
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
|
fig.suptitle('Enhanced CNN Training History')
|
|
|
|
# Loss plot
|
|
axes[0, 0].plot(self.training_history['train_loss'], label='Train Loss')
|
|
axes[0, 0].plot(self.training_history['val_loss'], label='Val Loss')
|
|
axes[0, 0].set_title('Loss')
|
|
axes[0, 0].set_xlabel('Epoch')
|
|
axes[0, 0].set_ylabel('Loss')
|
|
axes[0, 0].legend()
|
|
|
|
# Accuracy plot
|
|
axes[0, 1].plot(self.training_history['train_accuracy'], label='Train Accuracy')
|
|
axes[0, 1].plot(self.training_history['val_accuracy'], label='Val Accuracy')
|
|
axes[0, 1].set_title('Action Accuracy')
|
|
axes[0, 1].set_xlabel('Epoch')
|
|
axes[0, 1].set_ylabel('Accuracy')
|
|
axes[0, 1].legend()
|
|
|
|
# Confidence accuracy plot
|
|
axes[1, 0].plot(self.training_history['confidence_accuracy'], label='Confidence Accuracy')
|
|
axes[1, 0].set_title('Confidence Prediction Accuracy')
|
|
axes[1, 0].set_xlabel('Epoch')
|
|
axes[1, 0].set_ylabel('Accuracy')
|
|
axes[1, 0].legend()
|
|
|
|
# Learning curves comparison
|
|
axes[1, 1].plot(self.training_history['val_loss'], label='Validation Loss')
|
|
axes[1, 1].plot(self.training_history['confidence_accuracy'], label='Confidence Accuracy')
|
|
axes[1, 1].set_title('Model Performance Overview')
|
|
axes[1, 1].set_xlabel('Epoch')
|
|
axes[1, 1].legend()
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(self.save_dir / 'training_history.png', dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
|
|
logger.info(f"Training plots saved to {self.save_dir / 'training_history.png'}")
|
|
|
|
def get_model(self) -> EnhancedCNNModel:
|
|
"""Get the trained model"""
|
|
return self.model
|
|
|
|
def __del__(self):
|
|
"""Cleanup"""
|
|
self.close_tensorboard()
|
|
|
|
def main():
|
|
"""Main function for standalone CNN live training with backtesting and analysis"""
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent.parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
parser = argparse.ArgumentParser(description='Enhanced CNN Live Training with Backtesting and Analysis')
|
|
parser.add_argument('--symbols', type=str, nargs='+', default=['ETH/USDT', 'BTC/USDT'],
|
|
help='Trading symbols to train on')
|
|
parser.add_argument('--timeframes', type=str, nargs='+', default=['1m', '5m', '15m', '1h'],
|
|
help='Timeframes to use for training')
|
|
parser.add_argument('--epochs', type=int, default=100,
|
|
help='Number of training epochs')
|
|
parser.add_argument('--batch-size', type=int, default=32,
|
|
help='Training batch size')
|
|
parser.add_argument('--learning-rate', type=float, default=0.001,
|
|
help='Learning rate')
|
|
parser.add_argument('--save-path', type=str, default='models/enhanced_cnn/live_trained_model.pt',
|
|
help='Path to save the trained model')
|
|
parser.add_argument('--enable-backtesting', action='store_true', default=True,
|
|
help='Enable backtesting after training')
|
|
parser.add_argument('--enable-analysis', action='store_true', default=True,
|
|
help='Enable detailed analysis and reporting')
|
|
parser.add_argument('--enable-live-validation', action='store_true', default=True,
|
|
help='Enable live validation during training')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
|
|
logger.info("="*80)
|
|
logger.info("🧠 ENHANCED CNN LIVE TRAINING WITH BACKTESTING & ANALYSIS")
|
|
logger.info("="*80)
|
|
logger.info(f"Symbols: {args.symbols}")
|
|
logger.info(f"Timeframes: {args.timeframes}")
|
|
logger.info(f"Epochs: {args.epochs}")
|
|
logger.info(f"Batch Size: {args.batch_size}")
|
|
logger.info(f"Learning Rate: {args.learning_rate}")
|
|
logger.info(f"Save Path: {args.save_path}")
|
|
logger.info(f"Backtesting: {'Enabled' if args.enable_backtesting else 'Disabled'}")
|
|
logger.info(f"Analysis: {'Enabled' if args.enable_analysis else 'Disabled'}")
|
|
logger.info(f"Live Validation: {'Enabled' if args.enable_live_validation else 'Disabled'}")
|
|
logger.info("="*80)
|
|
|
|
try:
|
|
# Update config with command line arguments
|
|
config = get_config()
|
|
config.update('symbols', args.symbols)
|
|
config.update('timeframes', args.timeframes)
|
|
config.update('training', {
|
|
**config.training,
|
|
'epochs': args.epochs,
|
|
'batch_size': args.batch_size,
|
|
'learning_rate': args.learning_rate
|
|
})
|
|
|
|
# Initialize enhanced trainer
|
|
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
|
from core.data_provider import DataProvider
|
|
|
|
data_provider = DataProvider(config)
|
|
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
|
trainer = EnhancedCNNTrainer(config, orchestrator)
|
|
|
|
# Phase 1: Data Collection and Preparation
|
|
logger.info("📊 Phase 1: Collecting and preparing training data...")
|
|
training_data = trainer.collect_training_data(args.symbols, lookback_days=30)
|
|
logger.info(f" Collected {len(training_data)} training samples")
|
|
|
|
# Phase 2: Model Training
|
|
logger.info("🧠 Phase 2: Training Enhanced CNN Model...")
|
|
training_results = trainer.train_on_perfect_moves(min_samples=1000)
|
|
|
|
logger.info("Training Results:")
|
|
logger.info(f" Best Validation Accuracy: {training_results['best_val_accuracy']:.4f}")
|
|
logger.info(f" Best Validation Loss: {training_results['best_val_loss']:.4f}")
|
|
logger.info(f" Total Epochs: {training_results['epochs_completed']}")
|
|
logger.info(f" Training Time: {training_results['total_time']:.2f}s")
|
|
|
|
# Phase 3: Model Evaluation
|
|
logger.info("📈 Phase 3: Model Evaluation...")
|
|
evaluation_results = trainer.evaluate_model(args.symbols[:1]) # Use first symbol for evaluation
|
|
|
|
logger.info("Evaluation Results:")
|
|
logger.info(f" Test Accuracy: {evaluation_results['test_accuracy']:.4f}")
|
|
logger.info(f" Test Loss: {evaluation_results['test_loss']:.4f}")
|
|
logger.info(f" Confidence Score: {evaluation_results['avg_confidence']:.4f}")
|
|
|
|
# Phase 4: Backtesting (if enabled)
|
|
if args.enable_backtesting:
|
|
logger.info("📊 Phase 4: Backtesting...")
|
|
|
|
# Create backtest environment
|
|
from trading.backtest_environment import BacktestEnvironment
|
|
backtest_env = BacktestEnvironment(
|
|
symbols=args.symbols,
|
|
timeframes=args.timeframes,
|
|
initial_balance=10000.0,
|
|
data_provider=data_provider
|
|
)
|
|
|
|
# Run backtest
|
|
backtest_results = backtest_env.run_backtest_with_model(
|
|
model=trainer.model,
|
|
lookback_days=7, # Test on last 7 days
|
|
max_trades_per_day=50
|
|
)
|
|
|
|
logger.info("Backtesting Results:")
|
|
logger.info(f" Total Returns: {backtest_results['total_return']:.2f}%")
|
|
logger.info(f" Win Rate: {backtest_results['win_rate']:.2f}%")
|
|
logger.info(f" Sharpe Ratio: {backtest_results['sharpe_ratio']:.4f}")
|
|
logger.info(f" Max Drawdown: {backtest_results['max_drawdown']:.2f}%")
|
|
logger.info(f" Total Trades: {backtest_results['total_trades']}")
|
|
logger.info(f" Profit Factor: {backtest_results['profit_factor']:.4f}")
|
|
|
|
# Phase 5: Analysis and Reporting (if enabled)
|
|
if args.enable_analysis:
|
|
logger.info("📋 Phase 5: Analysis and Reporting...")
|
|
|
|
# Generate comprehensive analysis report
|
|
analysis_report = trainer.generate_analysis_report(
|
|
training_results=training_results,
|
|
evaluation_results=evaluation_results,
|
|
backtest_results=backtest_results if args.enable_backtesting else None
|
|
)
|
|
|
|
# Save analysis report
|
|
report_path = Path(args.save_path).parent / "analysis_report.json"
|
|
report_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(report_path, 'w') as f:
|
|
json.dump(analysis_report, f, indent=2, default=str)
|
|
|
|
logger.info(f" Analysis report saved: {report_path}")
|
|
|
|
# Generate performance plots
|
|
plots_dir = Path(args.save_path).parent / "plots"
|
|
plots_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
trainer.generate_performance_plots(
|
|
training_results=training_results,
|
|
evaluation_results=evaluation_results,
|
|
save_dir=plots_dir
|
|
)
|
|
|
|
logger.info(f" Performance plots saved: {plots_dir}")
|
|
|
|
# Phase 6: Model Saving
|
|
logger.info("💾 Phase 6: Saving trained model...")
|
|
model_path = Path(args.save_path)
|
|
model_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
trainer.model.save(str(model_path))
|
|
logger.info(f" Model saved: {model_path}")
|
|
|
|
# Save training metadata
|
|
metadata = {
|
|
'training_config': {
|
|
'symbols': args.symbols,
|
|
'timeframes': args.timeframes,
|
|
'epochs': args.epochs,
|
|
'batch_size': args.batch_size,
|
|
'learning_rate': args.learning_rate
|
|
},
|
|
'training_results': training_results,
|
|
'evaluation_results': evaluation_results
|
|
}
|
|
|
|
if args.enable_backtesting:
|
|
metadata['backtest_results'] = backtest_results
|
|
|
|
metadata_path = model_path.with_suffix('.json')
|
|
with open(metadata_path, 'w') as f:
|
|
json.dump(metadata, f, indent=2, default=str)
|
|
|
|
logger.info(f" Training metadata saved: {metadata_path}")
|
|
|
|
# Phase 7: Live Validation (if enabled)
|
|
if args.enable_live_validation:
|
|
logger.info("🔄 Phase 7: Live Validation...")
|
|
|
|
# Test model on recent live data
|
|
live_validation_results = trainer.run_live_validation(
|
|
symbols=args.symbols[:1], # Use first symbol
|
|
validation_hours=2 # Validate on last 2 hours
|
|
)
|
|
|
|
logger.info("Live Validation Results:")
|
|
logger.info(f" Prediction Accuracy: {live_validation_results['accuracy']:.2f}%")
|
|
logger.info(f" Average Confidence: {live_validation_results['avg_confidence']:.4f}")
|
|
logger.info(f" Predictions Made: {live_validation_results['total_predictions']}")
|
|
|
|
logger.info("="*80)
|
|
logger.info("🎉 ENHANCED CNN LIVE TRAINING COMPLETED SUCCESSFULLY!")
|
|
logger.info("="*80)
|
|
logger.info(f"📊 Model Path: {model_path}")
|
|
logger.info(f"📋 Metadata: {metadata_path}")
|
|
if args.enable_analysis:
|
|
logger.info(f"📈 Analysis Report: {report_path}")
|
|
logger.info(f"📊 Performance Plots: {plots_dir}")
|
|
logger.info("="*80)
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Training interrupted by user")
|
|
return 1
|
|
except Exception as e:
|
|
logger.error(f"Training failed: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return 1
|
|
|
|
return 0
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main()) |