stability
This commit is contained in:
@ -114,6 +114,34 @@ class EnhancedRealtimeTrainingSystem:
|
||||
|
||||
logger.info("Enhanced Real-time Training System initialized")
|
||||
|
||||
def _get_dqn_state_features(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Get DQN state features from orchestrator"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
return None
|
||||
|
||||
# Get DQN state from orchestrator if available
|
||||
if hasattr(self.orchestrator, 'build_comprehensive_rl_state'):
|
||||
rl_state = self.orchestrator.build_comprehensive_rl_state(symbol)
|
||||
if rl_state and 'state_vector' in rl_state:
|
||||
return np.array(rl_state['state_vector'], dtype=np.float32)
|
||||
|
||||
# Fallback: create basic state from available data
|
||||
if len(self.real_time_data['ohlcv_1m']) > 0:
|
||||
latest_bar = self.real_time_data['ohlcv_1m'][-1]
|
||||
basic_state = [
|
||||
latest_bar.get('close', 0) / 10000.0, # Normalized price
|
||||
latest_bar.get('volume', 0) / 1000000.0, # Normalized volume
|
||||
0.0, 0.0, 0.0 # Placeholder features
|
||||
]
|
||||
return np.array(basic_state, dtype=np.float32)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting DQN state features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def start_training(self):
|
||||
"""Start the enhanced real-time training system"""
|
||||
if self.is_training:
|
||||
@ -1885,7 +1913,7 @@ class EnhancedRealtimeTrainingSystem:
|
||||
and self.orchestrator.rl_agent):
|
||||
|
||||
# Use RL agent to make prediction
|
||||
current_state = self._get_dqn_state(symbol)
|
||||
current_state = self._get_dqn_state_features(symbol)
|
||||
if current_state is None:
|
||||
return
|
||||
action = self.orchestrator.rl_agent.act(current_state, explore=False)
|
||||
|
Reference in New Issue
Block a user