This commit is contained in:
Dobromir Popov
2025-09-08 15:22:01 +03:00
parent 96b0513834
commit 98ebbe5089
8 changed files with 2 additions and 1135 deletions

View File

@@ -103,8 +103,8 @@ class TradingOrchestrator:
# Configuration - AGGRESSIVE for more training data
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.15) # Lowered from 0.20
self.confidence_threshold_close = self.config.orchestrator.get('confidence_threshold_close', 0.08) # Lowered from 0.10
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30)
self.symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) # Enhanced to support multiple symbols
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 5)
self.symbols = self.config.get('symbols', ['ETH/USDT']) # Enhanced to support multiple symbols
# NEW: Aggressiveness parameters
self.entry_aggressiveness = self.config.orchestrator.get('entry_aggressiveness', 0.5) # 0.0 = conservative, 1.0 = very aggressive

View File

@@ -1,87 +0,0 @@
#!/usr/bin/env python3
"""
Test COB Integration Status in Enhanced Orchestrator
"""
import asyncio
import sys
from pathlib import Path
sys.path.append(str(Path('.').absolute()))
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
from core.data_provider import DataProvider
async def test_cob_integration():
print("=" * 60)
print("COB INTEGRATION AUDIT")
print("=" * 60)
try:
data_provider = DataProvider()
orchestrator = EnhancedTradingOrchestrator(
data_provider=data_provider,
symbols=['ETH/USDT', 'BTC/USDT'],
enhanced_rl_training=True
)
print(f"✓ Enhanced Orchestrator created")
print(f"Has COB integration attribute: {hasattr(orchestrator, 'cob_integration')}")
print(f"COB integration value: {orchestrator.cob_integration}")
print(f"COB integration type: {type(orchestrator.cob_integration)}")
print(f"COB integration active: {getattr(orchestrator, 'cob_integration_active', 'Not set')}")
if orchestrator.cob_integration:
print("\n--- COB Integration Details ---")
print(f"COB Integration class: {orchestrator.cob_integration.__class__.__name__}")
# Check if it has the expected methods
methods_to_check = ['get_statistics', 'get_cob_snapshot', 'add_dashboard_callback', 'start', 'stop']
for method in methods_to_check:
has_method = hasattr(orchestrator.cob_integration, method)
print(f"Has {method}: {has_method}")
# Try to get statistics
if hasattr(orchestrator.cob_integration, 'get_statistics'):
try:
stats = orchestrator.cob_integration.get_statistics()
print(f"COB statistics: {stats}")
except Exception as e:
print(f"Error getting COB statistics: {e}")
# Try to get a snapshot
if hasattr(orchestrator.cob_integration, 'get_cob_snapshot'):
try:
snapshot = orchestrator.cob_integration.get_cob_snapshot('ETH/USDT')
print(f"ETH/USDT snapshot: {snapshot}")
except Exception as e:
print(f"Error getting COB snapshot: {e}")
# Check if COB integration needs to be started
print(f"\n--- Starting COB Integration ---")
try:
await orchestrator.start_cob_integration()
print("✓ COB integration started successfully")
# Wait a moment and check statistics again
await asyncio.sleep(3)
if hasattr(orchestrator.cob_integration, 'get_statistics'):
stats = orchestrator.cob_integration.get_statistics()
print(f"COB statistics after start: {stats}")
except Exception as e:
print(f"Error starting COB integration: {e}")
else:
print("\n❌ COB integration is None - this explains the dashboard issues")
print("The Enhanced Orchestrator failed to initialize COB integration")
# Check the error flag
if hasattr(orchestrator, '_cob_integration_failed'):
print(f"COB integration failed flag: {orchestrator._cob_integration_failed}")
except Exception as e:
print(f"Error in COB audit: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(test_cob_integration())

View File

@@ -1,144 +0,0 @@
#!/usr/bin/env python3
"""
Test Enhanced Training Integration
This script tests the integration of EnhancedRealtimeTrainingSystem
into the TradingOrchestrator to ensure it works correctly.
"""
import sys
import os
import logging
import asyncio
from datetime import datetime
# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from core.orchestrator import TradingOrchestrator
from core.data_provider import DataProvider
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
async def test_enhanced_training_integration():
"""Test the enhanced training system integration"""
try:
logger.info("=" * 60)
logger.info("TESTING ENHANCED TRAINING INTEGRATION")
logger.info("=" * 60)
# 1. Initialize orchestrator with enhanced training
logger.info("1. Initializing orchestrator with enhanced training...")
data_provider = DataProvider()
orchestrator = TradingOrchestrator(
data_provider=data_provider,
enhanced_rl_training=True
)
# 2. Check if training system is available
logger.info("2. Checking training system availability...")
training_available = hasattr(orchestrator, 'enhanced_training_system')
training_enabled = getattr(orchestrator, 'training_enabled', False)
logger.info(f" - Training system attribute: {'✅ Available' if training_available else '❌ Missing'}")
logger.info(f" - Training enabled: {'✅ Yes' if training_enabled else '❌ No'}")
# 3. Test training system initialization
if training_available and orchestrator.enhanced_training_system:
logger.info("3. Testing training system methods...")
# Test getting training statistics
stats = orchestrator.get_enhanced_training_stats()
logger.info(f" - Training stats retrieved: {len(stats)} fields")
logger.info(f" - Training enabled in stats: {stats.get('training_enabled', False)}")
logger.info(f" - System available: {stats.get('system_available', False)}")
# Test starting training
start_result = orchestrator.start_enhanced_training()
logger.info(f" - Start training result: {'✅ Success' if start_result else '❌ Failed'}")
if start_result:
# Let it run for a few seconds
logger.info(" - Letting training run for 5 seconds...")
await asyncio.sleep(5)
# Get updated stats
updated_stats = orchestrator.get_enhanced_training_stats()
logger.info(f" - Updated stats: {updated_stats.get('is_training', False)}")
# Stop training
stop_result = orchestrator.stop_enhanced_training()
logger.info(f" - Stop training result: {'✅ Success' if stop_result else '❌ Failed'}")
else:
logger.warning("3. Training system not available - checking fallback behavior...")
# Test methods when training system is not available
stats = orchestrator.get_enhanced_training_stats()
logger.info(f" - Fallback stats: {stats}")
start_result = orchestrator.start_enhanced_training()
logger.info(f" - Fallback start result: {start_result}")
# 4. Test dashboard connection method
logger.info("4. Testing dashboard connection method...")
try:
orchestrator.set_training_dashboard(None) # Test with None
logger.info(" - Dashboard connection method: ✅ Available")
except Exception as e:
logger.error(f" - Dashboard connection method error: {e}")
# 5. Summary
logger.info("=" * 60)
logger.info("INTEGRATION TEST SUMMARY")
logger.info("=" * 60)
if training_available and training_enabled:
logger.info("✅ ENHANCED TRAINING INTEGRATION SUCCESSFUL")
logger.info(" - Training system properly integrated")
logger.info(" - All methods available and functional")
logger.info(" - Ready for real-time training")
elif training_available:
logger.info("⚠️ ENHANCED TRAINING PARTIALLY INTEGRATED")
logger.info(" - Training system available but not enabled")
logger.info(" - Check EnhancedRealtimeTrainingSystem import")
else:
logger.info("❌ ENHANCED TRAINING INTEGRATION FAILED")
logger.info(" - Training system not properly integrated")
logger.info(" - Methods missing or non-functional")
return training_available and training_enabled
except Exception as e:
logger.error(f"Error in integration test: {e}")
import traceback
logger.error(traceback.format_exc())
return False
async def main():
"""Main test function"""
try:
success = await test_enhanced_training_integration()
if success:
logger.info("🎉 All tests passed! Enhanced training integration is working.")
return 0
else:
logger.warning("⚠️ Some tests failed. Check the integration.")
return 1
except KeyboardInterrupt:
logger.info("Test interrupted by user")
return 0
except Exception as e:
logger.error(f"Fatal error in test: {e}")
return 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)

