merge training system
This commit is contained in:
@@ -78,6 +78,11 @@ class StandardizedCNN(nn.Module):
|
|||||||
# Device management
|
# Device management
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
try:
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
cudnn.benchmark = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
logger.info(f"StandardizedCNN '{model_name}' initialized")
|
logger.info(f"StandardizedCNN '{model_name}' initialized")
|
||||||
logger.info(f"Expected feature dimension: {self.expected_feature_dim}")
|
logger.info(f"Expected feature dimension: {self.expected_feature_dim}")
|
||||||
|
@@ -46,6 +46,8 @@ class PredictionRecord:
|
|||||||
confidence: float
|
confidence: float
|
||||||
current_price: float
|
current_price: float
|
||||||
model_name: str
|
model_name: str
|
||||||
|
# Optional state vector used for prediction/training (standardized feature/state)
|
||||||
|
state_vector: Optional[list] = None
|
||||||
|
|
||||||
# Outcome fields (set when outcome is determined)
|
# Outcome fields (set when outcome is determined)
|
||||||
actual_price: Optional[float] = None
|
actual_price: Optional[float] = None
|
||||||
@@ -144,7 +146,8 @@ class EnhancedRewardCalculator:
|
|||||||
predicted_direction: int,
|
predicted_direction: int,
|
||||||
confidence: float,
|
confidence: float,
|
||||||
current_price: float,
|
current_price: float,
|
||||||
model_name: str) -> str:
|
model_name: str,
|
||||||
|
state_vector: Optional[list] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Add a new prediction to track
|
Add a new prediction to track
|
||||||
|
|
||||||
@@ -169,7 +172,8 @@ class EnhancedRewardCalculator:
|
|||||||
predicted_direction=predicted_direction,
|
predicted_direction=predicted_direction,
|
||||||
confidence=confidence,
|
confidence=confidence,
|
||||||
current_price=current_price,
|
current_price=current_price,
|
||||||
model_name=model_name
|
model_name=model_name,
|
||||||
|
state_vector=state_vector
|
||||||
)
|
)
|
||||||
|
|
||||||
# If predicted_return provided, prefer computing implied predicted_price
|
# If predicted_return provided, prefer computing implied predicted_price
|
||||||
|
@@ -20,6 +20,7 @@ from datetime import datetime
|
|||||||
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame
|
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame
|
||||||
from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator
|
from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator
|
||||||
from core.enhanced_rl_training_adapter import EnhancedRLTrainingAdapter
|
from core.enhanced_rl_training_adapter import EnhancedRLTrainingAdapter
|
||||||
|
from core.unified_training_manager import UnifiedTrainingManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -60,6 +61,12 @@ class EnhancedRewardSystemIntegration:
|
|||||||
training_system=getattr(orchestrator, 'enhanced_training_system', None)
|
training_system=getattr(orchestrator, 'enhanced_training_system', None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Unified Training Manager (always available)
|
||||||
|
self.unified_training = UnifiedTrainingManager(
|
||||||
|
orchestrator=orchestrator,
|
||||||
|
reward_system=self,
|
||||||
|
)
|
||||||
|
|
||||||
# Integration state
|
# Integration state
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
self.integration_stats = {
|
self.integration_stats = {
|
||||||
@@ -83,6 +90,7 @@ class EnhancedRewardSystemIntegration:
|
|||||||
# Start core components
|
# Start core components
|
||||||
await self.inference_coordinator.start_coordination()
|
await self.inference_coordinator.start_coordination()
|
||||||
await self.training_adapter.start_training_loop()
|
await self.training_adapter.start_training_loop()
|
||||||
|
await self.unified_training.start()
|
||||||
|
|
||||||
# Start price monitoring
|
# Start price monitoring
|
||||||
asyncio.create_task(self._price_monitoring_loop())
|
asyncio.create_task(self._price_monitoring_loop())
|
||||||
@@ -107,6 +115,7 @@ class EnhancedRewardSystemIntegration:
|
|||||||
# Stop components
|
# Stop components
|
||||||
await self.inference_coordinator.stop_coordination()
|
await self.inference_coordinator.stop_coordination()
|
||||||
await self.training_adapter.stop_training_loop()
|
await self.training_adapter.stop_training_loop()
|
||||||
|
await self.unified_training.stop()
|
||||||
|
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
|
||||||
|
@@ -6894,6 +6894,14 @@ class TradingOrchestrator:
|
|||||||
try:
|
try:
|
||||||
if not self.training_enabled or not self.enhanced_training_system:
|
if not self.training_enabled or not self.enhanced_training_system:
|
||||||
logger.warning("Enhanced training system not available")
|
logger.warning("Enhanced training system not available")
|
||||||
|
# Still start enhanced reward system + timeframe coordinator unconditionally
|
||||||
|
try:
|
||||||
|
from core.enhanced_reward_system_integration import start_enhanced_rewards_for_orchestrator
|
||||||
|
import asyncio as _asyncio
|
||||||
|
_asyncio.create_task(start_enhanced_rewards_for_orchestrator(self, symbols=[self.symbol] + self.ref_symbols))
|
||||||
|
logger.info("Enhanced reward system started (without enhanced training)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error starting enhanced reward system: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if hasattr(self.enhanced_training_system, "start_training"):
|
if hasattr(self.enhanced_training_system, "start_training"):
|
||||||
|
@@ -317,7 +317,9 @@ class TimeframeInferenceCoordinator:
|
|||||||
predicted_direction=prediction.get('direction', 0),
|
predicted_direction=prediction.get('direction', 0),
|
||||||
confidence=prediction.get('confidence', 0.0),
|
confidence=prediction.get('confidence', 0.0),
|
||||||
current_price=prediction.get('current_price', 0.0),
|
current_price=prediction.get('current_price', 0.0),
|
||||||
model_name=model_name
|
model_name=model_name,
|
||||||
|
predicted_return=prediction.get('predicted_return'),
|
||||||
|
state_vector=prediction.get('model_state') or prediction.get('model_features')
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Added prediction {prediction_id} from {model_name} "
|
logger.debug(f"Added prediction {prediction_id} from {model_name} "
|
||||||
|
130
core/unified_training_manager.py
Normal file
130
core/unified_training_manager.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
"""
|
||||||
|
Unified Training Manager
|
||||||
|
|
||||||
|
Combines the previous built-in (normal) training and the EnhancedRealtimeTrainingSystem ideas
|
||||||
|
into a single orchestrator-agnostic manager. Keeps orchestrator lean by moving training logic here.
|
||||||
|
|
||||||
|
Key responsibilities:
|
||||||
|
- Subscribe to model predictions/outcomes and perform online updates (DQN/COB RL/CNN)
|
||||||
|
- Schedule periodic training (intervals) and replay-based training
|
||||||
|
- Integrate with Enhanced Reward System for evaluated rewards
|
||||||
|
- Work regardless of enhanced system availability
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from core.enhanced_reward_calculator import TimeFrame
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UnifiedTrainingManager:
|
||||||
|
"""Unified training controller decoupled from the orchestrator."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
orchestrator: Any,
|
||||||
|
reward_system: Any = None,
|
||||||
|
dqn_interval_s: int = 5,
|
||||||
|
cob_rl_interval_s: int = 1,
|
||||||
|
cnn_interval_s: int = 10,
|
||||||
|
min_dqn_experiences: int = 16,
|
||||||
|
):
|
||||||
|
self.orchestrator = orchestrator
|
||||||
|
self.reward_system = reward_system
|
||||||
|
self.dqn_interval_s = dqn_interval_s
|
||||||
|
self.cob_rl_interval_s = cob_rl_interval_s
|
||||||
|
self.cnn_interval_s = cnn_interval_s
|
||||||
|
self.min_dqn_experiences = min_dqn_experiences
|
||||||
|
|
||||||
|
self.running = False
|
||||||
|
self._tasks: List[asyncio.Task] = []
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self.running:
|
||||||
|
logger.warning("UnifiedTrainingManager already running")
|
||||||
|
return
|
||||||
|
self.running = True
|
||||||
|
logger.info("UnifiedTrainingManager started")
|
||||||
|
# Periodic trainers
|
||||||
|
self._tasks.append(asyncio.create_task(self._dqn_trainer_loop()))
|
||||||
|
self._tasks.append(asyncio.create_task(self._cob_rl_trainer_loop()))
|
||||||
|
self._tasks.append(asyncio.create_task(self._cnn_trainer_loop()))
|
||||||
|
# Reward-driven trainer
|
||||||
|
if self.reward_system is not None:
|
||||||
|
self._tasks.append(asyncio.create_task(self._reward_driven_training_loop()))
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
self.running = False
|
||||||
|
for t in self._tasks:
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*self._tasks, return_exceptions=True)
|
||||||
|
self._tasks.clear()
|
||||||
|
logger.info("UnifiedTrainingManager stopped")
|
||||||
|
|
||||||
|
async def _dqn_trainer_loop(self):
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
rl_agent = getattr(self.orchestrator, 'rl_agent', None)
|
||||||
|
if rl_agent and hasattr(rl_agent, 'memory'):
|
||||||
|
if len(rl_agent.memory) >= self.min_dqn_experiences and hasattr(rl_agent, 'replay'):
|
||||||
|
loss = rl_agent.replay()
|
||||||
|
if loss is not None:
|
||||||
|
logger.debug(f"DQN replay loss: {loss}")
|
||||||
|
await asyncio.sleep(self.dqn_interval_s)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DQN trainer loop error: {e}")
|
||||||
|
await asyncio.sleep(self.dqn_interval_s)
|
||||||
|
|
||||||
|
async def _cob_rl_trainer_loop(self):
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
cob_agent = getattr(self.orchestrator, 'cob_rl_agent', None)
|
||||||
|
if cob_agent and hasattr(cob_agent, 'replay') and hasattr(cob_agent, 'memory'):
|
||||||
|
if len(getattr(cob_agent, 'memory', [])) >= 8:
|
||||||
|
loss = cob_agent.replay()
|
||||||
|
if loss is not None:
|
||||||
|
logger.debug(f"COB RL replay loss: {loss}")
|
||||||
|
await asyncio.sleep(self.cob_rl_interval_s)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"COB RL trainer loop error: {e}")
|
||||||
|
await asyncio.sleep(self.cob_rl_interval_s)
|
||||||
|
|
||||||
|
async def _cnn_trainer_loop(self):
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
# Placeholder: hook to your CNN trainer if available
|
||||||
|
await asyncio.sleep(self.cnn_interval_s)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"CNN trainer loop error: {e}")
|
||||||
|
await asyncio.sleep(self.cnn_interval_s)
|
||||||
|
|
||||||
|
async def _reward_driven_training_loop(self):
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
# Pull evaluated samples and feed to respective models
|
||||||
|
symbols = getattr(self.reward_system.reward_calculator, 'symbols', []) if hasattr(self.reward_system, 'reward_calculator') else []
|
||||||
|
for sym in symbols:
|
||||||
|
# Use short horizon for fast feedback
|
||||||
|
samples = self.reward_system.reward_calculator.get_training_data(sym, TimeFrame.SECONDS_1, max_samples=64)
|
||||||
|
if not samples:
|
||||||
|
continue
|
||||||
|
# Currently DQN batch: add to memory and let replay loop train
|
||||||
|
rl_agent = getattr(self.orchestrator, 'rl_agent', None)
|
||||||
|
if rl_agent and hasattr(rl_agent, 'remember'):
|
||||||
|
for rec, reward in samples:
|
||||||
|
# Use state vector captured at inference time when available
|
||||||
|
state = rec.state_vector if getattr(rec, 'state_vector', None) else []
|
||||||
|
if not state:
|
||||||
|
continue
|
||||||
|
action = rec.predicted_direction + 1
|
||||||
|
rl_agent.remember(state, action, reward, state, True)
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Reward-driven training loop error: {e}")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
|
@@ -162,8 +162,8 @@ training:
|
|||||||
|
|
||||||
# RL specific training
|
# RL specific training
|
||||||
rl_training_interval: 300 # Train RL every 5 minutes (was 1 hour)
|
rl_training_interval: 300 # Train RL every 5 minutes (was 1 hour)
|
||||||
min_experiences: 50 # Reduced from 100 for faster learning
|
min_experiences: 16 # Lowered to trigger replay sooner in cold-start
|
||||||
training_steps_per_cycle: 20 # Increased from 10 for more learning
|
training_steps_per_cycle: 32 # More steps per cycle to use GPU effectively
|
||||||
|
|
||||||
model_type: "optimized_short_term"
|
model_type: "optimized_short_term"
|
||||||
use_realtime: true
|
use_realtime: true
|
||||||
|
@@ -8149,12 +8149,12 @@ class CleanTradingDashboard:
|
|||||||
'price_at_prediction': self._get_current_price(symbol)
|
'price_at_prediction': self._get_current_price(symbol)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Sleep for 10 seconds (0.1Hz prediction rate for cold start)
|
# Sleep for 2 seconds to improve GPU utilization and responsiveness
|
||||||
time.sleep(10.0)
|
time.sleep(2.0)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in CNN prediction worker: {e}")
|
logger.error(f"Error in CNN prediction worker: {e}")
|
||||||
time.sleep(10.0) # Wait same interval on error
|
time.sleep(2.0) # Wait same interval on error
|
||||||
|
|
||||||
# Start the worker thread
|
# Start the worker thread
|
||||||
import threading
|
import threading
|
||||||
|
Reference in New Issue
Block a user