checkbox manager and handling

This commit is contained in:
Dobromir Popov
2025-06-24 21:59:23 +03:00
parent 706eb13912
commit ab8c94d735
8 changed files with 1170 additions and 29 deletions

View File

@ -0,0 +1,525 @@
#!/usr/bin/env python3
"""
Comprehensive Checkpoint Management Integration
This script demonstrates how to integrate the checkpoint management system
across all training pipelines in the gogo2 project.
Features:
- DQN Agent training with automatic checkpointing
- CNN Model training with checkpoint management
- ExtremaTrainer with checkpoint persistence
- NegativeCaseTrainer with checkpoint integration
- Unified training orchestration with checkpoint coordination
"""
import asyncio
import logging
import time
import signal
import sys
import numpy as np
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, List
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('logs/checkpoint_integration.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# Import checkpoint management
from utils.checkpoint_manager import get_checkpoint_manager, get_checkpoint_stats
from utils.training_integration import get_training_integration
# Import training components
from NN.models.dqn_agent import DQNAgent
from NN.models.cnn_model import CNNModelTrainer, create_enhanced_cnn_model
from core.extrema_trainer import ExtremaTrainer
from core.negative_case_trainer import NegativeCaseTrainer
from core.data_provider import DataProvider
from core.config import get_config
class CheckpointIntegratedTrainingSystem:
"""Unified training system with comprehensive checkpoint management"""
def __init__(self):
"""Initialize the checkpoint-integrated training system"""
self.config = get_config()
self.running = False
# Checkpoint management
self.checkpoint_manager = get_checkpoint_manager()
self.training_integration = get_training_integration()
# Data provider
self.data_provider = DataProvider(
symbols=['ETH/USDT', 'BTC/USDT'],
timeframes=['1s', '1m', '1h', '1d']
)
# Training components with checkpoint management
self.dqn_agent = None
self.cnn_trainer = None
self.extrema_trainer = None
self.negative_case_trainer = None
# Training statistics
self.training_stats = {
'start_time': None,
'total_training_sessions': 0,
'checkpoints_saved': 0,
'models_loaded': 0,
'best_performances': {}
}
logger.info("Checkpoint-Integrated Training System initialized")
async def initialize_components(self):
"""Initialize all training components with checkpoint management"""
try:
logger.info("Initializing training components with checkpoint management...")
# Initialize data provider
await self.data_provider.start_real_time_streaming()
logger.info("Data provider streaming started")
# Initialize DQN Agent with checkpoint management
logger.info("Initializing DQN Agent with checkpoints...")
self.dqn_agent = DQNAgent(
state_shape=(100,), # Example state shape
n_actions=3,
model_name="integrated_dqn_agent",
enable_checkpoints=True
)
logger.info("✅ DQN Agent initialized with checkpoint management")
# Initialize CNN Model with checkpoint management
logger.info("Initializing CNN Model with checkpoints...")
cnn_model, self.cnn_trainer = create_enhanced_cnn_model(
input_size=60,
feature_dim=50,
output_size=3
)
# Update trainer with checkpoint management
self.cnn_trainer.model_name = "integrated_cnn_model"
self.cnn_trainer.enable_checkpoints = True
self.cnn_trainer.training_integration = self.training_integration
logger.info("✅ CNN Model initialized with checkpoint management")
# Initialize ExtremaTrainer with checkpoint management
logger.info("Initializing ExtremaTrainer with checkpoints...")
self.extrema_trainer = ExtremaTrainer(
data_provider=self.data_provider,
symbols=['ETH/USDT', 'BTC/USDT'],
model_name="integrated_extrema_trainer",
enable_checkpoints=True
)
await self.extrema_trainer.initialize_context_data()
logger.info("✅ ExtremaTrainer initialized with checkpoint management")
# Initialize NegativeCaseTrainer with checkpoint management
logger.info("Initializing NegativeCaseTrainer with checkpoints...")
self.negative_case_trainer = NegativeCaseTrainer(
model_name="integrated_negative_case_trainer",
enable_checkpoints=True
)
logger.info("✅ NegativeCaseTrainer initialized with checkpoint management")
# Load existing checkpoints for all components
self.training_stats['models_loaded'] = await self._load_all_checkpoints()
logger.info("All training components initialized successfully")
except Exception as e:
logger.error(f"Error initializing components: {e}")
raise
async def _load_all_checkpoints(self) -> int:
"""Load checkpoints for all training components"""
loaded_count = 0
try:
# DQN Agent checkpoint loading is handled in __init__
if hasattr(self.dqn_agent, 'episode_count') and self.dqn_agent.episode_count > 0:
loaded_count += 1
logger.info(f"DQN Agent resumed from episode {self.dqn_agent.episode_count}")
# CNN Trainer checkpoint loading is handled in __init__
if hasattr(self.cnn_trainer, 'epoch_count') and self.cnn_trainer.epoch_count > 0:
loaded_count += 1
logger.info(f"CNN Trainer resumed from epoch {self.cnn_trainer.epoch_count}")
# ExtremaTrainer checkpoint loading is handled in __init__
if hasattr(self.extrema_trainer, 'training_session_count') and self.extrema_trainer.training_session_count > 0:
loaded_count += 1
logger.info(f"ExtremaTrainer resumed from session {self.extrema_trainer.training_session_count}")
# NegativeCaseTrainer checkpoint loading is handled in __init__
if hasattr(self.negative_case_trainer, 'training_session_count') and self.negative_case_trainer.training_session_count > 0:
loaded_count += 1
logger.info(f"NegativeCaseTrainer resumed from session {self.negative_case_trainer.training_session_count}")
return loaded_count
except Exception as e:
logger.error(f"Error loading checkpoints: {e}")
return 0
async def run_integrated_training_loop(self):
"""Run the integrated training loop with checkpoint coordination"""
logger.info("Starting integrated training loop with checkpoint management...")
self.running = True
self.training_stats['start_time'] = datetime.now()
training_cycle = 0
try:
while self.running:
training_cycle += 1
cycle_start = time.time()
logger.info(f"=== Training Cycle {training_cycle} ===")
# DQN Training
dqn_results = await self._train_dqn_agent()
# CNN Training
cnn_results = await self._train_cnn_model()
# Extrema Detection Training
extrema_results = await self._train_extrema_detector()
# Negative Case Training (runs in background)
negative_results = await self._process_negative_cases()
# Coordinate checkpoint saving
await self._coordinate_checkpoint_saving(
dqn_results, cnn_results, extrema_results, negative_results
)
# Update statistics
self.training_stats['total_training_sessions'] += 1
# Log cycle summary
cycle_duration = time.time() - cycle_start
logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s")
# Wait before next cycle
await asyncio.sleep(60) # 1-minute cycles
except KeyboardInterrupt:
logger.info("Training interrupted by user")
except Exception as e:
logger.error(f"Error in training loop: {e}")
finally:
await self.shutdown()
async def _train_dqn_agent(self) -> Dict[str, Any]:
"""Train DQN agent with automatic checkpointing"""
try:
if not self.dqn_agent:
return {'status': 'skipped', 'reason': 'no_agent'}
# Simulate DQN training episode
episode_reward = 0.0
# Add some training experiences (simulate real training)
for _ in range(10): # Simulate 10 training steps
state = np.random.randn(100).astype(np.float32)
action = np.random.randint(0, 3)
reward = np.random.randn() * 0.1
next_state = np.random.randn(100).astype(np.float32)
done = np.random.random() < 0.1
self.dqn_agent.remember(state, action, reward, next_state, done)
episode_reward += reward
# Train if enough experiences
loss = 0.0
if len(self.dqn_agent.memory) >= self.dqn_agent.batch_size:
loss = self.dqn_agent.replay()
# Save checkpoint (automatic based on performance)
checkpoint_saved = self.dqn_agent.save_checkpoint(episode_reward)
if checkpoint_saved:
self.training_stats['checkpoints_saved'] += 1
return {
'status': 'completed',
'episode_reward': episode_reward,
'loss': loss,
'checkpoint_saved': checkpoint_saved,
'episode': self.dqn_agent.episode_count
}
except Exception as e:
logger.error(f"Error training DQN agent: {e}")
return {'status': 'error', 'error': str(e)}
async def _train_cnn_model(self) -> Dict[str, Any]:
"""Train CNN model with automatic checkpointing"""
try:
if not self.cnn_trainer:
return {'status': 'skipped', 'reason': 'no_trainer'}
# Simulate CNN training step
import torch
import numpy as np
batch_size = 32
input_size = 60
feature_dim = 50
# Generate synthetic training data
x = torch.randn(batch_size, input_size, feature_dim)
y = torch.randint(0, 3, (batch_size,))
# Training step
results = self.cnn_trainer.train_step(x, y)
# Simulate validation
val_x = torch.randn(16, input_size, feature_dim)
val_y = torch.randint(0, 3, (16,))
val_results = self.cnn_trainer.train_step(val_x, val_y)
# Save checkpoint (automatic based on performance)
checkpoint_saved = self.cnn_trainer.save_checkpoint(
train_accuracy=results.get('accuracy', 0.5),
val_accuracy=val_results.get('accuracy', 0.5),
train_loss=results.get('total_loss', 1.0),
val_loss=val_results.get('total_loss', 1.0)
)
if checkpoint_saved:
self.training_stats['checkpoints_saved'] += 1
return {
'status': 'completed',
'train_accuracy': results.get('accuracy', 0.5),
'val_accuracy': val_results.get('accuracy', 0.5),
'train_loss': results.get('total_loss', 1.0),
'val_loss': val_results.get('total_loss', 1.0),
'checkpoint_saved': checkpoint_saved,
'epoch': self.cnn_trainer.epoch_count
}
except Exception as e:
logger.error(f"Error training CNN model: {e}")
return {'status': 'error', 'error': str(e)}
async def _train_extrema_detector(self) -> Dict[str, Any]:
"""Train extrema detector with automatic checkpointing"""
try:
if not self.extrema_trainer:
return {'status': 'skipped', 'reason': 'no_trainer'}
# Update context data and detect extrema
update_results = self.extrema_trainer.update_context_data()
# Get training data
extrema_data = self.extrema_trainer.get_extrema_training_data(count=10)
# Simulate training accuracy improvement
if extrema_data:
self.extrema_trainer.training_stats['total_extrema_detected'] += len(extrema_data)
self.extrema_trainer.training_stats['successful_predictions'] += len(extrema_data) // 2
self.extrema_trainer.training_stats['failed_predictions'] += len(extrema_data) // 2
# Save checkpoint (automatic based on performance)
checkpoint_saved = self.extrema_trainer.save_checkpoint()
if checkpoint_saved:
self.training_stats['checkpoints_saved'] += 1
return {
'status': 'completed',
'extrema_detected': len(extrema_data),
'context_updates': sum(1 for success in update_results.values() if success),
'checkpoint_saved': checkpoint_saved,
'session': self.extrema_trainer.training_session_count
}
except Exception as e:
logger.error(f"Error training extrema detector: {e}")
return {'status': 'error', 'error': str(e)}
async def _process_negative_cases(self) -> Dict[str, Any]:
"""Process negative cases with automatic checkpointing"""
try:
if not self.negative_case_trainer:
return {'status': 'skipped', 'reason': 'no_trainer'}
# Simulate adding a negative case
if np.random.random() < 0.1: # 10% chance of negative case
trade_info = {
'symbol': 'ETH/USDT',
'action': 'BUY',
'price': 2000.0,
'pnl': -50.0, # Loss
'value': 1000.0,
'confidence': 0.7,
'timestamp': datetime.now()
}
market_data = {
'exit_price': 1950.0,
'state_before': {},
'state_after': {},
'tick_data': [],
'technical_indicators': {}
}
case_id = self.negative_case_trainer.add_losing_trade(trade_info, market_data)
# Simulate loss improvement
loss_improvement = np.random.random() * 0.1
# Save checkpoint (automatic based on performance)
checkpoint_saved = self.negative_case_trainer.save_checkpoint(loss_improvement)
if checkpoint_saved:
self.training_stats['checkpoints_saved'] += 1
return {
'status': 'completed',
'case_added': case_id,
'loss_improvement': loss_improvement,
'checkpoint_saved': checkpoint_saved,
'session': self.negative_case_trainer.training_session_count
}
else:
return {'status': 'no_cases'}
except Exception as e:
logger.error(f"Error processing negative cases: {e}")
return {'status': 'error', 'error': str(e)}
async def _coordinate_checkpoint_saving(self, dqn_results: Dict, cnn_results: Dict,
extrema_results: Dict, negative_results: Dict):
"""Coordinate checkpoint saving across all components"""
try:
# Count successful checkpoints
checkpoints_saved = sum([
dqn_results.get('checkpoint_saved', False),
cnn_results.get('checkpoint_saved', False),
extrema_results.get('checkpoint_saved', False),
negative_results.get('checkpoint_saved', False)
])
if checkpoints_saved > 0:
logger.info(f"Saved {checkpoints_saved} checkpoints this cycle")
# Update best performances
if 'episode_reward' in dqn_results:
current_best = self.training_stats['best_performances'].get('dqn_reward', float('-inf'))
if dqn_results['episode_reward'] > current_best:
self.training_stats['best_performances']['dqn_reward'] = dqn_results['episode_reward']
if 'val_accuracy' in cnn_results:
current_best = self.training_stats['best_performances'].get('cnn_accuracy', 0.0)
if cnn_results['val_accuracy'] > current_best:
self.training_stats['best_performances']['cnn_accuracy'] = cnn_results['val_accuracy']
# Log checkpoint statistics every 10 cycles
if self.training_stats['total_training_sessions'] % 10 == 0:
await self._log_checkpoint_statistics()
except Exception as e:
logger.error(f"Error coordinating checkpoint saving: {e}")
async def _log_checkpoint_statistics(self):
"""Log comprehensive checkpoint statistics"""
try:
stats = get_checkpoint_stats()
logger.info("=== Checkpoint Statistics ===")
logger.info(f"Total checkpoints: {stats['total_checkpoints']}")
logger.info(f"Total size: {stats['total_size_mb']:.2f} MB")
logger.info(f"Models managed: {len(stats['models'])}")
for model_name, model_stats in stats['models'].items():
logger.info(f" {model_name}: {model_stats['checkpoint_count']} checkpoints, "
f"{model_stats['total_size_mb']:.2f} MB, "
f"best: {model_stats['best_performance']:.4f}")
logger.info(f"Training sessions: {self.training_stats['total_training_sessions']}")
logger.info(f"Checkpoints saved: {self.training_stats['checkpoints_saved']}")
logger.info(f"Best performances: {self.training_stats['best_performances']}")
except Exception as e:
logger.error(f"Error logging checkpoint statistics: {e}")
async def shutdown(self):
"""Shutdown the training system and save final checkpoints"""
logger.info("Shutting down checkpoint-integrated training system...")
self.running = False
try:
# Force save checkpoints for all components
if self.dqn_agent:
self.dqn_agent.save_checkpoint(0.0, force_save=True)
if self.cnn_trainer:
self.cnn_trainer.save_checkpoint(0.0, 0.0, 0.0, 0.0, force_save=True)
if self.extrema_trainer:
self.extrema_trainer.save_checkpoint(force_save=True)
if self.negative_case_trainer:
self.negative_case_trainer.save_checkpoint(force_save=True)
# Final statistics
await self._log_checkpoint_statistics()
logger.info("Checkpoint-integrated training system shutdown complete")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
async def main():
"""Main function to run the checkpoint-integrated training system"""
logger.info("🚀 Starting Checkpoint-Integrated Training System")
# Create and initialize the training system
training_system = CheckpointIntegratedTrainingSystem()
# Setup signal handlers for graceful shutdown
def signal_handler(signum, frame):
logger.info("Received shutdown signal")
asyncio.create_task(training_system.shutdown())
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
# Initialize components
await training_system.initialize_components()
# Run the integrated training loop
await training_system.run_integrated_training_loop()
except Exception as e:
logger.error(f"Error in main: {e}")
raise
finally:
await training_system.shutdown()
logger.info("✅ Checkpoint management integration complete!")
logger.info("All training pipelines now support automatic checkpointing")
if __name__ == "__main__":
# Ensure logs directory exists
Path("logs").mkdir(exist_ok=True)
# Run the checkpoint-integrated training system
asyncio.run(main())