487 lines
20 KiB
Python
487 lines
20 KiB
Python
"""
|
|
Unified Training Manager V2 (Refactored)
|
|
|
|
Combines UnifiedTrainingManager and EnhancedRLTrainingAdapter into a single,
|
|
comprehensive training system that handles:
|
|
- Periodic training loops (DQN, COB RL, CNN)
|
|
- Reward-driven training with EnhancedRewardCalculator
|
|
- Multi-timeframe training coordination
|
|
- Batch processing and statistics tracking
|
|
- Inference coordination (optional)
|
|
|
|
This eliminates duplication and provides a single entry point for all training.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any, Union, Tuple
|
|
from dataclasses import dataclass
|
|
import numpy as np
|
|
import threading
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class TrainingBatch:
|
|
"""Training batch for RL models with enhanced reward data"""
|
|
model_name: str
|
|
symbol: str
|
|
timeframe: str
|
|
states: List[np.ndarray]
|
|
actions: List[int]
|
|
rewards: List[float]
|
|
next_states: List[np.ndarray]
|
|
dones: List[bool]
|
|
confidences: List[float]
|
|
metadata: Dict[str, Any]
|
|
batch_timestamp: datetime
|
|
|
|
|
|
class UnifiedTrainingManager:
|
|
"""
|
|
Unified training controller that combines periodic and reward-driven training
|
|
|
|
Features:
|
|
- Periodic training loops for DQN, COB RL, CNN
|
|
- Reward-driven training with EnhancedRewardCalculator
|
|
- Multi-timeframe training coordination
|
|
- Batch processing and statistics
|
|
- Inference coordination (optional)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
orchestrator: Any,
|
|
reward_system: Any = None,
|
|
inference_coordinator: Any = None,
|
|
# Periodic training intervals
|
|
dqn_interval_s: int = 5,
|
|
cob_rl_interval_s: int = 1,
|
|
cnn_interval_s: int = 10,
|
|
# Batch configuration
|
|
min_dqn_experiences: int = 16,
|
|
min_batch_size: int = 8,
|
|
max_batch_size: int = 64,
|
|
# Reward-driven training
|
|
reward_training_interval_s: int = 2,
|
|
):
|
|
"""
|
|
Initialize unified training manager
|
|
|
|
Args:
|
|
orchestrator: Trading orchestrator with models
|
|
reward_system: Enhanced reward system (optional)
|
|
inference_coordinator: Timeframe inference coordinator (optional)
|
|
dqn_interval_s: DQN training interval
|
|
cob_rl_interval_s: COB RL training interval
|
|
cnn_interval_s: CNN training interval
|
|
min_dqn_experiences: Minimum experiences before DQN training
|
|
min_batch_size: Minimum batch size for reward-driven training
|
|
max_batch_size: Maximum batch size for reward-driven training
|
|
reward_training_interval_s: Reward-driven training check interval
|
|
"""
|
|
self.orchestrator = orchestrator
|
|
self.reward_system = reward_system
|
|
self.inference_coordinator = inference_coordinator
|
|
|
|
# Training intervals
|
|
self.dqn_interval_s = dqn_interval_s
|
|
self.cob_rl_interval_s = cob_rl_interval_s
|
|
self.cnn_interval_s = cnn_interval_s
|
|
self.reward_training_interval_s = reward_training_interval_s
|
|
|
|
# Batch configuration
|
|
self.min_dqn_experiences = min_dqn_experiences
|
|
self.min_batch_size = min_batch_size
|
|
self.max_batch_size = max_batch_size
|
|
|
|
# Training statistics
|
|
self.training_stats = {
|
|
'total_training_batches': 0,
|
|
'successful_training_calls': 0,
|
|
'failed_training_calls': 0,
|
|
'last_training_time': None,
|
|
'training_times_per_model': {},
|
|
'average_batch_sizes': {},
|
|
'periodic_training_counts': {
|
|
'dqn': 0,
|
|
'cob_rl': 0,
|
|
'cnn': 0
|
|
},
|
|
'reward_driven_training_count': 0
|
|
}
|
|
|
|
# Thread safety
|
|
self.lock = threading.RLock()
|
|
|
|
# Running state
|
|
self.running = False
|
|
self._tasks: List[asyncio.Task] = []
|
|
|
|
logger.info("UnifiedTrainingManager V2 initialized")
|
|
|
|
# Register inference wrappers if coordinator available
|
|
if self.inference_coordinator:
|
|
self._register_inference_wrappers()
|
|
|
|
def _register_inference_wrappers(self):
|
|
"""Register inference wrappers with coordinator"""
|
|
try:
|
|
# Register model inference functions
|
|
self.inference_coordinator.register_model_inference_function(
|
|
'dqn_agent', self._dqn_inference_wrapper
|
|
)
|
|
self.inference_coordinator.register_model_inference_function(
|
|
'cob_rl', self._cob_rl_inference_wrapper
|
|
)
|
|
self.inference_coordinator.register_model_inference_function(
|
|
'enhanced_cnn', self._cnn_inference_wrapper
|
|
)
|
|
logger.info("Inference wrappers registered with coordinator")
|
|
except Exception as e:
|
|
logger.warning(f"Could not register inference wrappers: {e}")
|
|
|
|
async def start(self):
|
|
"""Start all training loops"""
|
|
if self.running:
|
|
logger.warning("UnifiedTrainingManager already running")
|
|
return
|
|
|
|
self.running = True
|
|
logger.info("UnifiedTrainingManager started")
|
|
|
|
# Start periodic training loops
|
|
self._tasks.append(asyncio.create_task(self._dqn_trainer_loop()))
|
|
self._tasks.append(asyncio.create_task(self._cob_rl_trainer_loop()))
|
|
self._tasks.append(asyncio.create_task(self._cnn_trainer_loop()))
|
|
|
|
# Start reward-driven training if reward system available
|
|
if self.reward_system is not None:
|
|
self._tasks.append(asyncio.create_task(self._reward_driven_training_loop()))
|
|
logger.info("Reward-driven training enabled")
|
|
|
|
async def stop(self):
|
|
"""Stop all training loops"""
|
|
if not self.running:
|
|
return
|
|
|
|
self.running = False
|
|
|
|
# Cancel all tasks
|
|
for t in self._tasks:
|
|
t.cancel()
|
|
|
|
# Wait for tasks to complete
|
|
await asyncio.gather(*self._tasks, return_exceptions=True)
|
|
self._tasks.clear()
|
|
|
|
logger.info("UnifiedTrainingManager stopped")
|
|
|
|
# ========================================================================
|
|
# PERIODIC TRAINING LOOPS
|
|
# ========================================================================
|
|
|
|
async def _dqn_trainer_loop(self):
|
|
"""Periodic DQN training loop"""
|
|
while self.running:
|
|
try:
|
|
rl_agent = getattr(self.orchestrator, 'rl_agent', None)
|
|
if rl_agent and hasattr(rl_agent, 'memory') and hasattr(rl_agent, 'replay'):
|
|
if len(rl_agent.memory) >= self.min_dqn_experiences:
|
|
loss = rl_agent.replay()
|
|
if loss is not None:
|
|
logger.debug(f"DQN periodic training loss: {loss:.6f}")
|
|
self._update_periodic_training_stats('dqn', loss)
|
|
|
|
await asyncio.sleep(self.dqn_interval_s)
|
|
except Exception as e:
|
|
logger.error(f"DQN trainer loop error: {e}")
|
|
await asyncio.sleep(self.dqn_interval_s)
|
|
|
|
async def _cob_rl_trainer_loop(self):
|
|
"""Periodic COB RL training loop"""
|
|
while self.running:
|
|
try:
|
|
cob_agent = getattr(self.orchestrator, 'cob_rl_agent', None)
|
|
if cob_agent and hasattr(cob_agent, 'replay') and hasattr(cob_agent, 'memory'):
|
|
if len(getattr(cob_agent, 'memory', [])) >= 8:
|
|
loss = cob_agent.replay()
|
|
if loss is not None:
|
|
logger.debug(f"COB RL periodic training loss: {loss:.6f}")
|
|
self._update_periodic_training_stats('cob_rl', loss)
|
|
|
|
await asyncio.sleep(self.cob_rl_interval_s)
|
|
except Exception as e:
|
|
logger.error(f"COB RL trainer loop error: {e}")
|
|
await asyncio.sleep(self.cob_rl_interval_s)
|
|
|
|
async def _cnn_trainer_loop(self):
|
|
"""Periodic CNN training loop"""
|
|
while self.running:
|
|
try:
|
|
# Hook to CNN trainer if available
|
|
cnn_model = getattr(self.orchestrator, 'cnn_model', None)
|
|
if cnn_model and hasattr(cnn_model, 'train_step'):
|
|
# CNN training would go here
|
|
pass
|
|
|
|
await asyncio.sleep(self.cnn_interval_s)
|
|
except Exception as e:
|
|
logger.error(f"CNN trainer loop error: {e}")
|
|
await asyncio.sleep(self.cnn_interval_s)
|
|
|
|
# ========================================================================
|
|
# REWARD-DRIVEN TRAINING
|
|
# ========================================================================
|
|
|
|
async def _reward_driven_training_loop(self):
|
|
"""Reward-driven training loop using EnhancedRewardCalculator"""
|
|
while self.running:
|
|
try:
|
|
# Get reward calculator
|
|
reward_calculator = getattr(self.reward_system, 'reward_calculator', None)
|
|
if not reward_calculator:
|
|
await asyncio.sleep(self.reward_training_interval_s)
|
|
continue
|
|
|
|
# Get symbols to train on
|
|
symbols = getattr(reward_calculator, 'symbols', [])
|
|
|
|
# Import TimeFrame enum
|
|
try:
|
|
from core.enhanced_reward_calculator import TimeFrame
|
|
timeframes = [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
|
|
TimeFrame.HOURS_1, TimeFrame.DAYS_1]
|
|
except ImportError:
|
|
timeframes = ['1s', '1m', '1h', '1d']
|
|
|
|
# Process each symbol and timeframe
|
|
for symbol in symbols:
|
|
for timeframe in timeframes:
|
|
# Get training data
|
|
training_data = reward_calculator.get_training_data(
|
|
symbol, timeframe, self.max_batch_size
|
|
)
|
|
|
|
if len(training_data) >= self.min_batch_size:
|
|
await self._process_reward_training_batch(
|
|
symbol, timeframe, training_data
|
|
)
|
|
|
|
await asyncio.sleep(self.reward_training_interval_s)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Reward-driven training loop error: {e}")
|
|
await asyncio.sleep(5)
|
|
|
|
async def _process_reward_training_batch(self, symbol: str, timeframe: Any,
|
|
training_data: List[Tuple[Any, float]]):
|
|
"""Process reward-driven training batch"""
|
|
try:
|
|
# Group by model
|
|
model_batches = {}
|
|
|
|
for prediction_record, reward in training_data:
|
|
model_name = getattr(prediction_record, 'model_name', 'unknown')
|
|
if model_name not in model_batches:
|
|
model_batches[model_name] = []
|
|
model_batches[model_name].append((prediction_record, reward))
|
|
|
|
# Train each model
|
|
for model_name, model_data in model_batches.items():
|
|
if len(model_data) >= self.min_batch_size:
|
|
await self._train_model_with_rewards(
|
|
model_name, symbol, timeframe, model_data
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing reward training batch: {e}")
|
|
|
|
async def _train_model_with_rewards(self, model_name: str, symbol: str,
|
|
timeframe: Any, training_data: List[Tuple[Any, float]]):
|
|
"""Train model with reward-evaluated data"""
|
|
try:
|
|
training_start = time.time()
|
|
|
|
# Route to appropriate model
|
|
if 'dqn' in model_name.lower():
|
|
success = await self._train_dqn_with_rewards(training_data)
|
|
elif 'cob' in model_name.lower():
|
|
success = await self._train_cob_rl_with_rewards(training_data)
|
|
elif 'cnn' in model_name.lower():
|
|
success = await self._train_cnn_with_rewards(training_data)
|
|
else:
|
|
logger.warning(f"Unknown model type: {model_name}")
|
|
return
|
|
|
|
training_time = time.time() - training_start
|
|
|
|
if success:
|
|
with self.lock:
|
|
self.training_stats['reward_driven_training_count'] += 1
|
|
logger.info(f"Reward-driven training: {model_name} on {symbol} "
|
|
f"with {len(training_data)} samples in {training_time:.3f}s")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in reward-driven training for {model_name}: {e}")
|
|
|
|
async def _train_dqn_with_rewards(self, training_data: List[Tuple[Any, float]]) -> bool:
|
|
"""Train DQN with reward-evaluated data"""
|
|
try:
|
|
rl_agent = getattr(self.orchestrator, 'rl_agent', None)
|
|
if not rl_agent or not hasattr(rl_agent, 'remember'):
|
|
return False
|
|
|
|
# Add experiences to memory
|
|
for prediction_record, reward in training_data:
|
|
# Get state vector from prediction record
|
|
state = getattr(prediction_record, 'state_vector', None)
|
|
if not state:
|
|
continue
|
|
|
|
# Convert direction to action
|
|
direction = getattr(prediction_record, 'predicted_direction', 0)
|
|
action = direction + 1 # Convert -1,0,1 to 0,1,2
|
|
|
|
# Add to memory
|
|
rl_agent.remember(state, action, reward, state, True)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training DQN with rewards: {e}")
|
|
return False
|
|
|
|
async def _train_cob_rl_with_rewards(self, training_data: List[Tuple[Any, float]]) -> bool:
|
|
"""Train COB RL with reward-evaluated data"""
|
|
try:
|
|
cob_agent = getattr(self.orchestrator, 'cob_rl_agent', None)
|
|
if not cob_agent or not hasattr(cob_agent, 'remember'):
|
|
return False
|
|
|
|
# Similar to DQN training
|
|
for prediction_record, reward in training_data:
|
|
state = getattr(prediction_record, 'state_vector', None)
|
|
if not state:
|
|
continue
|
|
|
|
direction = getattr(prediction_record, 'predicted_direction', 0)
|
|
action = direction + 1
|
|
|
|
cob_agent.remember(state, action, reward, state, True)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training COB RL with rewards: {e}")
|
|
return False
|
|
|
|
async def _train_cnn_with_rewards(self, training_data: List[Tuple[Any, float]]) -> bool:
|
|
"""Train CNN with reward-evaluated data"""
|
|
try:
|
|
# CNN training with rewards would go here
|
|
# This depends on CNN's training interface
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training CNN with rewards: {e}")
|
|
return False
|
|
|
|
# ========================================================================
|
|
# INFERENCE WRAPPERS (Optional - for TimeframeInferenceCoordinator)
|
|
# ========================================================================
|
|
|
|
async def _dqn_inference_wrapper(self, context: Any) -> Optional[Dict[str, Any]]:
|
|
"""Wrapper for DQN model inference"""
|
|
try:
|
|
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
|
|
# Get base data
|
|
base_data = await self._get_base_data(context.symbol)
|
|
if base_data is None:
|
|
return None
|
|
|
|
# Convert to state
|
|
state = self._convert_to_dqn_state(base_data, context)
|
|
|
|
# Run prediction
|
|
if hasattr(self.orchestrator.rl_agent, 'act'):
|
|
action_idx = self.orchestrator.rl_agent.act(state)
|
|
confidence = getattr(self.orchestrator.rl_agent, 'last_confidence', 0.5)
|
|
|
|
action_names = ['SELL', 'HOLD', 'BUY']
|
|
direction = action_idx - 1
|
|
|
|
current_price = self._safe_get_current_price(context.symbol)
|
|
|
|
return {
|
|
'predicted_price': current_price,
|
|
'current_price': current_price,
|
|
'direction': direction,
|
|
'confidence': float(confidence),
|
|
'action': action_names[action_idx],
|
|
'model_state': (state.tolist() if hasattr(state, 'tolist') else state),
|
|
'context': context
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error in DQN inference wrapper: {e}")
|
|
|
|
return None
|
|
|
|
async def _cob_rl_inference_wrapper(self, context: Any) -> Optional[Dict[str, Any]]:
|
|
"""Wrapper for COB RL model inference"""
|
|
# Implementation similar to EnhancedRLTrainingAdapter
|
|
return None
|
|
|
|
async def _cnn_inference_wrapper(self, context: Any) -> Optional[Dict[str, Any]]:
|
|
"""Wrapper for CNN model inference"""
|
|
# Implementation similar to EnhancedRLTrainingAdapter
|
|
return None
|
|
|
|
# ========================================================================
|
|
# HELPER METHODS
|
|
# ========================================================================
|
|
|
|
async def _get_base_data(self, symbol: str) -> Optional[Any]:
|
|
"""Get base data for a symbol"""
|
|
try:
|
|
if self.orchestrator and hasattr(self.orchestrator, '_build_base_data'):
|
|
return await self.orchestrator._build_base_data(symbol)
|
|
except Exception as e:
|
|
logger.debug(f"Error getting base data: {e}")
|
|
return None
|
|
|
|
def _safe_get_current_price(self, symbol: str) -> float:
|
|
"""Get current price safely"""
|
|
try:
|
|
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
|
|
price = self.orchestrator.data_provider.get_current_price(symbol)
|
|
return float(price) if price is not None else 0.0
|
|
except Exception as e:
|
|
logger.debug(f"Error getting current price: {e}")
|
|
return 0.0
|
|
|
|
def _convert_to_dqn_state(self, base_data: Any, context: Any) -> np.ndarray:
|
|
"""Convert base data to DQN state"""
|
|
try:
|
|
feature_vector = base_data.get_feature_vector() if hasattr(base_data, 'get_feature_vector') else []
|
|
if feature_vector:
|
|
return np.array(feature_vector, dtype=np.float32)
|
|
return np.zeros(100, dtype=np.float32)
|
|
except Exception as e:
|
|
logger.error(f"Error converting to DQN state: {e}")
|
|
return np.zeros(100, dtype=np.float32)
|
|
|
|
def _update_periodic_training_stats(self, model_type: str, loss: float):
|
|
"""Update periodic training statistics"""
|
|
with self.lock:
|
|
self.training_stats['periodic_training_counts'][model_type] += 1
|
|
self.training_stats['last_training_time'] = datetime.now().isoformat()
|
|
|
|
def get_training_statistics(self) -> Dict[str, Any]:
|
|
"""Get training statistics"""
|
|
with self.lock:
|
|
return self.training_stats.copy()
|