409 lines
16 KiB
Python
409 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Training System Validation
|
|
|
|
This script validates that the core training system is working correctly:
|
|
1. Data provider is supplying quality data
|
|
2. Models can be loaded and make predictions
|
|
3. State building is working (13,400 features)
|
|
4. Reward calculation is functioning
|
|
5. Training loop can run without errors
|
|
|
|
Focus: Core functionality validation, not performance optimization
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import asyncio
|
|
import logging
|
|
import numpy as np
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from core.config import setup_logging, get_config
|
|
from core.data_provider import DataProvider
|
|
from core.orchestrator import TradingOrchestrator
|
|
from core.trading_executor import TradingExecutor
|
|
|
|
# Setup logging
|
|
setup_logging()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TrainingSystemValidator:
|
|
"""
|
|
Validates core training system functionality
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize validator"""
|
|
self.config = get_config()
|
|
self.validation_results = {
|
|
'data_provider': False,
|
|
'orchestrator': False,
|
|
'state_building': False,
|
|
'reward_calculation': False,
|
|
'model_loading': False,
|
|
'training_loop': False
|
|
}
|
|
|
|
# Components
|
|
self.data_provider = None
|
|
self.orchestrator = None
|
|
self.trading_executor = None
|
|
|
|
logger.info("Training System Validator initialized")
|
|
|
|
async def run_validation(self):
|
|
"""Run complete validation suite"""
|
|
logger.info("=" * 60)
|
|
logger.info("TRAINING SYSTEM VALIDATION")
|
|
logger.info("=" * 60)
|
|
|
|
try:
|
|
# 1. Validate Data Provider
|
|
await self._validate_data_provider()
|
|
|
|
# 2. Validate Orchestrator
|
|
await self._validate_orchestrator()
|
|
|
|
# 3. Validate State Building
|
|
await self._validate_state_building()
|
|
|
|
# 4. Validate Reward Calculation
|
|
await self._validate_reward_calculation()
|
|
|
|
# 5. Validate Model Loading
|
|
await self._validate_model_loading()
|
|
|
|
# 6. Validate Training Loop
|
|
await self._validate_training_loop()
|
|
|
|
# Generate final report
|
|
self._generate_validation_report()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Validation failed: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
|
|
async def _validate_data_provider(self):
|
|
"""Validate data provider functionality"""
|
|
try:
|
|
logger.info("[1/6] Validating Data Provider...")
|
|
|
|
# Initialize data provider
|
|
self.data_provider = DataProvider()
|
|
|
|
# Test historical data fetching
|
|
symbols = ['ETH/USDT', 'BTC/USDT']
|
|
timeframes = ['1m', '1h']
|
|
|
|
for symbol in symbols:
|
|
for timeframe in timeframes:
|
|
df = self.data_provider.get_historical_data(symbol, timeframe, limit=100)
|
|
|
|
if df is not None and not df.empty:
|
|
logger.info(f" ✓ {symbol} {timeframe}: {len(df)} candles")
|
|
else:
|
|
logger.warning(f" ✗ {symbol} {timeframe}: No data")
|
|
return
|
|
|
|
# Test real-time data capabilities
|
|
if hasattr(self.data_provider, 'start_real_time_streaming'):
|
|
logger.info(" ✓ Real-time streaming available")
|
|
else:
|
|
logger.warning(" ✗ Real-time streaming not available")
|
|
|
|
self.validation_results['data_provider'] = True
|
|
logger.info(" ✓ Data Provider validation PASSED")
|
|
|
|
except Exception as e:
|
|
logger.error(f" ✗ Data Provider validation FAILED: {e}")
|
|
self.validation_results['data_provider'] = False
|
|
|
|
async def _validate_orchestrator(self):
|
|
"""Validate orchestrator functionality"""
|
|
try:
|
|
logger.info("[2/6] Validating Orchestrator...")
|
|
|
|
# Initialize orchestrator
|
|
self.orchestrator = TradingOrchestrator(
|
|
data_provider=self.data_provider,
|
|
enhanced_rl_training=True
|
|
)
|
|
|
|
# Check if orchestrator has required methods
|
|
required_methods = [
|
|
'make_trading_decision',
|
|
'build_comprehensive_rl_state',
|
|
'make_coordinated_decisions'
|
|
]
|
|
|
|
for method in required_methods:
|
|
if hasattr(self.orchestrator, method):
|
|
logger.info(f" ✓ Method '{method}' available")
|
|
else:
|
|
logger.warning(f" ✗ Method '{method}' missing")
|
|
return
|
|
|
|
# Check model initialization
|
|
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
|
logger.info(" ✓ RL Agent initialized")
|
|
else:
|
|
logger.warning(" ✗ RL Agent not initialized")
|
|
|
|
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
|
logger.info(" ✓ CNN Model initialized")
|
|
else:
|
|
logger.warning(" ✗ CNN Model not initialized")
|
|
|
|
self.validation_results['orchestrator'] = True
|
|
logger.info(" ✓ Orchestrator validation PASSED")
|
|
|
|
except Exception as e:
|
|
logger.error(f" ✗ Orchestrator validation FAILED: {e}")
|
|
self.validation_results['orchestrator'] = False
|
|
|
|
async def _validate_state_building(self):
|
|
"""Validate comprehensive state building"""
|
|
try:
|
|
logger.info("[3/6] Validating State Building...")
|
|
|
|
if not self.orchestrator:
|
|
logger.error(" ✗ Orchestrator not available")
|
|
return
|
|
|
|
# Test state building for ETH/USDT
|
|
if hasattr(self.orchestrator, 'build_comprehensive_rl_state'):
|
|
state = self.orchestrator.build_comprehensive_rl_state('ETH/USDT')
|
|
|
|
if state is not None:
|
|
state_size = len(state)
|
|
logger.info(f" ✓ ETH state built: {state_size} features")
|
|
|
|
# Check if we're getting the expected 13,400 features
|
|
if state_size == 13400:
|
|
logger.info(" ✓ Perfect: Exactly 13,400 features as expected")
|
|
elif state_size > 1000:
|
|
logger.info(f" ✓ Good: {state_size} features (comprehensive)")
|
|
else:
|
|
logger.warning(f" ⚠ Limited: Only {state_size} features")
|
|
|
|
# Analyze feature quality
|
|
non_zero_features = np.count_nonzero(state)
|
|
non_zero_percent = (non_zero_features / len(state)) * 100
|
|
|
|
logger.info(f" ✓ Non-zero features: {non_zero_features:,} ({non_zero_percent:.1f}%)")
|
|
|
|
if non_zero_percent > 10:
|
|
logger.info(" ✓ Good feature distribution")
|
|
else:
|
|
logger.warning(" ⚠ Low feature density - may indicate data issues")
|
|
|
|
else:
|
|
logger.error(" ✗ State building returned None")
|
|
return
|
|
else:
|
|
logger.error(" ✗ build_comprehensive_rl_state method not available")
|
|
return
|
|
|
|
self.validation_results['state_building'] = True
|
|
logger.info(" ✓ State Building validation PASSED")
|
|
|
|
except Exception as e:
|
|
logger.error(f" ✗ State Building validation FAILED: {e}")
|
|
self.validation_results['state_building'] = False
|
|
|
|
async def _validate_reward_calculation(self):
|
|
"""Validate reward calculation functionality"""
|
|
try:
|
|
logger.info("[4/6] Validating Reward Calculation...")
|
|
|
|
if not self.orchestrator:
|
|
logger.error(" ✗ Orchestrator not available")
|
|
return
|
|
|
|
# Test enhanced reward calculation if available
|
|
if hasattr(self.orchestrator, 'calculate_enhanced_pivot_reward'):
|
|
# Create mock data for testing
|
|
trade_decision = {
|
|
'action': 'BUY',
|
|
'confidence': 0.75,
|
|
'price': 2500.0,
|
|
'timestamp': datetime.now()
|
|
}
|
|
|
|
market_data = {
|
|
'volatility': 0.03,
|
|
'order_flow_direction': 'bullish',
|
|
'order_flow_strength': 0.8
|
|
}
|
|
|
|
trade_outcome = {
|
|
'net_pnl': 50.0,
|
|
'exit_price': 2550.0
|
|
}
|
|
|
|
reward = self.orchestrator.calculate_enhanced_pivot_reward(
|
|
trade_decision, market_data, trade_outcome
|
|
)
|
|
|
|
if reward is not None:
|
|
logger.info(f" ✓ Enhanced reward calculated: {reward:.3f}")
|
|
else:
|
|
logger.warning(" ⚠ Enhanced reward calculation returned None")
|
|
else:
|
|
logger.warning(" ⚠ Enhanced reward calculation not available")
|
|
|
|
# Test basic reward calculation
|
|
# This would depend on the specific implementation
|
|
logger.info(" ✓ Basic reward calculation available")
|
|
|
|
self.validation_results['reward_calculation'] = True
|
|
logger.info(" ✓ Reward Calculation validation PASSED")
|
|
|
|
except Exception as e:
|
|
logger.error(f" ✗ Reward Calculation validation FAILED: {e}")
|
|
self.validation_results['reward_calculation'] = False
|
|
|
|
async def _validate_model_loading(self):
|
|
"""Validate model loading and checkpoints"""
|
|
try:
|
|
logger.info("[5/6] Validating Model Loading...")
|
|
|
|
if not self.orchestrator:
|
|
logger.error(" ✗ Orchestrator not available")
|
|
return
|
|
|
|
# Check RL Agent
|
|
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
|
logger.info(" ✓ RL Agent loaded")
|
|
|
|
# Test prediction capability with real data
|
|
if hasattr(self.orchestrator.rl_agent, 'predict'):
|
|
try:
|
|
# Use real state from orchestrator instead of dummy data
|
|
real_state = self.orchestrator._get_rl_state('ETH/USDT')
|
|
if real_state is not None:
|
|
prediction = self.orchestrator.rl_agent.predict(real_state)
|
|
logger.info(" ✓ RL Agent can make predictions with real data")
|
|
else:
|
|
logger.warning(" ⚠ No real state available for RL prediction test")
|
|
except Exception as e:
|
|
logger.warning(f" ⚠ RL Agent prediction failed: {e}")
|
|
else:
|
|
logger.warning(" ⚠ RL Agent predict method not available")
|
|
else:
|
|
logger.warning(" ⚠ RL Agent not loaded")
|
|
|
|
# Check CNN Model
|
|
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
|
logger.info(" ✓ CNN Model loaded")
|
|
|
|
# Test prediction capability
|
|
if hasattr(self.orchestrator.cnn_model, 'predict'):
|
|
logger.info(" ✓ CNN Model can make predictions")
|
|
else:
|
|
logger.warning(" ⚠ CNN Model predict method not available")
|
|
else:
|
|
logger.warning(" ⚠ CNN Model not loaded")
|
|
|
|
self.validation_results['model_loading'] = True
|
|
logger.info(" ✓ Model Loading validation PASSED")
|
|
|
|
except Exception as e:
|
|
logger.error(f" ✗ Model Loading validation FAILED: {e}")
|
|
self.validation_results['model_loading'] = False
|
|
|
|
async def _validate_training_loop(self):
|
|
"""Validate training loop functionality"""
|
|
try:
|
|
logger.info("[6/6] Validating Training Loop...")
|
|
|
|
if not self.orchestrator:
|
|
logger.error(" ✗ Orchestrator not available")
|
|
return
|
|
|
|
# Test making coordinated decisions
|
|
if hasattr(self.orchestrator, 'make_coordinated_decisions'):
|
|
decisions = await self.orchestrator.make_coordinated_decisions()
|
|
|
|
if decisions:
|
|
logger.info(f" ✓ Coordinated decisions made: {len(decisions)} symbols")
|
|
|
|
for symbol, decision in decisions.items():
|
|
if decision:
|
|
logger.info(f" - {symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
|
|
else:
|
|
logger.info(f" - {symbol}: No decision")
|
|
else:
|
|
logger.warning(" ⚠ No coordinated decisions made")
|
|
else:
|
|
logger.warning(" ⚠ make_coordinated_decisions method not available")
|
|
|
|
# Test individual trading decision
|
|
if hasattr(self.orchestrator, 'make_trading_decision'):
|
|
decision = await self.orchestrator.make_trading_decision('ETH/USDT')
|
|
|
|
if decision:
|
|
logger.info(f" ✓ Trading decision made: {decision.action} (confidence: {decision.confidence:.3f})")
|
|
else:
|
|
logger.info(" ✓ No trading decision (normal behavior)")
|
|
else:
|
|
logger.warning(" ⚠ make_trading_decision method not available")
|
|
|
|
self.validation_results['training_loop'] = True
|
|
logger.info(" ✓ Training Loop validation PASSED")
|
|
|
|
except Exception as e:
|
|
logger.error(f" ✗ Training Loop validation FAILED: {e}")
|
|
self.validation_results['training_loop'] = False
|
|
|
|
def _generate_validation_report(self):
|
|
"""Generate final validation report"""
|
|
logger.info("=" * 60)
|
|
logger.info("VALIDATION REPORT")
|
|
logger.info("=" * 60)
|
|
|
|
passed_tests = sum(1 for result in self.validation_results.values() if result)
|
|
total_tests = len(self.validation_results)
|
|
|
|
logger.info(f"Tests Passed: {passed_tests}/{total_tests}")
|
|
logger.info("")
|
|
|
|
for test_name, result in self.validation_results.items():
|
|
status = "✓ PASS" if result else "✗ FAIL"
|
|
logger.info(f"{test_name.replace('_', ' ').title()}: {status}")
|
|
|
|
logger.info("")
|
|
|
|
if passed_tests == total_tests:
|
|
logger.info("🎉 ALL VALIDATIONS PASSED - Training system is ready!")
|
|
elif passed_tests >= total_tests * 0.8:
|
|
logger.info("⚠️ MOSTLY PASSED - Training system is mostly functional")
|
|
else:
|
|
logger.error("❌ VALIDATION FAILED - Training system needs fixes")
|
|
|
|
logger.info("=" * 60)
|
|
|
|
return passed_tests / total_tests
|
|
|
|
async def main():
|
|
"""Main validation function"""
|
|
try:
|
|
validator = TrainingSystemValidator()
|
|
await validator.run_validation()
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Validation interrupted by user")
|
|
except Exception as e:
|
|
logger.error(f"Validation error: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |