checkbox manager and handling
This commit is contained in:
@ -18,6 +18,14 @@ from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
from collections import deque
|
||||
import os
|
||||
import pickle
|
||||
import json
|
||||
|
||||
# Import checkpoint management
|
||||
import torch
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -44,9 +52,10 @@ class ContextData:
|
||||
last_update: datetime
|
||||
|
||||
class ExtremaTrainer:
|
||||
"""Reusable extrema detection and training functionality"""
|
||||
"""Reusable extrema detection and training functionality with checkpoint management"""
|
||||
|
||||
def __init__(self, data_provider, symbols: List[str], window_size: int = 10):
|
||||
def __init__(self, data_provider, symbols: List[str], window_size: int = 10,
|
||||
model_name: str = "extrema_trainer", enable_checkpoints: bool = True):
|
||||
"""
|
||||
Initialize the extrema trainer
|
||||
|
||||
@ -54,11 +63,21 @@ class ExtremaTrainer:
|
||||
data_provider: Data provider instance
|
||||
symbols: List of symbols to track
|
||||
window_size: Window size for extrema detection (default 10)
|
||||
model_name: Name for checkpoint management
|
||||
enable_checkpoints: Whether to enable checkpoint management
|
||||
"""
|
||||
self.data_provider = data_provider
|
||||
self.symbols = symbols
|
||||
self.window_size = window_size
|
||||
|
||||
# Checkpoint management
|
||||
self.model_name = model_name
|
||||
self.enable_checkpoints = enable_checkpoints
|
||||
self.training_integration = get_training_integration() if enable_checkpoints else None
|
||||
self.training_session_count = 0
|
||||
self.best_detection_accuracy = 0.0
|
||||
self.checkpoint_frequency = 50 # Save checkpoint every 50 training sessions
|
||||
|
||||
# Extrema tracking
|
||||
self.detected_extrema = {symbol: deque(maxlen=1000) for symbol in symbols}
|
||||
self.extrema_training_queue = deque(maxlen=500)
|
||||
@ -78,8 +97,125 @@ class ExtremaTrainer:
|
||||
self.min_confidence_threshold = 0.3 # Train on opportunities with at least 30% confidence
|
||||
self.max_confidence_threshold = 0.95 # Cap confidence at 95%
|
||||
|
||||
# Performance tracking
|
||||
self.training_stats = {
|
||||
'total_extrema_detected': 0,
|
||||
'successful_predictions': 0,
|
||||
'failed_predictions': 0,
|
||||
'detection_accuracy': 0.0,
|
||||
'last_training_time': None
|
||||
}
|
||||
|
||||
# Load best checkpoint if available
|
||||
if self.enable_checkpoints:
|
||||
self.load_best_checkpoint()
|
||||
|
||||
logger.info(f"ExtremaTrainer initialized for symbols: {symbols}")
|
||||
logger.info(f"Window size: {window_size}, Context update frequency: {self.context_update_frequency}s")
|
||||
logger.info(f"Checkpoint management: {enable_checkpoints}, Model name: {model_name}")
|
||||
|
||||
def load_best_checkpoint(self):
|
||||
"""Load the best checkpoint for this extrema trainer"""
|
||||
try:
|
||||
if not self.enable_checkpoints:
|
||||
return
|
||||
|
||||
result = load_best_checkpoint(self.model_name)
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
checkpoint = torch.load(file_path, map_location='cpu')
|
||||
|
||||
# Load training state
|
||||
if 'training_session_count' in checkpoint:
|
||||
self.training_session_count = checkpoint['training_session_count']
|
||||
if 'best_detection_accuracy' in checkpoint:
|
||||
self.best_detection_accuracy = checkpoint['best_detection_accuracy']
|
||||
if 'training_stats' in checkpoint:
|
||||
self.training_stats = checkpoint['training_stats']
|
||||
if 'detected_extrema' in checkpoint:
|
||||
# Convert back to deques
|
||||
for symbol, extrema_list in checkpoint['detected_extrema'].items():
|
||||
if symbol in self.detected_extrema:
|
||||
self.detected_extrema[symbol] = deque(extrema_list, maxlen=1000)
|
||||
|
||||
logger.info(f"Loaded ExtremaTrainer checkpoint: {metadata.checkpoint_id}")
|
||||
logger.info(f"Session: {self.training_session_count}, Best accuracy: {self.best_detection_accuracy:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
|
||||
|
||||
def save_checkpoint(self, force_save: bool = False):
|
||||
"""Save checkpoint if performance improved or forced"""
|
||||
try:
|
||||
if not self.enable_checkpoints:
|
||||
return False
|
||||
|
||||
self.training_session_count += 1
|
||||
|
||||
# Calculate current detection accuracy
|
||||
total_predictions = self.training_stats['successful_predictions'] + self.training_stats['failed_predictions']
|
||||
current_accuracy = (
|
||||
self.training_stats['successful_predictions'] / total_predictions
|
||||
if total_predictions > 0 else 0.0
|
||||
)
|
||||
|
||||
# Update best accuracy
|
||||
improved = False
|
||||
if current_accuracy > self.best_detection_accuracy:
|
||||
self.best_detection_accuracy = current_accuracy
|
||||
improved = True
|
||||
|
||||
# Save checkpoint if improved, forced, or at regular intervals
|
||||
should_save = (
|
||||
force_save or
|
||||
improved or
|
||||
self.training_session_count % self.checkpoint_frequency == 0
|
||||
)
|
||||
|
||||
if should_save:
|
||||
# Prepare checkpoint data
|
||||
checkpoint_data = {
|
||||
'training_session_count': self.training_session_count,
|
||||
'best_detection_accuracy': self.best_detection_accuracy,
|
||||
'training_stats': self.training_stats,
|
||||
'detected_extrema': {
|
||||
symbol: list(extrema_deque)
|
||||
for symbol, extrema_deque in self.detected_extrema.items()
|
||||
},
|
||||
'window_size': self.window_size,
|
||||
'symbols': self.symbols
|
||||
}
|
||||
|
||||
# Create performance metrics for checkpoint manager
|
||||
performance_metrics = {
|
||||
'accuracy': current_accuracy,
|
||||
'total_extrema_detected': self.training_stats['total_extrema_detected'],
|
||||
'successful_predictions': self.training_stats['successful_predictions']
|
||||
}
|
||||
|
||||
# Save using checkpoint manager
|
||||
metadata = save_checkpoint(
|
||||
model=checkpoint_data, # We're saving data dict instead of model
|
||||
model_name=self.model_name,
|
||||
model_type="extrema_trainer",
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata={
|
||||
'session': self.training_session_count,
|
||||
'symbols': self.symbols,
|
||||
'window_size': self.window_size
|
||||
},
|
||||
force_save=force_save
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f"Saved ExtremaTrainer checkpoint: {metadata.checkpoint_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving ExtremaTrainer checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def initialize_context_data(self) -> Dict[str, bool]:
|
||||
"""Initialize 200-candle 1m context data for all symbols"""
|
||||
|
Reference in New Issue
Block a user