logging channels; training steps storage
This commit is contained in:
@@ -138,9 +138,25 @@ class RealTrainingAdapter:
|
||||
self.data_provider = data_provider
|
||||
self.training_sessions: Dict[str, TrainingSession] = {}
|
||||
|
||||
# Real-time training tracking
|
||||
self.realtime_training_metrics = {
|
||||
'total_steps': 0,
|
||||
'total_loss': 0.0,
|
||||
'total_accuracy': 0.0,
|
||||
'best_loss': float('inf'),
|
||||
'best_accuracy': 0.0,
|
||||
'last_checkpoint_step': 0,
|
||||
'checkpoint_frequency': 100, # Save every N steps
|
||||
'losses': [], # Rolling window
|
||||
'accuracies': [] # Rolling window
|
||||
}
|
||||
|
||||
# Import real training systems
|
||||
self._import_training_systems()
|
||||
|
||||
# Load best realtime checkpoint if available
|
||||
self._load_best_realtime_checkpoint()
|
||||
|
||||
logger.info("RealTrainingAdapter initialized - NO SIMULATION, REAL TRAINING ONLY")
|
||||
|
||||
def _import_training_systems(self):
|
||||
@@ -2410,6 +2426,11 @@ class RealTrainingAdapter:
|
||||
'train_every_candle': train_every_candle,
|
||||
'timeframe': timeframe,
|
||||
'data_provider': data_provider,
|
||||
'metrics': {
|
||||
'accuracy': 0.0,
|
||||
'loss': 0.0,
|
||||
'steps': 0
|
||||
},
|
||||
'last_candle_time': None
|
||||
}
|
||||
|
||||
@@ -2711,6 +2732,13 @@ class RealTrainingAdapter:
|
||||
model_name = session['model_name']
|
||||
if model_name == 'Transformer':
|
||||
self._train_transformer_on_sample(training_sample)
|
||||
|
||||
# Update session metrics with latest realtime metrics
|
||||
if len(self.realtime_training_metrics['losses']) > 0:
|
||||
session['metrics']['loss'] = sum(self.realtime_training_metrics['losses']) / len(self.realtime_training_metrics['losses'])
|
||||
session['metrics']['accuracy'] = sum(self.realtime_training_metrics['accuracies']) / len(self.realtime_training_metrics['accuracies'])
|
||||
session['metrics']['steps'] = self.realtime_training_metrics['total_steps']
|
||||
|
||||
logger.info(f"Trained on candle: {symbol} {timeframe} @ {completed_candle.name} (change: {price_change:+.2%})")
|
||||
|
||||
except Exception as e:
|
||||
@@ -2740,7 +2768,7 @@ class RealTrainingAdapter:
|
||||
return {}
|
||||
|
||||
def _train_transformer_on_sample(self, training_sample: Dict):
|
||||
"""Train transformer on a single sample"""
|
||||
"""Train transformer on a single sample with checkpoint saving"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
return
|
||||
@@ -2760,12 +2788,249 @@ class RealTrainingAdapter:
|
||||
with torch.enable_grad():
|
||||
trainer.model.train()
|
||||
result = trainer.train_step(batch, accumulate_gradients=False)
|
||||
|
||||
if result:
|
||||
logger.info(f"Per-candle training: Loss={result.get('total_loss', 0):.4f}")
|
||||
loss = result.get('total_loss', 0)
|
||||
accuracy = result.get('accuracy', 0)
|
||||
|
||||
# Update metrics tracking
|
||||
self.realtime_training_metrics['total_steps'] += 1
|
||||
self.realtime_training_metrics['total_loss'] += loss
|
||||
self.realtime_training_metrics['total_accuracy'] += accuracy
|
||||
|
||||
# Maintain rolling window (last 100 steps)
|
||||
self.realtime_training_metrics['losses'].append(loss)
|
||||
self.realtime_training_metrics['accuracies'].append(accuracy)
|
||||
if len(self.realtime_training_metrics['losses']) > 100:
|
||||
self.realtime_training_metrics['losses'].pop(0)
|
||||
self.realtime_training_metrics['accuracies'].pop(0)
|
||||
|
||||
# Calculate rolling average
|
||||
avg_loss = sum(self.realtime_training_metrics['losses']) / len(self.realtime_training_metrics['losses'])
|
||||
avg_accuracy = sum(self.realtime_training_metrics['accuracies']) / len(self.realtime_training_metrics['accuracies'])
|
||||
|
||||
logger.info(f"Per-candle training: Loss={loss:.4f} (avg: {avg_loss:.4f}), Acc={accuracy:.2%} (avg: {avg_accuracy:.2%})")
|
||||
|
||||
# Check if model improved (save checkpoint)
|
||||
improved = False
|
||||
if avg_loss < self.realtime_training_metrics['best_loss']:
|
||||
self.realtime_training_metrics['best_loss'] = avg_loss
|
||||
improved = True
|
||||
logger.info(f" NEW BEST LOSS: {avg_loss:.4f}")
|
||||
|
||||
if avg_accuracy > self.realtime_training_metrics['best_accuracy']:
|
||||
self.realtime_training_metrics['best_accuracy'] = avg_accuracy
|
||||
improved = True
|
||||
logger.info(f" NEW BEST ACCURACY: {avg_accuracy:.2%}")
|
||||
|
||||
# Save checkpoint if improved or every N steps
|
||||
steps_since_checkpoint = self.realtime_training_metrics['total_steps'] - self.realtime_training_metrics['last_checkpoint_step']
|
||||
|
||||
if improved or steps_since_checkpoint >= self.realtime_training_metrics['checkpoint_frequency']:
|
||||
self._save_realtime_checkpoint(
|
||||
trainer=trainer,
|
||||
step=self.realtime_training_metrics['total_steps'],
|
||||
loss=avg_loss,
|
||||
accuracy=avg_accuracy,
|
||||
improved=improved
|
||||
)
|
||||
self.realtime_training_metrics['last_checkpoint_step'] = self.realtime_training_metrics['total_steps']
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error training transformer on sample: {e}")
|
||||
|
||||
def _save_realtime_checkpoint(self, trainer, step: int, loss: float, accuracy: float, improved: bool = False):
|
||||
"""
|
||||
Save checkpoint during real-time training
|
||||
|
||||
Args:
|
||||
trainer: Model trainer instance
|
||||
step: Current training step
|
||||
loss: Current average loss
|
||||
accuracy: Current average accuracy
|
||||
improved: Whether this is an improvement checkpoint
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
checkpoint_dir = "models/checkpoints/transformer/realtime"
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
checkpoint_type = "BEST" if improved else "periodic"
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f"realtime_{checkpoint_type}_step{step}_{timestamp}.pt")
|
||||
|
||||
# Save checkpoint
|
||||
torch.save({
|
||||
'step': step,
|
||||
'model_state_dict': trainer.model.state_dict(),
|
||||
'optimizer_state_dict': trainer.optimizer.state_dict(),
|
||||
'scheduler_state_dict': trainer.scheduler.state_dict() if hasattr(trainer, 'scheduler') else None,
|
||||
'loss': loss,
|
||||
'accuracy': accuracy,
|
||||
'learning_rate': trainer.scheduler.get_last_lr()[0] if hasattr(trainer, 'scheduler') else trainer.optimizer.param_groups[0]['lr'],
|
||||
'training_type': 'realtime_per_candle',
|
||||
'metrics': {
|
||||
'total_steps': self.realtime_training_metrics['total_steps'],
|
||||
'best_loss': self.realtime_training_metrics['best_loss'],
|
||||
'best_accuracy': self.realtime_training_metrics['best_accuracy'],
|
||||
'rolling_losses': self.realtime_training_metrics['losses'][-10:], # Last 10
|
||||
'rolling_accuracies': self.realtime_training_metrics['accuracies'][-10:]
|
||||
}
|
||||
}, checkpoint_path)
|
||||
|
||||
logger.info(f" SAVED REALTIME CHECKPOINT: {checkpoint_path}")
|
||||
logger.info(f" Step: {step}, Loss: {loss:.4f}, Acc: {accuracy:.2%}, Improved: {improved}")
|
||||
|
||||
# Save metadata to database
|
||||
try:
|
||||
from utils.database_manager import get_database_manager
|
||||
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_id = f"realtime_step{step}_{timestamp}"
|
||||
|
||||
from utils.database_manager import CheckpointMetadata
|
||||
metadata = CheckpointMetadata(
|
||||
checkpoint_id=checkpoint_id,
|
||||
model_name="transformer_realtime",
|
||||
model_type="transformer",
|
||||
timestamp=datetime.now(),
|
||||
performance_metrics={
|
||||
'loss': float(loss),
|
||||
'accuracy': float(accuracy),
|
||||
'step': step,
|
||||
'best_loss': float(self.realtime_training_metrics['best_loss']),
|
||||
'best_accuracy': float(self.realtime_training_metrics['best_accuracy'])
|
||||
},
|
||||
training_metadata={
|
||||
'training_type': 'realtime_per_candle',
|
||||
'total_steps': self.realtime_training_metrics['total_steps'],
|
||||
'checkpoint_type': checkpoint_type
|
||||
},
|
||||
file_path=checkpoint_path,
|
||||
file_size_mb=os.path.getsize(checkpoint_path) / (1024 * 1024),
|
||||
is_active=True
|
||||
)
|
||||
|
||||
if db_manager.save_checkpoint_metadata(metadata):
|
||||
logger.info(f" Saved checkpoint metadata to database: {checkpoint_id}")
|
||||
except Exception as meta_error:
|
||||
logger.warning(f" Could not save checkpoint metadata: {meta_error}")
|
||||
|
||||
# Cleanup: Keep only best 10 checkpoints
|
||||
if improved:
|
||||
self._cleanup_realtime_checkpoints(checkpoint_dir, keep_best=10)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving realtime checkpoint: {e}")
|
||||
|
||||
def _cleanup_realtime_checkpoints(self, checkpoint_dir: str, keep_best: int = 10):
|
||||
"""Keep only the best N realtime checkpoints"""
|
||||
try:
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
return
|
||||
|
||||
import torch
|
||||
|
||||
checkpoints = []
|
||||
for filename in os.listdir(checkpoint_dir):
|
||||
if filename.endswith('.pt') and filename.startswith('realtime_'):
|
||||
filepath = os.path.join(checkpoint_dir, filename)
|
||||
try:
|
||||
checkpoint = torch.load(filepath, map_location='cpu')
|
||||
checkpoints.append({
|
||||
'path': filepath,
|
||||
'loss': checkpoint.get('loss', float('inf')),
|
||||
'accuracy': checkpoint.get('accuracy', 0),
|
||||
'step': checkpoint.get('step', 0),
|
||||
'is_best': 'BEST' in filename
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load checkpoint {filename}: {e}")
|
||||
|
||||
# Sort by accuracy (higher is better), then by loss (lower is better)
|
||||
checkpoints.sort(key=lambda x: (x['accuracy'], -x['loss']), reverse=True)
|
||||
|
||||
# Keep best N checkpoints
|
||||
for checkpoint in checkpoints[keep_best:]:
|
||||
try:
|
||||
os.remove(checkpoint['path'])
|
||||
logger.debug(f"Removed old realtime checkpoint: {os.path.basename(checkpoint['path'])}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not remove checkpoint: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up realtime checkpoints: {e}")
|
||||
|
||||
def _load_best_realtime_checkpoint(self):
|
||||
"""Load the best realtime checkpoint on startup to resume training"""
|
||||
try:
|
||||
import torch
|
||||
import os
|
||||
|
||||
checkpoint_dir = "models/checkpoints/transformer/realtime"
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
logger.info("No realtime checkpoints found, starting fresh")
|
||||
return
|
||||
|
||||
# Find best checkpoint
|
||||
checkpoints = []
|
||||
for filename in os.listdir(checkpoint_dir):
|
||||
if filename.endswith('.pt') and filename.startswith('realtime_'):
|
||||
filepath = os.path.join(checkpoint_dir, filename)
|
||||
try:
|
||||
checkpoint = torch.load(filepath, map_location='cpu')
|
||||
checkpoints.append({
|
||||
'path': filepath,
|
||||
'loss': checkpoint.get('loss', float('inf')),
|
||||
'accuracy': checkpoint.get('accuracy', 0),
|
||||
'step': checkpoint.get('step', 0),
|
||||
'checkpoint': checkpoint
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load checkpoint {filename}: {e}")
|
||||
|
||||
if not checkpoints:
|
||||
logger.info("No valid realtime checkpoints found")
|
||||
return
|
||||
|
||||
# Sort by accuracy, then by loss
|
||||
checkpoints.sort(key=lambda x: (x['accuracy'], -x['loss']), reverse=True)
|
||||
best = checkpoints[0]
|
||||
|
||||
# Restore metrics from checkpoint
|
||||
if 'metrics' in best['checkpoint']:
|
||||
saved_metrics = best['checkpoint']['metrics']
|
||||
self.realtime_training_metrics['total_steps'] = saved_metrics.get('total_steps', 0)
|
||||
self.realtime_training_metrics['best_loss'] = saved_metrics.get('best_loss', float('inf'))
|
||||
self.realtime_training_metrics['best_accuracy'] = saved_metrics.get('best_accuracy', 0.0)
|
||||
self.realtime_training_metrics['losses'] = saved_metrics.get('rolling_losses', [])
|
||||
self.realtime_training_metrics['accuracies'] = saved_metrics.get('rolling_accuracies', [])
|
||||
self.realtime_training_metrics['last_checkpoint_step'] = best['step']
|
||||
|
||||
# Load model weights if orchestrator is available
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer_trainer'):
|
||||
trainer = self.orchestrator.primary_transformer_trainer
|
||||
if trainer and trainer.model:
|
||||
trainer.model.load_state_dict(best['checkpoint']['model_state_dict'])
|
||||
trainer.optimizer.load_state_dict(best['checkpoint']['optimizer_state_dict'])
|
||||
if 'scheduler_state_dict' in best['checkpoint'] and best['checkpoint']['scheduler_state_dict']:
|
||||
trainer.scheduler.load_state_dict(best['checkpoint']['scheduler_state_dict'])
|
||||
|
||||
logger.info(f"RESUMED REALTIME TRAINING from checkpoint:")
|
||||
logger.info(f" Step: {best['step']}, Loss: {best['loss']:.4f}, Acc: {best['accuracy']:.2%}")
|
||||
logger.info(f" Path: {os.path.basename(best['path'])}")
|
||||
else:
|
||||
logger.info(f"Found realtime checkpoint but trainer not available yet")
|
||||
else:
|
||||
logger.info(f"Found realtime checkpoint but orchestrator not available yet")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading realtime checkpoint: {e}")
|
||||
logger.info("Starting realtime training from scratch")
|
||||
|
||||
def _get_sleep_time_for_timeframe(self, timeframe: str) -> float:
|
||||
"""Get appropriate sleep time based on timeframe"""
|
||||
timeframe_seconds = {
|
||||
|
||||
Reference in New Issue
Block a user