131 lines
5.4 KiB
Python
131 lines
5.4 KiB
Python
"""
|
|
Unified Training Manager
|
|
|
|
Combines the previous built-in (normal) training and the EnhancedRealtimeTrainingSystem ideas
|
|
into a single orchestrator-agnostic manager. Keeps orchestrator lean by moving training logic here.
|
|
|
|
Key responsibilities:
|
|
- Subscribe to model predictions/outcomes and perform online updates (DQN/COB RL/CNN)
|
|
- Schedule periodic training (intervals) and replay-based training
|
|
- Integrate with Enhanced Reward System for evaluated rewards
|
|
- Work regardless of enhanced system availability
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from core.enhanced_reward_calculator import TimeFrame
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class UnifiedTrainingManager:
|
|
"""Unified training controller decoupled from the orchestrator."""
|
|
|
|
def __init__(
|
|
self,
|
|
orchestrator: Any,
|
|
reward_system: Any = None,
|
|
dqn_interval_s: int = 5,
|
|
cob_rl_interval_s: int = 1,
|
|
cnn_interval_s: int = 10,
|
|
min_dqn_experiences: int = 16,
|
|
):
|
|
self.orchestrator = orchestrator
|
|
self.reward_system = reward_system
|
|
self.dqn_interval_s = dqn_interval_s
|
|
self.cob_rl_interval_s = cob_rl_interval_s
|
|
self.cnn_interval_s = cnn_interval_s
|
|
self.min_dqn_experiences = min_dqn_experiences
|
|
|
|
self.running = False
|
|
self._tasks: List[asyncio.Task] = []
|
|
|
|
async def start(self):
|
|
if self.running:
|
|
logger.warning("UnifiedTrainingManager already running")
|
|
return
|
|
self.running = True
|
|
logger.info("UnifiedTrainingManager started")
|
|
# Periodic trainers
|
|
self._tasks.append(asyncio.create_task(self._dqn_trainer_loop()))
|
|
self._tasks.append(asyncio.create_task(self._cob_rl_trainer_loop()))
|
|
self._tasks.append(asyncio.create_task(self._cnn_trainer_loop()))
|
|
# Reward-driven trainer
|
|
if self.reward_system is not None:
|
|
self._tasks.append(asyncio.create_task(self._reward_driven_training_loop()))
|
|
|
|
async def stop(self):
|
|
self.running = False
|
|
for t in self._tasks:
|
|
t.cancel()
|
|
await asyncio.gather(*self._tasks, return_exceptions=True)
|
|
self._tasks.clear()
|
|
logger.info("UnifiedTrainingManager stopped")
|
|
|
|
async def _dqn_trainer_loop(self):
|
|
while self.running:
|
|
try:
|
|
rl_agent = getattr(self.orchestrator, 'rl_agent', None)
|
|
if rl_agent and hasattr(rl_agent, 'memory'):
|
|
if len(rl_agent.memory) >= self.min_dqn_experiences and hasattr(rl_agent, 'replay'):
|
|
loss = rl_agent.replay()
|
|
if loss is not None:
|
|
logger.debug(f"DQN replay loss: {loss}")
|
|
await asyncio.sleep(self.dqn_interval_s)
|
|
except Exception as e:
|
|
logger.error(f"DQN trainer loop error: {e}")
|
|
await asyncio.sleep(self.dqn_interval_s)
|
|
|
|
async def _cob_rl_trainer_loop(self):
|
|
while self.running:
|
|
try:
|
|
cob_agent = getattr(self.orchestrator, 'cob_rl_agent', None)
|
|
if cob_agent and hasattr(cob_agent, 'replay') and hasattr(cob_agent, 'memory'):
|
|
if len(getattr(cob_agent, 'memory', [])) >= 8:
|
|
loss = cob_agent.replay()
|
|
if loss is not None:
|
|
logger.debug(f"COB RL replay loss: {loss}")
|
|
await asyncio.sleep(self.cob_rl_interval_s)
|
|
except Exception as e:
|
|
logger.error(f"COB RL trainer loop error: {e}")
|
|
await asyncio.sleep(self.cob_rl_interval_s)
|
|
|
|
async def _cnn_trainer_loop(self):
|
|
while self.running:
|
|
try:
|
|
# Placeholder: hook to your CNN trainer if available
|
|
await asyncio.sleep(self.cnn_interval_s)
|
|
except Exception as e:
|
|
logger.error(f"CNN trainer loop error: {e}")
|
|
await asyncio.sleep(self.cnn_interval_s)
|
|
|
|
async def _reward_driven_training_loop(self):
|
|
while self.running:
|
|
try:
|
|
# Pull evaluated samples and feed to respective models
|
|
symbols = getattr(self.reward_system.reward_calculator, 'symbols', []) if hasattr(self.reward_system, 'reward_calculator') else []
|
|
for sym in symbols:
|
|
# Use short horizon for fast feedback
|
|
samples = self.reward_system.reward_calculator.get_training_data(sym, TimeFrame.SECONDS_1, max_samples=64)
|
|
if not samples:
|
|
continue
|
|
# Currently DQN batch: add to memory and let replay loop train
|
|
rl_agent = getattr(self.orchestrator, 'rl_agent', None)
|
|
if rl_agent and hasattr(rl_agent, 'remember'):
|
|
for rec, reward in samples:
|
|
# Use state vector captured at inference time when available
|
|
state = rec.state_vector if getattr(rec, 'state_vector', None) else []
|
|
if not state:
|
|
continue
|
|
action = rec.predicted_direction + 1
|
|
rl_agent.remember(state, action, reward, state, True)
|
|
await asyncio.sleep(2)
|
|
except Exception as e:
|
|
logger.error(f"Reward-driven training loop error: {e}")
|
|
await asyncio.sleep(5)
|
|
|
|
|