Files
gogo2/test_training.py
2025-10-25 16:35:08 +03:00

129 lines
4.7 KiB
Python

"""
Test script to verify model training works correctly
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
from core.data_provider import DataProvider
from core.orchestrator import TradingOrchestrator
from ANNOTATE.core.annotation_manager import AnnotationManager
from ANNOTATE.core.real_training_adapter import RealTrainingAdapter
def test_training():
"""Test the complete training flow"""
print("=" * 80)
print("Testing Model Training Flow")
print("=" * 80)
# Step 1: Initialize components
print("\n1. Initializing components...")
data_provider = DataProvider()
print(" DataProvider initialized")
orchestrator = TradingOrchestrator(
data_provider=data_provider,
enhanced_rl_training=True
)
print(" Orchestrator initialized")
# Step 2: Initialize ML models
print("\n2. Initializing ML models...")
orchestrator._initialize_ml_models()
# Check what models are available
available_models = []
if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent:
available_models.append('DQN')
print(" DQN model available")
if hasattr(orchestrator, 'cnn_model') and orchestrator.cnn_model:
available_models.append('CNN')
print(" CNN model available")
if hasattr(orchestrator, 'primary_transformer') and orchestrator.primary_transformer:
available_models.append('Transformer')
print(" Transformer model available")
# Check if trainer is available
if hasattr(orchestrator, 'primary_transformer_trainer') and orchestrator.primary_transformer_trainer:
trainer = orchestrator.primary_transformer_trainer
print(f" Transformer trainer available: {type(trainer).__name__}")
# List available methods
methods = [m for m in dir(trainer) if not m.startswith('_') and callable(getattr(trainer, m))]
print(f" 📋 Trainer methods: {', '.join(methods[:10])}...")
if not available_models:
print(" No models available!")
return
print(f"\n Available models: {', '.join(available_models)}")
# Step 3: Initialize training adapter
print("\n3. Initializing training adapter...")
training_adapter = RealTrainingAdapter(orchestrator, data_provider)
print(" Training adapter initialized")
# Step 4: Load test cases
print("\n4. Loading test cases...")
annotation_manager = AnnotationManager()
test_cases = annotation_manager.get_all_test_cases()
print(f" Loaded {len(test_cases)} test cases")
if len(test_cases) == 0:
print(" No test cases available - create some annotations first!")
return
# Step 5: Start training
print(f"\n5. Starting training with Transformer model...")
print(f" Test cases: {len(test_cases)}")
try:
training_id = training_adapter.start_training(
model_name='Transformer',
test_cases=test_cases
)
print(f" Training started: {training_id}")
# Step 6: Monitor training progress
print("\n6. Monitoring training progress...")
import time
for i in range(30): # Monitor for 30 seconds
time.sleep(1)
progress = training_adapter.get_training_progress(training_id)
if progress['status'] == 'completed':
print(f"\n Training completed!")
print(f" Final loss: {progress['final_loss']:.6f}")
print(f" Accuracy: {progress['accuracy']:.2%}")
print(f" Duration: {progress['duration_seconds']:.2f}s")
break
elif progress['status'] == 'failed':
print(f"\n Training failed!")
print(f" Error: {progress['error']}")
break
elif progress['status'] == 'running':
print(f" Epoch {progress['current_epoch']}/{progress['total_epochs']}, Loss: {progress['current_loss']:.6f}", end='\r')
else:
print(f"\n Training still running after 30 seconds")
progress = training_adapter.get_training_progress(training_id)
print(f" Status: {progress['status']}")
print(f" Epoch: {progress['current_epoch']}/{progress['total_epochs']}")
except Exception as e:
print(f" Training failed with exception: {e}")
import traceback
traceback.print_exc()
print("\n" + "=" * 80)
print("Test Complete")
print("=" * 80)
if __name__ == "__main__":
test_training()