merge training system

This commit is contained in:
Dobromir Popov
2025-08-23 16:27:05 +03:00
parent 81749ee18e
commit f86457fc38
8 changed files with 166 additions and 8 deletions

View File

@@ -78,6 +78,11 @@ class StandardizedCNN(nn.Module):
# Device management # Device management
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.to(self.device) self.to(self.device)
try:
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
except Exception:
pass
logger.info(f"StandardizedCNN '{model_name}' initialized") logger.info(f"StandardizedCNN '{model_name}' initialized")
logger.info(f"Expected feature dimension: {self.expected_feature_dim}") logger.info(f"Expected feature dimension: {self.expected_feature_dim}")

View File

@@ -46,6 +46,8 @@ class PredictionRecord:
confidence: float confidence: float
current_price: float current_price: float
model_name: str model_name: str
# Optional state vector used for prediction/training (standardized feature/state)
state_vector: Optional[list] = None
# Outcome fields (set when outcome is determined) # Outcome fields (set when outcome is determined)
actual_price: Optional[float] = None actual_price: Optional[float] = None
@@ -144,7 +146,8 @@ class EnhancedRewardCalculator:
predicted_direction: int, predicted_direction: int,
confidence: float, confidence: float,
current_price: float, current_price: float,
model_name: str) -> str: model_name: str,
state_vector: Optional[list] = None) -> str:
""" """
Add a new prediction to track Add a new prediction to track
@@ -169,7 +172,8 @@ class EnhancedRewardCalculator:
predicted_direction=predicted_direction, predicted_direction=predicted_direction,
confidence=confidence, confidence=confidence,
current_price=current_price, current_price=current_price,
model_name=model_name model_name=model_name,
state_vector=state_vector
) )
# If predicted_return provided, prefer computing implied predicted_price # If predicted_return provided, prefer computing implied predicted_price

View File

