fx T checkpoint and model loss measure
This commit is contained in:
@@ -15,6 +15,7 @@ import logging
|
||||
import uuid
|
||||
import time
|
||||
import threading
|
||||
import os
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
@@ -1447,6 +1448,71 @@ class RealTrainingAdapter:
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def _find_best_checkpoint(self, checkpoint_dir: str, metric: str = 'accuracy') -> Optional[str]:
|
||||
"""Find the best checkpoint based on a metric"""
|
||||
try:
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
return None
|
||||
|
||||
checkpoints = []
|
||||
for filename in os.listdir(checkpoint_dir):
|
||||
if filename.endswith('.pt'):
|
||||
filepath = os.path.join(checkpoint_dir, filename)
|
||||
try:
|
||||
checkpoint = torch.load(filepath, map_location='cpu')
|
||||
checkpoints.append({
|
||||
'path': filepath,
|
||||
'metric_value': checkpoint.get(metric, 0),
|
||||
'epoch': checkpoint.get('epoch', 0)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load checkpoint {filename}: {e}")
|
||||
|
||||
if not checkpoints:
|
||||
return None
|
||||
|
||||
# Sort by metric (higher is better for accuracy)
|
||||
checkpoints.sort(key=lambda x: x['metric_value'], reverse=True)
|
||||
return checkpoints[0]['path']
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding best checkpoint: {e}")
|
||||
return None
|
||||
|
||||
def _cleanup_old_checkpoints(self, checkpoint_dir: str, keep_best: int = 5, metric: str = 'accuracy'):
|
||||
"""Keep only the best N checkpoints"""
|
||||
try:
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
return
|
||||
|
||||
checkpoints = []
|
||||
for filename in os.listdir(checkpoint_dir):
|
||||
if filename.endswith('.pt'):
|
||||
filepath = os.path.join(checkpoint_dir, filename)
|
||||
try:
|
||||
checkpoint = torch.load(filepath, map_location='cpu')
|
||||
checkpoints.append({
|
||||
'path': filepath,
|
||||
'metric_value': checkpoint.get(metric, 0),
|
||||
'epoch': checkpoint.get('epoch', 0)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load checkpoint {filename}: {e}")
|
||||
|
||||
# Sort by metric (higher is better)
|
||||
checkpoints.sort(key=lambda x: x['metric_value'], reverse=True)
|
||||
|
||||
# Delete checkpoints beyond keep_best
|
||||
for checkpoint in checkpoints[keep_best:]:
|
||||
try:
|
||||
os.remove(checkpoint['path'])
|
||||
logger.debug(f"Removed old checkpoint: {checkpoint['path']}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not remove checkpoint: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up checkpoints: {e}")
|
||||
|
||||
def _train_transformer_real(self, session: TrainingSession, training_data: List[Dict]):
|
||||
"""
|
||||
Train Transformer model using orchestrator's existing training infrastructure
|
||||
@@ -1466,6 +1532,25 @@ class RealTrainingAdapter:
|
||||
logger.info(f"Using orchestrator's TradingTransformerTrainer")
|
||||
logger.info(f" Trainer type: {type(trainer).__name__}")
|
||||
|
||||
# Load best checkpoint if available to continue training
|
||||
try:
|
||||
checkpoint_dir = "models/checkpoints/transformer"
|
||||
best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy')
|
||||
|
||||
if best_checkpoint_path and os.path.exists(best_checkpoint_path):
|
||||
checkpoint = torch.load(best_checkpoint_path)
|
||||
trainer.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
logger.info(f" Loaded checkpoint from epoch {checkpoint.get('epoch', 0)}")
|
||||
logger.info(f" Previous best: Loss={checkpoint.get('loss', 0):.6f}, Accuracy={checkpoint.get('accuracy', 0):.2%}")
|
||||
else:
|
||||
logger.info(" No previous checkpoint found, starting fresh")
|
||||
except Exception as e:
|
||||
logger.warning(f" Failed to load checkpoint: {e}")
|
||||
logger.info(" Starting with fresh model weights")
|
||||
|
||||
# Use the trainer's train_step method for individual samples
|
||||
if hasattr(trainer, 'train_step'):
|
||||
logger.info(" Using trainer.train_step() method")
|
||||
@@ -1549,13 +1634,14 @@ class RealTrainingAdapter:
|
||||
if result is not None:
|
||||
batch_loss = result.get('total_loss', 0.0)
|
||||
batch_accuracy = result.get('accuracy', 0.0)
|
||||
batch_candle_accuracy = result.get('candle_accuracy', 0.0)
|
||||
epoch_loss += batch_loss
|
||||
epoch_accuracy += batch_accuracy
|
||||
num_batches += 1
|
||||
|
||||
# Log first batch and every 10th batch for debugging
|
||||
if (i + 1) == 1 or (i + 1) % 10 == 0:
|
||||
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Accuracy: {batch_accuracy:.4f}")
|
||||
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}")
|
||||
else:
|
||||
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
||||
|
||||
@@ -1577,6 +1663,32 @@ class RealTrainingAdapter:
|
||||
session.current_epoch = epoch + 1
|
||||
session.current_loss = avg_loss
|
||||
|
||||
# Save checkpoint after each epoch
|
||||
try:
|
||||
checkpoint_dir = "models/checkpoints/transformer"
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f"transformer_epoch{epoch+1}_{timestamp}.pt")
|
||||
|
||||
torch.save({
|
||||
'epoch': epoch + 1,
|
||||
'model_state_dict': trainer.model.state_dict(),
|
||||
'optimizer_state_dict': trainer.optimizer.state_dict(),
|
||||
'scheduler_state_dict': trainer.scheduler.state_dict(),
|
||||
'loss': avg_loss,
|
||||
'accuracy': avg_accuracy,
|
||||
'learning_rate': trainer.scheduler.get_last_lr()[0]
|
||||
}, checkpoint_path)
|
||||
|
||||
logger.info(f" Saved checkpoint: {checkpoint_path}")
|
||||
|
||||
# Keep only best 5 checkpoints based on accuracy
|
||||
self._cleanup_old_checkpoints(checkpoint_dir, keep_best=5, metric='accuracy')
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f" Failed to save checkpoint: {e}")
|
||||
|
||||
# Clear CUDA cache after each epoch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
@@ -1586,6 +1698,16 @@ class RealTrainingAdapter:
|
||||
session.final_loss = session.current_loss
|
||||
session.accuracy = avg_accuracy
|
||||
|
||||
# Log best checkpoint info
|
||||
try:
|
||||
checkpoint_dir = "models/checkpoints/transformer"
|
||||
best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy')
|
||||
if best_checkpoint_path:
|
||||
checkpoint = torch.load(best_checkpoint_path, map_location='cpu')
|
||||
logger.info(f" Best checkpoint: epoch {checkpoint.get('epoch', 0)}, accuracy: {checkpoint.get('accuracy', 0):.2%}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load best checkpoint info: {e}")
|
||||
|
||||
logger.info(f" Training complete: Loss = {session.final_loss:.6f}, Accuracy = {session.accuracy:.2%}")
|
||||
|
||||
else:
|
||||
|
||||
@@ -1267,6 +1267,20 @@ class TradingTransformerTrainer:
|
||||
with torch.no_grad():
|
||||
predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||
accuracy = (predictions == batch['actions']).float().mean()
|
||||
|
||||
# Calculate candle prediction accuracy (price direction)
|
||||
candle_accuracy = 0.0
|
||||
if 'next_candles' in outputs and 'future_prices' in batch:
|
||||
# Use 1m timeframe prediction as primary
|
||||
if '1m' in outputs['next_candles']:
|
||||
predicted_candle = outputs['next_candles']['1m'] # [batch, 5]
|
||||
# Predicted close is the 4th value (index 3)
|
||||
predicted_close_change = predicted_candle[:, 3] # Predicted close price change
|
||||
actual_close_change = batch['future_prices'] # Actual price change ratio
|
||||
|
||||
# Check if direction matches (both positive or both negative)
|
||||
direction_match = (torch.sign(predicted_close_change) == torch.sign(actual_close_change)).float()
|
||||
candle_accuracy = direction_match.mean().item()
|
||||
|
||||
# Extract values and delete tensors to free memory
|
||||
result = {
|
||||
@@ -1274,6 +1288,7 @@ class TradingTransformerTrainer:
|
||||
'action_loss': action_loss.item(),
|
||||
'price_loss': price_loss.item(),
|
||||
'accuracy': accuracy.item(),
|
||||
'candle_accuracy': candle_accuracy,
|
||||
'learning_rate': self.scheduler.get_last_lr()[0]
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user