wip training and inference stats

This commit is contained in:
Dobromir Popov
2025-07-27 19:20:23 +03:00
parent 2a21878ed5
commit 87c0dc8ac4
8 changed files with 833 additions and 84 deletions

139
test_model_training.py Normal file
View File

@ -0,0 +1,139 @@
#!/usr/bin/env python3
"""
Test Model Training Implementation
This script tests the improved model training functionality.
"""
import asyncio
import time
import numpy as np
from datetime import datetime
from core.orchestrator import TradingOrchestrator
from core.data_provider import DataProvider
async def test_model_training():
"""Test the improved model training system"""
print("=== Testing Model Training System ===")
# Initialize orchestrator
print("1. Initializing orchestrator...")
data_provider = DataProvider()
orchestrator = TradingOrchestrator(data_provider=data_provider)
# Wait for initialization
await asyncio.sleep(3)
# Show initial model statistics
print("\n2. Initial model statistics:")
orchestrator.log_model_statistics()
# Run predictions to generate training data
print("\n3. Running predictions to generate training data...")
predictions_data = []
for i in range(3):
print(f" Running prediction batch {i+1}/3...")
predictions = await orchestrator._get_all_predictions('ETH/USDT')
print(f" Got {len(predictions)} predictions")
# Store prediction data for training simulation
for pred in predictions:
predictions_data.append({
'model_name': pred.model_name,
'prediction': {
'action': pred.action,
'confidence': pred.confidence
},
'timestamp': pred.timestamp,
'symbol': 'ETH/USDT'
})
await asyncio.sleep(1)
print(f"\n4. Collected {len(predictions_data)} predictions for training")
# Simulate training with different outcomes
print("\n5. Testing training with simulated outcomes...")
for i, pred_data in enumerate(predictions_data[:6]): # Test first 6 predictions
# Simulate market outcome
was_correct = i % 2 == 0 # Alternate between correct and incorrect
price_change_pct = 0.5 if was_correct else -0.3
sophisticated_reward = 1.0 if was_correct else -0.5
# Create training record
training_record = {
'model_name': pred_data['model_name'],
'model_input': np.random.randn(7850), # Simulate model input
'prediction': pred_data['prediction'],
'symbol': pred_data['symbol'],
'timestamp': pred_data['timestamp']
}
print(f" Training {pred_data['model_name']}: "
f"action={pred_data['prediction']['action']}, "
f"correct={was_correct}, reward={sophisticated_reward}")
# Test the training method
try:
await orchestrator._train_model_on_outcome(
training_record, was_correct, price_change_pct, sophisticated_reward
)
print(f" ✅ Training completed for {pred_data['model_name']}")
except Exception as e:
print(f" ❌ Training failed for {pred_data['model_name']}: {e}")
# Show updated statistics
print("\n6. Updated model statistics after training:")
orchestrator.log_model_statistics(detailed=True)
# Test specific model training methods
print("\n7. Testing specific model training methods...")
# Test DQN training
if 'dqn_agent' in orchestrator.model_statistics:
print(" Testing DQN agent training...")
dqn_record = {
'model_name': 'dqn_agent',
'model_input': np.random.randn(7850),
'prediction': {'action': 'BUY', 'confidence': 0.8},
'symbol': 'ETH/USDT',
'timestamp': datetime.now()
}
try:
await orchestrator._train_model_on_outcome(dqn_record, True, 0.5, 1.0)
print(" ✅ DQN training test passed")
except Exception as e:
print(f" ❌ DQN training test failed: {e}")
# Test CNN training
if 'enhanced_cnn' in orchestrator.model_statistics:
print(" Testing CNN model training...")
cnn_record = {
'model_name': 'enhanced_cnn',
'model_input': np.random.randn(7850),
'prediction': {'action': 'SELL', 'confidence': 0.6},
'symbol': 'ETH/USDT',
'timestamp': datetime.now()
}
try:
await orchestrator._train_model_on_outcome(cnn_record, False, -0.3, -0.5)
print(" ✅ CNN training test passed")
except Exception as e:
print(f" ❌ CNN training test failed: {e}")
# Show final statistics
print("\n8. Final model statistics:")
summary = orchestrator.get_model_statistics_summary()
for model_name, stats in summary.items():
print(f" {model_name}:")
print(f" Inferences: {stats['total_inferences']}")
print(f" Rate: {stats['inference_rate_per_minute']:.1f}/min")
print(f" Current loss: {stats['current_loss']}")
print(f" Last prediction: {stats['last_prediction']} ({stats['last_confidence']})")
print("\n✅ Model training test completed!")
if __name__ == "__main__":
asyncio.run(test_model_training())