fx T checkpoint and model loss measure

This commit is contained in:
Dobromir Popov
2025-11-10 12:41:39 +02:00
parent 86ae8b499b
commit a2d34c6d7c
2 changed files with 138 additions and 1 deletions

View File

@@ -15,6 +15,7 @@ import logging
import uuid import uuid
import time import time
import threading import threading
import os
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
@@ -1447,6 +1448,71 @@ class RealTrainingAdapter:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return None return None
def _find_best_checkpoint(self, checkpoint_dir: str, metric: str = 'accuracy') -> Optional[str]:
"""Find the best checkpoint based on a metric"""
try:
if not os.path.exists(checkpoint_dir):
return None
checkpoints = []
for filename in os.listdir(checkpoint_dir):
if filename.endswith('.pt'):
filepath = os.path.join(checkpoint_dir, filename)
try:
checkpoint = torch.load(filepath, map_location='cpu')
checkpoints.append({
'path': filepath,
'metric_value': checkpoint.get(metric, 0),
'epoch': checkpoint.get('epoch', 0)
})
except Exception as e:
logger.debug(f"Could not load checkpoint {filename}: {e}")
if not checkpoints:
return None
# Sort by metric (higher is better for accuracy)
checkpoints.sort(key=lambda x: x['metric_value'], reverse=True)
return checkpoints[0]['path']
except Exception as e:
logger.error(f"Error finding best checkpoint: {e}")
return None
def _cleanup_old_checkpoints(self, checkpoint_dir: str, keep_best: int = 5, metric: str = 'accuracy'):
"""Keep only the best N checkpoints"""
try:
if not os.path.exists(checkpoint_dir):
return
checkpoints = []
for filename in os.listdir(checkpoint_dir):
if filename.endswith('.pt'):
filepath = os.path.join(checkpoint_dir, filename)
try:
checkpoint = torch.load(filepath, map_location='cpu')
checkpoints.append({
'path': filepath,
'metric_value': checkpoint.get(metric, 0),
'epoch': checkpoint.get('epoch', 0)
})
except Exception as e:
logger.debug(f"Could not load checkpoint {filename}: {e}")
# Sort by metric (higher is better)
checkpoints.sort(key=lambda x: x['metric_value'], reverse=True)
# Delete checkpoints beyond keep_best
for checkpoint in checkpoints[keep_best:]:
try:
os.remove(checkpoint['path'])
logger.debug(f"Removed old checkpoint: {checkpoint['path']}")
except Exception as e:
logger.warning(f"Could not remove checkpoint: {e}")
except Exception as e:
logger.error(f"Error cleaning up checkpoints: {e}")
def _train_transformer_real(self, session: TrainingSession, training_data: List[Dict]): def _train_transformer_real(self, session: TrainingSession, training_data: List[Dict]):
""" """
Train Transformer model using orchestrator's existing training infrastructure Train Transformer model using orchestrator's existing training infrastructure
@@ -1466,6 +1532,25 @@ class RealTrainingAdapter:
logger.info(f"Using orchestrator's TradingTransformerTrainer") logger.info(f"Using orchestrator's TradingTransformerTrainer")
logger.info(f" Trainer type: {type(trainer).__name__}") logger.info(f" Trainer type: {type(trainer).__name__}")
# Load best checkpoint if available to continue training
try:
checkpoint_dir = "models/checkpoints/transformer"
best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy')
if best_checkpoint_path and os.path.exists(best_checkpoint_path):
checkpoint = torch.load(best_checkpoint_path)
trainer.model.load_state_dict(checkpoint['model_state_dict'])
trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
logger.info(f" Loaded checkpoint from epoch {checkpoint.get('epoch', 0)}")
logger.info(f" Previous best: Loss={checkpoint.get('loss', 0):.6f}, Accuracy={checkpoint.get('accuracy', 0):.2%}")
else:
logger.info(" No previous checkpoint found, starting fresh")
except Exception as e:
logger.warning(f" Failed to load checkpoint: {e}")
logger.info(" Starting with fresh model weights")
# Use the trainer's train_step method for individual samples # Use the trainer's train_step method for individual samples
if hasattr(trainer, 'train_step'): if hasattr(trainer, 'train_step'):
logger.info(" Using trainer.train_step() method") logger.info(" Using trainer.train_step() method")
@@ -1549,13 +1634,14 @@ class RealTrainingAdapter:
if result is not None: if result is not None:
batch_loss = result.get('total_loss', 0.0) batch_loss = result.get('total_loss', 0.0)
batch_accuracy = result.get('accuracy', 0.0) batch_accuracy = result.get('accuracy', 0.0)
batch_candle_accuracy = result.get('candle_accuracy', 0.0)
epoch_loss += batch_loss epoch_loss += batch_loss
epoch_accuracy += batch_accuracy epoch_accuracy += batch_accuracy
num_batches += 1 num_batches += 1
# Log first batch and every 10th batch for debugging # Log first batch and every 10th batch for debugging
if (i + 1) == 1 or (i + 1) % 10 == 0: if (i + 1) == 1 or (i + 1) % 10 == 0:
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Accuracy: {batch_accuracy:.4f}") logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}")
else: else:
logger.warning(f" Batch {i + 1} returned None result - skipping") logger.warning(f" Batch {i + 1} returned None result - skipping")
@@ -1577,6 +1663,32 @@ class RealTrainingAdapter:
session.current_epoch = epoch + 1 session.current_epoch = epoch + 1
session.current_loss = avg_loss session.current_loss = avg_loss
# Save checkpoint after each epoch
try:
checkpoint_dir = "models/checkpoints/transformer"
os.makedirs(checkpoint_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_path = os.path.join(checkpoint_dir, f"transformer_epoch{epoch+1}_{timestamp}.pt")
torch.save({
'epoch': epoch + 1,
'model_state_dict': trainer.model.state_dict(),
'optimizer_state_dict': trainer.optimizer.state_dict(),
'scheduler_state_dict': trainer.scheduler.state_dict(),
'loss': avg_loss,
'accuracy': avg_accuracy,
'learning_rate': trainer.scheduler.get_last_lr()[0]
}, checkpoint_path)
logger.info(f" Saved checkpoint: {checkpoint_path}")
# Keep only best 5 checkpoints based on accuracy
self._cleanup_old_checkpoints(checkpoint_dir, keep_best=5, metric='accuracy')
except Exception as e:
logger.warning(f" Failed to save checkpoint: {e}")
# Clear CUDA cache after each epoch # Clear CUDA cache after each epoch
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -1586,6 +1698,16 @@ class RealTrainingAdapter:
session.final_loss = session.current_loss session.final_loss = session.current_loss
session.accuracy = avg_accuracy session.accuracy = avg_accuracy
# Log best checkpoint info
try:
checkpoint_dir = "models/checkpoints/transformer"
best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy')
if best_checkpoint_path:
checkpoint = torch.load(best_checkpoint_path, map_location='cpu')
logger.info(f" Best checkpoint: epoch {checkpoint.get('epoch', 0)}, accuracy: {checkpoint.get('accuracy', 0):.2%}")
except Exception as e:
logger.debug(f"Could not load best checkpoint info: {e}")
logger.info(f" Training complete: Loss = {session.final_loss:.6f}, Accuracy = {session.accuracy:.2%}") logger.info(f" Training complete: Loss = {session.final_loss:.6f}, Accuracy = {session.accuracy:.2%}")
else: else:

View File

@@ -1267,6 +1267,20 @@ class TradingTransformerTrainer:
with torch.no_grad(): with torch.no_grad():
predictions = torch.argmax(outputs['action_logits'], dim=-1) predictions = torch.argmax(outputs['action_logits'], dim=-1)
accuracy = (predictions == batch['actions']).float().mean() accuracy = (predictions == batch['actions']).float().mean()
# Calculate candle prediction accuracy (price direction)
candle_accuracy = 0.0
if 'next_candles' in outputs and 'future_prices' in batch:
# Use 1m timeframe prediction as primary
if '1m' in outputs['next_candles']:
predicted_candle = outputs['next_candles']['1m'] # [batch, 5]
# Predicted close is the 4th value (index 3)
predicted_close_change = predicted_candle[:, 3] # Predicted close price change
actual_close_change = batch['future_prices'] # Actual price change ratio
# Check if direction matches (both positive or both negative)
direction_match = (torch.sign(predicted_close_change) == torch.sign(actual_close_change)).float()
candle_accuracy = direction_match.mean().item()
# Extract values and delete tensors to free memory # Extract values and delete tensors to free memory
result = { result = {
@@ -1274,6 +1288,7 @@ class TradingTransformerTrainer:
'action_loss': action_loss.item(), 'action_loss': action_loss.item(),
'price_loss': price_loss.item(), 'price_loss': price_loss.item(),
'accuracy': accuracy.item(), 'accuracy': accuracy.item(),
'candle_accuracy': candle_accuracy,
'learning_rate': self.scheduler.get_last_lr()[0] 'learning_rate': self.scheduler.get_last_lr()[0]
} }