wip cnn training and cob
This commit is contained in:
@ -46,7 +46,7 @@ class EnhancedCNNAdapter:
|
||||
self.max_training_samples = 10000
|
||||
self.batch_size = 32
|
||||
self.learning_rate = 0.0001
|
||||
self.model_name = "enhanced_cnn_v1"
|
||||
self.model_name = "enhanced_cnn"
|
||||
|
||||
# Enhanced metrics tracking
|
||||
self.last_inference_time = None
|
||||
@ -72,6 +72,30 @@ class EnhancedCNNAdapter:
|
||||
|
||||
logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
|
||||
|
||||
def _initialize_model(self):
|
||||
"""Initialize the EnhancedCNN model"""
|
||||
try:
|
||||
# Calculate input shape based on BaseDataInput structure
|
||||
# OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features
|
||||
# BTC OHLCV: 300 frames x 5 features = 1500 features
|
||||
# COB: ±20 buckets x 4 metrics = 160 features
|
||||
# MA: 4 timeframes x 10 buckets = 40 features
|
||||
# Technical indicators: 100 features
|
||||
# Last predictions: 50 features
|
||||
# Total: 7850 features
|
||||
input_shape = 7850
|
||||
n_actions = 3 # BUY, SELL, HOLD
|
||||
|
||||
# Create model
|
||||
self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
|
||||
self.model.to(self.device)
|
||||
|
||||
logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing EnhancedCNN model: {e}")
|
||||
raise
|
||||
|
||||
def _load_checkpoint(self, checkpoint_path: str) -> bool:
|
||||
"""Load model from checkpoint path"""
|
||||
try:
|
||||
@ -98,6 +122,45 @@ class EnhancedCNNAdapter:
|
||||
logger.error(f"Error loading best checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def load_best_checkpoint(self) -> bool:
|
||||
"""Load the best checkpoint based on accuracy"""
|
||||
try:
|
||||
# Import checkpoint manager
|
||||
from utils.checkpoint_manager import CheckpointManager
|
||||
|
||||
# Create checkpoint manager
|
||||
checkpoint_manager = CheckpointManager(
|
||||
checkpoint_dir=self.checkpoint_dir,
|
||||
max_checkpoints=10,
|
||||
metric_name="accuracy"
|
||||
)
|
||||
|
||||
# Load best checkpoint
|
||||
best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name)
|
||||
|
||||
if not best_checkpoint_path:
|
||||
logger.info(f"No checkpoints found for {self.model_name} - starting in COLD START mode")
|
||||
return False
|
||||
|
||||
# Load model
|
||||
success = self.model.load(best_checkpoint_path)
|
||||
|
||||
if success:
|
||||
logger.info(f"Loaded best checkpoint from {best_checkpoint_path}")
|
||||
|
||||
# Log metrics
|
||||
metrics = best_checkpoint_metadata.get('metrics', {})
|
||||
logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}")
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint: {e}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def _create_default_output(self, symbol: str) -> ModelOutput:
|
||||
@ -124,37 +187,7 @@ class EnhancedCNNAdapter:
|
||||
|
||||
return processed_states
|
||||
|
||||
def _initialize_model(self):
|
||||
"""Initialize the EnhancedCNN model"""
|
||||
try:
|
||||
# Calculate input shape based on BaseDataInput structure
|
||||
# OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features
|
||||
# BTC OHLCV: 300 frames x 5 features = 1500 features
|
||||
# COB: ±20 buckets x 4 metrics = 160 features
|
||||
# MA: 4 timeframes x 10 buckets = 40 features
|
||||
# Technical indicators: 100 features
|
||||
# Last predictions: 50 features
|
||||
# Total: 7850 features
|
||||
input_shape = 7850
|
||||
n_actions = 3 # BUY, SELL, HOLD
|
||||
|
||||
# Create model
|
||||
self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
|
||||
self.model.to(self.device)
|
||||
|
||||
# Load model if path is provided
|
||||
if self.model_path:
|
||||
success = self.model.load(self.model_path)
|
||||
if success:
|
||||
logger.info(f"Model loaded from {self.model_path}")
|
||||
else:
|
||||
logger.warning(f"Failed to load model from {self.model_path}, using new model")
|
||||
else:
|
||||
logger.info("No model path provided, using new model")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing EnhancedCNN model: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _convert_base_data_to_features(self, base_data: BaseDataInput) -> torch.Tensor:
|
||||
"""
|
||||
@ -298,18 +331,35 @@ class EnhancedCNNAdapter:
|
||||
confidence=0.0
|
||||
)
|
||||
|
||||
def add_training_sample(self, base_data: BaseDataInput, actual_action: str, reward: float):
|
||||
def add_training_sample(self, symbol_or_base_data, actual_action: str, reward: float):
|
||||
"""
|
||||
Add a training sample to the training data
|
||||
|
||||
Args:
|
||||
base_data: Standardized input data
|
||||
symbol_or_base_data: Either a symbol string or BaseDataInput object
|
||||
actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
|
||||
reward: Reward received for the action
|
||||
"""
|
||||
try:
|
||||
# Convert BaseDataInput to features
|
||||
features = self._convert_base_data_to_features(base_data)
|
||||
# Handle both symbol string and BaseDataInput object
|
||||
if isinstance(symbol_or_base_data, str):
|
||||
# For cold start mode - create a simple training sample with current features
|
||||
# This is a simplified approach for rapid training
|
||||
symbol = symbol_or_base_data
|
||||
|
||||
# Create a simple feature vector (this could be enhanced with actual market data)
|
||||
# For now, use a random feature vector as placeholder for cold start
|
||||
features = torch.randn(7850, dtype=torch.float32, device=self.device)
|
||||
|
||||
logger.debug(f"Added simplified training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}")
|
||||
|
||||
else:
|
||||
# Full BaseDataInput object
|
||||
base_data = symbol_or_base_data
|
||||
features = self._convert_base_data_to_features(base_data)
|
||||
symbol = base_data.symbol
|
||||
|
||||
logger.debug(f"Added full training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}")
|
||||
|
||||
# Convert action to index
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
@ -325,8 +375,6 @@ class EnhancedCNNAdapter:
|
||||
self.training_data.sort(key=lambda x: x[2], reverse=True)
|
||||
self.training_data = self.training_data[:self.max_training_samples]
|
||||
|
||||
logger.debug(f"Added training sample for {base_data.symbol}, action: {actual_action}, reward: {reward:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding training sample: {e}")
|
||||
|
||||
@ -511,41 +559,3 @@ class EnhancedCNNAdapter:
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint: {e}")
|
||||
|
||||
def load_best_checkpoint(self):
|
||||
"""Load the best checkpoint based on accuracy"""
|
||||
try:
|
||||
# Import checkpoint manager
|
||||
from utils.checkpoint_manager import CheckpointManager
|
||||
|
||||
# Create checkpoint manager
|
||||
checkpoint_manager = CheckpointManager(
|
||||
checkpoint_dir=self.checkpoint_dir,
|
||||
max_checkpoints=10,
|
||||
metric_name="accuracy"
|
||||
)
|
||||
|
||||
# Load best checkpoint
|
||||
best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name)
|
||||
|
||||
if not best_checkpoint_path:
|
||||
logger.info("No checkpoints found")
|
||||
return False
|
||||
|
||||
# Load model
|
||||
success = self.model.load(best_checkpoint_path)
|
||||
|
||||
if success:
|
||||
logger.info(f"Loaded best checkpoint from {best_checkpoint_path}")
|
||||
|
||||
# Log metrics
|
||||
metrics = best_checkpoint_metadata.get('metrics', {})
|
||||
logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}")
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint: {e}")
|
||||
return False
|
Reference in New Issue
Block a user