diff --git a/ANNOTATE/core/real_training_adapter.py b/ANNOTATE/core/real_training_adapter.py index ee06ca9..8d2397d 100644 --- a/ANNOTATE/core/real_training_adapter.py +++ b/ANNOTATE/core/real_training_adapter.py @@ -4090,20 +4090,37 @@ class RealTrainingAdapter: return None # Override the future candle target with actual candle data - actual = prediction_sample['actual_candle'] # [O, H, L, C] + actual = prediction_sample['actual_candle'] # [O, H, L, C, V] or [O, H, L, C] # Create target tensor for the specific timeframe import torch - device = batch['prices_1m'].device if 'prices_1m' in batch else torch.device('cpu') + # Get device from any available tensor in batch + device = torch.device('cpu') + for key in ['price_data_1m', 'price_data_1h', 'price_data_1d', 'prices_1m']: + if key in batch and batch[key] is not None: + device = batch[key].device + break - # Target candle: [O, H, L, C, V] - we don't have actual volume, use predicted - target_candle = [ - actual[0], # Open - actual[1], # High - actual[2], # Low - actual[3], # Close - prediction_sample['predicted_candle'][4] # Volume (from prediction) - ] + # Target candle: [O, H, L, C, V] + # Use actual volume if available, otherwise use predicted volume + if len(actual) >= 5: + target_candle = [ + float(actual[0]), # Open + float(actual[1]), # High + float(actual[2]), # Low + float(actual[3]), # Close + float(actual[4]) # Volume (from actual) + ] + else: + # Fallback: use predicted volume if actual doesn't have it + predicted = prediction_sample.get('predicted_candle', [0, 0, 0, 0, 0]) + target_candle = [ + float(actual[0]), # Open + float(actual[1]), # High + float(actual[2]), # Low + float(actual[3]), # Close + float(predicted[4] if len(predicted) > 4 else 0.0) # Volume (from prediction) + ] # Add to batch based on timeframe if timeframe == '1s': diff --git a/ANNOTATE/web/app.py b/ANNOTATE/web/app.py index a316499..ce75741 100644 --- a/ANNOTATE/web/app.py +++ b/ANNOTATE/web/app.py @@ -2826,6 +2826,57 @@ class AnnotationDashboard: 'error': str(e) }), 500 + @self.server.route('/api/training-metrics', methods=['GET']) + def get_training_metrics(): + """Get current training metrics for display (loss, accuracy, etc.)""" + try: + metrics = { + 'loss': 0.0, + 'accuracy': 0.0, + 'steps': 0, + 'recent_history': [] + } + + # Get metrics from training adapter if available + if self.training_adapter and hasattr(self.training_adapter, 'realtime_training_metrics'): + rt_metrics = self.training_adapter.realtime_training_metrics + metrics['loss'] = rt_metrics.get('last_loss', 0.0) + metrics['accuracy'] = rt_metrics.get('last_accuracy', 0.0) + metrics['steps'] = rt_metrics.get('total_steps', 0) + + # Get incremental training metrics + if hasattr(self, '_incremental_training_steps'): + metrics['incremental_steps'] = self._incremental_training_steps + if hasattr(self, '_training_metrics_history') and self._training_metrics_history: + # Get last 10 metrics for display + metrics['recent_history'] = self._training_metrics_history[-10:] + # Update current metrics from most recent + latest = self._training_metrics_history[-1] + metrics['loss'] = latest.get('loss', metrics['loss']) + metrics['accuracy'] = latest.get('accuracy', metrics['accuracy']) + + # Get metrics from orchestrator trainer if available + if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer_trainer'): + trainer = self.orchestrator.primary_transformer_trainer + if trainer and hasattr(trainer, 'training_history'): + history = trainer.training_history + if history.get('train_loss'): + metrics['loss'] = history['train_loss'][-1] if history['train_loss'] else metrics['loss'] + if history.get('train_accuracy'): + metrics['accuracy'] = history['train_accuracy'][-1] if history['train_accuracy'] else metrics['accuracy'] + + return jsonify({ + 'success': True, + 'metrics': metrics + }) + + except Exception as e: + logger.error(f"Error getting training metrics: {e}") + return jsonify({ + 'success': False, + 'error': str(e) + }), 500 + @self.server.route('/api/realtime-inference/train-manual', methods=['POST']) def train_manual(): """Manually trigger training on current candle with specified action""" @@ -3074,11 +3125,17 @@ class AnnotationDashboard: # We need to fetch the market state at that timestamp symbol = 'ETH/USDT' # TODO: Get from active trading pair + # Ensure actual_candle has volume (frontend sends [O, H, L, C, V]) + actual_candle = list(actual) if isinstance(actual, (list, tuple)) else actual + if len(actual_candle) == 4: + # If only 4 values, add volume from predicted (fallback) + actual_candle.append(predicted[4] if len(predicted) > 4 else 0.0) + training_sample = { 'symbol': symbol, 'timestamp': timestamp, 'predicted_candle': predicted, # [O, H, L, C, V] - 'actual_candle': actual, # [O, H, L, C] + 'actual_candle': actual_candle, # [O, H, L, C, V] - ensure 5 values 'errors': errors, 'accuracy': accuracy, 'direction_correct': direction_correct, @@ -3088,22 +3145,38 @@ class AnnotationDashboard: # Get market state at that timestamp try: market_state = self._fetch_market_state_at_timestamp(symbol, timestamp, timeframe) + if not market_state or 'timeframes' not in market_state: + logger.warning(f"Could not fetch market state for {symbol} at {timestamp}") + return None training_sample['market_state'] = market_state except Exception as e: logger.warning(f"Could not fetch market state: {e}") - return + return None # Convert to transformer batch format batch = self.training_adapter._convert_prediction_to_batch(training_sample, timeframe) if not batch: logger.warning("Could not convert validated prediction to training batch") - return + return None # Train on this batch with sample weighting + # CRITICAL: Use training lock to prevent concurrent access import torch - with torch.enable_grad(): - trainer.model.train() - result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight) + import threading + + # Try to acquire training lock with timeout + if hasattr(self.training_adapter, '_training_lock'): + lock_acquired = self.training_adapter._training_lock.acquire(timeout=5.0) + if not lock_acquired: + logger.warning("Could not acquire training lock within 5 seconds - skipping incremental training") + return None + else: + lock_acquired = False + + try: + with torch.enable_grad(): + trainer.model.train() + result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight) if result: loss = result.get('total_loss', 0) @@ -3151,46 +3224,126 @@ class AnnotationDashboard: 'steps': self._incremental_training_steps, 'sample_weight': sample_weight } + else: + logger.warning("Training step returned no result") + return None + finally: + # CRITICAL: Always release the lock + if lock_acquired and hasattr(self.training_adapter, '_training_lock'): + self.training_adapter._training_lock.release() except Exception as e: logger.error(f"Error in incremental training: {e}", exc_info=True) + # Ensure lock is released even on error + if 'lock_acquired' in locals() and lock_acquired and hasattr(self.training_adapter, '_training_lock'): + try: + self.training_adapter._training_lock.release() + except: + pass return None def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict: """Fetch market state at a specific timestamp for training""" try: - from datetime import datetime + from datetime import datetime, timezone import pandas as pd - # Parse timestamp - ts = pd.Timestamp(timestamp) + # Parse timestamp - ensure it's timezone-aware + if isinstance(timestamp, str): + ts = pd.Timestamp(timestamp) + if ts.tz is None: + ts = ts.tz_localize('UTC') + else: + ts = pd.Timestamp(timestamp) + if ts.tz is None: + ts = ts.tz_localize('UTC') - # Get historical data for multiple timeframes + # Use data provider's method to get market state at that time + # This ensures we get the proper format with all required timeframes + if self.data_provider and hasattr(self.data_provider, 'get_market_state_at_time'): + try: + # Convert to datetime if needed + if isinstance(ts, pd.Timestamp): + dt = ts.to_pydatetime() + else: + dt = ts + + # Get market state with context window (need enough candles for training) + market_state = self.data_provider.get_market_state_at_time( + symbol=symbol, + timestamp=dt, + context_window_minutes=600 # Get 600 minutes of context for 1m candles + ) + + if market_state and 'timeframes' in market_state: + logger.debug(f"Fetched market state with {len(market_state.get('timeframes', {}))} timeframes") + return market_state + else: + logger.warning("Market state returned empty or invalid format") + except Exception as e: + logger.warning(f"Could not use data provider method: {e}") + + # Fallback: manually fetch data for each timeframe market_state = {'timeframes': {}, 'secondary_timeframes': {}} - for tf in ['1s', '1m', '1h']: + # REQUIRED timeframes for transformer: 1m, 1h, 1d (1s is optional) + # Need at least 50 candles, preferably 600 + required_tfs = ['1m', '1h', '1d'] + optional_tfs = ['1s'] + + for tf in required_tfs + optional_tfs: try: - df = self.data_provider.get_historical_data(symbol, tf, limit=200) + # Fetch enough candles (600 for training, but accept less) + df = self.data_loader.get_data( + symbol=symbol, + timeframe=tf, + end_time=dt, + limit=600, + direction='before' + ) if self.data_loader else None + + # Fallback to data provider if data_loader not available + if df is None or df.empty: + if self.data_provider: + df = self.data_provider.get_historical_data(symbol, tf, limit=600, refresh=False) + if df is not None and not df.empty: - # Find data up to (but not including) the target timestamp + # Filter to data before the target timestamp df_before = df[df.index < ts] - if not df_before.empty: - recent = df_before.tail(200) + if df_before.empty: + # If no data before timestamp, use all available data + df_before = df + + # Take last 600 candles (or all if less) + recent = df_before.tail(600) + + if len(recent) >= 50: # Minimum required market_state['timeframes'][tf] = { - 'timestamps': self._format_timestamps_utc(recent.index), 'open': recent['open'].tolist(), 'high': recent['high'].tolist(), 'low': recent['low'].tolist(), 'close': recent['close'].tolist(), 'volume': recent['volume'].tolist() } + logger.debug(f"Fetched {len(recent)} candles for {tf} timeframe") + else: + if tf in required_tfs: + logger.warning(f"Required timeframe {tf} has only {len(recent)} candles (need at least 50)") + else: + logger.debug(f"Optional timeframe {tf} has only {len(recent)} candles, skipping") except Exception as e: logger.warning(f"Could not fetch {tf} data: {e}") + # Validate we have required timeframes + missing_required = [tf for tf in required_tfs if tf not in market_state['timeframes']] + if missing_required: + logger.warning(f"Missing required timeframes: {missing_required}") + return {} + return market_state except Exception as e: - logger.error(f"Error fetching market state: {e}") + logger.error(f"Error fetching market state: {e}", exc_info=True) return {} def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1): diff --git a/ANNOTATE/web/templates/components/training_panel.html b/ANNOTATE/web/templates/components/training_panel.html index 6a76151..1f41a60 100644 --- a/ANNOTATE/web/templates/components/training_panel.html +++ b/ANNOTATE/web/templates/components/training_panel.html @@ -85,15 +85,15 @@