remove dummy data, improve training , follow architecture
This commit is contained in:
@ -22,6 +22,7 @@ from collections import deque
|
||||
|
||||
from .config import get_config
|
||||
from .data_provider import DataProvider
|
||||
from .universal_data_adapter import UniversalDataAdapter, UniversalDataStream
|
||||
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface, ModelRegistry
|
||||
|
||||
# Import COB integration for real-time market microstructure data
|
||||
@ -69,6 +70,7 @@ class TradingOrchestrator:
|
||||
"""Initialize the enhanced orchestrator with full ML capabilities"""
|
||||
self.config = get_config()
|
||||
self.data_provider = data_provider or DataProvider()
|
||||
self.universal_adapter = UniversalDataAdapter(self.data_provider)
|
||||
self.model_registry = model_registry or get_model_registry()
|
||||
self.enhanced_rl_training = enhanced_rl_training
|
||||
|
||||
@ -144,6 +146,7 @@ class TradingOrchestrator:
|
||||
logger.info(f"Confidence threshold: {self.confidence_threshold}")
|
||||
logger.info(f"Decision frequency: {self.decision_frequency}s")
|
||||
logger.info(f"Symbols: {self.symbols}")
|
||||
logger.info("Universal Data Adapter integrated for centralized data flow")
|
||||
|
||||
# Initialize models and COB integration
|
||||
self._initialize_ml_models()
|
||||
@ -181,9 +184,9 @@ class TradingOrchestrator:
|
||||
result = load_best_checkpoint("dqn_agent")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
self.model_states['dqn']['initial_loss'] = 0.285
|
||||
self.model_states['dqn']['current_loss'] = metadata.loss or 0.0145
|
||||
self.model_states['dqn']['best_loss'] = metadata.loss or 0.0098
|
||||
self.model_states['dqn']['initial_loss'] = getattr(metadata, 'initial_loss', None)
|
||||
self.model_states['dqn']['current_loss'] = metadata.loss
|
||||
self.model_states['dqn']['best_loss'] = metadata.loss
|
||||
self.model_states['dqn']['checkpoint_loaded'] = True
|
||||
self.model_states['dqn']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
@ -192,10 +195,10 @@ class TradingOrchestrator:
|
||||
logger.warning(f"Error loading DQN checkpoint: {e}")
|
||||
|
||||
if not checkpoint_loaded:
|
||||
# New model - set initial loss for tracking
|
||||
self.model_states['dqn']['initial_loss'] = 0.285 # Typical DQN starting loss
|
||||
self.model_states['dqn']['current_loss'] = 0.285
|
||||
self.model_states['dqn']['best_loss'] = 0.285
|
||||
# New model - no synthetic data, start fresh
|
||||
self.model_states['dqn']['initial_loss'] = None
|
||||
self.model_states['dqn']['current_loss'] = None
|
||||
self.model_states['dqn']['best_loss'] = None
|
||||
self.model_states['dqn']['checkpoint_filename'] = 'none (fresh start)'
|
||||
logger.info("DQN starting fresh - no checkpoint found")
|
||||
|
||||
@ -230,9 +233,10 @@ class TradingOrchestrator:
|
||||
logger.warning(f"Error loading CNN checkpoint: {e}")
|
||||
|
||||
if not checkpoint_loaded:
|
||||
self.model_states['cnn']['initial_loss'] = 0.412 # Typical CNN starting loss
|
||||
self.model_states['cnn']['current_loss'] = 0.412
|
||||
self.model_states['cnn']['best_loss'] = 0.412
|
||||
# New model - no synthetic data
|
||||
self.model_states['cnn']['initial_loss'] = None
|
||||
self.model_states['cnn']['current_loss'] = None
|
||||
self.model_states['cnn']['best_loss'] = None
|
||||
logger.info("CNN starting fresh - no checkpoint found")
|
||||
|
||||
logger.info("Enhanced CNN model initialized")
|
||||
@ -251,9 +255,9 @@ class TradingOrchestrator:
|
||||
self.model_states['cnn']['checkpoint_loaded'] = True
|
||||
logger.info(f"CNN checkpoint loaded: loss={checkpoint_data.get('loss', 'N/A')}")
|
||||
else:
|
||||
self.model_states['cnn']['initial_loss'] = 0.412
|
||||
self.model_states['cnn']['current_loss'] = 0.412
|
||||
self.model_states['cnn']['best_loss'] = 0.412
|
||||
self.model_states['cnn']['initial_loss'] = None
|
||||
self.model_states['cnn']['current_loss'] = None
|
||||
self.model_states['cnn']['best_loss'] = None
|
||||
logger.info("CNN starting fresh - no checkpoint found")
|
||||
|
||||
logger.info("Basic CNN model initialized")
|
||||
@ -279,9 +283,9 @@ class TradingOrchestrator:
|
||||
self.model_states['extrema_trainer']['checkpoint_loaded'] = True
|
||||
logger.info(f"Extrema trainer checkpoint loaded: loss={checkpoint_data.get('loss', 'N/A')}")
|
||||
else:
|
||||
self.model_states['extrema_trainer']['initial_loss'] = 0.356
|
||||
self.model_states['extrema_trainer']['current_loss'] = 0.356
|
||||
self.model_states['extrema_trainer']['best_loss'] = 0.356
|
||||
self.model_states['extrema_trainer']['initial_loss'] = None
|
||||
self.model_states['extrema_trainer']['current_loss'] = None
|
||||
self.model_states['extrema_trainer']['best_loss'] = None
|
||||
logger.info("Extrema trainer starting fresh - no checkpoint found")
|
||||
|
||||
logger.info("Extrema trainer initialized")
|
||||
@ -289,15 +293,15 @@ class TradingOrchestrator:
|
||||
logger.warning("Extrema trainer not available")
|
||||
self.extrema_trainer = None
|
||||
|
||||
# Initialize COB RL model state (placeholder)
|
||||
self.model_states['cob_rl']['initial_loss'] = 0.356
|
||||
self.model_states['cob_rl']['current_loss'] = 0.0098
|
||||
self.model_states['cob_rl']['best_loss'] = 0.0076
|
||||
# Initialize COB RL model state - no synthetic data
|
||||
self.model_states['cob_rl']['initial_loss'] = None
|
||||
self.model_states['cob_rl']['current_loss'] = None
|
||||
self.model_states['cob_rl']['best_loss'] = None
|
||||
|
||||
# Initialize Decision model state (placeholder)
|
||||
self.model_states['decision']['initial_loss'] = 0.298
|
||||
self.model_states['decision']['current_loss'] = 0.0089
|
||||
self.model_states['decision']['best_loss'] = 0.0065
|
||||
# Initialize Decision model state - no synthetic data
|
||||
self.model_states['decision']['initial_loss'] = None
|
||||
self.model_states['decision']['current_loss'] = None
|
||||
self.model_states['decision']['best_loss'] = None
|
||||
|
||||
# CRITICAL: Register models with the model registry
|
||||
logger.info("Registering models with model registry...")
|
||||
@ -1289,7 +1293,7 @@ class TradingOrchestrator:
|
||||
direction = best_action_idx # 0=SELL, 1=HOLD, 2=BUY
|
||||
pred_confidence = float(confidence) if confidence is not None else float(action_probs[best_action_idx])
|
||||
predicted_price = current_price * (1 + (pred_confidence * 0.01 if best_action == 'BUY' else -pred_confidence * 0.01 if best_action == 'SELL' else 0))
|
||||
self.capture_cnn_prediction(symbol, direction, pred_confidence, current_price, predicted_price)
|
||||
self.capture_cnn_prediction(symbol, int(direction), pred_confidence, current_price, predicted_price)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting CNN predictions: {e}")
|
||||
@ -2539,4 +2543,23 @@ class TradingOrchestrator:
|
||||
logger.info(f"🧠 Decision fusion checkpoint saved: {metadata.checkpoint_id} (loss={avg_loss:.4f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving fusion checkpoint: {e}")
|
||||
logger.error(f"Error saving fusion checkpoint: {e}")
|
||||
|
||||
def get_universal_data_stream(self, current_time: datetime = None) -> Optional[UniversalDataStream]:
|
||||
"""Get universal data stream for external consumers like dashboard"""
|
||||
try:
|
||||
return self.universal_adapter.get_universal_data_stream(current_time)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting universal data stream: {e}")
|
||||
return None
|
||||
|
||||
def get_universal_data_for_model(self, model_type: str = 'cnn') -> Optional[Dict[str, Any]]:
|
||||
"""Get formatted universal data for specific model types"""
|
||||
try:
|
||||
stream = self.universal_adapter.get_universal_data_stream()
|
||||
if stream:
|
||||
return self.universal_adapter.format_for_model(stream, model_type)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting universal data for {model_type}: {e}")
|
||||
return None
|
Reference in New Issue
Block a user