more MOCK/placeholder training functions replaced with real implementations
This commit is contained in:
@ -147,44 +147,135 @@ class TrainingIntegration:
|
||||
return False
|
||||
|
||||
def _train_cnn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
|
||||
"""Train CNN on trade outcome (placeholder)"""
|
||||
"""Train CNN on trade outcome with real implementation"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
return False
|
||||
|
||||
# Check if CNN is available
|
||||
if not hasattr(self.orchestrator, 'williams_cnn') or not self.orchestrator.williams_cnn:
|
||||
cnn_model = None
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
cnn_model = self.orchestrator.cnn_model
|
||||
elif hasattr(self.orchestrator, 'williams_cnn') and self.orchestrator.williams_cnn:
|
||||
cnn_model = self.orchestrator.williams_cnn
|
||||
|
||||
if not cnn_model:
|
||||
logger.debug("CNN not available for training")
|
||||
return False
|
||||
|
||||
# Get CNN features from model inputs
|
||||
model_inputs = trade_record.get('model_inputs_at_entry', {})
|
||||
cnn_features = model_inputs.get('cnn_features')
|
||||
cnn_predictions = model_inputs.get('cnn_predictions')
|
||||
|
||||
if not cnn_features or not cnn_predictions:
|
||||
if not cnn_features:
|
||||
logger.debug("No CNN features available for training")
|
||||
return False
|
||||
|
||||
# CNN training would go here - requires more specific implementation
|
||||
# For now, just log that we could train CNN
|
||||
logger.debug(f"CNN training opportunity: features={len(cnn_features)}, predictions={len(cnn_predictions)}")
|
||||
# Determine target based on trade outcome
|
||||
pnl = trade_record.get('pnl', 0)
|
||||
action = trade_record.get('side', 'HOLD').upper()
|
||||
|
||||
return True
|
||||
# Create target based on trade success
|
||||
if pnl > 0:
|
||||
if action == 'BUY':
|
||||
target = 0 # Successful BUY
|
||||
elif action == 'SELL':
|
||||
target = 1 # Successful SELL
|
||||
else:
|
||||
target = 2 # HOLD
|
||||
else:
|
||||
# For unsuccessful trades, learn the opposite
|
||||
if action == 'BUY':
|
||||
target = 1 # Should have been SELL
|
||||
elif action == 'SELL':
|
||||
target = 0 # Should have been BUY
|
||||
else:
|
||||
target = 2 # HOLD
|
||||
|
||||
# Initialize model attributes if needed
|
||||
if not hasattr(cnn_model, 'optimizer'):
|
||||
import torch
|
||||
cnn_model.optimizer = torch.optim.Adam(cnn_model.parameters(), lr=0.001)
|
||||
|
||||
# Perform actual CNN training
|
||||
try:
|
||||
import torch
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Prepare features
|
||||
if isinstance(cnn_features, list):
|
||||
features = np.array(cnn_features, dtype=np.float32)
|
||||
else:
|
||||
features = np.array(cnn_features, dtype=np.float32)
|
||||
|
||||
# Ensure features are the right size
|
||||
if len(features) < 50:
|
||||
# Pad with zeros
|
||||
padded_features = np.zeros(50)
|
||||
padded_features[:len(features)] = features
|
||||
features = padded_features
|
||||
elif len(features) > 50:
|
||||
# Truncate
|
||||
features = features[:50]
|
||||
|
||||
# Create tensors
|
||||
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(device)
|
||||
target_tensor = torch.LongTensor([target]).to(device)
|
||||
|
||||
# Training step
|
||||
cnn_model.train()
|
||||
cnn_model.optimizer.zero_grad()
|
||||
|
||||
outputs = cnn_model(features_tensor)
|
||||
|
||||
# Handle different output formats
|
||||
if isinstance(outputs, dict):
|
||||
if 'main_output' in outputs:
|
||||
logits = outputs['main_output']
|
||||
elif 'action_logits' in outputs:
|
||||
logits = outputs['action_logits']
|
||||
else:
|
||||
logits = list(outputs.values())[0]
|
||||
else:
|
||||
logits = outputs
|
||||
|
||||
# Calculate loss with reward weighting
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
loss = loss_fn(logits, target_tensor)
|
||||
|
||||
# Weight loss by reward magnitude
|
||||
weighted_loss = loss * abs(reward)
|
||||
|
||||
# Backward pass
|
||||
weighted_loss.backward()
|
||||
cnn_model.optimizer.step()
|
||||
|
||||
logger.info(f"CNN trained on trade outcome: P&L=${pnl:.2f}, loss={loss.item():.4f}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN training step: {e}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in CNN training: {e}")
|
||||
logger.error(f"Error in CNN training: {e}")
|
||||
return False
|
||||
|
||||
def _train_cob_rl_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
|
||||
"""Train COB RL on trade outcome (placeholder)"""
|
||||
"""Train COB RL on trade outcome with real implementation"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
return False
|
||||
|
||||
# Check if COB integration is available
|
||||
if not hasattr(self.orchestrator, 'cob_integration') or not self.orchestrator.cob_integration:
|
||||
logger.debug("COB integration not available for training")
|
||||
# Check if COB RL agent is available
|
||||
cob_rl_agent = None
|
||||
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
cob_rl_agent = self.orchestrator.rl_agent
|
||||
elif hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
||||
cob_rl_agent = self.orchestrator.cob_rl_agent
|
||||
|
||||
if not cob_rl_agent:
|
||||
logger.debug("COB RL agent not available for training")
|
||||
return False
|
||||
|
||||
# Get COB features from model inputs
|
||||
@ -195,14 +286,64 @@ class TrainingIntegration:
|
||||
logger.debug("No COB features available for training")
|
||||
return False
|
||||
|
||||
# COB RL training would go here - requires more specific implementation
|
||||
# For now, just log that we could train COB RL
|
||||
logger.debug(f"COB RL training opportunity: features={len(cob_features)}")
|
||||
# Create state from COB features
|
||||
if isinstance(cob_features, list):
|
||||
state_features = np.array(cob_features, dtype=np.float32)
|
||||
else:
|
||||
state_features = np.array(cob_features, dtype=np.float32)
|
||||
|
||||
# Pad or truncate to expected size
|
||||
if hasattr(cob_rl_agent, 'state_shape'):
|
||||
expected_size = cob_rl_agent.state_shape if isinstance(cob_rl_agent.state_shape, int) else cob_rl_agent.state_shape[0]
|
||||
else:
|
||||
expected_size = 100 # Default size
|
||||
|
||||
if len(state_features) < expected_size:
|
||||
# Pad with zeros
|
||||
padded_features = np.zeros(expected_size)
|
||||
padded_features[:len(state_features)] = state_features
|
||||
state_features = padded_features
|
||||
elif len(state_features) > expected_size:
|
||||
# Truncate
|
||||
state_features = state_features[:expected_size]
|
||||
|
||||
state = np.array(state_features, dtype=np.float32)
|
||||
|
||||
# Determine action from trade record
|
||||
action_str = trade_record.get('side', 'HOLD').upper()
|
||||
if action_str == 'BUY':
|
||||
action = 0
|
||||
elif action_str == 'SELL':
|
||||
action = 1
|
||||
else:
|
||||
action = 2 # HOLD
|
||||
|
||||
# Create next state (similar to current state for simplicity)
|
||||
next_state = state.copy()
|
||||
|
||||
# Use PnL as reward
|
||||
pnl = trade_record.get('pnl', 0)
|
||||
actual_reward = float(pnl * 100) # Scale reward
|
||||
|
||||
# Store experience in agent memory
|
||||
if hasattr(cob_rl_agent, 'remember'):
|
||||
cob_rl_agent.remember(state, action, actual_reward, next_state, done=True)
|
||||
elif hasattr(cob_rl_agent, 'store_experience'):
|
||||
cob_rl_agent.store_experience(state, action, actual_reward, next_state, done=True)
|
||||
|
||||
# Perform training step if agent has replay method
|
||||
if hasattr(cob_rl_agent, 'replay') and hasattr(cob_rl_agent, 'memory'):
|
||||
if len(cob_rl_agent.memory) > 32: # Enough samples to train
|
||||
loss = cob_rl_agent.replay(batch_size=min(32, len(cob_rl_agent.memory)))
|
||||
if loss is not None:
|
||||
logger.info(f"COB RL trained on trade outcome: P&L=${pnl:.2f}, loss={loss:.4f}")
|
||||
return True
|
||||
|
||||
logger.debug(f"COB RL experience stored: P&L=${pnl:.2f}, reward={actual_reward:.2f}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in COB RL training: {e}")
|
||||
logger.error(f"Error in COB RL training: {e}")
|
||||
return False
|
||||
|
||||
def get_training_status(self) -> Dict[str, Any]:
|
||||
|
Reference in New Issue
Block a user