Files
gogo2/core/enhanced_cnn_adapter.py
2025-07-23 22:11:19 +03:00

430 lines
17 KiB
Python

"""
Enhanced CNN Adapter for Standardized Input Format
This module provides an adapter for the EnhancedCNN model to work with the standardized
BaseDataInput format, enabling seamless integration with the multi-modal trading system.
"""
import torch
import numpy as np
import logging
import os
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Any, Union
from threading import Lock
from .data_models import BaseDataInput, ModelOutput, create_model_output
from NN.models.enhanced_cnn import EnhancedCNN
logger = logging.getLogger(__name__)
class EnhancedCNNAdapter:
"""
Adapter for EnhancedCNN model to work with standardized BaseDataInput format
This adapter:
1. Converts BaseDataInput to the format expected by EnhancedCNN
2. Processes model outputs to create standardized ModelOutput
3. Manages model training with collected data
4. Handles checkpoint management
"""
def __init__(self, model_path: str = None, checkpoint_dir: str = "models/enhanced_cnn"):
"""
Initialize the EnhancedCNN adapter
Args:
model_path: Path to load model from, if None a new model is created
checkpoint_dir: Directory to save checkpoints to
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = None
self.model_path = model_path
self.checkpoint_dir = checkpoint_dir
self.training_lock = Lock()
self.training_data = []
self.max_training_samples = 10000
self.batch_size = 32
self.learning_rate = 0.0001
self.model_name = "enhanced_cnn_v1"
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
# Initialize model
self._initialize_model()
logger.info(f"EnhancedCNNAdapter initialized with device: {self.device}")
def _initialize_model(self):
"""Initialize the EnhancedCNN model"""
try:
# Calculate input shape based on BaseDataInput structure
# OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features
# BTC OHLCV: 300 frames x 5 features = 1500 features
# COB: ±20 buckets x 4 metrics = 160 features
# MA: 4 timeframes x 10 buckets = 40 features
# Technical indicators: 100 features
# Last predictions: 50 features
# Total: 7850 features
input_shape = 7850
n_actions = 3 # BUY, SELL, HOLD
# Create model
self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
self.model.to(self.device)
# Load model if path is provided
if self.model_path:
success = self.model.load(self.model_path)
if success:
logger.info(f"Model loaded from {self.model_path}")
else:
logger.warning(f"Failed to load model from {self.model_path}, using new model")
else:
logger.info("No model path provided, using new model")
except Exception as e:
logger.error(f"Error initializing EnhancedCNN model: {e}")
raise
def _convert_base_data_to_features(self, base_data: BaseDataInput) -> torch.Tensor:
"""
Convert BaseDataInput to feature vector for EnhancedCNN
Args:
base_data: Standardized input data
Returns:
torch.Tensor: Feature vector for EnhancedCNN
"""
try:
# Use the get_feature_vector method from BaseDataInput
features = base_data.get_feature_vector()
# Convert to torch tensor
features_tensor = torch.tensor(features, dtype=torch.float32, device=self.device)
return features_tensor
except Exception as e:
logger.error(f"Error converting BaseDataInput to features: {e}")
# Return empty tensor with correct shape
return torch.zeros(7850, dtype=torch.float32, device=self.device)
def predict(self, base_data: BaseDataInput) -> ModelOutput:
"""
Make a prediction using the EnhancedCNN model
Args:
base_data: Standardized input data
Returns:
ModelOutput: Standardized model output
"""
try:
# Convert BaseDataInput to features
features = self._convert_base_data_to_features(base_data)
# Ensure features has batch dimension
if features.dim() == 1:
features = features.unsqueeze(0)
# Set model to evaluation mode
self.model.eval()
# Make prediction
with torch.no_grad():
q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.model(features)
# Get action and confidence
action_probs = torch.softmax(q_values, dim=1)
action_idx = torch.argmax(action_probs, dim=1).item()
confidence = float(action_probs[0, action_idx].item())
# Map action index to action string
actions = ['BUY', 'SELL', 'HOLD']
action = actions[action_idx]
# Create predictions dictionary
predictions = {
'action': action,
'buy_probability': float(action_probs[0, 0].item()),
'sell_probability': float(action_probs[0, 1].item()),
'hold_probability': float(action_probs[0, 2].item()),
'extrema': extrema_pred.squeeze(0).cpu().numpy().tolist(),
'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist()
}
# Create hidden states dictionary
hidden_states = {
'features': features_refined.squeeze(0).cpu().numpy().tolist()
}
# Create metadata dictionary
metadata = {
'model_version': '1.0',
'timestamp': datetime.now().isoformat(),
'input_shape': features.shape
}
# Create ModelOutput
model_output = ModelOutput(
model_type='cnn',
model_name=self.model_name,
symbol=base_data.symbol,
timestamp=datetime.now(),
confidence=confidence,
predictions=predictions,
hidden_states=hidden_states,
metadata=metadata
)
return model_output
except Exception as e:
logger.error(f"Error making prediction with EnhancedCNN: {e}")
# Return default ModelOutput
return create_model_output(
model_type='cnn',
model_name=self.model_name,
symbol=base_data.symbol,
action='HOLD',
confidence=0.0
)
def add_training_sample(self, base_data: BaseDataInput, actual_action: str, reward: float):
"""
Add a training sample to the training data
Args:
base_data: Standardized input data
actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
reward: Reward received for the action
"""
try:
# Convert BaseDataInput to features
features = self._convert_base_data_to_features(base_data)
# Convert action to index
actions = ['BUY', 'SELL', 'HOLD']
action_idx = actions.index(actual_action)
# Add to training data
with self.training_lock:
self.training_data.append((features, action_idx, reward))
# Limit training data size
if len(self.training_data) > self.max_training_samples:
# Sort by reward (highest first) and keep top samples
self.training_data.sort(key=lambda x: x[2], reverse=True)
self.training_data = self.training_data[:self.max_training_samples]
logger.debug(f"Added training sample for {base_data.symbol}, action: {actual_action}, reward: {reward:.4f}")
except Exception as e:
logger.error(f"Error adding training sample: {e}")
def train(self, epochs: int = 1) -> Dict[str, float]:
"""
Train the model with collected data
Args:
epochs: Number of epochs to train for
Returns:
Dict[str, float]: Training metrics
"""
try:
with self.training_lock:
# Check if we have enough data
if len(self.training_data) < self.batch_size:
logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
# Set model to training mode
self.model.train()
# Create optimizer
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
# Training metrics
total_loss = 0.0
correct_predictions = 0
total_predictions = 0
# Train for specified number of epochs
for epoch in range(epochs):
# Shuffle training data
np.random.shuffle(self.training_data)
# Process in batches
for i in range(0, len(self.training_data), self.batch_size):
batch = self.training_data[i:i+self.batch_size]
# Skip if batch is too small
if len(batch) < 2:
continue
# Prepare batch
features = torch.stack([sample[0] for sample in batch])
actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
q_values, _, _, _, _ = self.model(features)
# Calculate loss (CrossEntropyLoss with reward weighting)
# First, apply softmax to get probabilities
probs = torch.softmax(q_values, dim=1)
# Get probability of chosen action
chosen_probs = probs[torch.arange(len(actions)), actions]
# Calculate negative log likelihood loss
nll_loss = -torch.log(chosen_probs + 1e-10)
# Weight by reward (higher reward = higher weight)
# Normalize rewards to [0, 1] range
min_reward = rewards.min()
max_reward = rewards.max()
if max_reward > min_reward:
normalized_rewards = (rewards - min_reward) / (max_reward - min_reward)
else:
normalized_rewards = torch.ones_like(rewards)
# Apply reward weighting (higher reward = higher weight)
weighted_loss = nll_loss * (normalized_rewards + 0.1) # Add small constant to avoid zero weights
# Mean loss
loss = weighted_loss.mean()
# Backward pass
loss.backward()
# Update weights
optimizer.step()
# Update metrics
total_loss += loss.item()
# Calculate accuracy
predicted_actions = torch.argmax(q_values, dim=1)
correct_predictions += (predicted_actions == actions).sum().item()
total_predictions += len(actions)
# Calculate final metrics
avg_loss = total_loss / (len(self.training_data) / self.batch_size)
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
# Save checkpoint
self._save_checkpoint(avg_loss, accuracy)
logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}")
return {
'loss': avg_loss,
'accuracy': accuracy,
'samples': len(self.training_data)
}
except Exception as e:
logger.error(f"Error training model: {e}")
return {'loss': 0.0, 'accuracy': 0.0, 'samples': 0, 'error': str(e)}
def _save_checkpoint(self, loss: float, accuracy: float):
"""
Save model checkpoint
Args:
loss: Training loss
accuracy: Training accuracy
"""
try:
# Import checkpoint manager
from utils.checkpoint_manager import CheckpointManager
# Create checkpoint manager
checkpoint_manager = CheckpointManager(
checkpoint_dir=self.checkpoint_dir,
max_checkpoints=10,
metric_name="accuracy"
)
# Create temporary model file
temp_path = os.path.join(self.checkpoint_dir, f"{self.model_name}_temp")
self.model.save(temp_path)
# Create metrics
metrics = {
'loss': loss,
'accuracy': accuracy,
'samples': len(self.training_data)
}
# Create metadata
metadata = {
'timestamp': datetime.now().isoformat(),
'model_name': self.model_name,
'input_shape': self.model.input_shape,
'n_actions': self.model.n_actions
}
# Save checkpoint
checkpoint_path = checkpoint_manager.save_checkpoint(
model_name=self.model_name,
model_path=f"{temp_path}.pt",
metrics=metrics,
metadata=metadata
)
# Delete temporary model file
if os.path.exists(f"{temp_path}.pt"):
os.remove(f"{temp_path}.pt")
logger.info(f"Model checkpoint saved to {checkpoint_path}")
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
def load_best_checkpoint(self):
"""Load the best checkpoint based on accuracy"""
try:
# Import checkpoint manager
from utils.checkpoint_manager import CheckpointManager
# Create checkpoint manager
checkpoint_manager = CheckpointManager(
checkpoint_dir=self.checkpoint_dir,
max_checkpoints=10,
metric_name="accuracy"
)
# Load best checkpoint
best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name)
if not best_checkpoint_path:
logger.info("No checkpoints found")
return False
# Load model
success = self.model.load(best_checkpoint_path)
if success:
logger.info(f"Loaded best checkpoint from {best_checkpoint_path}")
# Log metrics
metrics = best_checkpoint_metadata.get('metrics', {})
logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}")
return True
else:
logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}")
return False
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return False