dash and training wip
This commit is contained in:
@@ -349,7 +349,8 @@ class TradingOrchestrator:
|
||||
try:
|
||||
self.cob_rl_agent.load_model() # This loads the state into the model
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("cob_rl_model")
|
||||
# Use consistent model name with checkpoint manager and get_model_states
|
||||
result = load_best_checkpoint("cob_rl")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
self.model_states['cob_rl']['initial_loss'] = getattr(metadata, 'initial_loss', None)
|
||||
@@ -1592,13 +1593,16 @@ class TradingOrchestrator:
|
||||
logger.warning("EnhancedRealtimeTrainingSystem not available - training disabled")
|
||||
self.training_enabled = False
|
||||
return
|
||||
|
||||
# Initialize the enhanced training system
|
||||
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
|
||||
# Initialize unified training manager
|
||||
from utils.training_integration import get_unified_training_manager
|
||||
self.training_manager = get_unified_training_manager(
|
||||
orchestrator=self,
|
||||
data_provider=self.data_provider,
|
||||
dashboard=None # Will be set by dashboard when available
|
||||
dashboard=None
|
||||
)
|
||||
self.training_manager.initialize()
|
||||
# Keep backward-compatible attribute
|
||||
self.enhanced_training_system = getattr(self.training_manager, 'training_system', None)
|
||||
|
||||
logger.info("Enhanced real-time training system initialized")
|
||||
logger.info(" - Real-time model training: ENABLED")
|
||||
@@ -1614,11 +1618,11 @@ class TradingOrchestrator:
|
||||
def start_enhanced_training(self):
|
||||
"""Start the enhanced real-time training system"""
|
||||
try:
|
||||
if not self.training_enabled or not self.enhanced_training_system:
|
||||
if not self.training_enabled or not getattr(self, 'training_manager', None):
|
||||
logger.warning("Enhanced training system not available")
|
||||
return False
|
||||
|
||||
self.enhanced_training_system.start_training()
|
||||
self.training_manager.start()
|
||||
logger.info("Enhanced real-time training started")
|
||||
return True
|
||||
|
||||
@@ -1629,8 +1633,8 @@ class TradingOrchestrator:
|
||||
def stop_enhanced_training(self):
|
||||
"""Stop the enhanced real-time training system"""
|
||||
try:
|
||||
if self.enhanced_training_system:
|
||||
self.enhanced_training_system.stop_training()
|
||||
if getattr(self, 'training_manager', None):
|
||||
self.training_manager.stop()
|
||||
logger.info("Enhanced real-time training stopped")
|
||||
return True
|
||||
return False
|
||||
|
@@ -731,7 +731,8 @@ class RealtimeRLCOBTrader:
|
||||
with self.training_lock:
|
||||
# Check if we have enough data for training
|
||||
predictions = list(self.prediction_history[symbol])
|
||||
if len(predictions) < 10:
|
||||
# Train with fewer samples to kickstart learning
|
||||
if len(predictions) < 6:
|
||||
return
|
||||
|
||||
# Calculate rewards for recent predictions
|
||||
@@ -739,11 +740,11 @@ class RealtimeRLCOBTrader:
|
||||
|
||||
# Filter predictions with calculated rewards
|
||||
training_predictions = [p for p in predictions if p.reward is not None]
|
||||
if len(training_predictions) < 5:
|
||||
if len(training_predictions) < 3:
|
||||
return
|
||||
|
||||
# Prepare training batch
|
||||
batch_size = min(32, len(training_predictions))
|
||||
batch_size = min(16, len(training_predictions))
|
||||
batch_predictions = training_predictions[-batch_size:]
|
||||
|
||||
# Train model
|
||||
|
Reference in New Issue
Block a user