training fixes and enhancements wip
This commit is contained in:
@@ -437,13 +437,34 @@ class TradingOrchestrator:
|
||||
|
||||
def predict(self, data=None):
|
||||
try:
|
||||
# ExtremaTrainer provides context features, not a direct prediction
|
||||
# We assume 'data' here is the 'symbol' string to pass to get_context_features_for_model
|
||||
if not isinstance(data, str):
|
||||
logger.warning(f"ExtremaTrainerInterface.predict received non-string data: {type(data)}. Cannot get context features.")
|
||||
# Handle different data types that might be passed to ExtremaTrainer
|
||||
symbol = None
|
||||
|
||||
if isinstance(data, str):
|
||||
# Direct symbol string
|
||||
symbol = data
|
||||
elif isinstance(data, dict):
|
||||
# Dictionary with symbol information
|
||||
symbol = data.get('symbol')
|
||||
elif isinstance(data, np.ndarray):
|
||||
# Numpy array - extract symbol from metadata or use default
|
||||
# For now, use the first symbol from the model's symbols list
|
||||
if hasattr(self.model, 'symbols') and self.model.symbols:
|
||||
symbol = self.model.symbols[0]
|
||||
else:
|
||||
symbol = 'ETH/USDT' # Default fallback
|
||||
else:
|
||||
# Unknown data type - use default symbol
|
||||
if hasattr(self.model, 'symbols') and self.model.symbols:
|
||||
symbol = self.model.symbols[0]
|
||||
else:
|
||||
symbol = 'ETH/USDT' # Default fallback
|
||||
|
||||
if not symbol:
|
||||
logger.warning(f"ExtremaTrainerInterface.predict could not determine symbol from data: {type(data)}")
|
||||
return None
|
||||
|
||||
features = self.model.get_context_features_for_model(symbol=data)
|
||||
features = self.model.get_context_features_for_model(symbol=symbol)
|
||||
if features is not None and features.size > 0:
|
||||
# The presence of features indicates a signal. We'll return a generic HOLD
|
||||
# with a neutral confidence. This can be refined if ExtremaTrainer provides
|
||||
|
||||
@@ -134,8 +134,8 @@ class TrainingIntegration:
|
||||
|
||||
# Store experience in DQN memory
|
||||
dqn_agent = self.orchestrator.dqn_agent
|
||||
if hasattr(dqn_agent, 'store_experience'):
|
||||
dqn_agent.store_experience(
|
||||
if hasattr(dqn_agent, 'remember'):
|
||||
dqn_agent.remember(
|
||||
state=np.array(dqn_state),
|
||||
action=action_idx,
|
||||
reward=reward,
|
||||
@@ -145,7 +145,7 @@ class TrainingIntegration:
|
||||
|
||||
# Trigger training if enough experiences
|
||||
if hasattr(dqn_agent, 'replay') and len(getattr(dqn_agent, 'memory', [])) > 32:
|
||||
dqn_agent.replay(batch_size=32)
|
||||
dqn_agent.replay()
|
||||
logger.info("DQN training step completed")
|
||||
|
||||
return True
|
||||
@@ -345,7 +345,7 @@ class TrainingIntegration:
|
||||
# Perform training step if agent has replay method
|
||||
if hasattr(cob_rl_agent, 'replay') and hasattr(cob_rl_agent, 'memory'):
|
||||
if len(cob_rl_agent.memory) > 32: # Enough samples to train
|
||||
loss = cob_rl_agent.replay(batch_size=min(32, len(cob_rl_agent.memory)))
|
||||
loss = cob_rl_agent.replay()
|
||||
if loss is not None:
|
||||
logger.info(f"COB RL trained on trade outcome: P&L=${pnl:.2f}, loss={loss:.4f}")
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user