new cnn model
This commit is contained in:
482
NN/models/standardized_cnn.py
Normal file
482
NN/models/standardized_cnn.py
Normal file
@ -0,0 +1,482 @@
|
||||
"""
|
||||
Standardized CNN Model for Multi-Modal Trading System
|
||||
|
||||
This module extends the existing EnhancedCNN to work with standardized BaseDataInput format
|
||||
and provides ModelOutput for cross-model feeding.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the project root to the path to import core modules
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
from core.data_models import BaseDataInput, ModelOutput, create_model_output
|
||||
from .enhanced_cnn import EnhancedCNN, SelfAttention, ResidualBlock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class StandardizedCNN(nn.Module):
|
||||
"""
|
||||
Standardized CNN Model that accepts BaseDataInput and outputs ModelOutput
|
||||
|
||||
Features:
|
||||
- Accepts standardized BaseDataInput format
|
||||
- Processes COB+OHLCV data: 300 frames (1s,1m,1h,1d) ETH + 300s 1s BTC
|
||||
- Includes COB ±20 buckets and MA (1s,5s,15s,60s) of COB imbalance ±5 buckets
|
||||
- Outputs BUY/SELL trading action with confidence scores
|
||||
- Provides hidden states for cross-model feeding
|
||||
- Integrates with checkpoint management system
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "standardized_cnn_v1", confidence_threshold: float = 0.6):
|
||||
"""
|
||||
Initialize the standardized CNN model
|
||||
|
||||
Args:
|
||||
model_name: Name identifier for this model instance
|
||||
confidence_threshold: Minimum confidence threshold for predictions
|
||||
"""
|
||||
super(StandardizedCNN, self).__init__()
|
||||
|
||||
self.model_name = model_name
|
||||
self.model_type = "cnn"
|
||||
self.confidence_threshold = confidence_threshold
|
||||
|
||||
# Calculate expected input dimensions from BaseDataInput
|
||||
self.expected_feature_dim = self._calculate_expected_features()
|
||||
|
||||
# Initialize the underlying enhanced CNN with calculated dimensions
|
||||
self.enhanced_cnn = EnhancedCNN(
|
||||
input_shape=self.expected_feature_dim,
|
||||
n_actions=3, # BUY, SELL, HOLD
|
||||
confidence_threshold=confidence_threshold
|
||||
)
|
||||
|
||||
# Additional layers for processing BaseDataInput structure
|
||||
self.input_processor = self._build_input_processor()
|
||||
|
||||
# Output processing layers
|
||||
self.output_processor = self._build_output_processor()
|
||||
|
||||
# Device management
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.to(self.device)
|
||||
|
||||
logger.info(f"StandardizedCNN '{model_name}' initialized")
|
||||
logger.info(f"Expected feature dimension: {self.expected_feature_dim}")
|
||||
logger.info(f"Device: {self.device}")
|
||||
|
||||
def _calculate_expected_features(self) -> int:
|
||||
"""
|
||||
Calculate expected feature dimension from BaseDataInput structure
|
||||
|
||||
Based on actual BaseDataInput.get_feature_vector():
|
||||
- OHLCV ETH: 300 frames x 4 timeframes x 5 features = 6000
|
||||
- OHLCV BTC: 300 frames x 5 features = 1500
|
||||
- COB features: ~184 features (actual from implementation)
|
||||
- Technical indicators: 100 features (padded)
|
||||
- Last predictions: 50 features (padded)
|
||||
Total: ~7834 features (actual measured)
|
||||
"""
|
||||
return 7834 # Based on actual BaseDataInput.get_feature_vector() measurement
|
||||
|
||||
def _build_input_processor(self) -> nn.Module:
|
||||
"""
|
||||
Build input processing layers for BaseDataInput
|
||||
|
||||
Returns:
|
||||
nn.Module: Input processing layers
|
||||
"""
|
||||
return nn.Sequential(
|
||||
# Initial processing of raw BaseDataInput features
|
||||
nn.Linear(self.expected_feature_dim, 4096),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.BatchNorm1d(4096),
|
||||
|
||||
# Feature refinement
|
||||
nn.Linear(4096, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.BatchNorm1d(2048),
|
||||
|
||||
# Final feature extraction
|
||||
nn.Linear(2048, 1024),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1)
|
||||
)
|
||||
|
||||
def _build_output_processor(self) -> nn.Module:
|
||||
"""
|
||||
Build output processing layers for standardized ModelOutput
|
||||
|
||||
Returns:
|
||||
nn.Module: Output processing layers
|
||||
"""
|
||||
return nn.Sequential(
|
||||
# Process CNN outputs for standardized format
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Final action prediction
|
||||
nn.Linear(512, 3), # BUY, SELL, HOLD
|
||||
nn.Softmax(dim=1)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Forward pass through the standardized CNN
|
||||
|
||||
Args:
|
||||
x: Input tensor from BaseDataInput.get_feature_vector()
|
||||
|
||||
Returns:
|
||||
Tuple of (action_probabilities, hidden_states_dict)
|
||||
"""
|
||||
batch_size = x.size(0)
|
||||
|
||||
# Validate input dimensions
|
||||
if x.size(1) != self.expected_feature_dim:
|
||||
logger.warning(f"Input dimension mismatch: expected {self.expected_feature_dim}, got {x.size(1)}")
|
||||
# Pad or truncate as needed
|
||||
if x.size(1) < self.expected_feature_dim:
|
||||
padding = torch.zeros(batch_size, self.expected_feature_dim - x.size(1), device=x.device)
|
||||
x = torch.cat([x, padding], dim=1)
|
||||
else:
|
||||
x = x[:, :self.expected_feature_dim]
|
||||
|
||||
# Process input through input processor
|
||||
processed_features = self.input_processor(x) # [batch, 1024]
|
||||
|
||||
# Get enhanced CNN predictions (using processed features as input)
|
||||
# We need to reshape for the enhanced CNN which expects different input format
|
||||
cnn_input = processed_features.unsqueeze(1) # Add sequence dimension
|
||||
|
||||
try:
|
||||
q_values, extrema_pred, price_pred, cnn_features, advanced_pred = self.enhanced_cnn(cnn_input)
|
||||
except Exception as e:
|
||||
logger.warning(f"Enhanced CNN forward pass failed: {e}, using fallback")
|
||||
# Fallback to direct processing
|
||||
cnn_features = processed_features
|
||||
q_values = torch.zeros(batch_size, 3, device=x.device)
|
||||
extrema_pred = torch.zeros(batch_size, 3, device=x.device)
|
||||
price_pred = torch.zeros(batch_size, 3, device=x.device)
|
||||
advanced_pred = torch.zeros(batch_size, 5, device=x.device)
|
||||
|
||||
# Process outputs for standardized format
|
||||
action_probs = self.output_processor(cnn_features) # [batch, 3]
|
||||
|
||||
# Prepare hidden states for cross-model feeding
|
||||
hidden_states = {
|
||||
'processed_features': processed_features.detach(),
|
||||
'cnn_features': cnn_features.detach(),
|
||||
'q_values': q_values.detach(),
|
||||
'extrema_predictions': extrema_pred.detach(),
|
||||
'price_predictions': price_pred.detach(),
|
||||
'advanced_predictions': advanced_pred.detach(),
|
||||
'attention_weights': torch.ones(batch_size, 1, device=x.device) # Placeholder
|
||||
}
|
||||
|
||||
return action_probs, hidden_states
|
||||
|
||||
def predict_from_base_input(self, base_input: BaseDataInput) -> ModelOutput:
|
||||
"""
|
||||
Make prediction from BaseDataInput and return standardized ModelOutput
|
||||
|
||||
Args:
|
||||
base_input: Standardized input data
|
||||
|
||||
Returns:
|
||||
ModelOutput: Standardized model output
|
||||
"""
|
||||
try:
|
||||
# Convert BaseDataInput to feature vector
|
||||
feature_vector = base_input.get_feature_vector()
|
||||
|
||||
# Convert to tensor and add batch dimension
|
||||
input_tensor = torch.tensor(feature_vector, dtype=torch.float32, device=self.device).unsqueeze(0)
|
||||
|
||||
# Set model to evaluation mode
|
||||
self.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
# Forward pass
|
||||
action_probs, hidden_states = self.forward(input_tensor)
|
||||
|
||||
# Get action and confidence
|
||||
action_probs_np = action_probs.squeeze(0).cpu().numpy()
|
||||
action_idx = np.argmax(action_probs_np)
|
||||
confidence = float(action_probs_np[action_idx])
|
||||
|
||||
# Map action index to action name
|
||||
action_names = ['BUY', 'SELL', 'HOLD']
|
||||
action = action_names[action_idx]
|
||||
|
||||
# Prepare predictions dictionary
|
||||
predictions = {
|
||||
'action': action,
|
||||
'buy_probability': float(action_probs_np[0]),
|
||||
'sell_probability': float(action_probs_np[1]),
|
||||
'hold_probability': float(action_probs_np[2]),
|
||||
'action_probabilities': action_probs_np.tolist(),
|
||||
'extrema_detected': self._interpret_extrema(hidden_states.get('extrema_predictions')),
|
||||
'price_direction': self._interpret_price_direction(hidden_states.get('price_predictions')),
|
||||
'market_conditions': self._interpret_advanced_predictions(hidden_states.get('advanced_predictions'))
|
||||
}
|
||||
|
||||
# Prepare hidden states for cross-model feeding (convert tensors to numpy)
|
||||
cross_model_states = {}
|
||||
for key, tensor in hidden_states.items():
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
cross_model_states[key] = tensor.squeeze(0).cpu().numpy().tolist()
|
||||
else:
|
||||
cross_model_states[key] = tensor
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
'model_version': '1.0',
|
||||
'confidence_threshold': self.confidence_threshold,
|
||||
'feature_dimension': self.expected_feature_dim,
|
||||
'processing_time_ms': 0, # Could add timing if needed
|
||||
'input_validation': base_input.validate()
|
||||
}
|
||||
|
||||
# Create standardized ModelOutput
|
||||
model_output = ModelOutput(
|
||||
model_type=self.model_type,
|
||||
model_name=self.model_name,
|
||||
symbol=base_input.symbol,
|
||||
timestamp=datetime.now(),
|
||||
confidence=confidence,
|
||||
predictions=predictions,
|
||||
hidden_states=cross_model_states,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
return model_output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN prediction: {e}")
|
||||
# Return default output
|
||||
return self._create_default_output(base_input.symbol)
|
||||
|
||||
def _interpret_extrema(self, extrema_tensor: Optional[torch.Tensor]) -> str:
|
||||
"""Interpret extrema predictions"""
|
||||
if extrema_tensor is None:
|
||||
return "unknown"
|
||||
|
||||
try:
|
||||
extrema_probs = torch.softmax(extrema_tensor.squeeze(0), dim=0)
|
||||
extrema_idx = torch.argmax(extrema_probs).item()
|
||||
extrema_labels = ['bottom', 'top', 'neither']
|
||||
return extrema_labels[extrema_idx]
|
||||
except:
|
||||
return "unknown"
|
||||
|
||||
def _interpret_price_direction(self, price_tensor: Optional[torch.Tensor]) -> str:
|
||||
"""Interpret price direction predictions"""
|
||||
if price_tensor is None:
|
||||
return "unknown"
|
||||
|
||||
try:
|
||||
price_probs = torch.softmax(price_tensor.squeeze(0), dim=0)
|
||||
price_idx = torch.argmax(price_probs).item()
|
||||
price_labels = ['up', 'down', 'sideways']
|
||||
return price_labels[price_idx]
|
||||
except:
|
||||
return "unknown"
|
||||
|
||||
def _interpret_advanced_predictions(self, advanced_tensor: Optional[torch.Tensor]) -> Dict[str, str]:
|
||||
"""Interpret advanced market predictions"""
|
||||
if advanced_tensor is None:
|
||||
return {"volatility": "unknown", "risk": "unknown"}
|
||||
|
||||
try:
|
||||
# Assuming advanced predictions include volatility (5 classes)
|
||||
if advanced_tensor.size(-1) >= 5:
|
||||
volatility_probs = torch.softmax(advanced_tensor.squeeze(0)[:5], dim=0)
|
||||
volatility_idx = torch.argmax(volatility_probs).item()
|
||||
volatility_labels = ['very_low', 'low', 'medium', 'high', 'very_high']
|
||||
volatility = volatility_labels[volatility_idx]
|
||||
else:
|
||||
volatility = "unknown"
|
||||
|
||||
return {
|
||||
"volatility": volatility,
|
||||
"risk": "medium" # Placeholder
|
||||
}
|
||||
except:
|
||||
return {"volatility": "unknown", "risk": "unknown"}
|
||||
|
||||
def _create_default_output(self, symbol: str) -> ModelOutput:
|
||||
"""Create default ModelOutput for error cases"""
|
||||
return create_model_output(
|
||||
model_type=self.model_type,
|
||||
model_name=self.model_name,
|
||||
symbol=symbol,
|
||||
action='HOLD',
|
||||
confidence=0.5,
|
||||
metadata={'error': True, 'default_output': True}
|
||||
)
|
||||
|
||||
def train_step(self, base_inputs: List[BaseDataInput], targets: List[str],
|
||||
optimizer: torch.optim.Optimizer) -> float:
|
||||
"""
|
||||
Perform a single training step
|
||||
|
||||
Args:
|
||||
base_inputs: List of BaseDataInput for training
|
||||
targets: List of target actions ('BUY', 'SELL', 'HOLD')
|
||||
optimizer: PyTorch optimizer
|
||||
|
||||
Returns:
|
||||
float: Training loss
|
||||
"""
|
||||
self.train()
|
||||
|
||||
try:
|
||||
# Convert inputs to tensors
|
||||
feature_vectors = []
|
||||
for base_input in base_inputs:
|
||||
feature_vector = base_input.get_feature_vector()
|
||||
feature_vectors.append(feature_vector)
|
||||
|
||||
input_tensor = torch.tensor(np.array(feature_vectors), dtype=torch.float32, device=self.device)
|
||||
|
||||
# Convert targets to tensor
|
||||
action_to_idx = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
|
||||
target_indices = [action_to_idx.get(target, 2) for target in targets]
|
||||
target_tensor = torch.tensor(target_indices, dtype=torch.long, device=self.device)
|
||||
|
||||
# Forward pass
|
||||
action_probs, _ = self.forward(input_tensor)
|
||||
|
||||
# Calculate loss
|
||||
loss = F.cross_entropy(action_probs, target_tensor)
|
||||
|
||||
# Backward pass
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
return float(loss.item())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training step: {e}")
|
||||
return float('inf')
|
||||
|
||||
def evaluate(self, base_inputs: List[BaseDataInput], targets: List[str]) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate model performance
|
||||
|
||||
Args:
|
||||
base_inputs: List of BaseDataInput for evaluation
|
||||
targets: List of target actions
|
||||
|
||||
Returns:
|
||||
Dict containing evaluation metrics
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
try:
|
||||
correct = 0
|
||||
total = len(base_inputs)
|
||||
total_confidence = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
for base_input, target in zip(base_inputs, targets):
|
||||
model_output = self.predict_from_base_input(base_input)
|
||||
predicted_action = model_output.predictions['action']
|
||||
|
||||
if predicted_action == target:
|
||||
correct += 1
|
||||
|
||||
total_confidence += model_output.confidence
|
||||
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
avg_confidence = total_confidence / total if total > 0 else 0.0
|
||||
|
||||
return {
|
||||
'accuracy': accuracy,
|
||||
'avg_confidence': avg_confidence,
|
||||
'correct_predictions': correct,
|
||||
'total_predictions': total
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in evaluation: {e}")
|
||||
return {'accuracy': 0.0, 'avg_confidence': 0.0, 'correct_predictions': 0, 'total_predictions': 0}
|
||||
|
||||
def save_checkpoint(self, filepath: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Save model checkpoint
|
||||
|
||||
Args:
|
||||
filepath: Path to save checkpoint
|
||||
metadata: Optional metadata to save with checkpoint
|
||||
"""
|
||||
try:
|
||||
checkpoint = {
|
||||
'model_state_dict': self.state_dict(),
|
||||
'model_name': self.model_name,
|
||||
'model_type': self.model_type,
|
||||
'confidence_threshold': self.confidence_threshold,
|
||||
'expected_feature_dim': self.expected_feature_dim,
|
||||
'metadata': metadata or {},
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
torch.save(checkpoint, filepath)
|
||||
logger.info(f"Checkpoint saved to {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint: {e}")
|
||||
|
||||
def load_checkpoint(self, filepath: str) -> bool:
|
||||
"""
|
||||
Load model checkpoint
|
||||
|
||||
Args:
|
||||
filepath: Path to checkpoint file
|
||||
|
||||
Returns:
|
||||
bool: True if loaded successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
checkpoint = torch.load(filepath, map_location=self.device)
|
||||
|
||||
# Load model state
|
||||
self.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
# Load configuration
|
||||
self.model_name = checkpoint.get('model_name', self.model_name)
|
||||
self.confidence_threshold = checkpoint.get('confidence_threshold', self.confidence_threshold)
|
||||
self.expected_feature_dim = checkpoint.get('expected_feature_dim', self.expected_feature_dim)
|
||||
|
||||
logger.info(f"Checkpoint loaded from {filepath}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""Get model information"""
|
||||
return {
|
||||
'model_name': self.model_name,
|
||||
'model_type': self.model_type,
|
||||
'confidence_threshold': self.confidence_threshold,
|
||||
'expected_feature_dim': self.expected_feature_dim,
|
||||
'device': str(self.device),
|
||||
'parameter_count': sum(p.numel() for p in self.parameters()),
|
||||
'trainable_parameters': sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
}
|
Reference in New Issue
Block a user