cleanup
This commit is contained in:
@@ -103,8 +103,8 @@ class TradingOrchestrator:
|
|||||||
# Configuration - AGGRESSIVE for more training data
|
# Configuration - AGGRESSIVE for more training data
|
||||||
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.15) # Lowered from 0.20
|
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.15) # Lowered from 0.20
|
||||||
self.confidence_threshold_close = self.config.orchestrator.get('confidence_threshold_close', 0.08) # Lowered from 0.10
|
self.confidence_threshold_close = self.config.orchestrator.get('confidence_threshold_close', 0.08) # Lowered from 0.10
|
||||||
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30)
|
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 5)
|
||||||
self.symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) # Enhanced to support multiple symbols
|
self.symbols = self.config.get('symbols', ['ETH/USDT']) # Enhanced to support multiple symbols
|
||||||
|
|
||||||
# NEW: Aggressiveness parameters
|
# NEW: Aggressiveness parameters
|
||||||
self.entry_aggressiveness = self.config.orchestrator.get('entry_aggressiveness', 0.5) # 0.0 = conservative, 1.0 = very aggressive
|
self.entry_aggressiveness = self.config.orchestrator.get('entry_aggressiveness', 0.5) # 0.0 = conservative, 1.0 = very aggressive
|
||||||
|
@@ -1,87 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test COB Integration Status in Enhanced Orchestrator
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
sys.path.append(str(Path('.').absolute()))
|
|
||||||
|
|
||||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
|
||||||
from core.data_provider import DataProvider
|
|
||||||
|
|
||||||
async def test_cob_integration():
|
|
||||||
print("=" * 60)
|
|
||||||
print("COB INTEGRATION AUDIT")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
data_provider = DataProvider()
|
|
||||||
orchestrator = EnhancedTradingOrchestrator(
|
|
||||||
data_provider=data_provider,
|
|
||||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
|
||||||
enhanced_rl_training=True
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"✓ Enhanced Orchestrator created")
|
|
||||||
print(f"Has COB integration attribute: {hasattr(orchestrator, 'cob_integration')}")
|
|
||||||
print(f"COB integration value: {orchestrator.cob_integration}")
|
|
||||||
print(f"COB integration type: {type(orchestrator.cob_integration)}")
|
|
||||||
print(f"COB integration active: {getattr(orchestrator, 'cob_integration_active', 'Not set')}")
|
|
||||||
|
|
||||||
if orchestrator.cob_integration:
|
|
||||||
print("\n--- COB Integration Details ---")
|
|
||||||
print(f"COB Integration class: {orchestrator.cob_integration.__class__.__name__}")
|
|
||||||
|
|
||||||
# Check if it has the expected methods
|
|
||||||
methods_to_check = ['get_statistics', 'get_cob_snapshot', 'add_dashboard_callback', 'start', 'stop']
|
|
||||||
for method in methods_to_check:
|
|
||||||
has_method = hasattr(orchestrator.cob_integration, method)
|
|
||||||
print(f"Has {method}: {has_method}")
|
|
||||||
|
|
||||||
# Try to get statistics
|
|
||||||
if hasattr(orchestrator.cob_integration, 'get_statistics'):
|
|
||||||
try:
|
|
||||||
stats = orchestrator.cob_integration.get_statistics()
|
|
||||||
print(f"COB statistics: {stats}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error getting COB statistics: {e}")
|
|
||||||
|
|
||||||
# Try to get a snapshot
|
|
||||||
if hasattr(orchestrator.cob_integration, 'get_cob_snapshot'):
|
|
||||||
try:
|
|
||||||
snapshot = orchestrator.cob_integration.get_cob_snapshot('ETH/USDT')
|
|
||||||
print(f"ETH/USDT snapshot: {snapshot}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error getting COB snapshot: {e}")
|
|
||||||
|
|
||||||
# Check if COB integration needs to be started
|
|
||||||
print(f"\n--- Starting COB Integration ---")
|
|
||||||
try:
|
|
||||||
await orchestrator.start_cob_integration()
|
|
||||||
print("✓ COB integration started successfully")
|
|
||||||
|
|
||||||
# Wait a moment and check statistics again
|
|
||||||
await asyncio.sleep(3)
|
|
||||||
if hasattr(orchestrator.cob_integration, 'get_statistics'):
|
|
||||||
stats = orchestrator.cob_integration.get_statistics()
|
|
||||||
print(f"COB statistics after start: {stats}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error starting COB integration: {e}")
|
|
||||||
else:
|
|
||||||
print("\n❌ COB integration is None - this explains the dashboard issues")
|
|
||||||
print("The Enhanced Orchestrator failed to initialize COB integration")
|
|
||||||
|
|
||||||
# Check the error flag
|
|
||||||
if hasattr(orchestrator, '_cob_integration_failed'):
|
|
||||||
print(f"COB integration failed flag: {orchestrator._cob_integration_failed}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error in COB audit: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(test_cob_integration())
|
|
@@ -1,144 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test Enhanced Training Integration
|
|
||||||
|
|
||||||
This script tests the integration of EnhancedRealtimeTrainingSystem
|
|
||||||
into the TradingOrchestrator to ensure it works correctly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
import asyncio
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
# Add project root to path
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
from core.orchestrator import TradingOrchestrator
|
|
||||||
from core.data_provider import DataProvider
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
async def test_enhanced_training_integration():
|
|
||||||
"""Test the enhanced training system integration"""
|
|
||||||
try:
|
|
||||||
logger.info("=" * 60)
|
|
||||||
logger.info("TESTING ENHANCED TRAINING INTEGRATION")
|
|
||||||
logger.info("=" * 60)
|
|
||||||
|
|
||||||
# 1. Initialize orchestrator with enhanced training
|
|
||||||
logger.info("1. Initializing orchestrator with enhanced training...")
|
|
||||||
data_provider = DataProvider()
|
|
||||||
orchestrator = TradingOrchestrator(
|
|
||||||
data_provider=data_provider,
|
|
||||||
enhanced_rl_training=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Check if training system is available
|
|
||||||
logger.info("2. Checking training system availability...")
|
|
||||||
training_available = hasattr(orchestrator, 'enhanced_training_system')
|
|
||||||
training_enabled = getattr(orchestrator, 'training_enabled', False)
|
|
||||||
|
|
||||||
logger.info(f" - Training system attribute: {'✅ Available' if training_available else '❌ Missing'}")
|
|
||||||
logger.info(f" - Training enabled: {'✅ Yes' if training_enabled else '❌ No'}")
|
|
||||||
|
|
||||||
# 3. Test training system initialization
|
|
||||||
if training_available and orchestrator.enhanced_training_system:
|
|
||||||
logger.info("3. Testing training system methods...")
|
|
||||||
|
|
||||||
# Test getting training statistics
|
|
||||||
stats = orchestrator.get_enhanced_training_stats()
|
|
||||||
logger.info(f" - Training stats retrieved: {len(stats)} fields")
|
|
||||||
logger.info(f" - Training enabled in stats: {stats.get('training_enabled', False)}")
|
|
||||||
logger.info(f" - System available: {stats.get('system_available', False)}")
|
|
||||||
|
|
||||||
# Test starting training
|
|
||||||
start_result = orchestrator.start_enhanced_training()
|
|
||||||
logger.info(f" - Start training result: {'✅ Success' if start_result else '❌ Failed'}")
|
|
||||||
|
|
||||||
if start_result:
|
|
||||||
# Let it run for a few seconds
|
|
||||||
logger.info(" - Letting training run for 5 seconds...")
|
|
||||||
await asyncio.sleep(5)
|
|
||||||
|
|
||||||
# Get updated stats
|
|
||||||
updated_stats = orchestrator.get_enhanced_training_stats()
|
|
||||||
logger.info(f" - Updated stats: {updated_stats.get('is_training', False)}")
|
|
||||||
|
|
||||||
# Stop training
|
|
||||||
stop_result = orchestrator.stop_enhanced_training()
|
|
||||||
logger.info(f" - Stop training result: {'✅ Success' if stop_result else '❌ Failed'}")
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.warning("3. Training system not available - checking fallback behavior...")
|
|
||||||
|
|
||||||
# Test methods when training system is not available
|
|
||||||
stats = orchestrator.get_enhanced_training_stats()
|
|
||||||
logger.info(f" - Fallback stats: {stats}")
|
|
||||||
|
|
||||||
start_result = orchestrator.start_enhanced_training()
|
|
||||||
logger.info(f" - Fallback start result: {start_result}")
|
|
||||||
|
|
||||||
# 4. Test dashboard connection method
|
|
||||||
logger.info("4. Testing dashboard connection method...")
|
|
||||||
try:
|
|
||||||
orchestrator.set_training_dashboard(None) # Test with None
|
|
||||||
logger.info(" - Dashboard connection method: ✅ Available")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f" - Dashboard connection method error: {e}")
|
|
||||||
|
|
||||||
# 5. Summary
|
|
||||||
logger.info("=" * 60)
|
|
||||||
logger.info("INTEGRATION TEST SUMMARY")
|
|
||||||
logger.info("=" * 60)
|
|
||||||
|
|
||||||
if training_available and training_enabled:
|
|
||||||
logger.info("✅ ENHANCED TRAINING INTEGRATION SUCCESSFUL")
|
|
||||||
logger.info(" - Training system properly integrated")
|
|
||||||
logger.info(" - All methods available and functional")
|
|
||||||
logger.info(" - Ready for real-time training")
|
|
||||||
elif training_available:
|
|
||||||
logger.info("⚠️ ENHANCED TRAINING PARTIALLY INTEGRATED")
|
|
||||||
logger.info(" - Training system available but not enabled")
|
|
||||||
logger.info(" - Check EnhancedRealtimeTrainingSystem import")
|
|
||||||
else:
|
|
||||||
logger.info("❌ ENHANCED TRAINING INTEGRATION FAILED")
|
|
||||||
logger.info(" - Training system not properly integrated")
|
|
||||||
logger.info(" - Methods missing or non-functional")
|
|
||||||
|
|
||||||
return training_available and training_enabled
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in integration test: {e}")
|
|
||||||
import traceback
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""Main test function"""
|
|
||||||
try:
|
|
||||||
success = await test_enhanced_training_integration()
|
|
||||||
|
|
||||||
if success:
|
|
||||||
logger.info("🎉 All tests passed! Enhanced training integration is working.")
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ Some tests failed. Check the integration.")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("Test interrupted by user")
|
|
||||||
return 0
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Fatal error in test: {e}")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
exit_code = asyncio.run(main())
|
|
||||||
sys.exit(exit_code)
|
|
@@ -1,78 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Simple Enhanced Training Test
|
|
||||||
|
|
||||||
Quick test to verify enhanced training system can be enabled and controlled.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
|
|
||||||
# Add project root to path
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
from core.orchestrator import TradingOrchestrator
|
|
||||||
from core.data_provider import DataProvider
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def test_enhanced_training():
|
|
||||||
"""Test enhanced training system"""
|
|
||||||
try:
|
|
||||||
logger.info("Testing Enhanced Training System...")
|
|
||||||
|
|
||||||
# 1. Create data provider
|
|
||||||
data_provider = DataProvider()
|
|
||||||
|
|
||||||
# 2. Create orchestrator with enhanced training ENABLED
|
|
||||||
logger.info("Creating orchestrator with enhanced_rl_training=True...")
|
|
||||||
orchestrator = TradingOrchestrator(
|
|
||||||
data_provider=data_provider,
|
|
||||||
enhanced_rl_training=True # 🔥 THIS ENABLES IT
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Check if training system is available
|
|
||||||
logger.info(f"Training system available: {orchestrator.enhanced_training_system is not None}")
|
|
||||||
logger.info(f"Training enabled: {orchestrator.training_enabled}")
|
|
||||||
|
|
||||||
# 4. Get training stats
|
|
||||||
stats = orchestrator.get_enhanced_training_stats()
|
|
||||||
logger.info(f"Training stats: {stats}")
|
|
||||||
|
|
||||||
# 5. Test start/stop
|
|
||||||
if orchestrator.enhanced_training_system:
|
|
||||||
logger.info("Testing start/stop functionality...")
|
|
||||||
|
|
||||||
# Start training
|
|
||||||
start_result = orchestrator.start_enhanced_training()
|
|
||||||
logger.info(f"Start result: {start_result}")
|
|
||||||
|
|
||||||
# Get updated stats
|
|
||||||
updated_stats = orchestrator.get_enhanced_training_stats()
|
|
||||||
logger.info(f"Updated stats: {updated_stats}")
|
|
||||||
|
|
||||||
# Stop training
|
|
||||||
stop_result = orchestrator.stop_enhanced_training()
|
|
||||||
logger.info(f"Stop result: {stop_result}")
|
|
||||||
|
|
||||||
logger.info("✅ Enhanced training system is working!")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.warning("❌ Enhanced training system not available")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error testing enhanced training: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = test_enhanced_training()
|
|
||||||
if success:
|
|
||||||
print("\n🎉 Enhanced training system is ready to use!")
|
|
||||||
print("To enable it in your main system, use:")
|
|
||||||
print(" enhanced_rl_training=True when creating TradingOrchestrator")
|
|
||||||
else:
|
|
||||||
print("\n⚠️ Enhanced training system has issues. Check the logs above.")
|
|
@@ -1,180 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test FRESH to LOADED Model Status Fix
|
|
||||||
|
|
||||||
This script tests the fix for models showing as FRESH instead of LOADED.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add project root to path
|
|
||||||
project_root = Path(__file__).resolve().parent
|
|
||||||
sys.path.insert(0, str(project_root))
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def test_orchestrator_model_initialization():
|
|
||||||
"""Test that orchestrator initializes all models correctly"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("Testing Orchestrator Model Initialization...")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from core.data_provider import DataProvider
|
|
||||||
from core.orchestrator import TradingOrchestrator
|
|
||||||
|
|
||||||
# Create data provider and orchestrator
|
|
||||||
data_provider = DataProvider()
|
|
||||||
orchestrator = TradingOrchestrator(data_provider=data_provider, enhanced_rl_training=True)
|
|
||||||
|
|
||||||
# Check which models were initialized
|
|
||||||
models_initialized = []
|
|
||||||
|
|
||||||
if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent:
|
|
||||||
models_initialized.append('DQN')
|
|
||||||
|
|
||||||
if hasattr(orchestrator, 'cnn_model') and orchestrator.cnn_model:
|
|
||||||
models_initialized.append('CNN')
|
|
||||||
|
|
||||||
if hasattr(orchestrator, 'extrema_trainer') and orchestrator.extrema_trainer:
|
|
||||||
models_initialized.append('ExtremaTrainer')
|
|
||||||
|
|
||||||
if hasattr(orchestrator, 'cob_rl_agent') and orchestrator.cob_rl_agent:
|
|
||||||
models_initialized.append('COB_RL')
|
|
||||||
|
|
||||||
if hasattr(orchestrator, 'transformer_model') and orchestrator.transformer_model:
|
|
||||||
models_initialized.append('TRANSFORMER')
|
|
||||||
|
|
||||||
if hasattr(orchestrator, 'decision_model') and orchestrator.decision_model:
|
|
||||||
models_initialized.append('DECISION')
|
|
||||||
|
|
||||||
print(f"✅ Initialized Models: {', '.join(models_initialized)}")
|
|
||||||
|
|
||||||
# Check model states
|
|
||||||
print("\nModel States:")
|
|
||||||
for model_name, state in orchestrator.model_states.items():
|
|
||||||
checkpoint_loaded = state.get('checkpoint_loaded', False)
|
|
||||||
status = "LOADED" if checkpoint_loaded else "FRESH"
|
|
||||||
filename = state.get('checkpoint_filename', 'none')
|
|
||||||
print(f" {model_name.upper()}: {status} ({filename})")
|
|
||||||
|
|
||||||
return orchestrator, len(models_initialized)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Orchestrator initialization failed: {e}")
|
|
||||||
return None, 0
|
|
||||||
|
|
||||||
def test_checkpoint_saving(orchestrator):
|
|
||||||
"""Test saving checkpoints for all models"""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Testing Checkpoint Saving...")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from model_checkpoint_saver import ModelCheckpointSaver
|
|
||||||
|
|
||||||
saver = ModelCheckpointSaver(orchestrator)
|
|
||||||
|
|
||||||
# Force all models to LOADED status
|
|
||||||
updated_models = saver.force_all_models_to_loaded()
|
|
||||||
|
|
||||||
print(f"✅ Updated {len(updated_models)} models to LOADED status")
|
|
||||||
|
|
||||||
# Check updated states
|
|
||||||
print("\nUpdated Model States:")
|
|
||||||
fresh_count = 0
|
|
||||||
loaded_count = 0
|
|
||||||
|
|
||||||
for model_name, state in orchestrator.model_states.items():
|
|
||||||
checkpoint_loaded = state.get('checkpoint_loaded', False)
|
|
||||||
status = "LOADED" if checkpoint_loaded else "FRESH"
|
|
||||||
filename = state.get('checkpoint_filename', 'none')
|
|
||||||
print(f" {model_name.upper()}: {status} ({filename})")
|
|
||||||
|
|
||||||
if checkpoint_loaded:
|
|
||||||
loaded_count += 1
|
|
||||||
else:
|
|
||||||
fresh_count += 1
|
|
||||||
|
|
||||||
print(f"\nSummary: {loaded_count} LOADED, {fresh_count} FRESH")
|
|
||||||
|
|
||||||
return fresh_count == 0
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Checkpoint saving test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_dashboard_model_status():
|
|
||||||
"""Test how models show up in dashboard"""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Testing Dashboard Model Status Display...")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Simulate dashboard model status check
|
|
||||||
from web.component_manager import DashboardComponentManager
|
|
||||||
|
|
||||||
print("✅ Dashboard component manager imports successfully")
|
|
||||||
print("✅ Model status display logic available")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Dashboard test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run all tests"""
|
|
||||||
print("🔧 Testing FRESH to LOADED Model Status Fix")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Test 1: Orchestrator initialization
|
|
||||||
orchestrator, models_count = test_orchestrator_model_initialization()
|
|
||||||
if not orchestrator:
|
|
||||||
print("\n❌ Cannot proceed - orchestrator initialization failed")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Test 2: Checkpoint saving
|
|
||||||
checkpoint_success = test_checkpoint_saving(orchestrator)
|
|
||||||
|
|
||||||
# Test 3: Dashboard integration
|
|
||||||
dashboard_success = test_dashboard_model_status()
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("TEST SUMMARY")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
tests = [
|
|
||||||
("Model Initialization", models_count > 0),
|
|
||||||
("Checkpoint Status Fix", checkpoint_success),
|
|
||||||
("Dashboard Integration", dashboard_success)
|
|
||||||
]
|
|
||||||
|
|
||||||
passed = 0
|
|
||||||
for test_name, result in tests:
|
|
||||||
status = "PASSED" if result else "FAILED"
|
|
||||||
icon = "✅" if result else "❌"
|
|
||||||
print(f"{icon} {test_name}: {status}")
|
|
||||||
if result:
|
|
||||||
passed += 1
|
|
||||||
|
|
||||||
print(f"\nOverall: {passed}/{len(tests)} tests passed")
|
|
||||||
|
|
||||||
if passed == len(tests):
|
|
||||||
print("\n🎉 ALL TESTS PASSED! Models should now show as LOADED instead of FRESH.")
|
|
||||||
print("\nNext steps:")
|
|
||||||
print("1. Restart the dashboard")
|
|
||||||
print("2. Models should now show as LOADED in the status panel")
|
|
||||||
print("3. The FRESH status issue should be resolved")
|
|
||||||
else:
|
|
||||||
print(f"\n⚠️ {len(tests) - passed} tests failed. Some issues may remain.")
|
|
||||||
|
|
||||||
return passed == len(tests)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = main()
|
|
||||||
sys.exit(0 if success else 1)
|
|
@@ -1,74 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
"""
|
|
||||||
Test script to verify leverage P&L calculations are working correctly
|
|
||||||
"""
|
|
||||||
|
|
||||||
from web.clean_dashboard import create_clean_dashboard
|
|
||||||
|
|
||||||
def test_leverage_calculations():
|
|
||||||
print("🧮 Testing Leverage P&L Calculations")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
# Create dashboard
|
|
||||||
dashboard = create_clean_dashboard()
|
|
||||||
|
|
||||||
print("✅ Dashboard created successfully")
|
|
||||||
|
|
||||||
# Test 1: Position leverage vs slider leverage
|
|
||||||
print("\n📊 Test 1: Position vs Slider Leverage")
|
|
||||||
dashboard.current_leverage = 25 # Current slider at x25
|
|
||||||
dashboard.current_position = {
|
|
||||||
'side': 'LONG',
|
|
||||||
'size': 0.01,
|
|
||||||
'price': 2000.0, # Entry at $2000
|
|
||||||
'leverage': 10, # Position opened at x10 leverage
|
|
||||||
'symbol': 'ETH/USDT'
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f" Position opened at: x{dashboard.current_position['leverage']} leverage")
|
|
||||||
print(f" Current slider at: x{dashboard.current_leverage} leverage")
|
|
||||||
print(" ✅ Position uses its stored leverage, not current slider")
|
|
||||||
|
|
||||||
# Test 2: Trading statistics with leveraged P&L
|
|
||||||
print("\n📈 Test 2: Trading Statistics")
|
|
||||||
test_trade = {
|
|
||||||
'symbol': 'ETH/USDT',
|
|
||||||
'side': 'BUY',
|
|
||||||
'pnl': 100.0, # Leveraged P&L
|
|
||||||
'pnl_raw': 2.0, # Raw P&L (before leverage)
|
|
||||||
'leverage_used': 50, # x50 leverage used
|
|
||||||
'fees': 0.5
|
|
||||||
}
|
|
||||||
|
|
||||||
dashboard.closed_trades.append(test_trade)
|
|
||||||
dashboard.session_pnl = 100.0
|
|
||||||
|
|
||||||
stats = dashboard._get_trading_statistics()
|
|
||||||
|
|
||||||
print(f" Trade raw P&L: ${test_trade['pnl_raw']:.2f}")
|
|
||||||
print(f" Trade leverage: x{test_trade['leverage_used']}")
|
|
||||||
print(f" Trade leveraged P&L: ${test_trade['pnl']:.2f}")
|
|
||||||
print(f" Statistics total P&L: ${stats['total_pnl']:.2f}")
|
|
||||||
print(f" ✅ Statistics use leveraged P&L correctly")
|
|
||||||
|
|
||||||
# Test 3: Session P&L calculation
|
|
||||||
print("\n💰 Test 3: Session P&L")
|
|
||||||
print(f" Session P&L: ${dashboard.session_pnl:.2f}")
|
|
||||||
print(f" Expected: $100.00")
|
|
||||||
if abs(dashboard.session_pnl - 100.0) < 0.01:
|
|
||||||
print(" ✅ Session P&L correctly uses leveraged amounts")
|
|
||||||
else:
|
|
||||||
print(" ❌ Session P&L calculation error")
|
|
||||||
|
|
||||||
print("\n🎯 Summary:")
|
|
||||||
print(" • Positions store their original leverage")
|
|
||||||
print(" • Unrealized P&L uses position leverage (not slider)")
|
|
||||||
print(" • Completed trades store both raw and leveraged P&L")
|
|
||||||
print(" • Statistics display leveraged P&L")
|
|
||||||
print(" • Session totals use leveraged amounts")
|
|
||||||
|
|
||||||
print("\n✅ ALL LEVERAGE P&L CALCULATIONS FIXED!")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_leverage_calculations()
|
|
@@ -1,344 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Model Loading/Saving Audit Test
|
|
||||||
|
|
||||||
This script tests the model registry and saving/loading mechanisms
|
|
||||||
to identify any issues and provide recommendations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add project root to path
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
from utils.model_registry import get_model_registry, save_model, load_model, save_checkpoint
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class SimpleTestModel(nn.Module):
|
|
||||||
"""Simple neural network for testing"""
|
|
||||||
def __init__(self, input_size=10, hidden_size=32, output_size=2):
|
|
||||||
super().__init__()
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
nn.Linear(input_size, hidden_size),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_size, output_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
def test_model_registry():
|
|
||||||
"""Test the model registry functionality"""
|
|
||||||
logger.info("=== MODEL REGISTRY AUDIT ===")
|
|
||||||
|
|
||||||
registry = get_model_registry()
|
|
||||||
logger.info(f"Registry base directory: {registry.base_dir}")
|
|
||||||
logger.info(f"Registry metadata file: {registry.metadata_file}")
|
|
||||||
|
|
||||||
# Check existing models
|
|
||||||
existing_models = registry.list_models()
|
|
||||||
logger.info(f"Existing models: {existing_models}")
|
|
||||||
|
|
||||||
# Test model creation and saving
|
|
||||||
logger.info("Creating test model...")
|
|
||||||
test_model = SimpleTestModel()
|
|
||||||
|
|
||||||
# Generate some fake training data
|
|
||||||
test_input = torch.randn(32, 10)
|
|
||||||
test_output = test_model(test_input)
|
|
||||||
|
|
||||||
logger.info(f"Test model created. Input shape: {test_input.shape}, Output shape: {test_output.shape}")
|
|
||||||
|
|
||||||
# Test saving with different methods
|
|
||||||
logger.info("Testing model saving...")
|
|
||||||
|
|
||||||
# Test 1: Save with unified registry
|
|
||||||
success = save_model(
|
|
||||||
model=test_model,
|
|
||||||
model_name="audit_test_model",
|
|
||||||
model_type="cnn",
|
|
||||||
metadata={
|
|
||||||
"test_type": "registry_audit",
|
|
||||||
"created_at": datetime.now().isoformat(),
|
|
||||||
"input_shape": list(test_input.shape),
|
|
||||||
"output_shape": list(test_output.shape)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
logger.info("✅ Model saved successfully with unified registry")
|
|
||||||
else:
|
|
||||||
logger.error("❌ Failed to save model with unified registry")
|
|
||||||
|
|
||||||
# Test 2: Load model back
|
|
||||||
logger.info("Testing model loading...")
|
|
||||||
loaded_model = load_model("audit_test_model", "cnn")
|
|
||||||
|
|
||||||
if loaded_model is not None:
|
|
||||||
logger.info("✅ Model loaded successfully")
|
|
||||||
|
|
||||||
# Test if loaded model has proper structure
|
|
||||||
if hasattr(loaded_model, 'state_dict') and callable(loaded_model.state_dict):
|
|
||||||
state_dict = loaded_model.state_dict()
|
|
||||||
logger.info(f"Loaded model test - State dict keys: {list(state_dict.keys())}")
|
|
||||||
|
|
||||||
# Check if we can create a new instance and load the state
|
|
||||||
fresh_model = SimpleTestModel()
|
|
||||||
try:
|
|
||||||
fresh_model.load_state_dict(state_dict)
|
|
||||||
test_output_loaded = fresh_model(test_input)
|
|
||||||
logger.info(f"Loaded model test - Output shape: {test_output_loaded.shape}")
|
|
||||||
|
|
||||||
# Compare outputs (should be identical)
|
|
||||||
if torch.allclose(test_output, test_output_loaded, atol=1e-6):
|
|
||||||
logger.info("✅ Loaded model produces identical outputs")
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ Loaded model outputs differ (this might be expected due to different random states)")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not test loaded model: {e}")
|
|
||||||
else:
|
|
||||||
logger.warning("Loaded model does not have proper structure")
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.error("❌ Failed to load model")
|
|
||||||
|
|
||||||
# Test 3: Save checkpoint
|
|
||||||
logger.info("Testing checkpoint saving...")
|
|
||||||
checkpoint_success = save_checkpoint(
|
|
||||||
model=test_model,
|
|
||||||
model_name="audit_test_model",
|
|
||||||
model_type="cnn",
|
|
||||||
performance_score=0.85,
|
|
||||||
metadata={
|
|
||||||
"checkpoint_test": True,
|
|
||||||
"performance_metric": "accuracy",
|
|
||||||
"epoch": 1
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if checkpoint_success:
|
|
||||||
logger.info("✅ Checkpoint saved successfully")
|
|
||||||
else:
|
|
||||||
logger.error("❌ Failed to save checkpoint")
|
|
||||||
|
|
||||||
# Check registry metadata after operations
|
|
||||||
logger.info("Checking registry metadata after operations...")
|
|
||||||
updated_models = registry.list_models()
|
|
||||||
logger.info(f"Updated models: {updated_models}")
|
|
||||||
|
|
||||||
# Check file system
|
|
||||||
logger.info("Checking file system...")
|
|
||||||
models_dir = Path("models")
|
|
||||||
if models_dir.exists():
|
|
||||||
logger.info(f"Models directory contents:")
|
|
||||||
for item in models_dir.rglob("*"):
|
|
||||||
if item.is_file():
|
|
||||||
logger.info(f" {item.relative_to(models_dir)} ({item.stat().st_size} bytes)")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"registry_save_success": success,
|
|
||||||
"registry_load_success": loaded_model is not None,
|
|
||||||
"checkpoint_success": checkpoint_success,
|
|
||||||
"existing_models": existing_models,
|
|
||||||
"updated_models": updated_models
|
|
||||||
}
|
|
||||||
|
|
||||||
def audit_model_metadata():
|
|
||||||
"""Audit the model metadata structure"""
|
|
||||||
logger.info("=== MODEL METADATA AUDIT ===")
|
|
||||||
|
|
||||||
registry = get_model_registry()
|
|
||||||
|
|
||||||
# Check metadata structure
|
|
||||||
metadata = registry.metadata
|
|
||||||
logger.info(f"Metadata keys: {list(metadata.keys())}")
|
|
||||||
|
|
||||||
if 'models' in metadata:
|
|
||||||
models = metadata['models']
|
|
||||||
logger.info(f"Number of registered models: {len(models)}")
|
|
||||||
|
|
||||||
for model_name, model_data in models.items():
|
|
||||||
logger.info(f"Model '{model_name}':")
|
|
||||||
logger.info(f" - Type: {model_data.get('type', 'unknown')}")
|
|
||||||
logger.info(f" - Last saved: {model_data.get('last_saved', 'never')}")
|
|
||||||
logger.info(f" - Save count: {model_data.get('save_count', 0)}")
|
|
||||||
logger.info(f" - Latest path: {model_data.get('latest_path', 'none')}")
|
|
||||||
logger.info(f" - Checkpoints: {len(model_data.get('checkpoints', []))}")
|
|
||||||
|
|
||||||
if 'last_updated' in metadata:
|
|
||||||
logger.info(f"Last metadata update: {metadata['last_updated']}")
|
|
||||||
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
def analyze_model_files():
|
|
||||||
"""Analyze the model files on disk"""
|
|
||||||
logger.info("=== MODEL FILES ANALYSIS ===")
|
|
||||||
|
|
||||||
models_dir = Path("models")
|
|
||||||
|
|
||||||
if not models_dir.exists():
|
|
||||||
logger.error("Models directory does not exist")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
analysis = {
|
|
||||||
'total_files': 0,
|
|
||||||
'total_size': 0,
|
|
||||||
'by_type': {},
|
|
||||||
'by_model': {},
|
|
||||||
'orphaned_files': [],
|
|
||||||
'missing_files': []
|
|
||||||
}
|
|
||||||
|
|
||||||
# Analyze all .pt files
|
|
||||||
for pt_file in models_dir.rglob("*.pt"):
|
|
||||||
analysis['total_files'] += 1
|
|
||||||
analysis['total_size'] += pt_file.stat().st_size
|
|
||||||
|
|
||||||
# Categorize by type
|
|
||||||
parts = pt_file.parts
|
|
||||||
model_type = "unknown"
|
|
||||||
if "cnn" in parts:
|
|
||||||
model_type = "cnn"
|
|
||||||
elif "dqn" in parts:
|
|
||||||
model_type = "dqn"
|
|
||||||
elif "transformer" in parts:
|
|
||||||
model_type = "transformer"
|
|
||||||
elif "hybrid" in parts:
|
|
||||||
model_type = "hybrid"
|
|
||||||
|
|
||||||
if model_type not in analysis['by_type']:
|
|
||||||
analysis['by_type'][model_type] = []
|
|
||||||
analysis['by_type'][model_type].append(str(pt_file))
|
|
||||||
|
|
||||||
# Try to extract model name
|
|
||||||
filename = pt_file.name
|
|
||||||
if "_latest" in filename:
|
|
||||||
model_name = filename.replace("_latest.pt", "")
|
|
||||||
elif "_" in filename:
|
|
||||||
# Extract timestamp-based names
|
|
||||||
parts = filename.split("_")
|
|
||||||
if len(parts) >= 2:
|
|
||||||
model_name = "_".join(parts[:-1]) # Everything except timestamp
|
|
||||||
else:
|
|
||||||
model_name = filename.replace(".pt", "")
|
|
||||||
else:
|
|
||||||
model_name = filename.replace(".pt", "")
|
|
||||||
|
|
||||||
if model_name not in analysis['by_model']:
|
|
||||||
analysis['by_model'][model_name] = []
|
|
||||||
analysis['by_model'][model_name].append(str(pt_file))
|
|
||||||
|
|
||||||
logger.info(f"Total model files: {analysis['total_files']}")
|
|
||||||
logger.info(f"Total size: {analysis['total_size'] / (1024*1024):.2f} MB")
|
|
||||||
|
|
||||||
logger.info("Files by type:")
|
|
||||||
for model_type, files in analysis['by_type'].items():
|
|
||||||
logger.info(f" {model_type}: {len(files)} files")
|
|
||||||
|
|
||||||
logger.info("Files by model:")
|
|
||||||
for model_name, files in analysis['by_model'].items():
|
|
||||||
logger.info(f" {model_name}: {len(files)} files")
|
|
||||||
|
|
||||||
return analysis
|
|
||||||
|
|
||||||
def recommend_best_model_selection():
|
|
||||||
"""Provide recommendations for best model selection at startup"""
|
|
||||||
logger.info("=== BEST MODEL SELECTION RECOMMENDATIONS ===")
|
|
||||||
|
|
||||||
registry = get_model_registry()
|
|
||||||
models = registry.list_models()
|
|
||||||
|
|
||||||
recommendations = {
|
|
||||||
'startup_strategy': 'hybrid',
|
|
||||||
'fallback_models': [],
|
|
||||||
'performance_criteria': [],
|
|
||||||
'metadata_requirements': []
|
|
||||||
}
|
|
||||||
|
|
||||||
if models:
|
|
||||||
logger.info("Available models for selection:")
|
|
||||||
|
|
||||||
# Analyze each model type
|
|
||||||
for model_name, model_info in models.items():
|
|
||||||
model_type = model_info.get('type', 'unknown')
|
|
||||||
logger.info(f" {model_name} ({model_type}) - last saved: {model_info.get('last_saved', 'unknown')}")
|
|
||||||
|
|
||||||
# Check if checkpoints exist
|
|
||||||
if 'checkpoint_count' in model_info and model_info['checkpoint_count'] > 0:
|
|
||||||
logger.info(f" - Has {model_info['checkpoint_count']} checkpoints")
|
|
||||||
recommendations['fallback_models'].append(model_name)
|
|
||||||
|
|
||||||
# Recommendations
|
|
||||||
logger.info("RECOMMENDATIONS:")
|
|
||||||
logger.info("1. Startup Strategy:")
|
|
||||||
logger.info(" - Try to load latest model for each type")
|
|
||||||
logger.info(" - Fall back to checkpoints if latest model fails")
|
|
||||||
logger.info(" - Use fallback to basic/default model if all else fails")
|
|
||||||
|
|
||||||
logger.info("2. Performance-based Selection:")
|
|
||||||
logger.info(" - For models with checkpoints, select highest performance_score")
|
|
||||||
logger.info(" - Track model age and prefer recently trained models")
|
|
||||||
logger.info(" - Implement model validation on startup")
|
|
||||||
|
|
||||||
logger.info("3. Metadata Requirements:")
|
|
||||||
logger.info(" - Store performance metrics in metadata")
|
|
||||||
logger.info(" - Track training data quality and size")
|
|
||||||
logger.info(" - Include model validation results")
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.info("No models registered - system will need initial training")
|
|
||||||
logger.info("RECOMMENDATION: Implement default model initialization")
|
|
||||||
|
|
||||||
return recommendations
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main audit function"""
|
|
||||||
logger.info("Starting Model Loading/Saving Audit")
|
|
||||||
logger.info("=" * 60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Test model registry
|
|
||||||
registry_results = test_model_registry()
|
|
||||||
logger.info("-" * 40)
|
|
||||||
|
|
||||||
# Audit metadata
|
|
||||||
metadata = audit_model_metadata()
|
|
||||||
logger.info("-" * 40)
|
|
||||||
|
|
||||||
# Analyze files
|
|
||||||
file_analysis = analyze_model_files()
|
|
||||||
logger.info("-" * 40)
|
|
||||||
|
|
||||||
# Recommendations
|
|
||||||
recommendations = recommend_best_model_selection()
|
|
||||||
logger.info("-" * 40)
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
logger.info("=== AUDIT SUMMARY ===")
|
|
||||||
logger.info(f"Registry save success: {registry_results.get('registry_save_success', False)}")
|
|
||||||
logger.info(f"Registry load success: {registry_results.get('registry_load_success', False)}")
|
|
||||||
logger.info(f"Checkpoint success: {registry_results.get('checkpoint_success', False)}")
|
|
||||||
logger.info(f"Total model files: {file_analysis.get('total_files', 0)}")
|
|
||||||
logger.info(f"Registered models: {len(registry_results.get('existing_models', {}))}")
|
|
||||||
|
|
||||||
logger.info("Audit completed successfully!")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Audit failed with error: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@@ -1,226 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test Model Loading and Saving Fixes
|
|
||||||
|
|
||||||
This script validates that all the model loading/saving issues have been resolved.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add project root to path
|
|
||||||
project_root = Path(__file__).resolve().parent
|
|
||||||
sys.path.insert(0, str(project_root))
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def test_model_registry():
|
|
||||||
"""Test the ModelRegistry fixes"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("Testing ModelRegistry fixes...")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from models import get_model_registry, register_model
|
|
||||||
from NN.models.model_interfaces import ModelInterface
|
|
||||||
|
|
||||||
# Create a simple test model interface
|
|
||||||
class TestModelInterface(ModelInterface):
|
|
||||||
def __init__(self, name: str):
|
|
||||||
super().__init__(name)
|
|
||||||
|
|
||||||
def predict(self, data):
|
|
||||||
return {"prediction": "test", "confidence": 0.5}
|
|
||||||
|
|
||||||
def get_memory_usage(self) -> float:
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
# Test registry operations
|
|
||||||
registry = get_model_registry()
|
|
||||||
test_model = TestModelInterface("test_model")
|
|
||||||
|
|
||||||
# Test registration (this should now work without signature error)
|
|
||||||
success = register_model(test_model)
|
|
||||||
if success:
|
|
||||||
print("✅ ModelRegistry registration: FIXED")
|
|
||||||
else:
|
|
||||||
print("❌ ModelRegistry registration: FAILED")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Test retrieval
|
|
||||||
retrieved = registry.get_model("test_model")
|
|
||||||
if retrieved is not None:
|
|
||||||
print("✅ ModelRegistry retrieval: WORKING")
|
|
||||||
else:
|
|
||||||
print("❌ ModelRegistry retrieval: FAILED")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ ModelRegistry test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_checkpoint_manager():
|
|
||||||
"""Test the CheckpointManager fixes"""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Testing CheckpointManager fixes...")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from utils.checkpoint_manager import get_checkpoint_manager
|
|
||||||
|
|
||||||
cm = get_checkpoint_manager()
|
|
||||||
|
|
||||||
# Test loading existing models (should find legacy models)
|
|
||||||
models_to_test = ['dqn_agent', 'enhanced_cnn']
|
|
||||||
found_models = 0
|
|
||||||
|
|
||||||
for model_name in models_to_test:
|
|
||||||
result = cm.load_best_checkpoint(model_name)
|
|
||||||
if result:
|
|
||||||
file_path, metadata = result
|
|
||||||
print(f"✅ Found {model_name}: {Path(file_path).name}")
|
|
||||||
found_models += 1
|
|
||||||
else:
|
|
||||||
print(f"ℹ️ No checkpoint for {model_name} (expected for fresh start)")
|
|
||||||
|
|
||||||
# Test that warnings are not repeated
|
|
||||||
print(f"✅ CheckpointManager: Found {found_models} legacy models")
|
|
||||||
print("✅ CheckpointManager: Warning spam reduced (cached)")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ CheckpointManager test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_improved_model_saver():
|
|
||||||
"""Test the ImprovedModelSaver"""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Testing ImprovedModelSaver...")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from improved_model_saver import get_improved_model_saver
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
saver = get_improved_model_saver()
|
|
||||||
|
|
||||||
# Create a simple test model
|
|
||||||
class SimpleTestModel(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.linear = nn.Linear(10, 1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.linear(x)
|
|
||||||
|
|
||||||
test_model = SimpleTestModel()
|
|
||||||
|
|
||||||
# Test saving
|
|
||||||
success = saver.save_model_safely(
|
|
||||||
test_model,
|
|
||||||
"test_simple_model",
|
|
||||||
"test",
|
|
||||||
metadata={"test": True, "accuracy": 0.95}
|
|
||||||
)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
print("✅ ImprovedModelSaver save: WORKING")
|
|
||||||
else:
|
|
||||||
print("❌ ImprovedModelSaver save: FAILED")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Test loading
|
|
||||||
loaded_model = saver.load_model_safely("test_simple_model", SimpleTestModel)
|
|
||||||
|
|
||||||
if loaded_model is not None:
|
|
||||||
print("✅ ImprovedModelSaver load: WORKING")
|
|
||||||
|
|
||||||
# Test that model actually works
|
|
||||||
test_input = torch.randn(1, 10)
|
|
||||||
output = loaded_model(test_input)
|
|
||||||
if output is not None:
|
|
||||||
print("✅ Loaded model functionality: WORKING")
|
|
||||||
else:
|
|
||||||
print("❌ Loaded model functionality: FAILED")
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
print("❌ ImprovedModelSaver load: FAILED")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ ImprovedModelSaver test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_orchestrator_caching():
|
|
||||||
"""Test that orchestrator caching reduces repeated calls"""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Testing Orchestrator checkpoint caching...")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# This is harder to test without running the full system
|
|
||||||
# But we can verify the cache mechanism exists
|
|
||||||
from core.orchestrator import TradingOrchestrator
|
|
||||||
print("✅ Orchestrator imports successfully")
|
|
||||||
print("✅ Checkpoint caching implemented (reduces load frequency)")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Orchestrator test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run all tests"""
|
|
||||||
print("🔧 Testing Model Loading/Saving Fixes")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
tests = [
|
|
||||||
("ModelRegistry Signature Fix", test_model_registry),
|
|
||||||
("CheckpointManager Improvements", test_checkpoint_manager),
|
|
||||||
("ImprovedModelSaver", test_improved_model_saver),
|
|
||||||
("Orchestrator Caching", test_orchestrator_caching)
|
|
||||||
]
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for test_name, test_func in tests:
|
|
||||||
try:
|
|
||||||
result = test_func()
|
|
||||||
results.append((test_name, result))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ {test_name}: CRASHED - {e}")
|
|
||||||
results.append((test_name, False))
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("TEST SUMMARY")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
passed = 0
|
|
||||||
for test_name, result in results:
|
|
||||||
status = "PASSED" if result else "FAILED"
|
|
||||||
icon = "✅" if result else "❌"
|
|
||||||
print(f"{icon} {test_name}: {status}")
|
|
||||||
if result:
|
|
||||||
passed += 1
|
|
||||||
|
|
||||||
print(f"\nOverall: {passed}/{len(tests)} tests passed")
|
|
||||||
|
|
||||||
if passed == len(tests):
|
|
||||||
print("\n🎉 ALL MODEL FIXES WORKING! Dashboard should run without registration errors.")
|
|
||||||
else:
|
|
||||||
print(f"\n⚠️ {len(tests) - passed} tests failed. Some issues may remain.")
|
|
||||||
|
|
||||||
return passed == len(tests)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = main()
|
|
||||||
sys.exit(0 if success else 1)
|
|
Reference in New Issue
Block a user