Files
gogo2/core/unified_training_manager.py
2025-08-23 16:27:05 +03:00

131 lines
5.4 KiB
Python

"""
Unified Training Manager
Combines the previous built-in (normal) training and the EnhancedRealtimeTrainingSystem ideas
into a single orchestrator-agnostic manager. Keeps orchestrator lean by moving training logic here.
Key responsibilities:
- Subscribe to model predictions/outcomes and perform online updates (DQN/COB RL/CNN)
- Schedule periodic training (intervals) and replay-based training
- Integrate with Enhanced Reward System for evaluated rewards
- Work regardless of enhanced system availability
"""
import asyncio
import logging
import time
from typing import Any, Dict, List, Optional, Tuple
from core.enhanced_reward_calculator import TimeFrame
logger = logging.getLogger(__name__)
class UnifiedTrainingManager:
"""Unified training controller decoupled from the orchestrator."""
def __init__(
self,
orchestrator: Any,
reward_system: Any = None,
dqn_interval_s: int = 5,
cob_rl_interval_s: int = 1,
cnn_interval_s: int = 10,
min_dqn_experiences: int = 16,
):
self.orchestrator = orchestrator
self.reward_system = reward_system
self.dqn_interval_s = dqn_interval_s
self.cob_rl_interval_s = cob_rl_interval_s
self.cnn_interval_s = cnn_interval_s
self.min_dqn_experiences = min_dqn_experiences
self.running = False
self._tasks: List[asyncio.Task] = []
async def start(self):
if self.running:
logger.warning("UnifiedTrainingManager already running")
return
self.running = True
logger.info("UnifiedTrainingManager started")
# Periodic trainers
self._tasks.append(asyncio.create_task(self._dqn_trainer_loop()))
self._tasks.append(asyncio.create_task(self._cob_rl_trainer_loop()))
self._tasks.append(asyncio.create_task(self._cnn_trainer_loop()))
# Reward-driven trainer
if self.reward_system is not None:
self._tasks.append(asyncio.create_task(self._reward_driven_training_loop()))
async def stop(self):
self.running = False
for t in self._tasks:
t.cancel()
await asyncio.gather(*self._tasks, return_exceptions=True)
self._tasks.clear()
logger.info("UnifiedTrainingManager stopped")
async def _dqn_trainer_loop(self):
while self.running:
try:
rl_agent = getattr(self.orchestrator, 'rl_agent', None)
if rl_agent and hasattr(rl_agent, 'memory'):
if len(rl_agent.memory) >= self.min_dqn_experiences and hasattr(rl_agent, 'replay'):
loss = rl_agent.replay()
if loss is not None:
logger.debug(f"DQN replay loss: {loss}")
await asyncio.sleep(self.dqn_interval_s)
except Exception as e:
logger.error(f"DQN trainer loop error: {e}")
await asyncio.sleep(self.dqn_interval_s)
async def _cob_rl_trainer_loop(self):
while self.running:
try:
cob_agent = getattr(self.orchestrator, 'cob_rl_agent', None)
if cob_agent and hasattr(cob_agent, 'replay') and hasattr(cob_agent, 'memory'):
if len(getattr(cob_agent, 'memory', [])) >= 8:
loss = cob_agent.replay()
if loss is not None:
logger.debug(f"COB RL replay loss: {loss}")
await asyncio.sleep(self.cob_rl_interval_s)
except Exception as e:
logger.error(f"COB RL trainer loop error: {e}")
await asyncio.sleep(self.cob_rl_interval_s)
async def _cnn_trainer_loop(self):
while self.running:
try:
# Placeholder: hook to your CNN trainer if available
await asyncio.sleep(self.cnn_interval_s)
except Exception as e:
logger.error(f"CNN trainer loop error: {e}")
await asyncio.sleep(self.cnn_interval_s)
async def _reward_driven_training_loop(self):
while self.running:
try:
# Pull evaluated samples and feed to respective models
symbols = getattr(self.reward_system.reward_calculator, 'symbols', []) if hasattr(self.reward_system, 'reward_calculator') else []
for sym in symbols:
# Use short horizon for fast feedback
samples = self.reward_system.reward_calculator.get_training_data(sym, TimeFrame.SECONDS_1, max_samples=64)
if not samples:
continue
# Currently DQN batch: add to memory and let replay loop train
rl_agent = getattr(self.orchestrator, 'rl_agent', None)
if rl_agent and hasattr(rl_agent, 'remember'):
for rec, reward in samples:
# Use state vector captured at inference time when available
state = rec.state_vector if getattr(rec, 'state_vector', None) else []
if not state:
continue
action = rec.predicted_direction + 1
rl_agent.remember(state, action, reward, state, True)
await asyncio.sleep(2)
except Exception as e:
logger.error(f"Reward-driven training loop error: {e}")
await asyncio.sleep(5)