Files
gogo2/test_training_data_collection.py
Dobromir Popov df17a99247 wip
2025-07-23 13:39:41 +03:00

118 lines
4.4 KiB
Python

#!/usr/bin/env python3
"""
Test Training Data Collection and Checkpoint Storage
This script tests if the training system is working correctly and storing checkpoints.
"""
import os
import sys
import logging
import asyncio
from pathlib import Path
from datetime import datetime
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import get_config, setup_logging
from core.orchestrator import TradingOrchestrator
from core.data_provider import DataProvider
from utils.checkpoint_manager import get_checkpoint_manager
# Setup logging
setup_logging()
logger = logging.getLogger(__name__)
async def test_training_system():
"""Test if the training system is working and storing checkpoints"""
logger.info("Testing training system and checkpoint storage...")
# Initialize components
data_provider = DataProvider()
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
# Get checkpoint manager
checkpoint_manager = get_checkpoint_manager()
# Check if checkpoint directory exists
checkpoint_dir = Path("models/saved")
if not checkpoint_dir.exists():
logger.warning(f"Checkpoint directory {checkpoint_dir} does not exist. Creating...")
checkpoint_dir.mkdir(parents=True, exist_ok=True)
# Check for existing checkpoints
checkpoint_stats = checkpoint_manager.get_checkpoint_stats()
logger.info(f"Found {checkpoint_stats['total_checkpoints']} existing checkpoints.")
logger.info(f"Total checkpoint size: {checkpoint_stats['total_size_mb']:.2f} MB")
# List checkpoint files
checkpoint_files = list(checkpoint_dir.glob("*.pt"))
if checkpoint_files:
logger.info("Recent checkpoint files:")
for i, file in enumerate(sorted(checkpoint_files, key=lambda f: f.stat().st_mtime, reverse=True)[:5]):
file_size = file.stat().st_size / (1024 * 1024) # Convert to MB
modified_time = datetime.fromtimestamp(file.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S")
logger.info(f" {i+1}. {file.name} ({file_size:.2f} MB, modified: {modified_time})")
else:
logger.warning("No checkpoint files found.")
# Test training by making trading decisions
logger.info("\nTesting training by making trading decisions...")
symbols = orchestrator.symbols
for symbol in symbols:
logger.info(f"Making trading decision for {symbol}...")
decision = await orchestrator.make_trading_decision(symbol)
if decision:
logger.info(f"Decision for {symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
else:
logger.warning(f"No decision made for {symbol}.")
# Check if new checkpoints were created
new_checkpoint_stats = checkpoint_manager.get_checkpoint_stats()
new_checkpoints = new_checkpoint_stats['total_checkpoints'] - checkpoint_stats['total_checkpoints']
if new_checkpoints > 0:
logger.info(f"\nSuccess! {new_checkpoints} new checkpoints were created.")
logger.info("Training system is working correctly.")
else:
logger.warning("\nNo new checkpoints were created.")
logger.warning("This could be normal if the training threshold wasn't met.")
logger.warning("Check the orchestrator's checkpoint saving logic.")
# Check model states
model_states = orchestrator.get_model_states()
logger.info("\nModel states:")
for model_name, state in model_states.items():
checkpoint_loaded = state.get('checkpoint_loaded', False)
checkpoint_filename = state.get('checkpoint_filename', 'none')
current_loss = state.get('current_loss', None)
status = "LOADED" if checkpoint_loaded else "FRESH"
loss_str = f"{current_loss:.4f}" if current_loss is not None else "N/A"
logger.info(f" {model_name}: {status}, Loss: {loss_str}, Checkpoint: {checkpoint_filename}")
return new_checkpoints > 0
async def main():
"""Main function"""
logger.info("=" * 70)
logger.info("TRAINING SYSTEM TEST")
logger.info("=" * 70)
success = await test_training_system()
if success:
logger.info("\nTraining system test passed!")
return 0
else:
logger.warning("\nTraining system test completed with warnings.")
logger.info("Check the logs for details.")
return 1
if __name__ == "__main__":
sys.exit(asyncio.run(main()))