checkbox manager and handling

This commit is contained in:
Dobromir Popov
2025-06-24 21:59:23 +03:00
parent 706eb13912
commit ab8c94d735
8 changed files with 1170 additions and 29 deletions

View File

@ -18,6 +18,14 @@ from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from collections import deque
import os
import pickle
import json
# Import checkpoint management
import torch
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
from utils.training_integration import get_training_integration
logger = logging.getLogger(__name__)
@ -44,9 +52,10 @@ class ContextData:
last_update: datetime
class ExtremaTrainer:
"""Reusable extrema detection and training functionality"""
"""Reusable extrema detection and training functionality with checkpoint management"""
def __init__(self, data_provider, symbols: List[str], window_size: int = 10):
def __init__(self, data_provider, symbols: List[str], window_size: int = 10,
model_name: str = "extrema_trainer", enable_checkpoints: bool = True):
"""
Initialize the extrema trainer
@ -54,11 +63,21 @@ class ExtremaTrainer:
data_provider: Data provider instance
symbols: List of symbols to track
window_size: Window size for extrema detection (default 10)
model_name: Name for checkpoint management
enable_checkpoints: Whether to enable checkpoint management
"""
self.data_provider = data_provider
self.symbols = symbols
self.window_size = window_size
# Checkpoint management
self.model_name = model_name
self.enable_checkpoints = enable_checkpoints
self.training_integration = get_training_integration() if enable_checkpoints else None
self.training_session_count = 0
self.best_detection_accuracy = 0.0
self.checkpoint_frequency = 50 # Save checkpoint every 50 training sessions
# Extrema tracking
self.detected_extrema = {symbol: deque(maxlen=1000) for symbol in symbols}
self.extrema_training_queue = deque(maxlen=500)
@ -78,8 +97,125 @@ class ExtremaTrainer:
self.min_confidence_threshold = 0.3 # Train on opportunities with at least 30% confidence
self.max_confidence_threshold = 0.95 # Cap confidence at 95%
# Performance tracking
self.training_stats = {
'total_extrema_detected': 0,
'successful_predictions': 0,
'failed_predictions': 0,
'detection_accuracy': 0.0,
'last_training_time': None
}
# Load best checkpoint if available
if self.enable_checkpoints:
self.load_best_checkpoint()
logger.info(f"ExtremaTrainer initialized for symbols: {symbols}")
logger.info(f"Window size: {window_size}, Context update frequency: {self.context_update_frequency}s")
logger.info(f"Checkpoint management: {enable_checkpoints}, Model name: {model_name}")
def load_best_checkpoint(self):
"""Load the best checkpoint for this extrema trainer"""
try:
if not self.enable_checkpoints:
return
result = load_best_checkpoint(self.model_name)
if result:
file_path, metadata = result
checkpoint = torch.load(file_path, map_location='cpu')
# Load training state
if 'training_session_count' in checkpoint:
self.training_session_count = checkpoint['training_session_count']
if 'best_detection_accuracy' in checkpoint:
self.best_detection_accuracy = checkpoint['best_detection_accuracy']
if 'training_stats' in checkpoint:
self.training_stats = checkpoint['training_stats']
if 'detected_extrema' in checkpoint:
# Convert back to deques
for symbol, extrema_list in checkpoint['detected_extrema'].items():
if symbol in self.detected_extrema:
self.detected_extrema[symbol] = deque(extrema_list, maxlen=1000)
logger.info(f"Loaded ExtremaTrainer checkpoint: {metadata.checkpoint_id}")
logger.info(f"Session: {self.training_session_count}, Best accuracy: {self.best_detection_accuracy:.4f}")
except Exception as e:
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
def save_checkpoint(self, force_save: bool = False):
"""Save checkpoint if performance improved or forced"""
try:
if not self.enable_checkpoints:
return False
self.training_session_count += 1
# Calculate current detection accuracy
total_predictions = self.training_stats['successful_predictions'] + self.training_stats['failed_predictions']
current_accuracy = (
self.training_stats['successful_predictions'] / total_predictions
if total_predictions > 0 else 0.0
)
# Update best accuracy
improved = False
if current_accuracy > self.best_detection_accuracy:
self.best_detection_accuracy = current_accuracy
improved = True
# Save checkpoint if improved, forced, or at regular intervals
should_save = (
force_save or
improved or
self.training_session_count % self.checkpoint_frequency == 0
)
if should_save:
# Prepare checkpoint data
checkpoint_data = {
'training_session_count': self.training_session_count,
'best_detection_accuracy': self.best_detection_accuracy,
'training_stats': self.training_stats,
'detected_extrema': {
symbol: list(extrema_deque)
for symbol, extrema_deque in self.detected_extrema.items()
},
'window_size': self.window_size,
'symbols': self.symbols
}
# Create performance metrics for checkpoint manager
performance_metrics = {
'accuracy': current_accuracy,
'total_extrema_detected': self.training_stats['total_extrema_detected'],
'successful_predictions': self.training_stats['successful_predictions']
}
# Save using checkpoint manager
metadata = save_checkpoint(
model=checkpoint_data, # We're saving data dict instead of model
model_name=self.model_name,
model_type="extrema_trainer",
performance_metrics=performance_metrics,
training_metadata={
'session': self.training_session_count,
'symbols': self.symbols,
'window_size': self.window_size
},
force_save=force_save
)
if metadata:
logger.info(f"Saved ExtremaTrainer checkpoint: {metadata.checkpoint_id}")
return True
return False
except Exception as e:
logger.error(f"Error saving ExtremaTrainer checkpoint: {e}")
return False
def initialize_context_data(self) -> Dict[str, bool]:
"""Initialize 200-candle 1m context data for all symbols"""

