BIG CLEANUP
This commit is contained in:
@ -1,490 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Enhanced RL Training Integration - Comprehensive Fix
|
||||
|
||||
This script addresses the critical RL training audit issues:
|
||||
1. MASSIVE INPUT DATA GAP (99.25% Missing) - Implements full 13,400 feature state
|
||||
2. Disconnected Training Pipeline - Provides proper data flow integration
|
||||
3. Missing Enhanced State Builder - Connects orchestrator to dashboard
|
||||
4. Reward Calculation Issues - Ensures enhanced pivot-based rewards
|
||||
5. Williams Market Structure Integration - Proper feature extraction
|
||||
6. Real-time Data Integration - Live market data to RL
|
||||
|
||||
Usage:
|
||||
python enhanced_rl_training_integration.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import setup_logging, get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.trading_executor import TradingExecutor
|
||||
from web.clean_dashboard import CleanTradingDashboard as TradingDashboard
|
||||
from utils.tensorboard_logger import TensorBoardLogger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EnhancedRLTrainingIntegrator:
|
||||
"""
|
||||
Comprehensive RL Training Integrator
|
||||
|
||||
Fixes all audit issues by ensuring proper data flow and feature completeness.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the enhanced RL training integrator"""
|
||||
# Setup logging
|
||||
setup_logging()
|
||||
logger.info("=" * 70)
|
||||
logger.info("ENHANCED RL TRAINING INTEGRATION - COMPREHENSIVE FIX")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Get configuration
|
||||
self.config = get_config()
|
||||
|
||||
# Initialize core components
|
||||
self.data_provider = DataProvider()
|
||||
self.enhanced_orchestrator = None
|
||||
self.trading_executor = TradingExecutor()
|
||||
self.dashboard = None
|
||||
|
||||
# Training metrics
|
||||
self.training_stats = {
|
||||
'total_episodes': 0,
|
||||
'successful_state_builds': 0,
|
||||
'enhanced_reward_calculations': 0,
|
||||
'comprehensive_features_used': 0,
|
||||
'pivot_features_extracted': 0,
|
||||
'cob_features_available': 0
|
||||
}
|
||||
|
||||
# Initialize TensorBoard logger
|
||||
experiment_name = f"enhanced_rl_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
self.tb_logger = TensorBoardLogger(
|
||||
log_dir="runs",
|
||||
experiment_name=experiment_name,
|
||||
enabled=True
|
||||
)
|
||||
logger.info(f"TensorBoard logging enabled for experiment: {experiment_name}")
|
||||
|
||||
logger.info("Enhanced RL Training Integrator initialized")
|
||||
|
||||
async def start_integration(self):
|
||||
"""Start the comprehensive RL training integration"""
|
||||
try:
|
||||
logger.info("Starting comprehensive RL training integration...")
|
||||
|
||||
# 1. Initialize Enhanced Orchestrator with comprehensive features
|
||||
await self._initialize_enhanced_orchestrator()
|
||||
|
||||
# 2. Create enhanced dashboard with proper connections
|
||||
await self._create_enhanced_dashboard()
|
||||
|
||||
# 3. Verify comprehensive state building
|
||||
await self._verify_comprehensive_state_building()
|
||||
|
||||
# 4. Test enhanced reward calculation
|
||||
await self._test_enhanced_reward_calculation()
|
||||
|
||||
# 5. Validate Williams market structure integration
|
||||
await self._validate_williams_integration()
|
||||
|
||||
# 6. Start live training with comprehensive features
|
||||
await self._start_live_comprehensive_training()
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info("COMPREHENSIVE RL TRAINING INTEGRATION COMPLETE")
|
||||
logger.info("=" * 70)
|
||||
self._log_integration_stats()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in RL training integration: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _initialize_enhanced_orchestrator(self):
|
||||
"""Initialize enhanced orchestrator with comprehensive RL capabilities"""
|
||||
try:
|
||||
logger.info("[STEP 1] Initializing Enhanced Orchestrator...")
|
||||
|
||||
# Create enhanced orchestrator with RL training enabled
|
||||
self.enhanced_orchestrator = EnhancedTradingOrchestrator(
|
||||
data_provider=self.data_provider,
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
enhanced_rl_training=True,
|
||||
model_registry={} # Will be populated as needed
|
||||
)
|
||||
|
||||
# Start COB integration for real-time market microstructure
|
||||
await self.enhanced_orchestrator.start_cob_integration()
|
||||
|
||||
# Start real-time processing
|
||||
await self.enhanced_orchestrator.start_realtime_processing()
|
||||
|
||||
logger.info("[SUCCESS] Enhanced Orchestrator initialized with:")
|
||||
logger.info(" - Comprehensive RL state building: ENABLED")
|
||||
logger.info(" - Enhanced pivot-based rewards: ENABLED")
|
||||
logger.info(" - COB integration: ENABLED")
|
||||
logger.info(" - Williams market structure: ENABLED")
|
||||
logger.info(" - Real-time tick processing: ENABLED")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing enhanced orchestrator: {e}")
|
||||
raise
|
||||
|
||||
async def _create_enhanced_dashboard(self):
|
||||
"""Create dashboard with enhanced orchestrator connections"""
|
||||
try:
|
||||
logger.info("[STEP 2] Creating Enhanced Dashboard...")
|
||||
|
||||
# Create trading dashboard with enhanced orchestrator
|
||||
self.dashboard = TradingDashboard(
|
||||
data_provider=self.data_provider,
|
||||
orchestrator=self.enhanced_orchestrator, # Use enhanced orchestrator
|
||||
trading_executor=self.trading_executor
|
||||
)
|
||||
|
||||
# Verify enhanced connections
|
||||
has_comprehensive_state_builder = hasattr(self.dashboard.orchestrator, 'build_comprehensive_rl_state')
|
||||
has_enhanced_reward_calc = hasattr(self.dashboard.orchestrator, 'calculate_enhanced_pivot_reward')
|
||||
has_symbol_correlation = hasattr(self.dashboard.orchestrator, '_get_symbol_correlation')
|
||||
|
||||
logger.info("[SUCCESS] Enhanced Dashboard created with:")
|
||||
logger.info(f" - Comprehensive state builder: {'AVAILABLE' if has_comprehensive_state_builder else 'MISSING'}")
|
||||
logger.info(f" - Enhanced reward calculation: {'AVAILABLE' if has_enhanced_reward_calc else 'MISSING'}")
|
||||
logger.info(f" - Symbol correlation analysis: {'AVAILABLE' if has_symbol_correlation else 'MISSING'}")
|
||||
|
||||
if not all([has_comprehensive_state_builder, has_enhanced_reward_calc, has_symbol_correlation]):
|
||||
logger.warning("Some enhanced features are missing - this will cause fallbacks to basic training")
|
||||
else:
|
||||
logger.info(" - ALL ENHANCED FEATURES AVAILABLE!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating enhanced dashboard: {e}")
|
||||
raise
|
||||
|
||||
async def _verify_comprehensive_state_building(self):
|
||||
"""Verify that comprehensive RL state building works correctly"""
|
||||
try:
|
||||
logger.info("[STEP 3] Verifying Comprehensive State Building...")
|
||||
|
||||
# Test comprehensive state building for ETH
|
||||
eth_state = self.enhanced_orchestrator.build_comprehensive_rl_state('ETH/USDT')
|
||||
|
||||
if eth_state is not None:
|
||||
logger.info(f"[SUCCESS] ETH comprehensive state built: {len(eth_state)} features")
|
||||
|
||||
# Verify feature count
|
||||
if len(eth_state) == 13400:
|
||||
logger.info(" - PERFECT: Exactly 13,400 features as required!")
|
||||
self.training_stats['comprehensive_features_used'] += 1
|
||||
else:
|
||||
logger.warning(f" - MISMATCH: Expected 13,400 features, got {len(eth_state)}")
|
||||
|
||||
# Analyze feature distribution
|
||||
self._analyze_state_features(eth_state)
|
||||
self.training_stats['successful_state_builds'] += 1
|
||||
|
||||
else:
|
||||
logger.error(" - FAILED: Comprehensive state building returned None")
|
||||
|
||||
# Test for BTC reference
|
||||
btc_state = self.enhanced_orchestrator.build_comprehensive_rl_state('BTC/USDT')
|
||||
if btc_state is not None:
|
||||
logger.info(f"[SUCCESS] BTC reference state built: {len(btc_state)} features")
|
||||
self.training_stats['successful_state_builds'] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying comprehensive state building: {e}")
|
||||
|
||||
def _analyze_state_features(self, state_vector: np.ndarray):
|
||||
"""Analyze the comprehensive state feature distribution"""
|
||||
try:
|
||||
# Calculate feature statistics
|
||||
non_zero_features = np.count_nonzero(state_vector)
|
||||
zero_features = len(state_vector) - non_zero_features
|
||||
feature_mean = np.mean(state_vector)
|
||||
feature_std = np.std(state_vector)
|
||||
feature_min = np.min(state_vector)
|
||||
feature_max = np.max(state_vector)
|
||||
|
||||
logger.info(" - Feature Analysis:")
|
||||
logger.info(f" * Non-zero features: {non_zero_features:,} ({non_zero_features/len(state_vector)*100:.1f}%)")
|
||||
logger.info(f" * Zero features: {zero_features:,} ({zero_features/len(state_vector)*100:.1f}%)")
|
||||
logger.info(f" * Mean: {feature_mean:.6f}")
|
||||
logger.info(f" * Std: {feature_std:.6f}")
|
||||
logger.info(f" * Range: [{feature_min:.6f}, {feature_max:.6f}]")
|
||||
|
||||
# Log feature statistics to TensorBoard
|
||||
step = self.training_stats['total_episodes']
|
||||
self.tb_logger.log_scalars('Features/Distribution', {
|
||||
'non_zero_percentage': non_zero_features/len(state_vector)*100,
|
||||
'mean': feature_mean,
|
||||
'std': feature_std,
|
||||
'min': feature_min,
|
||||
'max': feature_max
|
||||
}, step)
|
||||
|
||||
# Log feature histogram to TensorBoard
|
||||
self.tb_logger.log_histogram('Features/Values', state_vector, step)
|
||||
|
||||
# Check if features are properly distributed
|
||||
if non_zero_features > len(state_vector) * 0.1: # At least 10% non-zero
|
||||
logger.info(" * GOOD: Features are well distributed")
|
||||
else:
|
||||
logger.warning(" * WARNING: Too many zero features - data may be incomplete")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error analyzing state features: {e}")
|
||||
|
||||
async def _test_enhanced_reward_calculation(self):
|
||||
"""Test enhanced pivot-based reward calculation"""
|
||||
try:
|
||||
logger.info("[STEP 4] Testing Enhanced Reward Calculation...")
|
||||
|
||||
# Create mock trade data for testing
|
||||
trade_decision = {
|
||||
'action': 'BUY',
|
||||
'confidence': 0.75,
|
||||
'price': 2500.0,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
trade_outcome = {
|
||||
'net_pnl': 50.0,
|
||||
'exit_price': 2550.0,
|
||||
'duration': timedelta(minutes=15)
|
||||
}
|
||||
|
||||
# Get market data for reward calculation
|
||||
market_data = {
|
||||
'volatility': 0.03,
|
||||
'order_flow_direction': 'bullish',
|
||||
'order_flow_strength': 0.8
|
||||
}
|
||||
|
||||
# Test enhanced reward calculation
|
||||
if hasattr(self.enhanced_orchestrator, 'calculate_enhanced_pivot_reward'):
|
||||
enhanced_reward = self.enhanced_orchestrator.calculate_enhanced_pivot_reward(
|
||||
trade_decision, market_data, trade_outcome
|
||||
)
|
||||
|
||||
logger.info(f"[SUCCESS] Enhanced reward calculated: {enhanced_reward:.3f}")
|
||||
logger.info(" - Enhanced pivot-based reward system: WORKING")
|
||||
self.training_stats['enhanced_reward_calculations'] += 1
|
||||
|
||||
# Log reward metrics to TensorBoard
|
||||
step = self.training_stats['enhanced_reward_calculations']
|
||||
self.tb_logger.log_scalar('Rewards/Enhanced', enhanced_reward, step)
|
||||
|
||||
# Log reward components to TensorBoard
|
||||
self.tb_logger.log_scalars('Rewards/Components', {
|
||||
'pnl_component': trade_outcome['net_pnl'],
|
||||
'confidence': trade_decision['confidence'],
|
||||
'volatility': market_data['volatility'],
|
||||
'order_flow_strength': market_data['order_flow_strength']
|
||||
}, step)
|
||||
|
||||
else:
|
||||
logger.error(" - FAILED: Enhanced reward calculation method not available")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing enhanced reward calculation: {e}")
|
||||
|
||||
async def _validate_williams_integration(self):
|
||||
"""Validate Williams market structure integration"""
|
||||
try:
|
||||
logger.info("[STEP 5] Validating Williams Market Structure Integration...")
|
||||
|
||||
# Test Williams pivot feature extraction
|
||||
try:
|
||||
from training.williams_market_structure import extract_pivot_features, analyze_pivot_context
|
||||
|
||||
# Get test market data
|
||||
df = self.data_provider.get_historical_data('ETH/USDT', '1m', limit=100)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Test pivot feature extraction
|
||||
pivot_features = extract_pivot_features(df)
|
||||
|
||||
if pivot_features is not None:
|
||||
logger.info(f"[SUCCESS] Williams pivot features extracted: {len(pivot_features)} features")
|
||||
self.training_stats['pivot_features_extracted'] += 1
|
||||
|
||||
# Test pivot context analysis
|
||||
market_data = {'ohlcv_data': df}
|
||||
pivot_context = analyze_pivot_context(
|
||||
market_data, datetime.now(), 'BUY'
|
||||
)
|
||||
|
||||
if pivot_context is not None:
|
||||
logger.info("[SUCCESS] Williams pivot context analysis: WORKING")
|
||||
logger.info(f" - Near pivot: {pivot_context.get('near_pivot', False)}")
|
||||
logger.info(f" - Pivot strength: {pivot_context.get('pivot_strength', 0):.3f}")
|
||||
else:
|
||||
logger.warning(" - Williams pivot context analysis returned None")
|
||||
else:
|
||||
logger.warning(" - Williams pivot feature extraction returned None")
|
||||
else:
|
||||
logger.warning(" - No market data available for Williams testing")
|
||||
|
||||
except ImportError:
|
||||
logger.error(" - Williams market structure module not available")
|
||||
except Exception as e:
|
||||
logger.error(f" - Error in Williams integration: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating Williams integration: {e}")
|
||||
|
||||
async def _start_live_comprehensive_training(self):
|
||||
"""Start live training with comprehensive feature integration"""
|
||||
try:
|
||||
logger.info("[STEP 6] Starting Live Comprehensive Training...")
|
||||
|
||||
# Run a few training iterations to verify integration
|
||||
for iteration in range(5):
|
||||
logger.info(f"Training iteration {iteration + 1}/5")
|
||||
|
||||
# Make coordinated decisions using enhanced orchestrator
|
||||
decisions = await self.enhanced_orchestrator.make_coordinated_decisions()
|
||||
|
||||
# Track iteration metrics for TensorBoard
|
||||
iteration_metrics = {
|
||||
'decisions_count': len(decisions),
|
||||
'confidence_avg': 0.0,
|
||||
'state_size_avg': 0.0,
|
||||
'successful_states': 0
|
||||
}
|
||||
|
||||
# Process each decision
|
||||
for symbol, decision in decisions.items():
|
||||
if decision:
|
||||
logger.info(f" {symbol}: {decision.action} (confidence: {decision.confidence:.3f})")
|
||||
|
||||
# Track confidence for TensorBoard
|
||||
iteration_metrics['confidence_avg'] += decision.confidence
|
||||
|
||||
# Build comprehensive state for this decision
|
||||
comprehensive_state = self.enhanced_orchestrator.build_comprehensive_rl_state(symbol)
|
||||
|
||||
if comprehensive_state is not None:
|
||||
state_size = len(comprehensive_state)
|
||||
logger.info(f" - Comprehensive state: {state_size} features")
|
||||
self.training_stats['total_episodes'] += 1
|
||||
|
||||
# Track state size for TensorBoard
|
||||
iteration_metrics['state_size_avg'] += state_size
|
||||
iteration_metrics['successful_states'] += 1
|
||||
|
||||
# Log individual state metrics to TensorBoard
|
||||
self.tb_logger.log_state_metrics(
|
||||
symbol=symbol,
|
||||
state_info={
|
||||
'size': state_size,
|
||||
'quality': 1.0 if state_size == 13400 else 0.8,
|
||||
'feature_counts': {
|
||||
'total': state_size,
|
||||
'non_zero': np.count_nonzero(comprehensive_state)
|
||||
}
|
||||
},
|
||||
step=self.training_stats['total_episodes']
|
||||
)
|
||||
else:
|
||||
logger.warning(f" - Failed to build comprehensive state for {symbol}")
|
||||
|
||||
# Calculate averages for TensorBoard
|
||||
if decisions:
|
||||
iteration_metrics['confidence_avg'] /= len(decisions)
|
||||
|
||||
if iteration_metrics['successful_states'] > 0:
|
||||
iteration_metrics['state_size_avg'] /= iteration_metrics['successful_states']
|
||||
|
||||
# Log iteration metrics to TensorBoard
|
||||
self.tb_logger.log_scalars('Training/Iteration', {
|
||||
'iteration': iteration + 1,
|
||||
'decisions_count': iteration_metrics['decisions_count'],
|
||||
'confidence_avg': iteration_metrics['confidence_avg'],
|
||||
'state_size_avg': iteration_metrics['state_size_avg'],
|
||||
'successful_states': iteration_metrics['successful_states']
|
||||
}, iteration + 1)
|
||||
|
||||
# Wait between iterations
|
||||
await asyncio.sleep(2)
|
||||
|
||||
logger.info("[SUCCESS] Live comprehensive training demonstration complete")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in live comprehensive training: {e}")
|
||||
|
||||
def _log_integration_stats(self):
|
||||
"""Log comprehensive integration statistics"""
|
||||
logger.info("INTEGRATION STATISTICS:")
|
||||
logger.info(f" - Total training episodes: {self.training_stats['total_episodes']}")
|
||||
logger.info(f" - Successful state builds: {self.training_stats['successful_state_builds']}")
|
||||
logger.info(f" - Enhanced reward calculations: {self.training_stats['enhanced_reward_calculations']}")
|
||||
logger.info(f" - Comprehensive features used: {self.training_stats['comprehensive_features_used']}")
|
||||
logger.info(f" - Pivot features extracted: {self.training_stats['pivot_features_extracted']}")
|
||||
|
||||
# Calculate success rates
|
||||
state_success_rate = 0
|
||||
if self.training_stats['total_episodes'] > 0:
|
||||
state_success_rate = self.training_stats['successful_state_builds'] / self.training_stats['total_episodes'] * 100
|
||||
logger.info(f" - State building success rate: {state_success_rate:.1f}%")
|
||||
|
||||
# Log final statistics to TensorBoard
|
||||
self.tb_logger.log_scalars('Integration/Statistics', {
|
||||
'total_episodes': self.training_stats['total_episodes'],
|
||||
'successful_state_builds': self.training_stats['successful_state_builds'],
|
||||
'enhanced_reward_calculations': self.training_stats['enhanced_reward_calculations'],
|
||||
'comprehensive_features_used': self.training_stats['comprehensive_features_used'],
|
||||
'pivot_features_extracted': self.training_stats['pivot_features_extracted'],
|
||||
'state_success_rate': state_success_rate
|
||||
}, 0) # Use step 0 for final summary stats
|
||||
|
||||
# Integration status
|
||||
if self.training_stats['comprehensive_features_used'] > 0:
|
||||
logger.info("STATUS: COMPREHENSIVE RL TRAINING INTEGRATION SUCCESSFUL! ✅")
|
||||
logger.info("The system is now using the full 13,400 feature comprehensive state.")
|
||||
|
||||
# Log success status to TensorBoard
|
||||
self.tb_logger.log_scalar('Integration/Success', 1.0, 0)
|
||||
else:
|
||||
logger.warning("STATUS: Integration partially successful - some fallbacks may occur")
|
||||
|
||||
# Log partial success status to TensorBoard
|
||||
self.tb_logger.log_scalar('Integration/Success', 0.5, 0)
|
||||
|
||||
async def main():
|
||||
"""Main entry point"""
|
||||
try:
|
||||
# Create and run the enhanced RL training integrator
|
||||
integrator = EnhancedRLTrainingIntegrator()
|
||||
await integrator.start_integration()
|
||||
|
||||
logger.info("Enhanced RL training integration completed successfully!")
|
||||
return 0
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Integration interrupted by user")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in integration: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = asyncio.run(main())
|
||||
sys.exit(exit_code)
|
@ -1,148 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example: Using the Checkpoint Management System
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint, get_checkpoint_manager
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ExampleCNN(nn.Module):
|
||||
def __init__(self, input_channels=5, num_classes=3):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(input_channels, 32, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
|
||||
self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(64, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.relu(self.conv1(x))
|
||||
x = torch.relu(self.conv2(x))
|
||||
x = self.pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return self.fc(x)
|
||||
|
||||
def example_cnn_training():
|
||||
logger.info("=== CNN Training Example ===")
|
||||
|
||||
model = ExampleCNN()
|
||||
training_integration = get_training_integration()
|
||||
|
||||
for epoch in range(5): # Simulate 5 epochs
|
||||
# Simulate training metrics
|
||||
train_loss = 2.0 - (epoch * 0.15) + np.random.normal(0, 0.1)
|
||||
train_acc = 0.3 + (epoch * 0.06) + np.random.normal(0, 0.02)
|
||||
val_loss = train_loss + np.random.normal(0, 0.05)
|
||||
val_acc = train_acc - 0.05 + np.random.normal(0, 0.02)
|
||||
|
||||
# Clamp values to realistic ranges
|
||||
train_acc = max(0.0, min(1.0, train_acc))
|
||||
val_acc = max(0.0, min(1.0, val_acc))
|
||||
train_loss = max(0.1, train_loss)
|
||||
val_loss = max(0.1, val_loss)
|
||||
|
||||
logger.info(f"Epoch {epoch+1}: train_acc={train_acc:.3f}, val_acc={val_acc:.3f}")
|
||||
|
||||
# Save checkpoint
|
||||
saved = training_integration.save_cnn_checkpoint(
|
||||
cnn_model=model,
|
||||
model_name="example_cnn",
|
||||
epoch=epoch + 1,
|
||||
train_accuracy=train_acc,
|
||||
val_accuracy=val_acc,
|
||||
train_loss=train_loss,
|
||||
val_loss=val_loss,
|
||||
training_time_hours=0.1 * (epoch + 1)
|
||||
)
|
||||
|
||||
if saved:
|
||||
logger.info(f" Checkpoint saved for epoch {epoch+1}")
|
||||
else:
|
||||
logger.info(f" Checkpoint not saved (performance not improved)")
|
||||
|
||||
# Load the best checkpoint
|
||||
logger.info("\\nLoading best checkpoint...")
|
||||
best_result = load_best_checkpoint("example_cnn")
|
||||
if best_result:
|
||||
file_path, metadata = best_result
|
||||
logger.info(f"Best checkpoint: {metadata.checkpoint_id}")
|
||||
logger.info(f"Performance score: {metadata.performance_score:.4f}")
|
||||
|
||||
def example_manual_checkpoint():
|
||||
logger.info("\\n=== Manual Checkpoint Example ===")
|
||||
|
||||
model = nn.Linear(10, 3)
|
||||
|
||||
performance_metrics = {
|
||||
'accuracy': 0.85,
|
||||
'val_accuracy': 0.82,
|
||||
'loss': 0.45,
|
||||
'val_loss': 0.48
|
||||
}
|
||||
|
||||
training_metadata = {
|
||||
'epoch': 25,
|
||||
'training_time_hours': 2.5,
|
||||
'total_parameters': sum(p.numel() for p in model.parameters())
|
||||
}
|
||||
|
||||
logger.info("Saving checkpoint manually...")
|
||||
metadata = save_checkpoint(
|
||||
model=model,
|
||||
model_name="example_manual",
|
||||
model_type="cnn",
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata,
|
||||
force_save=True
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f" Manual checkpoint saved: {metadata.checkpoint_id}")
|
||||
logger.info(f" Performance score: {metadata.performance_score:.4f}")
|
||||
|
||||
def show_checkpoint_stats():
|
||||
logger.info("\\n=== Checkpoint Statistics ===")
|
||||
|
||||
checkpoint_manager = get_checkpoint_manager()
|
||||
stats = checkpoint_manager.get_checkpoint_stats()
|
||||
|
||||
logger.info(f"Total models: {stats['total_models']}")
|
||||
logger.info(f"Total checkpoints: {stats['total_checkpoints']}")
|
||||
logger.info(f"Total size: {stats['total_size_mb']:.2f} MB")
|
||||
|
||||
for model_name, model_stats in stats['models'].items():
|
||||
logger.info(f"\\n{model_name}:")
|
||||
logger.info(f" Checkpoints: {model_stats['checkpoint_count']}")
|
||||
logger.info(f" Size: {model_stats['total_size_mb']:.2f} MB")
|
||||
logger.info(f" Best performance: {model_stats['best_performance']:.4f}")
|
||||
|
||||
def main():
|
||||
logger.info(" Checkpoint Management System Examples")
|
||||
logger.info("=" * 50)
|
||||
|
||||
try:
|
||||
example_cnn_training()
|
||||
example_manual_checkpoint()
|
||||
show_checkpoint_stats()
|
||||
|
||||
logger.info("\\n All examples completed successfully!")
|
||||
logger.info("\\nTo use in your training:")
|
||||
logger.info("1. Import: from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint")
|
||||
logger.info("2. Or use: from utils.training_integration import get_training_integration")
|
||||
logger.info("3. Save checkpoints during training with performance metrics")
|
||||
logger.info("4. Load best checkpoints for inference or continued training")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in examples: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,517 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive Checkpoint Management Integration
|
||||
|
||||
This script demonstrates how to integrate the checkpoint management system
|
||||
across all training pipelines in the gogo2 project.
|
||||
|
||||
Features:
|
||||
- DQN Agent training with automatic checkpointing
|
||||
- CNN Model training with checkpoint management
|
||||
- ExtremaTrainer with checkpoint persistence
|
||||
- NegativeCaseTrainer with checkpoint integration
|
||||
- Unified training orchestration with checkpoint coordination
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import signal
|
||||
import sys
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('logs/checkpoint_integration.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import get_checkpoint_manager, get_checkpoint_stats
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
# Import training components
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
from NN.models.standardized_cnn import StandardizedCNN
|
||||
from core.extrema_trainer import ExtremaTrainer
|
||||
from core.negative_case_trainer import NegativeCaseTrainer
|
||||
from core.data_provider import DataProvider
|
||||
from core.config import get_config
|
||||
|
||||
class CheckpointIntegratedTrainingSystem:
|
||||
"""Unified training system with comprehensive checkpoint management"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the checkpoint-integrated training system"""
|
||||
self.config = get_config()
|
||||
self.running = False
|
||||
|
||||
# Checkpoint management
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
self.training_integration = get_training_integration()
|
||||
|
||||
# Data provider
|
||||
self.data_provider = DataProvider(
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
timeframes=['1s', '1m', '1h', '1d']
|
||||
)
|
||||
|
||||
# Training components with checkpoint management
|
||||
self.dqn_agent = None
|
||||
self.cnn_trainer = None
|
||||
self.extrema_trainer = None
|
||||
self.negative_case_trainer = None
|
||||
|
||||
# Training statistics
|
||||
self.training_stats = {
|
||||
'start_time': None,
|
||||
'total_training_sessions': 0,
|
||||
'checkpoints_saved': 0,
|
||||
'models_loaded': 0,
|
||||
'best_performances': {}
|
||||
}
|
||||
|
||||
logger.info("Checkpoint-Integrated Training System initialized")
|
||||
|
||||
async def initialize_components(self):
|
||||
"""Initialize all training components with checkpoint management"""
|
||||
try:
|
||||
logger.info("Initializing training components with checkpoint management...")
|
||||
|
||||
# Initialize data provider
|
||||
await self.data_provider.start_real_time_streaming()
|
||||
logger.info("Data provider streaming started")
|
||||
|
||||
# Initialize DQN Agent with checkpoint management
|
||||
logger.info("Initializing DQN Agent with checkpoints...")
|
||||
self.dqn_agent = DQNAgent(
|
||||
state_shape=(100,), # Example state shape
|
||||
n_actions=3,
|
||||
model_name="integrated_dqn_agent",
|
||||
enable_checkpoints=True
|
||||
)
|
||||
logger.info("✅ DQN Agent initialized with checkpoint management")
|
||||
|
||||
# Initialize StandardizedCNN Model with checkpoint management
|
||||
logger.info("Initializing StandardizedCNN Model with checkpoints...")
|
||||
self.cnn_model = StandardizedCNN(model_name="integrated_cnn_model")
|
||||
logger.info("✅ StandardizedCNN Model initialized with checkpoint management")
|
||||
|
||||
# Initialize ExtremaTrainer with checkpoint management
|
||||
logger.info("Initializing ExtremaTrainer with checkpoints...")
|
||||
self.extrema_trainer = ExtremaTrainer(
|
||||
data_provider=self.data_provider,
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
model_name="integrated_extrema_trainer",
|
||||
enable_checkpoints=True
|
||||
)
|
||||
await self.extrema_trainer.initialize_context_data()
|
||||
logger.info("✅ ExtremaTrainer initialized with checkpoint management")
|
||||
|
||||
# Initialize NegativeCaseTrainer with checkpoint management
|
||||
logger.info("Initializing NegativeCaseTrainer with checkpoints...")
|
||||
self.negative_case_trainer = NegativeCaseTrainer(
|
||||
model_name="integrated_negative_case_trainer",
|
||||
enable_checkpoints=True
|
||||
)
|
||||
logger.info("✅ NegativeCaseTrainer initialized with checkpoint management")
|
||||
|
||||
# Load existing checkpoints for all components
|
||||
self.training_stats['models_loaded'] = await self._load_all_checkpoints()
|
||||
|
||||
logger.info("All training components initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing components: {e}")
|
||||
raise
|
||||
|
||||
async def _load_all_checkpoints(self) -> int:
|
||||
"""Load checkpoints for all training components"""
|
||||
loaded_count = 0
|
||||
|
||||
try:
|
||||
# DQN Agent checkpoint loading is handled in __init__
|
||||
if hasattr(self.dqn_agent, 'episode_count') and self.dqn_agent.episode_count > 0:
|
||||
loaded_count += 1
|
||||
logger.info(f"DQN Agent resumed from episode {self.dqn_agent.episode_count}")
|
||||
|
||||
# CNN Trainer checkpoint loading is handled in __init__
|
||||
if hasattr(self.cnn_trainer, 'epoch_count') and self.cnn_trainer.epoch_count > 0:
|
||||
loaded_count += 1
|
||||
logger.info(f"CNN Trainer resumed from epoch {self.cnn_trainer.epoch_count}")
|
||||
|
||||
# ExtremaTrainer checkpoint loading is handled in __init__
|
||||
if hasattr(self.extrema_trainer, 'training_session_count') and self.extrema_trainer.training_session_count > 0:
|
||||
loaded_count += 1
|
||||
logger.info(f"ExtremaTrainer resumed from session {self.extrema_trainer.training_session_count}")
|
||||
|
||||
# NegativeCaseTrainer checkpoint loading is handled in __init__
|
||||
if hasattr(self.negative_case_trainer, 'training_session_count') and self.negative_case_trainer.training_session_count > 0:
|
||||
loaded_count += 1
|
||||
logger.info(f"NegativeCaseTrainer resumed from session {self.negative_case_trainer.training_session_count}")
|
||||
|
||||
return loaded_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checkpoints: {e}")
|
||||
return 0
|
||||
|
||||
async def run_integrated_training_loop(self):
|
||||
"""Run the integrated training loop with checkpoint coordination"""
|
||||
logger.info("Starting integrated training loop with checkpoint management...")
|
||||
|
||||
self.running = True
|
||||
self.training_stats['start_time'] = datetime.now()
|
||||
|
||||
training_cycle = 0
|
||||
|
||||
try:
|
||||
while self.running:
|
||||
training_cycle += 1
|
||||
cycle_start = time.time()
|
||||
|
||||
logger.info(f"=== Training Cycle {training_cycle} ===")
|
||||
|
||||
# DQN Training
|
||||
dqn_results = await self._train_dqn_agent()
|
||||
|
||||
# CNN Training
|
||||
cnn_results = await self._train_cnn_model()
|
||||
|
||||
# Extrema Detection Training
|
||||
extrema_results = await self._train_extrema_detector()
|
||||
|
||||
# Negative Case Training (runs in background)
|
||||
negative_results = await self._process_negative_cases()
|
||||
|
||||
# Coordinate checkpoint saving
|
||||
await self._coordinate_checkpoint_saving(
|
||||
dqn_results, cnn_results, extrema_results, negative_results
|
||||
)
|
||||
|
||||
# Update statistics
|
||||
self.training_stats['total_training_sessions'] += 1
|
||||
|
||||
# Log cycle summary
|
||||
cycle_duration = time.time() - cycle_start
|
||||
logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s")
|
||||
|
||||
# Wait before next cycle
|
||||
await asyncio.sleep(60) # 1-minute cycles
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training loop: {e}")
|
||||
finally:
|
||||
await self.shutdown()
|
||||
|
||||
async def _train_dqn_agent(self) -> Dict[str, Any]:
|
||||
"""Train DQN agent with automatic checkpointing"""
|
||||
try:
|
||||
if not self.dqn_agent:
|
||||
return {'status': 'skipped', 'reason': 'no_agent'}
|
||||
|
||||
# Simulate DQN training episode
|
||||
episode_reward = 0.0
|
||||
|
||||
# Add some training experiences (simulate real training)
|
||||
for _ in range(10): # Simulate 10 training steps
|
||||
state = np.random.randn(100).astype(np.float32)
|
||||
action = np.random.randint(0, 3)
|
||||
reward = np.random.randn() * 0.1
|
||||
next_state = np.random.randn(100).astype(np.float32)
|
||||
done = np.random.random() < 0.1
|
||||
|
||||
self.dqn_agent.remember(state, action, reward, next_state, done)
|
||||
episode_reward += reward
|
||||
|
||||
# Train if enough experiences
|
||||
loss = 0.0
|
||||
if len(self.dqn_agent.memory) >= self.dqn_agent.batch_size:
|
||||
loss = self.dqn_agent.replay()
|
||||
|
||||
# Save checkpoint (automatic based on performance)
|
||||
checkpoint_saved = self.dqn_agent.save_checkpoint(episode_reward)
|
||||
|
||||
if checkpoint_saved:
|
||||
self.training_stats['checkpoints_saved'] += 1
|
||||
|
||||
return {
|
||||
'status': 'completed',
|
||||
'episode_reward': episode_reward,
|
||||
'loss': loss,
|
||||
'checkpoint_saved': checkpoint_saved,
|
||||
'episode': self.dqn_agent.episode_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training DQN agent: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
async def _train_cnn_model(self) -> Dict[str, Any]:
|
||||
"""Train CNN model with automatic checkpointing"""
|
||||
try:
|
||||
if not self.cnn_trainer:
|
||||
return {'status': 'skipped', 'reason': 'no_trainer'}
|
||||
|
||||
# Simulate CNN training step
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
batch_size = 32
|
||||
input_size = 60
|
||||
feature_dim = 50
|
||||
|
||||
# Generate synthetic training data
|
||||
x = torch.randn(batch_size, input_size, feature_dim)
|
||||
y = torch.randint(0, 3, (batch_size,))
|
||||
|
||||
# Training step
|
||||
results = self.cnn_trainer.train_step(x, y)
|
||||
|
||||
# Simulate validation
|
||||
val_x = torch.randn(16, input_size, feature_dim)
|
||||
val_y = torch.randint(0, 3, (16,))
|
||||
val_results = self.cnn_trainer.train_step(val_x, val_y)
|
||||
|
||||
# Save checkpoint (automatic based on performance)
|
||||
checkpoint_saved = self.cnn_trainer.save_checkpoint(
|
||||
train_accuracy=results.get('accuracy', 0.5),
|
||||
val_accuracy=val_results.get('accuracy', 0.5),
|
||||
train_loss=results.get('total_loss', 1.0),
|
||||
val_loss=val_results.get('total_loss', 1.0)
|
||||
)
|
||||
|
||||
if checkpoint_saved:
|
||||
self.training_stats['checkpoints_saved'] += 1
|
||||
|
||||
return {
|
||||
'status': 'completed',
|
||||
'train_accuracy': results.get('accuracy', 0.5),
|
||||
'val_accuracy': val_results.get('accuracy', 0.5),
|
||||
'train_loss': results.get('total_loss', 1.0),
|
||||
'val_loss': val_results.get('total_loss', 1.0),
|
||||
'checkpoint_saved': checkpoint_saved,
|
||||
'epoch': self.cnn_trainer.epoch_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training CNN model: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
async def _train_extrema_detector(self) -> Dict[str, Any]:
|
||||
"""Train extrema detector with automatic checkpointing"""
|
||||
try:
|
||||
if not self.extrema_trainer:
|
||||
return {'status': 'skipped', 'reason': 'no_trainer'}
|
||||
|
||||
# Update context data and detect extrema
|
||||
update_results = self.extrema_trainer.update_context_data()
|
||||
|
||||
# Get training data
|
||||
extrema_data = self.extrema_trainer.get_extrema_training_data(count=10)
|
||||
|
||||
# Simulate training accuracy improvement
|
||||
if extrema_data:
|
||||
self.extrema_trainer.training_stats['total_extrema_detected'] += len(extrema_data)
|
||||
self.extrema_trainer.training_stats['successful_predictions'] += len(extrema_data) // 2
|
||||
self.extrema_trainer.training_stats['failed_predictions'] += len(extrema_data) // 2
|
||||
|
||||
# Save checkpoint (automatic based on performance)
|
||||
checkpoint_saved = self.extrema_trainer.save_checkpoint()
|
||||
|
||||
if checkpoint_saved:
|
||||
self.training_stats['checkpoints_saved'] += 1
|
||||
|
||||
return {
|
||||
'status': 'completed',
|
||||
'extrema_detected': len(extrema_data),
|
||||
'context_updates': sum(1 for success in update_results.values() if success),
|
||||
'checkpoint_saved': checkpoint_saved,
|
||||
'session': self.extrema_trainer.training_session_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training extrema detector: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
async def _process_negative_cases(self) -> Dict[str, Any]:
|
||||
"""Process negative cases with automatic checkpointing"""
|
||||
try:
|
||||
if not self.negative_case_trainer:
|
||||
return {'status': 'skipped', 'reason': 'no_trainer'}
|
||||
|
||||
# Simulate adding a negative case
|
||||
if np.random.random() < 0.1: # 10% chance of negative case
|
||||
trade_info = {
|
||||
'symbol': 'ETH/USDT',
|
||||
'action': 'BUY',
|
||||
'price': 2000.0,
|
||||
'pnl': -50.0, # Loss
|
||||
'value': 1000.0,
|
||||
'confidence': 0.7,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
market_data = {
|
||||
'exit_price': 1950.0,
|
||||
'state_before': {},
|
||||
'state_after': {},
|
||||
'tick_data': [],
|
||||
'technical_indicators': {}
|
||||
}
|
||||
|
||||
case_id = self.negative_case_trainer.add_losing_trade(trade_info, market_data)
|
||||
|
||||
# Simulate loss improvement
|
||||
loss_improvement = np.random.random() * 0.1
|
||||
|
||||
# Save checkpoint (automatic based on performance)
|
||||
checkpoint_saved = self.negative_case_trainer.save_checkpoint(loss_improvement)
|
||||
|
||||
if checkpoint_saved:
|
||||
self.training_stats['checkpoints_saved'] += 1
|
||||
|
||||
return {
|
||||
'status': 'completed',
|
||||
'case_added': case_id,
|
||||
'loss_improvement': loss_improvement,
|
||||
'checkpoint_saved': checkpoint_saved,
|
||||
'session': self.negative_case_trainer.training_session_count
|
||||
}
|
||||
else:
|
||||
return {'status': 'no_cases'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing negative cases: {e}")
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
|
||||
async def _coordinate_checkpoint_saving(self, dqn_results: Dict, cnn_results: Dict,
|
||||
extrema_results: Dict, negative_results: Dict):
|
||||
"""Coordinate checkpoint saving across all components"""
|
||||
try:
|
||||
# Count successful checkpoints
|
||||
checkpoints_saved = sum([
|
||||
dqn_results.get('checkpoint_saved', False),
|
||||
cnn_results.get('checkpoint_saved', False),
|
||||
extrema_results.get('checkpoint_saved', False),
|
||||
negative_results.get('checkpoint_saved', False)
|
||||
])
|
||||
|
||||
if checkpoints_saved > 0:
|
||||
logger.info(f"Saved {checkpoints_saved} checkpoints this cycle")
|
||||
|
||||
# Update best performances
|
||||
if 'episode_reward' in dqn_results:
|
||||
current_best = self.training_stats['best_performances'].get('dqn_reward', float('-inf'))
|
||||
if dqn_results['episode_reward'] > current_best:
|
||||
self.training_stats['best_performances']['dqn_reward'] = dqn_results['episode_reward']
|
||||
|
||||
if 'val_accuracy' in cnn_results:
|
||||
current_best = self.training_stats['best_performances'].get('cnn_accuracy', 0.0)
|
||||
if cnn_results['val_accuracy'] > current_best:
|
||||
self.training_stats['best_performances']['cnn_accuracy'] = cnn_results['val_accuracy']
|
||||
|
||||
# Log checkpoint statistics every 10 cycles
|
||||
if self.training_stats['total_training_sessions'] % 10 == 0:
|
||||
await self._log_checkpoint_statistics()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error coordinating checkpoint saving: {e}")
|
||||
|
||||
async def _log_checkpoint_statistics(self):
|
||||
"""Log comprehensive checkpoint statistics"""
|
||||
try:
|
||||
stats = get_checkpoint_stats()
|
||||
|
||||
logger.info("=== Checkpoint Statistics ===")
|
||||
logger.info(f"Total checkpoints: {stats['total_checkpoints']}")
|
||||
logger.info(f"Total size: {stats['total_size_mb']:.2f} MB")
|
||||
logger.info(f"Models managed: {len(stats['models'])}")
|
||||
|
||||
for model_name, model_stats in stats['models'].items():
|
||||
logger.info(f" {model_name}: {model_stats['checkpoint_count']} checkpoints, "
|
||||
f"{model_stats['total_size_mb']:.2f} MB, "
|
||||
f"best: {model_stats['best_performance']:.4f}")
|
||||
|
||||
logger.info(f"Training sessions: {self.training_stats['total_training_sessions']}")
|
||||
logger.info(f"Checkpoints saved: {self.training_stats['checkpoints_saved']}")
|
||||
logger.info(f"Best performances: {self.training_stats['best_performances']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging checkpoint statistics: {e}")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown the training system and save final checkpoints"""
|
||||
logger.info("Shutting down checkpoint-integrated training system...")
|
||||
|
||||
self.running = False
|
||||
|
||||
try:
|
||||
# Force save checkpoints for all components
|
||||
if self.dqn_agent:
|
||||
self.dqn_agent.save_checkpoint(0.0, force_save=True)
|
||||
|
||||
if self.cnn_trainer:
|
||||
self.cnn_trainer.save_checkpoint(0.0, 0.0, 0.0, 0.0, force_save=True)
|
||||
|
||||
if self.extrema_trainer:
|
||||
self.extrema_trainer.save_checkpoint(force_save=True)
|
||||
|
||||
if self.negative_case_trainer:
|
||||
self.negative_case_trainer.save_checkpoint(force_save=True)
|
||||
|
||||
# Final statistics
|
||||
await self._log_checkpoint_statistics()
|
||||
|
||||
logger.info("Checkpoint-integrated training system shutdown complete")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during shutdown: {e}")
|
||||
|
||||
async def main():
|
||||
"""Main function to run the checkpoint-integrated training system"""
|
||||
logger.info("🚀 Starting Checkpoint-Integrated Training System")
|
||||
|
||||
# Create and initialize the training system
|
||||
training_system = CheckpointIntegratedTrainingSystem()
|
||||
|
||||
# Setup signal handlers for graceful shutdown
|
||||
def signal_handler(signum, frame):
|
||||
logger.info("Received shutdown signal")
|
||||
asyncio.create_task(training_system.shutdown())
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
# Initialize components
|
||||
await training_system.initialize_components()
|
||||
|
||||
# Run the integrated training loop
|
||||
await training_system.run_integrated_training_loop()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main: {e}")
|
||||
raise
|
||||
finally:
|
||||
await training_system.shutdown()
|
||||
|
||||
logger.info("✅ Checkpoint management integration complete!")
|
||||
logger.info("All training pipelines now support automatic checkpointing")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Ensure logs directory exists
|
||||
Path("logs").mkdir(exist_ok=True)
|
||||
|
||||
# Run the checkpoint-integrated training system
|
||||
asyncio.run(main())
|
Reference in New Issue
Block a user