diff --git a/core/orchestrator.py b/core/orchestrator.py index 493c975..4b9d4d3 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -302,10 +302,14 @@ class TradingOrchestrator: # CRITICAL: Register models with the model registry logger.info("Registering models with model registry...") + # Import model interfaces + from models import CNNModelInterface, RLAgentInterface, ModelInterface + # Register RL Agent if self.rl_agent: try: - self.register_model(self.rl_agent, weight=0.3) + rl_interface = RLAgentInterface(self.rl_agent, name="dqn_agent") + self.register_model(rl_interface, weight=0.3) logger.info("RL Agent registered successfully") except Exception as e: logger.error(f"Failed to register RL Agent: {e}") @@ -313,15 +317,35 @@ class TradingOrchestrator: # Register CNN Model if self.cnn_model: try: - self.register_model(self.cnn_model, weight=0.7) + cnn_interface = CNNModelInterface(self.cnn_model, name="enhanced_cnn") + self.register_model(cnn_interface, weight=0.7) logger.info("CNN Model registered successfully") except Exception as e: logger.error(f"Failed to register CNN Model: {e}") - # Register Extrema Trainer + # Register Extrema Trainer (as generic ModelInterface) if self.extrema_trainer: try: - self.register_model(self.extrema_trainer, weight=0.2) + # Create a simple wrapper for extrema trainer + class ExtremaTrainerInterface(ModelInterface): + def __init__(self, model, name: str): + super().__init__(name) + self.model = model + + def predict(self, data): + try: + if hasattr(self.model, 'predict'): + return self.model.predict(data) + return None + except Exception as e: + logger.error(f"Error in extrema trainer prediction: {e}") + return None + + def get_memory_usage(self) -> float: + return 30.0 # MB + + extrema_interface = ExtremaTrainerInterface(self.extrema_trainer, name="extrema_trainer") + self.register_model(extrema_interface, weight=0.2) logger.info("Extrema Trainer registered successfully") except Exception as e: logger.error(f"Failed to register Extrema Trainer: {e}") diff --git a/debug/test_orchestrator_predictions.py b/debug/test_orchestrator_predictions.py deleted file mode 100644 index c96e214..0000000 --- a/debug/test_orchestrator_predictions.py +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to debug orchestrator prediction issues -""" - -import sys -import os -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) - -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider -import asyncio -import logging - -# Set up logging -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger(__name__) - -async def test_orchestrator_predictions(): - """Test orchestrator prediction generation""" - - logger.info("=" * 60) - logger.info("TESTING ORCHESTRATOR PREDICTIONS") - logger.info("=" * 60) - - # Initialize components - logger.info("1. Initializing orchestrator...") - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider) - - # Check configuration - logger.info("2. Checking orchestrator configuration...") - logger.info(f" Confidence threshold: {orchestrator.confidence_threshold}") - logger.info(f" Confidence threshold close: {orchestrator.confidence_threshold_close}") - logger.info(f" Model weights: {orchestrator.model_weights}") - - # Check registered models - logger.info("3. Checking registered models...") - if hasattr(orchestrator, 'model_registry'): - # Check what methods are available - registry_methods = [method for method in dir(orchestrator.model_registry) if not method.startswith('_')] - logger.info(f" Model registry methods: {registry_methods}") - - # Check if we have models - if hasattr(orchestrator, 'rl_agent'): - logger.info(f" RL Agent available: {orchestrator.rl_agent is not None}") - if hasattr(orchestrator, 'cnn_model'): - logger.info(f" CNN Model available: {orchestrator.cnn_model is not None}") - if hasattr(orchestrator, 'extrema_trainer'): - logger.info(f" Extrema Trainer available: {orchestrator.extrema_trainer is not None}") - else: - logger.info(" No model registry found") - - # Test prediction generation - logger.info("4. Testing prediction generation...") - symbol = 'ETH/USDT' - - try: - # Get all predictions - predictions = await orchestrator._get_all_predictions(symbol) - logger.info(f" Total predictions: {len(predictions)}") - - for i, pred in enumerate(predictions): - logger.info(f" Prediction {i+1}: {pred.action} (confidence: {pred.confidence:.3f}, model: {pred.model_name})") - - # Test decision making - logger.info("5. Testing decision making...") - for i in range(3): # Test multiple decisions - decision = await orchestrator.make_trading_decision(symbol) - if decision: - logger.info(f" Decision {i+1}: {decision.action} (confidence: {decision.confidence:.3f})") - logger.info(f" Reasoning: {decision.reasoning}") - else: - logger.info(f" Decision {i+1}: None") - - # Wait a bit between decisions - await asyncio.sleep(1) - - # Test fallback prediction - logger.info("6. Testing fallback prediction...") - current_price = data_provider.get_current_price(symbol) - if current_price: - fallback = await orchestrator._generate_fallback_prediction(symbol, current_price) - if fallback: - logger.info(f" Fallback prediction: {fallback.action} (confidence: {fallback.confidence:.3f})") - else: - logger.info(" No fallback prediction generated") - else: - logger.info(" No current price available for fallback test") - - except Exception as e: - logger.error(f" Error testing predictions: {e}") - import traceback - traceback.print_exc() - - logger.info("=" * 60) - logger.info("TEST COMPLETE") - logger.info("=" * 60) - -if __name__ == "__main__": - asyncio.run(test_orchestrator_predictions()) \ No newline at end of file diff --git a/debug/test_trading_execution.py b/debug/test_trading_execution.py deleted file mode 100644 index bc178c7..0000000 --- a/debug/test_trading_execution.py +++ /dev/null @@ -1,125 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to debug trading execution issues -""" - -import sys -import os -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) - -from core.trading_executor import TradingExecutor -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider -import logging - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_trading_execution(): - """Test trading execution in simulation mode""" - - logger.info("=" * 60) - logger.info("TESTING TRADING EXECUTION") - logger.info("=" * 60) - - # Initialize components - logger.info("1. Initializing components...") - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider) - trading_executor = TradingExecutor() - - # Check trading executor status - logger.info("2. Checking trading executor status...") - logger.info(f" Trading enabled: {trading_executor.trading_enabled}") - logger.info(f" Trading mode: {trading_executor.trading_mode}") - logger.info(f" Simulation mode: {trading_executor.simulation_mode}") - logger.info(f" Exchange connected: {trading_executor.exchange is not None}") - - # Check account balance - logger.info("3. Checking account balance...") - try: - balance = trading_executor.get_account_balance() - logger.info(f" Account balance: {balance}") - except Exception as e: - logger.error(f" Error getting balance: {e}") - - # Test manual trade execution - logger.info("4. Testing manual trade execution...") - symbol = 'ETH/USDT' - action = 'BUY' - quantity = 0.01 - - try: - logger.info(f" Executing: {action} {quantity} {symbol}") - result = trading_executor.execute_trade(symbol, action, quantity) - logger.info(f" Result: {result}") - - if result: - # Check positions - positions = trading_executor.get_positions() - logger.info(f" Positions after trade: {positions}") - - except Exception as e: - logger.error(f" Error executing trade: {e}") - import traceback - traceback.print_exc() - - # Test orchestrator decision making - logger.info("5. Testing orchestrator decisions...") - try: - import asyncio - - async def test_decision(): - decision = await orchestrator.make_trading_decision(symbol) - if decision: - logger.info(f" Decision: {decision.action} (confidence: {decision.confidence:.3f})") - - # Test executing the decision - if decision.action != 'HOLD': - result = trading_executor.execute_signal( - symbol=decision.symbol, - action=decision.action, - confidence=decision.confidence, - current_price=decision.price - ) - logger.info(f" Execution result: {result}") - else: - logger.info(" No decision made") - - asyncio.run(test_decision()) - - except Exception as e: - logger.error(f" Error testing orchestrator: {e}") - import traceback - traceback.print_exc() - - # Check safety conditions - logger.info("6. Testing safety conditions...") - try: - # Check if safety conditions are blocking trades - symbol_test = 'ETH/USDT' - action_test = 'BUY' - - # Access private method for testing (not ideal but for debugging) - if hasattr(trading_executor, '_check_safety_conditions'): - safety_ok = trading_executor._check_safety_conditions(symbol_test, action_test) - logger.info(f" Safety conditions OK for {action_test} {symbol_test}: {safety_ok}") - - # Check individual safety conditions - config = trading_executor.mexc_config - logger.info(f" Emergency stop: {config.get('emergency_stop', False)}") - logger.info(f" Allowed symbols: {config.get('allowed_symbols', [])}") - logger.info(f" Daily loss: {trading_executor.daily_loss} / {config.get('max_daily_loss_usd', 5.0)}") - logger.info(f" Daily trades: {trading_executor.daily_trades} / {config.get('max_trades_per_hour', 2) * 24}") - logger.info(f" Concurrent positions: {len(trading_executor.positions)} / {config.get('max_concurrent_positions', 1)}") - - except Exception as e: - logger.error(f" Error checking safety conditions: {e}") - - logger.info("=" * 60) - logger.info("TEST COMPLETE") - logger.info("=" * 60) - -if __name__ == "__main__": - test_trading_execution() \ No newline at end of file