118 lines
4.5 KiB
Python
118 lines
4.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test Training Fixes
|
|
|
|
This script tests the fixes for CNN adapter and DQN training issues.
|
|
"""
|
|
|
|
import asyncio
|
|
import time
|
|
import numpy as np
|
|
from datetime import datetime
|
|
from core.orchestrator import TradingOrchestrator
|
|
from core.data_provider import DataProvider
|
|
|
|
async def test_training_fixes():
|
|
"""Test the training fixes"""
|
|
print("=== Testing Training Fixes ===")
|
|
|
|
# Initialize orchestrator
|
|
print("1. Initializing orchestrator...")
|
|
data_provider = DataProvider()
|
|
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
|
|
|
# Wait for initialization
|
|
await asyncio.sleep(3)
|
|
|
|
# Check CNN adapter initialization
|
|
print("\n2. Checking CNN adapter initialization:")
|
|
if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter:
|
|
print(" ✅ CNN adapter is properly initialized")
|
|
print(f" CNN adapter type: {type(orchestrator.cnn_adapter)}")
|
|
else:
|
|
print(" ❌ CNN adapter is None or missing")
|
|
|
|
# Check DQN agent initialization
|
|
print("\n3. Checking DQN agent initialization:")
|
|
if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent:
|
|
print(" ✅ DQN agent is properly initialized")
|
|
print(f" DQN agent type: {type(orchestrator.rl_agent)}")
|
|
if hasattr(orchestrator.rl_agent, 'policy_net'):
|
|
print(" ✅ DQN policy network is available")
|
|
else:
|
|
print(" ❌ DQN policy network is missing")
|
|
else:
|
|
print(" ❌ DQN agent is None or missing")
|
|
|
|
# Test CNN predictions
|
|
print("\n4. Testing CNN predictions:")
|
|
try:
|
|
predictions = await orchestrator._get_all_predictions('ETH/USDT')
|
|
cnn_predictions = [p for p in predictions if 'cnn' in p.model_name.lower()]
|
|
if cnn_predictions:
|
|
print(f" ✅ Got {len(cnn_predictions)} CNN predictions")
|
|
for pred in cnn_predictions:
|
|
print(f" CNN prediction: {pred.action} (confidence: {pred.confidence:.3f})")
|
|
else:
|
|
print(" ❌ No CNN predictions received")
|
|
except Exception as e:
|
|
print(f" ❌ CNN prediction failed: {e}")
|
|
|
|
# Test training with validation
|
|
print("\n5. Testing training with validation:")
|
|
for i in range(3):
|
|
print(f" Training iteration {i+1}/3...")
|
|
|
|
# Create training records for different models
|
|
training_records = [
|
|
{
|
|
'model_name': 'enhanced_cnn',
|
|
'model_input': np.random.randn(7850),
|
|
'prediction': {'action': 'BUY', 'confidence': 0.7},
|
|
'symbol': 'ETH/USDT',
|
|
'timestamp': datetime.now()
|
|
},
|
|
{
|
|
'model_name': 'dqn_agent',
|
|
'model_input': np.random.randn(7850),
|
|
'prediction': {'action': 'SELL', 'confidence': 0.8},
|
|
'symbol': 'ETH/USDT',
|
|
'timestamp': datetime.now()
|
|
}
|
|
]
|
|
|
|
for record in training_records:
|
|
try:
|
|
success = await orchestrator._train_model_on_outcome(
|
|
record, True, 0.5, 1.0
|
|
)
|
|
if success:
|
|
print(f" ✅ Training succeeded for {record['model_name']}")
|
|
else:
|
|
print(f" ⚠️ Training failed for {record['model_name']}")
|
|
except Exception as e:
|
|
print(f" ❌ Training error for {record['model_name']}: {e}")
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
# Show final statistics
|
|
print("\n6. Final model statistics:")
|
|
orchestrator.log_model_statistics(detailed=True)
|
|
|
|
# Check for overfitting warnings
|
|
print("\n7. Checking for training quality:")
|
|
summary = orchestrator.get_model_statistics_summary()
|
|
for model_name, stats in summary.items():
|
|
if stats['total_trainings'] > 0:
|
|
print(f" {model_name}: {stats['total_trainings']} trainings, "
|
|
f"avg time: {stats['average_training_time_ms']:.1f}ms")
|
|
if stats['current_loss'] is not None:
|
|
if stats['current_loss'] < 0.001:
|
|
print(f" ⚠️ {model_name} has very low loss ({stats['current_loss']:.6f}) - check for overfitting")
|
|
else:
|
|
print(f" ✅ {model_name} has reasonable loss ({stats['current_loss']:.6f})")
|
|
|
|
print("\n✅ Training fixes test completed!")
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(test_training_fixes()) |