Files
gogo2/core/unified_training_manager_v2.py
Dobromir Popov 0225f4df58 wip wip wip
2025-10-23 18:57:07 +03:00

487 lines
20 KiB
Python

"""
Unified Training Manager V2 (Refactored)
Combines UnifiedTrainingManager and EnhancedRLTrainingAdapter into a single,
comprehensive training system that handles:
- Periodic training loops (DQN, COB RL, CNN)
- Reward-driven training with EnhancedRewardCalculator
- Multi-timeframe training coordination
- Batch processing and statistics tracking
- Inference coordination (optional)
This eliminates duplication and provides a single entry point for all training.
"""
import asyncio
import logging
import time
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Union, Tuple
from dataclasses import dataclass
import numpy as np
import threading
logger = logging.getLogger(__name__)
@dataclass
class TrainingBatch:
"""Training batch for RL models with enhanced reward data"""
model_name: str
symbol: str
timeframe: str
states: List[np.ndarray]
actions: List[int]
rewards: List[float]
next_states: List[np.ndarray]
dones: List[bool]
confidences: List[float]
metadata: Dict[str, Any]
batch_timestamp: datetime
class UnifiedTrainingManager:
"""
Unified training controller that combines periodic and reward-driven training
Features:
- Periodic training loops for DQN, COB RL, CNN
- Reward-driven training with EnhancedRewardCalculator
- Multi-timeframe training coordination
- Batch processing and statistics
- Inference coordination (optional)
"""
def __init__(
self,
orchestrator: Any,
reward_system: Any = None,
inference_coordinator: Any = None,
# Periodic training intervals
dqn_interval_s: int = 5,
cob_rl_interval_s: int = 1,
cnn_interval_s: int = 10,
# Batch configuration
min_dqn_experiences: int = 16,
min_batch_size: int = 8,
max_batch_size: int = 64,
# Reward-driven training
reward_training_interval_s: int = 2,
):
"""
Initialize unified training manager
Args:
orchestrator: Trading orchestrator with models
reward_system: Enhanced reward system (optional)
inference_coordinator: Timeframe inference coordinator (optional)
dqn_interval_s: DQN training interval
cob_rl_interval_s: COB RL training interval
cnn_interval_s: CNN training interval
min_dqn_experiences: Minimum experiences before DQN training
min_batch_size: Minimum batch size for reward-driven training
max_batch_size: Maximum batch size for reward-driven training
reward_training_interval_s: Reward-driven training check interval
"""
self.orchestrator = orchestrator
self.reward_system = reward_system
self.inference_coordinator = inference_coordinator
# Training intervals
self.dqn_interval_s = dqn_interval_s
self.cob_rl_interval_s = cob_rl_interval_s
self.cnn_interval_s = cnn_interval_s
self.reward_training_interval_s = reward_training_interval_s
# Batch configuration
self.min_dqn_experiences = min_dqn_experiences
self.min_batch_size = min_batch_size
self.max_batch_size = max_batch_size
# Training statistics
self.training_stats = {
'total_training_batches': 0,
'successful_training_calls': 0,
'failed_training_calls': 0,
'last_training_time': None,
'training_times_per_model': {},
'average_batch_sizes': {},
'periodic_training_counts': {
'dqn': 0,
'cob_rl': 0,
'cnn': 0
},
'reward_driven_training_count': 0
}
# Thread safety
self.lock = threading.RLock()
# Running state
self.running = False
self._tasks: List[asyncio.Task] = []
logger.info("UnifiedTrainingManager V2 initialized")
# Register inference wrappers if coordinator available
if self.inference_coordinator:
self._register_inference_wrappers()
def _register_inference_wrappers(self):
"""Register inference wrappers with coordinator"""
try:
# Register model inference functions
self.inference_coordinator.register_model_inference_function(
'dqn_agent', self._dqn_inference_wrapper
)
self.inference_coordinator.register_model_inference_function(
'cob_rl', self._cob_rl_inference_wrapper
)
self.inference_coordinator.register_model_inference_function(
'enhanced_cnn', self._cnn_inference_wrapper
)
logger.info("Inference wrappers registered with coordinator")
except Exception as e:
logger.warning(f"Could not register inference wrappers: {e}")
async def start(self):
"""Start all training loops"""
if self.running:
logger.warning("UnifiedTrainingManager already running")
return
self.running = True
logger.info("UnifiedTrainingManager started")
# Start periodic training loops
self._tasks.append(asyncio.create_task(self._dqn_trainer_loop()))
self._tasks.append(asyncio.create_task(self._cob_rl_trainer_loop()))
self._tasks.append(asyncio.create_task(self._cnn_trainer_loop()))
# Start reward-driven training if reward system available
if self.reward_system is not None:
self._tasks.append(asyncio.create_task(self._reward_driven_training_loop()))
logger.info("Reward-driven training enabled")
async def stop(self):
"""Stop all training loops"""
if not self.running:
return
self.running = False
# Cancel all tasks
for t in self._tasks:
t.cancel()
# Wait for tasks to complete
await asyncio.gather(*self._tasks, return_exceptions=True)
self._tasks.clear()
logger.info("UnifiedTrainingManager stopped")
# ========================================================================
# PERIODIC TRAINING LOOPS
# ========================================================================
async def _dqn_trainer_loop(self):
"""Periodic DQN training loop"""
while self.running:
try:
rl_agent = getattr(self.orchestrator, 'rl_agent', None)
if rl_agent and hasattr(rl_agent, 'memory') and hasattr(rl_agent, 'replay'):
if len(rl_agent.memory) >= self.min_dqn_experiences:
loss = rl_agent.replay()
if loss is not None:
logger.debug(f"DQN periodic training loss: {loss:.6f}")
self._update_periodic_training_stats('dqn', loss)
await asyncio.sleep(self.dqn_interval_s)
except Exception as e:
logger.error(f"DQN trainer loop error: {e}")
await asyncio.sleep(self.dqn_interval_s)
async def _cob_rl_trainer_loop(self):
"""Periodic COB RL training loop"""
while self.running:
try:
cob_agent = getattr(self.orchestrator, 'cob_rl_agent', None)
if cob_agent and hasattr(cob_agent, 'replay') and hasattr(cob_agent, 'memory'):
if len(getattr(cob_agent, 'memory', [])) >= 8:
loss = cob_agent.replay()
if loss is not None:
logger.debug(f"COB RL periodic training loss: {loss:.6f}")
self._update_periodic_training_stats('cob_rl', loss)
await asyncio.sleep(self.cob_rl_interval_s)
except Exception as e:
logger.error(f"COB RL trainer loop error: {e}")
await asyncio.sleep(self.cob_rl_interval_s)
async def _cnn_trainer_loop(self):
"""Periodic CNN training loop"""
while self.running:
try:
# Hook to CNN trainer if available
cnn_model = getattr(self.orchestrator, 'cnn_model', None)
if cnn_model and hasattr(cnn_model, 'train_step'):
# CNN training would go here
pass
await asyncio.sleep(self.cnn_interval_s)
except Exception as e:
logger.error(f"CNN trainer loop error: {e}")
await asyncio.sleep(self.cnn_interval_s)
# ========================================================================
# REWARD-DRIVEN TRAINING
# ========================================================================
async def _reward_driven_training_loop(self):
"""Reward-driven training loop using EnhancedRewardCalculator"""
while self.running:
try:
# Get reward calculator
reward_calculator = getattr(self.reward_system, 'reward_calculator', None)
if not reward_calculator:
await asyncio.sleep(self.reward_training_interval_s)
continue
# Get symbols to train on
symbols = getattr(reward_calculator, 'symbols', [])
# Import TimeFrame enum
try:
from core.enhanced_reward_calculator import TimeFrame
timeframes = [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
TimeFrame.HOURS_1, TimeFrame.DAYS_1]
except ImportError:
timeframes = ['1s', '1m', '1h', '1d']
# Process each symbol and timeframe
for symbol in symbols:
for timeframe in timeframes:
# Get training data
training_data = reward_calculator.get_training_data(
symbol, timeframe, self.max_batch_size
)
if len(training_data) >= self.min_batch_size:
await self._process_reward_training_batch(
symbol, timeframe, training_data
)
await asyncio.sleep(self.reward_training_interval_s)
except Exception as e:
logger.error(f"Reward-driven training loop error: {e}")
await asyncio.sleep(5)
async def _process_reward_training_batch(self, symbol: str, timeframe: Any,
training_data: List[Tuple[Any, float]]):
"""Process reward-driven training batch"""
try:
# Group by model
model_batches = {}
for prediction_record, reward in training_data:
model_name = getattr(prediction_record, 'model_name', 'unknown')
if model_name not in model_batches:
model_batches[model_name] = []
model_batches[model_name].append((prediction_record, reward))
# Train each model
for model_name, model_data in model_batches.items():
if len(model_data) >= self.min_batch_size:
await self._train_model_with_rewards(
model_name, symbol, timeframe, model_data
)
except Exception as e:
logger.error(f"Error processing reward training batch: {e}")
async def _train_model_with_rewards(self, model_name: str, symbol: str,
timeframe: Any, training_data: List[Tuple[Any, float]]):
"""Train model with reward-evaluated data"""
try:
training_start = time.time()
# Route to appropriate model
if 'dqn' in model_name.lower():
success = await self._train_dqn_with_rewards(training_data)
elif 'cob' in model_name.lower():
success = await self._train_cob_rl_with_rewards(training_data)
elif 'cnn' in model_name.lower():
success = await self._train_cnn_with_rewards(training_data)
else:
logger.warning(f"Unknown model type: {model_name}")
return
training_time = time.time() - training_start
if success:
with self.lock:
self.training_stats['reward_driven_training_count'] += 1
logger.info(f"Reward-driven training: {model_name} on {symbol} "
f"with {len(training_data)} samples in {training_time:.3f}s")
except Exception as e:
logger.error(f"Error in reward-driven training for {model_name}: {e}")
async def _train_dqn_with_rewards(self, training_data: List[Tuple[Any, float]]) -> bool:
"""Train DQN with reward-evaluated data"""
try:
rl_agent = getattr(self.orchestrator, 'rl_agent', None)
if not rl_agent or not hasattr(rl_agent, 'remember'):
return False
# Add experiences to memory
for prediction_record, reward in training_data:
# Get state vector from prediction record
state = getattr(prediction_record, 'state_vector', None)
if not state:
continue
# Convert direction to action
direction = getattr(prediction_record, 'predicted_direction', 0)
action = direction + 1 # Convert -1,0,1 to 0,1,2
# Add to memory
rl_agent.remember(state, action, reward, state, True)
return True
except Exception as e:
logger.error(f"Error training DQN with rewards: {e}")
return False
async def _train_cob_rl_with_rewards(self, training_data: List[Tuple[Any, float]]) -> bool:
"""Train COB RL with reward-evaluated data"""
try:
cob_agent = getattr(self.orchestrator, 'cob_rl_agent', None)
if not cob_agent or not hasattr(cob_agent, 'remember'):
return False
# Similar to DQN training
for prediction_record, reward in training_data:
state = getattr(prediction_record, 'state_vector', None)
if not state:
continue
direction = getattr(prediction_record, 'predicted_direction', 0)
action = direction + 1
cob_agent.remember(state, action, reward, state, True)
return True
except Exception as e:
logger.error(f"Error training COB RL with rewards: {e}")
return False
async def _train_cnn_with_rewards(self, training_data: List[Tuple[Any, float]]) -> bool:
"""Train CNN with reward-evaluated data"""
try:
# CNN training with rewards would go here
# This depends on CNN's training interface
return True
except Exception as e:
logger.error(f"Error training CNN with rewards: {e}")
return False
# ========================================================================
# INFERENCE WRAPPERS (Optional - for TimeframeInferenceCoordinator)
# ========================================================================
async def _dqn_inference_wrapper(self, context: Any) -> Optional[Dict[str, Any]]:
"""Wrapper for DQN model inference"""
try:
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
# Get base data
base_data = await self._get_base_data(context.symbol)
if base_data is None:
return None
# Convert to state
state = self._convert_to_dqn_state(base_data, context)
# Run prediction
if hasattr(self.orchestrator.rl_agent, 'act'):
action_idx = self.orchestrator.rl_agent.act(state)
confidence = getattr(self.orchestrator.rl_agent, 'last_confidence', 0.5)
action_names = ['SELL', 'HOLD', 'BUY']
direction = action_idx - 1
current_price = self._safe_get_current_price(context.symbol)
return {
'predicted_price': current_price,
'current_price': current_price,
'direction': direction,
'confidence': float(confidence),
'action': action_names[action_idx],
'model_state': (state.tolist() if hasattr(state, 'tolist') else state),
'context': context
}
except Exception as e:
logger.error(f"Error in DQN inference wrapper: {e}")
return None
async def _cob_rl_inference_wrapper(self, context: Any) -> Optional[Dict[str, Any]]:
"""Wrapper for COB RL model inference"""
# Implementation similar to EnhancedRLTrainingAdapter
return None
async def _cnn_inference_wrapper(self, context: Any) -> Optional[Dict[str, Any]]:
"""Wrapper for CNN model inference"""
# Implementation similar to EnhancedRLTrainingAdapter
return None
# ========================================================================
# HELPER METHODS
# ========================================================================
async def _get_base_data(self, symbol: str) -> Optional[Any]:
"""Get base data for a symbol"""
try:
if self.orchestrator and hasattr(self.orchestrator, '_build_base_data'):
return await self.orchestrator._build_base_data(symbol)
except Exception as e:
logger.debug(f"Error getting base data: {e}")
return None
def _safe_get_current_price(self, symbol: str) -> float:
"""Get current price safely"""
try:
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
price = self.orchestrator.data_provider.get_current_price(symbol)
return float(price) if price is not None else 0.0
except Exception as e:
logger.debug(f"Error getting current price: {e}")
return 0.0
def _convert_to_dqn_state(self, base_data: Any, context: Any) -> np.ndarray:
"""Convert base data to DQN state"""
try:
feature_vector = base_data.get_feature_vector() if hasattr(base_data, 'get_feature_vector') else []
if feature_vector:
return np.array(feature_vector, dtype=np.float32)
return np.zeros(100, dtype=np.float32)
except Exception as e:
logger.error(f"Error converting to DQN state: {e}")
return np.zeros(100, dtype=np.float32)
def _update_periodic_training_stats(self, model_type: str, loss: float):
"""Update periodic training statistics"""
with self.lock:
self.training_stats['periodic_training_counts'][model_type] += 1
self.training_stats['last_training_time'] = datetime.now().isoformat()
def get_training_statistics(self) -> Dict[str, Any]:
"""Get training statistics"""
with self.lock:
return self.training_stats.copy()