fix broken merge
This commit is contained in:
@@ -299,14 +299,24 @@ class TradingOrchestrator:
|
||||
Includes EnhancedRealtimeTrainingSystem for continuous learning
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider: Optional[DataProvider] = None, enhanced_rl_training: bool = True, model_registry: Optional[ModelRegistry] = None):
|
||||
def __init__(self, data_provider: Optional[DataProvider] = None, enhanced_rl_training: bool = True, model_registry: Optional[Any] = None):
|
||||
"""Initialize the enhanced orchestrator with full ML capabilities"""
|
||||
self.config = get_config()
|
||||
self.data_provider = data_provider or DataProvider()
|
||||
self.universal_adapter = UniversalDataAdapter(self.data_provider)
|
||||
self.model_manager = model_manager or create_model_manager()
|
||||
self.model_manager = None # Will be initialized later if needed
|
||||
self.enhanced_rl_training = enhanced_rl_training
|
||||
|
||||
# Set primary trading symbol
|
||||
self.symbol = self.config.get('primary_symbol', 'ETH/USDT')
|
||||
self.ref_symbols = self.config.get('reference_symbols', ['BTC/USDT'])
|
||||
|
||||
# Initialize signal accumulator
|
||||
self.signal_accumulator = {}
|
||||
|
||||
# Initialize confidence threshold
|
||||
self.confidence_threshold = self.config.get('confidence_threshold', 0.6)
|
||||
|
||||
# Determine the device to use (GPU if available, else CPU)
|
||||
# Initialize device - force CPU mode to avoid CUDA errors
|
||||
if torch.cuda.is_available():
|
||||
@@ -449,8 +459,8 @@ class TradingOrchestrator:
|
||||
self.last_inference: Dict[str, Dict] = {} # {model_name: last_inference_record}
|
||||
|
||||
# Initialize inference logger
|
||||
self.inference_logger = get_inference_logger()
|
||||
self.db_manager = get_database_manager()
|
||||
self.inference_logger = None # Will be initialized later if needed
|
||||
self.db_manager = None # Will be initialized later if needed
|
||||
|
||||
# ENHANCED: Real-time Training System Integration
|
||||
self.enhanced_training_system = (
|
||||
@@ -510,6 +520,247 @@ class TradingOrchestrator:
|
||||
self._initialize_decision_fusion() # Initialize fusion system
|
||||
self._initialize_transformer_model() # Initialize transformer model
|
||||
self._initialize_enhanced_training_system() # Initialize real-time training
|
||||
|
||||
def _normalize_model_name(self, model_name: str) -> str:
|
||||
"""Normalize model name for consistent storage"""
|
||||
import re
|
||||
|
||||
# Convert to lowercase
|
||||
normalized = model_name.lower()
|
||||
|
||||
# Replace spaces, hyphens, and other non-alphanumeric separators with underscores
|
||||
normalized = re.sub(r'[^a-z0-9]+', '_', normalized)
|
||||
|
||||
# Collapse multiple consecutive underscores into a single underscore
|
||||
normalized = re.sub(r'_+', '_', normalized)
|
||||
|
||||
# Strip leading and trailing underscores
|
||||
normalized = normalized.strip('_')
|
||||
|
||||
return normalized
|
||||
|
||||
def _log_data_status(self):
|
||||
"""Log data provider status"""
|
||||
logger.info(f"Data provider initialized for symbols: {self.data_provider.symbols}")
|
||||
logger.info(f"Available timeframes: {self.data_provider.timeframes}")
|
||||
|
||||
def _schedule_database_cleanup(self):
|
||||
"""
|
||||
Schedule periodic database cleanup tasks.
|
||||
|
||||
This method sets up a background task that periodically cleans up old
|
||||
inference records from the database to prevent it from growing indefinitely.
|
||||
|
||||
Side effects:
|
||||
- Creates a background asyncio task that runs every 24 hours
|
||||
- Cleans up records older than 30 days by default
|
||||
- Logs cleanup operations and any errors
|
||||
"""
|
||||
try:
|
||||
from utils.database_manager import get_database_manager
|
||||
|
||||
# Get database manager instance
|
||||
db_manager = get_database_manager()
|
||||
|
||||
async def cleanup_task():
|
||||
"""Background task for periodic database cleanup"""
|
||||
while True:
|
||||
try:
|
||||
logger.info("Running scheduled database cleanup...")
|
||||
success = db_manager.cleanup_old_records(days_to_keep=30)
|
||||
if success:
|
||||
logger.info("Database cleanup completed successfully")
|
||||
else:
|
||||
logger.warning("Database cleanup failed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during database cleanup: {e}")
|
||||
|
||||
# Wait 24 hours before next cleanup
|
||||
await asyncio.sleep(24 * 60 * 60) # 24 hours in seconds
|
||||
|
||||
# Create and start the cleanup task
|
||||
self._db_cleanup_task = asyncio.create_task(cleanup_task())
|
||||
logger.info("Database cleanup scheduler started - will run every 24 hours")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to schedule database cleanup: {e}")
|
||||
logger.warning("Database cleanup will not be performed automatically")
|
||||
|
||||
def _initialize_checkpoint_manager(self):
|
||||
"""
|
||||
Initialize the global checkpoint manager for model checkpoint management.
|
||||
|
||||
This method initializes the checkpoint manager that handles:
|
||||
- Saving model checkpoints with metadata
|
||||
- Loading the best performing checkpoints
|
||||
- Managing checkpoint storage and cleanup
|
||||
|
||||
Returns:
|
||||
CheckpointManager: The initialized checkpoint manager instance, or None if initialization fails
|
||||
|
||||
Side effects:
|
||||
- Sets self.checkpoint_manager to the global checkpoint manager instance
|
||||
- Creates checkpoint directory if it doesn't exist
|
||||
- Logs initialization status
|
||||
"""
|
||||
try:
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
|
||||
# Initialize the global checkpoint manager
|
||||
self.checkpoint_manager = get_checkpoint_manager(
|
||||
checkpoint_dir="models/checkpoints",
|
||||
max_checkpoints=10,
|
||||
metric_name="accuracy"
|
||||
)
|
||||
|
||||
logger.info(f"Checkpoint manager initialized successfully with directory: models/checkpoints")
|
||||
logger.info(f"Maximum checkpoints per model: 10, Primary metric: accuracy")
|
||||
|
||||
return self.checkpoint_manager
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize checkpoint manager: {e}")
|
||||
self.checkpoint_manager = None
|
||||
return None
|
||||
|
||||
def _start_cob_integration_sync(self):
|
||||
"""
|
||||
Start COB (Consolidated Order Book) integration synchronization.
|
||||
|
||||
This method initiates the COB integration system that provides real-time
|
||||
market microstructure data to the trading models. The COB integration
|
||||
streams order book data and generates features for CNN and DQN models.
|
||||
|
||||
Side effects:
|
||||
- Creates an async task to start COB integration if available
|
||||
- Registers COB data callbacks for model feeding
|
||||
- Begins streaming COB features to registered models
|
||||
- Logs integration status and any errors
|
||||
"""
|
||||
try:
|
||||
if self.cob_integration is None:
|
||||
logger.warning("COB integration not initialized - cannot start sync")
|
||||
return
|
||||
|
||||
# Create async task to start COB integration
|
||||
# Since this is called from __init__ (sync context), we need to create a task
|
||||
async def start_cob_task():
|
||||
try:
|
||||
await self.start_cob_integration()
|
||||
logger.info("COB integration synchronization started successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start COB integration sync: {e}")
|
||||
|
||||
# Create the task (will be executed when event loop is running)
|
||||
self._cob_sync_task = asyncio.create_task(start_cob_task())
|
||||
logger.info("COB integration sync task created - will start when event loop is available")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize COB integration sync: {e}")
|
||||
logger.warning("COB integration will not be available")
|
||||
|
||||
def _initialize_transformer_model(self):
|
||||
"""
|
||||
Initialize the transformer model for advanced trading pattern recognition.
|
||||
|
||||
This method loads or creates an AdvancedTradingTransformer model that uses
|
||||
attention mechanisms to analyze complex market patterns and generate trading signals.
|
||||
The model is optimized for COB (Consolidated Order Book) data and technical indicators.
|
||||
|
||||
Returns:
|
||||
bool: True if initialization successful, False otherwise
|
||||
|
||||
Side effects:
|
||||
- Sets self.primary_transformer to the loaded/created transformer model
|
||||
- Sets self.primary_transformer_trainer to the associated trainer
|
||||
- Updates self.transformer_checkpoint_info with checkpoint metadata
|
||||
- Loads best available checkpoint if exists
|
||||
- Moves model to appropriate device (CPU/GPU)
|
||||
- Logs initialization status and any errors
|
||||
"""
|
||||
try:
|
||||
from NN.models.advanced_transformer_trading import (
|
||||
AdvancedTradingTransformer,
|
||||
TradingTransformerTrainer,
|
||||
TradingTransformerConfig
|
||||
)
|
||||
|
||||
logger.info("Initializing transformer model for trading...")
|
||||
|
||||
# Create transformer configuration
|
||||
config = TradingTransformerConfig()
|
||||
|
||||
# Initialize the transformer model
|
||||
self.primary_transformer = AdvancedTradingTransformer(config)
|
||||
logger.info(f"AdvancedTradingTransformer created with config: d_model={config.d_model}, "
|
||||
f"n_heads={config.n_heads}, n_layers={config.n_layers}")
|
||||
|
||||
# Initialize the trainer
|
||||
self.primary_transformer_trainer = TradingTransformerTrainer(
|
||||
model=self.primary_transformer,
|
||||
config=config
|
||||
)
|
||||
logger.info("TradingTransformerTrainer initialized")
|
||||
|
||||
# Move model to device
|
||||
if hasattr(self, 'device') and self.device:
|
||||
self.primary_transformer.to(self.device)
|
||||
logger.info(f"Transformer model moved to device: {self.device}")
|
||||
else:
|
||||
logger.info("Transformer model using default device")
|
||||
|
||||
# Try to load best checkpoint
|
||||
checkpoint_loaded = False
|
||||
try:
|
||||
if self.checkpoint_manager:
|
||||
checkpoint_path, checkpoint_metadata = self.checkpoint_manager.load_best_checkpoint("transformer")
|
||||
if checkpoint_path and checkpoint_metadata:
|
||||
# Load the checkpoint
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
self.primary_transformer.load_state_dict(checkpoint.get('model_state_dict', checkpoint))
|
||||
|
||||
# Update checkpoint info
|
||||
self.transformer_checkpoint_info = {
|
||||
'path': checkpoint_path,
|
||||
'metadata': checkpoint_metadata,
|
||||
'loaded_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info(f"Transformer checkpoint loaded from: {checkpoint_path}")
|
||||
logger.info(f"Checkpoint metrics: {checkpoint_metadata.get('performance_metrics', {})}")
|
||||
checkpoint_loaded = True
|
||||
else:
|
||||
logger.info("No transformer checkpoint found - using fresh model")
|
||||
else:
|
||||
logger.warning("Checkpoint manager not available - cannot load transformer checkpoint")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading transformer checkpoint: {e}")
|
||||
logger.info("Continuing with fresh transformer model")
|
||||
|
||||
if not checkpoint_loaded:
|
||||
# Initialize checkpoint info for new model
|
||||
self.transformer_checkpoint_info = {
|
||||
'status': 'fresh_model',
|
||||
'created_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info("Transformer model initialization completed successfully")
|
||||
return True
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Advanced transformer trading module not available: {e}")
|
||||
self.primary_transformer = None
|
||||
self.primary_transformer_trainer = None
|
||||
logger.info("Transformer model will not be available")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize transformer model: {e}")
|
||||
self.primary_transformer = None
|
||||
self.primary_transformer_trainer = None
|
||||
return False
|
||||
|
||||
def _initialize_ml_models(self):
|
||||
"""Initialize ML models for enhanced trading"""
|
||||
try:
|
||||
@@ -549,7 +800,7 @@ class TradingOrchestrator:
|
||||
if hasattr(self.rl_agent, "load_best_checkpoint"):
|
||||
try:
|
||||
self.rl_agent.load_best_checkpoint()
|
||||
checkpoint_loaded = True
|
||||
checkpoint_loaded = True
|
||||
logger.info("DQN checkpoint loaded successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading DQN checkpoint (likely dimension mismatch): {e}")
|
||||
@@ -593,7 +844,7 @@ class TradingOrchestrator:
|
||||
try:
|
||||
# CNN checkpoint loading would go here
|
||||
logger.info("CNN checkpoint loaded successfully")
|
||||
checkpoint_loaded = True
|
||||
checkpoint_loaded = True
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading CNN checkpoint: {e}")
|
||||
checkpoint_loaded = False
|
||||
@@ -686,8 +937,8 @@ class TradingOrchestrator:
|
||||
except ImportError:
|
||||
logger.warning("Extrema trainer not available")
|
||||
self.extrema_trainer = None
|
||||
|
||||
self.cob_rl_agent = None
|
||||
|
||||
self.cob_rl_agent = None
|
||||
|
||||
|
||||
# CRITICAL: Register models with the model registry
|
||||
@@ -704,7 +955,7 @@ class TradingOrchestrator:
|
||||
try:
|
||||
rl_interface = RLAgentInterface(self.rl_agent, name="dqn_agent")
|
||||
if self.model_registry.register_model(rl_interface):
|
||||
logger.info("RL Agent registered successfully")
|
||||
logger.info("RL Agent registered successfully")
|
||||
else:
|
||||
logger.error("Failed to register RL Agent with registry")
|
||||
except Exception as e:
|
||||
@@ -715,7 +966,7 @@ class TradingOrchestrator:
|
||||
try:
|
||||
cnn_interface = CNNModelInterface(self.cnn_model, name="cnn_model")
|
||||
if self.model_registry.register_model(cnn_interface):
|
||||
logger.info("CNN Model registered successfully")
|
||||
logger.info("CNN Model registered successfully")
|
||||
else:
|
||||
logger.error("Failed to register CNN Model with registry")
|
||||
except Exception as e:
|
||||
@@ -726,14 +977,13 @@ class TradingOrchestrator:
|
||||
try:
|
||||
extrema_interface = ExtremaTrainerInterface(self.extrema_trainer, name="extrema_trainer")
|
||||
if self.model_registry.register_model(extrema_interface):
|
||||
logger.info("Extrema Trainer registered successfully")
|
||||
logger.info("Extrema Trainer registered successfully")
|
||||
else:
|
||||
logger.error("Failed to register Extrema Trainer with registry")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register Extrema Trainer: {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing ML models: {e}")
|
||||
|
||||
def get_model_training_stats(self) -> Dict[str, Dict[str, Any]]:
|
||||
@@ -841,7 +1091,7 @@ class TradingOrchestrator:
|
||||
}
|
||||
|
||||
for model_name, stats in dashboard_stats.items():
|
||||
if model_name in self.model_states:
|
||||
if model_name in self.model_states:
|
||||
self.model_states[model_name]["current_loss"] = stats["current_loss"]
|
||||
self.model_states[model_name]["initial_loss"] = stats["initial_loss"]
|
||||
if (
|
||||
@@ -1066,7 +1316,7 @@ class TradingOrchestrator:
|
||||
|
||||
with open(session_file, "w", encoding="utf-8") as f:
|
||||
json.dump(existing, f, indent=2)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
logger.error(f"Error appending session snapshot: {e}")
|
||||
|
||||
def get_model_toggle_state(self, model_name: str) -> Dict[str, bool]:
|
||||
@@ -1124,8 +1374,7 @@ class TradingOrchestrator:
|
||||
self._save_ui_state()
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering model {model_name} dynamically: {e}")
|
||||
return False
|
||||
|
||||
@@ -1239,8 +1488,8 @@ class TradingOrchestrator:
|
||||
self.cob_integration.add_dashboard_callback(
|
||||
self._on_cob_dashboard_data
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize COB Integration: {e}")
|
||||
self.cob_integration = None
|
||||
else:
|
||||
@@ -1460,10 +1709,10 @@ class TradingOrchestrator:
|
||||
)
|
||||
|
||||
# Use programmatic decision for actual actions
|
||||
decision = self._combine_predictions(
|
||||
symbol=symbol,
|
||||
price=current_price,
|
||||
predictions=predictions,
|
||||
decision = self._combine_predictions(
|
||||
symbol=symbol,
|
||||
price=current_price,
|
||||
predictions=predictions,
|
||||
timestamp=current_time,
|
||||
)
|
||||
else:
|
||||
@@ -1545,7 +1794,7 @@ class TradingOrchestrator:
|
||||
price_change_pct = (
|
||||
(current_price - recent_prices[-2]) / recent_prices[-2] * 100
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get recent prices for {symbol}: {e}")
|
||||
# Fallback: use current price and a small assumed change
|
||||
price_change_pct = 0.1 # Assume small positive change
|
||||
@@ -1658,7 +1907,7 @@ class TradingOrchestrator:
|
||||
# Validate base_data has the required method
|
||||
if not hasattr(base_data, 'get_feature_vector'):
|
||||
logger.debug(f"BaseDataInput for {symbol} missing get_feature_vector method")
|
||||
return None
|
||||
return None
|
||||
|
||||
# Get unified feature vector (7850 features including all timeframes and COB data)
|
||||
feature_vector = base_data.get_feature_vector()
|
||||
@@ -1724,7 +1973,7 @@ class TradingOrchestrator:
|
||||
action_scores[action] += random.uniform(-0.001, 0.001)
|
||||
|
||||
best_action = max(action_scores.keys(), key=lambda k: action_scores[k])
|
||||
best_confidence = action_scores[best_action]
|
||||
best_confidence = action_scores[best_action]
|
||||
|
||||
# DEBUG: Log action scores to understand bias
|
||||
logger.debug(f"Action scores for {symbol}: BUY={action_scores['BUY']:.3f}, SELL={action_scores['SELL']:.3f}, HOLD={action_scores['HOLD']:.3f}")
|
||||
@@ -1958,8 +2207,8 @@ class TradingOrchestrator:
|
||||
|
||||
# Try to load decision fusion checkpoint
|
||||
result = load_best_checkpoint("decision_fusion")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
# Load the checkpoint into the network
|
||||
checkpoint = torch.load(file_path, map_location=self.device)
|
||||
|
||||
@@ -1989,11 +2238,11 @@ class TradingOrchestrator:
|
||||
logger.info(
|
||||
f"Decision fusion network loaded from checkpoint: {metadata.checkpoint_id} (loss={loss_str})"
|
||||
)
|
||||
else:
|
||||
else:
|
||||
logger.info(
|
||||
"No existing decision fusion checkpoint found, starting fresh"
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading decision fusion checkpoint: {e}")
|
||||
logger.info("Decision fusion network starting fresh")
|
||||
|
||||
@@ -2053,6 +2302,7 @@ class TradingOrchestrator:
|
||||
def stop_enhanced_training(self):
|
||||
"""Stop the enhanced real-time training system"""
|
||||
try:
|
||||
if self.enhanced_training_system:
|
||||
self.enhanced_training_system.stop_training()
|
||||
logger.info("Enhanced real-time training stopped")
|
||||
return True
|
||||
@@ -2075,12 +2325,18 @@ class TradingOrchestrator:
|
||||
# Get base stats from enhanced training system
|
||||
stats = {}
|
||||
if hasattr(self.enhanced_training_system, "get_training_statistics"):
|
||||
stats = self.enhanced_training_system.get_training_statistics()
|
||||
stats = self.enhanced_training_system.get_training_statistics()
|
||||
else:
|
||||
stats = {}
|
||||
|
||||
stats["training_enabled"] = self.training_enabled
|
||||
stats["system_available"] = ENHANCED_TRAINING_AVAILABLE
|
||||
|
||||
# Add orchestrator-specific training integration data
|
||||
stats["orchestrator_integration"] = {
|
||||
"enhanced_training_enabled": self.enhanced_training_enabled,
|
||||
"model_registry_count": len(self.model_registry.models) if hasattr(self, 'model_registry') else 0,
|
||||
"decision_fusion_enabled": self.decision_fusion_enabled
|
||||
}
|
||||
|
||||
# Add model-specific training status from orchestrator
|
||||
@@ -2220,7 +2476,7 @@ class TradingOrchestrator:
|
||||
current_time
|
||||
)
|
||||
elif self.universal_adapter:
|
||||
return self.universal_adapter.get_universal_data_stream(current_time)
|
||||
return self.universal_adapter.get_universal_data_stream(current_time)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting universal data stream: {e}")
|
||||
@@ -2235,10 +2491,10 @@ class TradingOrchestrator:
|
||||
stream, model_type
|
||||
)
|
||||
elif self.universal_adapter:
|
||||
stream = self.universal_adapter.get_universal_data_stream()
|
||||
if stream:
|
||||
return self.universal_adapter.format_for_model(stream, model_type)
|
||||
return None
|
||||
stream = self.universal_adapter.get_universal_data_stream()
|
||||
if stream:
|
||||
return self.universal_adapter.format_for_model(stream, model_type)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting universal data for {model_type}: {e}")
|
||||
return None
|
||||
@@ -2278,7 +2534,7 @@ class TradingOrchestrator:
|
||||
side = position.get("side", "LONG")
|
||||
|
||||
if entry_price and size > 0:
|
||||
if side.upper() == "LONG":
|
||||
if side.upper() == "LONG":
|
||||
pnl = (current_price - entry_price) * size
|
||||
else: # SHORT
|
||||
pnl = (entry_price - current_price) * size
|
||||
|
||||
Reference in New Issue
Block a user