stability
This commit is contained in:
@ -230,7 +230,8 @@ class TradingOrchestrator:
|
|||||||
self.model_states['dqn']['checkpoint_loaded'] = True
|
self.model_states['dqn']['checkpoint_loaded'] = True
|
||||||
self.model_states['dqn']['checkpoint_filename'] = metadata.checkpoint_id
|
self.model_states['dqn']['checkpoint_filename'] = metadata.checkpoint_id
|
||||||
checkpoint_loaded = True
|
checkpoint_loaded = True
|
||||||
logger.info(f"DQN checkpoint loaded: {metadata.checkpoint_id} (loss={metadata.loss:.4f})")
|
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "N/A"
|
||||||
|
logger.info(f"DQN checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error loading DQN checkpoint: {e}")
|
logger.warning(f"Error loading DQN checkpoint: {e}")
|
||||||
|
|
||||||
@ -269,7 +270,8 @@ class TradingOrchestrator:
|
|||||||
self.model_states['cnn']['checkpoint_loaded'] = True
|
self.model_states['cnn']['checkpoint_loaded'] = True
|
||||||
self.model_states['cnn']['checkpoint_filename'] = metadata.checkpoint_id
|
self.model_states['cnn']['checkpoint_filename'] = metadata.checkpoint_id
|
||||||
checkpoint_loaded = True
|
checkpoint_loaded = True
|
||||||
logger.info(f"CNN checkpoint loaded: {metadata.checkpoint_id} (loss={metadata.loss:.4f})")
|
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "N/A"
|
||||||
|
logger.info(f"CNN checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error loading CNN checkpoint: {e}")
|
logger.warning(f"Error loading CNN checkpoint: {e}")
|
||||||
|
|
||||||
@ -356,7 +358,8 @@ class TradingOrchestrator:
|
|||||||
self.model_states['cob_rl']['checkpoint_loaded'] = True
|
self.model_states['cob_rl']['checkpoint_loaded'] = True
|
||||||
self.model_states['cob_rl']['checkpoint_filename'] = metadata.checkpoint_id
|
self.model_states['cob_rl']['checkpoint_filename'] = metadata.checkpoint_id
|
||||||
checkpoint_loaded = True
|
checkpoint_loaded = True
|
||||||
logger.info(f"COB RL checkpoint loaded: {metadata.checkpoint_id} (loss={metadata.loss:.4f})")
|
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "N/A"
|
||||||
|
logger.info(f"COB RL checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error loading COB RL checkpoint: {e}")
|
logger.warning(f"Error loading COB RL checkpoint: {e}")
|
||||||
|
|
||||||
@ -547,7 +550,7 @@ class TradingOrchestrator:
|
|||||||
if self.cob_integration:
|
if self.cob_integration:
|
||||||
try:
|
try:
|
||||||
logger.info("Attempting to start COB integration...")
|
logger.info("Attempting to start COB integration...")
|
||||||
await self.cob_integration.start_streaming()
|
await self.cob_integration.start()
|
||||||
logger.info("COB Integration streaming started successfully.")
|
logger.info("COB Integration streaming started successfully.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start COB integration streaming: {e}")
|
logger.error(f"Failed to start COB integration streaming: {e}")
|
||||||
|
@ -114,6 +114,34 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
|
|
||||||
logger.info("Enhanced Real-time Training System initialized")
|
logger.info("Enhanced Real-time Training System initialized")
|
||||||
|
|
||||||
|
def _get_dqn_state_features(self, symbol: str) -> Optional[np.ndarray]:
|
||||||
|
"""Get DQN state features from orchestrator"""
|
||||||
|
try:
|
||||||
|
if not self.orchestrator:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get DQN state from orchestrator if available
|
||||||
|
if hasattr(self.orchestrator, 'build_comprehensive_rl_state'):
|
||||||
|
rl_state = self.orchestrator.build_comprehensive_rl_state(symbol)
|
||||||
|
if rl_state and 'state_vector' in rl_state:
|
||||||
|
return np.array(rl_state['state_vector'], dtype=np.float32)
|
||||||
|
|
||||||
|
# Fallback: create basic state from available data
|
||||||
|
if len(self.real_time_data['ohlcv_1m']) > 0:
|
||||||
|
latest_bar = self.real_time_data['ohlcv_1m'][-1]
|
||||||
|
basic_state = [
|
||||||
|
latest_bar.get('close', 0) / 10000.0, # Normalized price
|
||||||
|
latest_bar.get('volume', 0) / 1000000.0, # Normalized volume
|
||||||
|
0.0, 0.0, 0.0 # Placeholder features
|
||||||
|
]
|
||||||
|
return np.array(basic_state, dtype=np.float32)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error getting DQN state features for {symbol}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def start_training(self):
|
def start_training(self):
|
||||||
"""Start the enhanced real-time training system"""
|
"""Start the enhanced real-time training system"""
|
||||||
if self.is_training:
|
if self.is_training:
|
||||||
@ -1885,7 +1913,7 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
and self.orchestrator.rl_agent):
|
and self.orchestrator.rl_agent):
|
||||||
|
|
||||||
# Use RL agent to make prediction
|
# Use RL agent to make prediction
|
||||||
current_state = self._get_dqn_state(symbol)
|
current_state = self._get_dqn_state_features(symbol)
|
||||||
if current_state is None:
|
if current_state is None:
|
||||||
return
|
return
|
||||||
action = self.orchestrator.rl_agent.act(current_state, explore=False)
|
action = self.orchestrator.rl_agent.act(current_state, explore=False)
|
||||||
|
@ -5205,9 +5205,7 @@ class CleanTradingDashboard:
|
|||||||
if self.orchestrator and hasattr(self.orchestrator, 'add_decision_callback'):
|
if self.orchestrator and hasattr(self.orchestrator, 'add_decision_callback'):
|
||||||
def connect_worker():
|
def connect_worker():
|
||||||
try:
|
try:
|
||||||
loop = asyncio.new_event_loop()
|
self.orchestrator.add_decision_callback(self._on_trading_decision)
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
loop.run_until_complete(self.orchestrator.add_decision_callback(self._on_trading_decision))
|
|
||||||
logger.info("Successfully connected to orchestrator for trading signals.")
|
logger.info("Successfully connected to orchestrator for trading signals.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Orchestrator connection worker failed: {e}")
|
logger.error(f"Orchestrator connection worker failed: {e}")
|
||||||
|
Reference in New Issue
Block a user