enable training
This commit is contained in:
@@ -1330,110 +1330,6 @@ class TradingOrchestrator:
|
|||||||
logger.debug(f"Error building COB state for {symbol}: {e}")
|
logger.debug(f"Error building COB state for {symbol}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _combine_predictions(self, symbol: str, price: float, predictions: List[Prediction],
|
|
||||||
timestamp: datetime) -> TradingDecision:
|
|
||||||
# Call act_with_confidence and handle different return formats
|
|
||||||
result = model.model.act_with_confidence(state)
|
|
||||||
|
|
||||||
if len(result) == 3:
|
|
||||||
# EnhancedCNN format: (action, confidence, q_values)
|
|
||||||
action_idx, confidence, raw_q_values = result
|
|
||||||
elif len(result) == 2:
|
|
||||||
# DQN format: (action, confidence)
|
|
||||||
action_idx, confidence = result
|
|
||||||
raw_q_values = None
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unexpected result format from RL model: {result}")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
# Fallback to standard act method
|
|
||||||
action_idx = model.model.act(state)
|
|
||||||
confidence = 0.6 # Default confidence
|
|
||||||
raw_q_values = None
|
|
||||||
|
|
||||||
# Convert action index to action name
|
|
||||||
action_names = ['BUY', 'SELL', 'HOLD']
|
|
||||||
if 0 <= action_idx < len(action_names):
|
|
||||||
action = action_names[action_idx]
|
|
||||||
else:
|
|
||||||
logger.warning(f"Invalid action index from RL model: {action_idx}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Store prediction in database for tracking
|
|
||||||
if (hasattr(self, 'enhanced_training_system') and
|
|
||||||
self.enhanced_training_system and
|
|
||||||
hasattr(self.enhanced_training_system, 'store_model_prediction')):
|
|
||||||
|
|
||||||
current_price = self._get_current_price_safe(symbol)
|
|
||||||
if current_price > 0:
|
|
||||||
prediction_id = self.enhanced_training_system.store_model_prediction(
|
|
||||||
model_name=f"DQN_{model.model_name}" if hasattr(model, 'model_name') else "DQN",
|
|
||||||
symbol=symbol,
|
|
||||||
prediction_type=action,
|
|
||||||
confidence=confidence,
|
|
||||||
current_price=current_price
|
|
||||||
)
|
|
||||||
logger.debug(f"Stored DQN prediction {prediction_id} for {symbol}")
|
|
||||||
|
|
||||||
# Create prediction object
|
|
||||||
prediction = Prediction(
|
|
||||||
model_name=f"DQN_{model.model_name}" if hasattr(model, 'model_name') else "DQN",
|
|
||||||
symbol=symbol,
|
|
||||||
signal=action,
|
|
||||||
confidence=confidence,
|
|
||||||
reasoning=f"DQN agent prediction with Q-values: {raw_q_values}",
|
|
||||||
features=state.tolist() if isinstance(state, np.ndarray) else [],
|
|
||||||
metadata={
|
|
||||||
'action_idx': action_idx,
|
|
||||||
'q_values': raw_q_values.tolist() if raw_q_values is not None else None,
|
|
||||||
'state_size': len(state) if state is not None else 0
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return prediction
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting RL prediction for {symbol}: {e}")
|
|
||||||
return None
|
|
||||||
raw_q_values = None # No raw q_values from simple act
|
|
||||||
else:
|
|
||||||
logger.error(f"RL model {model.name} has no act method")
|
|
||||||
return None
|
|
||||||
|
|
||||||
action_names = ['SELL', 'HOLD', 'BUY']
|
|
||||||
action = action_names[action_idx]
|
|
||||||
|
|
||||||
# Convert raw_q_values to list if they are a tensor
|
|
||||||
q_values_for_capture = None
|
|
||||||
if raw_q_values is not None and hasattr(raw_q_values, 'tolist'):
|
|
||||||
q_values_for_capture = raw_q_values.tolist()
|
|
||||||
elif raw_q_values is not None and isinstance(raw_q_values, list):
|
|
||||||
q_values_for_capture = raw_q_values
|
|
||||||
|
|
||||||
# Create prediction object
|
|
||||||
prediction = Prediction(
|
|
||||||
action=action,
|
|
||||||
confidence=float(confidence),
|
|
||||||
# Use actual q_values if available, otherwise default probabilities
|
|
||||||
probabilities={action_names[i]: float(q_values_for_capture[i]) if q_values_for_capture else (1.0 / len(action_names)) for i in range(len(action_names))},
|
|
||||||
timeframe='mixed', # RL uses mixed timeframes
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
model_name=model.name,
|
|
||||||
metadata={'state_size': len(state)}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Capture DQN prediction for dashboard visualization
|
|
||||||
current_price = self._get_current_price(symbol)
|
|
||||||
if current_price:
|
|
||||||
# Only pass q_values if they exist, otherwise pass empty list
|
|
||||||
q_values_to_pass = q_values_for_capture if q_values_for_capture is not None else []
|
|
||||||
self.capture_dqn_prediction(symbol, action_idx, float(confidence), current_price, q_values_to_pass)
|
|
||||||
|
|
||||||
return prediction
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting RL prediction: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _get_generic_prediction(self, model: ModelInterface, symbol: str) -> Optional[Prediction]:
|
async def _get_generic_prediction(self, model: ModelInterface, symbol: str) -> Optional[Prediction]:
|
||||||
"""Get prediction from generic model"""
|
"""Get prediction from generic model"""
|
||||||
@@ -1747,7 +1643,7 @@ class TradingOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if needs_refresh:
|
if needs_refresh:
|
||||||
result = load_best_checkpoint(model_name)
|
result = load_best_checkpoint(model_name)
|
||||||
self._checkpoint_cache[model_name] = result
|
self._checkpoint_cache[model_name] = result
|
||||||
self._checkpoint_cache_time[model_name] = current_time
|
self._checkpoint_cache_time[model_name] = current_time
|
||||||
|
|
||||||
@@ -1892,15 +1788,15 @@ class TradingOrchestrator:
|
|||||||
# Initialize enhanced training system directly (no external training_integration module needed)
|
# Initialize enhanced training system directly (no external training_integration module needed)
|
||||||
try:
|
try:
|
||||||
from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
||||||
|
|
||||||
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
|
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
|
||||||
orchestrator=self,
|
orchestrator=self,
|
||||||
data_provider=self.data_provider,
|
data_provider=self.data_provider,
|
||||||
dashboard=None
|
dashboard=None
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("✅ Enhanced training system initialized successfully")
|
logger.info("✅ Enhanced training system initialized successfully")
|
||||||
|
|
||||||
# Auto-start training by default
|
# Auto-start training by default
|
||||||
logger.info("🚀 Auto-starting enhanced real-time training...")
|
logger.info("🚀 Auto-starting enhanced real-time training...")
|
||||||
self.start_enhanced_training()
|
self.start_enhanced_training()
|
||||||
|
@@ -33,7 +33,7 @@ def create_safe_orchestrator() -> Optional[TradingOrchestrator]:
|
|||||||
try:
|
try:
|
||||||
# Create orchestrator with basic configuration (uses correct constructor parameters)
|
# Create orchestrator with basic configuration (uses correct constructor parameters)
|
||||||
orchestrator = TradingOrchestrator(
|
orchestrator = TradingOrchestrator(
|
||||||
enhanced_rl_training=False # Disable problematic training initially
|
enhanced_rl_training=True # Enable RL training for model improvement
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Trading orchestrator created successfully")
|
logger.info("Trading orchestrator created successfully")
|
||||||
|
Reference in New Issue
Block a user