551 lines
22 KiB
Python
551 lines
22 KiB
Python
"""
|
|
Enhanced CNN Adapter for Standardized Input Format
|
|
|
|
This module provides an adapter for the EnhancedCNN model to work with the standardized
|
|
BaseDataInput format, enabling seamless integration with the multi-modal trading system.
|
|
"""
|
|
|
|
import torch
|
|
import numpy as np
|
|
import logging
|
|
import os
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Tuple, Any, Union
|
|
from threading import Lock
|
|
|
|
from .data_models import BaseDataInput, ModelOutput, create_model_output
|
|
from NN.models.enhanced_cnn import EnhancedCNN
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class EnhancedCNNAdapter:
|
|
"""
|
|
Adapter for EnhancedCNN model to work with standardized BaseDataInput format
|
|
|
|
This adapter:
|
|
1. Converts BaseDataInput to the format expected by EnhancedCNN
|
|
2. Processes model outputs to create standardized ModelOutput
|
|
3. Manages model training with collected data
|
|
4. Handles checkpoint management
|
|
"""
|
|
|
|
def __init__(self, model_path: str = None, checkpoint_dir: str = "models/enhanced_cnn"):
|
|
"""
|
|
Initialize the EnhancedCNN adapter
|
|
|
|
Args:
|
|
model_path: Path to load model from, if None a new model is created
|
|
checkpoint_dir: Directory to save checkpoints to
|
|
"""
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.model = None
|
|
self.model_path = model_path
|
|
self.checkpoint_dir = checkpoint_dir
|
|
self.training_lock = Lock()
|
|
self.training_data = []
|
|
self.max_training_samples = 10000
|
|
self.batch_size = 32
|
|
self.learning_rate = 0.0001
|
|
self.model_name = "enhanced_cnn_v1"
|
|
|
|
# Enhanced metrics tracking
|
|
self.last_inference_time = None
|
|
self.last_inference_duration = 0.0
|
|
self.last_prediction_output = None
|
|
self.last_training_time = None
|
|
self.last_training_duration = 0.0
|
|
self.last_training_loss = 0.0
|
|
self.inference_count = 0
|
|
self.training_count = 0
|
|
|
|
# Create checkpoint directory if it doesn't exist
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
# Initialize the model
|
|
self._initialize_model()
|
|
|
|
# Load checkpoint if available
|
|
if model_path and os.path.exists(model_path):
|
|
self._load_checkpoint(model_path)
|
|
else:
|
|
self._load_best_checkpoint()
|
|
|
|
logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
|
|
|
|
def _load_checkpoint(self, checkpoint_path: str) -> bool:
|
|
"""Load model from checkpoint path"""
|
|
try:
|
|
if self.model and os.path.exists(checkpoint_path):
|
|
success = self.model.load(checkpoint_path)
|
|
if success:
|
|
logger.info(f"Loaded model from {checkpoint_path}")
|
|
return True
|
|
else:
|
|
logger.warning(f"Failed to load model from {checkpoint_path}")
|
|
return False
|
|
else:
|
|
logger.warning(f"Checkpoint path does not exist: {checkpoint_path}")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Error loading checkpoint: {e}")
|
|
return False
|
|
|
|
def _load_best_checkpoint(self) -> bool:
|
|
"""Load the best available checkpoint"""
|
|
try:
|
|
return self.load_best_checkpoint()
|
|
except Exception as e:
|
|
logger.error(f"Error loading best checkpoint: {e}")
|
|
return False
|
|
|
|
|
|
|
|
def _create_default_output(self, symbol: str) -> ModelOutput:
|
|
"""Create default output when prediction fails"""
|
|
return create_model_output(
|
|
model_type='cnn',
|
|
model_name=self.model_name,
|
|
symbol=symbol,
|
|
action='HOLD',
|
|
confidence=0.0,
|
|
metadata={'error': 'Prediction failed, using default output'}
|
|
)
|
|
|
|
def _process_hidden_states(self, hidden_states: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Process hidden states for cross-model feeding"""
|
|
processed_states = {}
|
|
|
|
for key, value in hidden_states.items():
|
|
if isinstance(value, torch.Tensor):
|
|
# Convert tensor to numpy array
|
|
processed_states[key] = value.cpu().numpy().tolist()
|
|
else:
|
|
processed_states[key] = value
|
|
|
|
return processed_states
|
|
|
|
def _initialize_model(self):
|
|
"""Initialize the EnhancedCNN model"""
|
|
try:
|
|
# Calculate input shape based on BaseDataInput structure
|
|
# OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features
|
|
# BTC OHLCV: 300 frames x 5 features = 1500 features
|
|
# COB: ±20 buckets x 4 metrics = 160 features
|
|
# MA: 4 timeframes x 10 buckets = 40 features
|
|
# Technical indicators: 100 features
|
|
# Last predictions: 50 features
|
|
# Total: 7850 features
|
|
input_shape = 7850
|
|
n_actions = 3 # BUY, SELL, HOLD
|
|
|
|
# Create model
|
|
self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
|
|
self.model.to(self.device)
|
|
|
|
# Load model if path is provided
|
|
if self.model_path:
|
|
success = self.model.load(self.model_path)
|
|
if success:
|
|
logger.info(f"Model loaded from {self.model_path}")
|
|
else:
|
|
logger.warning(f"Failed to load model from {self.model_path}, using new model")
|
|
else:
|
|
logger.info("No model path provided, using new model")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing EnhancedCNN model: {e}")
|
|
raise
|
|
|
|
def _convert_base_data_to_features(self, base_data: BaseDataInput) -> torch.Tensor:
|
|
"""
|
|
Convert BaseDataInput to feature vector for EnhancedCNN
|
|
|
|
Args:
|
|
base_data: Standardized input data
|
|
|
|
Returns:
|
|
torch.Tensor: Feature vector for EnhancedCNN
|
|
"""
|
|
try:
|
|
# Use the get_feature_vector method from BaseDataInput
|
|
features = base_data.get_feature_vector()
|
|
|
|
# Convert to torch tensor
|
|
features_tensor = torch.tensor(features, dtype=torch.float32, device=self.device)
|
|
|
|
return features_tensor
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error converting BaseDataInput to features: {e}")
|
|
# Return empty tensor with correct shape
|
|
return torch.zeros(7850, dtype=torch.float32, device=self.device)
|
|
|
|
def predict(self, base_data: BaseDataInput) -> ModelOutput:
|
|
"""
|
|
Make a prediction using the EnhancedCNN model
|
|
|
|
Args:
|
|
base_data: Standardized input data
|
|
|
|
Returns:
|
|
ModelOutput: Standardized model output
|
|
"""
|
|
try:
|
|
# Track inference timing
|
|
start_time = datetime.now()
|
|
inference_start = start_time.timestamp()
|
|
|
|
# Convert BaseDataInput to features
|
|
features = self._convert_base_data_to_features(base_data)
|
|
|
|
# Ensure features has batch dimension
|
|
if features.dim() == 1:
|
|
features = features.unsqueeze(0)
|
|
|
|
# Set model to evaluation mode
|
|
self.model.eval()
|
|
|
|
# Make prediction
|
|
with torch.no_grad():
|
|
q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.model(features)
|
|
|
|
# Get action and confidence
|
|
action_probs = torch.softmax(q_values, dim=1)
|
|
action_idx = torch.argmax(action_probs, dim=1).item()
|
|
confidence = float(action_probs[0, action_idx].item())
|
|
|
|
# Map action index to action string
|
|
actions = ['BUY', 'SELL', 'HOLD']
|
|
action = actions[action_idx]
|
|
|
|
# Extract pivot price prediction (simplified - take first value from price_pred)
|
|
pivot_price = None
|
|
if price_pred is not None and len(price_pred.squeeze()) > 0:
|
|
# Get current price from base_data for context
|
|
current_price = 0.0
|
|
if base_data.ohlcv_1s and len(base_data.ohlcv_1s) > 0:
|
|
current_price = base_data.ohlcv_1s[-1].close
|
|
|
|
# Calculate pivot price as current price + predicted change
|
|
price_change_pct = float(price_pred.squeeze()[0].item()) # First prediction value
|
|
pivot_price = current_price * (1 + price_change_pct * 0.01) # Convert percentage to price
|
|
|
|
# Create predictions dictionary
|
|
predictions = {
|
|
'action': action,
|
|
'buy_probability': float(action_probs[0, 0].item()),
|
|
'sell_probability': float(action_probs[0, 1].item()),
|
|
'hold_probability': float(action_probs[0, 2].item()),
|
|
'extrema': extrema_pred.squeeze(0).cpu().numpy().tolist(),
|
|
'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist(),
|
|
'pivot_price': pivot_price
|
|
}
|
|
|
|
# Create hidden states dictionary
|
|
hidden_states = {
|
|
'features': features_refined.squeeze(0).cpu().numpy().tolist()
|
|
}
|
|
|
|
# Calculate inference duration
|
|
end_time = datetime.now()
|
|
inference_duration = (end_time.timestamp() - inference_start) * 1000 # Convert to milliseconds
|
|
|
|
# Update metrics
|
|
self.last_inference_time = start_time
|
|
self.last_inference_duration = inference_duration
|
|
self.inference_count += 1
|
|
|
|
# Store last prediction output for dashboard
|
|
self.last_prediction_output = {
|
|
'action': action,
|
|
'confidence': confidence,
|
|
'pivot_price': pivot_price,
|
|
'timestamp': start_time,
|
|
'symbol': base_data.symbol
|
|
}
|
|
|
|
# Create metadata dictionary
|
|
metadata = {
|
|
'model_version': '1.0',
|
|
'timestamp': start_time.isoformat(),
|
|
'input_shape': features.shape,
|
|
'inference_duration_ms': inference_duration,
|
|
'inference_count': self.inference_count
|
|
}
|
|
|
|
# Create ModelOutput
|
|
model_output = ModelOutput(
|
|
model_type='cnn',
|
|
model_name=self.model_name,
|
|
symbol=base_data.symbol,
|
|
timestamp=start_time,
|
|
confidence=confidence,
|
|
predictions=predictions,
|
|
hidden_states=hidden_states,
|
|
metadata=metadata
|
|
)
|
|
|
|
return model_output
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error making prediction with EnhancedCNN: {e}")
|
|
# Return default ModelOutput
|
|
return create_model_output(
|
|
model_type='cnn',
|
|
model_name=self.model_name,
|
|
symbol=base_data.symbol,
|
|
action='HOLD',
|
|
confidence=0.0
|
|
)
|
|
|
|
def add_training_sample(self, base_data: BaseDataInput, actual_action: str, reward: float):
|
|
"""
|
|
Add a training sample to the training data
|
|
|
|
Args:
|
|
base_data: Standardized input data
|
|
actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
|
|
reward: Reward received for the action
|
|
"""
|
|
try:
|
|
# Convert BaseDataInput to features
|
|
features = self._convert_base_data_to_features(base_data)
|
|
|
|
# Convert action to index
|
|
actions = ['BUY', 'SELL', 'HOLD']
|
|
action_idx = actions.index(actual_action)
|
|
|
|
# Add to training data
|
|
with self.training_lock:
|
|
self.training_data.append((features, action_idx, reward))
|
|
|
|
# Limit training data size
|
|
if len(self.training_data) > self.max_training_samples:
|
|
# Sort by reward (highest first) and keep top samples
|
|
self.training_data.sort(key=lambda x: x[2], reverse=True)
|
|
self.training_data = self.training_data[:self.max_training_samples]
|
|
|
|
logger.debug(f"Added training sample for {base_data.symbol}, action: {actual_action}, reward: {reward:.4f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding training sample: {e}")
|
|
|
|
def train(self, epochs: int = 1) -> Dict[str, float]:
|
|
"""
|
|
Train the model with collected data
|
|
|
|
Args:
|
|
epochs: Number of epochs to train for
|
|
|
|
Returns:
|
|
Dict[str, float]: Training metrics
|
|
"""
|
|
try:
|
|
# Track training timing
|
|
training_start_time = datetime.now()
|
|
training_start = training_start_time.timestamp()
|
|
|
|
with self.training_lock:
|
|
# Check if we have enough data
|
|
if len(self.training_data) < self.batch_size:
|
|
logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
|
|
return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
|
|
|
|
# Set model to training mode
|
|
self.model.train()
|
|
|
|
# Create optimizer
|
|
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
|
|
|
|
# Training metrics
|
|
total_loss = 0.0
|
|
correct_predictions = 0
|
|
total_predictions = 0
|
|
|
|
# Train for specified number of epochs
|
|
for epoch in range(epochs):
|
|
# Shuffle training data
|
|
np.random.shuffle(self.training_data)
|
|
|
|
# Process in batches
|
|
for i in range(0, len(self.training_data), self.batch_size):
|
|
batch = self.training_data[i:i+self.batch_size]
|
|
|
|
# Skip if batch is too small
|
|
if len(batch) < 2:
|
|
continue
|
|
|
|
# Prepare batch
|
|
features = torch.stack([sample[0] for sample in batch])
|
|
actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
|
|
rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)
|
|
|
|
# Zero gradients
|
|
optimizer.zero_grad()
|
|
|
|
# Forward pass
|
|
q_values, _, _, _, _ = self.model(features)
|
|
|
|
# Calculate loss (CrossEntropyLoss with reward weighting)
|
|
# First, apply softmax to get probabilities
|
|
probs = torch.softmax(q_values, dim=1)
|
|
|
|
# Get probability of chosen action
|
|
chosen_probs = probs[torch.arange(len(actions)), actions]
|
|
|
|
# Calculate negative log likelihood loss
|
|
nll_loss = -torch.log(chosen_probs + 1e-10)
|
|
|
|
# Weight by reward (higher reward = higher weight)
|
|
# Normalize rewards to [0, 1] range
|
|
min_reward = rewards.min()
|
|
max_reward = rewards.max()
|
|
if max_reward > min_reward:
|
|
normalized_rewards = (rewards - min_reward) / (max_reward - min_reward)
|
|
else:
|
|
normalized_rewards = torch.ones_like(rewards)
|
|
|
|
# Apply reward weighting (higher reward = higher weight)
|
|
weighted_loss = nll_loss * (normalized_rewards + 0.1) # Add small constant to avoid zero weights
|
|
|
|
# Mean loss
|
|
loss = weighted_loss.mean()
|
|
|
|
# Backward pass
|
|
loss.backward()
|
|
|
|
# Update weights
|
|
optimizer.step()
|
|
|
|
# Update metrics
|
|
total_loss += loss.item()
|
|
|
|
# Calculate accuracy
|
|
predicted_actions = torch.argmax(q_values, dim=1)
|
|
correct_predictions += (predicted_actions == actions).sum().item()
|
|
total_predictions += len(actions)
|
|
|
|
# Calculate final metrics
|
|
avg_loss = total_loss / (len(self.training_data) / self.batch_size)
|
|
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
|
|
|
# Calculate training duration
|
|
training_end_time = datetime.now()
|
|
training_duration = (training_end_time.timestamp() - training_start) * 1000 # Convert to milliseconds
|
|
|
|
# Update training metrics
|
|
self.last_training_time = training_start_time
|
|
self.last_training_duration = training_duration
|
|
self.last_training_loss = avg_loss
|
|
self.training_count += 1
|
|
|
|
# Save checkpoint
|
|
self._save_checkpoint(avg_loss, accuracy)
|
|
|
|
logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}, duration={training_duration:.1f}ms")
|
|
|
|
return {
|
|
'loss': avg_loss,
|
|
'accuracy': accuracy,
|
|
'samples': len(self.training_data),
|
|
'duration_ms': training_duration,
|
|
'training_count': self.training_count
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training model: {e}")
|
|
return {'loss': 0.0, 'accuracy': 0.0, 'samples': 0, 'error': str(e)}
|
|
|
|
def _save_checkpoint(self, loss: float, accuracy: float):
|
|
"""
|
|
Save model checkpoint
|
|
|
|
Args:
|
|
loss: Training loss
|
|
accuracy: Training accuracy
|
|
"""
|
|
try:
|
|
# Import checkpoint manager
|
|
from utils.checkpoint_manager import CheckpointManager
|
|
|
|
# Create checkpoint manager
|
|
checkpoint_manager = CheckpointManager(
|
|
checkpoint_dir=self.checkpoint_dir,
|
|
max_checkpoints=10,
|
|
metric_name="accuracy"
|
|
)
|
|
|
|
# Create temporary model file
|
|
temp_path = os.path.join(self.checkpoint_dir, f"{self.model_name}_temp")
|
|
self.model.save(temp_path)
|
|
|
|
# Create metrics
|
|
metrics = {
|
|
'loss': loss,
|
|
'accuracy': accuracy,
|
|
'samples': len(self.training_data)
|
|
}
|
|
|
|
# Create metadata
|
|
metadata = {
|
|
'timestamp': datetime.now().isoformat(),
|
|
'model_name': self.model_name,
|
|
'input_shape': self.model.input_shape,
|
|
'n_actions': self.model.n_actions
|
|
}
|
|
|
|
# Save checkpoint
|
|
checkpoint_path = checkpoint_manager.save_checkpoint(
|
|
model_name=self.model_name,
|
|
model_path=f"{temp_path}.pt",
|
|
metrics=metrics,
|
|
metadata=metadata
|
|
)
|
|
|
|
# Delete temporary model file
|
|
if os.path.exists(f"{temp_path}.pt"):
|
|
os.remove(f"{temp_path}.pt")
|
|
|
|
logger.info(f"Model checkpoint saved to {checkpoint_path}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving checkpoint: {e}")
|
|
|
|
def load_best_checkpoint(self):
|
|
"""Load the best checkpoint based on accuracy"""
|
|
try:
|
|
# Import checkpoint manager
|
|
from utils.checkpoint_manager import CheckpointManager
|
|
|
|
# Create checkpoint manager
|
|
checkpoint_manager = CheckpointManager(
|
|
checkpoint_dir=self.checkpoint_dir,
|
|
max_checkpoints=10,
|
|
metric_name="accuracy"
|
|
)
|
|
|
|
# Load best checkpoint
|
|
best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name)
|
|
|
|
if not best_checkpoint_path:
|
|
logger.info("No checkpoints found")
|
|
return False
|
|
|
|
# Load model
|
|
success = self.model.load(best_checkpoint_path)
|
|
|
|
if success:
|
|
logger.info(f"Loaded best checkpoint from {best_checkpoint_path}")
|
|
|
|
# Log metrics
|
|
metrics = best_checkpoint_metadata.get('metrics', {})
|
|
logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}")
|
|
|
|
return True
|
|
else:
|
|
logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading best checkpoint: {e}")
|
|
return False |