#!/usr/bin/env python3 """ Test Training Data Collection and Checkpoint Storage This script tests if the training system is working correctly and storing checkpoints. """ import os import sys import logging import asyncio from pathlib import Path from datetime import datetime # Add project root to path project_root = Path(__file__).parent sys.path.insert(0, str(project_root)) from core.config import get_config, setup_logging from core.orchestrator import TradingOrchestrator from core.data_provider import DataProvider from utils.checkpoint_manager import get_checkpoint_manager # Setup logging setup_logging() logger = logging.getLogger(__name__) async def test_training_system(): """Test if the training system is working and storing checkpoints""" logger.info("Testing training system and checkpoint storage...") # Initialize components data_provider = DataProvider() orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True) # Get checkpoint manager checkpoint_manager = get_checkpoint_manager() # Check if checkpoint directory exists checkpoint_dir = Path("models/saved") if not checkpoint_dir.exists(): logger.warning(f"Checkpoint directory {checkpoint_dir} does not exist. Creating...") checkpoint_dir.mkdir(parents=True, exist_ok=True) # Check for existing checkpoints checkpoint_stats = checkpoint_manager.get_checkpoint_stats() logger.info(f"Found {checkpoint_stats['total_checkpoints']} existing checkpoints.") logger.info(f"Total checkpoint size: {checkpoint_stats['total_size_mb']:.2f} MB") # List checkpoint files checkpoint_files = list(checkpoint_dir.glob("*.pt")) if checkpoint_files: logger.info("Recent checkpoint files:") for i, file in enumerate(sorted(checkpoint_files, key=lambda f: f.stat().st_mtime, reverse=True)[:5]): file_size = file.stat().st_size / (1024 * 1024) # Convert to MB modified_time = datetime.fromtimestamp(file.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S") logger.info(f" {i+1}. {file.name} ({file_size:.2f} MB, modified: {modified_time})") else: logger.warning("No checkpoint files found.") # Test training by making trading decisions logger.info("\nTesting training by making trading decisions...") symbols = orchestrator.symbols for symbol in symbols: logger.info(f"Making trading decision for {symbol}...") decision = await orchestrator.make_trading_decision(symbol) if decision: logger.info(f"Decision for {symbol}: {decision.action} (confidence: {decision.confidence:.3f})") else: logger.warning(f"No decision made for {symbol}.") # Check if new checkpoints were created new_checkpoint_stats = checkpoint_manager.get_checkpoint_stats() new_checkpoints = new_checkpoint_stats['total_checkpoints'] - checkpoint_stats['total_checkpoints'] if new_checkpoints > 0: logger.info(f"\nSuccess! {new_checkpoints} new checkpoints were created.") logger.info("Training system is working correctly.") else: logger.warning("\nNo new checkpoints were created.") logger.warning("This could be normal if the training threshold wasn't met.") logger.warning("Check the orchestrator's checkpoint saving logic.") # Check model states model_states = orchestrator.get_model_states() logger.info("\nModel states:") for model_name, state in model_states.items(): checkpoint_loaded = state.get('checkpoint_loaded', False) checkpoint_filename = state.get('checkpoint_filename', 'none') current_loss = state.get('current_loss', None) status = "LOADED" if checkpoint_loaded else "FRESH" loss_str = f"{current_loss:.4f}" if current_loss is not None else "N/A" logger.info(f" {model_name}: {status}, Loss: {loss_str}, Checkpoint: {checkpoint_filename}") return new_checkpoints > 0 async def main(): """Main function""" logger.info("=" * 70) logger.info("TRAINING SYSTEM TEST") logger.info("=" * 70) success = await test_training_system() if success: logger.info("\nTraining system test passed!") return 0 else: logger.warning("\nTraining system test completed with warnings.") logger.info("Check the logs for details.") return 1 if __name__ == "__main__": sys.exit(asyncio.run(main()))