wip training and inference stats
This commit is contained in:
139
test_model_training.py
Normal file
139
test_model_training.py
Normal file
@ -0,0 +1,139 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Model Training Implementation
|
||||
|
||||
This script tests the improved model training functionality.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
async def test_model_training():
|
||||
"""Test the improved model training system"""
|
||||
print("=== Testing Model Training System ===")
|
||||
|
||||
# Initialize orchestrator
|
||||
print("1. Initializing orchestrator...")
|
||||
data_provider = DataProvider()
|
||||
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
||||
|
||||
# Wait for initialization
|
||||
await asyncio.sleep(3)
|
||||
|
||||
# Show initial model statistics
|
||||
print("\n2. Initial model statistics:")
|
||||
orchestrator.log_model_statistics()
|
||||
|
||||
# Run predictions to generate training data
|
||||
print("\n3. Running predictions to generate training data...")
|
||||
predictions_data = []
|
||||
|
||||
for i in range(3):
|
||||
print(f" Running prediction batch {i+1}/3...")
|
||||
predictions = await orchestrator._get_all_predictions('ETH/USDT')
|
||||
print(f" Got {len(predictions)} predictions")
|
||||
|
||||
# Store prediction data for training simulation
|
||||
for pred in predictions:
|
||||
predictions_data.append({
|
||||
'model_name': pred.model_name,
|
||||
'prediction': {
|
||||
'action': pred.action,
|
||||
'confidence': pred.confidence
|
||||
},
|
||||
'timestamp': pred.timestamp,
|
||||
'symbol': 'ETH/USDT'
|
||||
})
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
print(f"\n4. Collected {len(predictions_data)} predictions for training")
|
||||
|
||||
# Simulate training with different outcomes
|
||||
print("\n5. Testing training with simulated outcomes...")
|
||||
|
||||
for i, pred_data in enumerate(predictions_data[:6]): # Test first 6 predictions
|
||||
# Simulate market outcome
|
||||
was_correct = i % 2 == 0 # Alternate between correct and incorrect
|
||||
price_change_pct = 0.5 if was_correct else -0.3
|
||||
sophisticated_reward = 1.0 if was_correct else -0.5
|
||||
|
||||
# Create training record
|
||||
training_record = {
|
||||
'model_name': pred_data['model_name'],
|
||||
'model_input': np.random.randn(7850), # Simulate model input
|
||||
'prediction': pred_data['prediction'],
|
||||
'symbol': pred_data['symbol'],
|
||||
'timestamp': pred_data['timestamp']
|
||||
}
|
||||
|
||||
print(f" Training {pred_data['model_name']}: "
|
||||
f"action={pred_data['prediction']['action']}, "
|
||||
f"correct={was_correct}, reward={sophisticated_reward}")
|
||||
|
||||
# Test the training method
|
||||
try:
|
||||
await orchestrator._train_model_on_outcome(
|
||||
training_record, was_correct, price_change_pct, sophisticated_reward
|
||||
)
|
||||
print(f" ✅ Training completed for {pred_data['model_name']}")
|
||||
except Exception as e:
|
||||
print(f" ❌ Training failed for {pred_data['model_name']}: {e}")
|
||||
|
||||
# Show updated statistics
|
||||
print("\n6. Updated model statistics after training:")
|
||||
orchestrator.log_model_statistics(detailed=True)
|
||||
|
||||
# Test specific model training methods
|
||||
print("\n7. Testing specific model training methods...")
|
||||
|
||||
# Test DQN training
|
||||
if 'dqn_agent' in orchestrator.model_statistics:
|
||||
print(" Testing DQN agent training...")
|
||||
dqn_record = {
|
||||
'model_name': 'dqn_agent',
|
||||
'model_input': np.random.randn(7850),
|
||||
'prediction': {'action': 'BUY', 'confidence': 0.8},
|
||||
'symbol': 'ETH/USDT',
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
try:
|
||||
await orchestrator._train_model_on_outcome(dqn_record, True, 0.5, 1.0)
|
||||
print(" ✅ DQN training test passed")
|
||||
except Exception as e:
|
||||
print(f" ❌ DQN training test failed: {e}")
|
||||
|
||||
# Test CNN training
|
||||
if 'enhanced_cnn' in orchestrator.model_statistics:
|
||||
print(" Testing CNN model training...")
|
||||
cnn_record = {
|
||||
'model_name': 'enhanced_cnn',
|
||||
'model_input': np.random.randn(7850),
|
||||
'prediction': {'action': 'SELL', 'confidence': 0.6},
|
||||
'symbol': 'ETH/USDT',
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
try:
|
||||
await orchestrator._train_model_on_outcome(cnn_record, False, -0.3, -0.5)
|
||||
print(" ✅ CNN training test passed")
|
||||
except Exception as e:
|
||||
print(f" ❌ CNN training test failed: {e}")
|
||||
|
||||
# Show final statistics
|
||||
print("\n8. Final model statistics:")
|
||||
summary = orchestrator.get_model_statistics_summary()
|
||||
for model_name, stats in summary.items():
|
||||
print(f" {model_name}:")
|
||||
print(f" Inferences: {stats['total_inferences']}")
|
||||
print(f" Rate: {stats['inference_rate_per_minute']:.1f}/min")
|
||||
print(f" Current loss: {stats['current_loss']}")
|
||||
print(f" Last prediction: {stats['last_prediction']} ({stats['last_confidence']})")
|
||||
|
||||
print("\n✅ Model training test completed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_model_training())
|
Reference in New Issue
Block a user