139 lines
5.1 KiB
Python
139 lines
5.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test Model Training Implementation
|
|
|
|
This script tests the improved model training functionality.
|
|
"""
|
|
|
|
import asyncio
|
|
import time
|
|
import numpy as np
|
|
from datetime import datetime
|
|
from core.orchestrator import TradingOrchestrator
|
|
from core.data_provider import DataProvider
|
|
|
|
async def test_model_training():
|
|
"""Test the improved model training system"""
|
|
print("=== Testing Model Training System ===")
|
|
|
|
# Initialize orchestrator
|
|
print("1. Initializing orchestrator...")
|
|
data_provider = DataProvider()
|
|
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
|
|
|
# Wait for initialization
|
|
await asyncio.sleep(3)
|
|
|
|
# Show initial model statistics
|
|
print("\n2. Initial model statistics:")
|
|
orchestrator.log_model_statistics()
|
|
|
|
# Run predictions to generate training data
|
|
print("\n3. Running predictions to generate training data...")
|
|
predictions_data = []
|
|
|
|
for i in range(3):
|
|
print(f" Running prediction batch {i+1}/3...")
|
|
predictions = await orchestrator._get_all_predictions('ETH/USDT')
|
|
print(f" Got {len(predictions)} predictions")
|
|
|
|
# Store prediction data for training simulation
|
|
for pred in predictions:
|
|
predictions_data.append({
|
|
'model_name': pred.model_name,
|
|
'prediction': {
|
|
'action': pred.action,
|
|
'confidence': pred.confidence
|
|
},
|
|
'timestamp': pred.timestamp,
|
|
'symbol': 'ETH/USDT'
|
|
})
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
print(f"\n4. Collected {len(predictions_data)} predictions for training")
|
|
|
|
# Simulate training with different outcomes
|
|
print("\n5. Testing training with simulated outcomes...")
|
|
|
|
for i, pred_data in enumerate(predictions_data[:6]): # Test first 6 predictions
|
|
# Simulate market outcome
|
|
was_correct = i % 2 == 0 # Alternate between correct and incorrect
|
|
price_change_pct = 0.5 if was_correct else -0.3
|
|
sophisticated_reward = 1.0 if was_correct else -0.5
|
|
|
|
# Create training record
|
|
training_record = {
|
|
'model_name': pred_data['model_name'],
|
|
'model_input': np.random.randn(7850), # Simulate model input
|
|
'prediction': pred_data['prediction'],
|
|
'symbol': pred_data['symbol'],
|
|
'timestamp': pred_data['timestamp']
|
|
}
|
|
|
|
print(f" Training {pred_data['model_name']}: "
|
|
f"action={pred_data['prediction']['action']}, "
|
|
f"correct={was_correct}, reward={sophisticated_reward}")
|
|
|
|
# Test the training method
|
|
try:
|
|
await orchestrator._train_model_on_outcome(
|
|
training_record, was_correct, price_change_pct, sophisticated_reward
|
|
)
|
|
print(f" ✅ Training completed for {pred_data['model_name']}")
|
|
except Exception as e:
|
|
print(f" ❌ Training failed for {pred_data['model_name']}: {e}")
|
|
|
|
# Show updated statistics
|
|
print("\n6. Updated model statistics after training:")
|
|
orchestrator.log_model_statistics(detailed=True)
|
|
|
|
# Test specific model training methods
|
|
print("\n7. Testing specific model training methods...")
|
|
|
|
# Test DQN training
|
|
if 'dqn_agent' in orchestrator.model_statistics:
|
|
print(" Testing DQN agent training...")
|
|
dqn_record = {
|
|
'model_name': 'dqn_agent',
|
|
'model_input': np.random.randn(7850),
|
|
'prediction': {'action': 'BUY', 'confidence': 0.8},
|
|
'symbol': 'ETH/USDT',
|
|
'timestamp': datetime.now()
|
|
}
|
|
try:
|
|
await orchestrator._train_model_on_outcome(dqn_record, True, 0.5, 1.0)
|
|
print(" ✅ DQN training test passed")
|
|
except Exception as e:
|
|
print(f" ❌ DQN training test failed: {e}")
|
|
|
|
# Test CNN training
|
|
if 'enhanced_cnn' in orchestrator.model_statistics:
|
|
print(" Testing CNN model training...")
|
|
cnn_record = {
|
|
'model_name': 'enhanced_cnn',
|
|
'model_input': np.random.randn(7850),
|
|
'prediction': {'action': 'SELL', 'confidence': 0.6},
|
|
'symbol': 'ETH/USDT',
|
|
'timestamp': datetime.now()
|
|
}
|
|
try:
|
|
await orchestrator._train_model_on_outcome(cnn_record, False, -0.3, -0.5)
|
|
print(" ✅ CNN training test passed")
|
|
except Exception as e:
|
|
print(f" ❌ CNN training test failed: {e}")
|
|
|
|
# Show final statistics
|
|
print("\n8. Final model statistics:")
|
|
summary = orchestrator.get_model_statistics_summary()
|
|
for model_name, stats in summary.items():
|
|
print(f" {model_name}:")
|
|
print(f" Inferences: {stats['total_inferences']}")
|
|
print(f" Rate: {stats['inference_rate_per_minute']:.1f}/min")
|
|
print(f" Current loss: {stats['current_loss']}")
|
|
print(f" Last prediction: {stats['last_prediction']} ({stats['last_confidence']})")
|
|
|
|
print("\n✅ Model training test completed!")
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(test_model_training()) |