This commit is contained in:
Dobromir Popov
2025-07-27 20:56:37 +03:00
parent bd986f4534
commit 9e1684f9f8
7 changed files with 531 additions and 112 deletions

View File

@ -444,12 +444,15 @@ class TradingOrchestrator:
logger.warning("DQN Agent not available")
self.rl_agent = None
# Initialize CNN Model with Adapter
# Initialize CNN Model directly (no adapter)
try:
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
from NN.models.enhanced_cnn import EnhancedCNN
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
self.cnn_model = self.cnn_adapter.model # Keep reference for compatibility
# Initialize CNN model directly
input_shape = 7850 # Unified feature vector size
n_actions = 3 # BUY, SELL, HOLD
self.cnn_model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
self.cnn_adapter = None # No adapter needed
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN
# Load best checkpoint and capture initial state (using database metadata)
@ -476,7 +479,7 @@ class TradingOrchestrator:
self.model_states['cnn']['best_loss'] = None
logger.info("CNN starting fresh - no checkpoint found")
logger.info("Enhanced CNN adapter initialized")
logger.info("Enhanced CNN model initialized directly")
except ImportError:
try:
from NN.models.standardized_cnn import StandardizedCNN
@ -1672,7 +1675,7 @@ class TradingOrchestrator:
processing_time_ms=0.0, # We don't track this in orchestrator
memory_usage_mb=0.0, # We don't track this in orchestrator
input_features=input_features_array,
checkpoint_id=None,f
checkpoint_id=None,
metadata=inference_record.get('metadata', {})
)
@ -2376,49 +2379,72 @@ class TradingOrchestrator:
async def _train_cnn_model(self, model, model_name: str, record: Dict, prediction: Dict, reward: float) -> bool:
"""Train CNN model with training samples"""
try:
# Check if we have CNN adapter (preferred method)
if hasattr(self, 'cnn_adapter') and self.cnn_adapter and 'cnn' in model_name.lower():
# Direct CNN model training (no adapter)
if hasattr(self, 'cnn_model') and self.cnn_model and 'cnn' in model_name.lower():
try:
symbol = record.get('symbol', 'ETH/USDT')
actual_action = prediction['action']
# Add training sample to adapter
if hasattr(self.cnn_adapter, 'add_training_sample'):
self.cnn_adapter.add_training_sample(symbol, actual_action, reward)
logger.debug(f"Added training sample to CNN adapter: action={actual_action}, reward={reward:.3f}")
# Create training sample from record
model_input = record.get('model_input')
if model_input is not None:
# Convert to tensor and ensure device placement
device = next(self.cnn_model.parameters()).device
# Check if we have enough samples to train
if hasattr(self.cnn_adapter, 'training_data') and hasattr(self.cnn_adapter, 'batch_size'):
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
logger.debug(f"Training CNN with {len(self.cnn_adapter.training_data)} samples")
training_start_time = time.time()
# Add validation to prevent overfitting
training_results = self.cnn_adapter.train(epochs=1)
training_duration_ms = (time.time() - training_start_time) * 1000
if training_results and 'loss' in training_results:
current_loss = training_results['loss']
accuracy = training_results.get('accuracy', 0.0)
# Validate training results - 100% accuracy is suspicious
if accuracy >= 0.99:
logger.warning(f"CNN training shows suspiciously high accuracy: {accuracy:.4f} - possible overfitting")
# Don't update loss if accuracy is too high (likely overfitting)
logger.warning("Skipping loss update due to potential overfitting")
else:
self.update_model_loss(model_name, current_loss)
self._update_model_training_statistics(model_name, current_loss, training_duration_ms)
logger.debug(f"CNN training completed: loss={current_loss:.4f}, accuracy={accuracy:.4f}, time={training_duration_ms:.1f}ms")
return True
else:
# Still update training statistics even if no loss returned
self._update_model_training_statistics(model_name, training_duration_ms=training_duration_ms)
else:
logger.debug(f"Not enough samples for CNN training: {len(self.cnn_adapter.training_data)}/{self.cnn_adapter.batch_size}")
return True # Sample added successfully
if hasattr(model_input, 'get_feature_vector'):
features = model_input.get_feature_vector()
elif isinstance(model_input, np.ndarray):
features = model_input
else:
features = np.array(model_input, dtype=np.float32)
features_tensor = torch.tensor(features, dtype=torch.float32, device=device)
if features_tensor.dim() == 1:
features_tensor = features_tensor.unsqueeze(0)
# Convert action to index
actions = ['BUY', 'SELL', 'HOLD']
action_idx = actions.index(actual_action) if actual_action in actions else 2
action_tensor = torch.tensor([action_idx], dtype=torch.long, device=device)
reward_tensor = torch.tensor([reward], dtype=torch.float32, device=device)
# Perform training step
self.cnn_model.train()
self.cnn_optimizer.zero_grad()
# Forward pass
q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.cnn_model(features_tensor)
# Calculate loss
q_values_selected = q_values.gather(1, action_tensor.unsqueeze(1)).squeeze(1)
target_q = reward_tensor # Simplified target
loss = nn.MSELoss()(q_values_selected, target_q)
# Backward pass
training_start_time = time.time()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.cnn_model.parameters(), max_norm=1.0)
# Optimizer step
self.cnn_optimizer.step()
training_duration_ms = (time.time() - training_start_time) * 1000
# Update statistics
current_loss = loss.item()
self.update_model_loss(model_name, current_loss)
self._update_model_training_statistics(model_name, current_loss, training_duration_ms)
logger.debug(f"CNN direct training completed: loss={current_loss:.4f}, time={training_duration_ms:.1f}ms")
return True
else:
logger.debug(f"CNN adapter doesn't have add_training_sample method")
logger.warning(f"No model input available for CNN training")
return False
except Exception as e:
logger.error(f"Error in direct CNN training: {e}")
return False
# Try direct model training methods
elif hasattr(model, 'add_training_sample'):
@ -2588,43 +2614,70 @@ class TradingOrchestrator:
logger.warning(f"Cannot build BaseDataInput for CNN prediction: {symbol}")
return predictions
# Use CNN adapter if available
if hasattr(self, 'cnn_adapter') and self.cnn_adapter:
# Direct CNN model inference (no adapter needed)
if hasattr(self, 'cnn_model') and self.cnn_model:
try:
result = self.cnn_adapter.predict(base_data)
if result:
# Extract action and probabilities from ModelOutput
action = result.predictions.get('action', 'HOLD')
# Get feature vector from base_data
features = base_data.get_feature_vector()
# Convert to tensor and ensure proper device placement
device = next(self.cnn_model.parameters()).device
features_tensor = torch.tensor(features, dtype=torch.float32, device=device)
# Ensure batch dimension
if features_tensor.dim() == 1:
features_tensor = features_tensor.unsqueeze(0)
# Set model to evaluation mode
self.cnn_model.eval()
# Get prediction from CNN model
with torch.no_grad():
q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.cnn_model(features_tensor)
# Convert to probabilities using softmax
action_probs = torch.softmax(q_values, dim=1)
action_idx = torch.argmax(action_probs, dim=1).item()
confidence = float(action_probs[0, action_idx].item())
# Map action index to action string
actions = ['BUY', 'SELL', 'HOLD']
action = actions[action_idx]
# Create probabilities dictionary
probabilities = {
'BUY': result.predictions.get('buy_probability', 0.0),
'SELL': result.predictions.get('sell_probability', 0.0),
'HOLD': result.predictions.get('hold_probability', 0.0)
'BUY': float(action_probs[0, 0].item()),
'SELL': float(action_probs[0, 1].item()),
'HOLD': float(action_probs[0, 2].item())
}
# Extract price predictions if available
price_prediction = None
if price_pred is not None:
price_prediction = price_pred.squeeze(0).cpu().numpy().tolist()
prediction = Prediction(
action=action,
confidence=result.confidence,
confidence=confidence,
probabilities=probabilities,
timeframe="multi", # Multi-timeframe prediction
timestamp=datetime.now(),
model_name=model.name, # Use the actual model name, not hardcoded "enhanced_cnn"
model_name=model.name, # Use the actual model name
metadata={
'feature_size': len(base_data.get_feature_vector()),
'data_sources': ['ohlcv_1s', 'ohlcv_1m', 'ohlcv_1h', 'ohlcv_1d', 'btc', 'cob', 'indicators'],
'pivot_price': result.predictions.get('pivot_price'),
'extrema_prediction': result.predictions.get('extrema'),
'price_prediction': result.predictions.get('price_prediction')
'price_prediction': price_prediction,
'extrema_prediction': extrema_pred.squeeze(0).cpu().numpy().tolist() if extrema_pred is not None else None
}
)
predictions.append(prediction)
# Store prediction in database for training
logger.debug(f"Added CNN prediction to database: {prediction}")
# Note: Inference data will be stored in main prediction loop to avoid duplication
logger.debug(f"Added CNN prediction: {action} ({confidence:.3f})")
except Exception as e:
logger.error(f"Error using CNN adapter: {e}")
logger.error(f"Error using direct CNN model: {e}")
import traceback
traceback.print_exc()
# Fallback to direct model inference using BaseDataInput (unified approach)
if not predictions:
@ -2689,7 +2742,7 @@ class TradingOrchestrator:
logger.info(f"CNN fallback successful for {symbol}: {best_action} (confidence: {confidence:.3f})")
else:
logger.warning(f"CNN model {model.name} does not have act() method for fallback")
logger.debug(f"CNN model {model.name} fallback not needed - direct inference succeeded")
except Exception as e:
logger.error(f"CNN fallback inference failed for {symbol}: {e}")