129 lines
4.7 KiB
Python
129 lines
4.7 KiB
Python
"""
|
|
Test script to verify model training works correctly
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
|
|
import logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
|
from core.data_provider import DataProvider
|
|
from core.orchestrator import TradingOrchestrator
|
|
from ANNOTATE.core.annotation_manager import AnnotationManager
|
|
from ANNOTATE.core.real_training_adapter import RealTrainingAdapter
|
|
|
|
def test_training():
|
|
"""Test the complete training flow"""
|
|
|
|
print("=" * 80)
|
|
print("Testing Model Training Flow")
|
|
print("=" * 80)
|
|
|
|
# Step 1: Initialize components
|
|
print("\n1. Initializing components...")
|
|
data_provider = DataProvider()
|
|
print(" DataProvider initialized")
|
|
|
|
orchestrator = TradingOrchestrator(
|
|
data_provider=data_provider,
|
|
enhanced_rl_training=True
|
|
)
|
|
print(" Orchestrator initialized")
|
|
|
|
# Step 2: Initialize ML models
|
|
print("\n2. Initializing ML models...")
|
|
orchestrator._initialize_ml_models()
|
|
|
|
# Check what models are available
|
|
available_models = []
|
|
if hasattr(orchestrator, 'rl_agent') and orchestrator.rl_agent:
|
|
available_models.append('DQN')
|
|
print(" DQN model available")
|
|
if hasattr(orchestrator, 'cnn_model') and orchestrator.cnn_model:
|
|
available_models.append('CNN')
|
|
print(" CNN model available")
|
|
if hasattr(orchestrator, 'primary_transformer') and orchestrator.primary_transformer:
|
|
available_models.append('Transformer')
|
|
print(" Transformer model available")
|
|
|
|
# Check if trainer is available
|
|
if hasattr(orchestrator, 'primary_transformer_trainer') and orchestrator.primary_transformer_trainer:
|
|
trainer = orchestrator.primary_transformer_trainer
|
|
print(f" Transformer trainer available: {type(trainer).__name__}")
|
|
|
|
# List available methods
|
|
methods = [m for m in dir(trainer) if not m.startswith('_') and callable(getattr(trainer, m))]
|
|
print(f" 📋 Trainer methods: {', '.join(methods[:10])}...")
|
|
|
|
if not available_models:
|
|
print(" No models available!")
|
|
return
|
|
|
|
print(f"\n Available models: {', '.join(available_models)}")
|
|
|
|
# Step 3: Initialize training adapter
|
|
print("\n3. Initializing training adapter...")
|
|
training_adapter = RealTrainingAdapter(orchestrator, data_provider)
|
|
print(" Training adapter initialized")
|
|
|
|
# Step 4: Load test cases
|
|
print("\n4. Loading test cases...")
|
|
annotation_manager = AnnotationManager()
|
|
test_cases = annotation_manager.get_all_test_cases()
|
|
print(f" Loaded {len(test_cases)} test cases")
|
|
|
|
if len(test_cases) == 0:
|
|
print(" No test cases available - create some annotations first!")
|
|
return
|
|
|
|
# Step 5: Start training
|
|
print(f"\n5. Starting training with Transformer model...")
|
|
print(f" Test cases: {len(test_cases)}")
|
|
|
|
try:
|
|
training_id = training_adapter.start_training(
|
|
model_name='Transformer',
|
|
test_cases=test_cases
|
|
)
|
|
print(f" Training started: {training_id}")
|
|
|
|
# Step 6: Monitor training progress
|
|
print("\n6. Monitoring training progress...")
|
|
import time
|
|
|
|
for i in range(30): # Monitor for 30 seconds
|
|
time.sleep(1)
|
|
progress = training_adapter.get_training_progress(training_id)
|
|
|
|
if progress['status'] == 'completed':
|
|
print(f"\n Training completed!")
|
|
print(f" Final loss: {progress['final_loss']:.6f}")
|
|
print(f" Accuracy: {progress['accuracy']:.2%}")
|
|
print(f" Duration: {progress['duration_seconds']:.2f}s")
|
|
break
|
|
elif progress['status'] == 'failed':
|
|
print(f"\n Training failed!")
|
|
print(f" Error: {progress['error']}")
|
|
break
|
|
elif progress['status'] == 'running':
|
|
print(f" Epoch {progress['current_epoch']}/{progress['total_epochs']}, Loss: {progress['current_loss']:.6f}", end='\r')
|
|
else:
|
|
print(f"\n Training still running after 30 seconds")
|
|
progress = training_adapter.get_training_progress(training_id)
|
|
print(f" Status: {progress['status']}")
|
|
print(f" Epoch: {progress['current_epoch']}/{progress['total_epochs']}")
|
|
|
|
except Exception as e:
|
|
print(f" Training failed with exception: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
print("\n" + "=" * 80)
|
|
print("Test Complete")
|
|
print("=" * 80)
|
|
|
|
if __name__ == "__main__":
|
|
test_training()
|