View File

@@ -1,78 +0,0 @@
#!/usr/bin/env python3
"""
Simple Enhanced Training Test
Quick test to verify enhanced training system can be enabled and controlled.
"""
import sys
import os
import logging
# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from core.orchestrator import TradingOrchestrator
from core.data_provider import DataProvider
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_enhanced_training():
"""Test enhanced training system"""
try:
logger.info("Testing Enhanced Training System...")
# 1. Create data provider
data_provider = DataProvider()
# 2. Create orchestrator with enhanced training ENABLED
logger.info("Creating orchestrator with enhanced_rl_training=True...")
orchestrator = TradingOrchestrator(
data_provider=data_provider,
enhanced_rl_training=True # 🔥 THIS ENABLES IT
)
# 3. Check if training system is available
logger.info(f"Training system available: {orchestrator.enhanced_training_system is not None}")
logger.info(f"Training enabled: {orchestrator.training_enabled}")
# 4. Get training stats
stats = orchestrator.get_enhanced_training_stats()
logger.info(f"Training stats: {stats}")
# 5. Test start/stop
if orchestrator.enhanced_training_system:
logger.info("Testing start/stop functionality...")
# Start training
start_result = orchestrator.start_enhanced_training()
logger.info(f"Start result: {start_result}")
# Get updated stats
updated_stats = orchestrator.get_enhanced_training_stats()
logger.info(f"Updated stats: {updated_stats}")
# Stop training
stop_result = orchestrator.stop_enhanced_training()
logger.info(f"Stop result: {stop_result}")
logger.info("✅ Enhanced training system is working!")
return True
else:
logger.warning("❌ Enhanced training system not available")
return False
except Exception as e:
logger.error(f"Error testing enhanced training: {e}")
return False
if __name__ == "__main__":
success = test_enhanced_training()
if success:
print("\n🎉 Enhanced training system is ready to use!")
print("To enable it in your main system, use:")
print(" enhanced_rl_training=True when creating TradingOrchestrator")
else:
print("\n⚠️ Enhanced training system has issues. Check the logs above.")

