enable training

This commit is contained in:
Dobromir Popov
2025-09-08 12:13:50 +03:00
parent 4fe952dbee
commit 060fdd28b4
2 changed files with 8 additions and 112 deletions

View File

@@ -1330,110 +1330,6 @@ class TradingOrchestrator:
logger.debug(f"Error building COB state for {symbol}: {e}")
return None
def _combine_predictions(self, symbol: str, price: float, predictions: List[Prediction],
timestamp: datetime) -> TradingDecision:
# Call act_with_confidence and handle different return formats
result = model.model.act_with_confidence(state)
if len(result) == 3:
# EnhancedCNN format: (action, confidence, q_values)
action_idx, confidence, raw_q_values = result
elif len(result) == 2:
# DQN format: (action, confidence)
action_idx, confidence = result
raw_q_values = None
else:
logger.warning(f"Unexpected result format from RL model: {result}")
return None
else:
# Fallback to standard act method
action_idx = model.model.act(state)
confidence = 0.6 # Default confidence
raw_q_values = None
# Convert action index to action name
action_names = ['BUY', 'SELL', 'HOLD']
if 0 <= action_idx < len(action_names):
action = action_names[action_idx]
else:
logger.warning(f"Invalid action index from RL model: {action_idx}")
return None
# Store prediction in database for tracking
if (hasattr(self, 'enhanced_training_system') and
self.enhanced_training_system and
hasattr(self.enhanced_training_system, 'store_model_prediction')):
current_price = self._get_current_price_safe(symbol)
if current_price > 0:
prediction_id = self.enhanced_training_system.store_model_prediction(
model_name=f"DQN_{model.model_name}" if hasattr(model, 'model_name') else "DQN",
symbol=symbol,
prediction_type=action,
confidence=confidence,
current_price=current_price
)
logger.debug(f"Stored DQN prediction {prediction_id} for {symbol}")
# Create prediction object
prediction = Prediction(
model_name=f"DQN_{model.model_name}" if hasattr(model, 'model_name') else "DQN",
symbol=symbol,
signal=action,
confidence=confidence,
reasoning=f"DQN agent prediction with Q-values: {raw_q_values}",
features=state.tolist() if isinstance(state, np.ndarray) else [],
metadata={
'action_idx': action_idx,
'q_values': raw_q_values.tolist() if raw_q_values is not None else None,
'state_size': len(state) if state is not None else 0
}
)
return prediction
except Exception as e:
logger.error(f"Error getting RL prediction for {symbol}: {e}")
return None
raw_q_values = None # No raw q_values from simple act
else:
logger.error(f"RL model {model.name} has no act method")
return None
action_names = ['SELL', 'HOLD', 'BUY']
action = action_names[action_idx]
# Convert raw_q_values to list if they are a tensor
q_values_for_capture = None
if raw_q_values is not None and hasattr(raw_q_values, 'tolist'):
q_values_for_capture = raw_q_values.tolist()
elif raw_q_values is not None and isinstance(raw_q_values, list):
q_values_for_capture = raw_q_values
# Create prediction object
prediction = Prediction(
action=action,
confidence=float(confidence),
# Use actual q_values if available, otherwise default probabilities
probabilities={action_names[i]: float(q_values_for_capture[i]) if q_values_for_capture else (1.0 / len(action_names)) for i in range(len(action_names))},
timeframe='mixed', # RL uses mixed timeframes
timestamp=datetime.now(),
model_name=model.name,
metadata={'state_size': len(state)}
)
# Capture DQN prediction for dashboard visualization
current_price = self._get_current_price(symbol)
if current_price:
# Only pass q_values if they exist, otherwise pass empty list
q_values_to_pass = q_values_for_capture if q_values_for_capture is not None else []
self.capture_dqn_prediction(symbol, action_idx, float(confidence), current_price, q_values_to_pass)
return prediction
except Exception as e:
logger.error(f"Error getting RL prediction: {e}")
return None
async def _get_generic_prediction(self, model: ModelInterface, symbol: str) -> Optional[Prediction]:
"""Get prediction from generic model"""

View File

@@ -33,7 +33,7 @@ def create_safe_orchestrator() -> Optional[TradingOrchestrator]:
try:
# Create orchestrator with basic configuration (uses correct constructor parameters)
orchestrator = TradingOrchestrator(
enhanced_rl_training=False # Disable problematic training initially
enhanced_rl_training=True # Enable RL training for model improvement
)
logger.info("Trading orchestrator created successfully")