timezones
This commit is contained in:
118
test_training_fixes.py
Normal file
118
test_training_fixes.py
Normal file
@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Training Fixes
|
||||
|
||||
This script tests the fixes for CNN adapter and DQN training issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
async def test_training_fixes():
|
||||
"""Test the training fixes"""
|
||||
print("=== Testing Training Fixes ===")
|
||||
|
||||
# Initialize orchestrator
|
||||
print("1. Initializing orchestrator...")
|
||||
data_provider = DataProvider()
|
||||
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
||||
|
||||
# Wait for initialization
|
||||
await asyncio.sleep(3)
|
||||
|
||||
# Check CNN adapter initialization
|
||||
print("\n2. Checking CNN adapter initialization:")
|
||||
if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter:
|
||||
print(" ✅ CNN adapter is properly initialized")
|
||||
print(f" CNN adapter type: {type(orchestrator.cnn_adapter)}")
|
||||
else:
|
||||
print(" ❌ CNN adapter is None or missing")
|
||||
|
||||
# Check DQN agent initialization
|
||||
print("\n3. Checking DQN agent initialization:")
|
||||
if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent:
|
||||
print(" ✅ DQN agent is properly initialized")
|
||||
print(f" DQN agent type: {type(orchestrator.rl_agent)}")
|
||||
if hasattr(orchestrator.rl_agent, 'policy_net'):
|
||||
print(" ✅ DQN policy network is available")
|
||||
else:
|
||||
print(" ❌ DQN policy network is missing")
|
||||
else:
|
||||
print(" ❌ DQN agent is None or missing")
|
||||
|
||||
# Test CNN predictions
|
||||
print("\n4. Testing CNN predictions:")
|
||||
try:
|
||||
predictions = await orchestrator._get_all_predictions('ETH/USDT')
|
||||
cnn_predictions = [p for p in predictions if 'cnn' in p.model_name.lower()]
|
||||
if cnn_predictions:
|
||||
print(f" ✅ Got {len(cnn_predictions)} CNN predictions")
|
||||
for pred in cnn_predictions:
|
||||
print(f" CNN prediction: {pred.action} (confidence: {pred.confidence:.3f})")
|
||||
else:
|
||||
print(" ❌ No CNN predictions received")
|
||||
except Exception as e:
|
||||
print(f" ❌ CNN prediction failed: {e}")
|
||||
|
||||
# Test training with validation
|
||||
print("\n5. Testing training with validation:")
|
||||
for i in range(3):
|
||||
print(f" Training iteration {i+1}/3...")
|
||||
|
||||
# Create training records for different models
|
||||
training_records = [
|
||||
{
|
||||
'model_name': 'enhanced_cnn',
|
||||
'model_input': np.random.randn(7850),
|
||||
'prediction': {'action': 'BUY', 'confidence': 0.7},
|
||||
'symbol': 'ETH/USDT',
|
||||
'timestamp': datetime.now()
|
||||
},
|
||||
{
|
||||
'model_name': 'dqn_agent',
|
||||
'model_input': np.random.randn(7850),
|
||||
'prediction': {'action': 'SELL', 'confidence': 0.8},
|
||||
'symbol': 'ETH/USDT',
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
]
|
||||
|
||||
for record in training_records:
|
||||
try:
|
||||
success = await orchestrator._train_model_on_outcome(
|
||||
record, True, 0.5, 1.0
|
||||
)
|
||||
if success:
|
||||
print(f" ✅ Training succeeded for {record['model_name']}")
|
||||
else:
|
||||
print(f" ⚠️ Training failed for {record['model_name']}")
|
||||
except Exception as e:
|
||||
print(f" ❌ Training error for {record['model_name']}: {e}")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Show final statistics
|
||||
print("\n6. Final model statistics:")
|
||||
orchestrator.log_model_statistics(detailed=True)
|
||||
|
||||
# Check for overfitting warnings
|
||||
print("\n7. Checking for training quality:")
|
||||
summary = orchestrator.get_model_statistics_summary()
|
||||
for model_name, stats in summary.items():
|
||||
if stats['total_trainings'] > 0:
|
||||
print(f" {model_name}: {stats['total_trainings']} trainings, "
|
||||
f"avg time: {stats['average_training_time_ms']:.1f}ms")
|
||||
if stats['current_loss'] is not None:
|
||||
if stats['current_loss'] < 0.001:
|
||||
print(f" ⚠️ {model_name} has very low loss ({stats['current_loss']:.6f}) - check for overfitting")
|
||||
else:
|
||||
print(f" ✅ {model_name} has reasonable loss ({stats['current_loss']:.6f})")
|
||||
|
||||
print("\n✅ Training fixes test completed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_training_fixes())
|
Reference in New Issue
Block a user