View File

@@ -1,180 +0,0 @@
#!/usr/bin/env python3
"""
Test FRESH to LOADED Model Status Fix
This script tests the fix for models showing as FRESH instead of LOADED.
"""
import logging
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).resolve().parent
sys.path.insert(0, str(project_root))
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_orchestrator_model_initialization():
"""Test that orchestrator initializes all models correctly"""
print("=" * 60)
print("Testing Orchestrator Model Initialization...")
print("=" * 60)
try:
from core.data_provider import DataProvider
from core.orchestrator import TradingOrchestrator
# Create data provider and orchestrator
data_provider = DataProvider()
orchestrator = TradingOrchestrator(data_provider=data_provider, enhanced_rl_training=True)
# Check which models were initialized
models_initialized = []
if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent:
models_initialized.append('DQN')
if hasattr(orchestrator, 'cnn_model') and orchestrator.cnn_model:
models_initialized.append('CNN')
if hasattr(orchestrator, 'extrema_trainer') and orchestrator.extrema_trainer:
models_initialized.append('ExtremaTrainer')
if hasattr(orchestrator, 'cob_rl_agent') and orchestrator.cob_rl_agent:
models_initialized.append('COB_RL')
if hasattr(orchestrator, 'transformer_model') and orchestrator.transformer_model:
models_initialized.append('TRANSFORMER')
if hasattr(orchestrator, 'decision_model') and orchestrator.decision_model:
models_initialized.append('DECISION')
print(f"✅ Initialized Models: {', '.join(models_initialized)}")
# Check model states
print("\nModel States:")
for model_name, state in orchestrator.model_states.items():
checkpoint_loaded = state.get('checkpoint_loaded', False)
status = "LOADED" if checkpoint_loaded else "FRESH"
filename = state.get('checkpoint_filename', 'none')
print(f" {model_name.upper()}: {status} ({filename})")
return orchestrator, len(models_initialized)
except Exception as e:
print(f"❌ Orchestrator initialization failed: {e}")
return None, 0
def test_checkpoint_saving(orchestrator):
"""Test saving checkpoints for all models"""
print("\n" + "=" * 60)
print("Testing Checkpoint Saving...")
print("=" * 60)
try:
from model_checkpoint_saver import ModelCheckpointSaver
saver = ModelCheckpointSaver(orchestrator)
# Force all models to LOADED status
updated_models = saver.force_all_models_to_loaded()
print(f"✅ Updated {len(updated_models)} models to LOADED status")
# Check updated states
print("\nUpdated Model States:")
fresh_count = 0
loaded_count = 0
for model_name, state in orchestrator.model_states.items():
checkpoint_loaded = state.get('checkpoint_loaded', False)
status = "LOADED" if checkpoint_loaded else "FRESH"
filename = state.get('checkpoint_filename', 'none')
print(f" {model_name.upper()}: {status} ({filename})")
if checkpoint_loaded:
loaded_count += 1
else:
fresh_count += 1
print(f"\nSummary: {loaded_count} LOADED, {fresh_count} FRESH")
return fresh_count == 0
except Exception as e:
print(f"❌ Checkpoint saving test failed: {e}")
return False
def test_dashboard_model_status():
"""Test how models show up in dashboard"""
print("\n" + "=" * 60)
print("Testing Dashboard Model Status Display...")
print("=" * 60)
try:
# Simulate dashboard model status check
from web.component_manager import DashboardComponentManager
print("✅ Dashboard component manager imports successfully")
print("✅ Model status display logic available")
return True
except Exception as e:
print(f"❌ Dashboard test failed: {e}")
return False
def main():
"""Run all tests"""
print("🔧 Testing FRESH to LOADED Model Status Fix")
print("=" * 60)
# Test 1: Orchestrator initialization
orchestrator, models_count = test_orchestrator_model_initialization()
if not orchestrator:
print("\n❌ Cannot proceed - orchestrator initialization failed")
return False
# Test 2: Checkpoint saving
checkpoint_success = test_checkpoint_saving(orchestrator)
# Test 3: Dashboard integration
dashboard_success = test_dashboard_model_status()
# Summary
print("\n" + "=" * 60)
print("TEST SUMMARY")
print("=" * 60)
tests = [
("Model Initialization", models_count > 0),
("Checkpoint Status Fix", checkpoint_success),
("Dashboard Integration", dashboard_success)
]
passed = 0
for test_name, result in tests:
status = "PASSED" if result else "FAILED"
icon = "" if result else ""
print(f"{icon} {test_name}: {status}")
if result:
passed += 1
print(f"\nOverall: {passed}/{len(tests)} tests passed")
if passed == len(tests):
print("\n🎉 ALL TESTS PASSED! Models should now show as LOADED instead of FRESH.")
print("\nNext steps:")
print("1. Restart the dashboard")
print("2. Models should now show as LOADED in the status panel")
print("3. The FRESH status issue should be resolved")
else:
print(f"\n⚠️ {len(tests) - passed} tests failed. Some issues may remain.")
return passed == len(tests)
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@@ -1,74 +0,0 @@
#!/usr/bin/env python3
"""
Test script to verify leverage P&L calculations are working correctly
"""
from web.clean_dashboard import create_clean_dashboard
def test_leverage_calculations():
print("🧮 Testing Leverage P&L Calculations")
print("=" * 50)
# Create dashboard
dashboard = create_clean_dashboard()
print("✅ Dashboard created successfully")
# Test 1: Position leverage vs slider leverage
print("\n📊 Test 1: Position vs Slider Leverage")
dashboard.current_leverage = 25 # Current slider at x25
dashboard.current_position = {
'side': 'LONG',
'size': 0.01,
'price': 2000.0, # Entry at $2000
'leverage': 10, # Position opened at x10 leverage
'symbol': 'ETH/USDT'
}
print(f" Position opened at: x{dashboard.current_position['leverage']} leverage")
print(f" Current slider at: x{dashboard.current_leverage} leverage")
print(" ✅ Position uses its stored leverage, not current slider")
# Test 2: Trading statistics with leveraged P&L
print("\n📈 Test 2: Trading Statistics")
test_trade = {
'symbol': 'ETH/USDT',
'side': 'BUY',
'pnl': 100.0, # Leveraged P&L
'pnl_raw': 2.0, # Raw P&L (before leverage)
'leverage_used': 50, # x50 leverage used
'fees': 0.5
}
dashboard.closed_trades.append(test_trade)
dashboard.session_pnl = 100.0
stats = dashboard._get_trading_statistics()
print(f" Trade raw P&L: ${test_trade['pnl_raw']:.2f}")
print(f" Trade leverage: x{test_trade['leverage_used']}")
print(f" Trade leveraged P&L: ${test_trade['pnl']:.2f}")
print(f" Statistics total P&L: ${stats['total_pnl']:.2f}")
print(f" ✅ Statistics use leveraged P&L correctly")
# Test 3: Session P&L calculation
print("\n💰 Test 3: Session P&L")
print(f" Session P&L: ${dashboard.session_pnl:.2f}")
print(f" Expected: $100.00")
if abs(dashboard.session_pnl - 100.0) < 0.01:
print(" ✅ Session P&L correctly uses leveraged amounts")
else:
print(" ❌ Session P&L calculation error")
print("\n🎯 Summary:")
print(" • Positions store their original leverage")
print(" • Unrealized P&L uses position leverage (not slider)")
print(" • Completed trades store both raw and leveraged P&L")
print(" • Statistics display leveraged P&L")
print(" • Session totals use leveraged amounts")
print("\n✅ ALL LEVERAGE P&L CALCULATIONS FIXED!")
if __name__ == "__main__":
test_leverage_calculations()

