fx T checkpoint and model loss measure

This commit is contained in:
Dobromir Popov
2025-11-10 12:41:39 +02:00
parent 86ae8b499b
commit a2d34c6d7c
2 changed files with 138 additions and 1 deletions

View File

@@ -15,6 +15,7 @@ import logging
import uuid
import time
import threading
import os
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
@@ -1447,6 +1448,71 @@ class RealTrainingAdapter:
logger.error(traceback.format_exc())
return None
def _find_best_checkpoint(self, checkpoint_dir: str, metric: str = 'accuracy') -> Optional[str]:
"""Find the best checkpoint based on a metric"""
try:
if not os.path.exists(checkpoint_dir):
return None
checkpoints = []
for filename in os.listdir(checkpoint_dir):
if filename.endswith('.pt'):
filepath = os.path.join(checkpoint_dir, filename)
try:
checkpoint = torch.load(filepath, map_location='cpu')
checkpoints.append({
'path': filepath,
'metric_value': checkpoint.get(metric, 0),
'epoch': checkpoint.get('epoch', 0)
})
except Exception as e:
logger.debug(f"Could not load checkpoint {filename}: {e}")
if not checkpoints:
return None
# Sort by metric (higher is better for accuracy)
checkpoints.sort(key=lambda x: x['metric_value'], reverse=True)
return checkpoints[0]['path']
except Exception as e:
logger.error(f"Error finding best checkpoint: {e}")
return None
def _cleanup_old_checkpoints(self, checkpoint_dir: str, keep_best: int = 5, metric: str = 'accuracy'):
"""Keep only the best N checkpoints"""
try:
if not os.path.exists(checkpoint_dir):
return
checkpoints = []
for filename in os.listdir(checkpoint_dir):
if filename.endswith('.pt'):
filepath = os.path.join(checkpoint_dir, filename)
try:
checkpoint = torch.load(filepath, map_location='cpu')
checkpoints.append({
'path': filepath,
'metric_value': checkpoint.get(metric, 0),
'epoch': checkpoint.get('epoch', 0)
})
except Exception as e:
logger.debug(f"Could not load checkpoint {filename}: {e}")
# Sort by metric (higher is better)
checkpoints.sort(key=lambda x: x['metric_value'], reverse=True)
# Delete checkpoints beyond keep_best
for checkpoint in checkpoints[keep_best:]:
try:
os.remove(checkpoint['path'])
logger.debug(f"Removed old checkpoint: {checkpoint['path']}")
except Exception as e:
logger.warning(f"Could not remove checkpoint: {e}")
except Exception as e:
logger.error(f"Error cleaning up checkpoints: {e}")
def _train_transformer_real(self, session: TrainingSession, training_data: List[Dict]):
"""
Train Transformer model using orchestrator's existing training infrastructure
@@ -1466,6 +1532,25 @@ class RealTrainingAdapter:
logger.info(f"Using orchestrator's TradingTransformerTrainer")
logger.info(f" Trainer type: {type(trainer).__name__}")
# Load best checkpoint if available to continue training
try:
checkpoint_dir = "models/checkpoints/transformer"
best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy')
if best_checkpoint_path and os.path.exists(best_checkpoint_path):
checkpoint = torch.load(best_checkpoint_path)
trainer.model.load_state_dict(checkpoint['model_state_dict'])
trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
logger.info(f" Loaded checkpoint from epoch {checkpoint.get('epoch', 0)}")
logger.info(f" Previous best: Loss={checkpoint.get('loss', 0):.6f}, Accuracy={checkpoint.get('accuracy', 0):.2%}")
else:
logger.info(" No previous checkpoint found, starting fresh")
except Exception as e:
logger.warning(f" Failed to load checkpoint: {e}")
logger.info(" Starting with fresh model weights")
# Use the trainer's train_step method for individual samples
if hasattr(trainer, 'train_step'):
logger.info(" Using trainer.train_step() method")
@@ -1549,13 +1634,14 @@ class RealTrainingAdapter:
if result is not None:
batch_loss = result.get('total_loss', 0.0)
batch_accuracy = result.get('accuracy', 0.0)
batch_candle_accuracy = result.get('candle_accuracy', 0.0)
epoch_loss += batch_loss
epoch_accuracy += batch_accuracy
num_batches += 1
# Log first batch and every 10th batch for debugging
if (i + 1) == 1 or (i + 1) % 10 == 0:
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Accuracy: {batch_accuracy:.4f}")
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}")
else:
logger.warning(f" Batch {i + 1} returned None result - skipping")
@@ -1577,6 +1663,32 @@ class RealTrainingAdapter:
session.current_epoch = epoch + 1
session.current_loss = avg_loss
# Save checkpoint after each epoch
try:
checkpoint_dir = "models/checkpoints/transformer"
os.makedirs(checkpoint_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_path = os.path.join(checkpoint_dir, f"transformer_epoch{epoch+1}_{timestamp}.pt")
torch.save({
'epoch': epoch + 1,
'model_state_dict': trainer.model.state_dict(),
'optimizer_state_dict': trainer.optimizer.state_dict(),
'scheduler_state_dict': trainer.scheduler.state_dict(),
'loss': avg_loss,
'accuracy': avg_accuracy,
'learning_rate': trainer.scheduler.get_last_lr()[0]
}, checkpoint_path)
logger.info(f" Saved checkpoint: {checkpoint_path}")
# Keep only best 5 checkpoints based on accuracy
self._cleanup_old_checkpoints(checkpoint_dir, keep_best=5, metric='accuracy')
except Exception as e:
logger.warning(f" Failed to save checkpoint: {e}")
# Clear CUDA cache after each epoch
if torch.cuda.is_available():
torch.cuda.empty_cache()
@@ -1586,6 +1698,16 @@ class RealTrainingAdapter:
session.final_loss = session.current_loss
session.accuracy = avg_accuracy
# Log best checkpoint info
try:
checkpoint_dir = "models/checkpoints/transformer"
best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy')
if best_checkpoint_path:
checkpoint = torch.load(best_checkpoint_path, map_location='cpu')
logger.info(f" Best checkpoint: epoch {checkpoint.get('epoch', 0)}, accuracy: {checkpoint.get('accuracy', 0):.2%}")
except Exception as e:
logger.debug(f"Could not load best checkpoint info: {e}")
logger.info(f" Training complete: Loss = {session.final_loss:.6f}, Accuracy = {session.accuracy:.2%}")
else:

View File

@@ -1268,12 +1268,27 @@ class TradingTransformerTrainer:
predictions = torch.argmax(outputs['action_logits'], dim=-1)
accuracy = (predictions == batch['actions']).float().mean()
# Calculate candle prediction accuracy (price direction)
candle_accuracy = 0.0
if 'next_candles' in outputs and 'future_prices' in batch:
# Use 1m timeframe prediction as primary
if '1m' in outputs['next_candles']:
predicted_candle = outputs['next_candles']['1m'] # [batch, 5]
# Predicted close is the 4th value (index 3)
predicted_close_change = predicted_candle[:, 3] # Predicted close price change
actual_close_change = batch['future_prices'] # Actual price change ratio
# Check if direction matches (both positive or both negative)
direction_match = (torch.sign(predicted_close_change) == torch.sign(actual_close_change)).float()
candle_accuracy = direction_match.mean().item()
# Extract values and delete tensors to free memory
result = {
'total_loss': total_loss.item(),
'action_loss': action_loss.item(),
'price_loss': price_loss.item(),
'accuracy': accuracy.item(),
'candle_accuracy': candle_accuracy,
'learning_rate': self.scheduler.get_last_lr()[0]
}