144 lines
5.7 KiB
Python
144 lines
5.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test Enhanced Training Integration
|
|
|
|
This script tests the integration of EnhancedRealtimeTrainingSystem
|
|
into the TradingOrchestrator to ensure it works correctly.
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import logging
|
|
import asyncio
|
|
from datetime import datetime
|
|
|
|
# Add project root to path
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from core.orchestrator import TradingOrchestrator
|
|
from core.data_provider import DataProvider
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
async def test_enhanced_training_integration():
|
|
"""Test the enhanced training system integration"""
|
|
try:
|
|
logger.info("=" * 60)
|
|
logger.info("TESTING ENHANCED TRAINING INTEGRATION")
|
|
logger.info("=" * 60)
|
|
|
|
# 1. Initialize orchestrator with enhanced training
|
|
logger.info("1. Initializing orchestrator with enhanced training...")
|
|
data_provider = DataProvider()
|
|
orchestrator = TradingOrchestrator(
|
|
data_provider=data_provider,
|
|
enhanced_rl_training=True
|
|
)
|
|
|
|
# 2. Check if training system is available
|
|
logger.info("2. Checking training system availability...")
|
|
training_available = hasattr(orchestrator, 'enhanced_training_system')
|
|
training_enabled = getattr(orchestrator, 'training_enabled', False)
|
|
|
|
logger.info(f" - Training system attribute: {'✅ Available' if training_available else '❌ Missing'}")
|
|
logger.info(f" - Training enabled: {'✅ Yes' if training_enabled else '❌ No'}")
|
|
|
|
# 3. Test training system initialization
|
|
if training_available and orchestrator.enhanced_training_system:
|
|
logger.info("3. Testing training system methods...")
|
|
|
|
# Test getting training statistics
|
|
stats = orchestrator.get_enhanced_training_stats()
|
|
logger.info(f" - Training stats retrieved: {len(stats)} fields")
|
|
logger.info(f" - Training enabled in stats: {stats.get('training_enabled', False)}")
|
|
logger.info(f" - System available: {stats.get('system_available', False)}")
|
|
|
|
# Test starting training
|
|
start_result = orchestrator.start_enhanced_training()
|
|
logger.info(f" - Start training result: {'✅ Success' if start_result else '❌ Failed'}")
|
|
|
|
if start_result:
|
|
# Let it run for a few seconds
|
|
logger.info(" - Letting training run for 5 seconds...")
|
|
await asyncio.sleep(5)
|
|
|
|
# Get updated stats
|
|
updated_stats = orchestrator.get_enhanced_training_stats()
|
|
logger.info(f" - Updated stats: {updated_stats.get('is_training', False)}")
|
|
|
|
# Stop training
|
|
stop_result = orchestrator.stop_enhanced_training()
|
|
logger.info(f" - Stop training result: {'✅ Success' if stop_result else '❌ Failed'}")
|
|
|
|
else:
|
|
logger.warning("3. Training system not available - checking fallback behavior...")
|
|
|
|
# Test methods when training system is not available
|
|
stats = orchestrator.get_enhanced_training_stats()
|
|
logger.info(f" - Fallback stats: {stats}")
|
|
|
|
start_result = orchestrator.start_enhanced_training()
|
|
logger.info(f" - Fallback start result: {start_result}")
|
|
|
|
# 4. Test dashboard connection method
|
|
logger.info("4. Testing dashboard connection method...")
|
|
try:
|
|
orchestrator.set_training_dashboard(None) # Test with None
|
|
logger.info(" - Dashboard connection method: ✅ Available")
|
|
except Exception as e:
|
|
logger.error(f" - Dashboard connection method error: {e}")
|
|
|
|
# 5. Summary
|
|
logger.info("=" * 60)
|
|
logger.info("INTEGRATION TEST SUMMARY")
|
|
logger.info("=" * 60)
|
|
|
|
if training_available and training_enabled:
|
|
logger.info("✅ ENHANCED TRAINING INTEGRATION SUCCESSFUL")
|
|
logger.info(" - Training system properly integrated")
|
|
logger.info(" - All methods available and functional")
|
|
logger.info(" - Ready for real-time training")
|
|
elif training_available:
|
|
logger.info("⚠️ ENHANCED TRAINING PARTIALLY INTEGRATED")
|
|
logger.info(" - Training system available but not enabled")
|
|
logger.info(" - Check EnhancedRealtimeTrainingSystem import")
|
|
else:
|
|
logger.info("❌ ENHANCED TRAINING INTEGRATION FAILED")
|
|
logger.info(" - Training system not properly integrated")
|
|
logger.info(" - Methods missing or non-functional")
|
|
|
|
return training_available and training_enabled
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in integration test: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return False
|
|
|
|
async def main():
|
|
"""Main test function"""
|
|
try:
|
|
success = await test_enhanced_training_integration()
|
|
|
|
if success:
|
|
logger.info("🎉 All tests passed! Enhanced training integration is working.")
|
|
return 0
|
|
else:
|
|
logger.warning("⚠️ Some tests failed. Check the integration.")
|
|
return 1
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Test interrupted by user")
|
|
return 0
|
|
except Exception as e:
|
|
logger.error(f"Fatal error in test: {e}")
|
|
return 1
|
|
|
|
if __name__ == "__main__":
|
|
exit_code = asyncio.run(main())
|
|
sys.exit(exit_code) |