From 27039c70a39914701845d9edb1f0437742c60c33 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Mon, 10 Nov 2025 20:12:22 +0200 Subject: [PATCH] T model trend prediction added --- ANNOTATE/core/real_training_adapter.py | 205 ++++++++++++++-------- NN/models/advanced_transformer_trading.py | 17 +- 2 files changed, 152 insertions(+), 70 deletions(-) diff --git a/ANNOTATE/core/real_training_adapter.py b/ANNOTATE/core/real_training_adapter.py index b2f81c6..7ffe071 100644 --- a/ANNOTATE/core/real_training_adapter.py +++ b/ANNOTATE/core/real_training_adapter.py @@ -574,16 +574,17 @@ class RealTrainingAdapter: training_data.append(entry_sample) logger.info(f" Test case {i+1}: ENTRY sample - {entry_sample['direction']} @ {entry_sample['entry_price']}") - # Create HOLD samples (every candle while position is open) + # Create HOLD samples (every 13 candles while position is open) # This teaches the model to maintain the position until exit hold_samples = self._create_hold_samples( test_case=test_case, - market_state=market_state + market_state=market_state, + sample_interval=13 # One sample every 13 candles ) training_data.extend(hold_samples) if hold_samples: - logger.info(f" Test case {i+1}: Added {len(hold_samples)} HOLD samples (during position)") + logger.info(f" Test case {i+1}: Added {len(hold_samples)} HOLD samples (every 13 candles during position)") # Create EXIT sample (where model SHOULD exit trade) # Exit info is in expected_outcome, not annotation_metadata @@ -606,21 +607,20 @@ class RealTrainingAdapter: logger.info(f" Test case {i+1}: EXIT sample @ {exit_price} ({expected_outcome.get('profit_loss_pct', 0):.2f}%)") # Create NEGATIVE samples (where model should NOT trade) - # These are candles before and after the signal (±15 candles) + # 5 candles before entry + 5 candles after exit = 10 NO_TRADE samples per annotation # This teaches the model to recognize when NOT to enter negative_samples = self._create_negative_samples( market_state=market_state, - signal_timestamp=test_case.get('timestamp'), - window_size=negative_samples_window + entry_timestamp=test_case.get('timestamp'), + exit_timestamp=None, # Will be calculated from holding period + holding_period_seconds=expected_outcome.get('holding_period_seconds', 0), + samples_before=5, # 5 candles before entry + samples_after=5 # 5 candles after exit ) training_data.extend(negative_samples) if negative_samples: - logger.info(f" Test case {i+1}: Added {len(negative_samples)} NO_TRADE samples (±{negative_samples_window} candles)") - # Show breakdown of before/after - before_count = sum(1 for s in negative_samples if 'before' in str(s.get('timestamp', ''))) - after_count = len(negative_samples) - before_count - logger.info(f" -> {before_count} before signal, {after_count} after signal") + logger.info(f" Test case {i+1}: Added {len(negative_samples)} NO_TRADE samples (5 before entry + 5 after exit)") except Exception as e: logger.error(f" Error preparing test case {i+1}: {e}") @@ -644,9 +644,9 @@ class RealTrainingAdapter: return training_data - def _create_hold_samples(self, test_case: Dict, market_state: Dict) -> List[Dict]: + def _create_hold_samples(self, test_case: Dict, market_state: Dict, sample_interval: int = 13) -> List[Dict]: """ - Create HOLD training samples for every candle while position is open + Create HOLD training samples at intervals while position is open This teaches the model to: 1. Maintain the position (not exit early) @@ -656,6 +656,7 @@ class RealTrainingAdapter: Args: test_case: Test case with entry/exit info market_state: Market state data + sample_interval: Create one sample every N candles (default: 13) Returns: List of HOLD training samples @@ -691,7 +692,8 @@ class RealTrainingAdapter: timestamps = timeframes['1m'].get('timestamps', []) - # Find all candles between entry and exit + # Find all candles between entry and exit, sample every N candles + candles_in_position = [] for idx, ts_str in enumerate(timestamps): # Parse timestamp using unified parser try: @@ -702,24 +704,43 @@ class RealTrainingAdapter: # If this candle is between entry and exit (exclusive) if entry_time < ts < exit_time: - # Create market state snapshot at this candle - hold_market_state = self._create_market_state_snapshot(market_state, idx) - - hold_sample = { - 'market_state': hold_market_state, - 'action': 'HOLD', - 'direction': expected_outcome.get('direction'), - 'profit_loss_pct': expected_outcome.get('profit_loss_pct'), - 'entry_price': expected_outcome.get('entry_price'), - 'exit_price': expected_outcome.get('exit_price'), - 'timestamp': ts_str, - 'label': 'HOLD', # Hold position - 'in_position': True # Flag indicating we're in a position - } - - hold_samples.append(hold_sample) + candles_in_position.append((idx, ts_str, ts)) - logger.debug(f" Created {len(hold_samples)} HOLD samples between entry and exit") + # Sample every Nth candle (e.g., every 13 candles) + for i in range(0, len(candles_in_position), sample_interval): + idx, ts_str, ts = candles_in_position[i] + + # Create market state snapshot at this candle + hold_market_state = self._create_market_state_snapshot(market_state, idx) + + # Calculate current unrealized PnL at this point + entry_price = expected_outcome.get('entry_price', 0) + current_price = timeframes['1m']['close'][idx] if idx < len(timeframes['1m']['close']) else entry_price + direction = expected_outcome.get('direction') + + if entry_price > 0 and current_price > 0: + if direction == 'LONG': + current_pnl = (current_price - entry_price) / entry_price * 100 + else: # SHORT + current_pnl = (entry_price - current_price) / entry_price * 100 + else: + current_pnl = 0.0 + + hold_sample = { + 'market_state': hold_market_state, + 'action': 'HOLD', + 'direction': direction, + 'profit_loss_pct': current_pnl, # Current unrealized PnL + 'entry_price': entry_price, + 'exit_price': expected_outcome.get('exit_price'), + 'timestamp': ts_str, + 'label': 'HOLD', # Hold position + 'in_position': True # Flag indicating we're in a position + } + + hold_samples.append(hold_sample) + + logger.debug(f" Created {len(hold_samples)} HOLD samples (every {sample_interval} candles, {len(candles_in_position)} total candles in position)") except Exception as e: logger.error(f"Error creating HOLD samples: {e}") @@ -728,24 +749,30 @@ class RealTrainingAdapter: return hold_samples - def _create_negative_samples(self, market_state: Dict, signal_timestamp: str, - window_size: int) -> List[Dict]: + def _create_negative_samples(self, market_state: Dict, entry_timestamp: str, + exit_timestamp: Optional[str], holding_period_seconds: int, + samples_before: int = 5, samples_after: int = 5) -> List[Dict]: """ - Create negative training samples from candles around the signal + Create negative training samples from candles before entry and after exit These samples teach the model when NOT to trade - crucial for reducing false signals! Args: market_state: Market state with OHLCV data - signal_timestamp: Timestamp of the actual signal - window_size: Number of candles before/after signal to use + entry_timestamp: Timestamp of entry signal + exit_timestamp: Timestamp of exit signal (optional, calculated from holding period) + holding_period_seconds: Duration of the trade in seconds + samples_before: Number of candles before entry (default: 5) + samples_after: Number of candles after exit (default: 5) Returns: - List of negative training samples + List of negative training samples (NO_TRADE) """ negative_samples = [] try: + from datetime import timedelta + # Get timestamps from market state (use 1m timeframe as reference) timeframes = market_state.get('timeframes', {}) if '1m' not in timeframes: @@ -756,55 +783,65 @@ class RealTrainingAdapter: if not timestamps: return negative_samples - # Find the index of the signal timestamp - from datetime import datetime - - # Parse signal timestamp using unified parser + # Parse entry timestamp try: - signal_time = parse_timestamp_to_utc(signal_timestamp) + entry_time = parse_timestamp_to_utc(entry_timestamp) except Exception as e: - logger.warning(f"Could not parse signal timestamp '{signal_timestamp}': {e}") + logger.warning(f"Could not parse entry timestamp '{entry_timestamp}': {e}") return negative_samples - signal_index = None + # Calculate exit time + exit_time = entry_time + timedelta(seconds=holding_period_seconds) + + # Find entry and exit indices + entry_index = None + exit_index = None + for idx, ts_str in enumerate(timestamps): try: - # Parse timestamp using unified parser ts = parse_timestamp_to_utc(ts_str) - # Match within 1 minute - if abs((ts - signal_time).total_seconds()) < 60: - signal_index = idx - logger.debug(f" Found signal at index {idx}: {ts_str}") + # Match entry within 1 minute + if entry_index is None and abs((ts - entry_time).total_seconds()) < 60: + entry_index = idx + + # Match exit within 1 minute + if exit_index is None and abs((ts - exit_time).total_seconds()) < 60: + exit_index = idx + + if entry_index is not None and exit_index is not None: break except Exception as e: continue - if signal_index is None: - logger.warning(f"Could not find signal timestamp {signal_timestamp} in market data") - logger.warning(f" Market data has {len(timestamps)} timestamps from {timestamps[0] if timestamps else 'N/A'} to {timestamps[-1] if timestamps else 'N/A'}") + if entry_index is None: + logger.warning(f"Could not find entry timestamp in market data") return negative_samples - # Create negative samples from candles before and after the signal - # BEFORE signal: candles at signal_index - window_size to signal_index - 1 - # AFTER signal: candles at signal_index + 1 to signal_index + window_size + # If exit not found, estimate it + if exit_index is None: + # Estimate: 1 minute per candle + candles_in_trade = holding_period_seconds // 60 + exit_index = min(entry_index + candles_in_trade, len(timestamps) - 1) + logger.debug(f" Estimated exit index: {exit_index} ({candles_in_trade} candles)") + # Create NO_TRADE samples: 5 before entry + 5 after exit negative_indices = [] - # Before signal - for offset in range(1, window_size + 1): - idx = signal_index - offset + # 5 candles BEFORE entry + for offset in range(1, samples_before + 1): + idx = entry_index - offset if 0 <= idx < len(timestamps): - negative_indices.append(idx) + negative_indices.append(('before_entry', idx)) - # After signal - for offset in range(1, window_size + 1): - idx = signal_index + offset + # 5 candles AFTER exit + for offset in range(1, samples_after + 1): + idx = exit_index + offset if 0 <= idx < len(timestamps): - negative_indices.append(idx) + negative_indices.append(('after_exit', idx)) - # Create negative samples for each index - for idx in negative_indices: + # Create negative samples + for location, idx in negative_indices: # Create a market state snapshot at this timestamp negative_market_state = self._create_market_state_snapshot(market_state, idx) @@ -816,12 +853,13 @@ class RealTrainingAdapter: 'entry_price': None, 'exit_price': None, 'timestamp': timestamps[idx], - 'label': 'NO_TRADE' # Negative label + 'label': 'NO_TRADE', # Negative label + 'in_position': False # Not in position } negative_samples.append(negative_sample) - logger.debug(f" Created {len(negative_samples)} negative samples from ±{window_size} candles") + logger.debug(f" Created {len(negative_samples)} NO_TRADE samples ({samples_before} before entry + {samples_after} after exit)") except Exception as e: logger.error(f"Error creating negative samples: {e}") @@ -1423,6 +1461,33 @@ class RealTrainingAdapter: # FIXED: Ensure shape is [1, 1] not [1] to match BCELoss requirements trade_success = torch.tensor([[1.0 if profit_loss_pct > 0 else 0.0]], dtype=torch.float32) # [1, 1] + # NEW: Trend vector target for trend analysis optimization + # Calculate expected trend from entry to exit + direction = training_sample.get('direction', 'NONE') + + if direction == 'LONG': + # Upward trend: positive angle, positive direction + trend_angle = 0.785 # ~45 degrees in radians (pi/4) + trend_direction = 1.0 # Upward + elif direction == 'SHORT': + # Downward trend: negative angle, negative direction + trend_angle = -0.785 # ~-45 degrees + trend_direction = -1.0 # Downward + else: + # No trend + trend_angle = 0.0 + trend_direction = 0.0 + + # Steepness based on profit potential + if exit_price and entry_price and entry_price > 0: + price_change_pct = abs((exit_price - entry_price) / entry_price) + trend_steepness = min(price_change_pct * 10, 1.0) # Normalize to [0, 1] + else: + trend_steepness = 0.0 + + # Create trend target tensor [batch, 3]: [angle, steepness, direction] + trend_target = torch.tensor([[trend_angle, trend_steepness, trend_direction]], dtype=torch.float32) # [1, 3] + # Return batch dictionary with ALL timeframes batch = { # Multi-timeframe price data @@ -1440,8 +1505,9 @@ class RealTrainingAdapter: # Training targets 'actions': actions, # [1] - 'future_prices': future_prices, # [1] + 'future_prices': future_prices, # [1, 1] 'trade_success': trade_success, # [1, 1] + 'trend_target': trend_target, # [1, 3] - NEW: [angle, steepness, direction] # Legacy support (use 1m as default) 'price_data': price_data_1m if price_data_1m is not None else ref_data @@ -1646,13 +1712,14 @@ class RealTrainingAdapter: batch_loss = result.get('total_loss', 0.0) batch_accuracy = result.get('accuracy', 0.0) batch_candle_accuracy = result.get('candle_accuracy', 0.0) + batch_trend_loss = result.get('trend_loss', 0.0) epoch_loss += batch_loss epoch_accuracy += batch_accuracy num_batches += 1 # Log first batch and every 5th batch for debugging if (i + 1) == 1 or (i + 1) % 5 == 0: - logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}") + logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}") else: logger.warning(f" Batch {i + 1} returned None result - skipping") diff --git a/NN/models/advanced_transformer_trading.py b/NN/models/advanced_transformer_trading.py index 6eadb63..23bd7be 100644 --- a/NN/models/advanced_transformer_trading.py +++ b/NN/models/advanced_transformer_trading.py @@ -1203,8 +1203,22 @@ class TradingTransformerTrainer: price_loss = self.price_criterion(price_pred, price_target) + # NEW: Trend analysis loss (if trend_target provided) + trend_loss = torch.tensor(0.0, device=self.device) + if 'trend_target' in batch and 'trend_analysis' in outputs: + trend_pred = torch.cat([ + outputs['trend_analysis']['angle_radians'], + outputs['trend_analysis']['steepness'], + outputs['trend_analysis']['direction'] + ], dim=1) # [batch, 3] + + trend_target = batch['trend_target'] + if trend_pred.shape == trend_target.shape: + trend_loss = self.price_criterion(trend_pred, trend_target) + logger.debug(f"Trend loss: {trend_loss.item():.6f} (pred={trend_pred[0].tolist()}, target={trend_target[0].tolist()})") + # Start with base losses - avoid inplace operations on computation graph - total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task + total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss # Weight auxiliary tasks # CRITICAL FIX: Scale loss for gradient accumulation # This prevents gradient explosion when accumulating over multiple batches @@ -1308,6 +1322,7 @@ class TradingTransformerTrainer: 'total_loss': total_loss.item(), 'action_loss': action_loss.item(), 'price_loss': price_loss.item(), + 'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0, # NEW 'accuracy': accuracy.item(), 'candle_accuracy': candle_accuracy, 'learning_rate': self.scheduler.get_last_lr()[0]