From a2d34c6d7c190d2ba5a900327de8327afb353dbd Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Mon, 10 Nov 2025 12:41:39 +0200 Subject: [PATCH] fx T checkpoint and model loss measure --- ANNOTATE/core/real_training_adapter.py | 124 +++++++++++++++++++++- NN/models/advanced_transformer_trading.py | 15 +++ 2 files changed, 138 insertions(+), 1 deletion(-) diff --git a/ANNOTATE/core/real_training_adapter.py b/ANNOTATE/core/real_training_adapter.py index 59dcc64..caac6b8 100644 --- a/ANNOTATE/core/real_training_adapter.py +++ b/ANNOTATE/core/real_training_adapter.py @@ -15,6 +15,7 @@ import logging import uuid import time import threading +import os from typing import Dict, List, Optional, Any from dataclasses import dataclass from datetime import datetime, timedelta, timezone @@ -1447,6 +1448,71 @@ class RealTrainingAdapter: logger.error(traceback.format_exc()) return None + def _find_best_checkpoint(self, checkpoint_dir: str, metric: str = 'accuracy') -> Optional[str]: + """Find the best checkpoint based on a metric""" + try: + if not os.path.exists(checkpoint_dir): + return None + + checkpoints = [] + for filename in os.listdir(checkpoint_dir): + if filename.endswith('.pt'): + filepath = os.path.join(checkpoint_dir, filename) + try: + checkpoint = torch.load(filepath, map_location='cpu') + checkpoints.append({ + 'path': filepath, + 'metric_value': checkpoint.get(metric, 0), + 'epoch': checkpoint.get('epoch', 0) + }) + except Exception as e: + logger.debug(f"Could not load checkpoint {filename}: {e}") + + if not checkpoints: + return None + + # Sort by metric (higher is better for accuracy) + checkpoints.sort(key=lambda x: x['metric_value'], reverse=True) + return checkpoints[0]['path'] + + except Exception as e: + logger.error(f"Error finding best checkpoint: {e}") + return None + + def _cleanup_old_checkpoints(self, checkpoint_dir: str, keep_best: int = 5, metric: str = 'accuracy'): + """Keep only the best N checkpoints""" + try: + if not os.path.exists(checkpoint_dir): + return + + checkpoints = [] + for filename in os.listdir(checkpoint_dir): + if filename.endswith('.pt'): + filepath = os.path.join(checkpoint_dir, filename) + try: + checkpoint = torch.load(filepath, map_location='cpu') + checkpoints.append({ + 'path': filepath, + 'metric_value': checkpoint.get(metric, 0), + 'epoch': checkpoint.get('epoch', 0) + }) + except Exception as e: + logger.debug(f"Could not load checkpoint {filename}: {e}") + + # Sort by metric (higher is better) + checkpoints.sort(key=lambda x: x['metric_value'], reverse=True) + + # Delete checkpoints beyond keep_best + for checkpoint in checkpoints[keep_best:]: + try: + os.remove(checkpoint['path']) + logger.debug(f"Removed old checkpoint: {checkpoint['path']}") + except Exception as e: + logger.warning(f"Could not remove checkpoint: {e}") + + except Exception as e: + logger.error(f"Error cleaning up checkpoints: {e}") + def _train_transformer_real(self, session: TrainingSession, training_data: List[Dict]): """ Train Transformer model using orchestrator's existing training infrastructure @@ -1466,6 +1532,25 @@ class RealTrainingAdapter: logger.info(f"Using orchestrator's TradingTransformerTrainer") logger.info(f" Trainer type: {type(trainer).__name__}") + # Load best checkpoint if available to continue training + try: + checkpoint_dir = "models/checkpoints/transformer" + best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy') + + if best_checkpoint_path and os.path.exists(best_checkpoint_path): + checkpoint = torch.load(best_checkpoint_path) + trainer.model.load_state_dict(checkpoint['model_state_dict']) + trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + + logger.info(f" Loaded checkpoint from epoch {checkpoint.get('epoch', 0)}") + logger.info(f" Previous best: Loss={checkpoint.get('loss', 0):.6f}, Accuracy={checkpoint.get('accuracy', 0):.2%}") + else: + logger.info(" No previous checkpoint found, starting fresh") + except Exception as e: + logger.warning(f" Failed to load checkpoint: {e}") + logger.info(" Starting with fresh model weights") + # Use the trainer's train_step method for individual samples if hasattr(trainer, 'train_step'): logger.info(" Using trainer.train_step() method") @@ -1549,13 +1634,14 @@ class RealTrainingAdapter: if result is not None: batch_loss = result.get('total_loss', 0.0) batch_accuracy = result.get('accuracy', 0.0) + batch_candle_accuracy = result.get('candle_accuracy', 0.0) epoch_loss += batch_loss epoch_accuracy += batch_accuracy num_batches += 1 # Log first batch and every 10th batch for debugging if (i + 1) == 1 or (i + 1) % 10 == 0: - logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Accuracy: {batch_accuracy:.4f}") + logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}") else: logger.warning(f" Batch {i + 1} returned None result - skipping") @@ -1577,6 +1663,32 @@ class RealTrainingAdapter: session.current_epoch = epoch + 1 session.current_loss = avg_loss + # Save checkpoint after each epoch + try: + checkpoint_dir = "models/checkpoints/transformer" + os.makedirs(checkpoint_dir, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + checkpoint_path = os.path.join(checkpoint_dir, f"transformer_epoch{epoch+1}_{timestamp}.pt") + + torch.save({ + 'epoch': epoch + 1, + 'model_state_dict': trainer.model.state_dict(), + 'optimizer_state_dict': trainer.optimizer.state_dict(), + 'scheduler_state_dict': trainer.scheduler.state_dict(), + 'loss': avg_loss, + 'accuracy': avg_accuracy, + 'learning_rate': trainer.scheduler.get_last_lr()[0] + }, checkpoint_path) + + logger.info(f" Saved checkpoint: {checkpoint_path}") + + # Keep only best 5 checkpoints based on accuracy + self._cleanup_old_checkpoints(checkpoint_dir, keep_best=5, metric='accuracy') + + except Exception as e: + logger.warning(f" Failed to save checkpoint: {e}") + # Clear CUDA cache after each epoch if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -1586,6 +1698,16 @@ class RealTrainingAdapter: session.final_loss = session.current_loss session.accuracy = avg_accuracy + # Log best checkpoint info + try: + checkpoint_dir = "models/checkpoints/transformer" + best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy') + if best_checkpoint_path: + checkpoint = torch.load(best_checkpoint_path, map_location='cpu') + logger.info(f" Best checkpoint: epoch {checkpoint.get('epoch', 0)}, accuracy: {checkpoint.get('accuracy', 0):.2%}") + except Exception as e: + logger.debug(f"Could not load best checkpoint info: {e}") + logger.info(f" Training complete: Loss = {session.final_loss:.6f}, Accuracy = {session.accuracy:.2%}") else: diff --git a/NN/models/advanced_transformer_trading.py b/NN/models/advanced_transformer_trading.py index f629ef9..90d2efe 100644 --- a/NN/models/advanced_transformer_trading.py +++ b/NN/models/advanced_transformer_trading.py @@ -1267,6 +1267,20 @@ class TradingTransformerTrainer: with torch.no_grad(): predictions = torch.argmax(outputs['action_logits'], dim=-1) accuracy = (predictions == batch['actions']).float().mean() + + # Calculate candle prediction accuracy (price direction) + candle_accuracy = 0.0 + if 'next_candles' in outputs and 'future_prices' in batch: + # Use 1m timeframe prediction as primary + if '1m' in outputs['next_candles']: + predicted_candle = outputs['next_candles']['1m'] # [batch, 5] + # Predicted close is the 4th value (index 3) + predicted_close_change = predicted_candle[:, 3] # Predicted close price change + actual_close_change = batch['future_prices'] # Actual price change ratio + + # Check if direction matches (both positive or both negative) + direction_match = (torch.sign(predicted_close_change) == torch.sign(actual_close_change)).float() + candle_accuracy = direction_match.mean().item() # Extract values and delete tensors to free memory result = { @@ -1274,6 +1288,7 @@ class TradingTransformerTrainer: 'action_loss': action_loss.item(), 'price_loss': price_loss.item(), 'accuracy': accuracy.item(), + 'candle_accuracy': candle_accuracy, 'learning_rate': self.scheduler.get_last_lr()[0] }