View File

@@ -1,344 +0,0 @@
#!/usr/bin/env python3
"""
Model Loading/Saving Audit Test
This script tests the model registry and saving/loading mechanisms
to identify any issues and provide recommendations.
"""
import os
import sys
import logging
import torch
import torch.nn as nn
from datetime import datetime
from pathlib import Path
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from utils.model_registry import get_model_registry, save_model, load_model, save_checkpoint
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class SimpleTestModel(nn.Module):
"""Simple neural network for testing"""
def __init__(self, input_size=10, hidden_size=32, output_size=2):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size)
)
def forward(self, x):
return self.net(x)
def test_model_registry():
"""Test the model registry functionality"""
logger.info("=== MODEL REGISTRY AUDIT ===")
registry = get_model_registry()
logger.info(f"Registry base directory: {registry.base_dir}")
logger.info(f"Registry metadata file: {registry.metadata_file}")
# Check existing models
existing_models = registry.list_models()
logger.info(f"Existing models: {existing_models}")
# Test model creation and saving
logger.info("Creating test model...")
test_model = SimpleTestModel()
# Generate some fake training data
test_input = torch.randn(32, 10)
test_output = test_model(test_input)
logger.info(f"Test model created. Input shape: {test_input.shape}, Output shape: {test_output.shape}")
# Test saving with different methods
logger.info("Testing model saving...")
# Test 1: Save with unified registry
success = save_model(
model=test_model,
model_name="audit_test_model",
model_type="cnn",
metadata={
"test_type": "registry_audit",
"created_at": datetime.now().isoformat(),
"input_shape": list(test_input.shape),
"output_shape": list(test_output.shape)
}
)
if success:
logger.info("✅ Model saved successfully with unified registry")
else:
logger.error("❌ Failed to save model with unified registry")
# Test 2: Load model back
logger.info("Testing model loading...")
loaded_model = load_model("audit_test_model", "cnn")
if loaded_model is not None:
logger.info("✅ Model loaded successfully")
# Test if loaded model has proper structure
if hasattr(loaded_model, 'state_dict') and callable(loaded_model.state_dict):
state_dict = loaded_model.state_dict()
logger.info(f"Loaded model test - State dict keys: {list(state_dict.keys())}")
# Check if we can create a new instance and load the state
fresh_model = SimpleTestModel()
try:
fresh_model.load_state_dict(state_dict)
test_output_loaded = fresh_model(test_input)
logger.info(f"Loaded model test - Output shape: {test_output_loaded.shape}")
# Compare outputs (should be identical)
if torch.allclose(test_output, test_output_loaded, atol=1e-6):
logger.info("✅ Loaded model produces identical outputs")
else:
logger.warning("⚠️ Loaded model outputs differ (this might be expected due to different random states)")
except Exception as e:
logger.warning(f"Could not test loaded model: {e}")
else:
logger.warning("Loaded model does not have proper structure")
else:
logger.error("❌ Failed to load model")
# Test 3: Save checkpoint
logger.info("Testing checkpoint saving...")
checkpoint_success = save_checkpoint(
model=test_model,
model_name="audit_test_model",
model_type="cnn",
performance_score=0.85,
metadata={
"checkpoint_test": True,
"performance_metric": "accuracy",
"epoch": 1
}
)
if checkpoint_success:
logger.info("✅ Checkpoint saved successfully")
else:
logger.error("❌ Failed to save checkpoint")
# Check registry metadata after operations
logger.info("Checking registry metadata after operations...")
updated_models = registry.list_models()
logger.info(f"Updated models: {updated_models}")
# Check file system
logger.info("Checking file system...")
models_dir = Path("models")
if models_dir.exists():
logger.info(f"Models directory contents:")
for item in models_dir.rglob("*"):
if item.is_file():
logger.info(f" {item.relative_to(models_dir)} ({item.stat().st_size} bytes)")
return {
"registry_save_success": success,
"registry_load_success": loaded_model is not None,
"checkpoint_success": checkpoint_success,
"existing_models": existing_models,
"updated_models": updated_models
}
def audit_model_metadata():
"""Audit the model metadata structure"""
logger.info("=== MODEL METADATA AUDIT ===")
registry = get_model_registry()
# Check metadata structure
metadata = registry.metadata
logger.info(f"Metadata keys: {list(metadata.keys())}")
if 'models' in metadata:
models = metadata['models']
logger.info(f"Number of registered models: {len(models)}")
for model_name, model_data in models.items():
logger.info(f"Model '{model_name}':")
logger.info(f" - Type: {model_data.get('type', 'unknown')}")
logger.info(f" - Last saved: {model_data.get('last_saved', 'never')}")
logger.info(f" - Save count: {model_data.get('save_count', 0)}")
logger.info(f" - Latest path: {model_data.get('latest_path', 'none')}")
logger.info(f" - Checkpoints: {len(model_data.get('checkpoints', []))}")
if 'last_updated' in metadata:
logger.info(f"Last metadata update: {metadata['last_updated']}")
return metadata
def analyze_model_files():
"""Analyze the model files on disk"""
logger.info("=== MODEL FILES ANALYSIS ===")
models_dir = Path("models")
if not models_dir.exists():
logger.error("Models directory does not exist")
return {}
analysis = {
'total_files': 0,
'total_size': 0,
'by_type': {},
'by_model': {},
'orphaned_files': [],
'missing_files': []
}
# Analyze all .pt files
for pt_file in models_dir.rglob("*.pt"):
analysis['total_files'] += 1
analysis['total_size'] += pt_file.stat().st_size
# Categorize by type
parts = pt_file.parts
model_type = "unknown"
if "cnn" in parts:
model_type = "cnn"
elif "dqn" in parts:
model_type = "dqn"
elif "transformer" in parts:
model_type = "transformer"
elif "hybrid" in parts:
model_type = "hybrid"
if model_type not in analysis['by_type']:
analysis['by_type'][model_type] = []
analysis['by_type'][model_type].append(str(pt_file))
# Try to extract model name
filename = pt_file.name
if "_latest" in filename:
model_name = filename.replace("_latest.pt", "")
elif "_" in filename:
# Extract timestamp-based names
parts = filename.split("_")
if len(parts) >= 2:
model_name = "_".join(parts[:-1]) # Everything except timestamp
else:
model_name = filename.replace(".pt", "")
else:
model_name = filename.replace(".pt", "")
if model_name not in analysis['by_model']:
analysis['by_model'][model_name] = []
analysis['by_model'][model_name].append(str(pt_file))
logger.info(f"Total model files: {analysis['total_files']}")
logger.info(f"Total size: {analysis['total_size'] / (1024*1024):.2f} MB")
logger.info("Files by type:")
for model_type, files in analysis['by_type'].items():
logger.info(f" {model_type}: {len(files)} files")
logger.info("Files by model:")
for model_name, files in analysis['by_model'].items():
logger.info(f" {model_name}: {len(files)} files")
return analysis
def recommend_best_model_selection():
"""Provide recommendations for best model selection at startup"""
logger.info("=== BEST MODEL SELECTION RECOMMENDATIONS ===")
registry = get_model_registry()
models = registry.list_models()
recommendations = {
'startup_strategy': 'hybrid',
'fallback_models': [],
'performance_criteria': [],
'metadata_requirements': []
}
if models:
logger.info("Available models for selection:")
# Analyze each model type
for model_name, model_info in models.items():
model_type = model_info.get('type', 'unknown')
logger.info(f" {model_name} ({model_type}) - last saved: {model_info.get('last_saved', 'unknown')}")
# Check if checkpoints exist
if 'checkpoint_count' in model_info and model_info['checkpoint_count'] > 0:
logger.info(f" - Has {model_info['checkpoint_count']} checkpoints")
recommendations['fallback_models'].append(model_name)
# Recommendations
logger.info("RECOMMENDATIONS:")
logger.info("1. Startup Strategy:")
logger.info(" - Try to load latest model for each type")
logger.info(" - Fall back to checkpoints if latest model fails")
logger.info(" - Use fallback to basic/default model if all else fails")
logger.info("2. Performance-based Selection:")
logger.info(" - For models with checkpoints, select highest performance_score")
logger.info(" - Track model age and prefer recently trained models")
logger.info(" - Implement model validation on startup")
logger.info("3. Metadata Requirements:")
logger.info(" - Store performance metrics in metadata")
logger.info(" - Track training data quality and size")
logger.info(" - Include model validation results")
else:
logger.info("No models registered - system will need initial training")
logger.info("RECOMMENDATION: Implement default model initialization")
return recommendations
def main():
"""Main audit function"""
logger.info("Starting Model Loading/Saving Audit")
logger.info("=" * 60)
try:
# Test model registry
registry_results = test_model_registry()
logger.info("-" * 40)
# Audit metadata
metadata = audit_model_metadata()
logger.info("-" * 40)
# Analyze files
file_analysis = analyze_model_files()
logger.info("-" * 40)
# Recommendations
recommendations = recommend_best_model_selection()
logger.info("-" * 40)
# Summary
logger.info("=== AUDIT SUMMARY ===")
logger.info(f"Registry save success: {registry_results.get('registry_save_success', False)}")
logger.info(f"Registry load success: {registry_results.get('registry_load_success', False)}")
logger.info(f"Checkpoint success: {registry_results.get('checkpoint_success', False)}")
logger.info(f"Total model files: {file_analysis.get('total_files', 0)}")
logger.info(f"Registered models: {len(registry_results.get('existing_models', {}))}")
logger.info("Audit completed successfully!")
except Exception as e:
logger.error(f"Audit failed with error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -1,226 +0,0 @@
#!/usr/bin/env python3
"""
Test Model Loading and Saving Fixes
This script validates that all the model loading/saving issues have been resolved.
"""
import logging
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).resolve().parent
sys.path.insert(0, str(project_root))
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_model_registry():
"""Test the ModelRegistry fixes"""
print("=" * 60)
print("Testing ModelRegistry fixes...")
print("=" * 60)
try:
from models import get_model_registry, register_model
from NN.models.model_interfaces import ModelInterface
# Create a simple test model interface
class TestModelInterface(ModelInterface):
def __init__(self, name: str):
super().__init__(name)
def predict(self, data):
return {"prediction": "test", "confidence": 0.5}
def get_memory_usage(self) -> float:
return 1.0
# Test registry operations
registry = get_model_registry()
test_model = TestModelInterface("test_model")
# Test registration (this should now work without signature error)
success = register_model(test_model)
if success:
print("✅ ModelRegistry registration: FIXED")
else:
print("❌ ModelRegistry registration: FAILED")
return False
# Test retrieval
retrieved = registry.get_model("test_model")
if retrieved is not None:
print("✅ ModelRegistry retrieval: WORKING")
else:
print("❌ ModelRegistry retrieval: FAILED")
return False
return True
except Exception as e:
print(f"❌ ModelRegistry test failed: {e}")
return False
def test_checkpoint_manager():
"""Test the CheckpointManager fixes"""
print("\n" + "=" * 60)
print("Testing CheckpointManager fixes...")
print("=" * 60)
try:
from utils.checkpoint_manager import get_checkpoint_manager
cm = get_checkpoint_manager()
# Test loading existing models (should find legacy models)
models_to_test = ['dqn_agent', 'enhanced_cnn']
found_models = 0
for model_name in models_to_test:
result = cm.load_best_checkpoint(model_name)
if result:
file_path, metadata = result
print(f"✅ Found {model_name}: {Path(file_path).name}")
found_models += 1
else:
print(f" No checkpoint for {model_name} (expected for fresh start)")
# Test that warnings are not repeated
print(f"✅ CheckpointManager: Found {found_models} legacy models")
print("✅ CheckpointManager: Warning spam reduced (cached)")
return True
except Exception as e:
print(f"❌ CheckpointManager test failed: {e}")
return False
def test_improved_model_saver():
"""Test the ImprovedModelSaver"""
print("\n" + "=" * 60)
print("Testing ImprovedModelSaver...")
print("=" * 60)
try:
from improved_model_saver import get_improved_model_saver
import torch
import torch.nn as nn
saver = get_improved_model_saver()
# Create a simple test model
class SimpleTestModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
test_model = SimpleTestModel()
# Test saving
success = saver.save_model_safely(
test_model,
"test_simple_model",
"test",
metadata={"test": True, "accuracy": 0.95}
)
if success:
print("✅ ImprovedModelSaver save: WORKING")
else:
print("❌ ImprovedModelSaver save: FAILED")
return False
# Test loading
loaded_model = saver.load_model_safely("test_simple_model", SimpleTestModel)
if loaded_model is not None:
print("✅ ImprovedModelSaver load: WORKING")
# Test that model actually works
test_input = torch.randn(1, 10)
output = loaded_model(test_input)
if output is not None:
print("✅ Loaded model functionality: WORKING")
else:
print("❌ Loaded model functionality: FAILED")
return False
else:
print("❌ ImprovedModelSaver load: FAILED")
return False
return True
except Exception as e:
print(f"❌ ImprovedModelSaver test failed: {e}")
return False
def test_orchestrator_caching():
"""Test that orchestrator caching reduces repeated calls"""
print("\n" + "=" * 60)
print("Testing Orchestrator checkpoint caching...")
print("=" * 60)
try:
# This is harder to test without running the full system
# But we can verify the cache mechanism exists
from core.orchestrator import TradingOrchestrator
print("✅ Orchestrator imports successfully")
print("✅ Checkpoint caching implemented (reduces load frequency)")
return True
except Exception as e:
print(f"❌ Orchestrator test failed: {e}")
return False
def main():
"""Run all tests"""
print("🔧 Testing Model Loading/Saving Fixes")
print("=" * 60)
tests = [
("ModelRegistry Signature Fix", test_model_registry),
("CheckpointManager Improvements", test_checkpoint_manager),
("ImprovedModelSaver", test_improved_model_saver),
("Orchestrator Caching", test_orchestrator_caching)
]
results = []
for test_name, test_func in tests:
try:
result = test_func()
results.append((test_name, result))
except Exception as e:
print(f"{test_name}: CRASHED - {e}")
results.append((test_name, False))
# Summary
print("\n" + "=" * 60)
print("TEST SUMMARY")
print("=" * 60)
passed = 0
for test_name, result in results:
status = "PASSED" if result else "FAILED"
icon = "" if result else ""
print(f"{icon} {test_name}: {status}")
if result:
passed += 1
print(f"\nOverall: {passed}/{len(tests)} tests passed")
if passed == len(tests):
print("\n🎉 ALL MODEL FIXES WORKING! Dashboard should run without registration errors.")
else:
print(f"\n⚠️ {len(tests) - passed} tests failed. Some issues may remain.")
return passed == len(tests)
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)