cb ws
This commit is contained in:
@ -444,12 +444,15 @@ class TradingOrchestrator:
|
||||
logger.warning("DQN Agent not available")
|
||||
self.rl_agent = None
|
||||
|
||||
# Initialize CNN Model with Adapter
|
||||
# Initialize CNN Model directly (no adapter)
|
||||
try:
|
||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
from NN.models.enhanced_cnn import EnhancedCNN
|
||||
|
||||
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||
self.cnn_model = self.cnn_adapter.model # Keep reference for compatibility
|
||||
# Initialize CNN model directly
|
||||
input_shape = 7850 # Unified feature vector size
|
||||
n_actions = 3 # BUY, SELL, HOLD
|
||||
self.cnn_model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
|
||||
self.cnn_adapter = None # No adapter needed
|
||||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN
|
||||
|
||||
# Load best checkpoint and capture initial state (using database metadata)
|
||||
@ -476,7 +479,7 @@ class TradingOrchestrator:
|
||||
self.model_states['cnn']['best_loss'] = None
|
||||
logger.info("CNN starting fresh - no checkpoint found")
|
||||
|
||||
logger.info("Enhanced CNN adapter initialized")
|
||||
logger.info("Enhanced CNN model initialized directly")
|
||||
except ImportError:
|
||||
try:
|
||||
from NN.models.standardized_cnn import StandardizedCNN
|
||||
@ -1672,7 +1675,7 @@ class TradingOrchestrator:
|
||||
processing_time_ms=0.0, # We don't track this in orchestrator
|
||||
memory_usage_mb=0.0, # We don't track this in orchestrator
|
||||
input_features=input_features_array,
|
||||
checkpoint_id=None,f
|
||||
checkpoint_id=None,
|
||||
metadata=inference_record.get('metadata', {})
|
||||
)
|
||||
|
||||
@ -2376,49 +2379,72 @@ class TradingOrchestrator:
|
||||
async def _train_cnn_model(self, model, model_name: str, record: Dict, prediction: Dict, reward: float) -> bool:
|
||||
"""Train CNN model with training samples"""
|
||||
try:
|
||||
# Check if we have CNN adapter (preferred method)
|
||||
if hasattr(self, 'cnn_adapter') and self.cnn_adapter and 'cnn' in model_name.lower():
|
||||
# Direct CNN model training (no adapter)
|
||||
if hasattr(self, 'cnn_model') and self.cnn_model and 'cnn' in model_name.lower():
|
||||
try:
|
||||
symbol = record.get('symbol', 'ETH/USDT')
|
||||
actual_action = prediction['action']
|
||||
|
||||
# Add training sample to adapter
|
||||
if hasattr(self.cnn_adapter, 'add_training_sample'):
|
||||
self.cnn_adapter.add_training_sample(symbol, actual_action, reward)
|
||||
logger.debug(f"Added training sample to CNN adapter: action={actual_action}, reward={reward:.3f}")
|
||||
# Create training sample from record
|
||||
model_input = record.get('model_input')
|
||||
if model_input is not None:
|
||||
# Convert to tensor and ensure device placement
|
||||
device = next(self.cnn_model.parameters()).device
|
||||
|
||||
# Check if we have enough samples to train
|
||||
if hasattr(self.cnn_adapter, 'training_data') and hasattr(self.cnn_adapter, 'batch_size'):
|
||||
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
|
||||
logger.debug(f"Training CNN with {len(self.cnn_adapter.training_data)} samples")
|
||||
training_start_time = time.time()
|
||||
|
||||
# Add validation to prevent overfitting
|
||||
training_results = self.cnn_adapter.train(epochs=1)
|
||||
training_duration_ms = (time.time() - training_start_time) * 1000
|
||||
|
||||
if training_results and 'loss' in training_results:
|
||||
current_loss = training_results['loss']
|
||||
accuracy = training_results.get('accuracy', 0.0)
|
||||
|
||||
# Validate training results - 100% accuracy is suspicious
|
||||
if accuracy >= 0.99:
|
||||
logger.warning(f"CNN training shows suspiciously high accuracy: {accuracy:.4f} - possible overfitting")
|
||||
# Don't update loss if accuracy is too high (likely overfitting)
|
||||
logger.warning("Skipping loss update due to potential overfitting")
|
||||
else:
|
||||
self.update_model_loss(model_name, current_loss)
|
||||
|
||||
self._update_model_training_statistics(model_name, current_loss, training_duration_ms)
|
||||
logger.debug(f"CNN training completed: loss={current_loss:.4f}, accuracy={accuracy:.4f}, time={training_duration_ms:.1f}ms")
|
||||
return True
|
||||
else:
|
||||
# Still update training statistics even if no loss returned
|
||||
self._update_model_training_statistics(model_name, training_duration_ms=training_duration_ms)
|
||||
else:
|
||||
logger.debug(f"Not enough samples for CNN training: {len(self.cnn_adapter.training_data)}/{self.cnn_adapter.batch_size}")
|
||||
return True # Sample added successfully
|
||||
if hasattr(model_input, 'get_feature_vector'):
|
||||
features = model_input.get_feature_vector()
|
||||
elif isinstance(model_input, np.ndarray):
|
||||
features = model_input
|
||||
else:
|
||||
features = np.array(model_input, dtype=np.float32)
|
||||
|
||||
features_tensor = torch.tensor(features, dtype=torch.float32, device=device)
|
||||
if features_tensor.dim() == 1:
|
||||
features_tensor = features_tensor.unsqueeze(0)
|
||||
|
||||
# Convert action to index
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
action_idx = actions.index(actual_action) if actual_action in actions else 2
|
||||
action_tensor = torch.tensor([action_idx], dtype=torch.long, device=device)
|
||||
reward_tensor = torch.tensor([reward], dtype=torch.float32, device=device)
|
||||
|
||||
# Perform training step
|
||||
self.cnn_model.train()
|
||||
self.cnn_optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.cnn_model(features_tensor)
|
||||
|
||||
# Calculate loss
|
||||
q_values_selected = q_values.gather(1, action_tensor.unsqueeze(1)).squeeze(1)
|
||||
target_q = reward_tensor # Simplified target
|
||||
loss = nn.MSELoss()(q_values_selected, target_q)
|
||||
|
||||
# Backward pass
|
||||
training_start_time = time.time()
|
||||
loss.backward()
|
||||
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(self.cnn_model.parameters(), max_norm=1.0)
|
||||
|
||||
# Optimizer step
|
||||
self.cnn_optimizer.step()
|
||||
training_duration_ms = (time.time() - training_start_time) * 1000
|
||||
|
||||
# Update statistics
|
||||
current_loss = loss.item()
|
||||
self.update_model_loss(model_name, current_loss)
|
||||
self._update_model_training_statistics(model_name, current_loss, training_duration_ms)
|
||||
|
||||
logger.debug(f"CNN direct training completed: loss={current_loss:.4f}, time={training_duration_ms:.1f}ms")
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"CNN adapter doesn't have add_training_sample method")
|
||||
logger.warning(f"No model input available for CNN training")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in direct CNN training: {e}")
|
||||
return False
|
||||
|
||||
# Try direct model training methods
|
||||
elif hasattr(model, 'add_training_sample'):
|
||||
@ -2588,43 +2614,70 @@ class TradingOrchestrator:
|
||||
logger.warning(f"Cannot build BaseDataInput for CNN prediction: {symbol}")
|
||||
return predictions
|
||||
|
||||
# Use CNN adapter if available
|
||||
if hasattr(self, 'cnn_adapter') and self.cnn_adapter:
|
||||
# Direct CNN model inference (no adapter needed)
|
||||
if hasattr(self, 'cnn_model') and self.cnn_model:
|
||||
try:
|
||||
result = self.cnn_adapter.predict(base_data)
|
||||
if result:
|
||||
# Extract action and probabilities from ModelOutput
|
||||
action = result.predictions.get('action', 'HOLD')
|
||||
# Get feature vector from base_data
|
||||
features = base_data.get_feature_vector()
|
||||
|
||||
# Convert to tensor and ensure proper device placement
|
||||
device = next(self.cnn_model.parameters()).device
|
||||
features_tensor = torch.tensor(features, dtype=torch.float32, device=device)
|
||||
|
||||
# Ensure batch dimension
|
||||
if features_tensor.dim() == 1:
|
||||
features_tensor = features_tensor.unsqueeze(0)
|
||||
|
||||
# Set model to evaluation mode
|
||||
self.cnn_model.eval()
|
||||
|
||||
# Get prediction from CNN model
|
||||
with torch.no_grad():
|
||||
q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.cnn_model(features_tensor)
|
||||
|
||||
# Convert to probabilities using softmax
|
||||
action_probs = torch.softmax(q_values, dim=1)
|
||||
action_idx = torch.argmax(action_probs, dim=1).item()
|
||||
confidence = float(action_probs[0, action_idx].item())
|
||||
|
||||
# Map action index to action string
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
action = actions[action_idx]
|
||||
|
||||
# Create probabilities dictionary
|
||||
probabilities = {
|
||||
'BUY': result.predictions.get('buy_probability', 0.0),
|
||||
'SELL': result.predictions.get('sell_probability', 0.0),
|
||||
'HOLD': result.predictions.get('hold_probability', 0.0)
|
||||
'BUY': float(action_probs[0, 0].item()),
|
||||
'SELL': float(action_probs[0, 1].item()),
|
||||
'HOLD': float(action_probs[0, 2].item())
|
||||
}
|
||||
|
||||
# Extract price predictions if available
|
||||
price_prediction = None
|
||||
if price_pred is not None:
|
||||
price_prediction = price_pred.squeeze(0).cpu().numpy().tolist()
|
||||
|
||||
prediction = Prediction(
|
||||
action=action,
|
||||
confidence=result.confidence,
|
||||
confidence=confidence,
|
||||
probabilities=probabilities,
|
||||
timeframe="multi", # Multi-timeframe prediction
|
||||
timestamp=datetime.now(),
|
||||
model_name=model.name, # Use the actual model name, not hardcoded "enhanced_cnn"
|
||||
model_name=model.name, # Use the actual model name
|
||||
metadata={
|
||||
'feature_size': len(base_data.get_feature_vector()),
|
||||
'data_sources': ['ohlcv_1s', 'ohlcv_1m', 'ohlcv_1h', 'ohlcv_1d', 'btc', 'cob', 'indicators'],
|
||||
'pivot_price': result.predictions.get('pivot_price'),
|
||||
'extrema_prediction': result.predictions.get('extrema'),
|
||||
'price_prediction': result.predictions.get('price_prediction')
|
||||
'price_prediction': price_prediction,
|
||||
'extrema_prediction': extrema_pred.squeeze(0).cpu().numpy().tolist() if extrema_pred is not None else None
|
||||
}
|
||||
)
|
||||
predictions.append(prediction)
|
||||
|
||||
# Store prediction in database for training
|
||||
logger.debug(f"Added CNN prediction to database: {prediction}")
|
||||
|
||||
# Note: Inference data will be stored in main prediction loop to avoid duplication
|
||||
logger.debug(f"Added CNN prediction: {action} ({confidence:.3f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error using CNN adapter: {e}")
|
||||
logger.error(f"Error using direct CNN model: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Fallback to direct model inference using BaseDataInput (unified approach)
|
||||
if not predictions:
|
||||
@ -2689,7 +2742,7 @@ class TradingOrchestrator:
|
||||
logger.info(f"CNN fallback successful for {symbol}: {best_action} (confidence: {confidence:.3f})")
|
||||
|
||||
else:
|
||||
logger.warning(f"CNN model {model.name} does not have act() method for fallback")
|
||||
logger.debug(f"CNN model {model.name} fallback not needed - direct inference succeeded")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CNN fallback inference failed for {symbol}: {e}")
|
||||
|
Reference in New Issue
Block a user