wip cnn training and cob

This commit is contained in:
Dobromir Popov
2025-07-23 23:33:36 +03:00
parent 8677c4c01c
commit 5437495003
4 changed files with 599 additions and 210 deletions

View File

@ -46,7 +46,7 @@ class EnhancedCNNAdapter:
self.max_training_samples = 10000
self.batch_size = 32
self.learning_rate = 0.0001
self.model_name = "enhanced_cnn_v1"
self.model_name = "enhanced_cnn"
# Enhanced metrics tracking
self.last_inference_time = None
@ -72,6 +72,30 @@ class EnhancedCNNAdapter:
logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
def _initialize_model(self):
"""Initialize the EnhancedCNN model"""
try:
# Calculate input shape based on BaseDataInput structure
# OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features
# BTC OHLCV: 300 frames x 5 features = 1500 features
# COB: ±20 buckets x 4 metrics = 160 features
# MA: 4 timeframes x 10 buckets = 40 features
# Technical indicators: 100 features
# Last predictions: 50 features
# Total: 7850 features
input_shape = 7850
n_actions = 3 # BUY, SELL, HOLD
# Create model
self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
self.model.to(self.device)
logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions}")
except Exception as e:
logger.error(f"Error initializing EnhancedCNN model: {e}")
raise
def _load_checkpoint(self, checkpoint_path: str) -> bool:
"""Load model from checkpoint path"""
try:
@ -98,6 +122,45 @@ class EnhancedCNNAdapter:
logger.error(f"Error loading best checkpoint: {e}")
return False
def load_best_checkpoint(self) -> bool:
"""Load the best checkpoint based on accuracy"""
try:
# Import checkpoint manager
from utils.checkpoint_manager import CheckpointManager
# Create checkpoint manager
checkpoint_manager = CheckpointManager(
checkpoint_dir=self.checkpoint_dir,
max_checkpoints=10,
metric_name="accuracy"
)
# Load best checkpoint
best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name)
if not best_checkpoint_path:
logger.info(f"No checkpoints found for {self.model_name} - starting in COLD START mode")
return False
# Load model
success = self.model.load(best_checkpoint_path)
if success:
logger.info(f"Loaded best checkpoint from {best_checkpoint_path}")
# Log metrics
metrics = best_checkpoint_metadata.get('metrics', {})
logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}")
return True
else:
logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}")
return False
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return False
def _create_default_output(self, symbol: str) -> ModelOutput:
@ -124,37 +187,7 @@ class EnhancedCNNAdapter:
return processed_states
def _initialize_model(self):
"""Initialize the EnhancedCNN model"""
try:
# Calculate input shape based on BaseDataInput structure
# OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features
# BTC OHLCV: 300 frames x 5 features = 1500 features
# COB: ±20 buckets x 4 metrics = 160 features
# MA: 4 timeframes x 10 buckets = 40 features
# Technical indicators: 100 features
# Last predictions: 50 features
# Total: 7850 features
input_shape = 7850
n_actions = 3 # BUY, SELL, HOLD
# Create model
self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
self.model.to(self.device)
# Load model if path is provided
if self.model_path:
success = self.model.load(self.model_path)
if success:
logger.info(f"Model loaded from {self.model_path}")
else:
logger.warning(f"Failed to load model from {self.model_path}, using new model")
else:
logger.info("No model path provided, using new model")
except Exception as e:
logger.error(f"Error initializing EnhancedCNN model: {e}")
raise
def _convert_base_data_to_features(self, base_data: BaseDataInput) -> torch.Tensor:
"""
@ -298,18 +331,35 @@ class EnhancedCNNAdapter:
confidence=0.0
)
def add_training_sample(self, base_data: BaseDataInput, actual_action: str, reward: float):
def add_training_sample(self, symbol_or_base_data, actual_action: str, reward: float):
"""
Add a training sample to the training data
Args:
base_data: Standardized input data
symbol_or_base_data: Either a symbol string or BaseDataInput object
actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
reward: Reward received for the action
"""
try:
# Convert BaseDataInput to features
features = self._convert_base_data_to_features(base_data)
# Handle both symbol string and BaseDataInput object
if isinstance(symbol_or_base_data, str):
# For cold start mode - create a simple training sample with current features
# This is a simplified approach for rapid training
symbol = symbol_or_base_data
# Create a simple feature vector (this could be enhanced with actual market data)
# For now, use a random feature vector as placeholder for cold start
features = torch.randn(7850, dtype=torch.float32, device=self.device)
logger.debug(f"Added simplified training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}")
else:
# Full BaseDataInput object
base_data = symbol_or_base_data
features = self._convert_base_data_to_features(base_data)
symbol = base_data.symbol
logger.debug(f"Added full training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}")
# Convert action to index
actions = ['BUY', 'SELL', 'HOLD']
@ -325,8 +375,6 @@ class EnhancedCNNAdapter:
self.training_data.sort(key=lambda x: x[2], reverse=True)
self.training_data = self.training_data[:self.max_training_samples]
logger.debug(f"Added training sample for {base_data.symbol}, action: {actual_action}, reward: {reward:.4f}")
except Exception as e:
logger.error(f"Error adding training sample: {e}")
@ -511,41 +559,3 @@ class EnhancedCNNAdapter:
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
def load_best_checkpoint(self):
"""Load the best checkpoint based on accuracy"""
try:
# Import checkpoint manager
from utils.checkpoint_manager import CheckpointManager
# Create checkpoint manager
checkpoint_manager = CheckpointManager(
checkpoint_dir=self.checkpoint_dir,
max_checkpoints=10,
metric_name="accuracy"
)
# Load best checkpoint
best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name)
if not best_checkpoint_path:
logger.info("No checkpoints found")
return False
# Load model
success = self.model.load(best_checkpoint_path)
if success:
logger.info(f"Loaded best checkpoint from {best_checkpoint_path}")
# Log metrics
metrics = best_checkpoint_metadata.get('metrics', {})
logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}")
return True
else:
logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}")
return False
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return False