From bd213c44e0821bd4d0f15c928b16acd30d85371d Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Sat, 25 Oct 2025 00:47:56 +0300 Subject: [PATCH] train wip --- ANNOTATE/core/real_training_adapter.py | 444 +++++++++++++++++++++++-- 1 file changed, 413 insertions(+), 31 deletions(-) diff --git a/ANNOTATE/core/real_training_adapter.py b/ANNOTATE/core/real_training_adapter.py index ad48a56..1bf33de 100644 --- a/ANNOTATE/core/real_training_adapter.py +++ b/ANNOTATE/core/real_training_adapter.py @@ -325,8 +325,8 @@ class RealTrainingAdapter: logger.debug(f" Test case {i+1}: has_market_state={bool(market_state)}, has_expected_outcome={bool(expected_outcome)}") - # Create POSITIVE sample (where model SHOULD trade) - positive_sample = { + # Create ENTRY sample (where model SHOULD enter trade) + entry_sample = { 'market_state': market_state, 'action': test_case.get('action'), 'direction': expected_outcome.get('direction'), @@ -334,12 +334,40 @@ class RealTrainingAdapter: 'entry_price': expected_outcome.get('entry_price'), 'exit_price': expected_outcome.get('exit_price'), 'timestamp': test_case.get('timestamp'), - 'label': 'TRADE', # Positive label + 'label': 'ENTRY', # Entry signal 'repetitions': training_repetitions } - training_data.append(positive_sample) - logger.debug(f" ✅ Positive sample: {positive_sample['direction']} @ {positive_sample['entry_price']} -> {positive_sample['exit_price']} ({positive_sample['profit_loss_pct']:.2f}%)") + training_data.append(entry_sample) + logger.debug(f" ✅ Entry sample: {entry_sample['direction']} @ {entry_sample['entry_price']}") + + # Create HOLD samples (every candle while position is open) + # This teaches the model to maintain the position until exit + hold_samples = self._create_hold_samples( + test_case=test_case, + market_state=market_state, + repetitions=training_repetitions // 4 # Quarter reps for hold samples + ) + + training_data.extend(hold_samples) + logger.debug(f" 📊 Added {len(hold_samples)} HOLD samples (during position)") + + # Create EXIT sample (where model SHOULD exit trade) + exit_timestamp = test_case.get('annotation_metadata', {}).get('exit_timestamp') + if exit_timestamp: + exit_sample = { + 'market_state': market_state, # TODO: Get market state at exit time + 'action': 'CLOSE', + 'direction': expected_outcome.get('direction'), + 'profit_loss_pct': expected_outcome.get('profit_loss_pct'), + 'entry_price': expected_outcome.get('entry_price'), + 'exit_price': expected_outcome.get('exit_price'), + 'timestamp': exit_timestamp, + 'label': 'EXIT', # Exit signal + 'repetitions': training_repetitions + } + training_data.append(exit_sample) + logger.debug(f" ✅ Exit sample @ {exit_sample['exit_price']} ({exit_sample['profit_loss_pct']:.2f}%)") # Create NEGATIVE samples (where model should NOT trade) # These are candles before and after the signal @@ -356,19 +384,110 @@ class RealTrainingAdapter: except Exception as e: logger.error(f"❌ Error preparing test case {i+1}: {e}") - total_positive = sum(1 for s in training_data if s.get('label') == 'TRADE') - total_negative = sum(1 for s in training_data if s.get('label') == 'NO_TRADE') + total_entry = sum(1 for s in training_data if s.get('label') == 'ENTRY') + total_hold = sum(1 for s in training_data if s.get('label') == 'HOLD') + total_exit = sum(1 for s in training_data if s.get('label') == 'EXIT') + total_no_trade = sum(1 for s in training_data if s.get('label') == 'NO_TRADE') logger.info(f"✅ Prepared {len(training_data)} training samples from {len(test_cases)} test cases") - logger.info(f" Positive samples (TRADE): {total_positive}") - logger.info(f" Negative samples (NO_TRADE): {total_negative}") - logger.info(f" Ratio: 1:{total_negative/total_positive:.1f} (positive:negative)") + logger.info(f" ENTRY samples: {total_entry}") + logger.info(f" HOLD samples: {total_hold}") + logger.info(f" EXIT samples: {total_exit}") + logger.info(f" NO_TRADE samples: {total_no_trade}") + + if total_entry > 0: + logger.info(f" Ratio: 1:{total_no_trade/total_entry:.1f} (entry:no_trade)") if len(training_data) < len(test_cases): logger.warning(f"⚠️ Skipped {len(test_cases) - len(training_data)} test cases due to missing data") return training_data + def _create_hold_samples(self, test_case: Dict, market_state: Dict, repetitions: int) -> List[Dict]: + """ + Create HOLD training samples for every candle while position is open + + This teaches the model to: + 1. Maintain the position (not exit early) + 2. Recognize the trade is still valid + 3. Wait for the optimal exit point + + Args: + test_case: Test case with entry/exit info + market_state: Market state data + repetitions: Number of times to repeat each hold sample + + Returns: + List of HOLD training samples + """ + hold_samples = [] + + try: + from datetime import datetime, timedelta + + # Get entry and exit timestamps + entry_timestamp = test_case.get('timestamp') + expected_outcome = test_case.get('expected_outcome', {}) + + # Calculate exit timestamp from holding period + holding_period_seconds = expected_outcome.get('holding_period_seconds', 0) + if holding_period_seconds == 0: + logger.debug(" No holding period, skipping HOLD samples") + return hold_samples + + # Parse entry timestamp + try: + if 'T' in entry_timestamp: + entry_time = datetime.fromisoformat(entry_timestamp.replace('Z', '+00:00')) + else: + entry_time = datetime.strptime(entry_timestamp, '%Y-%m-%d %H:%M:%S') + entry_time = entry_time.replace(tzinfo=pytz.UTC) + except Exception as e: + logger.warning(f"Could not parse entry timestamp '{entry_timestamp}': {e}") + return hold_samples + + exit_time = entry_time + timedelta(seconds=holding_period_seconds) + + # Get 1m timeframe timestamps + timeframes = market_state.get('timeframes', {}) + if '1m' not in timeframes: + return hold_samples + + timestamps = timeframes['1m'].get('timestamps', []) + + # Find all candles between entry and exit + for idx, ts_str in enumerate(timestamps): + ts = datetime.fromisoformat(ts_str.replace(' ', 'T')) + + # If this candle is between entry and exit (exclusive) + if entry_time < ts < exit_time: + # Create market state snapshot at this candle + hold_market_state = self._create_market_state_snapshot(market_state, idx) + + hold_sample = { + 'market_state': hold_market_state, + 'action': 'HOLD', + 'direction': expected_outcome.get('direction'), + 'profit_loss_pct': expected_outcome.get('profit_loss_pct'), + 'entry_price': expected_outcome.get('entry_price'), + 'exit_price': expected_outcome.get('exit_price'), + 'timestamp': ts_str, + 'label': 'HOLD', # Hold position + 'repetitions': repetitions, + 'in_position': True # Flag indicating we're in a position + } + + hold_samples.append(hold_sample) + + logger.debug(f" Created {len(hold_samples)} HOLD samples between entry and exit") + + except Exception as e: + logger.error(f"Error creating HOLD samples: {e}") + import traceback + logger.error(traceback.format_exc()) + + return hold_samples + def _create_negative_samples(self, market_state: Dict, signal_timestamp: str, window_size: int, repetitions: int) -> List[Dict]: """ @@ -400,17 +519,39 @@ class RealTrainingAdapter: # Find the index of the signal timestamp from datetime import datetime - signal_time = datetime.fromisoformat(signal_timestamp.replace('Z', '+00:00')) + + # Parse signal timestamp - handle different formats + try: + if 'T' in signal_timestamp: + signal_time = datetime.fromisoformat(signal_timestamp.replace('Z', '+00:00')) + else: + signal_time = datetime.strptime(signal_timestamp, '%Y-%m-%d %H:%M:%S') + signal_time = signal_time.replace(tzinfo=pytz.UTC) + except Exception as e: + logger.warning(f"Could not parse signal timestamp '{signal_timestamp}': {e}") + return negative_samples signal_index = None for idx, ts_str in enumerate(timestamps): - ts = datetime.fromisoformat(ts_str.replace(' ', 'T')) - if abs((ts - signal_time).total_seconds()) < 60: # Within 1 minute - signal_index = idx - break + try: + # Parse timestamp from market data + if 'T' in ts_str: + ts = datetime.fromisoformat(ts_str.replace('Z', '+00:00')) + else: + ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S') + ts = ts.replace(tzinfo=pytz.UTC) + + # Match within 1 minute + if abs((ts - signal_time).total_seconds()) < 60: + signal_index = idx + logger.debug(f" Found signal at index {idx}: {ts_str}") + break + except Exception as e: + continue if signal_index is None: - logger.warning(f"Could not find signal timestamp in market data") + logger.warning(f"Could not find signal timestamp {signal_timestamp} in market data") + logger.warning(f" Market data has {len(timestamps)} timestamps from {timestamps[0] if timestamps else 'N/A'} to {timestamps[-1] if timestamps else 'N/A'}") return negative_samples # Create negative samples from candles before and after the signal @@ -596,6 +737,224 @@ class RealTrainingAdapter: state_size = agent.state_size if hasattr(agent, 'state_size') else 100 return [0.0] * state_size + def _convert_annotation_to_transformer_batch(self, training_sample: Dict) -> Dict[str, 'torch.Tensor']: + """ + Convert annotation training sample to transformer model input format + + The transformer expects: + - price_data: [batch, seq_len, features] - OHLCV sequences + - cob_data: [batch, seq_len, cob_features] - Change of Bid data + - tech_data: [batch, features] - Technical indicators + - market_data: [batch, features] - Market context + - actions: [batch] - Target actions (0=HOLD, 1=BUY, 2=SELL) + - future_prices: [batch] - Future price targets + - trade_success: [batch] - Whether trade was successful + """ + import torch + import numpy as np + + try: + market_state = training_sample.get('market_state', {}) + + # Extract OHLCV data from ALL timeframes + timeframes = market_state.get('timeframes', {}) + + # Collect data from all available timeframes + all_price_data = [] + timeframe_order = ['1s', '1m', '1h', '1d'] # Process in order + + for tf in timeframe_order: + if tf not in timeframes: + continue + + tf_data = timeframes[tf] + + # Convert to numpy arrays + opens = np.array(tf_data.get('open', []), dtype=np.float32) + highs = np.array(tf_data.get('high', []), dtype=np.float32) + lows = np.array(tf_data.get('low', []), dtype=np.float32) + closes = np.array(tf_data.get('close', []), dtype=np.float32) + volumes = np.array(tf_data.get('volume', []), dtype=np.float32) + + if len(closes) > 0: + # Stack OHLCV for this timeframe [seq_len, 5] + tf_price_data = np.stack([opens, highs, lows, closes, volumes], axis=-1) + all_price_data.append(tf_price_data) + + if not all_price_data: + logger.warning("No price data in any timeframe") + return None + + # Concatenate all timeframes along sequence dimension + # This gives the model multi-timeframe context + price_data = np.concatenate(all_price_data, axis=0) + + # Add batch dimension [1, total_seq_len, 5] + price_data = torch.tensor(price_data, dtype=torch.float32).unsqueeze(0) + + # Get primary timeframe for reference + primary_tf = '1m' if '1m' in timeframes else timeframe_order[0] + closes = np.array(timeframes[primary_tf].get('close', []), dtype=np.float32) + + # Create placeholder COB data (zeros if not available) + # COB data shape: [1, seq_len, cob_features] + # Transformer expects 100 COB features (as defined in TransformerConfig) + cob_data = torch.zeros(1, len(closes), 100, dtype=torch.float32) # 100 COB features + + # Create technical indicators (simple ones for now) + # tech_data shape: [1, features] + tech_features = [] + + # Add simple technical indicators + if len(closes) >= 20: + sma_20 = np.mean(closes[-20:]) + tech_features.append(closes[-1] / sma_20 - 1.0) # Price vs SMA + else: + tech_features.append(0.0) + + if len(closes) >= 2: + returns = (closes[-1] - closes[-2]) / closes[-2] + tech_features.append(returns) # Recent return + else: + tech_features.append(0.0) + + # Add volatility + if len(closes) >= 20: + volatility = np.std(closes[-20:]) / np.mean(closes[-20:]) + tech_features.append(volatility) + else: + tech_features.append(0.0) + + # Pad tech_features to match transformer's expected size (40 features) + while len(tech_features) < 40: + tech_features.append(0.0) + + tech_data = torch.tensor([tech_features[:40]], dtype=torch.float32) # Ensure exactly 40 features + + # Create market context data with pivot points + # market_data shape: [1, features] + market_features = [] + + # Add volume profile + primary_volumes = np.array(timeframes[primary_tf].get('volume', []), dtype=np.float32) + if len(primary_volumes) >= 20: + vol_ratio = primary_volumes[-1] / np.mean(primary_volumes[-20:]) + market_features.append(vol_ratio) + else: + market_features.append(1.0) + + # Add price range + primary_highs = np.array(timeframes[primary_tf].get('high', []), dtype=np.float32) + primary_lows = np.array(timeframes[primary_tf].get('low', []), dtype=np.float32) + if len(primary_highs) >= 20 and len(primary_lows) >= 20: + price_range = (np.max(primary_highs[-20:]) - np.min(primary_lows[-20:])) / closes[-1] + market_features.append(price_range) + else: + market_features.append(0.0) + + # Add pivot point features + # Calculate simple pivot points from recent price action + if len(primary_highs) >= 5 and len(primary_lows) >= 5: + # Pivot Point = (High + Low + Close) / 3 + pivot = (primary_highs[-1] + primary_lows[-1] + closes[-1]) / 3.0 + + # Support and Resistance levels + r1 = 2 * pivot - primary_lows[-1] # Resistance 1 + s1 = 2 * pivot - primary_highs[-1] # Support 1 + + # Normalize relative to current price + pivot_distance = (closes[-1] - pivot) / closes[-1] + r1_distance = (closes[-1] - r1) / closes[-1] + s1_distance = (closes[-1] - s1) / closes[-1] + + market_features.extend([pivot_distance, r1_distance, s1_distance]) + else: + market_features.extend([0.0, 0.0, 0.0]) + + # Add Williams pivot levels if available in market state + pivot_markers = market_state.get('pivot_markers', {}) + if pivot_markers: + # Count nearby pivot levels + num_support = len([p for p in pivot_markers.get('support_levels', []) if abs(p - closes[-1]) / closes[-1] < 0.02]) + num_resistance = len([p for p in pivot_markers.get('resistance_levels', []) if abs(p - closes[-1]) / closes[-1] < 0.02]) + market_features.extend([float(num_support), float(num_resistance)]) + else: + market_features.extend([0.0, 0.0]) + + # Pad market_features to match transformer's expected size (30 features) + while len(market_features) < 30: + market_features.append(0.0) + + market_data = torch.tensor([market_features[:30]], dtype=torch.float32) # Ensure exactly 30 features + + # Convert action to tensor + # 0 = HOLD/NO_TRADE, 1 = BUY (LONG), 2 = SELL (SHORT) + action_label = training_sample.get('label', 'TRADE') + direction = training_sample.get('direction', 'NONE') + in_position = training_sample.get('in_position', False) + + if action_label == 'NO_TRADE': + action = 0 # HOLD - no position + elif action_label == 'HOLD': + action = 0 # HOLD - maintain position + elif action_label == 'ENTRY': + if direction == 'LONG': + action = 1 # BUY + elif direction == 'SHORT': + action = 2 # SELL + else: + action = 0 + elif action_label == 'EXIT': + # Exit is opposite of entry + if direction == 'LONG': + action = 2 # SELL to close long + elif direction == 'SHORT': + action = 1 # BUY to close short + else: + action = 0 + elif direction == 'LONG': + action = 1 # BUY + elif direction == 'SHORT': + action = 2 # SELL + else: + action = 0 # HOLD + + actions = torch.tensor([action], dtype=torch.long) + + # Future price target + entry_price = training_sample.get('entry_price') + exit_price = training_sample.get('exit_price') + + if exit_price and entry_price: + future_price = exit_price + else: + future_price = closes[-1] # Current price for HOLD + + future_prices = torch.tensor([future_price], dtype=torch.float32) + + # Trade success (1.0 if profitable, 0.0 otherwise) + profit_loss_pct = training_sample.get('profit_loss_pct', 0.0) + trade_success = torch.tensor([1.0 if profit_loss_pct > 0 else 0.0], dtype=torch.float32) + + # Return batch dictionary + batch = { + 'price_data': price_data, + 'cob_data': cob_data, + 'tech_data': tech_data, + 'market_data': market_data, + 'actions': actions, + 'future_prices': future_prices, + 'trade_success': trade_success + } + + return batch + + except Exception as e: + logger.error(f"Error converting annotation to transformer batch: {e}") + import traceback + logger.error(traceback.format_exc()) + return None + def _train_transformer_real(self, session: TrainingSession, training_data: List[Dict]): """ Train Transformer model using orchestrator's existing training infrastructure @@ -618,40 +977,63 @@ class RealTrainingAdapter: # Use the trainer's train_step method for individual samples if hasattr(trainer, 'train_step'): logger.info(" Using trainer.train_step() method") + logger.info(" Converting annotation data to transformer format...") import torch - # Train using train_step for each sample + # Convert all training samples to transformer format + converted_batches = [] + for i, data in enumerate(training_data): + batch = self._convert_annotation_to_transformer_batch(data) + if batch is not None: + # Repeat based on repetitions parameter + repetitions = data.get('repetitions', 1) + for _ in range(repetitions): + converted_batches.append(batch) + else: + logger.warning(f" Failed to convert sample {i+1}") + + if not converted_batches: + raise Exception("No valid training batches after conversion") + + logger.info(f" ✅ Converted {len(training_data)} samples to {len(converted_batches)} training batches") + + # Train using train_step for each batch for epoch in range(session.total_epochs): epoch_loss = 0.0 - num_samples = 0 + epoch_accuracy = 0.0 + num_batches = 0 - for i, data in enumerate(training_data): + for i, batch in enumerate(converted_batches): try: - # Call the trainer's train_step method - loss = trainer.train_step(data) + # Call the trainer's train_step method with proper batch format + result = trainer.train_step(batch) - if loss is not None: - epoch_loss += float(loss) - num_samples += 1 + if result is not None: + epoch_loss += result.get('total_loss', 0.0) + epoch_accuracy += result.get('accuracy', 0.0) + num_batches += 1 - if (i + 1) % 10 == 0: - logger.debug(f" Sample {i + 1}/{len(training_data)}, Loss: {loss:.6f}") + if (i + 1) % 100 == 0: + logger.debug(f" Batch {i + 1}/{len(converted_batches)}, Loss: {result.get('total_loss', 0.0):.6f}") except Exception as e: - logger.error(f" Error in sample {i + 1}: {e}") + logger.error(f" Error in batch {i + 1}: {e}") + import traceback + logger.error(traceback.format_exc()) continue - avg_loss = epoch_loss / num_samples if num_samples > 0 else 0.0 + avg_loss = epoch_loss / num_batches if num_batches > 0 else 0.0 + avg_accuracy = epoch_accuracy / num_batches if num_batches > 0 else 0.0 session.current_epoch = epoch + 1 session.current_loss = avg_loss - logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Avg Loss: {avg_loss:.6f} ({num_samples} samples)") + logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)") session.final_loss = session.current_loss - session.accuracy = 0.85 # TODO: Calculate actual accuracy + session.accuracy = avg_accuracy - logger.info(f" Training complete: Loss = {session.final_loss:.6f}") + logger.info(f" Training complete: Loss = {session.final_loss:.6f}, Accuracy = {session.accuracy:.2%}") else: raise Exception(f"Transformer trainer does not have train_on_batch() or train() methods. Available methods: {[m for m in dir(trainer) if not m.startswith('_')]}")