""" Unified Training Manager Combines the previous built-in (normal) training and the EnhancedRealtimeTrainingSystem ideas into a single orchestrator-agnostic manager. Keeps orchestrator lean by moving training logic here. Key responsibilities: - Subscribe to model predictions/outcomes and perform online updates (DQN/COB RL/CNN) - Schedule periodic training (intervals) and replay-based training - Integrate with Enhanced Reward System for evaluated rewards - Work regardless of enhanced system availability """ import asyncio import logging import time from typing import Any, Dict, List, Optional, Tuple from core.enhanced_reward_calculator import TimeFrame logger = logging.getLogger(__name__) class UnifiedTrainingManager: """Unified training controller decoupled from the orchestrator.""" def __init__( self, orchestrator: Any, reward_system: Any = None, dqn_interval_s: int = 5, cob_rl_interval_s: int = 1, cnn_interval_s: int = 10, min_dqn_experiences: int = 16, ): self.orchestrator = orchestrator self.reward_system = reward_system self.dqn_interval_s = dqn_interval_s self.cob_rl_interval_s = cob_rl_interval_s self.cnn_interval_s = cnn_interval_s self.min_dqn_experiences = min_dqn_experiences self.running = False self._tasks: List[asyncio.Task] = [] async def start(self): if self.running: logger.warning("UnifiedTrainingManager already running") return self.running = True logger.info("UnifiedTrainingManager started") # Periodic trainers self._tasks.append(asyncio.create_task(self._dqn_trainer_loop())) self._tasks.append(asyncio.create_task(self._cob_rl_trainer_loop())) self._tasks.append(asyncio.create_task(self._cnn_trainer_loop())) # Reward-driven trainer if self.reward_system is not None: self._tasks.append(asyncio.create_task(self._reward_driven_training_loop())) async def stop(self): self.running = False for t in self._tasks: t.cancel() await asyncio.gather(*self._tasks, return_exceptions=True) self._tasks.clear() logger.info("UnifiedTrainingManager stopped") async def _dqn_trainer_loop(self): while self.running: try: rl_agent = getattr(self.orchestrator, 'rl_agent', None) if rl_agent and hasattr(rl_agent, 'memory'): if len(rl_agent.memory) >= self.min_dqn_experiences and hasattr(rl_agent, 'replay'): loss = rl_agent.replay() if loss is not None: logger.debug(f"DQN replay loss: {loss}") await asyncio.sleep(self.dqn_interval_s) except Exception as e: logger.error(f"DQN trainer loop error: {e}") await asyncio.sleep(self.dqn_interval_s) async def _cob_rl_trainer_loop(self): while self.running: try: cob_agent = getattr(self.orchestrator, 'cob_rl_agent', None) if cob_agent and hasattr(cob_agent, 'replay') and hasattr(cob_agent, 'memory'): if len(getattr(cob_agent, 'memory', [])) >= 8: loss = cob_agent.replay() if loss is not None: logger.debug(f"COB RL replay loss: {loss}") await asyncio.sleep(self.cob_rl_interval_s) except Exception as e: logger.error(f"COB RL trainer loop error: {e}") await asyncio.sleep(self.cob_rl_interval_s) async def _cnn_trainer_loop(self): while self.running: try: # Placeholder: hook to your CNN trainer if available await asyncio.sleep(self.cnn_interval_s) except Exception as e: logger.error(f"CNN trainer loop error: {e}") await asyncio.sleep(self.cnn_interval_s) async def _reward_driven_training_loop(self): while self.running: try: # Pull evaluated samples and feed to respective models symbols = getattr(self.reward_system.reward_calculator, 'symbols', []) if hasattr(self.reward_system, 'reward_calculator') else [] for sym in symbols: # Use short horizon for fast feedback samples = self.reward_system.reward_calculator.get_training_data(sym, TimeFrame.SECONDS_1, max_samples=64) if not samples: continue # Currently DQN batch: add to memory and let replay loop train rl_agent = getattr(self.orchestrator, 'rl_agent', None) if rl_agent and hasattr(rl_agent, 'remember'): for rec, reward in samples: # Use state vector captured at inference time when available state = rec.state_vector if getattr(rec, 'state_vector', None) else [] if not state: continue action = rec.predicted_direction + 1 rl_agent.remember(state, action, reward, state, True) await asyncio.sleep(2) except Exception as e: logger.error(f"Reward-driven training loop error: {e}") await asyncio.sleep(5)