118 lines
4.4 KiB
Python
118 lines
4.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test Training Data Collection and Checkpoint Storage
|
|
|
|
This script tests if the training system is working correctly and storing checkpoints.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import logging
|
|
import asyncio
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from core.config import get_config, setup_logging
|
|
from core.orchestrator import TradingOrchestrator
|
|
from core.data_provider import DataProvider
|
|
from utils.checkpoint_manager import get_checkpoint_manager
|
|
|
|
# Setup logging
|
|
setup_logging()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
async def test_training_system():
|
|
"""Test if the training system is working and storing checkpoints"""
|
|
logger.info("Testing training system and checkpoint storage...")
|
|
|
|
# Initialize components
|
|
data_provider = DataProvider()
|
|
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
|
|
|
|
# Get checkpoint manager
|
|
checkpoint_manager = get_checkpoint_manager()
|
|
|
|
# Check if checkpoint directory exists
|
|
checkpoint_dir = Path("models/saved")
|
|
if not checkpoint_dir.exists():
|
|
logger.warning(f"Checkpoint directory {checkpoint_dir} does not exist. Creating...")
|
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Check for existing checkpoints
|
|
checkpoint_stats = checkpoint_manager.get_checkpoint_stats()
|
|
logger.info(f"Found {checkpoint_stats['total_checkpoints']} existing checkpoints.")
|
|
logger.info(f"Total checkpoint size: {checkpoint_stats['total_size_mb']:.2f} MB")
|
|
|
|
# List checkpoint files
|
|
checkpoint_files = list(checkpoint_dir.glob("*.pt"))
|
|
if checkpoint_files:
|
|
logger.info("Recent checkpoint files:")
|
|
for i, file in enumerate(sorted(checkpoint_files, key=lambda f: f.stat().st_mtime, reverse=True)[:5]):
|
|
file_size = file.stat().st_size / (1024 * 1024) # Convert to MB
|
|
modified_time = datetime.fromtimestamp(file.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S")
|
|
logger.info(f" {i+1}. {file.name} ({file_size:.2f} MB, modified: {modified_time})")
|
|
else:
|
|
logger.warning("No checkpoint files found.")
|
|
|
|
# Test training by making trading decisions
|
|
logger.info("\nTesting training by making trading decisions...")
|
|
symbols = orchestrator.symbols
|
|
|
|
for symbol in symbols:
|
|
logger.info(f"Making trading decision for {symbol}...")
|
|
decision = await orchestrator.make_trading_decision(symbol)
|
|
|
|
if decision:
|
|
logger.info(f"Decision for {symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
|
|
else:
|
|
logger.warning(f"No decision made for {symbol}.")
|
|
|
|
# Check if new checkpoints were created
|
|
new_checkpoint_stats = checkpoint_manager.get_checkpoint_stats()
|
|
new_checkpoints = new_checkpoint_stats['total_checkpoints'] - checkpoint_stats['total_checkpoints']
|
|
|
|
if new_checkpoints > 0:
|
|
logger.info(f"\nSuccess! {new_checkpoints} new checkpoints were created.")
|
|
logger.info("Training system is working correctly.")
|
|
else:
|
|
logger.warning("\nNo new checkpoints were created.")
|
|
logger.warning("This could be normal if the training threshold wasn't met.")
|
|
logger.warning("Check the orchestrator's checkpoint saving logic.")
|
|
|
|
# Check model states
|
|
model_states = orchestrator.get_model_states()
|
|
logger.info("\nModel states:")
|
|
for model_name, state in model_states.items():
|
|
checkpoint_loaded = state.get('checkpoint_loaded', False)
|
|
checkpoint_filename = state.get('checkpoint_filename', 'none')
|
|
current_loss = state.get('current_loss', None)
|
|
|
|
status = "LOADED" if checkpoint_loaded else "FRESH"
|
|
loss_str = f"{current_loss:.4f}" if current_loss is not None else "N/A"
|
|
|
|
logger.info(f" {model_name}: {status}, Loss: {loss_str}, Checkpoint: {checkpoint_filename}")
|
|
|
|
return new_checkpoints > 0
|
|
|
|
async def main():
|
|
"""Main function"""
|
|
logger.info("=" * 70)
|
|
logger.info("TRAINING SYSTEM TEST")
|
|
logger.info("=" * 70)
|
|
|
|
success = await test_training_system()
|
|
|
|
if success:
|
|
logger.info("\nTraining system test passed!")
|
|
return 0
|
|
else:
|
|
logger.warning("\nTraining system test completed with warnings.")
|
|
logger.info("Check the logs for details.")
|
|
return 1
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(asyncio.run(main())) |