fx T checkpoint and model loss measure
This commit is contained in:
@@ -15,6 +15,7 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
|
import os
|
||||||
from typing import Dict, List, Optional, Any
|
from typing import Dict, List, Optional, Any
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
@@ -1447,6 +1448,71 @@ class RealTrainingAdapter:
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _find_best_checkpoint(self, checkpoint_dir: str, metric: str = 'accuracy') -> Optional[str]:
|
||||||
|
"""Find the best checkpoint based on a metric"""
|
||||||
|
try:
|
||||||
|
if not os.path.exists(checkpoint_dir):
|
||||||
|
return None
|
||||||
|
|
||||||
|
checkpoints = []
|
||||||
|
for filename in os.listdir(checkpoint_dir):
|
||||||
|
if filename.endswith('.pt'):
|
||||||
|
filepath = os.path.join(checkpoint_dir, filename)
|
||||||
|
try:
|
||||||
|
checkpoint = torch.load(filepath, map_location='cpu')
|
||||||
|
checkpoints.append({
|
||||||
|
'path': filepath,
|
||||||
|
'metric_value': checkpoint.get(metric, 0),
|
||||||
|
'epoch': checkpoint.get('epoch', 0)
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not load checkpoint {filename}: {e}")
|
||||||
|
|
||||||
|
if not checkpoints:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Sort by metric (higher is better for accuracy)
|
||||||
|
checkpoints.sort(key=lambda x: x['metric_value'], reverse=True)
|
||||||
|
return checkpoints[0]['path']
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error finding best checkpoint: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _cleanup_old_checkpoints(self, checkpoint_dir: str, keep_best: int = 5, metric: str = 'accuracy'):
|
||||||
|
"""Keep only the best N checkpoints"""
|
||||||
|
try:
|
||||||
|
if not os.path.exists(checkpoint_dir):
|
||||||
|
return
|
||||||
|
|
||||||
|
checkpoints = []
|
||||||
|
for filename in os.listdir(checkpoint_dir):
|
||||||
|
if filename.endswith('.pt'):
|
||||||
|
filepath = os.path.join(checkpoint_dir, filename)
|
||||||
|
try:
|
||||||
|
checkpoint = torch.load(filepath, map_location='cpu')
|
||||||
|
checkpoints.append({
|
||||||
|
'path': filepath,
|
||||||
|
'metric_value': checkpoint.get(metric, 0),
|
||||||
|
'epoch': checkpoint.get('epoch', 0)
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not load checkpoint {filename}: {e}")
|
||||||
|
|
||||||
|
# Sort by metric (higher is better)
|
||||||
|
checkpoints.sort(key=lambda x: x['metric_value'], reverse=True)
|
||||||
|
|
||||||
|
# Delete checkpoints beyond keep_best
|
||||||
|
for checkpoint in checkpoints[keep_best:]:
|
||||||
|
try:
|
||||||
|
os.remove(checkpoint['path'])
|
||||||
|
logger.debug(f"Removed old checkpoint: {checkpoint['path']}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not remove checkpoint: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error cleaning up checkpoints: {e}")
|
||||||
|
|
||||||
def _train_transformer_real(self, session: TrainingSession, training_data: List[Dict]):
|
def _train_transformer_real(self, session: TrainingSession, training_data: List[Dict]):
|
||||||
"""
|
"""
|
||||||
Train Transformer model using orchestrator's existing training infrastructure
|
Train Transformer model using orchestrator's existing training infrastructure
|
||||||
@@ -1466,6 +1532,25 @@ class RealTrainingAdapter:
|
|||||||
logger.info(f"Using orchestrator's TradingTransformerTrainer")
|
logger.info(f"Using orchestrator's TradingTransformerTrainer")
|
||||||
logger.info(f" Trainer type: {type(trainer).__name__}")
|
logger.info(f" Trainer type: {type(trainer).__name__}")
|
||||||
|
|
||||||
|
# Load best checkpoint if available to continue training
|
||||||
|
try:
|
||||||
|
checkpoint_dir = "models/checkpoints/transformer"
|
||||||
|
best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy')
|
||||||
|
|
||||||
|
if best_checkpoint_path and os.path.exists(best_checkpoint_path):
|
||||||
|
checkpoint = torch.load(best_checkpoint_path)
|
||||||
|
trainer.model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||||
|
|
||||||
|
logger.info(f" Loaded checkpoint from epoch {checkpoint.get('epoch', 0)}")
|
||||||
|
logger.info(f" Previous best: Loss={checkpoint.get('loss', 0):.6f}, Accuracy={checkpoint.get('accuracy', 0):.2%}")
|
||||||
|
else:
|
||||||
|
logger.info(" No previous checkpoint found, starting fresh")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f" Failed to load checkpoint: {e}")
|
||||||
|
logger.info(" Starting with fresh model weights")
|
||||||
|
|
||||||
# Use the trainer's train_step method for individual samples
|
# Use the trainer's train_step method for individual samples
|
||||||
if hasattr(trainer, 'train_step'):
|
if hasattr(trainer, 'train_step'):
|
||||||
logger.info(" Using trainer.train_step() method")
|
logger.info(" Using trainer.train_step() method")
|
||||||
@@ -1549,13 +1634,14 @@ class RealTrainingAdapter:
|
|||||||
if result is not None:
|
if result is not None:
|
||||||
batch_loss = result.get('total_loss', 0.0)
|
batch_loss = result.get('total_loss', 0.0)
|
||||||
batch_accuracy = result.get('accuracy', 0.0)
|
batch_accuracy = result.get('accuracy', 0.0)
|
||||||
|
batch_candle_accuracy = result.get('candle_accuracy', 0.0)
|
||||||
epoch_loss += batch_loss
|
epoch_loss += batch_loss
|
||||||
epoch_accuracy += batch_accuracy
|
epoch_accuracy += batch_accuracy
|
||||||
num_batches += 1
|
num_batches += 1
|
||||||
|
|
||||||
# Log first batch and every 10th batch for debugging
|
# Log first batch and every 10th batch for debugging
|
||||||
if (i + 1) == 1 or (i + 1) % 10 == 0:
|
if (i + 1) == 1 or (i + 1) % 10 == 0:
|
||||||
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Accuracy: {batch_accuracy:.4f}")
|
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
||||||
|
|
||||||
@@ -1577,6 +1663,32 @@ class RealTrainingAdapter:
|
|||||||
session.current_epoch = epoch + 1
|
session.current_epoch = epoch + 1
|
||||||
session.current_loss = avg_loss
|
session.current_loss = avg_loss
|
||||||
|
|
||||||
|
# Save checkpoint after each epoch
|
||||||
|
try:
|
||||||
|
checkpoint_dir = "models/checkpoints/transformer"
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir, f"transformer_epoch{epoch+1}_{timestamp}.pt")
|
||||||
|
|
||||||
|
torch.save({
|
||||||
|
'epoch': epoch + 1,
|
||||||
|
'model_state_dict': trainer.model.state_dict(),
|
||||||
|
'optimizer_state_dict': trainer.optimizer.state_dict(),
|
||||||
|
'scheduler_state_dict': trainer.scheduler.state_dict(),
|
||||||
|
'loss': avg_loss,
|
||||||
|
'accuracy': avg_accuracy,
|
||||||
|
'learning_rate': trainer.scheduler.get_last_lr()[0]
|
||||||
|
}, checkpoint_path)
|
||||||
|
|
||||||
|
logger.info(f" Saved checkpoint: {checkpoint_path}")
|
||||||
|
|
||||||
|
# Keep only best 5 checkpoints based on accuracy
|
||||||
|
self._cleanup_old_checkpoints(checkpoint_dir, keep_best=5, metric='accuracy')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f" Failed to save checkpoint: {e}")
|
||||||
|
|
||||||
# Clear CUDA cache after each epoch
|
# Clear CUDA cache after each epoch
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@@ -1586,6 +1698,16 @@ class RealTrainingAdapter:
|
|||||||
session.final_loss = session.current_loss
|
session.final_loss = session.current_loss
|
||||||
session.accuracy = avg_accuracy
|
session.accuracy = avg_accuracy
|
||||||
|
|
||||||
|
# Log best checkpoint info
|
||||||
|
try:
|
||||||
|
checkpoint_dir = "models/checkpoints/transformer"
|
||||||
|
best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy')
|
||||||
|
if best_checkpoint_path:
|
||||||
|
checkpoint = torch.load(best_checkpoint_path, map_location='cpu')
|
||||||
|
logger.info(f" Best checkpoint: epoch {checkpoint.get('epoch', 0)}, accuracy: {checkpoint.get('accuracy', 0):.2%}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not load best checkpoint info: {e}")
|
||||||
|
|
||||||
logger.info(f" Training complete: Loss = {session.final_loss:.6f}, Accuracy = {session.accuracy:.2%}")
|
logger.info(f" Training complete: Loss = {session.final_loss:.6f}, Accuracy = {session.accuracy:.2%}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1268,12 +1268,27 @@ class TradingTransformerTrainer:
|
|||||||
predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||||
accuracy = (predictions == batch['actions']).float().mean()
|
accuracy = (predictions == batch['actions']).float().mean()
|
||||||
|
|
||||||
|
# Calculate candle prediction accuracy (price direction)
|
||||||
|
candle_accuracy = 0.0
|
||||||
|
if 'next_candles' in outputs and 'future_prices' in batch:
|
||||||
|
# Use 1m timeframe prediction as primary
|
||||||
|
if '1m' in outputs['next_candles']:
|
||||||
|
predicted_candle = outputs['next_candles']['1m'] # [batch, 5]
|
||||||
|
# Predicted close is the 4th value (index 3)
|
||||||
|
predicted_close_change = predicted_candle[:, 3] # Predicted close price change
|
||||||
|
actual_close_change = batch['future_prices'] # Actual price change ratio
|
||||||
|
|
||||||
|
# Check if direction matches (both positive or both negative)
|
||||||
|
direction_match = (torch.sign(predicted_close_change) == torch.sign(actual_close_change)).float()
|
||||||
|
candle_accuracy = direction_match.mean().item()
|
||||||
|
|
||||||
# Extract values and delete tensors to free memory
|
# Extract values and delete tensors to free memory
|
||||||
result = {
|
result = {
|
||||||
'total_loss': total_loss.item(),
|
'total_loss': total_loss.item(),
|
||||||
'action_loss': action_loss.item(),
|
'action_loss': action_loss.item(),
|
||||||
'price_loss': price_loss.item(),
|
'price_loss': price_loss.item(),
|
||||||
'accuracy': accuracy.item(),
|
'accuracy': accuracy.item(),
|
||||||
|
'candle_accuracy': candle_accuracy,
|
||||||
'learning_rate': self.scheduler.get_last_lr()[0]
|
'learning_rate': self.scheduler.get_last_lr()[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user