Files
gogo2/test_training_fixes.py
Dobromir Popov 1894d453c9 timezones
2025-07-27 20:43:28 +03:00

118 lines
4.5 KiB
Python

#!/usr/bin/env python3
"""
Test Training Fixes
This script tests the fixes for CNN adapter and DQN training issues.
"""
import asyncio
import time
import numpy as np
from datetime import datetime
from core.orchestrator import TradingOrchestrator
from core.data_provider import DataProvider
async def test_training_fixes():
"""Test the training fixes"""
print("=== Testing Training Fixes ===")
# Initialize orchestrator
print("1. Initializing orchestrator...")
data_provider = DataProvider()
orchestrator = TradingOrchestrator(data_provider=data_provider)
# Wait for initialization
await asyncio.sleep(3)
# Check CNN adapter initialization
print("\n2. Checking CNN adapter initialization:")
if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter:
print(" ✅ CNN adapter is properly initialized")
print(f" CNN adapter type: {type(orchestrator.cnn_adapter)}")
else:
print(" ❌ CNN adapter is None or missing")
# Check DQN agent initialization
print("\n3. Checking DQN agent initialization:")
if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent:
print(" ✅ DQN agent is properly initialized")
print(f" DQN agent type: {type(orchestrator.rl_agent)}")
if hasattr(orchestrator.rl_agent, 'policy_net'):
print(" ✅ DQN policy network is available")
else:
print(" ❌ DQN policy network is missing")
else:
print(" ❌ DQN agent is None or missing")
# Test CNN predictions
print("\n4. Testing CNN predictions:")
try:
predictions = await orchestrator._get_all_predictions('ETH/USDT')
cnn_predictions = [p for p in predictions if 'cnn' in p.model_name.lower()]
if cnn_predictions:
print(f" ✅ Got {len(cnn_predictions)} CNN predictions")
for pred in cnn_predictions:
print(f" CNN prediction: {pred.action} (confidence: {pred.confidence:.3f})")
else:
print(" ❌ No CNN predictions received")
except Exception as e:
print(f" ❌ CNN prediction failed: {e}")
# Test training with validation
print("\n5. Testing training with validation:")
for i in range(3):
print(f" Training iteration {i+1}/3...")
# Create training records for different models
training_records = [
{
'model_name': 'enhanced_cnn',
'model_input': np.random.randn(7850),
'prediction': {'action': 'BUY', 'confidence': 0.7},
'symbol': 'ETH/USDT',
'timestamp': datetime.now()
},
{
'model_name': 'dqn_agent',
'model_input': np.random.randn(7850),
'prediction': {'action': 'SELL', 'confidence': 0.8},
'symbol': 'ETH/USDT',
'timestamp': datetime.now()
}
]
for record in training_records:
try:
success = await orchestrator._train_model_on_outcome(
record, True, 0.5, 1.0
)
if success:
print(f" ✅ Training succeeded for {record['model_name']}")
else:
print(f" ⚠️ Training failed for {record['model_name']}")
except Exception as e:
print(f" ❌ Training error for {record['model_name']}: {e}")
await asyncio.sleep(1)
# Show final statistics
print("\n6. Final model statistics:")
orchestrator.log_model_statistics(detailed=True)
# Check for overfitting warnings
print("\n7. Checking for training quality:")
summary = orchestrator.get_model_statistics_summary()
for model_name, stats in summary.items():
if stats['total_trainings'] > 0:
print(f" {model_name}: {stats['total_trainings']} trainings, "
f"avg time: {stats['average_training_time_ms']:.1f}ms")
if stats['current_loss'] is not None:
if stats['current_loss'] < 0.001:
print(f" ⚠️ {model_name} has very low loss ({stats['current_loss']:.6f}) - check for overfitting")
else:
print(f"{model_name} has reasonable loss ({stats['current_loss']:.6f})")
print("\n✅ Training fixes test completed!")
if __name__ == "__main__":
asyncio.run(test_training_fixes())