525 lines
22 KiB
Python
525 lines
22 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Comprehensive Checkpoint Management Integration
|
|
|
|
This script demonstrates how to integrate the checkpoint management system
|
|
across all training pipelines in the gogo2 project.
|
|
|
|
Features:
|
|
- DQN Agent training with automatic checkpointing
|
|
- CNN Model training with checkpoint management
|
|
- ExtremaTrainer with checkpoint persistence
|
|
- NegativeCaseTrainer with checkpoint integration
|
|
- Unified training orchestration with checkpoint coordination
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
import signal
|
|
import sys
|
|
import numpy as np
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, Any, List
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('logs/checkpoint_integration.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Import checkpoint management
|
|
from utils.checkpoint_manager import get_checkpoint_manager, get_checkpoint_stats
|
|
from utils.training_integration import get_training_integration
|
|
|
|
# Import training components
|
|
from NN.models.dqn_agent import DQNAgent
|
|
from NN.models.cnn_model import CNNModelTrainer, create_enhanced_cnn_model
|
|
from core.extrema_trainer import ExtremaTrainer
|
|
from core.negative_case_trainer import NegativeCaseTrainer
|
|
from core.data_provider import DataProvider
|
|
from core.config import get_config
|
|
|
|
class CheckpointIntegratedTrainingSystem:
|
|
"""Unified training system with comprehensive checkpoint management"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the checkpoint-integrated training system"""
|
|
self.config = get_config()
|
|
self.running = False
|
|
|
|
# Checkpoint management
|
|
self.checkpoint_manager = get_checkpoint_manager()
|
|
self.training_integration = get_training_integration()
|
|
|
|
# Data provider
|
|
self.data_provider = DataProvider(
|
|
symbols=['ETH/USDT', 'BTC/USDT'],
|
|
timeframes=['1s', '1m', '1h', '1d']
|
|
)
|
|
|
|
# Training components with checkpoint management
|
|
self.dqn_agent = None
|
|
self.cnn_trainer = None
|
|
self.extrema_trainer = None
|
|
self.negative_case_trainer = None
|
|
|
|
# Training statistics
|
|
self.training_stats = {
|
|
'start_time': None,
|
|
'total_training_sessions': 0,
|
|
'checkpoints_saved': 0,
|
|
'models_loaded': 0,
|
|
'best_performances': {}
|
|
}
|
|
|
|
logger.info("Checkpoint-Integrated Training System initialized")
|
|
|
|
async def initialize_components(self):
|
|
"""Initialize all training components with checkpoint management"""
|
|
try:
|
|
logger.info("Initializing training components with checkpoint management...")
|
|
|
|
# Initialize data provider
|
|
await self.data_provider.start_real_time_streaming()
|
|
logger.info("Data provider streaming started")
|
|
|
|
# Initialize DQN Agent with checkpoint management
|
|
logger.info("Initializing DQN Agent with checkpoints...")
|
|
self.dqn_agent = DQNAgent(
|
|
state_shape=(100,), # Example state shape
|
|
n_actions=3,
|
|
model_name="integrated_dqn_agent",
|
|
enable_checkpoints=True
|
|
)
|
|
logger.info("✅ DQN Agent initialized with checkpoint management")
|
|
|
|
# Initialize CNN Model with checkpoint management
|
|
logger.info("Initializing CNN Model with checkpoints...")
|
|
cnn_model, self.cnn_trainer = create_enhanced_cnn_model(
|
|
input_size=60,
|
|
feature_dim=50,
|
|
output_size=3
|
|
)
|
|
# Update trainer with checkpoint management
|
|
self.cnn_trainer.model_name = "integrated_cnn_model"
|
|
self.cnn_trainer.enable_checkpoints = True
|
|
self.cnn_trainer.training_integration = self.training_integration
|
|
logger.info("✅ CNN Model initialized with checkpoint management")
|
|
|
|
# Initialize ExtremaTrainer with checkpoint management
|
|
logger.info("Initializing ExtremaTrainer with checkpoints...")
|
|
self.extrema_trainer = ExtremaTrainer(
|
|
data_provider=self.data_provider,
|
|
symbols=['ETH/USDT', 'BTC/USDT'],
|
|
model_name="integrated_extrema_trainer",
|
|
enable_checkpoints=True
|
|
)
|
|
await self.extrema_trainer.initialize_context_data()
|
|
logger.info("✅ ExtremaTrainer initialized with checkpoint management")
|
|
|
|
# Initialize NegativeCaseTrainer with checkpoint management
|
|
logger.info("Initializing NegativeCaseTrainer with checkpoints...")
|
|
self.negative_case_trainer = NegativeCaseTrainer(
|
|
model_name="integrated_negative_case_trainer",
|
|
enable_checkpoints=True
|
|
)
|
|
logger.info("✅ NegativeCaseTrainer initialized with checkpoint management")
|
|
|
|
# Load existing checkpoints for all components
|
|
self.training_stats['models_loaded'] = await self._load_all_checkpoints()
|
|
|
|
logger.info("All training components initialized successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing components: {e}")
|
|
raise
|
|
|
|
async def _load_all_checkpoints(self) -> int:
|
|
"""Load checkpoints for all training components"""
|
|
loaded_count = 0
|
|
|
|
try:
|
|
# DQN Agent checkpoint loading is handled in __init__
|
|
if hasattr(self.dqn_agent, 'episode_count') and self.dqn_agent.episode_count > 0:
|
|
loaded_count += 1
|
|
logger.info(f"DQN Agent resumed from episode {self.dqn_agent.episode_count}")
|
|
|
|
# CNN Trainer checkpoint loading is handled in __init__
|
|
if hasattr(self.cnn_trainer, 'epoch_count') and self.cnn_trainer.epoch_count > 0:
|
|
loaded_count += 1
|
|
logger.info(f"CNN Trainer resumed from epoch {self.cnn_trainer.epoch_count}")
|
|
|
|
# ExtremaTrainer checkpoint loading is handled in __init__
|
|
if hasattr(self.extrema_trainer, 'training_session_count') and self.extrema_trainer.training_session_count > 0:
|
|
loaded_count += 1
|
|
logger.info(f"ExtremaTrainer resumed from session {self.extrema_trainer.training_session_count}")
|
|
|
|
# NegativeCaseTrainer checkpoint loading is handled in __init__
|
|
if hasattr(self.negative_case_trainer, 'training_session_count') and self.negative_case_trainer.training_session_count > 0:
|
|
loaded_count += 1
|
|
logger.info(f"NegativeCaseTrainer resumed from session {self.negative_case_trainer.training_session_count}")
|
|
|
|
return loaded_count
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading checkpoints: {e}")
|
|
return 0
|
|
|
|
async def run_integrated_training_loop(self):
|
|
"""Run the integrated training loop with checkpoint coordination"""
|
|
logger.info("Starting integrated training loop with checkpoint management...")
|
|
|
|
self.running = True
|
|
self.training_stats['start_time'] = datetime.now()
|
|
|
|
training_cycle = 0
|
|
|
|
try:
|
|
while self.running:
|
|
training_cycle += 1
|
|
cycle_start = time.time()
|
|
|
|
logger.info(f"=== Training Cycle {training_cycle} ===")
|
|
|
|
# DQN Training
|
|
dqn_results = await self._train_dqn_agent()
|
|
|
|
# CNN Training
|
|
cnn_results = await self._train_cnn_model()
|
|
|
|
# Extrema Detection Training
|
|
extrema_results = await self._train_extrema_detector()
|
|
|
|
# Negative Case Training (runs in background)
|
|
negative_results = await self._process_negative_cases()
|
|
|
|
# Coordinate checkpoint saving
|
|
await self._coordinate_checkpoint_saving(
|
|
dqn_results, cnn_results, extrema_results, negative_results
|
|
)
|
|
|
|
# Update statistics
|
|
self.training_stats['total_training_sessions'] += 1
|
|
|
|
# Log cycle summary
|
|
cycle_duration = time.time() - cycle_start
|
|
logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s")
|
|
|
|
# Wait before next cycle
|
|
await asyncio.sleep(60) # 1-minute cycles
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Training interrupted by user")
|
|
except Exception as e:
|
|
logger.error(f"Error in training loop: {e}")
|
|
finally:
|
|
await self.shutdown()
|
|
|
|
async def _train_dqn_agent(self) -> Dict[str, Any]:
|
|
"""Train DQN agent with automatic checkpointing"""
|
|
try:
|
|
if not self.dqn_agent:
|
|
return {'status': 'skipped', 'reason': 'no_agent'}
|
|
|
|
# Simulate DQN training episode
|
|
episode_reward = 0.0
|
|
|
|
# Add some training experiences (simulate real training)
|
|
for _ in range(10): # Simulate 10 training steps
|
|
state = np.random.randn(100).astype(np.float32)
|
|
action = np.random.randint(0, 3)
|
|
reward = np.random.randn() * 0.1
|
|
next_state = np.random.randn(100).astype(np.float32)
|
|
done = np.random.random() < 0.1
|
|
|
|
self.dqn_agent.remember(state, action, reward, next_state, done)
|
|
episode_reward += reward
|
|
|
|
# Train if enough experiences
|
|
loss = 0.0
|
|
if len(self.dqn_agent.memory) >= self.dqn_agent.batch_size:
|
|
loss = self.dqn_agent.replay()
|
|
|
|
# Save checkpoint (automatic based on performance)
|
|
checkpoint_saved = self.dqn_agent.save_checkpoint(episode_reward)
|
|
|
|
if checkpoint_saved:
|
|
self.training_stats['checkpoints_saved'] += 1
|
|
|
|
return {
|
|
'status': 'completed',
|
|
'episode_reward': episode_reward,
|
|
'loss': loss,
|
|
'checkpoint_saved': checkpoint_saved,
|
|
'episode': self.dqn_agent.episode_count
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training DQN agent: {e}")
|
|
return {'status': 'error', 'error': str(e)}
|
|
|
|
async def _train_cnn_model(self) -> Dict[str, Any]:
|
|
"""Train CNN model with automatic checkpointing"""
|
|
try:
|
|
if not self.cnn_trainer:
|
|
return {'status': 'skipped', 'reason': 'no_trainer'}
|
|
|
|
# Simulate CNN training step
|
|
import torch
|
|
import numpy as np
|
|
|
|
batch_size = 32
|
|
input_size = 60
|
|
feature_dim = 50
|
|
|
|
# Generate synthetic training data
|
|
x = torch.randn(batch_size, input_size, feature_dim)
|
|
y = torch.randint(0, 3, (batch_size,))
|
|
|
|
# Training step
|
|
results = self.cnn_trainer.train_step(x, y)
|
|
|
|
# Simulate validation
|
|
val_x = torch.randn(16, input_size, feature_dim)
|
|
val_y = torch.randint(0, 3, (16,))
|
|
val_results = self.cnn_trainer.train_step(val_x, val_y)
|
|
|
|
# Save checkpoint (automatic based on performance)
|
|
checkpoint_saved = self.cnn_trainer.save_checkpoint(
|
|
train_accuracy=results.get('accuracy', 0.5),
|
|
val_accuracy=val_results.get('accuracy', 0.5),
|
|
train_loss=results.get('total_loss', 1.0),
|
|
val_loss=val_results.get('total_loss', 1.0)
|
|
)
|
|
|
|
if checkpoint_saved:
|
|
self.training_stats['checkpoints_saved'] += 1
|
|
|
|
return {
|
|
'status': 'completed',
|
|
'train_accuracy': results.get('accuracy', 0.5),
|
|
'val_accuracy': val_results.get('accuracy', 0.5),
|
|
'train_loss': results.get('total_loss', 1.0),
|
|
'val_loss': val_results.get('total_loss', 1.0),
|
|
'checkpoint_saved': checkpoint_saved,
|
|
'epoch': self.cnn_trainer.epoch_count
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training CNN model: {e}")
|
|
return {'status': 'error', 'error': str(e)}
|
|
|
|
async def _train_extrema_detector(self) -> Dict[str, Any]:
|
|
"""Train extrema detector with automatic checkpointing"""
|
|
try:
|
|
if not self.extrema_trainer:
|
|
return {'status': 'skipped', 'reason': 'no_trainer'}
|
|
|
|
# Update context data and detect extrema
|
|
update_results = self.extrema_trainer.update_context_data()
|
|
|
|
# Get training data
|
|
extrema_data = self.extrema_trainer.get_extrema_training_data(count=10)
|
|
|
|
# Simulate training accuracy improvement
|
|
if extrema_data:
|
|
self.extrema_trainer.training_stats['total_extrema_detected'] += len(extrema_data)
|
|
self.extrema_trainer.training_stats['successful_predictions'] += len(extrema_data) // 2
|
|
self.extrema_trainer.training_stats['failed_predictions'] += len(extrema_data) // 2
|
|
|
|
# Save checkpoint (automatic based on performance)
|
|
checkpoint_saved = self.extrema_trainer.save_checkpoint()
|
|
|
|
if checkpoint_saved:
|
|
self.training_stats['checkpoints_saved'] += 1
|
|
|
|
return {
|
|
'status': 'completed',
|
|
'extrema_detected': len(extrema_data),
|
|
'context_updates': sum(1 for success in update_results.values() if success),
|
|
'checkpoint_saved': checkpoint_saved,
|
|
'session': self.extrema_trainer.training_session_count
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training extrema detector: {e}")
|
|
return {'status': 'error', 'error': str(e)}
|
|
|
|
async def _process_negative_cases(self) -> Dict[str, Any]:
|
|
"""Process negative cases with automatic checkpointing"""
|
|
try:
|
|
if not self.negative_case_trainer:
|
|
return {'status': 'skipped', 'reason': 'no_trainer'}
|
|
|
|
# Simulate adding a negative case
|
|
if np.random.random() < 0.1: # 10% chance of negative case
|
|
trade_info = {
|
|
'symbol': 'ETH/USDT',
|
|
'action': 'BUY',
|
|
'price': 2000.0,
|
|
'pnl': -50.0, # Loss
|
|
'value': 1000.0,
|
|
'confidence': 0.7,
|
|
'timestamp': datetime.now()
|
|
}
|
|
|
|
market_data = {
|
|
'exit_price': 1950.0,
|
|
'state_before': {},
|
|
'state_after': {},
|
|
'tick_data': [],
|
|
'technical_indicators': {}
|
|
}
|
|
|
|
case_id = self.negative_case_trainer.add_losing_trade(trade_info, market_data)
|
|
|
|
# Simulate loss improvement
|
|
loss_improvement = np.random.random() * 0.1
|
|
|
|
# Save checkpoint (automatic based on performance)
|
|
checkpoint_saved = self.negative_case_trainer.save_checkpoint(loss_improvement)
|
|
|
|
if checkpoint_saved:
|
|
self.training_stats['checkpoints_saved'] += 1
|
|
|
|
return {
|
|
'status': 'completed',
|
|
'case_added': case_id,
|
|
'loss_improvement': loss_improvement,
|
|
'checkpoint_saved': checkpoint_saved,
|
|
'session': self.negative_case_trainer.training_session_count
|
|
}
|
|
else:
|
|
return {'status': 'no_cases'}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing negative cases: {e}")
|
|
return {'status': 'error', 'error': str(e)}
|
|
|
|
async def _coordinate_checkpoint_saving(self, dqn_results: Dict, cnn_results: Dict,
|
|
extrema_results: Dict, negative_results: Dict):
|
|
"""Coordinate checkpoint saving across all components"""
|
|
try:
|
|
# Count successful checkpoints
|
|
checkpoints_saved = sum([
|
|
dqn_results.get('checkpoint_saved', False),
|
|
cnn_results.get('checkpoint_saved', False),
|
|
extrema_results.get('checkpoint_saved', False),
|
|
negative_results.get('checkpoint_saved', False)
|
|
])
|
|
|
|
if checkpoints_saved > 0:
|
|
logger.info(f"Saved {checkpoints_saved} checkpoints this cycle")
|
|
|
|
# Update best performances
|
|
if 'episode_reward' in dqn_results:
|
|
current_best = self.training_stats['best_performances'].get('dqn_reward', float('-inf'))
|
|
if dqn_results['episode_reward'] > current_best:
|
|
self.training_stats['best_performances']['dqn_reward'] = dqn_results['episode_reward']
|
|
|
|
if 'val_accuracy' in cnn_results:
|
|
current_best = self.training_stats['best_performances'].get('cnn_accuracy', 0.0)
|
|
if cnn_results['val_accuracy'] > current_best:
|
|
self.training_stats['best_performances']['cnn_accuracy'] = cnn_results['val_accuracy']
|
|
|
|
# Log checkpoint statistics every 10 cycles
|
|
if self.training_stats['total_training_sessions'] % 10 == 0:
|
|
await self._log_checkpoint_statistics()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error coordinating checkpoint saving: {e}")
|
|
|
|
async def _log_checkpoint_statistics(self):
|
|
"""Log comprehensive checkpoint statistics"""
|
|
try:
|
|
stats = get_checkpoint_stats()
|
|
|
|
logger.info("=== Checkpoint Statistics ===")
|
|
logger.info(f"Total checkpoints: {stats['total_checkpoints']}")
|
|
logger.info(f"Total size: {stats['total_size_mb']:.2f} MB")
|
|
logger.info(f"Models managed: {len(stats['models'])}")
|
|
|
|
for model_name, model_stats in stats['models'].items():
|
|
logger.info(f" {model_name}: {model_stats['checkpoint_count']} checkpoints, "
|
|
f"{model_stats['total_size_mb']:.2f} MB, "
|
|
f"best: {model_stats['best_performance']:.4f}")
|
|
|
|
logger.info(f"Training sessions: {self.training_stats['total_training_sessions']}")
|
|
logger.info(f"Checkpoints saved: {self.training_stats['checkpoints_saved']}")
|
|
logger.info(f"Best performances: {self.training_stats['best_performances']}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error logging checkpoint statistics: {e}")
|
|
|
|
async def shutdown(self):
|
|
"""Shutdown the training system and save final checkpoints"""
|
|
logger.info("Shutting down checkpoint-integrated training system...")
|
|
|
|
self.running = False
|
|
|
|
try:
|
|
# Force save checkpoints for all components
|
|
if self.dqn_agent:
|
|
self.dqn_agent.save_checkpoint(0.0, force_save=True)
|
|
|
|
if self.cnn_trainer:
|
|
self.cnn_trainer.save_checkpoint(0.0, 0.0, 0.0, 0.0, force_save=True)
|
|
|
|
if self.extrema_trainer:
|
|
self.extrema_trainer.save_checkpoint(force_save=True)
|
|
|
|
if self.negative_case_trainer:
|
|
self.negative_case_trainer.save_checkpoint(force_save=True)
|
|
|
|
# Final statistics
|
|
await self._log_checkpoint_statistics()
|
|
|
|
logger.info("Checkpoint-integrated training system shutdown complete")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during shutdown: {e}")
|
|
|
|
async def main():
|
|
"""Main function to run the checkpoint-integrated training system"""
|
|
logger.info("🚀 Starting Checkpoint-Integrated Training System")
|
|
|
|
# Create and initialize the training system
|
|
training_system = CheckpointIntegratedTrainingSystem()
|
|
|
|
# Setup signal handlers for graceful shutdown
|
|
def signal_handler(signum, frame):
|
|
logger.info("Received shutdown signal")
|
|
asyncio.create_task(training_system.shutdown())
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
try:
|
|
# Initialize components
|
|
await training_system.initialize_components()
|
|
|
|
# Run the integrated training loop
|
|
await training_system.run_integrated_training_loop()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in main: {e}")
|
|
raise
|
|
finally:
|
|
await training_system.shutdown()
|
|
|
|
logger.info("✅ Checkpoint management integration complete!")
|
|
logger.info("All training pipelines now support automatic checkpointing")
|
|
|
|
if __name__ == "__main__":
|
|
# Ensure logs directory exists
|
|
Path("logs").mkdir(exist_ok=True)
|
|
|
|
# Run the checkpoint-integrated training system
|
|
asyncio.run(main()) |