View File

@ -19,6 +19,11 @@ from collections import deque
import numpy as np
import pandas as pd
# Import checkpoint management
import torch
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
from utils.training_integration import get_training_integration
logger = logging.getLogger(__name__)
@dataclass
@ -57,7 +62,7 @@ class TrainingSession:
class NegativeCaseTrainer:
"""
Intensive trainer focused on learning from losing trades
Intensive trainer focused on learning from losing trades with checkpoint management
Features:
- Stores all losing trades as negative cases
@ -65,15 +70,25 @@ class NegativeCaseTrainer:
- Simultaneous inference and training
- Persistent storage in testcases/negative
- Priority-based training (bigger losses = higher priority)
- Checkpoint management for training progress
"""
def __init__(self, storage_dir: str = "testcases/negative"):
def __init__(self, storage_dir: str = "testcases/negative",
model_name: str = "negative_case_trainer", enable_checkpoints: bool = True):
self.storage_dir = storage_dir
self.stored_cases: List[NegativeCase] = []
self.training_queue = deque(maxlen=1000)
self.training_lock = threading.Lock()
self.inference_lock = threading.Lock()
# Checkpoint management
self.model_name = model_name
self.enable_checkpoints = enable_checkpoints
self.training_integration = get_training_integration() if enable_checkpoints else None
self.training_session_count = 0
self.best_loss_reduction = 0.0
self.checkpoint_frequency = 25 # Save checkpoint every 25 training sessions
# Training configuration
self.max_concurrent_training = 3 # Max parallel training sessions
self.intensive_training_epochs = 50 # Epochs per negative case
@ -93,12 +108,17 @@ class NegativeCaseTrainer:
self._initialize_storage()
self._load_existing_cases()
# Load best checkpoint if available
if self.enable_checkpoints:
self.load_best_checkpoint()
# Start background training thread
self.training_thread = threading.Thread(target=self._background_training_loop, daemon=True)
self.training_thread.start()
logger.info(f"NegativeCaseTrainer initialized with {len(self.stored_cases)} existing cases")
logger.info(f"Storage directory: {self.storage_dir}")
logger.info(f"Checkpoint management: {enable_checkpoints}, Model name: {model_name}")
logger.info("Background training thread started")
def _initialize_storage(self):
@ -469,4 +489,107 @@ class NegativeCaseTrainer:
logger.warning(f"Added {len(self.stored_cases)} cases to retraining queue")
except Exception as e:
logger.error(f"Error retraining all cases: {e}")
logger.error(f"Error retraining all cases: {e}")
def load_best_checkpoint(self):
"""Load the best checkpoint for this negative case trainer"""
try:
if not self.enable_checkpoints:
return
result = load_best_checkpoint(self.model_name)
if result:
file_path, metadata = result
checkpoint = torch.load(file_path, map_location='cpu')
# Load training state
if 'training_session_count' in checkpoint:
self.training_session_count = checkpoint['training_session_count']
if 'best_loss_reduction' in checkpoint:
self.best_loss_reduction = checkpoint['best_loss_reduction']
if 'total_cases_processed' in checkpoint:
self.total_cases_processed = checkpoint['total_cases_processed']
if 'total_training_time' in checkpoint:
self.total_training_time = checkpoint['total_training_time']
if 'accuracy_improvements' in checkpoint:
self.accuracy_improvements = checkpoint['accuracy_improvements']
logger.info(f"Loaded NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}")
logger.info(f"Session: {self.training_session_count}, Best loss reduction: {self.best_loss_reduction:.4f}")
except Exception as e:
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
def save_checkpoint(self, loss_improvement: float = 0.0, force_save: bool = False):
"""Save checkpoint if performance improved or forced"""
try:
if not self.enable_checkpoints:
return False
self.training_session_count += 1
# Update best loss reduction
improved = False
if loss_improvement > self.best_loss_reduction:
self.best_loss_reduction = loss_improvement
improved = True
# Save checkpoint if improved, forced, or at regular intervals
should_save = (
force_save or
improved or
self.training_session_count % self.checkpoint_frequency == 0
)
if should_save:
# Prepare checkpoint data
checkpoint_data = {
'training_session_count': self.training_session_count,
'best_loss_reduction': self.best_loss_reduction,
'total_cases_processed': self.total_cases_processed,
'total_training_time': self.total_training_time,
'accuracy_improvements': self.accuracy_improvements,
'storage_dir': self.storage_dir,
'max_concurrent_training': self.max_concurrent_training,
'intensive_training_epochs': self.intensive_training_epochs
}
# Create performance metrics for checkpoint manager
avg_accuracy_improvement = (
sum(self.accuracy_improvements) / len(self.accuracy_improvements)
if self.accuracy_improvements else 0.0
)
performance_metrics = {
'loss_reduction': self.best_loss_reduction,
'avg_accuracy_improvement': avg_accuracy_improvement,
'total_cases_processed': self.total_cases_processed,
'training_efficiency': (
self.total_cases_processed / self.total_training_time
if self.total_training_time > 0 else 0.0
)
}
# Save using checkpoint manager
metadata = save_checkpoint(
model=checkpoint_data, # We're saving data dict instead of model
model_name=self.model_name,
model_type="negative_case_trainer",
performance_metrics=performance_metrics,
training_metadata={
'session': self.training_session_count,
'cases_processed': self.total_cases_processed,
'training_time_hours': self.total_training_time / 3600
},
force_save=force_save
)
if metadata:
logger.info(f"Saved NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}")
return True
return False
except Exception as e:
logger.error(f"Error saving NegativeCaseTrainer checkpoint: {e}")
return False