From 98ebbe5089d92633a5354d7117582cb1893fd343 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Mon, 8 Sep 2025 15:22:01 +0300 Subject: [PATCH] cleanup --- core/orchestrator.py | 4 +- test_cob_audit.py | 87 ------- test_enhanced_training_integration.py | 144 ----------- test_enhanced_training_simple.py | 78 ------ test_fresh_to_loaded.py | 180 -------------- test_leverage_fix.py | 74 ------ test_model_audit.py | 344 -------------------------- test_model_fixes.py | 226 ----------------- 8 files changed, 2 insertions(+), 1135 deletions(-) delete mode 100644 test_cob_audit.py delete mode 100644 test_enhanced_training_integration.py delete mode 100644 test_enhanced_training_simple.py delete mode 100644 test_fresh_to_loaded.py delete mode 100644 test_leverage_fix.py delete mode 100644 test_model_audit.py delete mode 100644 test_model_fixes.py diff --git a/core/orchestrator.py b/core/orchestrator.py index 7bf876e..49d1866 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -103,8 +103,8 @@ class TradingOrchestrator: # Configuration - AGGRESSIVE for more training data self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.15) # Lowered from 0.20 self.confidence_threshold_close = self.config.orchestrator.get('confidence_threshold_close', 0.08) # Lowered from 0.10 - self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30) - self.symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) # Enhanced to support multiple symbols + self.decision_frequency = self.config.orchestrator.get('decision_frequency', 5) + self.symbols = self.config.get('symbols', ['ETH/USDT']) # Enhanced to support multiple symbols # NEW: Aggressiveness parameters self.entry_aggressiveness = self.config.orchestrator.get('entry_aggressiveness', 0.5) # 0.0 = conservative, 1.0 = very aggressive diff --git a/test_cob_audit.py b/test_cob_audit.py deleted file mode 100644 index 7afacce..0000000 --- a/test_cob_audit.py +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env python3 -""" -Test COB Integration Status in Enhanced Orchestrator -""" - -import asyncio -import sys -from pathlib import Path -sys.path.append(str(Path('.').absolute())) - -from core.enhanced_orchestrator import EnhancedTradingOrchestrator -from core.data_provider import DataProvider - -async def test_cob_integration(): - print("=" * 60) - print("COB INTEGRATION AUDIT") - print("=" * 60) - - try: - data_provider = DataProvider() - orchestrator = EnhancedTradingOrchestrator( - data_provider=data_provider, - symbols=['ETH/USDT', 'BTC/USDT'], - enhanced_rl_training=True - ) - - print(f"✓ Enhanced Orchestrator created") - print(f"Has COB integration attribute: {hasattr(orchestrator, 'cob_integration')}") - print(f"COB integration value: {orchestrator.cob_integration}") - print(f"COB integration type: {type(orchestrator.cob_integration)}") - print(f"COB integration active: {getattr(orchestrator, 'cob_integration_active', 'Not set')}") - - if orchestrator.cob_integration: - print("\n--- COB Integration Details ---") - print(f"COB Integration class: {orchestrator.cob_integration.__class__.__name__}") - - # Check if it has the expected methods - methods_to_check = ['get_statistics', 'get_cob_snapshot', 'add_dashboard_callback', 'start', 'stop'] - for method in methods_to_check: - has_method = hasattr(orchestrator.cob_integration, method) - print(f"Has {method}: {has_method}") - - # Try to get statistics - if hasattr(orchestrator.cob_integration, 'get_statistics'): - try: - stats = orchestrator.cob_integration.get_statistics() - print(f"COB statistics: {stats}") - except Exception as e: - print(f"Error getting COB statistics: {e}") - - # Try to get a snapshot - if hasattr(orchestrator.cob_integration, 'get_cob_snapshot'): - try: - snapshot = orchestrator.cob_integration.get_cob_snapshot('ETH/USDT') - print(f"ETH/USDT snapshot: {snapshot}") - except Exception as e: - print(f"Error getting COB snapshot: {e}") - - # Check if COB integration needs to be started - print(f"\n--- Starting COB Integration ---") - try: - await orchestrator.start_cob_integration() - print("✓ COB integration started successfully") - - # Wait a moment and check statistics again - await asyncio.sleep(3) - if hasattr(orchestrator.cob_integration, 'get_statistics'): - stats = orchestrator.cob_integration.get_statistics() - print(f"COB statistics after start: {stats}") - - except Exception as e: - print(f"Error starting COB integration: {e}") - else: - print("\n❌ COB integration is None - this explains the dashboard issues") - print("The Enhanced Orchestrator failed to initialize COB integration") - - # Check the error flag - if hasattr(orchestrator, '_cob_integration_failed'): - print(f"COB integration failed flag: {orchestrator._cob_integration_failed}") - - except Exception as e: - print(f"Error in COB audit: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - asyncio.run(test_cob_integration()) \ No newline at end of file diff --git a/test_enhanced_training_integration.py b/test_enhanced_training_integration.py deleted file mode 100644 index 3568fff..0000000 --- a/test_enhanced_training_integration.py +++ /dev/null @@ -1,144 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Enhanced Training Integration - -This script tests the integration of EnhancedRealtimeTrainingSystem -into the TradingOrchestrator to ensure it works correctly. -""" - -import sys -import os -import logging -import asyncio -from datetime import datetime - -# Add project root to path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -async def test_enhanced_training_integration(): - """Test the enhanced training system integration""" - try: - logger.info("=" * 60) - logger.info("TESTING ENHANCED TRAINING INTEGRATION") - logger.info("=" * 60) - - # 1. Initialize orchestrator with enhanced training - logger.info("1. Initializing orchestrator with enhanced training...") - data_provider = DataProvider() - orchestrator = TradingOrchestrator( - data_provider=data_provider, - enhanced_rl_training=True - ) - - # 2. Check if training system is available - logger.info("2. Checking training system availability...") - training_available = hasattr(orchestrator, 'enhanced_training_system') - training_enabled = getattr(orchestrator, 'training_enabled', False) - - logger.info(f" - Training system attribute: {'✅ Available' if training_available else '❌ Missing'}") - logger.info(f" - Training enabled: {'✅ Yes' if training_enabled else '❌ No'}") - - # 3. Test training system initialization - if training_available and orchestrator.enhanced_training_system: - logger.info("3. Testing training system methods...") - - # Test getting training statistics - stats = orchestrator.get_enhanced_training_stats() - logger.info(f" - Training stats retrieved: {len(stats)} fields") - logger.info(f" - Training enabled in stats: {stats.get('training_enabled', False)}") - logger.info(f" - System available: {stats.get('system_available', False)}") - - # Test starting training - start_result = orchestrator.start_enhanced_training() - logger.info(f" - Start training result: {'✅ Success' if start_result else '❌ Failed'}") - - if start_result: - # Let it run for a few seconds - logger.info(" - Letting training run for 5 seconds...") - await asyncio.sleep(5) - - # Get updated stats - updated_stats = orchestrator.get_enhanced_training_stats() - logger.info(f" - Updated stats: {updated_stats.get('is_training', False)}") - - # Stop training - stop_result = orchestrator.stop_enhanced_training() - logger.info(f" - Stop training result: {'✅ Success' if stop_result else '❌ Failed'}") - - else: - logger.warning("3. Training system not available - checking fallback behavior...") - - # Test methods when training system is not available - stats = orchestrator.get_enhanced_training_stats() - logger.info(f" - Fallback stats: {stats}") - - start_result = orchestrator.start_enhanced_training() - logger.info(f" - Fallback start result: {start_result}") - - # 4. Test dashboard connection method - logger.info("4. Testing dashboard connection method...") - try: - orchestrator.set_training_dashboard(None) # Test with None - logger.info(" - Dashboard connection method: ✅ Available") - except Exception as e: - logger.error(f" - Dashboard connection method error: {e}") - - # 5. Summary - logger.info("=" * 60) - logger.info("INTEGRATION TEST SUMMARY") - logger.info("=" * 60) - - if training_available and training_enabled: - logger.info("✅ ENHANCED TRAINING INTEGRATION SUCCESSFUL") - logger.info(" - Training system properly integrated") - logger.info(" - All methods available and functional") - logger.info(" - Ready for real-time training") - elif training_available: - logger.info("⚠️ ENHANCED TRAINING PARTIALLY INTEGRATED") - logger.info(" - Training system available but not enabled") - logger.info(" - Check EnhancedRealtimeTrainingSystem import") - else: - logger.info("❌ ENHANCED TRAINING INTEGRATION FAILED") - logger.info(" - Training system not properly integrated") - logger.info(" - Methods missing or non-functional") - - return training_available and training_enabled - - except Exception as e: - logger.error(f"Error in integration test: {e}") - import traceback - logger.error(traceback.format_exc()) - return False - -async def main(): - """Main test function""" - try: - success = await test_enhanced_training_integration() - - if success: - logger.info("🎉 All tests passed! Enhanced training integration is working.") - return 0 - else: - logger.warning("⚠️ Some tests failed. Check the integration.") - return 1 - - except KeyboardInterrupt: - logger.info("Test interrupted by user") - return 0 - except Exception as e: - logger.error(f"Fatal error in test: {e}") - return 1 - -if __name__ == "__main__": - exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file diff --git a/test_enhanced_training_simple.py b/test_enhanced_training_simple.py deleted file mode 100644 index f3f600c..0000000 --- a/test_enhanced_training_simple.py +++ /dev/null @@ -1,78 +0,0 @@ -#!/usr/bin/env python3 -""" -Simple Enhanced Training Test - -Quick test to verify enhanced training system can be enabled and controlled. -""" - -import sys -import os -import logging - -# Add project root to path -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from core.orchestrator import TradingOrchestrator -from core.data_provider import DataProvider - -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def test_enhanced_training(): - """Test enhanced training system""" - try: - logger.info("Testing Enhanced Training System...") - - # 1. Create data provider - data_provider = DataProvider() - - # 2. Create orchestrator with enhanced training ENABLED - logger.info("Creating orchestrator with enhanced_rl_training=True...") - orchestrator = TradingOrchestrator( - data_provider=data_provider, - enhanced_rl_training=True # 🔥 THIS ENABLES IT - ) - - # 3. Check if training system is available - logger.info(f"Training system available: {orchestrator.enhanced_training_system is not None}") - logger.info(f"Training enabled: {orchestrator.training_enabled}") - - # 4. Get training stats - stats = orchestrator.get_enhanced_training_stats() - logger.info(f"Training stats: {stats}") - - # 5. Test start/stop - if orchestrator.enhanced_training_system: - logger.info("Testing start/stop functionality...") - - # Start training - start_result = orchestrator.start_enhanced_training() - logger.info(f"Start result: {start_result}") - - # Get updated stats - updated_stats = orchestrator.get_enhanced_training_stats() - logger.info(f"Updated stats: {updated_stats}") - - # Stop training - stop_result = orchestrator.stop_enhanced_training() - logger.info(f"Stop result: {stop_result}") - - logger.info("✅ Enhanced training system is working!") - return True - else: - logger.warning("❌ Enhanced training system not available") - return False - - except Exception as e: - logger.error(f"Error testing enhanced training: {e}") - return False - -if __name__ == "__main__": - success = test_enhanced_training() - if success: - print("\n🎉 Enhanced training system is ready to use!") - print("To enable it in your main system, use:") - print(" enhanced_rl_training=True when creating TradingOrchestrator") - else: - print("\n⚠️ Enhanced training system has issues. Check the logs above.") \ No newline at end of file diff --git a/test_fresh_to_loaded.py b/test_fresh_to_loaded.py deleted file mode 100644 index bf3edc0..0000000 --- a/test_fresh_to_loaded.py +++ /dev/null @@ -1,180 +0,0 @@ -#!/usr/bin/env python3 -""" -Test FRESH to LOADED Model Status Fix - -This script tests the fix for models showing as FRESH instead of LOADED. -""" - -import logging -import sys -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).resolve().parent -sys.path.insert(0, str(project_root)) - -logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def test_orchestrator_model_initialization(): - """Test that orchestrator initializes all models correctly""" - print("=" * 60) - print("Testing Orchestrator Model Initialization...") - print("=" * 60) - - try: - from core.data_provider import DataProvider - from core.orchestrator import TradingOrchestrator - - # Create data provider and orchestrator - data_provider = DataProvider() - orchestrator = TradingOrchestrator(data_provider=data_provider, enhanced_rl_training=True) - - # Check which models were initialized - models_initialized = [] - - if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent: - models_initialized.append('DQN') - - if hasattr(orchestrator, 'cnn_model') and orchestrator.cnn_model: - models_initialized.append('CNN') - - if hasattr(orchestrator, 'extrema_trainer') and orchestrator.extrema_trainer: - models_initialized.append('ExtremaTrainer') - - if hasattr(orchestrator, 'cob_rl_agent') and orchestrator.cob_rl_agent: - models_initialized.append('COB_RL') - - if hasattr(orchestrator, 'transformer_model') and orchestrator.transformer_model: - models_initialized.append('TRANSFORMER') - - if hasattr(orchestrator, 'decision_model') and orchestrator.decision_model: - models_initialized.append('DECISION') - - print(f"✅ Initialized Models: {', '.join(models_initialized)}") - - # Check model states - print("\nModel States:") - for model_name, state in orchestrator.model_states.items(): - checkpoint_loaded = state.get('checkpoint_loaded', False) - status = "LOADED" if checkpoint_loaded else "FRESH" - filename = state.get('checkpoint_filename', 'none') - print(f" {model_name.upper()}: {status} ({filename})") - - return orchestrator, len(models_initialized) - - except Exception as e: - print(f"❌ Orchestrator initialization failed: {e}") - return None, 0 - -def test_checkpoint_saving(orchestrator): - """Test saving checkpoints for all models""" - print("\n" + "=" * 60) - print("Testing Checkpoint Saving...") - print("=" * 60) - - try: - from model_checkpoint_saver import ModelCheckpointSaver - - saver = ModelCheckpointSaver(orchestrator) - - # Force all models to LOADED status - updated_models = saver.force_all_models_to_loaded() - - print(f"✅ Updated {len(updated_models)} models to LOADED status") - - # Check updated states - print("\nUpdated Model States:") - fresh_count = 0 - loaded_count = 0 - - for model_name, state in orchestrator.model_states.items(): - checkpoint_loaded = state.get('checkpoint_loaded', False) - status = "LOADED" if checkpoint_loaded else "FRESH" - filename = state.get('checkpoint_filename', 'none') - print(f" {model_name.upper()}: {status} ({filename})") - - if checkpoint_loaded: - loaded_count += 1 - else: - fresh_count += 1 - - print(f"\nSummary: {loaded_count} LOADED, {fresh_count} FRESH") - - return fresh_count == 0 - - except Exception as e: - print(f"❌ Checkpoint saving test failed: {e}") - return False - -def test_dashboard_model_status(): - """Test how models show up in dashboard""" - print("\n" + "=" * 60) - print("Testing Dashboard Model Status Display...") - print("=" * 60) - - try: - # Simulate dashboard model status check - from web.component_manager import DashboardComponentManager - - print("✅ Dashboard component manager imports successfully") - print("✅ Model status display logic available") - - return True - - except Exception as e: - print(f"❌ Dashboard test failed: {e}") - return False - -def main(): - """Run all tests""" - print("🔧 Testing FRESH to LOADED Model Status Fix") - print("=" * 60) - - # Test 1: Orchestrator initialization - orchestrator, models_count = test_orchestrator_model_initialization() - if not orchestrator: - print("\n❌ Cannot proceed - orchestrator initialization failed") - return False - - # Test 2: Checkpoint saving - checkpoint_success = test_checkpoint_saving(orchestrator) - - # Test 3: Dashboard integration - dashboard_success = test_dashboard_model_status() - - # Summary - print("\n" + "=" * 60) - print("TEST SUMMARY") - print("=" * 60) - - tests = [ - ("Model Initialization", models_count > 0), - ("Checkpoint Status Fix", checkpoint_success), - ("Dashboard Integration", dashboard_success) - ] - - passed = 0 - for test_name, result in tests: - status = "PASSED" if result else "FAILED" - icon = "✅" if result else "❌" - print(f"{icon} {test_name}: {status}") - if result: - passed += 1 - - print(f"\nOverall: {passed}/{len(tests)} tests passed") - - if passed == len(tests): - print("\n🎉 ALL TESTS PASSED! Models should now show as LOADED instead of FRESH.") - print("\nNext steps:") - print("1. Restart the dashboard") - print("2. Models should now show as LOADED in the status panel") - print("3. The FRESH status issue should be resolved") - else: - print(f"\n⚠️ {len(tests) - passed} tests failed. Some issues may remain.") - - return passed == len(tests) - -if __name__ == "__main__": - success = main() - sys.exit(0 if success else 1) diff --git a/test_leverage_fix.py b/test_leverage_fix.py deleted file mode 100644 index 905fb7b..0000000 --- a/test_leverage_fix.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env python3 - -""" -Test script to verify leverage P&L calculations are working correctly -""" - -from web.clean_dashboard import create_clean_dashboard - -def test_leverage_calculations(): - print("🧮 Testing Leverage P&L Calculations") - print("=" * 50) - - # Create dashboard - dashboard = create_clean_dashboard() - - print("✅ Dashboard created successfully") - - # Test 1: Position leverage vs slider leverage - print("\n📊 Test 1: Position vs Slider Leverage") - dashboard.current_leverage = 25 # Current slider at x25 - dashboard.current_position = { - 'side': 'LONG', - 'size': 0.01, - 'price': 2000.0, # Entry at $2000 - 'leverage': 10, # Position opened at x10 leverage - 'symbol': 'ETH/USDT' - } - - print(f" Position opened at: x{dashboard.current_position['leverage']} leverage") - print(f" Current slider at: x{dashboard.current_leverage} leverage") - print(" ✅ Position uses its stored leverage, not current slider") - - # Test 2: Trading statistics with leveraged P&L - print("\n📈 Test 2: Trading Statistics") - test_trade = { - 'symbol': 'ETH/USDT', - 'side': 'BUY', - 'pnl': 100.0, # Leveraged P&L - 'pnl_raw': 2.0, # Raw P&L (before leverage) - 'leverage_used': 50, # x50 leverage used - 'fees': 0.5 - } - - dashboard.closed_trades.append(test_trade) - dashboard.session_pnl = 100.0 - - stats = dashboard._get_trading_statistics() - - print(f" Trade raw P&L: ${test_trade['pnl_raw']:.2f}") - print(f" Trade leverage: x{test_trade['leverage_used']}") - print(f" Trade leveraged P&L: ${test_trade['pnl']:.2f}") - print(f" Statistics total P&L: ${stats['total_pnl']:.2f}") - print(f" ✅ Statistics use leveraged P&L correctly") - - # Test 3: Session P&L calculation - print("\n💰 Test 3: Session P&L") - print(f" Session P&L: ${dashboard.session_pnl:.2f}") - print(f" Expected: $100.00") - if abs(dashboard.session_pnl - 100.0) < 0.01: - print(" ✅ Session P&L correctly uses leveraged amounts") - else: - print(" ❌ Session P&L calculation error") - - print("\n🎯 Summary:") - print(" • Positions store their original leverage") - print(" • Unrealized P&L uses position leverage (not slider)") - print(" • Completed trades store both raw and leveraged P&L") - print(" • Statistics display leveraged P&L") - print(" • Session totals use leveraged amounts") - - print("\n✅ ALL LEVERAGE P&L CALCULATIONS FIXED!") - -if __name__ == "__main__": - test_leverage_calculations() \ No newline at end of file diff --git a/test_model_audit.py b/test_model_audit.py deleted file mode 100644 index 24576ef..0000000 --- a/test_model_audit.py +++ /dev/null @@ -1,344 +0,0 @@ -#!/usr/bin/env python3 -""" -Model Loading/Saving Audit Test - -This script tests the model registry and saving/loading mechanisms -to identify any issues and provide recommendations. -""" - -import os -import sys -import logging -import torch -import torch.nn as nn -from datetime import datetime -from pathlib import Path - -# Add project root to path -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - -from utils.model_registry import get_model_registry, save_model, load_model, save_checkpoint - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -class SimpleTestModel(nn.Module): - """Simple neural network for testing""" - def __init__(self, input_size=10, hidden_size=32, output_size=2): - super().__init__() - self.net = nn.Sequential( - nn.Linear(input_size, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, output_size) - ) - - def forward(self, x): - return self.net(x) - -def test_model_registry(): - """Test the model registry functionality""" - logger.info("=== MODEL REGISTRY AUDIT ===") - - registry = get_model_registry() - logger.info(f"Registry base directory: {registry.base_dir}") - logger.info(f"Registry metadata file: {registry.metadata_file}") - - # Check existing models - existing_models = registry.list_models() - logger.info(f"Existing models: {existing_models}") - - # Test model creation and saving - logger.info("Creating test model...") - test_model = SimpleTestModel() - - # Generate some fake training data - test_input = torch.randn(32, 10) - test_output = test_model(test_input) - - logger.info(f"Test model created. Input shape: {test_input.shape}, Output shape: {test_output.shape}") - - # Test saving with different methods - logger.info("Testing model saving...") - - # Test 1: Save with unified registry - success = save_model( - model=test_model, - model_name="audit_test_model", - model_type="cnn", - metadata={ - "test_type": "registry_audit", - "created_at": datetime.now().isoformat(), - "input_shape": list(test_input.shape), - "output_shape": list(test_output.shape) - } - ) - - if success: - logger.info("✅ Model saved successfully with unified registry") - else: - logger.error("❌ Failed to save model with unified registry") - - # Test 2: Load model back - logger.info("Testing model loading...") - loaded_model = load_model("audit_test_model", "cnn") - - if loaded_model is not None: - logger.info("✅ Model loaded successfully") - - # Test if loaded model has proper structure - if hasattr(loaded_model, 'state_dict') and callable(loaded_model.state_dict): - state_dict = loaded_model.state_dict() - logger.info(f"Loaded model test - State dict keys: {list(state_dict.keys())}") - - # Check if we can create a new instance and load the state - fresh_model = SimpleTestModel() - try: - fresh_model.load_state_dict(state_dict) - test_output_loaded = fresh_model(test_input) - logger.info(f"Loaded model test - Output shape: {test_output_loaded.shape}") - - # Compare outputs (should be identical) - if torch.allclose(test_output, test_output_loaded, atol=1e-6): - logger.info("✅ Loaded model produces identical outputs") - else: - logger.warning("⚠️ Loaded model outputs differ (this might be expected due to different random states)") - except Exception as e: - logger.warning(f"Could not test loaded model: {e}") - else: - logger.warning("Loaded model does not have proper structure") - - else: - logger.error("❌ Failed to load model") - - # Test 3: Save checkpoint - logger.info("Testing checkpoint saving...") - checkpoint_success = save_checkpoint( - model=test_model, - model_name="audit_test_model", - model_type="cnn", - performance_score=0.85, - metadata={ - "checkpoint_test": True, - "performance_metric": "accuracy", - "epoch": 1 - } - ) - - if checkpoint_success: - logger.info("✅ Checkpoint saved successfully") - else: - logger.error("❌ Failed to save checkpoint") - - # Check registry metadata after operations - logger.info("Checking registry metadata after operations...") - updated_models = registry.list_models() - logger.info(f"Updated models: {updated_models}") - - # Check file system - logger.info("Checking file system...") - models_dir = Path("models") - if models_dir.exists(): - logger.info(f"Models directory contents:") - for item in models_dir.rglob("*"): - if item.is_file(): - logger.info(f" {item.relative_to(models_dir)} ({item.stat().st_size} bytes)") - - return { - "registry_save_success": success, - "registry_load_success": loaded_model is not None, - "checkpoint_success": checkpoint_success, - "existing_models": existing_models, - "updated_models": updated_models - } - -def audit_model_metadata(): - """Audit the model metadata structure""" - logger.info("=== MODEL METADATA AUDIT ===") - - registry = get_model_registry() - - # Check metadata structure - metadata = registry.metadata - logger.info(f"Metadata keys: {list(metadata.keys())}") - - if 'models' in metadata: - models = metadata['models'] - logger.info(f"Number of registered models: {len(models)}") - - for model_name, model_data in models.items(): - logger.info(f"Model '{model_name}':") - logger.info(f" - Type: {model_data.get('type', 'unknown')}") - logger.info(f" - Last saved: {model_data.get('last_saved', 'never')}") - logger.info(f" - Save count: {model_data.get('save_count', 0)}") - logger.info(f" - Latest path: {model_data.get('latest_path', 'none')}") - logger.info(f" - Checkpoints: {len(model_data.get('checkpoints', []))}") - - if 'last_updated' in metadata: - logger.info(f"Last metadata update: {metadata['last_updated']}") - - return metadata - -def analyze_model_files(): - """Analyze the model files on disk""" - logger.info("=== MODEL FILES ANALYSIS ===") - - models_dir = Path("models") - - if not models_dir.exists(): - logger.error("Models directory does not exist") - return {} - - analysis = { - 'total_files': 0, - 'total_size': 0, - 'by_type': {}, - 'by_model': {}, - 'orphaned_files': [], - 'missing_files': [] - } - - # Analyze all .pt files - for pt_file in models_dir.rglob("*.pt"): - analysis['total_files'] += 1 - analysis['total_size'] += pt_file.stat().st_size - - # Categorize by type - parts = pt_file.parts - model_type = "unknown" - if "cnn" in parts: - model_type = "cnn" - elif "dqn" in parts: - model_type = "dqn" - elif "transformer" in parts: - model_type = "transformer" - elif "hybrid" in parts: - model_type = "hybrid" - - if model_type not in analysis['by_type']: - analysis['by_type'][model_type] = [] - analysis['by_type'][model_type].append(str(pt_file)) - - # Try to extract model name - filename = pt_file.name - if "_latest" in filename: - model_name = filename.replace("_latest.pt", "") - elif "_" in filename: - # Extract timestamp-based names - parts = filename.split("_") - if len(parts) >= 2: - model_name = "_".join(parts[:-1]) # Everything except timestamp - else: - model_name = filename.replace(".pt", "") - else: - model_name = filename.replace(".pt", "") - - if model_name not in analysis['by_model']: - analysis['by_model'][model_name] = [] - analysis['by_model'][model_name].append(str(pt_file)) - - logger.info(f"Total model files: {analysis['total_files']}") - logger.info(f"Total size: {analysis['total_size'] / (1024*1024):.2f} MB") - - logger.info("Files by type:") - for model_type, files in analysis['by_type'].items(): - logger.info(f" {model_type}: {len(files)} files") - - logger.info("Files by model:") - for model_name, files in analysis['by_model'].items(): - logger.info(f" {model_name}: {len(files)} files") - - return analysis - -def recommend_best_model_selection(): - """Provide recommendations for best model selection at startup""" - logger.info("=== BEST MODEL SELECTION RECOMMENDATIONS ===") - - registry = get_model_registry() - models = registry.list_models() - - recommendations = { - 'startup_strategy': 'hybrid', - 'fallback_models': [], - 'performance_criteria': [], - 'metadata_requirements': [] - } - - if models: - logger.info("Available models for selection:") - - # Analyze each model type - for model_name, model_info in models.items(): - model_type = model_info.get('type', 'unknown') - logger.info(f" {model_name} ({model_type}) - last saved: {model_info.get('last_saved', 'unknown')}") - - # Check if checkpoints exist - if 'checkpoint_count' in model_info and model_info['checkpoint_count'] > 0: - logger.info(f" - Has {model_info['checkpoint_count']} checkpoints") - recommendations['fallback_models'].append(model_name) - - # Recommendations - logger.info("RECOMMENDATIONS:") - logger.info("1. Startup Strategy:") - logger.info(" - Try to load latest model for each type") - logger.info(" - Fall back to checkpoints if latest model fails") - logger.info(" - Use fallback to basic/default model if all else fails") - - logger.info("2. Performance-based Selection:") - logger.info(" - For models with checkpoints, select highest performance_score") - logger.info(" - Track model age and prefer recently trained models") - logger.info(" - Implement model validation on startup") - - logger.info("3. Metadata Requirements:") - logger.info(" - Store performance metrics in metadata") - logger.info(" - Track training data quality and size") - logger.info(" - Include model validation results") - - else: - logger.info("No models registered - system will need initial training") - logger.info("RECOMMENDATION: Implement default model initialization") - - return recommendations - -def main(): - """Main audit function""" - logger.info("Starting Model Loading/Saving Audit") - logger.info("=" * 60) - - try: - # Test model registry - registry_results = test_model_registry() - logger.info("-" * 40) - - # Audit metadata - metadata = audit_model_metadata() - logger.info("-" * 40) - - # Analyze files - file_analysis = analyze_model_files() - logger.info("-" * 40) - - # Recommendations - recommendations = recommend_best_model_selection() - logger.info("-" * 40) - - # Summary - logger.info("=== AUDIT SUMMARY ===") - logger.info(f"Registry save success: {registry_results.get('registry_save_success', False)}") - logger.info(f"Registry load success: {registry_results.get('registry_load_success', False)}") - logger.info(f"Checkpoint success: {registry_results.get('checkpoint_success', False)}") - logger.info(f"Total model files: {file_analysis.get('total_files', 0)}") - logger.info(f"Registered models: {len(registry_results.get('existing_models', {}))}") - - logger.info("Audit completed successfully!") - - except Exception as e: - logger.error(f"Audit failed with error: {e}") - import traceback - traceback.print_exc() - -if __name__ == "__main__": - main() diff --git a/test_model_fixes.py b/test_model_fixes.py deleted file mode 100644 index ca31d14..0000000 --- a/test_model_fixes.py +++ /dev/null @@ -1,226 +0,0 @@ -#!/usr/bin/env python3 -""" -Test Model Loading and Saving Fixes - -This script validates that all the model loading/saving issues have been resolved. -""" - -import logging -import sys -from pathlib import Path - -# Add project root to path -project_root = Path(__file__).resolve().parent -sys.path.insert(0, str(project_root)) - -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -def test_model_registry(): - """Test the ModelRegistry fixes""" - print("=" * 60) - print("Testing ModelRegistry fixes...") - print("=" * 60) - - try: - from models import get_model_registry, register_model - from NN.models.model_interfaces import ModelInterface - - # Create a simple test model interface - class TestModelInterface(ModelInterface): - def __init__(self, name: str): - super().__init__(name) - - def predict(self, data): - return {"prediction": "test", "confidence": 0.5} - - def get_memory_usage(self) -> float: - return 1.0 - - # Test registry operations - registry = get_model_registry() - test_model = TestModelInterface("test_model") - - # Test registration (this should now work without signature error) - success = register_model(test_model) - if success: - print("✅ ModelRegistry registration: FIXED") - else: - print("❌ ModelRegistry registration: FAILED") - return False - - # Test retrieval - retrieved = registry.get_model("test_model") - if retrieved is not None: - print("✅ ModelRegistry retrieval: WORKING") - else: - print("❌ ModelRegistry retrieval: FAILED") - return False - - return True - - except Exception as e: - print(f"❌ ModelRegistry test failed: {e}") - return False - -def test_checkpoint_manager(): - """Test the CheckpointManager fixes""" - print("\n" + "=" * 60) - print("Testing CheckpointManager fixes...") - print("=" * 60) - - try: - from utils.checkpoint_manager import get_checkpoint_manager - - cm = get_checkpoint_manager() - - # Test loading existing models (should find legacy models) - models_to_test = ['dqn_agent', 'enhanced_cnn'] - found_models = 0 - - for model_name in models_to_test: - result = cm.load_best_checkpoint(model_name) - if result: - file_path, metadata = result - print(f"✅ Found {model_name}: {Path(file_path).name}") - found_models += 1 - else: - print(f"ℹ️ No checkpoint for {model_name} (expected for fresh start)") - - # Test that warnings are not repeated - print(f"✅ CheckpointManager: Found {found_models} legacy models") - print("✅ CheckpointManager: Warning spam reduced (cached)") - - return True - - except Exception as e: - print(f"❌ CheckpointManager test failed: {e}") - return False - -def test_improved_model_saver(): - """Test the ImprovedModelSaver""" - print("\n" + "=" * 60) - print("Testing ImprovedModelSaver...") - print("=" * 60) - - try: - from improved_model_saver import get_improved_model_saver - import torch - import torch.nn as nn - - saver = get_improved_model_saver() - - # Create a simple test model - class SimpleTestModel(nn.Module): - def __init__(self): - super().__init__() - self.linear = nn.Linear(10, 1) - - def forward(self, x): - return self.linear(x) - - test_model = SimpleTestModel() - - # Test saving - success = saver.save_model_safely( - test_model, - "test_simple_model", - "test", - metadata={"test": True, "accuracy": 0.95} - ) - - if success: - print("✅ ImprovedModelSaver save: WORKING") - else: - print("❌ ImprovedModelSaver save: FAILED") - return False - - # Test loading - loaded_model = saver.load_model_safely("test_simple_model", SimpleTestModel) - - if loaded_model is not None: - print("✅ ImprovedModelSaver load: WORKING") - - # Test that model actually works - test_input = torch.randn(1, 10) - output = loaded_model(test_input) - if output is not None: - print("✅ Loaded model functionality: WORKING") - else: - print("❌ Loaded model functionality: FAILED") - return False - else: - print("❌ ImprovedModelSaver load: FAILED") - return False - - return True - - except Exception as e: - print(f"❌ ImprovedModelSaver test failed: {e}") - return False - -def test_orchestrator_caching(): - """Test that orchestrator caching reduces repeated calls""" - print("\n" + "=" * 60) - print("Testing Orchestrator checkpoint caching...") - print("=" * 60) - - try: - # This is harder to test without running the full system - # But we can verify the cache mechanism exists - from core.orchestrator import TradingOrchestrator - print("✅ Orchestrator imports successfully") - print("✅ Checkpoint caching implemented (reduces load frequency)") - return True - - except Exception as e: - print(f"❌ Orchestrator test failed: {e}") - return False - -def main(): - """Run all tests""" - print("🔧 Testing Model Loading/Saving Fixes") - print("=" * 60) - - tests = [ - ("ModelRegistry Signature Fix", test_model_registry), - ("CheckpointManager Improvements", test_checkpoint_manager), - ("ImprovedModelSaver", test_improved_model_saver), - ("Orchestrator Caching", test_orchestrator_caching) - ] - - results = [] - - for test_name, test_func in tests: - try: - result = test_func() - results.append((test_name, result)) - except Exception as e: - print(f"❌ {test_name}: CRASHED - {e}") - results.append((test_name, False)) - - # Summary - print("\n" + "=" * 60) - print("TEST SUMMARY") - print("=" * 60) - - passed = 0 - for test_name, result in results: - status = "PASSED" if result else "FAILED" - icon = "✅" if result else "❌" - print(f"{icon} {test_name}: {status}") - if result: - passed += 1 - - print(f"\nOverall: {passed}/{len(tests)} tests passed") - - if passed == len(tests): - print("\n🎉 ALL MODEL FIXES WORKING! Dashboard should run without registration errors.") - else: - print(f"\n⚠️ {len(tests) - passed} tests failed. Some issues may remain.") - - return passed == len(tests) - -if __name__ == "__main__": - success = main() - sys.exit(0 if success else 1)