checkbox manager and handling
This commit is contained in:
@ -18,6 +18,14 @@ from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
from collections import deque
|
||||
import os
|
||||
import pickle
|
||||
import json
|
||||
|
||||
# Import checkpoint management
|
||||
import torch
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -44,9 +52,10 @@ class ContextData:
|
||||
last_update: datetime
|
||||
|
||||
class ExtremaTrainer:
|
||||
"""Reusable extrema detection and training functionality"""
|
||||
"""Reusable extrema detection and training functionality with checkpoint management"""
|
||||
|
||||
def __init__(self, data_provider, symbols: List[str], window_size: int = 10):
|
||||
def __init__(self, data_provider, symbols: List[str], window_size: int = 10,
|
||||
model_name: str = "extrema_trainer", enable_checkpoints: bool = True):
|
||||
"""
|
||||
Initialize the extrema trainer
|
||||
|
||||
@ -54,11 +63,21 @@ class ExtremaTrainer:
|
||||
data_provider: Data provider instance
|
||||
symbols: List of symbols to track
|
||||
window_size: Window size for extrema detection (default 10)
|
||||
model_name: Name for checkpoint management
|
||||
enable_checkpoints: Whether to enable checkpoint management
|
||||
"""
|
||||
self.data_provider = data_provider
|
||||
self.symbols = symbols
|
||||
self.window_size = window_size
|
||||
|
||||
# Checkpoint management
|
||||
self.model_name = model_name
|
||||
self.enable_checkpoints = enable_checkpoints
|
||||
self.training_integration = get_training_integration() if enable_checkpoints else None
|
||||
self.training_session_count = 0
|
||||
self.best_detection_accuracy = 0.0
|
||||
self.checkpoint_frequency = 50 # Save checkpoint every 50 training sessions
|
||||
|
||||
# Extrema tracking
|
||||
self.detected_extrema = {symbol: deque(maxlen=1000) for symbol in symbols}
|
||||
self.extrema_training_queue = deque(maxlen=500)
|
||||
@ -78,8 +97,125 @@ class ExtremaTrainer:
|
||||
self.min_confidence_threshold = 0.3 # Train on opportunities with at least 30% confidence
|
||||
self.max_confidence_threshold = 0.95 # Cap confidence at 95%
|
||||
|
||||
# Performance tracking
|
||||
self.training_stats = {
|
||||
'total_extrema_detected': 0,
|
||||
'successful_predictions': 0,
|
||||
'failed_predictions': 0,
|
||||
'detection_accuracy': 0.0,
|
||||
'last_training_time': None
|
||||
}
|
||||
|
||||
# Load best checkpoint if available
|
||||
if self.enable_checkpoints:
|
||||
self.load_best_checkpoint()
|
||||
|
||||
logger.info(f"ExtremaTrainer initialized for symbols: {symbols}")
|
||||
logger.info(f"Window size: {window_size}, Context update frequency: {self.context_update_frequency}s")
|
||||
logger.info(f"Checkpoint management: {enable_checkpoints}, Model name: {model_name}")
|
||||
|
||||
def load_best_checkpoint(self):
|
||||
"""Load the best checkpoint for this extrema trainer"""
|
||||
try:
|
||||
if not self.enable_checkpoints:
|
||||
return
|
||||
|
||||
result = load_best_checkpoint(self.model_name)
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
checkpoint = torch.load(file_path, map_location='cpu')
|
||||
|
||||
# Load training state
|
||||
if 'training_session_count' in checkpoint:
|
||||
self.training_session_count = checkpoint['training_session_count']
|
||||
if 'best_detection_accuracy' in checkpoint:
|
||||
self.best_detection_accuracy = checkpoint['best_detection_accuracy']
|
||||
if 'training_stats' in checkpoint:
|
||||
self.training_stats = checkpoint['training_stats']
|
||||
if 'detected_extrema' in checkpoint:
|
||||
# Convert back to deques
|
||||
for symbol, extrema_list in checkpoint['detected_extrema'].items():
|
||||
if symbol in self.detected_extrema:
|
||||
self.detected_extrema[symbol] = deque(extrema_list, maxlen=1000)
|
||||
|
||||
logger.info(f"Loaded ExtremaTrainer checkpoint: {metadata.checkpoint_id}")
|
||||
logger.info(f"Session: {self.training_session_count}, Best accuracy: {self.best_detection_accuracy:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
|
||||
|
||||
def save_checkpoint(self, force_save: bool = False):
|
||||
"""Save checkpoint if performance improved or forced"""
|
||||
try:
|
||||
if not self.enable_checkpoints:
|
||||
return False
|
||||
|
||||
self.training_session_count += 1
|
||||
|
||||
# Calculate current detection accuracy
|
||||
total_predictions = self.training_stats['successful_predictions'] + self.training_stats['failed_predictions']
|
||||
current_accuracy = (
|
||||
self.training_stats['successful_predictions'] / total_predictions
|
||||
if total_predictions > 0 else 0.0
|
||||
)
|
||||
|
||||
# Update best accuracy
|
||||
improved = False
|
||||
if current_accuracy > self.best_detection_accuracy:
|
||||
self.best_detection_accuracy = current_accuracy
|
||||
improved = True
|
||||
|
||||
# Save checkpoint if improved, forced, or at regular intervals
|
||||
should_save = (
|
||||
force_save or
|
||||
improved or
|
||||
self.training_session_count % self.checkpoint_frequency == 0
|
||||
)
|
||||
|
||||
if should_save:
|
||||
# Prepare checkpoint data
|
||||
checkpoint_data = {
|
||||
'training_session_count': self.training_session_count,
|
||||
'best_detection_accuracy': self.best_detection_accuracy,
|
||||
'training_stats': self.training_stats,
|
||||
'detected_extrema': {
|
||||
symbol: list(extrema_deque)
|
||||
for symbol, extrema_deque in self.detected_extrema.items()
|
||||
},
|
||||
'window_size': self.window_size,
|
||||
'symbols': self.symbols
|
||||
}
|
||||
|
||||
# Create performance metrics for checkpoint manager
|
||||
performance_metrics = {
|
||||
'accuracy': current_accuracy,
|
||||
'total_extrema_detected': self.training_stats['total_extrema_detected'],
|
||||
'successful_predictions': self.training_stats['successful_predictions']
|
||||
}
|
||||
|
||||
# Save using checkpoint manager
|
||||
metadata = save_checkpoint(
|
||||
model=checkpoint_data, # We're saving data dict instead of model
|
||||
model_name=self.model_name,
|
||||
model_type="extrema_trainer",
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata={
|
||||
'session': self.training_session_count,
|
||||
'symbols': self.symbols,
|
||||
'window_size': self.window_size
|
||||
},
|
||||
force_save=force_save
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f"Saved ExtremaTrainer checkpoint: {metadata.checkpoint_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving ExtremaTrainer checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def initialize_context_data(self) -> Dict[str, bool]:
|
||||
"""Initialize 200-candle 1m context data for all symbols"""
|
||||
|
@ -19,6 +19,11 @@ from collections import deque
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# Import checkpoint management
|
||||
import torch
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
@ -57,7 +62,7 @@ class TrainingSession:
|
||||
|
||||
class NegativeCaseTrainer:
|
||||
"""
|
||||
Intensive trainer focused on learning from losing trades
|
||||
Intensive trainer focused on learning from losing trades with checkpoint management
|
||||
|
||||
Features:
|
||||
- Stores all losing trades as negative cases
|
||||
@ -65,15 +70,25 @@ class NegativeCaseTrainer:
|
||||
- Simultaneous inference and training
|
||||
- Persistent storage in testcases/negative
|
||||
- Priority-based training (bigger losses = higher priority)
|
||||
- Checkpoint management for training progress
|
||||
"""
|
||||
|
||||
def __init__(self, storage_dir: str = "testcases/negative"):
|
||||
def __init__(self, storage_dir: str = "testcases/negative",
|
||||
model_name: str = "negative_case_trainer", enable_checkpoints: bool = True):
|
||||
self.storage_dir = storage_dir
|
||||
self.stored_cases: List[NegativeCase] = []
|
||||
self.training_queue = deque(maxlen=1000)
|
||||
self.training_lock = threading.Lock()
|
||||
self.inference_lock = threading.Lock()
|
||||
|
||||
# Checkpoint management
|
||||
self.model_name = model_name
|
||||
self.enable_checkpoints = enable_checkpoints
|
||||
self.training_integration = get_training_integration() if enable_checkpoints else None
|
||||
self.training_session_count = 0
|
||||
self.best_loss_reduction = 0.0
|
||||
self.checkpoint_frequency = 25 # Save checkpoint every 25 training sessions
|
||||
|
||||
# Training configuration
|
||||
self.max_concurrent_training = 3 # Max parallel training sessions
|
||||
self.intensive_training_epochs = 50 # Epochs per negative case
|
||||
@ -93,12 +108,17 @@ class NegativeCaseTrainer:
|
||||
self._initialize_storage()
|
||||
self._load_existing_cases()
|
||||
|
||||
# Load best checkpoint if available
|
||||
if self.enable_checkpoints:
|
||||
self.load_best_checkpoint()
|
||||
|
||||
# Start background training thread
|
||||
self.training_thread = threading.Thread(target=self._background_training_loop, daemon=True)
|
||||
self.training_thread.start()
|
||||
|
||||
logger.info(f"NegativeCaseTrainer initialized with {len(self.stored_cases)} existing cases")
|
||||
logger.info(f"Storage directory: {self.storage_dir}")
|
||||
logger.info(f"Checkpoint management: {enable_checkpoints}, Model name: {model_name}")
|
||||
logger.info("Background training thread started")
|
||||
|
||||
def _initialize_storage(self):
|
||||
@ -469,4 +489,107 @@ class NegativeCaseTrainer:
|
||||
logger.warning(f"Added {len(self.stored_cases)} cases to retraining queue")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retraining all cases: {e}")
|
||||
logger.error(f"Error retraining all cases: {e}")
|
||||
|
||||
def load_best_checkpoint(self):
|
||||
"""Load the best checkpoint for this negative case trainer"""
|
||||
try:
|
||||
if not self.enable_checkpoints:
|
||||
return
|
||||
|
||||
result = load_best_checkpoint(self.model_name)
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
checkpoint = torch.load(file_path, map_location='cpu')
|
||||
|
||||
# Load training state
|
||||
if 'training_session_count' in checkpoint:
|
||||
self.training_session_count = checkpoint['training_session_count']
|
||||
if 'best_loss_reduction' in checkpoint:
|
||||
self.best_loss_reduction = checkpoint['best_loss_reduction']
|
||||
if 'total_cases_processed' in checkpoint:
|
||||
self.total_cases_processed = checkpoint['total_cases_processed']
|
||||
if 'total_training_time' in checkpoint:
|
||||
self.total_training_time = checkpoint['total_training_time']
|
||||
if 'accuracy_improvements' in checkpoint:
|
||||
self.accuracy_improvements = checkpoint['accuracy_improvements']
|
||||
|
||||
logger.info(f"Loaded NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}")
|
||||
logger.info(f"Session: {self.training_session_count}, Best loss reduction: {self.best_loss_reduction:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
|
||||
|
||||
def save_checkpoint(self, loss_improvement: float = 0.0, force_save: bool = False):
|
||||
"""Save checkpoint if performance improved or forced"""
|
||||
try:
|
||||
if not self.enable_checkpoints:
|
||||
return False
|
||||
|
||||
self.training_session_count += 1
|
||||
|
||||
# Update best loss reduction
|
||||
improved = False
|
||||
if loss_improvement > self.best_loss_reduction:
|
||||
self.best_loss_reduction = loss_improvement
|
||||
improved = True
|
||||
|
||||
# Save checkpoint if improved, forced, or at regular intervals
|
||||
should_save = (
|
||||
force_save or
|
||||
improved or
|
||||
self.training_session_count % self.checkpoint_frequency == 0
|
||||
)
|
||||
|
||||
if should_save:
|
||||
# Prepare checkpoint data
|
||||
checkpoint_data = {
|
||||
'training_session_count': self.training_session_count,
|
||||
'best_loss_reduction': self.best_loss_reduction,
|
||||
'total_cases_processed': self.total_cases_processed,
|
||||
'total_training_time': self.total_training_time,
|
||||
'accuracy_improvements': self.accuracy_improvements,
|
||||
'storage_dir': self.storage_dir,
|
||||
'max_concurrent_training': self.max_concurrent_training,
|
||||
'intensive_training_epochs': self.intensive_training_epochs
|
||||
}
|
||||
|
||||
# Create performance metrics for checkpoint manager
|
||||
avg_accuracy_improvement = (
|
||||
sum(self.accuracy_improvements) / len(self.accuracy_improvements)
|
||||
if self.accuracy_improvements else 0.0
|
||||
)
|
||||
|
||||
performance_metrics = {
|
||||
'loss_reduction': self.best_loss_reduction,
|
||||
'avg_accuracy_improvement': avg_accuracy_improvement,
|
||||
'total_cases_processed': self.total_cases_processed,
|
||||
'training_efficiency': (
|
||||
self.total_cases_processed / self.total_training_time
|
||||
if self.total_training_time > 0 else 0.0
|
||||
)
|
||||
}
|
||||
|
||||
# Save using checkpoint manager
|
||||
metadata = save_checkpoint(
|
||||
model=checkpoint_data, # We're saving data dict instead of model
|
||||
model_name=self.model_name,
|
||||
model_type="negative_case_trainer",
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata={
|
||||
'session': self.training_session_count,
|
||||
'cases_processed': self.total_cases_processed,
|
||||
'training_time_hours': self.total_training_time / 3600
|
||||
},
|
||||
force_save=force_save
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f"Saved NegativeCaseTrainer checkpoint: {metadata.checkpoint_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving NegativeCaseTrainer checkpoint: {e}")
|
||||
return False
|
Reference in New Issue
Block a user