@@ -20,6 +20,7 @@ from datetime import datetime
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame
from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator
from core.enhanced_rl_training_adapter import EnhancedRLTrainingAdapter from core.enhanced_rl_training_adapter import EnhancedRLTrainingAdapter
from core.unified_training_manager import UnifiedTrainingManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -60,6 +61,12 @@ class EnhancedRewardSystemIntegration:
training_system=getattr(orchestrator, 'enhanced_training_system', None) training_system=getattr(orchestrator, 'enhanced_training_system', None)
) )
# Unified Training Manager (always available)
self.unified_training = UnifiedTrainingManager(
orchestrator=orchestrator,
reward_system=self,
)
# Integration state # Integration state
self.is_running = False self.is_running = False
self.integration_stats = { self.integration_stats = {
@@ -83,6 +90,7 @@ class EnhancedRewardSystemIntegration:
# Start core components # Start core components
await self.inference_coordinator.start_coordination() await self.inference_coordinator.start_coordination()
await self.training_adapter.start_training_loop() await self.training_adapter.start_training_loop()
await self.unified_training.start()
# Start price monitoring # Start price monitoring
asyncio.create_task(self._price_monitoring_loop()) asyncio.create_task(self._price_monitoring_loop())
@@ -107,6 +115,7 @@ class EnhancedRewardSystemIntegration:
# Stop components # Stop components
await self.inference_coordinator.stop_coordination() await self.inference_coordinator.stop_coordination()
await self.training_adapter.stop_training_loop() await self.training_adapter.stop_training_loop()
await self.unified_training.stop()
self.is_running = False self.is_running = False

View File

@@ -6894,6 +6894,14 @@ class TradingOrchestrator:
try: try:
if not self.training_enabled or not self.enhanced_training_system: if not self.training_enabled or not self.enhanced_training_system:
logger.warning("Enhanced training system not available") logger.warning("Enhanced training system not available")
# Still start enhanced reward system + timeframe coordinator unconditionally
try:
from core.enhanced_reward_system_integration import start_enhanced_rewards_for_orchestrator
import asyncio as _asyncio
_asyncio.create_task(start_enhanced_rewards_for_orchestrator(self, symbols=[self.symbol] + self.ref_symbols))
logger.info("Enhanced reward system started (without enhanced training)")
except Exception as e:
logger.error(f"Error starting enhanced reward system: {e}")
return False return False
if hasattr(self.enhanced_training_system, "start_training"): if hasattr(self.enhanced_training_system, "start_training"):

View File

@@ -317,7 +317,9 @@ class TimeframeInferenceCoordinator:
predicted_direction=prediction.get('direction', 0), predicted_direction=prediction.get('direction', 0),
confidence=prediction.get('confidence', 0.0), confidence=prediction.get('confidence', 0.0),
current_price=prediction.get('current_price', 0.0), current_price=prediction.get('current_price', 0.0),
model_name=model_name model_name=model_name,
predicted_return=prediction.get('predicted_return'),
state_vector=prediction.get('model_state') or prediction.get('model_features')
) )
logger.debug(f"Added prediction {prediction_id} from {model_name} " logger.debug(f"Added prediction {prediction_id} from {model_name} "

View File

@@ -0,0 +1,130 @@
"""
Unified Training Manager
Combines the previous built-in (normal) training and the EnhancedRealtimeTrainingSystem ideas
into a single orchestrator-agnostic manager. Keeps orchestrator lean by moving training logic here.
Key responsibilities:
- Subscribe to model predictions/outcomes and perform online updates (DQN/COB RL/CNN)
- Schedule periodic training (intervals) and replay-based training
- Integrate with Enhanced Reward System for evaluated rewards
- Work regardless of enhanced system availability
"""
import asyncio
import logging
import time
from typing import Any, Dict, List, Optional, Tuple
from core.enhanced_reward_calculator import TimeFrame
logger = logging.getLogger(__name__)
class UnifiedTrainingManager:
"""Unified training controller decoupled from the orchestrator."""
def __init__(
self,
orchestrator: Any,
reward_system: Any = None,
dqn_interval_s: int = 5,
cob_rl_interval_s: int = 1,
cnn_interval_s: int = 10,
min_dqn_experiences: int = 16,
):
self.orchestrator = orchestrator
self.reward_system = reward_system
self.dqn_interval_s = dqn_interval_s
self.cob_rl_interval_s = cob_rl_interval_s
self.cnn_interval_s = cnn_interval_s
self.min_dqn_experiences = min_dqn_experiences
self.running = False
self._tasks: List[asyncio.Task] = []
async def start(self):
if self.running:
logger.warning("UnifiedTrainingManager already running")
return
self.running = True
logger.info("UnifiedTrainingManager started")
# Periodic trainers
self._tasks.append(asyncio.create_task(self._dqn_trainer_loop()))
self._tasks.append(asyncio.create_task(self._cob_rl_trainer_loop()))
self._tasks.append(asyncio.create_task(self._cnn_trainer_loop()))
# Reward-driven trainer
if self.reward_system is not None:
self._tasks.append(asyncio.create_task(self._reward_driven_training_loop()))
async def stop(self):
self.running = False
for t in self._tasks:
t.cancel()
await asyncio.gather(*self._tasks, return_exceptions=True)
self._tasks.clear()
logger.info("UnifiedTrainingManager stopped")
async def _dqn_trainer_loop(self):
while self.running:
try:
rl_agent = getattr(self.orchestrator, 'rl_agent', None)
if rl_agent and hasattr(rl_agent, 'memory'):
if len(rl_agent.memory) >= self.min_dqn_experiences and hasattr(rl_agent, 'replay'):
loss = rl_agent.replay()
if loss is not None:
logger.debug(f"DQN replay loss: {loss}")
await asyncio.sleep(self.dqn_interval_s)
except Exception as e:
logger.error(f"DQN trainer loop error: {e}")
await asyncio.sleep(self.dqn_interval_s)
async def _cob_rl_trainer_loop(self):
while self.running:
try:
cob_agent = getattr(self.orchestrator, 'cob_rl_agent', None)
if cob_agent and hasattr(cob_agent, 'replay') and hasattr(cob_agent, 'memory'):
if len(getattr(cob_agent, 'memory', [])) >= 8:
loss = cob_agent.replay()
if loss is not None:
logger.debug(f"COB RL replay loss: {loss}")
await asyncio.sleep(self.cob_rl_interval_s)
except Exception as e:
logger.error(f"COB RL trainer loop error: {e}")
await asyncio.sleep(self.cob_rl_interval_s)
async def _cnn_trainer_loop(self):
while self.running:
try:
# Placeholder: hook to your CNN trainer if available
await asyncio.sleep(self.cnn_interval_s)
except Exception as e:
logger.error(f"CNN trainer loop error: {e}")
await asyncio.sleep(self.cnn_interval_s)
async def _reward_driven_training_loop(self):
while self.running:
try:
# Pull evaluated samples and feed to respective models
symbols = getattr(self.reward_system.reward_calculator, 'symbols', []) if hasattr(self.reward_system, 'reward_calculator') else []
for sym in symbols:
# Use short horizon for fast feedback
samples = self.reward_system.reward_calculator.get_training_data(sym, TimeFrame.SECONDS_1, max_samples=64)
if not samples:
continue
# Currently DQN batch: add to memory and let replay loop train
rl_agent = getattr(self.orchestrator, 'rl_agent', None)
if rl_agent and hasattr(rl_agent, 'remember'):
for rec, reward in samples:
# Use state vector captured at inference time when available
state = rec.state_vector if getattr(rec, 'state_vector', None) else []
if not state:
continue
action = rec.predicted_direction + 1
rl_agent.remember(state, action, reward, state, True)
await asyncio.sleep(2)
except Exception as e:
logger.error(f"Reward-driven training loop error: {e}")
await asyncio.sleep(5)

View File

@@ -162,8 +162,8 @@ training:
# RL specific training # RL specific training
rl_training_interval: 300 # Train RL every 5 minutes (was 1 hour) rl_training_interval: 300 # Train RL every 5 minutes (was 1 hour)
min_experiences: 50 # Reduced from 100 for faster learning min_experiences: 16 # Lowered to trigger replay sooner in cold-start
training_steps_per_cycle: 20 # Increased from 10 for more learning training_steps_per_cycle: 32 # More steps per cycle to use GPU effectively
model_type: "optimized_short_term" model_type: "optimized_short_term"
use_realtime: true use_realtime: true

View File

@@ -8149,12 +8149,12 @@ class CleanTradingDashboard:
'price_at_prediction': self._get_current_price(symbol) 'price_at_prediction': self._get_current_price(symbol)
} }
# Sleep for 10 seconds (0.1Hz prediction rate for cold start) # Sleep for 2 seconds to improve GPU utilization and responsiveness
time.sleep(10.0) time.sleep(2.0)
except Exception as e: except Exception as e:
logger.error(f"Error in CNN prediction worker: {e}") logger.error(f"Error in CNN prediction worker: {e}")
time.sleep(10.0) # Wait same interval on error time.sleep(2.0) # Wait same interval on error
# Start the worker thread # Start the worker thread
import threading import threading