big cleanup
This commit is contained in:
@ -1,219 +0,0 @@
|
||||
"""
|
||||
CNN-RL Bridge Module
|
||||
|
||||
This module provides the interface between CNN models and RL training,
|
||||
extracting hidden features and predictions from CNN models for use in RL state building.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CNNRLBridge:
|
||||
"""Bridge between CNN models and RL training for feature extraction"""
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
"""Initialize CNN-RL bridge"""
|
||||
self.config = config
|
||||
self.cnn_models = {}
|
||||
self.feature_cache = {}
|
||||
self.cache_timeout = 60 # Cache features for 60 seconds
|
||||
|
||||
# Initialize CNN model registry if available
|
||||
self._initialize_cnn_models()
|
||||
|
||||
logger.info("CNN-RL Bridge initialized")
|
||||
|
||||
def _initialize_cnn_models(self):
|
||||
"""Initialize CNN models from config or model registry"""
|
||||
try:
|
||||
# Try to load CNN models from config
|
||||
if hasattr(self.config, 'cnn_models') and self.config.cnn_models:
|
||||
for model_name, model_config in self.config.cnn_models.items():
|
||||
try:
|
||||
# Load CNN model (implementation would depend on your CNN architecture)
|
||||
model = self._load_cnn_model(model_name, model_config)
|
||||
if model:
|
||||
self.cnn_models[model_name] = model
|
||||
logger.info(f"Loaded CNN model: {model_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load CNN model {model_name}: {e}")
|
||||
|
||||
if not self.cnn_models:
|
||||
logger.info("No CNN models available - RL will train without CNN features")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing CNN models: {e}")
|
||||
|
||||
def _load_cnn_model(self, model_name: str, model_config: Dict) -> Optional[nn.Module]:
|
||||
"""Load a CNN model from configuration"""
|
||||
try:
|
||||
# This would implement actual CNN model loading
|
||||
# For now, return None to indicate no models available
|
||||
# In your implementation, this would load your specific CNN architecture
|
||||
|
||||
logger.info(f"CNN model loading framework ready for {model_name}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading CNN model {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def get_latest_features_for_symbol(self, symbol: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get latest CNN features and predictions for a symbol"""
|
||||
try:
|
||||
# Check cache first
|
||||
cache_key = f"{symbol}_{datetime.now().strftime('%Y%m%d_%H%M')}"
|
||||
if cache_key in self.feature_cache:
|
||||
cached_data = self.feature_cache[cache_key]
|
||||
if (datetime.now() - cached_data['timestamp']).seconds < self.cache_timeout:
|
||||
return cached_data['features']
|
||||
|
||||
# Generate new features if models available
|
||||
if self.cnn_models:
|
||||
features = self._extract_cnn_features_for_symbol(symbol)
|
||||
|
||||
# Cache the features
|
||||
self.feature_cache[cache_key] = {
|
||||
'timestamp': datetime.now(),
|
||||
'features': features
|
||||
}
|
||||
|
||||
# Clean old cache entries
|
||||
self._cleanup_cache()
|
||||
|
||||
return features
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting CNN features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _extract_cnn_features_for_symbol(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Extract CNN hidden features and predictions for a symbol"""
|
||||
try:
|
||||
extracted_features = {
|
||||
'hidden_features': {},
|
||||
'predictions': {}
|
||||
}
|
||||
|
||||
for model_name, model in self.cnn_models.items():
|
||||
try:
|
||||
# Extract features from each CNN model
|
||||
hidden_features, predictions = self._extract_model_features(model, symbol)
|
||||
|
||||
if hidden_features is not None:
|
||||
extracted_features['hidden_features'][model_name] = hidden_features
|
||||
|
||||
if predictions is not None:
|
||||
extracted_features['predictions'][model_name] = predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting features from {model_name}: {e}")
|
||||
|
||||
return extracted_features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting CNN features for {symbol}: {e}")
|
||||
return {'hidden_features': {}, 'predictions': {}}
|
||||
|
||||
def _extract_model_features(self, model: nn.Module, symbol: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""Extract hidden features and predictions from a specific CNN model"""
|
||||
try:
|
||||
# This would implement the actual feature extraction from your CNN models
|
||||
# The implementation depends on your specific CNN architecture
|
||||
|
||||
# For now, return mock data to show the structure
|
||||
# In real implementation, this would:
|
||||
# 1. Get market data for the model
|
||||
# 2. Run forward pass through CNN
|
||||
# 3. Extract hidden layer activations
|
||||
# 4. Get model predictions
|
||||
|
||||
# Mock hidden features (last hidden layer of CNN)
|
||||
hidden_features = np.random.random(512).astype(np.float32)
|
||||
|
||||
# Mock predictions for different timeframes
|
||||
# [1s_pred, 1m_pred, 1h_pred, 1d_pred] for each timeframe
|
||||
predictions = np.array([
|
||||
0.45, # 1s prediction (probability of up move)
|
||||
0.52, # 1m prediction
|
||||
0.38, # 1h prediction
|
||||
0.61 # 1d prediction
|
||||
]).astype(np.float32)
|
||||
|
||||
logger.debug(f"Extracted CNN features for {symbol}: {len(hidden_features)} hidden, {len(predictions)} predictions")
|
||||
|
||||
return hidden_features, predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting features from model: {e}")
|
||||
return None, None
|
||||
|
||||
def _cleanup_cache(self):
|
||||
"""Clean up old cache entries"""
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
expired_keys = []
|
||||
|
||||
for key, data in self.feature_cache.items():
|
||||
if (current_time - data['timestamp']).seconds > self.cache_timeout * 2:
|
||||
expired_keys.append(key)
|
||||
|
||||
for key in expired_keys:
|
||||
del self.feature_cache[key]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning up feature cache: {e}")
|
||||
|
||||
def register_cnn_model(self, model_name: str, model: nn.Module):
|
||||
"""Register a CNN model for feature extraction"""
|
||||
try:
|
||||
self.cnn_models[model_name] = model
|
||||
logger.info(f"Registered CNN model: {model_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering CNN model {model_name}: {e}")
|
||||
|
||||
def unregister_cnn_model(self, model_name: str):
|
||||
"""Unregister a CNN model"""
|
||||
try:
|
||||
if model_name in self.cnn_models:
|
||||
del self.cnn_models[model_name]
|
||||
logger.info(f"Unregistered CNN model: {model_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error unregistering CNN model {model_name}: {e}")
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""Get list of available CNN models"""
|
||||
return list(self.cnn_models.keys())
|
||||
|
||||
def is_model_available(self, model_name: str) -> bool:
|
||||
"""Check if a specific CNN model is available"""
|
||||
return model_name in self.cnn_models
|
||||
|
||||
def get_feature_dimensions(self) -> Dict[str, int]:
|
||||
"""Get the dimensions of features extracted from CNN models"""
|
||||
return {
|
||||
'hidden_features_per_model': 512,
|
||||
'predictions_per_model': 4, # 1s, 1m, 1h, 1d
|
||||
'total_models': len(self.cnn_models)
|
||||
}
|
||||
|
||||
def validate_cnn_integration(self) -> Dict[str, Any]:
|
||||
"""Validate CNN integration status"""
|
||||
status = {
|
||||
'models_available': len(self.cnn_models),
|
||||
'models_list': list(self.cnn_models.keys()),
|
||||
'cache_entries': len(self.feature_cache),
|
||||
'integration_ready': len(self.cnn_models) > 0,
|
||||
'expected_feature_size': len(self.cnn_models) * 512, # hidden features
|
||||
'expected_prediction_size': len(self.cnn_models) * 4 # predictions
|
||||
}
|
||||
|
||||
return status
|
@ -1,491 +0,0 @@
|
||||
"""
|
||||
CNN Training Pipeline
|
||||
|
||||
This module handles training of the CNN model using ONLY real market data.
|
||||
All training metrics are logged to TensorBoard for real-time monitoring.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from pathlib import Path
|
||||
import time
|
||||
from sklearn.metrics import classification_report, confusion_matrix
|
||||
import json
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from models.cnn.scalping_cnn import MultiTimeframeCNN, ScalpingDataGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CNNDataset(Dataset):
|
||||
"""Dataset for CNN training with real market data"""
|
||||
|
||||
def __init__(self, features: np.ndarray, labels: np.ndarray):
|
||||
self.features = torch.FloatTensor(features)
|
||||
self.labels = torch.LongTensor(np.argmax(labels, axis=1)) # Convert one-hot to class indices
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.features[idx], self.labels[idx]
|
||||
|
||||
class CNNTrainer:
|
||||
"""CNN Trainer using ONLY real market data with TensorBoard monitoring"""
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None):
|
||||
"""Initialize CNN trainer"""
|
||||
self.config = config or get_config()
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Training parameters
|
||||
self.learning_rate = self.config.training.get('learning_rate', 0.001)
|
||||
self.batch_size = self.config.training.get('batch_size', 32)
|
||||
self.epochs = self.config.training.get('epochs', 100)
|
||||
self.validation_split = self.config.training.get('validation_split', 0.2)
|
||||
self.early_stopping_patience = self.config.training.get('early_stopping_patience', 10)
|
||||
|
||||
# Model parameters - will be updated based on real data
|
||||
self.n_timeframes = len(self.config.timeframes)
|
||||
self.window_size = self.config.cnn.get('window_size', 20)
|
||||
self.n_features = self.config.cnn.get('features', 26) # Will be dynamically updated
|
||||
self.n_classes = 3 # BUY, SELL, HOLD
|
||||
|
||||
# Initialize components
|
||||
self.data_provider = DataProvider(self.config)
|
||||
self.data_generator = ScalpingDataGenerator(self.data_provider, self.window_size)
|
||||
self.model = None
|
||||
|
||||
# TensorBoard setup
|
||||
self.setup_tensorboard()
|
||||
|
||||
logger.info(f"CNNTrainer initialized with {self.n_timeframes} timeframes, {self.n_features} features")
|
||||
logger.info("Will use ONLY real market data for training")
|
||||
|
||||
def setup_tensorboard(self):
|
||||
"""Setup TensorBoard logging"""
|
||||
# Create tensorboard logs directory
|
||||
log_dir = Path("runs") / f"cnn_training_{int(time.time())}"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.writer = SummaryWriter(log_dir=str(log_dir))
|
||||
self.tensorboard_dir = log_dir
|
||||
|
||||
logger.info(f"TensorBoard logging to: {log_dir}")
|
||||
logger.info(f"Run: tensorboard --logdir=runs")
|
||||
|
||||
def log_model_architecture(self):
|
||||
"""Log model architecture to TensorBoard"""
|
||||
if self.model is not None:
|
||||
# Log model graph (requires a dummy input)
|
||||
dummy_input = torch.randn(1, self.n_timeframes, self.window_size, self.n_features).to(self.device)
|
||||
try:
|
||||
self.writer.add_graph(self.model, dummy_input)
|
||||
logger.info("Model architecture logged to TensorBoard")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not log model graph: {e}")
|
||||
|
||||
# Log model parameters count
|
||||
total_params = sum(p.numel() for p in self.model.parameters())
|
||||
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
|
||||
self.writer.add_scalar('Model/TotalParameters', total_params, 0)
|
||||
self.writer.add_scalar('Model/TrainableParameters', trainable_params, 0)
|
||||
|
||||
def create_model(self) -> MultiTimeframeCNN:
|
||||
"""Create CNN model"""
|
||||
model = MultiTimeframeCNN(
|
||||
n_timeframes=self.n_timeframes,
|
||||
window_size=self.window_size,
|
||||
n_features=self.n_features,
|
||||
n_classes=self.n_classes,
|
||||
dropout_rate=self.config.cnn.get('dropout', 0.2)
|
||||
)
|
||||
|
||||
model = model.to(self.device)
|
||||
|
||||
# Log model info
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
memory_usage = model.get_memory_usage()
|
||||
|
||||
logger.info(f"Model created with {total_params:,} total parameters")
|
||||
logger.info(f"Trainable parameters: {trainable_params:,}")
|
||||
logger.info(f"Estimated memory usage: {memory_usage}MB")
|
||||
|
||||
return model
|
||||
|
||||
def prepare_data(self, symbols: List[str], num_samples: int = 10000) -> Tuple[np.ndarray, np.ndarray, Dict]:
|
||||
"""Prepare training data from REAL market data"""
|
||||
logger.info("Preparing training data...")
|
||||
logger.info("Data source: REAL market data from exchange APIs")
|
||||
|
||||
all_features = []
|
||||
all_labels = []
|
||||
all_metadata = []
|
||||
|
||||
for symbol in symbols:
|
||||
logger.info(f"Generating data for {symbol}...")
|
||||
|
||||
features, labels, metadata = self.data_generator.generate_training_cases(
|
||||
symbol=symbol,
|
||||
timeframes=self.config.timeframes,
|
||||
num_samples=num_samples
|
||||
)
|
||||
|
||||
if features is not None:
|
||||
all_features.append(features)
|
||||
all_labels.append(labels)
|
||||
all_metadata.append(metadata)
|
||||
|
||||
logger.info(f"Generated {len(features)} samples for {symbol}")
|
||||
|
||||
# Update feature count if needed
|
||||
actual_features = features.shape[-1]
|
||||
if actual_features != self.n_features:
|
||||
logger.info(f"Updating feature count from {self.n_features} to {actual_features}")
|
||||
self.n_features = actual_features
|
||||
|
||||
if not all_features:
|
||||
raise ValueError("No training data generated from real market data")
|
||||
|
||||
# Combine all data
|
||||
features = np.concatenate(all_features, axis=0)
|
||||
labels = np.concatenate(all_labels, axis=0)
|
||||
|
||||
# Log data statistics to TensorBoard
|
||||
self.log_data_statistics(features, labels)
|
||||
|
||||
return features, labels, all_metadata
|
||||
|
||||
def log_data_statistics(self, features: np.ndarray, labels: np.ndarray):
|
||||
"""Log data statistics to TensorBoard"""
|
||||
# Dataset size
|
||||
self.writer.add_scalar('Data/TotalSamples', len(features), 0)
|
||||
self.writer.add_scalar('Data/Features', features.shape[-1], 0)
|
||||
self.writer.add_scalar('Data/Timeframes', features.shape[1], 0)
|
||||
self.writer.add_scalar('Data/WindowSize', features.shape[2], 0)
|
||||
|
||||
# Class distribution
|
||||
class_counts = np.bincount(np.argmax(labels, axis=1))
|
||||
for i, count in enumerate(class_counts):
|
||||
self.writer.add_scalar(f'Data/Class_{i}_Count', count, 0)
|
||||
|
||||
# Feature statistics
|
||||
feature_means = features.mean(axis=(0, 1, 2))
|
||||
feature_stds = features.std(axis=(0, 1, 2))
|
||||
|
||||
for i in range(min(10, len(feature_means))): # Log first 10 features
|
||||
self.writer.add_scalar(f'Data/Feature_{i}_Mean', feature_means[i], 0)
|
||||
self.writer.add_scalar(f'Data/Feature_{i}_Std', feature_stds[i], 0)
|
||||
|
||||
def train_epoch(self, model: nn.Module, train_loader: DataLoader,
|
||||
optimizer: torch.optim.Optimizer, criterion: nn.Module, epoch: int) -> Tuple[float, float]:
|
||||
"""Train for one epoch with TensorBoard logging"""
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
for batch_idx, (features, labels) in enumerate(train_loader):
|
||||
features, labels = features.to(self.device), labels.to(self.device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
predictions = model(features)
|
||||
loss = criterion(predictions['action'], labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
_, predicted = torch.max(predictions['action'].data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
# Log batch metrics
|
||||
step = epoch * len(train_loader) + batch_idx
|
||||
self.writer.add_scalar('Training/BatchLoss', loss.item(), step)
|
||||
|
||||
if batch_idx % 50 == 0: # Log every 50 batches
|
||||
batch_acc = 100. * (predicted == labels).sum().item() / labels.size(0)
|
||||
self.writer.add_scalar('Training/BatchAccuracy', batch_acc, step)
|
||||
|
||||
# Log confidence scores
|
||||
avg_confidence = predictions['confidence'].mean().item()
|
||||
self.writer.add_scalar('Training/BatchConfidence', avg_confidence, step)
|
||||
|
||||
epoch_loss = total_loss / len(train_loader)
|
||||
epoch_accuracy = correct / total
|
||||
|
||||
return epoch_loss, epoch_accuracy
|
||||
|
||||
def validate_epoch(self, model: nn.Module, val_loader: DataLoader,
|
||||
criterion: nn.Module, epoch: int) -> Tuple[float, float, Dict]:
|
||||
"""Validate for one epoch with TensorBoard logging"""
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
all_predictions = []
|
||||
all_labels = []
|
||||
all_confidences = []
|
||||
|
||||
with torch.no_grad():
|
||||
for features, labels in val_loader:
|
||||
features, labels = features.to(self.device), labels.to(self.device)
|
||||
|
||||
predictions = model(features)
|
||||
loss = criterion(predictions['action'], labels)
|
||||
|
||||
total_loss += loss.item()
|
||||
_, predicted = torch.max(predictions['action'].data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
all_predictions.extend(predicted.cpu().numpy())
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
all_confidences.extend(predictions['confidence'].cpu().numpy())
|
||||
|
||||
epoch_loss = total_loss / len(val_loader)
|
||||
epoch_accuracy = correct / total
|
||||
|
||||
# Calculate detailed metrics
|
||||
metrics = self.calculate_detailed_metrics(all_predictions, all_labels, all_confidences)
|
||||
|
||||
# Log validation metrics to TensorBoard
|
||||
self.writer.add_scalar('Validation/Loss', epoch_loss, epoch)
|
||||
self.writer.add_scalar('Validation/Accuracy', epoch_accuracy, epoch)
|
||||
self.writer.add_scalar('Validation/AvgConfidence', metrics['avg_confidence'], epoch)
|
||||
|
||||
for class_idx, acc in metrics['class_accuracies'].items():
|
||||
self.writer.add_scalar(f'Validation/Class_{class_idx}_Accuracy', acc, epoch)
|
||||
|
||||
return epoch_loss, epoch_accuracy, metrics
|
||||
|
||||
def calculate_detailed_metrics(self, predictions: List, labels: List, confidences: List) -> Dict:
|
||||
"""Calculate detailed training metrics"""
|
||||
predictions = np.array(predictions)
|
||||
labels = np.array(labels)
|
||||
confidences = np.array(confidences)
|
||||
|
||||
# Class-wise accuracies
|
||||
class_accuracies = {}
|
||||
for class_idx in range(self.n_classes):
|
||||
class_mask = labels == class_idx
|
||||
if class_mask.sum() > 0:
|
||||
class_acc = (predictions[class_mask] == labels[class_mask]).mean()
|
||||
class_accuracies[class_idx] = class_acc
|
||||
|
||||
return {
|
||||
'class_accuracies': class_accuracies,
|
||||
'avg_confidence': confidences.mean(),
|
||||
'confusion_matrix': confusion_matrix(labels, predictions)
|
||||
}
|
||||
|
||||
def train(self, symbols: List[str], save_path: str = 'models/cnn/scalping_cnn_trained.pt',
|
||||
num_samples: int = 10000) -> Dict:
|
||||
"""Train CNN model with TensorBoard monitoring"""
|
||||
logger.info("Starting CNN training...")
|
||||
logger.info("Using ONLY real market data from exchange APIs")
|
||||
|
||||
# Prepare data
|
||||
features, labels, metadata = self.prepare_data(symbols, num_samples)
|
||||
|
||||
# Log training configuration
|
||||
self.writer.add_text('Config/Symbols', str(symbols), 0)
|
||||
self.writer.add_text('Config/Timeframes', str(self.config.timeframes), 0)
|
||||
self.writer.add_scalar('Config/LearningRate', self.learning_rate, 0)
|
||||
self.writer.add_scalar('Config/BatchSize', self.batch_size, 0)
|
||||
self.writer.add_scalar('Config/MaxEpochs', self.epochs, 0)
|
||||
|
||||
# Create datasets
|
||||
dataset = CNNDataset(features, labels)
|
||||
|
||||
# Split data
|
||||
val_size = int(len(dataset) * self.validation_split)
|
||||
train_size = len(dataset) - val_size
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||
|
||||
# Create data loaders
|
||||
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
|
||||
|
||||
logger.info(f"Total dataset: {len(dataset)} samples")
|
||||
logger.info(f"Features shape: {features.shape}")
|
||||
logger.info(f"Labels shape: {labels.shape}")
|
||||
logger.info(f"Train samples: {train_size}")
|
||||
logger.info(f"Validation samples: {val_size}")
|
||||
|
||||
# Log class distributions
|
||||
train_labels = [dataset[i][1].item() for i in train_dataset.indices]
|
||||
val_labels = [dataset[i][1].item() for i in val_dataset.indices]
|
||||
|
||||
logger.info(f"Train label distribution: {np.bincount(train_labels)}")
|
||||
logger.info(f"Val label distribution: {np.bincount(val_labels)}")
|
||||
|
||||
# Create model
|
||||
self.model = self.create_model()
|
||||
self.log_model_architecture()
|
||||
|
||||
# Setup training
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True)
|
||||
|
||||
# Training loop
|
||||
best_val_loss = float('inf')
|
||||
best_val_accuracy = 0.0
|
||||
patience_counter = 0
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(self.epochs):
|
||||
epoch_start = time.time()
|
||||
|
||||
# Train
|
||||
train_loss, train_accuracy = self.train_epoch(self.model, train_loader, optimizer, criterion, epoch)
|
||||
|
||||
# Validate
|
||||
val_loss, val_accuracy, val_metrics = self.validate_epoch(self.model, val_loader, criterion, epoch)
|
||||
|
||||
# Update learning rate
|
||||
scheduler.step(val_loss)
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
# Log epoch metrics
|
||||
self.writer.add_scalar('Training/EpochLoss', train_loss, epoch)
|
||||
self.writer.add_scalar('Training/EpochAccuracy', train_accuracy, epoch)
|
||||
self.writer.add_scalar('Training/LearningRate', current_lr, epoch)
|
||||
|
||||
epoch_time = time.time() - epoch_start
|
||||
self.writer.add_scalar('Training/EpochTime', epoch_time, epoch)
|
||||
|
||||
# Save best model
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
best_val_accuracy = val_accuracy
|
||||
patience_counter = 0
|
||||
|
||||
# Save best model
|
||||
best_path = save_path.replace('.pt', '_best.pt')
|
||||
self.model.save(best_path)
|
||||
logger.info(f"New best model saved: {best_path}")
|
||||
|
||||
# Log best metrics
|
||||
self.writer.add_scalar('Best/ValidationLoss', best_val_loss, epoch)
|
||||
self.writer.add_scalar('Best/ValidationAccuracy', best_val_accuracy, epoch)
|
||||
else:
|
||||
patience_counter += 1
|
||||
|
||||
logger.info(f"Epoch {epoch+1}/{self.epochs} - "
|
||||
f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f} - "
|
||||
f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f} - "
|
||||
f"Time: {epoch_time:.2f}s")
|
||||
|
||||
# Log detailed metrics every 10 epochs
|
||||
if (epoch + 1) % 10 == 0:
|
||||
logger.info(f"Class accuracies: {val_metrics['class_accuracies']}")
|
||||
logger.info(f"Average confidence: {val_metrics['avg_confidence']:.4f}")
|
||||
|
||||
# Early stopping
|
||||
if patience_counter >= self.early_stopping_patience:
|
||||
logger.info(f"Early stopping triggered after {epoch+1} epochs")
|
||||
break
|
||||
|
||||
# Training completed
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"Training completed in {total_time:.2f} seconds")
|
||||
logger.info(f"Best validation loss: {best_val_loss:.4f}")
|
||||
logger.info(f"Best validation accuracy: {best_val_accuracy:.4f}")
|
||||
|
||||
# Log final metrics
|
||||
self.writer.add_scalar('Final/TotalTrainingTime', total_time, 0)
|
||||
self.writer.add_scalar('Final/TotalEpochs', epoch + 1, 0)
|
||||
|
||||
# Save final model
|
||||
self.model.save(save_path)
|
||||
logger.info(f"Final model saved: {save_path}")
|
||||
|
||||
# Log training summary
|
||||
self.writer.add_text('Training/Summary',
|
||||
f"Completed training with {len(features)} real market samples. "
|
||||
f"Best validation accuracy: {best_val_accuracy:.4f}", 0)
|
||||
|
||||
return {
|
||||
'best_val_loss': best_val_loss,
|
||||
'best_val_accuracy': best_val_accuracy,
|
||||
'total_epochs': epoch + 1,
|
||||
'training_time': total_time,
|
||||
'tensorboard_dir': str(self.tensorboard_dir)
|
||||
}
|
||||
|
||||
def evaluate(self, symbols: List[str], num_samples: int = 5000) -> Dict:
|
||||
"""Evaluate trained model on test data"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model not trained yet")
|
||||
|
||||
logger.info("Evaluating model...")
|
||||
|
||||
# Generate test data from real market data
|
||||
features, labels, metadata = self.prepare_data(symbols, num_samples)
|
||||
|
||||
# Create test dataset and loader
|
||||
test_dataset = CNNDataset(features, labels)
|
||||
test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)
|
||||
|
||||
# Evaluate
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
test_loss, test_accuracy, test_metrics = self.validate_epoch(
|
||||
self.model, test_loader, criterion, epoch=0
|
||||
)
|
||||
|
||||
# Generate detailed classification report
|
||||
from sklearn.metrics import classification_report
|
||||
class_names = ['BUY', 'SELL', 'HOLD']
|
||||
all_predictions = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for features_batch, labels_batch in test_loader:
|
||||
features_batch = features_batch.to(self.device)
|
||||
predictions = self.model(features_batch)
|
||||
_, predicted = torch.max(predictions['action'].data, 1)
|
||||
all_predictions.extend(predicted.cpu().numpy())
|
||||
all_labels.extend(labels_batch.numpy())
|
||||
|
||||
classification_rep = classification_report(
|
||||
all_labels, all_predictions, target_names=class_names, output_dict=True
|
||||
)
|
||||
|
||||
evaluation_results = {
|
||||
'test_loss': test_loss,
|
||||
'test_accuracy': test_accuracy,
|
||||
'classification_report': classification_rep,
|
||||
'class_accuracies': test_metrics['class_accuracies'],
|
||||
'avg_confidence': test_metrics['avg_confidence'],
|
||||
'confusion_matrix': test_metrics['confusion_matrix']
|
||||
}
|
||||
|
||||
logger.info(f"Test accuracy: {test_accuracy:.4f}")
|
||||
logger.info(f"Test loss: {test_loss:.4f}")
|
||||
|
||||
return evaluation_results
|
||||
|
||||
def close_tensorboard(self):
|
||||
"""Close TensorBoard writer"""
|
||||
if hasattr(self, 'writer'):
|
||||
self.writer.close()
|
||||
logger.info("TensorBoard writer closed")
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup"""
|
||||
self.close_tensorboard()
|
||||
|
||||
# Export
|
||||
__all__ = ['CNNTrainer', 'CNNDataset']
|
@ -1,811 +0,0 @@
|
||||
"""
|
||||
Enhanced CNN Trainer with Perfect Move Learning
|
||||
|
||||
This trainer implements:
|
||||
1. Training on marked perfect moves with known outcomes
|
||||
2. Multi-timeframe CNN model training with confidence scoring
|
||||
3. Backpropagation on optimal moves when future outcomes are known
|
||||
4. Progressive learning from real trading experience
|
||||
5. Symbol-specific and timeframe-specific model fine-tuning
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import PerfectMove, EnhancedTradingOrchestrator
|
||||
from models import CNNModelInterface
|
||||
import models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PerfectMoveDataset(Dataset):
|
||||
"""Dataset for training on perfect moves with known outcomes"""
|
||||
|
||||
def __init__(self, perfect_moves: List[PerfectMove], data_provider: DataProvider):
|
||||
"""
|
||||
Initialize dataset from perfect moves
|
||||
|
||||
Args:
|
||||
perfect_moves: List of perfect moves with known outcomes
|
||||
data_provider: Data provider to fetch additional context
|
||||
"""
|
||||
self.perfect_moves = perfect_moves
|
||||
self.data_provider = data_provider
|
||||
self.samples = []
|
||||
self._prepare_samples()
|
||||
|
||||
def _prepare_samples(self):
|
||||
"""Prepare training samples from perfect moves"""
|
||||
logger.info(f"Preparing {len(self.perfect_moves)} perfect move samples")
|
||||
|
||||
for move in self.perfect_moves:
|
||||
try:
|
||||
# Get feature matrix at the time of the decision
|
||||
feature_matrix = self.data_provider.get_feature_matrix(
|
||||
symbol=move.symbol,
|
||||
timeframes=[move.timeframe],
|
||||
window_size=20,
|
||||
end_time=move.timestamp
|
||||
)
|
||||
|
||||
if feature_matrix is not None:
|
||||
# Convert optimal action to label
|
||||
action_to_label = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
|
||||
label = action_to_label.get(move.optimal_action, 1)
|
||||
|
||||
# Create confidence target (what confidence should have been)
|
||||
confidence_target = move.confidence_should_have_been
|
||||
|
||||
sample = {
|
||||
'features': feature_matrix,
|
||||
'action_label': label,
|
||||
'confidence_target': confidence_target,
|
||||
'symbol': move.symbol,
|
||||
'timeframe': move.timeframe,
|
||||
'outcome': move.actual_outcome,
|
||||
'timestamp': move.timestamp
|
||||
}
|
||||
self.samples.append(sample)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error preparing sample for perfect move: {e}")
|
||||
|
||||
logger.info(f"Prepared {len(self.samples)} valid training samples")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.samples[idx]
|
||||
|
||||
# Convert to tensors
|
||||
features = torch.FloatTensor(sample['features'])
|
||||
action_label = torch.LongTensor([sample['action_label']])
|
||||
confidence_target = torch.FloatTensor([sample['confidence_target']])
|
||||
|
||||
return {
|
||||
'features': features,
|
||||
'action_label': action_label,
|
||||
'confidence_target': confidence_target,
|
||||
'metadata': {
|
||||
'symbol': sample['symbol'],
|
||||
'timeframe': sample['timeframe'],
|
||||
'outcome': sample['outcome'],
|
||||
'timestamp': sample['timestamp']
|
||||
}
|
||||
}
|
||||
|
||||
class EnhancedCNNModel(nn.Module, CNNModelInterface):
|
||||
"""Enhanced CNN model with timeframe-specific predictions and confidence scoring"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
nn.Module.__init__(self)
|
||||
CNNModelInterface.__init__(self, config)
|
||||
|
||||
self.timeframes = config.get('timeframes', ['1h', '4h', '1d'])
|
||||
self.n_features = len(config.get('features', ['open', 'high', 'low', 'close', 'volume']))
|
||||
self.window_size = config.get('window_size', 20)
|
||||
|
||||
# Build the neural network
|
||||
self._build_network()
|
||||
|
||||
# Initialize device
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.to(self.device)
|
||||
|
||||
# Training components
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=config.get('learning_rate', 0.001))
|
||||
self.action_criterion = nn.CrossEntropyLoss()
|
||||
self.confidence_criterion = nn.MSELoss()
|
||||
|
||||
logger.info(f"Enhanced CNN model initialized for {len(self.timeframes)} timeframes")
|
||||
|
||||
def _build_network(self):
|
||||
"""Build the CNN architecture"""
|
||||
# Convolutional feature extraction
|
||||
self.conv_layers = nn.Sequential(
|
||||
# First conv block
|
||||
nn.Conv1d(self.n_features, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Second conv block
|
||||
nn.Conv1d(64, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Third conv block
|
||||
nn.Conv1d(128, 256, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Global average pooling
|
||||
nn.AdaptiveAvgPool1d(1)
|
||||
)
|
||||
|
||||
# Timeframe-specific heads
|
||||
self.timeframe_heads = nn.ModuleDict()
|
||||
for timeframe in self.timeframes:
|
||||
self.timeframe_heads[timeframe] = nn.Sequential(
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(128, 64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3)
|
||||
)
|
||||
|
||||
# Action prediction heads (one per timeframe)
|
||||
self.action_heads = nn.ModuleDict()
|
||||
for timeframe in self.timeframes:
|
||||
self.action_heads[timeframe] = nn.Linear(64, 3) # BUY, HOLD, SELL
|
||||
|
||||
# Confidence prediction heads (one per timeframe)
|
||||
self.confidence_heads = nn.ModuleDict()
|
||||
for timeframe in self.timeframes:
|
||||
self.confidence_heads[timeframe] = nn.Sequential(
|
||||
nn.Linear(64, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(32, 1),
|
||||
nn.Sigmoid() # Output between 0 and 1
|
||||
)
|
||||
|
||||
def forward(self, x, timeframe: str = None):
|
||||
"""
|
||||
Forward pass through the network
|
||||
|
||||
Args:
|
||||
x: Input tensor [batch_size, window_size, features]
|
||||
timeframe: Specific timeframe to predict for
|
||||
|
||||
Returns:
|
||||
action_probs: Action probabilities
|
||||
confidence: Confidence score
|
||||
"""
|
||||
# Reshape for conv1d: [batch, features, sequence]
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
# Extract features
|
||||
features = self.conv_layers(x) # [batch, 256, 1]
|
||||
features = features.squeeze(-1) # [batch, 256]
|
||||
|
||||
if timeframe and timeframe in self.timeframe_heads:
|
||||
# Timeframe-specific prediction
|
||||
tf_features = self.timeframe_heads[timeframe](features)
|
||||
action_logits = self.action_heads[timeframe](tf_features)
|
||||
confidence = self.confidence_heads[timeframe](tf_features)
|
||||
|
||||
action_probs = torch.softmax(action_logits, dim=1)
|
||||
return action_probs, confidence.squeeze(-1)
|
||||
else:
|
||||
# Multi-timeframe prediction (average across timeframes)
|
||||
all_action_probs = []
|
||||
all_confidences = []
|
||||
|
||||
for tf in self.timeframes:
|
||||
tf_features = self.timeframe_heads[tf](features)
|
||||
action_logits = self.action_heads[tf](tf_features)
|
||||
confidence = self.confidence_heads[tf](tf_features)
|
||||
|
||||
action_probs = torch.softmax(action_logits, dim=1)
|
||||
all_action_probs.append(action_probs)
|
||||
all_confidences.append(confidence.squeeze(-1))
|
||||
|
||||
# Average predictions across timeframes
|
||||
avg_action_probs = torch.stack(all_action_probs).mean(dim=0)
|
||||
avg_confidence = torch.stack(all_confidences).mean(dim=0)
|
||||
|
||||
return avg_action_probs, avg_confidence
|
||||
|
||||
def predict(self, features: np.ndarray) -> Tuple[np.ndarray, float]:
|
||||
"""Predict action probabilities and confidence"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
x = torch.FloatTensor(features).to(self.device)
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(0) # Add batch dimension
|
||||
|
||||
action_probs, confidence = self.forward(x)
|
||||
|
||||
return action_probs[0].cpu().numpy(), confidence[0].cpu().item()
|
||||
|
||||
def predict_timeframe(self, features: np.ndarray, timeframe: str) -> Tuple[np.ndarray, float]:
|
||||
"""Predict for specific timeframe"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
x = torch.FloatTensor(features).to(self.device)
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(0) # Add batch dimension
|
||||
|
||||
action_probs, confidence = self.forward(x, timeframe)
|
||||
|
||||
return action_probs[0].cpu().numpy(), confidence[0].cpu().item()
|
||||
|
||||
def get_memory_usage(self) -> int:
|
||||
"""Get memory usage in MB"""
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.memory_allocated(self.device) // (1024 * 1024)
|
||||
else:
|
||||
# Rough estimate for CPU
|
||||
param_count = sum(p.numel() for p in self.parameters())
|
||||
return (param_count * 4) // (1024 * 1024) # 4 bytes per float32
|
||||
|
||||
def train(self, training_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Train the model (placeholder for interface compatibility)"""
|
||||
return {}
|
||||
|
||||
class EnhancedCNNTrainer:
|
||||
"""Enhanced CNN trainer using perfect moves and real market outcomes"""
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
|
||||
"""Initialize the enhanced trainer"""
|
||||
self.config = config or get_config()
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = DataProvider(self.config)
|
||||
|
||||
# Training parameters
|
||||
self.learning_rate = self.config.training.get('learning_rate', 0.001)
|
||||
self.batch_size = self.config.training.get('batch_size', 32)
|
||||
self.epochs = self.config.training.get('epochs', 100)
|
||||
self.patience = self.config.training.get('early_stopping_patience', 10)
|
||||
|
||||
# Model
|
||||
self.model = EnhancedCNNModel(self.config.cnn)
|
||||
|
||||
# Training history
|
||||
self.training_history = {
|
||||
'train_loss': [],
|
||||
'val_loss': [],
|
||||
'train_accuracy': [],
|
||||
'val_accuracy': [],
|
||||
'confidence_accuracy': []
|
||||
}
|
||||
|
||||
# Create save directory
|
||||
models_path = self.config.cnn.get('model_dir', "models/enhanced_cnn")
|
||||
self.save_dir = Path(models_path)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("Enhanced CNN trainer initialized")
|
||||
|
||||
def train_on_perfect_moves(self, min_samples: int = 100) -> Dict[str, Any]:
|
||||
"""Train the model on perfect moves from the orchestrator"""
|
||||
if not self.orchestrator:
|
||||
raise ValueError("Orchestrator required for perfect move training")
|
||||
|
||||
# Get perfect moves from orchestrator
|
||||
perfect_moves = []
|
||||
for symbol in self.config.symbols:
|
||||
symbol_moves = self.orchestrator.get_perfect_moves_for_training(symbol=symbol)
|
||||
perfect_moves.extend(symbol_moves)
|
||||
|
||||
if len(perfect_moves) < min_samples:
|
||||
logger.warning(f"Not enough perfect moves for training: {len(perfect_moves)} < {min_samples}")
|
||||
return {'error': 'insufficient_data', 'samples': len(perfect_moves)}
|
||||
|
||||
logger.info(f"Training on {len(perfect_moves)} perfect moves")
|
||||
|
||||
# Create dataset
|
||||
dataset = PerfectMoveDataset(perfect_moves, self.data_provider)
|
||||
|
||||
# Split into train/validation
|
||||
train_size = int(0.8 * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
||||
|
||||
# Create data loaders
|
||||
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
|
||||
|
||||
# Training loop
|
||||
best_val_loss = float('inf')
|
||||
patience_counter = 0
|
||||
|
||||
for epoch in range(self.epochs):
|
||||
# Training phase
|
||||
train_loss, train_acc = self._train_epoch(train_loader)
|
||||
|
||||
# Validation phase
|
||||
val_loss, val_acc, conf_acc = self._validate_epoch(val_loader)
|
||||
|
||||
# Update history
|
||||
self.training_history['train_loss'].append(train_loss)
|
||||
self.training_history['val_loss'].append(val_loss)
|
||||
self.training_history['train_accuracy'].append(train_acc)
|
||||
self.training_history['val_accuracy'].append(val_acc)
|
||||
self.training_history['confidence_accuracy'].append(conf_acc)
|
||||
|
||||
# Log progress
|
||||
logger.info(f"Epoch {epoch+1}/{self.epochs}: "
|
||||
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
|
||||
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
|
||||
f"Conf Acc: {conf_acc:.4f}")
|
||||
|
||||
# Early stopping
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
patience_counter = 0
|
||||
self._save_model('best_model.pt')
|
||||
else:
|
||||
patience_counter += 1
|
||||
if patience_counter >= self.patience:
|
||||
logger.info(f"Early stopping at epoch {epoch+1}")
|
||||
break
|
||||
|
||||
# Save final model
|
||||
self._save_model('final_model.pt')
|
||||
|
||||
# Generate training report
|
||||
return self._generate_training_report()
|
||||
|
||||
def _train_epoch(self, train_loader: DataLoader) -> Tuple[float, float]:
|
||||
"""Train for one epoch"""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
correct_predictions = 0
|
||||
total_predictions = 0
|
||||
|
||||
for batch in train_loader:
|
||||
features = batch['features'].to(self.model.device)
|
||||
action_labels = batch['action_label'].to(self.model.device).squeeze(-1)
|
||||
confidence_targets = batch['confidence_target'].to(self.model.device).squeeze(-1)
|
||||
|
||||
# Zero gradients
|
||||
self.model.optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
action_probs, confidence_pred = self.model(features)
|
||||
|
||||
# Calculate losses
|
||||
action_loss = self.model.action_criterion(action_probs, action_labels)
|
||||
confidence_loss = self.model.confidence_criterion(confidence_pred, confidence_targets)
|
||||
|
||||
# Combined loss
|
||||
total_loss_batch = action_loss + 0.5 * confidence_loss
|
||||
|
||||
# Backward pass
|
||||
total_loss_batch.backward()
|
||||
self.model.optimizer.step()
|
||||
|
||||
# Track metrics
|
||||
total_loss += total_loss_batch.item()
|
||||
predicted_actions = torch.argmax(action_probs, dim=1)
|
||||
correct_predictions += (predicted_actions == action_labels).sum().item()
|
||||
total_predictions += action_labels.size(0)
|
||||
|
||||
avg_loss = total_loss / len(train_loader)
|
||||
accuracy = correct_predictions / total_predictions
|
||||
|
||||
return avg_loss, accuracy
|
||||
|
||||
def _validate_epoch(self, val_loader: DataLoader) -> Tuple[float, float, float]:
|
||||
"""Validate for one epoch"""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
correct_predictions = 0
|
||||
total_predictions = 0
|
||||
confidence_errors = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
features = batch['features'].to(self.model.device)
|
||||
action_labels = batch['action_label'].to(self.model.device).squeeze(-1)
|
||||
confidence_targets = batch['confidence_target'].to(self.model.device).squeeze(-1)
|
||||
|
||||
# Forward pass
|
||||
action_probs, confidence_pred = self.model(features)
|
||||
|
||||
# Calculate losses
|
||||
action_loss = self.model.action_criterion(action_probs, action_labels)
|
||||
confidence_loss = self.model.confidence_criterion(confidence_pred, confidence_targets)
|
||||
total_loss_batch = action_loss + 0.5 * confidence_loss
|
||||
|
||||
# Track metrics
|
||||
total_loss += total_loss_batch.item()
|
||||
predicted_actions = torch.argmax(action_probs, dim=1)
|
||||
correct_predictions += (predicted_actions == action_labels).sum().item()
|
||||
total_predictions += action_labels.size(0)
|
||||
|
||||
# Track confidence accuracy
|
||||
conf_errors = torch.abs(confidence_pred - confidence_targets)
|
||||
confidence_errors.extend(conf_errors.cpu().numpy())
|
||||
|
||||
avg_loss = total_loss / len(val_loader)
|
||||
accuracy = correct_predictions / total_predictions
|
||||
confidence_accuracy = 1.0 - np.mean(confidence_errors) # 1 - mean absolute error
|
||||
|
||||
return avg_loss, accuracy, confidence_accuracy
|
||||
|
||||
def _save_model(self, filename: str):
|
||||
"""Save the model"""
|
||||
save_path = self.save_dir / filename
|
||||
torch.save({
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.model.optimizer.state_dict(),
|
||||
'config': self.config.cnn,
|
||||
'training_history': self.training_history
|
||||
}, save_path)
|
||||
logger.info(f"Model saved to {save_path}")
|
||||
|
||||
def load_model(self, filename: str) -> bool:
|
||||
"""Load a saved model"""
|
||||
load_path = self.save_dir / filename
|
||||
if not load_path.exists():
|
||||
logger.error(f"Model file not found: {load_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(load_path, map_location=self.model.device)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.training_history = checkpoint.get('training_history', {})
|
||||
logger.info(f"Model loaded from {load_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {e}")
|
||||
return False
|
||||
|
||||
def _generate_training_report(self) -> Dict[str, Any]:
|
||||
"""Generate comprehensive training report"""
|
||||
if not self.training_history['train_loss']:
|
||||
return {'error': 'no_training_data'}
|
||||
|
||||
# Calculate final metrics
|
||||
final_train_loss = self.training_history['train_loss'][-1]
|
||||
final_val_loss = self.training_history['val_loss'][-1]
|
||||
final_train_acc = self.training_history['train_accuracy'][-1]
|
||||
final_val_acc = self.training_history['val_accuracy'][-1]
|
||||
final_conf_acc = self.training_history['confidence_accuracy'][-1]
|
||||
|
||||
# Best metrics
|
||||
best_val_loss = min(self.training_history['val_loss'])
|
||||
best_val_acc = max(self.training_history['val_accuracy'])
|
||||
best_conf_acc = max(self.training_history['confidence_accuracy'])
|
||||
|
||||
report = {
|
||||
'training_completed': True,
|
||||
'epochs_trained': len(self.training_history['train_loss']),
|
||||
'final_metrics': {
|
||||
'train_loss': final_train_loss,
|
||||
'val_loss': final_val_loss,
|
||||
'train_accuracy': final_train_acc,
|
||||
'val_accuracy': final_val_acc,
|
||||
'confidence_accuracy': final_conf_acc
|
||||
},
|
||||
'best_metrics': {
|
||||
'val_loss': best_val_loss,
|
||||
'val_accuracy': best_val_acc,
|
||||
'confidence_accuracy': best_conf_acc
|
||||
},
|
||||
'model_info': {
|
||||
'timeframes': self.model.timeframes,
|
||||
'memory_usage_mb': self.model.get_memory_usage(),
|
||||
'device': str(self.model.device)
|
||||
}
|
||||
}
|
||||
|
||||
# Generate plots
|
||||
self._plot_training_history()
|
||||
|
||||
logger.info("Training completed successfully")
|
||||
logger.info(f"Final validation accuracy: {final_val_acc:.4f}")
|
||||
logger.info(f"Final confidence accuracy: {final_conf_acc:.4f}")
|
||||
|
||||
return report
|
||||
|
||||
def _plot_training_history(self):
|
||||
"""Plot training history"""
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||
fig.suptitle('Enhanced CNN Training History')
|
||||
|
||||
# Loss plot
|
||||
axes[0, 0].plot(self.training_history['train_loss'], label='Train Loss')
|
||||
axes[0, 0].plot(self.training_history['val_loss'], label='Val Loss')
|
||||
axes[0, 0].set_title('Loss')
|
||||
axes[0, 0].set_xlabel('Epoch')
|
||||
axes[0, 0].set_ylabel('Loss')
|
||||
axes[0, 0].legend()
|
||||
|
||||
# Accuracy plot
|
||||
axes[0, 1].plot(self.training_history['train_accuracy'], label='Train Accuracy')
|
||||
axes[0, 1].plot(self.training_history['val_accuracy'], label='Val Accuracy')
|
||||
axes[0, 1].set_title('Action Accuracy')
|
||||
axes[0, 1].set_xlabel('Epoch')
|
||||
axes[0, 1].set_ylabel('Accuracy')
|
||||
axes[0, 1].legend()
|
||||
|
||||
# Confidence accuracy plot
|
||||
axes[1, 0].plot(self.training_history['confidence_accuracy'], label='Confidence Accuracy')
|
||||
axes[1, 0].set_title('Confidence Prediction Accuracy')
|
||||
axes[1, 0].set_xlabel('Epoch')
|
||||
axes[1, 0].set_ylabel('Accuracy')
|
||||
axes[1, 0].legend()
|
||||
|
||||
# Learning curves comparison
|
||||
axes[1, 1].plot(self.training_history['val_loss'], label='Validation Loss')
|
||||
axes[1, 1].plot(self.training_history['confidence_accuracy'], label='Confidence Accuracy')
|
||||
axes[1, 1].set_title('Model Performance Overview')
|
||||
axes[1, 1].set_xlabel('Epoch')
|
||||
axes[1, 1].legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(self.save_dir / 'training_history.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
logger.info(f"Training plots saved to {self.save_dir / 'training_history.png'}")
|
||||
|
||||
def get_model(self) -> EnhancedCNNModel:
|
||||
"""Get the trained model"""
|
||||
return self.model
|
||||
|
||||
def close_tensorboard(self):
|
||||
"""Close TensorBoard writer if it exists"""
|
||||
if hasattr(self, 'writer') and self.writer:
|
||||
try:
|
||||
self.writer.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup when object is destroyed"""
|
||||
self.close_tensorboard()
|
||||
|
||||
def main():
|
||||
"""Main function for standalone CNN live training with backtesting and analysis"""
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
parser = argparse.ArgumentParser(description='Enhanced CNN Live Training with Backtesting and Analysis')
|
||||
parser.add_argument('--symbols', type=str, nargs='+', default=['ETH/USDT', 'BTC/USDT'],
|
||||
help='Trading symbols to train on')
|
||||
parser.add_argument('--timeframes', type=str, nargs='+', default=['1m', '5m', '15m', '1h'],
|
||||
help='Timeframes to use for training')
|
||||
parser.add_argument('--epochs', type=int, default=100,
|
||||
help='Number of training epochs')
|
||||
parser.add_argument('--batch-size', type=int, default=32,
|
||||
help='Training batch size')
|
||||
parser.add_argument('--learning-rate', type=float, default=0.001,
|
||||
help='Learning rate')
|
||||
parser.add_argument('--save-path', type=str, default='models/enhanced_cnn/live_trained_model.pt',
|
||||
help='Path to save the trained model')
|
||||
parser.add_argument('--enable-backtesting', action='store_true', default=True,
|
||||
help='Enable backtesting after training')
|
||||
parser.add_argument('--enable-analysis', action='store_true', default=True,
|
||||
help='Enable detailed analysis and reporting')
|
||||
parser.add_argument('--enable-live-validation', action='store_true', default=True,
|
||||
help='Enable live validation during training')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
logger.info("="*80)
|
||||
logger.info("ENHANCED CNN LIVE TRAINING WITH BACKTESTING & ANALYSIS")
|
||||
logger.info("="*80)
|
||||
logger.info(f"Symbols: {args.symbols}")
|
||||
logger.info(f"Timeframes: {args.timeframes}")
|
||||
logger.info(f"Epochs: {args.epochs}")
|
||||
logger.info(f"Batch Size: {args.batch_size}")
|
||||
logger.info(f"Learning Rate: {args.learning_rate}")
|
||||
logger.info(f"Save Path: {args.save_path}")
|
||||
logger.info(f"Backtesting: {'Enabled' if args.enable_backtesting else 'Disabled'}")
|
||||
logger.info(f"Analysis: {'Enabled' if args.enable_analysis else 'Disabled'}")
|
||||
logger.info(f"Live Validation: {'Enabled' if args.enable_live_validation else 'Disabled'}")
|
||||
logger.info("="*80)
|
||||
|
||||
try:
|
||||
# Update config with command line arguments
|
||||
config = get_config()
|
||||
config.update('symbols', args.symbols)
|
||||
config.update('timeframes', args.timeframes)
|
||||
config.update('training', {
|
||||
**config.training,
|
||||
'epochs': args.epochs,
|
||||
'batch_size': args.batch_size,
|
||||
'learning_rate': args.learning_rate
|
||||
})
|
||||
|
||||
# Initialize enhanced trainer
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
data_provider = DataProvider(config)
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||
trainer = EnhancedCNNTrainer(config, orchestrator)
|
||||
|
||||
# Phase 1: Data Collection and Preparation
|
||||
logger.info("📊 Phase 1: Collecting and preparing training data...")
|
||||
training_data = trainer.collect_training_data(args.symbols, lookback_days=30)
|
||||
logger.info(f" Collected {len(training_data)} training samples")
|
||||
|
||||
# Phase 2: Model Training
|
||||
logger.info("Phase 2: Training Enhanced CNN Model...")
|
||||
training_results = trainer.train_on_perfect_moves(min_samples=1000)
|
||||
|
||||
logger.info("Training Results:")
|
||||
logger.info(f" Best Validation Accuracy: {training_results['best_val_accuracy']:.4f}")
|
||||
logger.info(f" Best Validation Loss: {training_results['best_val_loss']:.4f}")
|
||||
logger.info(f" Total Epochs: {training_results['epochs_completed']}")
|
||||
logger.info(f" Training Time: {training_results['total_time']:.2f}s")
|
||||
|
||||
# Phase 3: Model Evaluation
|
||||
logger.info("📈 Phase 3: Model Evaluation...")
|
||||
evaluation_results = trainer.evaluate_model(args.symbols[:1]) # Use first symbol for evaluation
|
||||
|
||||
logger.info("Evaluation Results:")
|
||||
logger.info(f" Test Accuracy: {evaluation_results['test_accuracy']:.4f}")
|
||||
logger.info(f" Test Loss: {evaluation_results['test_loss']:.4f}")
|
||||
logger.info(f" Confidence Score: {evaluation_results['avg_confidence']:.4f}")
|
||||
|
||||
# Phase 4: Backtesting (if enabled)
|
||||
if args.enable_backtesting:
|
||||
logger.info("📊 Phase 4: Backtesting...")
|
||||
|
||||
# Create backtest environment
|
||||
from trading.backtest_environment import BacktestEnvironment
|
||||
backtest_env = BacktestEnvironment(
|
||||
symbols=args.symbols,
|
||||
timeframes=args.timeframes,
|
||||
initial_balance=10000.0,
|
||||
data_provider=data_provider
|
||||
)
|
||||
|
||||
# Run backtest
|
||||
backtest_results = backtest_env.run_backtest_with_model(
|
||||
model=trainer.model,
|
||||
lookback_days=7, # Test on last 7 days
|
||||
max_trades_per_day=50
|
||||
)
|
||||
|
||||
logger.info("Backtesting Results:")
|
||||
logger.info(f" Total Returns: {backtest_results['total_return']:.2f}%")
|
||||
logger.info(f" Win Rate: {backtest_results['win_rate']:.2f}%")
|
||||
logger.info(f" Sharpe Ratio: {backtest_results['sharpe_ratio']:.4f}")
|
||||
logger.info(f" Max Drawdown: {backtest_results['max_drawdown']:.2f}%")
|
||||
logger.info(f" Total Trades: {backtest_results['total_trades']}")
|
||||
logger.info(f" Profit Factor: {backtest_results['profit_factor']:.4f}")
|
||||
|
||||
# Phase 5: Analysis and Reporting (if enabled)
|
||||
if args.enable_analysis:
|
||||
logger.info("📋 Phase 5: Analysis and Reporting...")
|
||||
|
||||
# Generate comprehensive analysis report
|
||||
analysis_report = trainer.generate_analysis_report(
|
||||
training_results=training_results,
|
||||
evaluation_results=evaluation_results,
|
||||
backtest_results=backtest_results if args.enable_backtesting else None
|
||||
)
|
||||
|
||||
# Save analysis report
|
||||
report_path = Path(args.save_path).parent / "analysis_report.json"
|
||||
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(report_path, 'w') as f:
|
||||
json.dump(analysis_report, f, indent=2, default=str)
|
||||
|
||||
logger.info(f" Analysis report saved: {report_path}")
|
||||
|
||||
# Generate performance plots
|
||||
plots_dir = Path(args.save_path).parent / "plots"
|
||||
plots_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
trainer.generate_performance_plots(
|
||||
training_results=training_results,
|
||||
evaluation_results=evaluation_results,
|
||||
save_dir=plots_dir
|
||||
)
|
||||
|
||||
logger.info(f" Performance plots saved: {plots_dir}")
|
||||
|
||||
# Phase 6: Model Saving
|
||||
logger.info("💾 Phase 6: Saving trained model...")
|
||||
model_path = Path(args.save_path)
|
||||
model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
trainer.model.save(str(model_path))
|
||||
logger.info(f" Model saved: {model_path}")
|
||||
|
||||
# Save training metadata
|
||||
metadata = {
|
||||
'training_config': {
|
||||
'symbols': args.symbols,
|
||||
'timeframes': args.timeframes,
|
||||
'epochs': args.epochs,
|
||||
'batch_size': args.batch_size,
|
||||
'learning_rate': args.learning_rate
|
||||
},
|
||||
'training_results': training_results,
|
||||
'evaluation_results': evaluation_results
|
||||
}
|
||||
|
||||
if args.enable_backtesting:
|
||||
metadata['backtest_results'] = backtest_results
|
||||
|
||||
metadata_path = model_path.with_suffix('.json')
|
||||
with open(metadata_path, 'w') as f:
|
||||
json.dump(metadata, f, indent=2, default=str)
|
||||
|
||||
logger.info(f" Training metadata saved: {metadata_path}")
|
||||
|
||||
# Phase 7: Live Validation (if enabled)
|
||||
if args.enable_live_validation:
|
||||
logger.info("🔄 Phase 7: Live Validation...")
|
||||
|
||||
# Test model on recent live data
|
||||
live_validation_results = trainer.run_live_validation(
|
||||
symbols=args.symbols[:1], # Use first symbol
|
||||
validation_hours=2 # Validate on last 2 hours
|
||||
)
|
||||
|
||||
logger.info("Live Validation Results:")
|
||||
logger.info(f" Prediction Accuracy: {live_validation_results['accuracy']:.2f}%")
|
||||
logger.info(f" Average Confidence: {live_validation_results['avg_confidence']:.4f}")
|
||||
logger.info(f" Predictions Made: {live_validation_results['total_predictions']}")
|
||||
|
||||
logger.info("="*80)
|
||||
logger.info("🎉 ENHANCED CNN LIVE TRAINING COMPLETED SUCCESSFULLY!")
|
||||
logger.info("="*80)
|
||||
logger.info(f"📊 Model Path: {model_path}")
|
||||
logger.info(f"📋 Metadata: {metadata_path}")
|
||||
if args.enable_analysis:
|
||||
logger.info(f"📈 Analysis Report: {report_path}")
|
||||
logger.info(f"📊 Performance Plots: {plots_dir}")
|
||||
logger.info("="*80)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
return 1
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
@ -1,584 +0,0 @@
|
||||
"""
|
||||
Enhanced Pivot-Based RL Trainer
|
||||
|
||||
Integrates Williams Market Structure pivot points with CNN predictions
|
||||
for improved trading decisions and training rewards.
|
||||
|
||||
Key Features:
|
||||
- Train RL model to buy/sell at local pivot points
|
||||
- CNN predicts next pivot to avoid late signals
|
||||
- Different thresholds for entry vs exit
|
||||
- Rewards for staying uninvested when uncertain
|
||||
- Uncertainty-based confidence adjustment
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from collections import deque, namedtuple
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union, TYPE_CHECKING
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from training.williams_market_structure import WilliamsMarketStructure, SwingType, SwingPoint
|
||||
|
||||
# Use TYPE_CHECKING to avoid circular import
|
||||
if TYPE_CHECKING:
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PivotReward:
|
||||
"""Reward structure for pivot-based trading decisions"""
|
||||
|
||||
def __init__(self):
|
||||
# Pivot-based reward weights
|
||||
self.pivot_hit_bonus = 2.0 # Bonus for trading at actual pivot points
|
||||
self.pivot_anticipation_bonus = 1.5 # Bonus for trading before pivot (CNN prediction)
|
||||
self.wrong_direction_penalty = -1.0 # Penalty for trading opposite to pivot direction
|
||||
self.late_entry_penalty = -0.5 # Penalty for entering after pivot is confirmed
|
||||
|
||||
# Stay uninvested rewards
|
||||
self.uninvested_reward = 0.1 # Small positive reward for staying out of poor setups
|
||||
self.avoid_false_signal_bonus = 0.5 # Bonus for avoiding false signals
|
||||
|
||||
# Uncertainty penalties
|
||||
self.overconfidence_penalty = -0.3 # Penalty for being overconfident on losses
|
||||
self.underconfidence_penalty = -0.1 # Small penalty for being underconfident on wins
|
||||
|
||||
class EnhancedPivotRLTrainer:
|
||||
"""Enhanced RL trainer focused on Williams pivot points and CNN predictions"""
|
||||
|
||||
def __init__(self,
|
||||
data_provider: DataProvider = None,
|
||||
orchestrator: Optional["EnhancedTradingOrchestrator"] = None):
|
||||
|
||||
self.config = get_config()
|
||||
self.data_provider = data_provider or DataProvider()
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
# Initialize Williams Market Structure with CNN
|
||||
self.williams = WilliamsMarketStructure(
|
||||
swing_strengths=[2, 4, 6, 8, 10], # Multiple strengths for better detection
|
||||
enable_cnn_feature=True,
|
||||
training_data_provider=data_provider
|
||||
)
|
||||
|
||||
# Pivot tracking
|
||||
self.recent_pivots = deque(maxlen=50)
|
||||
self.pivot_predictions = deque(maxlen=20)
|
||||
self.trade_outcomes = deque(maxlen=100)
|
||||
|
||||
# Threshold management - different for entry vs exit
|
||||
self.entry_threshold = 0.65 # Higher threshold for entering positions
|
||||
self.exit_threshold = 0.35 # Lower threshold for exiting positions
|
||||
self.max_uninvested_reward_threshold = 0.60 # Stay out if confidence below this
|
||||
|
||||
# Confidence learning parameters
|
||||
self.confidence_history = deque(maxlen=200)
|
||||
self.mistake_severity_tracker = deque(maxlen=50)
|
||||
|
||||
# Reward calculator
|
||||
self.pivot_reward = PivotReward()
|
||||
|
||||
logger.info("Enhanced Pivot RL Trainer initialized")
|
||||
logger.info(f"Entry threshold: {self.entry_threshold:.2%}")
|
||||
logger.info(f"Exit threshold: {self.exit_threshold:.2%}")
|
||||
logger.info(f"Uninvested reward threshold: {self.max_uninvested_reward_threshold:.2%}")
|
||||
|
||||
def calculate_pivot_based_reward(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
market_data: pd.DataFrame,
|
||||
trade_outcome: Dict[str, Any]) -> float:
|
||||
"""
|
||||
Calculate enhanced reward based on pivot points and CNN predictions
|
||||
|
||||
Args:
|
||||
trade_decision: The trading decision made by the model
|
||||
market_data: Market data context
|
||||
trade_outcome: Actual trade outcome
|
||||
|
||||
Returns:
|
||||
Enhanced reward score
|
||||
"""
|
||||
try:
|
||||
base_pnl = trade_outcome.get('net_pnl', 0.0)
|
||||
confidence = trade_decision.get('confidence', 0.5)
|
||||
action = trade_decision.get('action', 'HOLD')
|
||||
entry_price = trade_decision.get('price', 0.0)
|
||||
exit_price = trade_outcome.get('exit_price', entry_price)
|
||||
duration = trade_outcome.get('duration', timedelta(0))
|
||||
|
||||
# Base PnL reward
|
||||
base_reward = base_pnl / 5.0
|
||||
|
||||
# 1. Pivot Point Analysis Rewards
|
||||
pivot_reward = self._calculate_pivot_rewards(
|
||||
trade_decision, market_data, trade_outcome
|
||||
)
|
||||
|
||||
# 2. CNN Prediction Accuracy Rewards
|
||||
cnn_reward = self._calculate_cnn_prediction_rewards(
|
||||
trade_decision, market_data, trade_outcome
|
||||
)
|
||||
|
||||
# 3. Uninvested Period Rewards
|
||||
uninvested_reward = self._calculate_uninvested_rewards(
|
||||
trade_decision, confidence
|
||||
)
|
||||
|
||||
# 4. Uncertainty-based Confidence Adjustment
|
||||
confidence_adjustment = self._calculate_confidence_adjustment(
|
||||
trade_decision, trade_outcome
|
||||
)
|
||||
|
||||
# 5. Time efficiency with pivot context
|
||||
time_reward = self._calculate_time_efficiency_reward(
|
||||
duration, base_pnl, market_data
|
||||
)
|
||||
|
||||
# Combine all rewards
|
||||
total_reward = (
|
||||
base_reward +
|
||||
pivot_reward +
|
||||
cnn_reward +
|
||||
uninvested_reward +
|
||||
confidence_adjustment +
|
||||
time_reward
|
||||
)
|
||||
|
||||
# Log detailed reward breakdown
|
||||
self._log_reward_breakdown(
|
||||
trade_decision, trade_outcome, {
|
||||
'base': base_reward,
|
||||
'pivot': pivot_reward,
|
||||
'cnn': cnn_reward,
|
||||
'uninvested': uninvested_reward,
|
||||
'confidence': confidence_adjustment,
|
||||
'time': time_reward,
|
||||
'total': total_reward
|
||||
}
|
||||
)
|
||||
|
||||
# Track for learning
|
||||
self._track_reward_outcome(trade_decision, trade_outcome, total_reward)
|
||||
|
||||
return np.clip(total_reward, -15.0, 10.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating pivot-based reward: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_pivot_rewards(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
market_data: pd.DataFrame,
|
||||
trade_outcome: Dict[str, Any]) -> float:
|
||||
"""Calculate rewards based on proximity to pivot points"""
|
||||
try:
|
||||
entry_price = trade_decision.get('price', 0.0)
|
||||
action = trade_decision.get('action', 'HOLD')
|
||||
entry_time = trade_decision.get('timestamp', datetime.now())
|
||||
net_pnl = trade_outcome.get('net_pnl', 0.0)
|
||||
|
||||
# Find recent pivot points from Williams analysis
|
||||
ohlcv_array = self._convert_dataframe_to_ohlcv_array(market_data)
|
||||
if ohlcv_array is None or len(ohlcv_array) < 20:
|
||||
return 0.0
|
||||
|
||||
# Get pivot points from Williams structure
|
||||
structure_levels = self.williams.calculate_recursive_pivot_points(ohlcv_array)
|
||||
if not structure_levels or 'level_0' not in structure_levels:
|
||||
return 0.0
|
||||
|
||||
level_0_pivots = structure_levels['level_0'].swing_points
|
||||
if not level_0_pivots:
|
||||
return 0.0
|
||||
|
||||
# Find closest pivot to entry
|
||||
closest_pivot = self._find_closest_pivot(entry_price, entry_time, level_0_pivots)
|
||||
if not closest_pivot:
|
||||
return 0.0
|
||||
|
||||
# Calculate distance to pivot (price and time)
|
||||
price_distance = abs(entry_price - closest_pivot.price) / closest_pivot.price
|
||||
time_distance = abs((entry_time - closest_pivot.timestamp).total_seconds()) / 3600.0 # hours
|
||||
|
||||
pivot_reward = 0.0
|
||||
|
||||
# Reward trading at or near pivot points
|
||||
if price_distance < 0.005: # Within 0.5% of pivot
|
||||
if time_distance < 0.5: # Within 30 minutes
|
||||
pivot_reward += self.pivot_reward.pivot_hit_bonus
|
||||
logger.debug(f"PIVOT HIT BONUS: {self.pivot_reward.pivot_hit_bonus:.2f}")
|
||||
|
||||
# Check if trade direction aligns with pivot
|
||||
if self._trade_aligns_with_pivot(action, closest_pivot, net_pnl):
|
||||
pivot_reward += self.pivot_reward.pivot_anticipation_bonus
|
||||
logger.debug(f"PIVOT DIRECTION BONUS: {self.pivot_reward.pivot_anticipation_bonus:.2f}")
|
||||
else:
|
||||
pivot_reward += self.pivot_reward.wrong_direction_penalty
|
||||
logger.debug(f"WRONG DIRECTION PENALTY: {self.pivot_reward.wrong_direction_penalty:.2f}")
|
||||
|
||||
# Penalty for late entry after pivot confirmation
|
||||
if time_distance > 2.0: # More than 2 hours after pivot
|
||||
pivot_reward += self.pivot_reward.late_entry_penalty
|
||||
logger.debug(f"LATE ENTRY PENALTY: {self.pivot_reward.late_entry_penalty:.2f}")
|
||||
|
||||
return pivot_reward
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating pivot rewards: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_cnn_prediction_rewards(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
market_data: pd.DataFrame,
|
||||
trade_outcome: Dict[str, Any]) -> float:
|
||||
"""Calculate rewards based on CNN pivot predictions"""
|
||||
try:
|
||||
# Check if we have CNN predictions available
|
||||
if not hasattr(self.williams, 'cnn_model') or not self.williams.cnn_model:
|
||||
return 0.0
|
||||
|
||||
action = trade_decision.get('action', 'HOLD')
|
||||
confidence = trade_decision.get('confidence', 0.5)
|
||||
net_pnl = trade_outcome.get('net_pnl', 0.0)
|
||||
|
||||
# Get latest CNN prediction if available
|
||||
# This would be the prediction made before the trade
|
||||
cnn_prediction = self._get_latest_cnn_prediction()
|
||||
if not cnn_prediction:
|
||||
return 0.0
|
||||
|
||||
cnn_reward = 0.0
|
||||
|
||||
# Reward for following CNN predictions that turn out correct
|
||||
predicted_direction = self._interpret_cnn_prediction(cnn_prediction)
|
||||
|
||||
if predicted_direction == action and net_pnl > 0:
|
||||
# CNN prediction was correct and we followed it
|
||||
cnn_reward += 1.0 * confidence # Scale by confidence
|
||||
logger.debug(f"CNN CORRECT FOLLOW: +{1.0 * confidence:.2f}")
|
||||
|
||||
elif predicted_direction != action and net_pnl < 0:
|
||||
# We didn't follow CNN and it was right (we were wrong)
|
||||
cnn_reward -= 0.5
|
||||
logger.debug(f"CNN IGNORE PENALTY: -0.5")
|
||||
|
||||
elif predicted_direction == action and net_pnl < 0:
|
||||
# We followed CNN but it was wrong
|
||||
cnn_reward -= 0.2 # Small penalty, CNN predictions can be wrong
|
||||
logger.debug(f"CNN WRONG FOLLOW: -0.2")
|
||||
|
||||
return cnn_reward
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating CNN prediction rewards: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_uninvested_rewards(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
confidence: float) -> float:
|
||||
"""Calculate rewards for staying uninvested when uncertain"""
|
||||
try:
|
||||
action = trade_decision.get('action', 'HOLD')
|
||||
|
||||
# Reward staying out when confidence is low
|
||||
if action == 'HOLD' and confidence < self.max_uninvested_reward_threshold:
|
||||
uninvested_reward = self.pivot_reward.uninvested_reward
|
||||
|
||||
# Bonus for avoiding very uncertain setups
|
||||
if confidence < 0.4:
|
||||
uninvested_reward += self.pivot_reward.avoid_false_signal_bonus
|
||||
logger.debug(f"AVOID FALSE SIGNAL BONUS: +{self.pivot_reward.avoid_false_signal_bonus:.2f}")
|
||||
|
||||
logger.debug(f"UNINVESTED REWARD: +{uninvested_reward:.2f}")
|
||||
return uninvested_reward
|
||||
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating uninvested rewards: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_confidence_adjustment(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
trade_outcome: Dict[str, Any]) -> float:
|
||||
"""Adjust rewards based on confidence vs outcome to reduce overconfidence"""
|
||||
try:
|
||||
confidence = trade_decision.get('confidence', 0.5)
|
||||
net_pnl = trade_outcome.get('net_pnl', 0.0)
|
||||
|
||||
confidence_adjustment = 0.0
|
||||
|
||||
# Track mistake severity
|
||||
mistake_severity = abs(net_pnl) if net_pnl < 0 else 0.0
|
||||
self.mistake_severity_tracker.append(mistake_severity)
|
||||
|
||||
# Penalize overconfidence on losses
|
||||
if net_pnl < 0 and confidence > 0.7:
|
||||
# High confidence but loss - penalize overconfidence
|
||||
overconfidence_factor = (confidence - 0.7) / 0.3 # 0-1 scale
|
||||
severity_factor = min(mistake_severity / 2.0, 1.0) # Scale by loss size
|
||||
|
||||
penalty = self.pivot_reward.overconfidence_penalty * overconfidence_factor * severity_factor
|
||||
confidence_adjustment += penalty
|
||||
|
||||
logger.debug(f"OVERCONFIDENCE PENALTY: {penalty:.2f} (conf: {confidence:.2f}, loss: ${net_pnl:.2f})")
|
||||
|
||||
# Small penalty for underconfidence on wins
|
||||
elif net_pnl > 0 and confidence < 0.4:
|
||||
underconfidence_factor = (0.4 - confidence) / 0.4 # 0-1 scale
|
||||
penalty = self.pivot_reward.underconfidence_penalty * underconfidence_factor
|
||||
confidence_adjustment += penalty
|
||||
|
||||
logger.debug(f"UNDERCONFIDENCE PENALTY: {penalty:.2f} (conf: {confidence:.2f}, profit: ${net_pnl:.2f})")
|
||||
|
||||
# Update confidence learning
|
||||
self._update_confidence_learning(confidence, net_pnl, mistake_severity)
|
||||
|
||||
return confidence_adjustment
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating confidence adjustment: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_time_efficiency_reward(self,
|
||||
duration: timedelta,
|
||||
net_pnl: float,
|
||||
market_data: pd.DataFrame) -> float:
|
||||
"""Calculate time-based rewards considering market context"""
|
||||
try:
|
||||
duration_hours = duration.total_seconds() / 3600.0
|
||||
|
||||
# Quick profitable trades get bonus
|
||||
if net_pnl > 0 and duration_hours < 0.5: # Less than 30 minutes
|
||||
return 0.3
|
||||
|
||||
# Holding losses too long gets penalty
|
||||
elif net_pnl < 0 and duration_hours > 2.0: # More than 2 hours
|
||||
return -0.5
|
||||
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating time efficiency reward: {e}")
|
||||
return 0.0
|
||||
|
||||
def update_thresholds_based_on_performance(self):
|
||||
"""Dynamically adjust entry/exit thresholds based on recent performance"""
|
||||
try:
|
||||
if len(self.trade_outcomes) < 20:
|
||||
return
|
||||
|
||||
recent_outcomes = list(self.trade_outcomes)[-20:]
|
||||
|
||||
# Calculate win rate and average PnL
|
||||
wins = sum(1 for outcome in recent_outcomes if outcome['net_pnl'] > 0)
|
||||
win_rate = wins / len(recent_outcomes)
|
||||
avg_pnl = np.mean([outcome['net_pnl'] for outcome in recent_outcomes])
|
||||
|
||||
# Adjust thresholds based on performance
|
||||
if win_rate < 0.4: # Low win rate - be more selective
|
||||
self.entry_threshold = min(self.entry_threshold + 0.02, 0.80)
|
||||
logger.info(f"Low win rate ({win_rate:.2%}) - increased entry threshold to {self.entry_threshold:.2%}")
|
||||
|
||||
elif win_rate > 0.6 and avg_pnl > 0: # High win rate - can be more aggressive
|
||||
self.entry_threshold = max(self.entry_threshold - 0.01, 0.50)
|
||||
logger.info(f"High win rate ({win_rate:.2%}) - decreased entry threshold to {self.entry_threshold:.2%}")
|
||||
|
||||
# Adjust exit threshold based on loss severity
|
||||
avg_loss_severity = np.mean(list(self.mistake_severity_tracker)) if self.mistake_severity_tracker else 0
|
||||
|
||||
if avg_loss_severity > 1.0: # Large average losses
|
||||
self.exit_threshold = max(self.exit_threshold - 0.01, 0.20)
|
||||
logger.info(f"High loss severity - decreased exit threshold to {self.exit_threshold:.2%}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating thresholds: {e}")
|
||||
|
||||
def get_current_thresholds(self) -> Dict[str, float]:
|
||||
"""Get current entry and exit thresholds"""
|
||||
return {
|
||||
'entry_threshold': self.entry_threshold,
|
||||
'exit_threshold': self.exit_threshold,
|
||||
'uninvested_threshold': self.max_uninvested_reward_threshold
|
||||
}
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _convert_dataframe_to_ohlcv_array(self, df: pd.DataFrame) -> Optional[np.ndarray]:
|
||||
"""Convert pandas DataFrame to numpy array for Williams analysis"""
|
||||
try:
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
# Ensure we have required columns
|
||||
required_cols = ['open', 'high', 'low', 'close', 'volume']
|
||||
if not all(col in df.columns for col in required_cols):
|
||||
return None
|
||||
|
||||
# Convert to numpy array
|
||||
timestamps = df.index.astype(np.int64) // 10**9 # Convert to Unix timestamp
|
||||
ohlcv_array = np.column_stack([
|
||||
timestamps,
|
||||
df['open'].values,
|
||||
df['high'].values,
|
||||
df['low'].values,
|
||||
df['close'].values,
|
||||
df['volume'].values
|
||||
])
|
||||
|
||||
return ohlcv_array.astype(np.float64)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting DataFrame to OHLCV array: {e}")
|
||||
return None
|
||||
|
||||
def _find_closest_pivot(self,
|
||||
entry_price: float,
|
||||
entry_time: datetime,
|
||||
pivots: List[SwingPoint]) -> Optional[SwingPoint]:
|
||||
"""Find the closest pivot point to the trade entry"""
|
||||
try:
|
||||
if not pivots:
|
||||
return None
|
||||
|
||||
# Find pivot closest in time and price
|
||||
best_pivot = None
|
||||
best_score = float('inf')
|
||||
|
||||
for pivot in pivots:
|
||||
time_diff = abs((entry_time - pivot.timestamp).total_seconds()) / 3600.0
|
||||
price_diff = abs(entry_price - pivot.price) / pivot.price
|
||||
|
||||
# Combined score (weighted by time and price proximity)
|
||||
score = time_diff * 0.3 + price_diff * 100 # Weight price difference more heavily
|
||||
|
||||
if score < best_score:
|
||||
best_score = score
|
||||
best_pivot = pivot
|
||||
|
||||
return best_pivot
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding closest pivot: {e}")
|
||||
return None
|
||||
|
||||
def _trade_aligns_with_pivot(self,
|
||||
action: str,
|
||||
pivot: SwingPoint,
|
||||
net_pnl: float) -> bool:
|
||||
"""Check if trade direction aligns with pivot type and was profitable"""
|
||||
try:
|
||||
if net_pnl <= 0: # Only consider profitable trades as aligned
|
||||
return False
|
||||
|
||||
if action == 'BUY' and pivot.swing_type == SwingType.SWING_LOW:
|
||||
return True # Bought at/near swing low
|
||||
elif action == 'SELL' and pivot.swing_type == SwingType.SWING_HIGH:
|
||||
return True # Sold at/near swing high
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking trade alignment: {e}")
|
||||
return False
|
||||
|
||||
def _get_latest_cnn_prediction(self) -> Optional[np.ndarray]:
|
||||
"""Get the latest CNN prediction from Williams structure"""
|
||||
try:
|
||||
# This would access the Williams CNN model's latest prediction
|
||||
# For now, return None if not available
|
||||
if hasattr(self.williams, 'latest_cnn_prediction'):
|
||||
return self.williams.latest_cnn_prediction
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting CNN prediction: {e}")
|
||||
return None
|
||||
|
||||
def _interpret_cnn_prediction(self, prediction: np.ndarray) -> str:
|
||||
"""Interpret CNN prediction array to trading action"""
|
||||
try:
|
||||
if len(prediction) < 2:
|
||||
return 'HOLD'
|
||||
|
||||
# Assuming prediction format: [type, price] for level 0
|
||||
predicted_type = prediction[0] # 0 = LOW, 1 = HIGH
|
||||
|
||||
if predicted_type > 0.5:
|
||||
return 'SELL' # Expecting swing high - sell
|
||||
else:
|
||||
return 'BUY' # Expecting swing low - buy
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error interpreting CNN prediction: {e}")
|
||||
return 'HOLD'
|
||||
|
||||
def _update_confidence_learning(self,
|
||||
confidence: float,
|
||||
net_pnl: float,
|
||||
mistake_severity: float):
|
||||
"""Update confidence learning parameters"""
|
||||
try:
|
||||
self.confidence_history.append({
|
||||
'confidence': confidence,
|
||||
'net_pnl': net_pnl,
|
||||
'mistake_severity': mistake_severity,
|
||||
'timestamp': datetime.now()
|
||||
})
|
||||
|
||||
# Periodically update thresholds based on confidence patterns
|
||||
if len(self.confidence_history) % 10 == 0:
|
||||
self.update_thresholds_based_on_performance()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating confidence learning: {e}")
|
||||
|
||||
def _track_reward_outcome(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
trade_outcome: Dict[str, Any],
|
||||
total_reward: float):
|
||||
"""Track reward outcomes for analysis"""
|
||||
try:
|
||||
outcome_record = {
|
||||
'timestamp': datetime.now(),
|
||||
'action': trade_decision.get('action'),
|
||||
'confidence': trade_decision.get('confidence'),
|
||||
'net_pnl': trade_outcome.get('net_pnl'),
|
||||
'reward': total_reward,
|
||||
'duration': trade_outcome.get('duration')
|
||||
}
|
||||
|
||||
self.trade_outcomes.append(outcome_record)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error tracking reward outcome: {e}")
|
||||
|
||||
def _log_reward_breakdown(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
trade_outcome: Dict[str, Any],
|
||||
rewards: Dict[str, float]):
|
||||
"""Log detailed reward breakdown"""
|
||||
try:
|
||||
action = trade_decision.get('action', 'UNKNOWN')
|
||||
confidence = trade_decision.get('confidence', 0.0)
|
||||
net_pnl = trade_outcome.get('net_pnl', 0.0)
|
||||
|
||||
logger.info(f"[REWARD] {action} (conf: {confidence:.2%}) PnL: ${net_pnl:.2f} -> Total: {rewards['total']:.2f}")
|
||||
logger.debug(f" Base: {rewards['base']:.2f}, Pivot: {rewards['pivot']:.2f}, CNN: {rewards['cnn']:.2f}")
|
||||
logger.debug(f" Uninvested: {rewards['uninvested']:.2f}, Confidence: {rewards['confidence']:.2f}, Time: {rewards['time']:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging reward breakdown: {e}")
|
||||
|
||||
def create_enhanced_pivot_trainer(data_provider: DataProvider = None,
|
||||
orchestrator: Optional["EnhancedTradingOrchestrator"] = None) -> EnhancedPivotRLTrainer:
|
||||
"""Factory function to create enhanced pivot trainer"""
|
||||
return EnhancedPivotRLTrainer(data_provider, orchestrator)
|
@ -1,708 +0,0 @@
|
||||
"""
|
||||
Enhanced RL State Builder for Comprehensive Market Data Integration
|
||||
|
||||
This module implements the specification requirements for RL training with:
|
||||
- 300s of raw tick data for momentum detection
|
||||
- Multi-timeframe OHLCV data (1s, 1m, 1h, 1d) for ETH and BTC
|
||||
- CNN hidden layer features integration
|
||||
- CNN predictions from all timeframes
|
||||
- Pivot point predictions using Williams market structure
|
||||
- Market regime analysis
|
||||
|
||||
State Vector Components:
|
||||
- ETH tick data: ~3000 features (300s * 10 features/tick)
|
||||
- ETH OHLCV 1s: ~2400 features (300 bars * 8 features)
|
||||
- ETH OHLCV 1m: ~2400 features (300 bars * 8 features)
|
||||
- ETH OHLCV 1h: ~2400 features (300 bars * 8 features)
|
||||
- ETH OHLCV 1d: ~2400 features (300 bars * 8 features)
|
||||
- BTC reference: ~2400 features (300 bars * 8 features)
|
||||
- CNN features: ~512 features (hidden layer)
|
||||
- CNN predictions: ~16 features (4 timeframes * 4 outputs)
|
||||
- Pivot points: ~250 features (Williams structure)
|
||||
- Market regime: ~20 features
|
||||
Total: ~8000+ features
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
try:
|
||||
import ta
|
||||
except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning("TA-Lib not available, using pandas for technical indicators")
|
||||
ta = None
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.universal_data_adapter import UniversalDataStream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TickData:
|
||||
"""Tick data structure"""
|
||||
timestamp: datetime
|
||||
price: float
|
||||
volume: float
|
||||
bid: float = 0.0
|
||||
ask: float = 0.0
|
||||
|
||||
@property
|
||||
def spread(self) -> float:
|
||||
return self.ask - self.bid if self.ask > 0 and self.bid > 0 else 0.0
|
||||
|
||||
@dataclass
|
||||
class OHLCVData:
|
||||
"""OHLCV data structure"""
|
||||
timestamp: datetime
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: float
|
||||
|
||||
# Technical indicators (optional)
|
||||
rsi: Optional[float] = None
|
||||
macd: Optional[float] = None
|
||||
bb_upper: Optional[float] = None
|
||||
bb_lower: Optional[float] = None
|
||||
sma_20: Optional[float] = None
|
||||
ema_12: Optional[float] = None
|
||||
atr: Optional[float] = None
|
||||
|
||||
@dataclass
|
||||
class StateComponentConfig:
|
||||
"""Configuration for state component sizes"""
|
||||
eth_ticks: int = 3000 # 300s * 10 features per tick
|
||||
eth_1s_ohlcv: int = 2400 # 300 bars * 8 features (OHLCV + indicators)
|
||||
eth_1m_ohlcv: int = 2400 # 300 bars * 8 features
|
||||
eth_1h_ohlcv: int = 2400 # 300 bars * 8 features
|
||||
eth_1d_ohlcv: int = 2400 # 300 bars * 8 features
|
||||
btc_reference: int = 2400 # BTC reference data
|
||||
cnn_features: int = 512 # CNN hidden layer features
|
||||
cnn_predictions: int = 16 # CNN predictions (4 timeframes * 4 outputs)
|
||||
pivot_points: int = 250 # Recursive pivot points (5 levels * 50 points)
|
||||
market_regime: int = 20 # Market regime features
|
||||
|
||||
@property
|
||||
def total_size(self) -> int:
|
||||
"""Calculate total state size"""
|
||||
return (self.eth_ticks + self.eth_1s_ohlcv + self.eth_1m_ohlcv +
|
||||
self.eth_1h_ohlcv + self.eth_1d_ohlcv + self.btc_reference +
|
||||
self.cnn_features + self.cnn_predictions + self.pivot_points +
|
||||
self.market_regime)
|
||||
|
||||
class EnhancedRLStateBuilder:
|
||||
"""
|
||||
Comprehensive RL state builder implementing specification requirements
|
||||
|
||||
Features:
|
||||
- 300s tick data processing with momentum detection
|
||||
- Multi-timeframe OHLCV integration
|
||||
- CNN hidden layer feature extraction
|
||||
- Pivot point calculation and integration
|
||||
- Market regime analysis
|
||||
- BTC reference data processing
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
|
||||
# Data windows
|
||||
self.tick_window_seconds = 300 # 5 minutes of tick data
|
||||
self.ohlcv_window_bars = 300 # 300 bars for each timeframe
|
||||
|
||||
# State component sizes
|
||||
self.state_components = {
|
||||
'eth_ticks': 300 * 10, # 3000 features: tick data with derived features
|
||||
'eth_1s_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
||||
'eth_1m_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
||||
'eth_1h_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
||||
'eth_1d_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
||||
'btc_reference': 300 * 8, # 2400 features: BTC reference data
|
||||
'cnn_features': 512, # 512 features: CNN hidden layer
|
||||
'cnn_predictions': 16, # 16 features: CNN predictions (4 timeframes * 4 outputs)
|
||||
'pivot_points': 250, # 250 features: Williams market structure
|
||||
'market_regime': 20 # 20 features: Market regime indicators
|
||||
}
|
||||
|
||||
self.total_state_size = sum(self.state_components.values())
|
||||
|
||||
# Data buffers for maintaining windows
|
||||
self.tick_buffers = {}
|
||||
self.ohlcv_buffers = {}
|
||||
|
||||
# Normalization parameters
|
||||
self.normalization_params = self._initialize_normalization_params()
|
||||
|
||||
# Feature extractors
|
||||
self.momentum_detector = TickMomentumDetector()
|
||||
self.indicator_calculator = TechnicalIndicatorCalculator()
|
||||
self.regime_analyzer = MarketRegimeAnalyzer()
|
||||
|
||||
logger.info(f"Enhanced RL State Builder initialized")
|
||||
logger.info(f"Total state size: {self.total_state_size} features")
|
||||
logger.info(f"State components: {self.state_components}")
|
||||
|
||||
def build_rl_state(self,
|
||||
eth_ticks: List[TickData],
|
||||
eth_ohlcv: Dict[str, List[OHLCVData]],
|
||||
btc_ohlcv: Dict[str, List[OHLCVData]],
|
||||
cnn_hidden_features: Optional[Dict[str, np.ndarray]] = None,
|
||||
cnn_predictions: Optional[Dict[str, np.ndarray]] = None,
|
||||
pivot_data: Optional[Dict[str, Any]] = None) -> np.ndarray:
|
||||
"""
|
||||
Build comprehensive RL state vector from all data sources
|
||||
|
||||
Args:
|
||||
eth_ticks: List of ETH tick data (last 300s)
|
||||
eth_ohlcv: Dict of ETH OHLCV data by timeframe
|
||||
btc_ohlcv: Dict of BTC OHLCV data by timeframe
|
||||
cnn_hidden_features: CNN hidden layer features by timeframe
|
||||
cnn_predictions: CNN predictions by timeframe
|
||||
pivot_data: Pivot point data from Williams analysis
|
||||
|
||||
Returns:
|
||||
np.ndarray: Comprehensive state vector (~8000+ features)
|
||||
"""
|
||||
try:
|
||||
state_vector = []
|
||||
|
||||
# 1. Process ETH tick data (3000 features)
|
||||
tick_features = self._process_tick_data(eth_ticks)
|
||||
state_vector.extend(tick_features)
|
||||
|
||||
# 2. Process ETH multi-timeframe OHLCV (9600 features total)
|
||||
for timeframe in ['1s', '1m', '1h', '1d']:
|
||||
if timeframe in eth_ohlcv:
|
||||
ohlcv_features = self._process_ohlcv_data(
|
||||
eth_ohlcv[timeframe], timeframe, symbol='ETH'
|
||||
)
|
||||
else:
|
||||
ohlcv_features = np.zeros(self.state_components[f'eth_{timeframe}_ohlcv'])
|
||||
state_vector.extend(ohlcv_features)
|
||||
|
||||
# 3. Process BTC reference data (2400 features)
|
||||
btc_features = self._process_btc_reference_data(btc_ohlcv)
|
||||
state_vector.extend(btc_features)
|
||||
|
||||
# 4. Process CNN hidden layer features (512 features)
|
||||
cnn_hidden = self._process_cnn_hidden_features(cnn_hidden_features)
|
||||
state_vector.extend(cnn_hidden)
|
||||
|
||||
# 5. Process CNN predictions (16 features)
|
||||
cnn_pred = self._process_cnn_predictions(cnn_predictions)
|
||||
state_vector.extend(cnn_pred)
|
||||
|
||||
# 6. Process pivot points (250 features)
|
||||
pivot_features = self._process_pivot_points(pivot_data, eth_ohlcv)
|
||||
state_vector.extend(pivot_features)
|
||||
|
||||
# 7. Process market regime features (20 features)
|
||||
regime_features = self._process_market_regime(eth_ohlcv, btc_ohlcv)
|
||||
state_vector.extend(regime_features)
|
||||
|
||||
# Convert to numpy array and validate size
|
||||
state_array = np.array(state_vector, dtype=np.float32)
|
||||
|
||||
if len(state_array) != self.total_state_size:
|
||||
logger.warning(f"State size mismatch: expected {self.total_state_size}, got {len(state_array)}")
|
||||
# Pad or truncate to expected size
|
||||
if len(state_array) < self.total_state_size:
|
||||
padding = np.zeros(self.total_state_size - len(state_array))
|
||||
state_array = np.concatenate([state_array, padding])
|
||||
else:
|
||||
state_array = state_array[:self.total_state_size]
|
||||
|
||||
# Apply normalization
|
||||
state_array = self._normalize_state(state_array)
|
||||
|
||||
return state_array
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building RL state: {e}")
|
||||
# Return zero state on error
|
||||
return np.zeros(self.total_state_size, dtype=np.float32)
|
||||
|
||||
def _process_tick_data(self, ticks: List[TickData]) -> List[float]:
|
||||
"""Process raw tick data into features for momentum detection"""
|
||||
features = []
|
||||
|
||||
if not ticks or len(ticks) < 10:
|
||||
# Return zeros if insufficient data
|
||||
return [0.0] * self.state_components['eth_ticks']
|
||||
|
||||
# Ensure we have exactly 300 data points (pad or sample)
|
||||
processed_ticks = self._normalize_tick_window(ticks, 300)
|
||||
|
||||
for i, tick in enumerate(processed_ticks):
|
||||
# Basic tick features
|
||||
tick_features = [
|
||||
tick.price,
|
||||
tick.volume,
|
||||
tick.bid,
|
||||
tick.ask,
|
||||
tick.spread
|
||||
]
|
||||
|
||||
# Derived features
|
||||
if i > 0:
|
||||
prev_tick = processed_ticks[i-1]
|
||||
price_change = (tick.price - prev_tick.price) / prev_tick.price if prev_tick.price > 0 else 0
|
||||
volume_change = (tick.volume - prev_tick.volume) / prev_tick.volume if prev_tick.volume > 0 else 0
|
||||
|
||||
tick_features.extend([
|
||||
price_change,
|
||||
volume_change,
|
||||
tick.price / prev_tick.price - 1.0 if prev_tick.price > 0 else 0, # Price ratio
|
||||
np.log(tick.volume / prev_tick.volume) if prev_tick.volume > 0 else 0, # Log volume ratio
|
||||
self.momentum_detector.calculate_micro_momentum(processed_ticks[max(0, i-5):i+1])
|
||||
])
|
||||
else:
|
||||
tick_features.extend([0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
features.extend(tick_features)
|
||||
|
||||
return features[:self.state_components['eth_ticks']]
|
||||
|
||||
def _process_ohlcv_data(self, ohlcv_data: List[OHLCVData],
|
||||
timeframe: str, symbol: str = 'ETH') -> List[float]:
|
||||
"""Process OHLCV data with technical indicators"""
|
||||
features = []
|
||||
|
||||
if not ohlcv_data or len(ohlcv_data) < 20:
|
||||
component_key = f'{symbol.lower()}_{timeframe}_ohlcv' if symbol == 'ETH' else 'btc_reference'
|
||||
return [0.0] * self.state_components[component_key]
|
||||
|
||||
# Convert to DataFrame for indicator calculation
|
||||
df = pd.DataFrame([{
|
||||
'timestamp': bar.timestamp,
|
||||
'open': bar.open,
|
||||
'high': bar.high,
|
||||
'low': bar.low,
|
||||
'close': bar.close,
|
||||
'volume': bar.volume
|
||||
} for bar in ohlcv_data[-self.ohlcv_window_bars:]])
|
||||
|
||||
# Calculate technical indicators
|
||||
df = self.indicator_calculator.add_all_indicators(df)
|
||||
|
||||
# Ensure we have exactly 300 bars
|
||||
if len(df) < 300:
|
||||
# Pad with last known values
|
||||
last_row = df.iloc[-1:].copy()
|
||||
padding_rows = []
|
||||
for _ in range(300 - len(df)):
|
||||
padding_rows.append(last_row)
|
||||
if padding_rows:
|
||||
df = pd.concat([df] + padding_rows, ignore_index=True)
|
||||
else:
|
||||
df = df.tail(300)
|
||||
|
||||
# Extract features for each bar
|
||||
feature_columns = ['open', 'high', 'low', 'close', 'volume', 'rsi', 'macd', 'bb_middle']
|
||||
|
||||
for _, row in df.iterrows():
|
||||
bar_features = []
|
||||
for col in feature_columns:
|
||||
if col in row and not pd.isna(row[col]):
|
||||
bar_features.append(float(row[col]))
|
||||
else:
|
||||
bar_features.append(0.0)
|
||||
features.extend(bar_features)
|
||||
|
||||
component_key = f'{symbol.lower()}_{timeframe}_ohlcv' if symbol == 'ETH' else 'btc_reference'
|
||||
return features[:self.state_components[component_key]]
|
||||
|
||||
def _process_btc_reference_data(self, btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
||||
"""Process BTC reference data (using 1h timeframe as primary)"""
|
||||
if '1h' in btc_ohlcv and btc_ohlcv['1h']:
|
||||
return self._process_ohlcv_data(btc_ohlcv['1h'], '1h', 'BTC')
|
||||
elif '1m' in btc_ohlcv and btc_ohlcv['1m']:
|
||||
return self._process_ohlcv_data(btc_ohlcv['1m'], '1m', 'BTC')
|
||||
else:
|
||||
return [0.0] * self.state_components['btc_reference']
|
||||
|
||||
def _process_cnn_hidden_features(self, cnn_features: Optional[Dict[str, np.ndarray]]) -> List[float]:
|
||||
"""Process CNN hidden layer features"""
|
||||
if not cnn_features:
|
||||
return [0.0] * self.state_components['cnn_features']
|
||||
|
||||
# Combine features from all timeframes
|
||||
combined_features = []
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
features_per_timeframe = self.state_components['cnn_features'] // len(timeframes)
|
||||
|
||||
for tf in timeframes:
|
||||
if tf in cnn_features and cnn_features[tf] is not None:
|
||||
tf_features = cnn_features[tf].flatten()
|
||||
# Truncate or pad to fit allocation
|
||||
if len(tf_features) >= features_per_timeframe:
|
||||
combined_features.extend(tf_features[:features_per_timeframe])
|
||||
else:
|
||||
combined_features.extend(tf_features)
|
||||
combined_features.extend([0.0] * (features_per_timeframe - len(tf_features)))
|
||||
else:
|
||||
combined_features.extend([0.0] * features_per_timeframe)
|
||||
|
||||
return combined_features[:self.state_components['cnn_features']]
|
||||
|
||||
def _process_cnn_predictions(self, cnn_predictions: Optional[Dict[str, np.ndarray]]) -> List[float]:
|
||||
"""Process CNN predictions from all timeframes"""
|
||||
if not cnn_predictions:
|
||||
return [0.0] * self.state_components['cnn_predictions']
|
||||
|
||||
predictions = []
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
for tf in timeframes:
|
||||
if tf in cnn_predictions and cnn_predictions[tf] is not None:
|
||||
pred = cnn_predictions[tf].flatten()
|
||||
# Expecting 4 outputs per timeframe (BUY, SELL, HOLD, confidence)
|
||||
if len(pred) >= 4:
|
||||
predictions.extend(pred[:4])
|
||||
else:
|
||||
predictions.extend(pred)
|
||||
predictions.extend([0.0] * (4 - len(pred)))
|
||||
else:
|
||||
predictions.extend([0.0, 0.0, 1.0, 0.0]) # Default to HOLD with 0 confidence
|
||||
|
||||
return predictions[:self.state_components['cnn_predictions']]
|
||||
|
||||
def _process_pivot_points(self, pivot_data: Optional[Dict[str, Any]],
|
||||
eth_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
||||
"""Process pivot points using Williams market structure"""
|
||||
if pivot_data:
|
||||
# Use provided pivot data
|
||||
return self._extract_pivot_features(pivot_data)
|
||||
elif '1m' in eth_ohlcv and eth_ohlcv['1m']:
|
||||
# Calculate pivot points from 1m data
|
||||
from training.williams_market_structure import WilliamsMarketStructure
|
||||
williams = WilliamsMarketStructure()
|
||||
|
||||
# Convert OHLCV to numpy array
|
||||
ohlcv_array = self._ohlcv_to_array(eth_ohlcv['1m'])
|
||||
pivot_data = williams.calculate_recursive_pivot_points(ohlcv_array)
|
||||
return self._extract_pivot_features(pivot_data)
|
||||
else:
|
||||
return [0.0] * self.state_components['pivot_points']
|
||||
|
||||
def _process_market_regime(self, eth_ohlcv: Dict[str, List[OHLCVData]],
|
||||
btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
||||
"""Process market regime indicators"""
|
||||
regime_features = []
|
||||
|
||||
# ETH regime analysis
|
||||
if '1h' in eth_ohlcv and eth_ohlcv['1h']:
|
||||
eth_regime = self.regime_analyzer.analyze_regime(eth_ohlcv['1h'])
|
||||
regime_features.extend([
|
||||
eth_regime['volatility'],
|
||||
eth_regime['trend_strength'],
|
||||
eth_regime['volume_trend'],
|
||||
eth_regime['momentum'],
|
||||
1.0 if eth_regime['regime'] == 'trending' else 0.0,
|
||||
1.0 if eth_regime['regime'] == 'ranging' else 0.0,
|
||||
1.0 if eth_regime['regime'] == 'volatile' else 0.0
|
||||
])
|
||||
else:
|
||||
regime_features.extend([0.0] * 7)
|
||||
|
||||
# BTC regime analysis
|
||||
if '1h' in btc_ohlcv and btc_ohlcv['1h']:
|
||||
btc_regime = self.regime_analyzer.analyze_regime(btc_ohlcv['1h'])
|
||||
regime_features.extend([
|
||||
btc_regime['volatility'],
|
||||
btc_regime['trend_strength'],
|
||||
btc_regime['volume_trend'],
|
||||
btc_regime['momentum'],
|
||||
1.0 if btc_regime['regime'] == 'trending' else 0.0,
|
||||
1.0 if btc_regime['regime'] == 'ranging' else 0.0,
|
||||
1.0 if btc_regime['regime'] == 'volatile' else 0.0
|
||||
])
|
||||
else:
|
||||
regime_features.extend([0.0] * 7)
|
||||
|
||||
# Correlation features
|
||||
correlation_features = self._calculate_btc_eth_correlation(eth_ohlcv, btc_ohlcv)
|
||||
regime_features.extend(correlation_features)
|
||||
|
||||
return regime_features[:self.state_components['market_regime']]
|
||||
|
||||
def _normalize_tick_window(self, ticks: List[TickData], target_size: int) -> List[TickData]:
|
||||
"""Normalize tick window to target size"""
|
||||
if len(ticks) == target_size:
|
||||
return ticks
|
||||
elif len(ticks) > target_size:
|
||||
# Sample evenly
|
||||
step = len(ticks) / target_size
|
||||
indices = [int(i * step) for i in range(target_size)]
|
||||
return [ticks[i] for i in indices]
|
||||
else:
|
||||
# Pad with last tick
|
||||
result = ticks.copy()
|
||||
last_tick = ticks[-1] if ticks else TickData(datetime.now(), 0, 0)
|
||||
while len(result) < target_size:
|
||||
result.append(last_tick)
|
||||
return result
|
||||
|
||||
def _extract_pivot_features(self, pivot_data: Dict[str, Any]) -> List[float]:
|
||||
"""Extract features from pivot point data"""
|
||||
features = []
|
||||
|
||||
for level in range(5): # 5 levels of recursion
|
||||
level_key = f'level_{level}'
|
||||
if level_key in pivot_data:
|
||||
level_data = pivot_data[level_key]
|
||||
|
||||
# Swing point features
|
||||
swing_points = level_data.get('swing_points', [])
|
||||
if swing_points:
|
||||
# Last 10 swing points
|
||||
recent_swings = swing_points[-10:]
|
||||
for swing in recent_swings:
|
||||
features.extend([
|
||||
swing['price'],
|
||||
1.0 if swing['type'] == 'swing_high' else 0.0,
|
||||
swing['index']
|
||||
])
|
||||
|
||||
# Pad if fewer than 10 swings
|
||||
while len(recent_swings) < 10:
|
||||
features.extend([0.0, 0.0, 0.0])
|
||||
recent_swings.append({'type': 'none'})
|
||||
else:
|
||||
features.extend([0.0] * 30) # 10 swings * 3 features
|
||||
|
||||
# Trend features
|
||||
features.extend([
|
||||
level_data.get('trend_strength', 0.0),
|
||||
1.0 if level_data.get('trend_direction') == 'up' else 0.0,
|
||||
1.0 if level_data.get('trend_direction') == 'down' else 0.0
|
||||
])
|
||||
else:
|
||||
features.extend([0.0] * 33) # 30 swing + 3 trend features
|
||||
|
||||
return features[:self.state_components['pivot_points']]
|
||||
|
||||
def _ohlcv_to_array(self, ohlcv_data: List[OHLCVData]) -> np.ndarray:
|
||||
"""Convert OHLCV data to numpy array"""
|
||||
return np.array([[
|
||||
bar.timestamp.timestamp(),
|
||||
bar.open,
|
||||
bar.high,
|
||||
bar.low,
|
||||
bar.close,
|
||||
bar.volume
|
||||
] for bar in ohlcv_data])
|
||||
|
||||
def _calculate_btc_eth_correlation(self, eth_ohlcv: Dict[str, List[OHLCVData]],
|
||||
btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
||||
"""Calculate BTC-ETH correlation features"""
|
||||
try:
|
||||
# Use 1h data for correlation
|
||||
if '1h' not in eth_ohlcv or '1h' not in btc_ohlcv:
|
||||
return [0.0] * 6
|
||||
|
||||
eth_prices = [bar.close for bar in eth_ohlcv['1h'][-50:]] # Last 50 hours
|
||||
btc_prices = [bar.close for bar in btc_ohlcv['1h'][-50:]]
|
||||
|
||||
if len(eth_prices) < 10 or len(btc_prices) < 10:
|
||||
return [0.0] * 6
|
||||
|
||||
# Align lengths
|
||||
min_len = min(len(eth_prices), len(btc_prices))
|
||||
eth_prices = eth_prices[-min_len:]
|
||||
btc_prices = btc_prices[-min_len:]
|
||||
|
||||
# Calculate returns
|
||||
eth_returns = np.diff(eth_prices) / eth_prices[:-1]
|
||||
btc_returns = np.diff(btc_prices) / btc_prices[:-1]
|
||||
|
||||
# Correlation
|
||||
correlation = np.corrcoef(eth_returns, btc_returns)[0, 1] if len(eth_returns) > 1 else 0.0
|
||||
|
||||
# Price ratio
|
||||
current_ratio = eth_prices[-1] / btc_prices[-1] if btc_prices[-1] > 0 else 0.0
|
||||
avg_ratio = np.mean([e/b for e, b in zip(eth_prices, btc_prices) if b > 0])
|
||||
ratio_deviation = (current_ratio - avg_ratio) / avg_ratio if avg_ratio > 0 else 0.0
|
||||
|
||||
# Volatility comparison
|
||||
eth_vol = np.std(eth_returns) if len(eth_returns) > 1 else 0.0
|
||||
btc_vol = np.std(btc_returns) if len(btc_returns) > 1 else 0.0
|
||||
vol_ratio = eth_vol / btc_vol if btc_vol > 0 else 1.0
|
||||
|
||||
return [
|
||||
correlation,
|
||||
current_ratio,
|
||||
ratio_deviation,
|
||||
vol_ratio,
|
||||
eth_vol,
|
||||
btc_vol
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating BTC-ETH correlation: {e}")
|
||||
return [0.0] * 6
|
||||
|
||||
def _initialize_normalization_params(self) -> Dict[str, Dict[str, float]]:
|
||||
"""Initialize normalization parameters for different feature types"""
|
||||
return {
|
||||
'price_features': {'mean': 0.0, 'std': 1.0, 'min': -10.0, 'max': 10.0},
|
||||
'volume_features': {'mean': 0.0, 'std': 1.0, 'min': -5.0, 'max': 5.0},
|
||||
'indicator_features': {'mean': 0.0, 'std': 1.0, 'min': -3.0, 'max': 3.0},
|
||||
'cnn_features': {'mean': 0.0, 'std': 1.0, 'min': -2.0, 'max': 2.0},
|
||||
'pivot_features': {'mean': 0.0, 'std': 1.0, 'min': -5.0, 'max': 5.0}
|
||||
}
|
||||
|
||||
def _normalize_state(self, state: np.ndarray) -> np.ndarray:
|
||||
"""Apply normalization to state vector"""
|
||||
try:
|
||||
# Simple clipping and scaling for now
|
||||
# More sophisticated normalization can be added based on training data
|
||||
normalized_state = np.clip(state, -10.0, 10.0)
|
||||
|
||||
# Replace any NaN or inf values
|
||||
normalized_state = np.nan_to_num(normalized_state, nan=0.0, posinf=10.0, neginf=-10.0)
|
||||
|
||||
return normalized_state.astype(np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing state: {e}")
|
||||
return state.astype(np.float32)
|
||||
|
||||
class TickMomentumDetector:
|
||||
"""Detect momentum from tick-level data"""
|
||||
|
||||
def calculate_micro_momentum(self, ticks: List[TickData]) -> float:
|
||||
"""Calculate micro-momentum from tick sequence"""
|
||||
if len(ticks) < 2:
|
||||
return 0.0
|
||||
|
||||
# Price momentum
|
||||
prices = [tick.price for tick in ticks]
|
||||
price_changes = np.diff(prices)
|
||||
price_momentum = np.sum(price_changes) / len(price_changes) if len(price_changes) > 0 else 0.0
|
||||
|
||||
# Volume-weighted momentum
|
||||
volumes = [tick.volume for tick in ticks]
|
||||
if sum(volumes) > 0:
|
||||
weighted_changes = [pc * v for pc, v in zip(price_changes, volumes[1:])]
|
||||
volume_momentum = sum(weighted_changes) / sum(volumes[1:])
|
||||
else:
|
||||
volume_momentum = 0.0
|
||||
|
||||
return (price_momentum + volume_momentum) / 2.0
|
||||
|
||||
class TechnicalIndicatorCalculator:
|
||||
"""Calculate technical indicators for OHLCV data"""
|
||||
|
||||
def add_all_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Add all technical indicators to DataFrame"""
|
||||
df = df.copy()
|
||||
|
||||
# RSI
|
||||
df['rsi'] = self.calculate_rsi(df['close'])
|
||||
|
||||
# MACD
|
||||
df['macd'] = self.calculate_macd(df['close'])
|
||||
|
||||
# Bollinger Bands
|
||||
df['bb_middle'] = df['close'].rolling(20).mean()
|
||||
df['bb_std'] = df['close'].rolling(20).std()
|
||||
df['bb_upper'] = df['bb_middle'] + (df['bb_std'] * 2)
|
||||
df['bb_lower'] = df['bb_middle'] - (df['bb_std'] * 2)
|
||||
|
||||
# Fill NaN values
|
||||
df = df.fillna(method='forward').fillna(0)
|
||||
|
||||
return df
|
||||
|
||||
def calculate_rsi(self, prices: pd.Series, period: int = 14) -> pd.Series:
|
||||
"""Calculate RSI"""
|
||||
delta = prices.diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
||||
rs = gain / loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
return rsi.fillna(50)
|
||||
|
||||
def calculate_macd(self, prices: pd.Series, fast: int = 12, slow: int = 26) -> pd.Series:
|
||||
"""Calculate MACD"""
|
||||
ema_fast = prices.ewm(span=fast).mean()
|
||||
ema_slow = prices.ewm(span=slow).mean()
|
||||
macd = ema_fast - ema_slow
|
||||
return macd.fillna(0)
|
||||
|
||||
class MarketRegimeAnalyzer:
|
||||
"""Analyze market regime from OHLCV data"""
|
||||
|
||||
def analyze_regime(self, ohlcv_data: List[OHLCVData]) -> Dict[str, Any]:
|
||||
"""Analyze market regime"""
|
||||
if len(ohlcv_data) < 20:
|
||||
return {
|
||||
'regime': 'unknown',
|
||||
'volatility': 0.0,
|
||||
'trend_strength': 0.0,
|
||||
'volume_trend': 0.0,
|
||||
'momentum': 0.0
|
||||
}
|
||||
|
||||
prices = [bar.close for bar in ohlcv_data[-50:]] # Last 50 bars
|
||||
volumes = [bar.volume for bar in ohlcv_data[-50:]]
|
||||
|
||||
# Calculate volatility
|
||||
returns = np.diff(prices) / prices[:-1]
|
||||
volatility = np.std(returns) * 100 # Percentage volatility
|
||||
|
||||
# Calculate trend strength
|
||||
sma_short = np.mean(prices[-10:])
|
||||
sma_long = np.mean(prices[-30:])
|
||||
trend_strength = abs(sma_short - sma_long) / sma_long if sma_long > 0 else 0.0
|
||||
|
||||
# Volume trend
|
||||
volume_ma_short = np.mean(volumes[-10:])
|
||||
volume_ma_long = np.mean(volumes[-30:])
|
||||
volume_trend = (volume_ma_short - volume_ma_long) / volume_ma_long if volume_ma_long > 0 else 0.0
|
||||
|
||||
# Momentum
|
||||
momentum = (prices[-1] - prices[-10]) / prices[-10] if len(prices) >= 10 and prices[-10] > 0 else 0.0
|
||||
|
||||
# Determine regime
|
||||
if volatility > 3.0: # High volatility
|
||||
regime = 'volatile'
|
||||
elif abs(momentum) > 0.02: # Strong momentum
|
||||
regime = 'trending'
|
||||
else:
|
||||
regime = 'ranging'
|
||||
|
||||
return {
|
||||
'regime': regime,
|
||||
'volatility': volatility,
|
||||
'trend_strength': trend_strength,
|
||||
'volume_trend': volume_trend,
|
||||
'momentum': momentum
|
||||
}
|
||||
|
||||
def get_state_info(self) -> Dict[str, Any]:
|
||||
"""Get information about the state structure"""
|
||||
return {
|
||||
'total_size': self.config.total_size,
|
||||
'components': {
|
||||
'eth_ticks': self.config.eth_ticks,
|
||||
'eth_1s_ohlcv': self.config.eth_1s_ohlcv,
|
||||
'eth_1m_ohlcv': self.config.eth_1m_ohlcv,
|
||||
'eth_1h_ohlcv': self.config.eth_1h_ohlcv,
|
||||
'eth_1d_ohlcv': self.config.eth_1d_ohlcv,
|
||||
'btc_reference': self.config.btc_reference,
|
||||
'cnn_features': self.config.cnn_features,
|
||||
'cnn_predictions': self.config.cnn_predictions,
|
||||
'pivot_points': self.config.pivot_points,
|
||||
'market_regime': self.config.market_regime,
|
||||
},
|
||||
'data_windows': {
|
||||
'tick_window_seconds': self.tick_window_seconds,
|
||||
'ohlcv_window_bars': self.ohlcv_window_bars,
|
||||
}
|
||||
}
|
@ -1,821 +0,0 @@
|
||||
"""
|
||||
Enhanced RL Trainer with Continuous Learning
|
||||
|
||||
This module implements sophisticated RL training with:
|
||||
- Prioritized experience replay
|
||||
- Market regime adaptation
|
||||
- Continuous learning from trading outcomes
|
||||
- Performance tracking and visualization
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from collections import deque, namedtuple
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator, MarketState, TradingAction
|
||||
from models import RLAgentInterface
|
||||
import models
|
||||
from training.enhanced_rl_state_builder import EnhancedRLStateBuilder
|
||||
from training.williams_market_structure import WilliamsMarketStructure
|
||||
from training.cnn_rl_bridge import CNNRLBridge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Experience tuple for replay buffer
|
||||
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done', 'priority'])
|
||||
|
||||
class PrioritizedReplayBuffer:
|
||||
"""Prioritized experience replay buffer for RL training"""
|
||||
|
||||
def __init__(self, capacity: int = 10000, alpha: float = 0.6):
|
||||
"""
|
||||
Initialize prioritized replay buffer
|
||||
|
||||
Args:
|
||||
capacity: Maximum number of experiences to store
|
||||
alpha: Priority exponent (0 = uniform, 1 = fully prioritized)
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.alpha = alpha
|
||||
self.buffer = []
|
||||
self.priorities = np.zeros(capacity, dtype=np.float32)
|
||||
self.position = 0
|
||||
self.size = 0
|
||||
|
||||
def add(self, experience: Experience):
|
||||
"""Add experience to buffer with priority"""
|
||||
max_priority = self.priorities[:self.size].max() if self.size > 0 else 1.0
|
||||
|
||||
if self.size < self.capacity:
|
||||
self.buffer.append(experience)
|
||||
self.size += 1
|
||||
else:
|
||||
self.buffer[self.position] = experience
|
||||
|
||||
self.priorities[self.position] = max_priority
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
|
||||
def sample(self, batch_size: int, beta: float = 0.4) -> Tuple[List[Experience], np.ndarray, np.ndarray]:
|
||||
"""Sample batch with prioritized sampling"""
|
||||
if self.size == 0:
|
||||
return [], np.array([]), np.array([])
|
||||
|
||||
# Calculate sampling probabilities
|
||||
priorities = self.priorities[:self.size] ** self.alpha
|
||||
probabilities = priorities / priorities.sum()
|
||||
|
||||
# Sample indices
|
||||
indices = np.random.choice(self.size, batch_size, p=probabilities)
|
||||
experiences = [self.buffer[i] for i in indices]
|
||||
|
||||
# Calculate importance sampling weights
|
||||
weights = (self.size * probabilities[indices]) ** (-beta)
|
||||
weights = weights / weights.max() # Normalize
|
||||
|
||||
return experiences, indices, weights
|
||||
|
||||
def update_priorities(self, indices: np.ndarray, priorities: np.ndarray):
|
||||
"""Update priorities for sampled experiences"""
|
||||
for idx, priority in zip(indices, priorities):
|
||||
self.priorities[idx] = priority + 1e-6 # Small epsilon to avoid zero priority
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
class EnhancedDQNAgent(nn.Module, RLAgentInterface):
|
||||
"""Enhanced DQN agent with market environment adaptation"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
nn.Module.__init__(self)
|
||||
RLAgentInterface.__init__(self, config)
|
||||
|
||||
# Network architecture
|
||||
self.state_size = config.get('state_size', 100)
|
||||
self.action_space = config.get('action_space', 3)
|
||||
self.hidden_size = config.get('hidden_size', 256)
|
||||
|
||||
# Build networks
|
||||
self._build_networks()
|
||||
|
||||
# Training parameters
|
||||
self.learning_rate = config.get('learning_rate', 0.0001)
|
||||
self.gamma = config.get('gamma', 0.99)
|
||||
self.epsilon = config.get('epsilon', 1.0)
|
||||
self.epsilon_decay = config.get('epsilon_decay', 0.995)
|
||||
self.epsilon_min = config.get('epsilon_min', 0.01)
|
||||
self.target_update_freq = config.get('target_update_freq', 1000)
|
||||
|
||||
# Initialize device and optimizer
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.to(self.device)
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||
|
||||
# Experience replay
|
||||
self.replay_buffer = PrioritizedReplayBuffer(config.get('buffer_size', 10000))
|
||||
self.batch_size = config.get('batch_size', 64)
|
||||
|
||||
# Market adaptation
|
||||
self.market_regime_weights = {
|
||||
'trending': 1.2, # Higher confidence in trending markets
|
||||
'ranging': 0.8, # Lower confidence in ranging markets
|
||||
'volatile': 0.6 # Much lower confidence in volatile markets
|
||||
}
|
||||
|
||||
# Training statistics
|
||||
self.training_steps = 0
|
||||
self.losses = []
|
||||
self.rewards = []
|
||||
self.epsilon_history = []
|
||||
|
||||
logger.info(f"Enhanced DQN agent initialized with state size: {self.state_size}")
|
||||
|
||||
def _build_networks(self):
|
||||
"""Build main and target networks"""
|
||||
# Main network
|
||||
self.main_network = nn.Sequential(
|
||||
nn.Linear(self.state_size, self.hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(self.hidden_size, self.hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(self.hidden_size, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2)
|
||||
)
|
||||
|
||||
# Dueling network heads
|
||||
self.value_head = nn.Linear(128, 1)
|
||||
self.advantage_head = nn.Linear(128, self.action_space)
|
||||
|
||||
# Target network (copy of main network)
|
||||
self.target_network = nn.Sequential(
|
||||
nn.Linear(self.state_size, self.hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(self.hidden_size, self.hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(self.hidden_size, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2)
|
||||
)
|
||||
|
||||
self.target_value_head = nn.Linear(128, 1)
|
||||
self.target_advantage_head = nn.Linear(128, self.action_space)
|
||||
|
||||
# Initialize target network with same weights
|
||||
self._update_target_network()
|
||||
|
||||
def forward(self, state, target: bool = False):
|
||||
"""Forward pass through the network"""
|
||||
if target:
|
||||
features = self.target_network(state)
|
||||
value = self.target_value_head(features)
|
||||
advantage = self.target_advantage_head(features)
|
||||
else:
|
||||
features = self.main_network(state)
|
||||
value = self.value_head(features)
|
||||
advantage = self.advantage_head(features)
|
||||
|
||||
# Dueling architecture: Q(s,a) = V(s) + A(s,a) - mean(A(s,a))
|
||||
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
|
||||
|
||||
return q_values
|
||||
|
||||
def act(self, state: np.ndarray) -> int:
|
||||
"""Choose action using epsilon-greedy policy"""
|
||||
if random.random() < self.epsilon:
|
||||
return random.randint(0, self.action_space - 1)
|
||||
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
q_values = self.forward(state_tensor)
|
||||
return q_values.argmax().item()
|
||||
|
||||
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]:
|
||||
"""Choose action with confidence score adapted to market regime"""
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
q_values = self.forward(state_tensor)
|
||||
|
||||
# Convert Q-values to probabilities
|
||||
action_probs = torch.softmax(q_values, dim=1)
|
||||
action = q_values.argmax().item()
|
||||
base_confidence = action_probs[0, action].item()
|
||||
|
||||
# Adapt confidence based on market regime
|
||||
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
||||
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
||||
|
||||
return action, adapted_confidence
|
||||
|
||||
def remember(self, state: np.ndarray, action: int, reward: float,
|
||||
next_state: np.ndarray, done: bool):
|
||||
"""Store experience in replay buffer"""
|
||||
# Calculate TD error for priority
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)
|
||||
|
||||
current_q = self.forward(state_tensor)[0, action]
|
||||
next_q = self.forward(next_state_tensor, target=True).max(1)[0]
|
||||
target_q = reward + (self.gamma * next_q * (1 - done))
|
||||
|
||||
td_error = abs(current_q.item() - target_q.item())
|
||||
|
||||
experience = Experience(state, action, reward, next_state, done, td_error)
|
||||
self.replay_buffer.add(experience)
|
||||
|
||||
def replay(self) -> Optional[float]:
|
||||
"""Train the network on a batch of experiences"""
|
||||
if len(self.replay_buffer) < self.batch_size:
|
||||
return None
|
||||
|
||||
# Sample batch
|
||||
experiences, indices, weights = self.replay_buffer.sample(self.batch_size)
|
||||
|
||||
if not experiences:
|
||||
return None
|
||||
|
||||
# Convert to tensors
|
||||
states = torch.FloatTensor([e.state for e in experiences]).to(self.device)
|
||||
actions = torch.LongTensor([e.action for e in experiences]).to(self.device)
|
||||
rewards = torch.FloatTensor([e.reward for e in experiences]).to(self.device)
|
||||
next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device)
|
||||
dones = torch.BoolTensor([e.done for e in experiences]).to(self.device)
|
||||
weights_tensor = torch.FloatTensor(weights).to(self.device)
|
||||
|
||||
# Current Q-values
|
||||
current_q_values = self.forward(states).gather(1, actions.unsqueeze(1))
|
||||
|
||||
# Target Q-values (Double DQN)
|
||||
with torch.no_grad():
|
||||
# Use main network to select actions
|
||||
next_actions = self.forward(next_states).argmax(1)
|
||||
# Use target network to evaluate actions
|
||||
next_q_values = self.forward(next_states, target=True).gather(1, next_actions.unsqueeze(1))
|
||||
target_q_values = rewards.unsqueeze(1) + (self.gamma * next_q_values * ~dones.unsqueeze(1))
|
||||
|
||||
# Calculate weighted loss
|
||||
td_errors = target_q_values - current_q_values
|
||||
loss = (weights_tensor * (td_errors ** 2)).mean()
|
||||
|
||||
# Optimize
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
|
||||
self.optimizer.step()
|
||||
|
||||
# Update priorities
|
||||
new_priorities = torch.abs(td_errors).detach().cpu().numpy().flatten()
|
||||
self.replay_buffer.update_priorities(indices, new_priorities)
|
||||
|
||||
# Update target network
|
||||
self.training_steps += 1
|
||||
if self.training_steps % self.target_update_freq == 0:
|
||||
self._update_target_network()
|
||||
|
||||
# Decay epsilon
|
||||
if self.epsilon > self.epsilon_min:
|
||||
self.epsilon *= self.epsilon_decay
|
||||
|
||||
# Track statistics
|
||||
self.losses.append(loss.item())
|
||||
self.epsilon_history.append(self.epsilon)
|
||||
|
||||
return loss.item()
|
||||
|
||||
def _update_target_network(self):
|
||||
"""Update target network with main network weights"""
|
||||
self.target_network.load_state_dict(self.main_network.state_dict())
|
||||
self.target_value_head.load_state_dict(self.value_head.state_dict())
|
||||
self.target_advantage_head.load_state_dict(self.advantage_head.state_dict())
|
||||
|
||||
def predict(self, features: np.ndarray) -> Tuple[np.ndarray, float]:
|
||||
"""Predict action probabilities and confidence (required by ModelInterface)"""
|
||||
action, confidence = self.act_with_confidence(features)
|
||||
# Convert action to probabilities
|
||||
action_probs = np.zeros(self.action_space)
|
||||
action_probs[action] = 1.0
|
||||
return action_probs, confidence
|
||||
|
||||
def get_memory_usage(self) -> int:
|
||||
"""Get memory usage in MB"""
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.memory_allocated(self.device) // (1024 * 1024)
|
||||
else:
|
||||
param_count = sum(p.numel() for p in self.parameters())
|
||||
buffer_size = len(self.replay_buffer) * self.state_size * 4 # Rough estimate
|
||||
return (param_count * 4 + buffer_size) // (1024 * 1024)
|
||||
|
||||
class EnhancedRLTrainer:
|
||||
"""Enhanced RL trainer with comprehensive state representation and real data integration"""
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
|
||||
"""Initialize enhanced RL trainer with comprehensive state building"""
|
||||
self.config = config or get_config()
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
# Initialize comprehensive state builder (replaces mock code)
|
||||
self.state_builder = EnhancedRLStateBuilder(self.config)
|
||||
self.williams_structure = WilliamsMarketStructure()
|
||||
self.cnn_rl_bridge = CNNRLBridge(self.config) if hasattr(self.config, 'cnn_models') else None
|
||||
|
||||
# Enhanced RL agents with much larger state space
|
||||
self.agents = {}
|
||||
self.initialize_agents()
|
||||
|
||||
# Training configuration
|
||||
self.symbols = self.config.symbols
|
||||
self.save_dir = Path(self.config.rl.get('save_dir', 'models/rl/saved'))
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Performance tracking
|
||||
self.training_metrics = {
|
||||
'total_episodes': 0,
|
||||
'total_rewards': {symbol: [] for symbol in self.symbols},
|
||||
'losses': {symbol: [] for symbol in self.symbols},
|
||||
'epsilon_values': {symbol: [] for symbol in self.symbols}
|
||||
}
|
||||
|
||||
self.performance_history = {symbol: [] for symbol in self.symbols}
|
||||
|
||||
# Real-time learning parameters
|
||||
self.learning_active = False
|
||||
self.experience_buffer_size = 1000
|
||||
self.min_experiences_for_training = 100
|
||||
|
||||
logger.info("Enhanced RL Trainer initialized with comprehensive state representation")
|
||||
logger.info(f"State builder total size: {self.state_builder.total_state_size} features")
|
||||
logger.info(f"Symbols: {self.symbols}")
|
||||
|
||||
def initialize_agents(self):
|
||||
"""Initialize RL agents with enhanced state size"""
|
||||
for symbol in self.symbols:
|
||||
agent_config = {
|
||||
'state_size': self.state_builder.total_state_size, # ~13,400 features
|
||||
'action_space': 3, # BUY, SELL, HOLD
|
||||
'hidden_size': 1024, # Larger hidden layers for complex state
|
||||
'learning_rate': 0.0001,
|
||||
'gamma': 0.99,
|
||||
'epsilon': 1.0,
|
||||
'epsilon_decay': 0.995,
|
||||
'epsilon_min': 0.01,
|
||||
'buffer_size': 50000, # Larger replay buffer
|
||||
'batch_size': 128,
|
||||
'target_update_freq': 1000
|
||||
}
|
||||
|
||||
self.agents[symbol] = EnhancedDQNAgent(agent_config)
|
||||
logger.info(f"Initialized {symbol} RL agent with state size: {agent_config['state_size']}")
|
||||
|
||||
async def continuous_learning_loop(self):
|
||||
"""Main continuous learning loop"""
|
||||
logger.info("Starting continuous RL learning loop")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Train agents with recent experiences
|
||||
await self._train_all_agents()
|
||||
|
||||
# Evaluate recent actions
|
||||
if self.orchestrator:
|
||||
await self.orchestrator.evaluate_actions_with_rl()
|
||||
|
||||
# Adapt to market regime changes
|
||||
await self._adapt_to_market_changes()
|
||||
|
||||
# Update performance metrics
|
||||
self._update_performance_metrics()
|
||||
|
||||
# Save models periodically
|
||||
if self.training_metrics['total_episodes'] % 100 == 0:
|
||||
self._save_all_models()
|
||||
|
||||
# Wait before next training cycle
|
||||
await asyncio.sleep(3600) # Train every hour
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in continuous learning loop: {e}")
|
||||
await asyncio.sleep(60) # Wait 1 minute on error
|
||||
|
||||
async def _train_all_agents(self):
|
||||
"""Train all RL agents with their experiences"""
|
||||
for symbol, agent in self.agents.items():
|
||||
try:
|
||||
if len(agent.replay_buffer) >= self.min_experiences_for_training:
|
||||
# Train for multiple steps
|
||||
losses = []
|
||||
for _ in range(10): # Train 10 steps per cycle
|
||||
loss = agent.replay()
|
||||
if loss is not None:
|
||||
losses.append(loss)
|
||||
|
||||
if losses:
|
||||
avg_loss = np.mean(losses)
|
||||
self.training_metrics['losses'][symbol].append(avg_loss)
|
||||
self.training_metrics['epsilon_values'][symbol].append(agent.epsilon)
|
||||
|
||||
logger.info(f"Trained {symbol} RL agent: Loss={avg_loss:.4f}, Epsilon={agent.epsilon:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training {symbol} agent: {e}")
|
||||
|
||||
async def _adapt_to_market_changes(self):
|
||||
"""Adapt agents to market regime changes"""
|
||||
if not self.orchestrator:
|
||||
return
|
||||
|
||||
for symbol in self.symbols:
|
||||
try:
|
||||
# Get recent market states
|
||||
recent_states = list(self.orchestrator.market_states[symbol])[-10:] # Last 10 states
|
||||
|
||||
if len(recent_states) < 5:
|
||||
continue
|
||||
|
||||
# Analyze regime stability
|
||||
regimes = [state.market_regime for state in recent_states]
|
||||
regime_stability = len(set(regimes)) / len(regimes) # Lower = more stable
|
||||
|
||||
# Adjust learning parameters based on stability
|
||||
agent = self.agents[symbol]
|
||||
if regime_stability < 0.3: # Stable regime
|
||||
agent.epsilon *= 0.99 # Faster epsilon decay
|
||||
elif regime_stability > 0.7: # Unstable regime
|
||||
agent.epsilon = min(agent.epsilon * 1.01, 0.5) # Increase exploration
|
||||
|
||||
logger.debug(f"{symbol} regime stability: {regime_stability:.3f}, epsilon: {agent.epsilon:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adapting {symbol} to market changes: {e}")
|
||||
|
||||
def add_trading_experience(self, symbol: str, action: TradingAction,
|
||||
initial_state: MarketState, final_state: MarketState,
|
||||
reward: float):
|
||||
"""Add trading experience to the appropriate agent"""
|
||||
if symbol not in self.agents:
|
||||
logger.warning(f"No agent for symbol {symbol}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Convert market states to RL state vectors
|
||||
initial_rl_state = self._market_state_to_rl_state(initial_state)
|
||||
final_rl_state = self._market_state_to_rl_state(final_state)
|
||||
|
||||
# Convert action to RL action index
|
||||
action_mapping = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
|
||||
action_idx = action_mapping.get(action.action, 1)
|
||||
|
||||
# Store experience
|
||||
agent = self.agents[symbol]
|
||||
agent.remember(
|
||||
state=initial_rl_state,
|
||||
action=action_idx,
|
||||
reward=reward,
|
||||
next_state=final_rl_state,
|
||||
done=False
|
||||
)
|
||||
|
||||
# Track reward
|
||||
self.training_metrics['total_rewards'][symbol].append(reward)
|
||||
|
||||
logger.debug(f"Added experience for {symbol}: action={action.action}, reward={reward:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding experience for {symbol}: {e}")
|
||||
|
||||
def _market_state_to_rl_state(self, market_state: MarketState) -> np.ndarray:
|
||||
"""Convert market state to comprehensive RL state vector using real data"""
|
||||
try:
|
||||
# Extract data from market state and orchestrator
|
||||
if not self.orchestrator:
|
||||
logger.warning("No orchestrator available for comprehensive state building")
|
||||
return self._fallback_state_conversion(market_state)
|
||||
|
||||
# Get real tick data from orchestrator's data provider
|
||||
symbol = market_state.symbol
|
||||
eth_ticks = self._get_recent_tick_data(symbol, seconds=300)
|
||||
|
||||
# Get multi-timeframe OHLCV data
|
||||
eth_ohlcv = self._get_multiframe_ohlcv_data(symbol)
|
||||
btc_ohlcv = self._get_multiframe_ohlcv_data('BTC/USDT')
|
||||
|
||||
# Get CNN features if available
|
||||
cnn_hidden_features = None
|
||||
cnn_predictions = None
|
||||
if self.cnn_rl_bridge:
|
||||
cnn_data = self.cnn_rl_bridge.get_latest_features_for_symbol(symbol)
|
||||
if cnn_data:
|
||||
cnn_hidden_features = cnn_data.get('hidden_features', {})
|
||||
cnn_predictions = cnn_data.get('predictions', {})
|
||||
|
||||
# Get pivot point data
|
||||
pivot_data = self._calculate_pivot_points(eth_ohlcv)
|
||||
|
||||
# Build comprehensive state using enhanced state builder
|
||||
comprehensive_state = self.state_builder.build_rl_state(
|
||||
eth_ticks=eth_ticks,
|
||||
eth_ohlcv=eth_ohlcv,
|
||||
btc_ohlcv=btc_ohlcv,
|
||||
cnn_hidden_features=cnn_hidden_features,
|
||||
cnn_predictions=cnn_predictions,
|
||||
pivot_data=pivot_data
|
||||
)
|
||||
|
||||
logger.debug(f"Built comprehensive RL state: {len(comprehensive_state)} features")
|
||||
return comprehensive_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building comprehensive RL state: {e}")
|
||||
return self._fallback_state_conversion(market_state)
|
||||
|
||||
def _get_recent_tick_data(self, symbol: str, seconds: int = 300) -> List:
|
||||
"""Get recent tick data from orchestrator's data provider"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'data_provider') and self.orchestrator.data_provider:
|
||||
# Get recent ticks from data provider
|
||||
recent_ticks = self.orchestrator.data_provider.get_recent_ticks(symbol, count=seconds*10)
|
||||
|
||||
# Convert to required format
|
||||
tick_data = []
|
||||
for tick in recent_ticks[-300:]: # Last 300 ticks max
|
||||
tick_data.append({
|
||||
'timestamp': tick.timestamp,
|
||||
'price': tick.price,
|
||||
'volume': tick.volume,
|
||||
'quantity': getattr(tick, 'quantity', tick.volume),
|
||||
'side': getattr(tick, 'side', 'unknown'),
|
||||
'trade_id': getattr(tick, 'trade_id', 'unknown')
|
||||
})
|
||||
|
||||
return tick_data
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting tick data for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def _get_multiframe_ohlcv_data(self, symbol: str) -> Dict[str, List]:
|
||||
"""Get multi-timeframe OHLCV data"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'data_provider') and self.orchestrator.data_provider:
|
||||
ohlcv_data = {}
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
for tf in timeframes:
|
||||
try:
|
||||
# Get historical data for timeframe
|
||||
df = self.orchestrator.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=tf,
|
||||
limit=300,
|
||||
refresh=True
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Convert to list of dictionaries
|
||||
bars = []
|
||||
for _, row in df.tail(300).iterrows():
|
||||
bar = {
|
||||
'timestamp': row.name if hasattr(row, 'name') else datetime.now(),
|
||||
'open': float(row.get('open', 0)),
|
||||
'high': float(row.get('high', 0)),
|
||||
'low': float(row.get('low', 0)),
|
||||
'close': float(row.get('close', 0)),
|
||||
'volume': float(row.get('volume', 0))
|
||||
}
|
||||
bars.append(bar)
|
||||
|
||||
ohlcv_data[tf] = bars
|
||||
else:
|
||||
ohlcv_data[tf] = []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting {tf} data for {symbol}: {e}")
|
||||
ohlcv_data[tf] = []
|
||||
|
||||
return ohlcv_data
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting OHLCV data for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def _calculate_pivot_points(self, eth_ohlcv: Dict[str, List]) -> Dict[str, Any]:
|
||||
"""Calculate Williams pivot points from OHLCV data"""
|
||||
try:
|
||||
if '1m' in eth_ohlcv and eth_ohlcv['1m']:
|
||||
# Convert to numpy array for Williams calculation
|
||||
bars = eth_ohlcv['1m']
|
||||
if len(bars) >= 50: # Need minimum data for pivot calculation
|
||||
ohlc_array = np.array([
|
||||
[bar['timestamp'].timestamp() if hasattr(bar['timestamp'], 'timestamp') else time.time(),
|
||||
bar['open'], bar['high'], bar['low'], bar['close'], bar['volume']]
|
||||
for bar in bars[-200:] # Last 200 bars
|
||||
])
|
||||
|
||||
pivot_data = self.williams_structure.calculate_recursive_pivot_points(ohlc_array)
|
||||
return pivot_data
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating pivot points: {e}")
|
||||
return {}
|
||||
|
||||
def _fallback_state_conversion(self, market_state: MarketState) -> np.ndarray:
|
||||
"""Fallback to basic state conversion if comprehensive state building fails"""
|
||||
logger.warning("Using fallback state conversion - limited features")
|
||||
|
||||
state_components = [
|
||||
market_state.volatility,
|
||||
market_state.volume,
|
||||
market_state.trend_strength
|
||||
]
|
||||
|
||||
# Add price features
|
||||
for timeframe in sorted(market_state.prices.keys()):
|
||||
state_components.append(market_state.prices[timeframe])
|
||||
|
||||
# Pad to match expected state size
|
||||
expected_size = self.state_builder.total_state_size
|
||||
if len(state_components) < expected_size:
|
||||
state_components.extend([0.0] * (expected_size - len(state_components)))
|
||||
else:
|
||||
state_components = state_components[:expected_size]
|
||||
|
||||
return np.array(state_components, dtype=np.float32)
|
||||
|
||||
def _update_performance_metrics(self):
|
||||
"""Update performance tracking metrics"""
|
||||
self.training_metrics['total_episodes'] += 1
|
||||
|
||||
# Calculate recent performance for each agent
|
||||
for symbol, agent in self.agents.items():
|
||||
recent_rewards = self.training_metrics['total_rewards'][symbol][-100:] # Last 100 rewards
|
||||
if recent_rewards:
|
||||
avg_reward = np.mean(recent_rewards)
|
||||
self.performance_history[symbol].append({
|
||||
'timestamp': datetime.now(),
|
||||
'avg_reward': avg_reward,
|
||||
'epsilon': agent.epsilon,
|
||||
'experiences': len(agent.replay_buffer)
|
||||
})
|
||||
|
||||
def _save_all_models(self):
|
||||
"""Save all RL models"""
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
|
||||
for symbol, agent in self.agents.items():
|
||||
filename = f"rl_agent_{symbol}_{timestamp}.pt"
|
||||
filepath = self.save_dir / filename
|
||||
|
||||
torch.save({
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': agent.optimizer.state_dict(),
|
||||
'config': self.config.rl,
|
||||
'training_metrics': self.training_metrics,
|
||||
'symbol': symbol,
|
||||
'epsilon': agent.epsilon,
|
||||
'training_steps': agent.training_steps
|
||||
}, filepath)
|
||||
|
||||
logger.info(f"Saved {symbol} RL agent to {filepath}")
|
||||
|
||||
def load_models(self, timestamp: str = None):
|
||||
"""Load RL models from files"""
|
||||
if timestamp is None:
|
||||
# Find most recent models
|
||||
model_files = list(self.save_dir.glob("rl_agent_*.pt"))
|
||||
if not model_files:
|
||||
logger.warning("No saved RL models found")
|
||||
return False
|
||||
|
||||
# Group by timestamp and get most recent
|
||||
timestamps = set(f.stem.split('_')[-2] + '_' + f.stem.split('_')[-1] for f in model_files)
|
||||
timestamp = max(timestamps)
|
||||
|
||||
loaded_count = 0
|
||||
for symbol in self.symbols:
|
||||
filename = f"rl_agent_{symbol}_{timestamp}.pt"
|
||||
filepath = self.save_dir / filename
|
||||
|
||||
if filepath.exists():
|
||||
try:
|
||||
checkpoint = torch.load(filepath, map_location=self.agents[symbol].device)
|
||||
self.agents[symbol].load_state_dict(checkpoint['model_state_dict'])
|
||||
self.agents[symbol].optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.agents[symbol].epsilon = checkpoint.get('epsilon', 0.1)
|
||||
self.agents[symbol].training_steps = checkpoint.get('training_steps', 0)
|
||||
|
||||
logger.info(f"Loaded {symbol} RL agent from {filepath}")
|
||||
loaded_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {symbol} RL agent: {e}")
|
||||
|
||||
return loaded_count > 0
|
||||
|
||||
def get_performance_report(self) -> Dict[str, Any]:
|
||||
"""Generate performance report for all agents"""
|
||||
report = {
|
||||
'total_episodes': self.training_metrics['total_episodes'],
|
||||
'agents': {}
|
||||
}
|
||||
|
||||
for symbol, agent in self.agents.items():
|
||||
recent_rewards = self.training_metrics['total_rewards'][symbol][-100:]
|
||||
recent_losses = self.training_metrics['losses'][symbol][-10:]
|
||||
|
||||
agent_report = {
|
||||
'symbol': symbol,
|
||||
'epsilon': agent.epsilon,
|
||||
'training_steps': agent.training_steps,
|
||||
'experiences_stored': len(agent.replay_buffer),
|
||||
'memory_usage_mb': agent.get_memory_usage(),
|
||||
'avg_recent_reward': np.mean(recent_rewards) if recent_rewards else 0.0,
|
||||
'avg_recent_loss': np.mean(recent_losses) if recent_losses else 0.0,
|
||||
'total_rewards': len(self.training_metrics['total_rewards'][symbol])
|
||||
}
|
||||
|
||||
report['agents'][symbol] = agent_report
|
||||
|
||||
return report
|
||||
|
||||
def plot_training_metrics(self):
|
||||
"""Plot training metrics for all agents"""
|
||||
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
|
||||
fig.suptitle('Enhanced RL Training Metrics')
|
||||
|
||||
symbols = list(self.agents.keys())
|
||||
colors = ['blue', 'red', 'green', 'orange'][:len(symbols)]
|
||||
|
||||
# Rewards plot
|
||||
for i, symbol in enumerate(symbols):
|
||||
rewards = self.training_metrics['total_rewards'][symbol]
|
||||
if rewards:
|
||||
# Moving average of rewards
|
||||
window = min(100, len(rewards))
|
||||
if len(rewards) >= window:
|
||||
moving_avg = np.convolve(rewards, np.ones(window)/window, mode='valid')
|
||||
axes[0, 0].plot(moving_avg, label=f'{symbol}', color=colors[i])
|
||||
|
||||
axes[0, 0].set_title('Average Rewards (Moving Average)')
|
||||
axes[0, 0].set_xlabel('Episodes')
|
||||
axes[0, 0].set_ylabel('Reward')
|
||||
axes[0, 0].legend()
|
||||
|
||||
# Losses plot
|
||||
for i, symbol in enumerate(symbols):
|
||||
losses = self.training_metrics['losses'][symbol]
|
||||
if losses:
|
||||
axes[0, 1].plot(losses, label=f'{symbol}', color=colors[i])
|
||||
|
||||
axes[0, 1].set_title('Training Losses')
|
||||
axes[0, 1].set_xlabel('Training Steps')
|
||||
axes[0, 1].set_ylabel('Loss')
|
||||
axes[0, 1].legend()
|
||||
|
||||
# Epsilon values
|
||||
for i, symbol in enumerate(symbols):
|
||||
epsilon_values = self.training_metrics['epsilon_values'][symbol]
|
||||
if epsilon_values:
|
||||
axes[1, 0].plot(epsilon_values, label=f'{symbol}', color=colors[i])
|
||||
|
||||
axes[1, 0].set_title('Exploration Rate (Epsilon)')
|
||||
axes[1, 0].set_xlabel('Training Steps')
|
||||
axes[1, 0].set_ylabel('Epsilon')
|
||||
axes[1, 0].legend()
|
||||
|
||||
# Experience buffer sizes
|
||||
buffer_sizes = [len(agent.replay_buffer) for agent in self.agents.values()]
|
||||
axes[1, 1].bar(symbols, buffer_sizes, color=colors[:len(symbols)])
|
||||
axes[1, 1].set_title('Experience Buffer Sizes')
|
||||
axes[1, 1].set_ylabel('Number of Experiences')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(self.save_dir / 'rl_training_metrics.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
logger.info(f"RL training plots saved to {self.save_dir / 'rl_training_metrics.png'}")
|
||||
|
||||
def get_agents(self) -> Dict[str, EnhancedDQNAgent]:
|
||||
"""Get all RL agents"""
|
||||
return self.agents
|
@ -1,523 +0,0 @@
|
||||
"""
|
||||
RL Training Pipeline - Scalping Agent Training
|
||||
|
||||
Comprehensive training pipeline for scalping RL agents:
|
||||
- Environment setup and management
|
||||
- Agent training with experience replay
|
||||
- Performance tracking and evaluation
|
||||
- Memory-efficient training loops
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Optional, Any
|
||||
import time
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import deque
|
||||
import random
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
# Add project imports
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from models.rl.scalping_agent import ScalpingEnvironment, ScalpingRLAgent
|
||||
from utils.model_utils import robust_save, robust_load
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RLTrainer:
|
||||
"""
|
||||
RL Training Pipeline for Scalping
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider: DataProvider, config: Optional[Dict] = None):
|
||||
self.data_provider = data_provider
|
||||
self.config = config or get_config()
|
||||
|
||||
# Training parameters
|
||||
self.num_episodes = 1000
|
||||
self.max_steps_per_episode = 1000
|
||||
self.training_frequency = 4 # Train every N steps
|
||||
self.evaluation_frequency = 50 # Evaluate every N episodes
|
||||
self.save_frequency = 100 # Save model every N episodes
|
||||
|
||||
# Environment parameters
|
||||
self.symbols = ['ETH/USDT']
|
||||
self.initial_balance = 1000.0
|
||||
self.max_position_size = 0.1
|
||||
|
||||
# Agent parameters (will be set when we know state dimension)
|
||||
self.state_dim = None
|
||||
self.action_dim = 3 # BUY, SELL, HOLD
|
||||
self.learning_rate = 1e-4
|
||||
self.memory_size = 50000
|
||||
|
||||
# Device
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Training state
|
||||
self.environment = None
|
||||
self.agent = None
|
||||
self.episode_rewards = []
|
||||
self.episode_lengths = []
|
||||
self.episode_balances = []
|
||||
self.episode_trades = []
|
||||
self.training_losses = []
|
||||
|
||||
# Performance tracking
|
||||
self.best_reward = -float('inf')
|
||||
self.best_balance = 0.0
|
||||
self.win_rates = []
|
||||
self.avg_rewards = []
|
||||
|
||||
# TensorBoard setup
|
||||
self.setup_tensorboard()
|
||||
|
||||
logger.info(f"RLTrainer initialized for symbols: {self.symbols}")
|
||||
|
||||
def setup_tensorboard(self):
|
||||
"""Setup TensorBoard logging"""
|
||||
# Create tensorboard logs directory
|
||||
log_dir = Path("runs") / f"rl_training_{int(time.time())}"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.writer = SummaryWriter(log_dir=str(log_dir))
|
||||
self.tensorboard_dir = log_dir
|
||||
|
||||
logger.info(f"TensorBoard logging to: {log_dir}")
|
||||
logger.info(f"Run: tensorboard --logdir=runs")
|
||||
|
||||
def setup_environment_and_agent(self) -> Tuple[ScalpingEnvironment, ScalpingRLAgent]:
|
||||
"""Setup trading environment and RL agent"""
|
||||
logger.info("Setting up environment and agent...")
|
||||
|
||||
# Create environment
|
||||
environment = ScalpingEnvironment(
|
||||
data_provider=self.data_provider,
|
||||
symbol=self.symbols[0],
|
||||
initial_balance=self.initial_balance,
|
||||
max_position_size=self.max_position_size
|
||||
)
|
||||
|
||||
# Get state dimension by resetting environment
|
||||
initial_state = environment.reset()
|
||||
if initial_state is None:
|
||||
raise ValueError("Could not get initial state from environment")
|
||||
|
||||
self.state_dim = len(initial_state)
|
||||
logger.info(f"State dimension: {self.state_dim}")
|
||||
|
||||
# Create agent
|
||||
agent = ScalpingRLAgent(
|
||||
state_dim=self.state_dim,
|
||||
action_dim=self.action_dim,
|
||||
learning_rate=self.learning_rate,
|
||||
memory_size=self.memory_size
|
||||
)
|
||||
|
||||
return environment, agent
|
||||
|
||||
def run_episode(self, episode_num: int, training: bool = True) -> Dict:
|
||||
"""Run a single episode"""
|
||||
state = self.environment.reset()
|
||||
if state is None:
|
||||
return {'error': 'Could not reset environment'}
|
||||
|
||||
episode_reward = 0.0
|
||||
episode_loss = 0.0
|
||||
step_count = 0
|
||||
trades_made = 0
|
||||
|
||||
# Episode loop
|
||||
for step in range(self.max_steps_per_episode):
|
||||
# Select action
|
||||
action = self.agent.act(state, training=training)
|
||||
|
||||
# Execute action in environment
|
||||
next_state, reward, done, info = self.environment.step(action, step)
|
||||
|
||||
if next_state is None:
|
||||
break
|
||||
|
||||
# Store experience if training
|
||||
if training:
|
||||
# Determine if this is a high-priority experience
|
||||
priority = (abs(reward) > 0.1 or
|
||||
info.get('trade_info', {}).get('executed', False))
|
||||
|
||||
self.agent.remember(state, action, reward, next_state, done, priority)
|
||||
|
||||
# Train agent
|
||||
if step % self.training_frequency == 0 and len(self.agent.memory) > self.agent.batch_size:
|
||||
loss = self.agent.replay()
|
||||
if loss is not None:
|
||||
episode_loss += loss
|
||||
|
||||
# Update state
|
||||
state = next_state
|
||||
episode_reward += reward
|
||||
step_count += 1
|
||||
|
||||
# Track trades
|
||||
if info.get('trade_info', {}).get('executed', False):
|
||||
trades_made += 1
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Episode results
|
||||
final_balance = info.get('balance', self.initial_balance)
|
||||
total_fees = info.get('total_fees', 0.0)
|
||||
|
||||
episode_results = {
|
||||
'episode': episode_num,
|
||||
'reward': episode_reward,
|
||||
'steps': step_count,
|
||||
'balance': final_balance,
|
||||
'trades': trades_made,
|
||||
'fees': total_fees,
|
||||
'pnl': final_balance - self.initial_balance,
|
||||
'pnl_percentage': (final_balance - self.initial_balance) / self.initial_balance * 100,
|
||||
'avg_loss': episode_loss / max(step_count // self.training_frequency, 1) if training else 0
|
||||
}
|
||||
|
||||
return episode_results
|
||||
|
||||
def evaluate_agent(self, num_episodes: int = 10) -> Dict:
|
||||
"""Evaluate agent performance"""
|
||||
logger.info(f"Evaluating agent over {num_episodes} episodes...")
|
||||
|
||||
evaluation_results = []
|
||||
total_reward = 0.0
|
||||
total_balance = 0.0
|
||||
total_trades = 0
|
||||
winning_episodes = 0
|
||||
|
||||
# Set agent to evaluation mode
|
||||
original_epsilon = self.agent.epsilon
|
||||
self.agent.epsilon = 0.0 # No exploration during evaluation
|
||||
|
||||
for episode in range(num_episodes):
|
||||
results = self.run_episode(episode, training=False)
|
||||
evaluation_results.append(results)
|
||||
|
||||
total_reward += results['reward']
|
||||
total_balance += results['balance']
|
||||
total_trades += results['trades']
|
||||
|
||||
if results['pnl'] > 0:
|
||||
winning_episodes += 1
|
||||
|
||||
# Restore original epsilon
|
||||
self.agent.epsilon = original_epsilon
|
||||
|
||||
# Calculate summary statistics
|
||||
avg_reward = total_reward / num_episodes
|
||||
avg_balance = total_balance / num_episodes
|
||||
avg_trades = total_trades / num_episodes
|
||||
win_rate = winning_episodes / num_episodes
|
||||
|
||||
evaluation_summary = {
|
||||
'num_episodes': num_episodes,
|
||||
'avg_reward': avg_reward,
|
||||
'avg_balance': avg_balance,
|
||||
'avg_pnl': avg_balance - self.initial_balance,
|
||||
'avg_pnl_percentage': (avg_balance - self.initial_balance) / self.initial_balance * 100,
|
||||
'avg_trades': avg_trades,
|
||||
'win_rate': win_rate,
|
||||
'results': evaluation_results
|
||||
}
|
||||
|
||||
logger.info(f"Evaluation complete - Avg Reward: {avg_reward:.4f}, Win Rate: {win_rate:.2%}")
|
||||
|
||||
return evaluation_summary
|
||||
|
||||
def train(self, save_path: Optional[str] = None) -> Dict:
|
||||
"""Train the RL agent"""
|
||||
logger.info("Starting RL agent training...")
|
||||
|
||||
# Setup environment and agent
|
||||
self.environment, self.agent = self.setup_environment_and_agent()
|
||||
|
||||
# Training state
|
||||
start_time = time.time()
|
||||
best_eval_reward = -float('inf')
|
||||
|
||||
# Training loop
|
||||
for episode in range(self.num_episodes):
|
||||
episode_start_time = time.time()
|
||||
|
||||
# Run training episode
|
||||
results = self.run_episode(episode, training=True)
|
||||
|
||||
# Track metrics
|
||||
self.episode_rewards.append(results['reward'])
|
||||
self.episode_lengths.append(results['steps'])
|
||||
self.episode_balances.append(results['balance'])
|
||||
self.episode_trades.append(results['trades'])
|
||||
|
||||
if results.get('avg_loss', 0) > 0:
|
||||
self.training_losses.append(results['avg_loss'])
|
||||
|
||||
# Update best metrics
|
||||
if results['reward'] > self.best_reward:
|
||||
self.best_reward = results['reward']
|
||||
|
||||
if results['balance'] > self.best_balance:
|
||||
self.best_balance = results['balance']
|
||||
|
||||
# Calculate running averages
|
||||
recent_rewards = self.episode_rewards[-100:] # Last 100 episodes
|
||||
recent_balances = self.episode_balances[-100:]
|
||||
|
||||
avg_reward = np.mean(recent_rewards)
|
||||
avg_balance = np.mean(recent_balances)
|
||||
|
||||
self.avg_rewards.append(avg_reward)
|
||||
|
||||
# Log progress
|
||||
episode_time = time.time() - episode_start_time
|
||||
|
||||
if episode % 10 == 0:
|
||||
logger.info(
|
||||
f"Episode {episode}/{self.num_episodes} - "
|
||||
f"Reward: {results['reward']:.4f}, Balance: ${results['balance']:.2f}, "
|
||||
f"Trades: {results['trades']}, PnL: {results['pnl_percentage']:.2f}%, "
|
||||
f"Epsilon: {self.agent.epsilon:.3f}, Time: {episode_time:.2f}s"
|
||||
)
|
||||
|
||||
# Evaluation
|
||||
if episode % self.evaluation_frequency == 0 and episode > 0:
|
||||
eval_results = self.evaluate_agent(num_episodes=5)
|
||||
|
||||
# Track win rate
|
||||
self.win_rates.append(eval_results['win_rate'])
|
||||
|
||||
logger.info(
|
||||
f"Evaluation - Avg Reward: {eval_results['avg_reward']:.4f}, "
|
||||
f"Win Rate: {eval_results['win_rate']:.2%}, "
|
||||
f"Avg PnL: {eval_results['avg_pnl_percentage']:.2f}%"
|
||||
)
|
||||
|
||||
# Save best model
|
||||
if eval_results['avg_reward'] > best_eval_reward:
|
||||
best_eval_reward = eval_results['avg_reward']
|
||||
if save_path:
|
||||
best_path = save_path.replace('.pt', '_best.pt')
|
||||
self.agent.save(best_path)
|
||||
logger.info(f"New best model saved: {best_path}")
|
||||
|
||||
# Save checkpoint
|
||||
if episode % self.save_frequency == 0 and episode > 0 and save_path:
|
||||
checkpoint_path = save_path.replace('.pt', f'_checkpoint_{episode}.pt')
|
||||
self.agent.save(checkpoint_path)
|
||||
logger.info(f"Checkpoint saved: {checkpoint_path}")
|
||||
|
||||
# Training complete
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"Training completed in {total_time:.2f} seconds")
|
||||
|
||||
# Final evaluation
|
||||
final_eval = self.evaluate_agent(num_episodes=20)
|
||||
|
||||
# Save final model
|
||||
if save_path:
|
||||
self.agent.save(save_path)
|
||||
logger.info(f"Final model saved: {save_path}")
|
||||
|
||||
# Prepare training results
|
||||
training_results = {
|
||||
'total_episodes': self.num_episodes,
|
||||
'total_time': total_time,
|
||||
'best_reward': self.best_reward,
|
||||
'best_balance': self.best_balance,
|
||||
'final_evaluation': final_eval,
|
||||
'episode_rewards': self.episode_rewards,
|
||||
'episode_balances': self.episode_balances,
|
||||
'episode_trades': self.episode_trades,
|
||||
'training_losses': self.training_losses,
|
||||
'avg_rewards': self.avg_rewards,
|
||||
'win_rates': self.win_rates,
|
||||
'agent_config': {
|
||||
'state_dim': self.state_dim,
|
||||
'action_dim': self.action_dim,
|
||||
'learning_rate': self.learning_rate,
|
||||
'epsilon_final': self.agent.epsilon
|
||||
}
|
||||
}
|
||||
|
||||
return training_results
|
||||
|
||||
def backtest_agent(self, agent_path: str, test_episodes: int = 50) -> Dict:
|
||||
"""Backtest trained agent"""
|
||||
logger.info(f"Backtesting agent from {agent_path}...")
|
||||
|
||||
# Setup environment and agent
|
||||
self.environment, self.agent = self.setup_environment_and_agent()
|
||||
|
||||
# Load trained agent
|
||||
self.agent.load(agent_path)
|
||||
|
||||
# Run backtest
|
||||
backtest_results = self.evaluate_agent(test_episodes)
|
||||
|
||||
# Additional analysis
|
||||
results = backtest_results['results']
|
||||
pnls = [r['pnl_percentage'] for r in results]
|
||||
rewards = [r['reward'] for r in results]
|
||||
trades = [r['trades'] for r in results]
|
||||
|
||||
analysis = {
|
||||
'total_episodes': test_episodes,
|
||||
'avg_pnl': np.mean(pnls),
|
||||
'std_pnl': np.std(pnls),
|
||||
'max_pnl': np.max(pnls),
|
||||
'min_pnl': np.min(pnls),
|
||||
'avg_reward': np.mean(rewards),
|
||||
'avg_trades': np.mean(trades),
|
||||
'win_rate': backtest_results['win_rate'],
|
||||
'profit_factor': np.sum([p for p in pnls if p > 0]) / abs(np.sum([p for p in pnls if p < 0])) if any(p < 0 for p in pnls) else float('inf'),
|
||||
'sharpe_ratio': np.mean(pnls) / np.std(pnls) if np.std(pnls) > 0 else 0,
|
||||
'max_drawdown': self._calculate_max_drawdown(pnls)
|
||||
}
|
||||
|
||||
logger.info(f"Backtest complete - Win Rate: {analysis['win_rate']:.2%}, Avg PnL: {analysis['avg_pnl']:.2f}%")
|
||||
|
||||
return {
|
||||
'backtest_results': backtest_results,
|
||||
'analysis': analysis
|
||||
}
|
||||
|
||||
def _calculate_max_drawdown(self, pnls: List[float]) -> float:
|
||||
"""Calculate maximum drawdown"""
|
||||
cumulative = np.cumsum(pnls)
|
||||
running_max = np.maximum.accumulate(cumulative)
|
||||
drawdowns = running_max - cumulative
|
||||
return np.max(drawdowns) if len(drawdowns) > 0 else 0.0
|
||||
|
||||
def plot_training_progress(self, save_path: Optional[str] = None):
|
||||
"""Plot training progress"""
|
||||
if not self.episode_rewards:
|
||||
logger.warning("No training data to plot")
|
||||
return
|
||||
|
||||
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
|
||||
|
||||
episodes = range(1, len(self.episode_rewards) + 1)
|
||||
|
||||
# Episode rewards
|
||||
ax1.plot(episodes, self.episode_rewards, alpha=0.6, label='Episode Reward')
|
||||
if self.avg_rewards:
|
||||
ax1.plot(episodes, self.avg_rewards, 'r-', label='Avg Reward (100 episodes)')
|
||||
ax1.set_title('Training Rewards')
|
||||
ax1.set_xlabel('Episode')
|
||||
ax1.set_ylabel('Reward')
|
||||
ax1.legend()
|
||||
ax1.grid(True)
|
||||
|
||||
# Episode balances
|
||||
ax2.plot(episodes, self.episode_balances, alpha=0.6, label='Episode Balance')
|
||||
ax2.axhline(y=self.initial_balance, color='r', linestyle='--', label='Initial Balance')
|
||||
ax2.set_title('Portfolio Balance')
|
||||
ax2.set_xlabel('Episode')
|
||||
ax2.set_ylabel('Balance ($)')
|
||||
ax2.legend()
|
||||
ax2.grid(True)
|
||||
|
||||
# Training losses
|
||||
if self.training_losses:
|
||||
loss_episodes = np.linspace(1, len(self.episode_rewards), len(self.training_losses))
|
||||
ax3.plot(loss_episodes, self.training_losses, 'g-', alpha=0.8)
|
||||
ax3.set_title('Training Loss')
|
||||
ax3.set_xlabel('Episode')
|
||||
ax3.set_ylabel('Loss')
|
||||
ax3.grid(True)
|
||||
|
||||
# Win rates
|
||||
if self.win_rates:
|
||||
eval_episodes = np.arange(self.evaluation_frequency,
|
||||
len(self.episode_rewards) + 1,
|
||||
self.evaluation_frequency)[:len(self.win_rates)]
|
||||
ax4.plot(eval_episodes, self.win_rates, 'purple', marker='o')
|
||||
ax4.set_title('Win Rate')
|
||||
ax4.set_xlabel('Episode')
|
||||
ax4.set_ylabel('Win Rate')
|
||||
ax4.grid(True)
|
||||
ax4.set_ylim(0, 1)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
logger.info(f"Training progress plot saved: {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
def log_episode_metrics(self, episode: int, metrics: Dict):
|
||||
"""Log episode metrics to TensorBoard"""
|
||||
# Main performance metrics
|
||||
self.writer.add_scalar('Episode/TotalReward', metrics['total_reward'], episode)
|
||||
self.writer.add_scalar('Episode/FinalBalance', metrics['final_balance'], episode)
|
||||
self.writer.add_scalar('Episode/TotalReturn', metrics['total_return'], episode)
|
||||
self.writer.add_scalar('Episode/Steps', metrics['steps'], episode)
|
||||
|
||||
# Trading metrics
|
||||
self.writer.add_scalar('Trading/TotalTrades', metrics['total_trades'], episode)
|
||||
self.writer.add_scalar('Trading/WinRate', metrics['win_rate'], episode)
|
||||
self.writer.add_scalar('Trading/ProfitFactor', metrics.get('profit_factor', 0), episode)
|
||||
self.writer.add_scalar('Trading/MaxDrawdown', metrics.get('max_drawdown', 0), episode)
|
||||
|
||||
# Agent metrics
|
||||
self.writer.add_scalar('Agent/Epsilon', metrics['epsilon'], episode)
|
||||
self.writer.add_scalar('Agent/LearningRate', metrics.get('learning_rate', self.learning_rate), episode)
|
||||
self.writer.add_scalar('Agent/MemorySize', metrics.get('memory_size', 0), episode)
|
||||
|
||||
# Loss metrics (if available)
|
||||
if 'loss' in metrics:
|
||||
self.writer.add_scalar('Agent/Loss', metrics['loss'], episode)
|
||||
|
||||
class HybridTrainer:
|
||||
"""
|
||||
Hybrid training pipeline combining CNN and RL
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider: DataProvider):
|
||||
self.data_provider = data_provider
|
||||
self.cnn_trainer = None
|
||||
self.rl_trainer = None
|
||||
|
||||
def train_hybrid(self, symbols: List[str], cnn_save_path: str, rl_save_path: str) -> Dict:
|
||||
"""Train CNN first, then RL with CNN features"""
|
||||
logger.info("Starting hybrid CNN + RL training...")
|
||||
|
||||
# Phase 1: Train CNN
|
||||
logger.info("Phase 1: Training CNN...")
|
||||
from training.cnn_trainer import CNNTrainer
|
||||
|
||||
self.cnn_trainer = CNNTrainer(self.data_provider)
|
||||
cnn_results = self.cnn_trainer.train(symbols, cnn_save_path)
|
||||
|
||||
# Phase 2: Train RL
|
||||
logger.info("Phase 2: Training RL...")
|
||||
self.rl_trainer = RLTrainer(self.data_provider)
|
||||
rl_results = self.rl_trainer.train(rl_save_path)
|
||||
|
||||
# Combine results
|
||||
hybrid_results = {
|
||||
'cnn_results': cnn_results,
|
||||
'rl_results': rl_results,
|
||||
'total_time': cnn_results['total_time'] + rl_results['total_time']
|
||||
}
|
||||
|
||||
logger.info("Hybrid training completed!")
|
||||
return hybrid_results
|
||||
|
||||
# Export
|
||||
__all__ = ['RLTrainer', 'HybridTrainer']
|
Reference in New Issue
Block a user