""" Annotation Manager - Manages trade annotations and test case generation Handles storage, retrieval, and test case generation from manual trade annotations. """ import json import uuid from pathlib import Path from datetime import datetime from typing import List, Dict, Optional, Any from dataclasses import dataclass, asdict import logging logger = logging.getLogger(__name__) @dataclass class TradeAnnotation: """Represents a manually marked trade""" annotation_id: str symbol: str timeframe: str entry: Dict[str, Any] # {timestamp, price, index} exit: Dict[str, Any] # {timestamp, price, index} direction: str # 'LONG' or 'SHORT' profit_loss_pct: float notes: str = "" created_at: str = None market_context: Dict[str, Any] = None def __post_init__(self): if self.created_at is None: self.created_at = datetime.now().isoformat() if self.market_context is None: self.market_context = {} class AnnotationManager: """Manages trade annotations and test case generation""" def __init__(self, storage_path: str = "ANNOTATE/data/annotations"): """Initialize annotation manager""" self.storage_path = Path(storage_path) self.storage_path.mkdir(parents=True, exist_ok=True) self.annotations_file = self.storage_path / "annotations_db.json" self.test_cases_dir = self.storage_path.parent / "test_cases" self.test_cases_dir.mkdir(parents=True, exist_ok=True) self.annotations_db = self._load_annotations() logger.info(f"AnnotationManager initialized with storage: {self.storage_path}") def _load_annotations(self) -> Dict[str, List[Dict]]: """Load annotations from storage""" if self.annotations_file.exists(): try: with open(self.annotations_file, 'r') as f: data = json.load(f) logger.info(f"Loaded {len(data.get('annotations', []))} annotations") return data except Exception as e: logger.error(f"Error loading annotations: {e}") return {"annotations": [], "metadata": {}} else: return {"annotations": [], "metadata": {}} def _save_annotations(self): """Save annotations to storage""" try: # Update metadata self.annotations_db["metadata"] = { "total_annotations": len(self.annotations_db["annotations"]), "last_updated": datetime.now().isoformat() } with open(self.annotations_file, 'w') as f: json.dump(self.annotations_db, f, indent=2) logger.info(f"Saved {len(self.annotations_db['annotations'])} annotations") except Exception as e: logger.error(f"Error saving annotations: {e}") raise def create_annotation(self, entry_point: Dict, exit_point: Dict, symbol: str, timeframe: str, entry_market_state: Dict = None, exit_market_state: Dict = None) -> TradeAnnotation: """Create new trade annotation""" # Calculate direction and P&L entry_price = entry_point['price'] exit_price = exit_point['price'] if exit_price > entry_price: direction = 'LONG' profit_loss_pct = ((exit_price - entry_price) / entry_price) * 100 else: direction = 'SHORT' profit_loss_pct = ((entry_price - exit_price) / entry_price) * 100 # Store complete market context for training market_context = { 'entry_state': entry_market_state or {}, 'exit_state': exit_market_state or {} } annotation = TradeAnnotation( annotation_id=str(uuid.uuid4()), symbol=symbol, timeframe=timeframe, entry=entry_point, exit=exit_point, direction=direction, profit_loss_pct=profit_loss_pct, market_context=market_context ) logger.info(f"Created annotation: {annotation.annotation_id} ({direction}, {profit_loss_pct:.2f}%)") logger.info(f" Entry state: {len(entry_market_state or {})} timeframes") logger.info(f" Exit state: {len(exit_market_state or {})} timeframes") return annotation def save_annotation(self, annotation: TradeAnnotation): """Save annotation to storage""" # Convert to dict ann_dict = asdict(annotation) # Add to database self.annotations_db["annotations"].append(ann_dict) # Save to file self._save_annotations() logger.info(f"Saved annotation: {annotation.annotation_id}") def get_annotations(self, symbol: str = None, timeframe: str = None) -> List[TradeAnnotation]: """Retrieve annotations with optional filtering""" annotations = self.annotations_db.get("annotations", []) # Filter by symbol if symbol: annotations = [a for a in annotations if a.get('symbol') == symbol] # Filter by timeframe if timeframe: annotations = [a for a in annotations if a.get('timeframe') == timeframe] # Convert to TradeAnnotation objects result = [] for ann_dict in annotations: try: annotation = TradeAnnotation(**ann_dict) result.append(annotation) except Exception as e: logger.error(f"Error converting annotation: {e}") return result def delete_annotation(self, annotation_id: str): """Delete annotation""" original_count = len(self.annotations_db["annotations"]) self.annotations_db["annotations"] = [ a for a in self.annotations_db["annotations"] if a.get('annotation_id') != annotation_id ] if len(self.annotations_db["annotations"]) < original_count: self._save_annotations() logger.info(f"Deleted annotation: {annotation_id}") else: logger.warning(f"Annotation not found: {annotation_id}") def generate_test_case(self, annotation: TradeAnnotation, data_provider=None, auto_save: bool = True) -> Dict: """ Generate test case from annotation in realtime format Args: annotation: TradeAnnotation object data_provider: Optional DataProvider instance to fetch market context Returns: Test case dictionary in realtime format """ test_case = { "test_case_id": f"annotation_{annotation.annotation_id}", "symbol": annotation.symbol, "timestamp": annotation.entry['timestamp'], "action": "BUY" if annotation.direction == "LONG" else "SELL", "market_state": {}, "expected_outcome": { "direction": annotation.direction, "profit_loss_pct": annotation.profit_loss_pct, "holding_period_seconds": self._calculate_holding_period(annotation), "exit_price": annotation.exit['price'], "entry_price": annotation.entry['price'] }, "annotation_metadata": { "annotator": "manual", "confidence": 1.0, "notes": annotation.notes, "created_at": annotation.created_at, "timeframe": annotation.timeframe } } # Populate market state with ±5 minutes of data for negative examples if data_provider: try: entry_time = datetime.fromisoformat(annotation.entry['timestamp'].replace('Z', '+00:00')) exit_time = datetime.fromisoformat(annotation.exit['timestamp'].replace('Z', '+00:00')) # Calculate time window: ±5 minutes around entry time_window_before = timedelta(minutes=5) time_window_after = timedelta(minutes=5) start_time = entry_time - time_window_before end_time = entry_time + time_window_after logger.info(f"Fetching market data from {start_time} to {end_time} (±5min around entry)") # Fetch OHLCV data for all timeframes timeframes = ['1s', '1m', '1h', '1d'] market_state = {} for tf in timeframes: # Get data for the time window df = data_provider.get_historical_data( symbol=annotation.symbol, timeframe=tf, limit=1000 # Get enough data to cover ±5 minutes ) if df is not None and not df.empty: # Filter to time window df_window = df[(df.index >= start_time) & (df.index <= end_time)] if not df_window.empty: # Convert to list format market_state[f'ohlcv_{tf}'] = { 'timestamps': df_window.index.strftime('%Y-%m-%d %H:%M:%S').tolist(), 'open': df_window['open'].tolist(), 'high': df_window['high'].tolist(), 'low': df_window['low'].tolist(), 'close': df_window['close'].tolist(), 'volume': df_window['volume'].tolist() } logger.info(f" {tf}: {len(df_window)} candles in ±5min window") # Add training labels for each timestamp # This helps model learn WHERE to signal and WHERE NOT to signal market_state['training_labels'] = self._generate_training_labels( market_state, entry_time, exit_time, annotation.direction ) test_case["market_state"] = market_state logger.info(f"Populated market state with {len(market_state)-1} timeframes + training labels") except Exception as e: logger.error(f"Error fetching market state: {e}") import traceback traceback.print_exc() test_case["market_state"] = {} else: logger.warning("No data_provider available, market_state will be empty") test_case["market_state"] = {} # Save test case to file if auto_save is True if auto_save: test_case_file = self.test_cases_dir / f"{test_case['test_case_id']}.json" with open(test_case_file, 'w') as f: json.dump(test_case, f, indent=2) logger.info(f"Saved test case to: {test_case_file}") logger.info(f"Generated test case: {test_case['test_case_id']}") return test_case def get_all_test_cases(self) -> List[Dict]: """Load all test cases from disk""" test_cases = [] if not self.test_cases_dir.exists(): return test_cases for test_case_file in self.test_cases_dir.glob("annotation_*.json"): try: with open(test_case_file, 'r') as f: test_case = json.load(f) test_cases.append(test_case) except Exception as e: logger.error(f"Error loading test case {test_case_file}: {e}") logger.info(f"Loaded {len(test_cases)} test cases from disk") return test_cases def _calculate_holding_period(self, annotation: TradeAnnotation) -> float: """Calculate holding period in seconds""" try: entry_time = datetime.fromisoformat(annotation.entry['timestamp'].replace('Z', '+00:00')) exit_time = datetime.fromisoformat(annotation.exit['timestamp'].replace('Z', '+00:00')) return (exit_time - entry_time).total_seconds() except Exception as e: logger.error(f"Error calculating holding period: {e}") return 0.0 def _generate_training_labels(self, market_state: Dict, entry_time: datetime, exit_time: datetime, direction: str) -> Dict: """ Generate training labels for each timestamp in the market data. This helps the model learn WHERE to signal and WHERE NOT to signal. Labels: - 0 = NO SIGNAL (before entry or after exit) - 1 = ENTRY SIGNAL (at entry time) - 2 = HOLD (between entry and exit) - 3 = EXIT SIGNAL (at exit time) """ labels = {} # Use 1m timeframe as reference for labeling if 'ohlcv_1m' in market_state and 'timestamps' in market_state['ohlcv_1m']: timestamps = market_state['ohlcv_1m']['timestamps'] label_list = [] for ts_str in timestamps: try: ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S') # Determine label based on position relative to entry/exit if abs((ts - entry_time).total_seconds()) < 60: # Within 1 minute of entry label = 1 # ENTRY SIGNAL elif abs((ts - exit_time).total_seconds()) < 60: # Within 1 minute of exit label = 3 # EXIT SIGNAL elif entry_time < ts < exit_time: # Between entry and exit label = 2 # HOLD else: # Before entry or after exit label = 0 # NO SIGNAL label_list.append(label) except Exception as e: logger.error(f"Error parsing timestamp {ts_str}: {e}") label_list.append(0) labels['labels_1m'] = label_list labels['direction'] = direction labels['entry_timestamp'] = entry_time.strftime('%Y-%m-%d %H:%M:%S') labels['exit_timestamp'] = exit_time.strftime('%Y-%m-%d %H:%M:%S') logger.info(f"Generated {len(label_list)} training labels: " f"{label_list.count(0)} NO_SIGNAL, " f"{label_list.count(1)} ENTRY, " f"{label_list.count(2)} HOLD, " f"{label_list.count(3)} EXIT") return labels def export_annotations(self, annotations: List[TradeAnnotation] = None, format_type: str = 'json') -> Path: """Export annotations to file""" if annotations is None: annotations = self.get_annotations() # Convert to dicts export_data = [asdict(ann) for ann in annotations] # Create export file timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') export_file = self.storage_path / f"export_{timestamp}.{format_type}" if format_type == 'json': with open(export_file, 'w') as f: json.dump(export_data, f, indent=2) elif format_type == 'csv': import csv with open(export_file, 'w', newline='') as f: if export_data: writer = csv.DictWriter(f, fieldnames=export_data[0].keys()) writer.writeheader() writer.writerows(export_data) logger.info(f"Exported {len(annotations)} annotations to {export_file}") return export_file