checkbox manager and handling

This commit is contained in:
Dobromir Popov
2025-06-24 21:59:23 +03:00
parent 706eb13912
commit ab8c94d735
8 changed files with 1170 additions and 29 deletions

View File

@ -19,6 +19,11 @@ from collections import deque
import numpy as np
import pandas as pd
# Import checkpoint management
import torch
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
from utils.training_integration import get_training_integration
logger = logging.getLogger(__name__)
@dataclass
@ -57,7 +62,7 @@ class TrainingSession:
class NegativeCaseTrainer:
"""
Intensive trainer focused on learning from losing trades
Intensive trainer focused on learning from losing trades with checkpoint management
Features:
- Stores all losing trades as negative cases
@ -65,15 +70,25 @@ class NegativeCaseTrainer:
- Simultaneous inference and training
- Persistent storage in testcases/negative
- Priority-based training (bigger losses = higher priority)
- Checkpoint management for training progress
"""
def __init__(self, storage_dir: str = "testcases/negative"):
def __init__(self, storage_dir: str = "testcases/negative",
model_name: str = "negative_case_trainer", enable_checkpoints: bool = True):
self.storage_dir = storage_dir
self.stored_cases: List[NegativeCase] = []
self.training_queue = deque(maxlen=1000)
self.training_lock = threading.Lock()
self.inference_lock = threading.Lock()
# Checkpoint management
self.model_name = model_name
self.enable_checkpoints = enable_checkpoints
self.training_integration = get_training_integration() if enable_checkpoints else None
self.training_session_count = 0
self.best_loss_reduction = 0.0
self.checkpoint_frequency = 25 # Save checkpoint every 25 training sessions
# Training configuration
self.max_concurrent_training = 3 # Max parallel training sessions
self.intensive_training_epochs = 50 # Epochs per negative case
@ -93,12 +108,17 @@ class NegativeCaseTrainer:
self._initialize_storage()
self._load_existing_cases()
# Load best checkpoint if available
if self.enable_checkpoints:
self.load_best_checkpoint()
# Start background training thread
self.training_thread = threading.Thread(target=self._background_training_loop, daemon=True)
self.training_thread.start()
logger.info(f"NegativeCaseTrainer initialized with {len(self.stored_cases)} existing cases")
logger.info(f"Storage directory: {self.storage_dir}")
logger.info(f"Checkpoint management: {enable_checkpoints}, Model name: {model_name}")
logger.info("Background training thread started")
def _initialize_storage(self):
@ -469,4 +489,107 @@ class NegativeCaseTrainer:
logger.warning(f"Added {len(self.stored_cases)} cases to retraining queue")
except Exception as e:
logger.error(f"Error retraining all cases: {e}")
logger.error(f"Error retraining all cases: {e}")
def load_best_checkpoint(self):
"""Load the best checkpoint for this negative case trainer"""
try:
if not self.enable_checkpoints:
return
result = load_best_checkpoint(self.model_name)
if result:
file_path, metadata = result
checkpoint = torch.load(file_path, map_location='cpu')
# Load training state
if 'training_session_count' in checkpoint:
self.training_session_count = checkpoint['training_session_count']
if 'best_loss_reduction' in checkpoint:
self.best_loss_reduction = checkpoint['best_loss_reduction']
if 'total_cases_processed' in checkpoint:
self.total_cases_processed = checkpoint['total_cases_processed']
if 'total_training_time' in checkpoint:
self.total_training_time = checkpoint['total_training_time']
if 'accuracy_improvements' in checkpoint:
self.accuracy_improvements = checkpoint['accuracy_improvements']
logger.info(f"Loaded NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}")
logger.info(f"Session: {self.training_session_count}, Best loss reduction: {self.best_loss_reduction:.4f}")
except Exception as e:
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
def save_checkpoint(self, loss_improvement: float = 0.0, force_save: bool = False):
"""Save checkpoint if performance improved or forced"""
try:
if not self.enable_checkpoints:
return False
self.training_session_count += 1
# Update best loss reduction
improved = False
if loss_improvement > self.best_loss_reduction:
self.best_loss_reduction = loss_improvement
improved = True
# Save checkpoint if improved, forced, or at regular intervals
should_save = (
force_save or
improved or
self.training_session_count % self.checkpoint_frequency == 0
)
if should_save:
# Prepare checkpoint data
checkpoint_data = {
'training_session_count': self.training_session_count,
'best_loss_reduction': self.best_loss_reduction,
'total_cases_processed': self.total_cases_processed,
'total_training_time': self.total_training_time,
'accuracy_improvements': self.accuracy_improvements,
'storage_dir': self.storage_dir,
'max_concurrent_training': self.max_concurrent_training,
'intensive_training_epochs': self.intensive_training_epochs
}
# Create performance metrics for checkpoint manager
avg_accuracy_improvement = (
sum(self.accuracy_improvements) / len(self.accuracy_improvements)
if self.accuracy_improvements else 0.0
)
performance_metrics = {
'loss_reduction': self.best_loss_reduction,
'avg_accuracy_improvement': avg_accuracy_improvement,
'total_cases_processed': self.total_cases_processed,
'training_efficiency': (
self.total_cases_processed / self.total_training_time
if self.total_training_time > 0 else 0.0
)
}
# Save using checkpoint manager
metadata = save_checkpoint(
model=checkpoint_data, # We're saving data dict instead of model
model_name=self.model_name,
model_type="negative_case_trainer",
performance_metrics=performance_metrics,
training_metadata={
'session': self.training_session_count,
'cases_processed': self.total_cases_processed,
'training_time_hours': self.total_training_time / 3600
},
force_save=force_save
)
if metadata:
logger.info(f"Saved NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}")
return True
return False
except Exception as e:
logger.error(f"Error saving NegativeCaseTrainer checkpoint: {e}")
return False