470 lines
19 KiB
Python
470 lines
19 KiB
Python
"""
|
|
Annotation Manager - Manages trade annotations and test case generation
|
|
|
|
Handles storage, retrieval, and test case generation from manual trade annotations.
|
|
Stores annotations in both JSON (legacy) and SQLite (with full market data).
|
|
"""
|
|
|
|
import json
|
|
import uuid
|
|
import sys
|
|
from pathlib import Path
|
|
from datetime import datetime, timedelta
|
|
from typing import List, Dict, Optional, Any
|
|
from dataclasses import dataclass, asdict
|
|
import logging
|
|
import pytz
|
|
|
|
# Add parent directory to path for imports
|
|
parent_dir = Path(__file__).parent.parent.parent
|
|
sys.path.insert(0, str(parent_dir))
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Import DuckDB storage
|
|
try:
|
|
from core.duckdb_storage import DuckDBStorage
|
|
DUCKDB_AVAILABLE = True
|
|
except ImportError:
|
|
DUCKDB_AVAILABLE = False
|
|
logger.warning("DuckDB storage not available for annotations")
|
|
|
|
|
|
@dataclass
|
|
class TradeAnnotation:
|
|
"""Represents a manually marked trade"""
|
|
annotation_id: str
|
|
symbol: str
|
|
timeframe: str
|
|
entry: Dict[str, Any] # {timestamp, price, index}
|
|
exit: Dict[str, Any] # {timestamp, price, index}
|
|
direction: str # 'LONG' or 'SHORT'
|
|
profit_loss_pct: float
|
|
notes: str = ""
|
|
created_at: str = None
|
|
market_context: Dict[str, Any] = None
|
|
|
|
def __post_init__(self):
|
|
if self.created_at is None:
|
|
self.created_at = datetime.now().isoformat()
|
|
if self.market_context is None:
|
|
self.market_context = {}
|
|
|
|
|
|
class AnnotationManager:
|
|
"""Manages trade annotations and test case generation"""
|
|
|
|
def __init__(self, storage_path: str = "ANNOTATE/data/annotations"):
|
|
"""Initialize annotation manager"""
|
|
self.storage_path = Path(storage_path)
|
|
self.storage_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
self.annotations_file = self.storage_path / "annotations_db.json"
|
|
self.test_cases_dir = self.storage_path.parent / "test_cases"
|
|
self.test_cases_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
self.annotations_db = self._load_annotations()
|
|
|
|
# Initialize DuckDB storage for complete annotation data
|
|
self.duckdb_storage: Optional[DuckDBStorage] = None
|
|
if DUCKDB_AVAILABLE:
|
|
try:
|
|
self.duckdb_storage = DuckDBStorage()
|
|
logger.info("DuckDB storage initialized for annotations")
|
|
except Exception as e:
|
|
logger.warning(f"Could not initialize DuckDB storage: {e}")
|
|
|
|
logger.info(f"AnnotationManager initialized with storage: {self.storage_path}")
|
|
|
|
def _load_annotations(self) -> Dict[str, List[Dict]]:
|
|
"""Load annotations from storage"""
|
|
if self.annotations_file.exists():
|
|
try:
|
|
with open(self.annotations_file, 'r') as f:
|
|
data = json.load(f)
|
|
logger.info(f"Loaded {len(data.get('annotations', []))} annotations")
|
|
return data
|
|
except Exception as e:
|
|
logger.error(f"Error loading annotations: {e}")
|
|
return {"annotations": [], "metadata": {}}
|
|
else:
|
|
return {"annotations": [], "metadata": {}}
|
|
|
|
def _save_annotations(self):
|
|
"""Save annotations to storage"""
|
|
try:
|
|
# Update metadata
|
|
self.annotations_db["metadata"] = {
|
|
"total_annotations": len(self.annotations_db["annotations"]),
|
|
"last_updated": datetime.now().isoformat()
|
|
}
|
|
|
|
with open(self.annotations_file, 'w') as f:
|
|
json.dump(self.annotations_db, f, indent=2)
|
|
|
|
logger.info(f"Saved {len(self.annotations_db['annotations'])} annotations")
|
|
except Exception as e:
|
|
logger.error(f"Error saving annotations: {e}")
|
|
raise
|
|
|
|
def create_annotation(self, entry_point: Dict, exit_point: Dict,
|
|
symbol: str, timeframe: str,
|
|
entry_market_state: Dict = None,
|
|
exit_market_state: Dict = None) -> TradeAnnotation:
|
|
"""Create new trade annotation"""
|
|
# Calculate direction and P&L
|
|
entry_price = entry_point['price']
|
|
exit_price = exit_point['price']
|
|
|
|
if exit_price > entry_price:
|
|
direction = 'LONG'
|
|
profit_loss_pct = ((exit_price - entry_price) / entry_price) * 100
|
|
else:
|
|
direction = 'SHORT'
|
|
profit_loss_pct = ((entry_price - exit_price) / entry_price) * 100
|
|
|
|
# Store complete market context for training
|
|
market_context = {
|
|
'entry_state': entry_market_state or {},
|
|
'exit_state': exit_market_state or {}
|
|
}
|
|
|
|
annotation = TradeAnnotation(
|
|
annotation_id=str(uuid.uuid4()),
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
entry=entry_point,
|
|
exit=exit_point,
|
|
direction=direction,
|
|
profit_loss_pct=profit_loss_pct,
|
|
market_context=market_context
|
|
)
|
|
|
|
logger.info(f"Created annotation: {annotation.annotation_id} ({direction}, {profit_loss_pct:.2f}%)")
|
|
logger.info(f" Entry state: {len(entry_market_state or {})} timeframes")
|
|
logger.info(f" Exit state: {len(exit_market_state or {})} timeframes")
|
|
return annotation
|
|
|
|
def save_annotation(self, annotation: TradeAnnotation,
|
|
market_snapshots: Dict = None,
|
|
model_predictions: List[Dict] = None):
|
|
"""
|
|
Save annotation to storage (JSON + SQLite)
|
|
|
|
Args:
|
|
annotation: TradeAnnotation object
|
|
market_snapshots: Dict of {timeframe: DataFrame} with OHLCV data
|
|
model_predictions: List of model predictions at annotation time
|
|
"""
|
|
# Convert to dict
|
|
ann_dict = asdict(annotation)
|
|
|
|
# Add to JSON database (legacy)
|
|
self.annotations_db["annotations"].append(ann_dict)
|
|
|
|
# Save to JSON file
|
|
self._save_annotations()
|
|
|
|
# Save to DuckDB with complete market data
|
|
if self.duckdb_storage and market_snapshots:
|
|
try:
|
|
self.duckdb_storage.store_annotation(
|
|
annotation_id=annotation.annotation_id,
|
|
annotation_data=ann_dict,
|
|
market_snapshots=market_snapshots,
|
|
model_predictions=model_predictions
|
|
)
|
|
logger.info(f"Saved annotation {annotation.annotation_id} to DuckDB with {len(market_snapshots)} timeframes")
|
|
except Exception as e:
|
|
logger.error(f"Could not save annotation to DuckDB: {e}")
|
|
|
|
logger.info(f"Saved annotation: {annotation.annotation_id}")
|
|
|
|
def get_annotations(self, symbol: str = None,
|
|
timeframe: str = None) -> List[TradeAnnotation]:
|
|
"""Retrieve annotations with optional filtering"""
|
|
annotations = self.annotations_db.get("annotations", [])
|
|
|
|
# Filter by symbol
|
|
if symbol:
|
|
annotations = [a for a in annotations if a.get('symbol') == symbol]
|
|
|
|
# Filter by timeframe
|
|
if timeframe:
|
|
annotations = [a for a in annotations if a.get('timeframe') == timeframe]
|
|
|
|
# Convert to TradeAnnotation objects
|
|
result = []
|
|
for ann_dict in annotations:
|
|
try:
|
|
annotation = TradeAnnotation(**ann_dict)
|
|
result.append(annotation)
|
|
except Exception as e:
|
|
logger.error(f"Error converting annotation: {e}")
|
|
|
|
return result
|
|
|
|
def delete_annotation(self, annotation_id: str) -> bool:
|
|
"""
|
|
Delete annotation and its associated test case file
|
|
|
|
Args:
|
|
annotation_id: ID of annotation to delete
|
|
|
|
Returns:
|
|
bool: True if annotation was deleted, False if not found
|
|
|
|
Raises:
|
|
Exception: If there's an error during deletion
|
|
"""
|
|
original_count = len(self.annotations_db["annotations"])
|
|
self.annotations_db["annotations"] = [
|
|
a for a in self.annotations_db["annotations"]
|
|
if a.get('annotation_id') != annotation_id
|
|
]
|
|
|
|
if len(self.annotations_db["annotations"]) < original_count:
|
|
# Annotation was found and removed
|
|
self._save_annotations()
|
|
|
|
# Also delete the associated test case file
|
|
test_case_file = self.test_cases_dir / f"annotation_{annotation_id}.json"
|
|
if test_case_file.exists():
|
|
try:
|
|
test_case_file.unlink()
|
|
logger.info(f"Deleted test case file: {test_case_file}")
|
|
except Exception as e:
|
|
logger.error(f"Error deleting test case file {test_case_file}: {e}")
|
|
# Don't fail the whole operation if test case deletion fails
|
|
|
|
logger.info(f"Deleted annotation: {annotation_id}")
|
|
return True
|
|
else:
|
|
logger.warning(f"Annotation not found: {annotation_id}")
|
|
return False
|
|
|
|
def clear_all_annotations(self, symbol: str = None):
|
|
"""
|
|
Clear all annotations (optionally filtered by symbol)
|
|
More efficient than deleting one by one
|
|
|
|
Args:
|
|
symbol: Optional symbol filter. If None, clears all annotations.
|
|
|
|
Returns:
|
|
int: Number of annotations deleted
|
|
"""
|
|
# Get annotations to delete
|
|
if symbol:
|
|
annotations_to_delete = [
|
|
a for a in self.annotations_db["annotations"]
|
|
if a.get('symbol') == symbol
|
|
]
|
|
# Keep annotations for other symbols
|
|
self.annotations_db["annotations"] = [
|
|
a for a in self.annotations_db["annotations"]
|
|
if a.get('symbol') != symbol
|
|
]
|
|
else:
|
|
annotations_to_delete = self.annotations_db["annotations"].copy()
|
|
self.annotations_db["annotations"] = []
|
|
|
|
deleted_count = len(annotations_to_delete)
|
|
|
|
if deleted_count > 0:
|
|
# Save updated annotations database
|
|
self._save_annotations()
|
|
|
|
# Delete associated test case files
|
|
for annotation in annotations_to_delete:
|
|
annotation_id = annotation.get('annotation_id')
|
|
test_case_file = self.test_cases_dir / f"annotation_{annotation_id}.json"
|
|
if test_case_file.exists():
|
|
try:
|
|
test_case_file.unlink()
|
|
logger.debug(f"Deleted test case file: {test_case_file}")
|
|
except Exception as e:
|
|
logger.error(f"Error deleting test case file {test_case_file}: {e}")
|
|
|
|
logger.info(f"Cleared {deleted_count} annotations" + (f" for symbol {symbol}" if symbol else ""))
|
|
|
|
return deleted_count
|
|
|
|
def generate_test_case(self, annotation: TradeAnnotation, data_provider=None, auto_save: bool = True) -> Dict:
|
|
"""
|
|
Generate lightweight test case metadata (no OHLCV data stored)
|
|
OHLCV data will be fetched dynamically from cache/database during training
|
|
|
|
Args:
|
|
annotation: TradeAnnotation object
|
|
data_provider: Optional DataProvider instance (not used for storage)
|
|
|
|
Returns:
|
|
Test case metadata dictionary
|
|
"""
|
|
test_case = {
|
|
"test_case_id": f"annotation_{annotation.annotation_id}",
|
|
"symbol": annotation.symbol,
|
|
"timestamp": annotation.entry['timestamp'],
|
|
"action": "BUY" if annotation.direction == "LONG" else "SELL",
|
|
"expected_outcome": {
|
|
"direction": annotation.direction,
|
|
"profit_loss_pct": annotation.profit_loss_pct,
|
|
"holding_period_seconds": self._calculate_holding_period(annotation),
|
|
"exit_price": annotation.exit['price'],
|
|
"entry_price": annotation.entry['price']
|
|
},
|
|
"annotation_metadata": {
|
|
"annotator": "manual",
|
|
"confidence": 1.0,
|
|
"notes": annotation.notes,
|
|
"created_at": annotation.created_at,
|
|
"timeframe": annotation.timeframe
|
|
},
|
|
"training_config": {
|
|
"context_window_minutes": 5, # ±5 minutes around entry/exit
|
|
"timeframes": ["1s", "1m", "1h", "1d"],
|
|
"data_source": "cache" # Will fetch from cache/database
|
|
}
|
|
}
|
|
|
|
# Save lightweight test case metadata to file if auto_save is True
|
|
if auto_save:
|
|
test_case_file = self.test_cases_dir / f"{test_case['test_case_id']}.json"
|
|
with open(test_case_file, 'w') as f:
|
|
json.dump(test_case, f, indent=2)
|
|
logger.info(f"Saved test case metadata to: {test_case_file}")
|
|
|
|
logger.info(f"Generated lightweight test case: {test_case['test_case_id']} (OHLCV data will be fetched dynamically)")
|
|
return test_case
|
|
|
|
def get_all_test_cases(self, symbol: Optional[str] = None) -> List[Dict]:
|
|
"""
|
|
Load all test cases from disk
|
|
|
|
Args:
|
|
symbol: Optional symbol filter (e.g., 'ETH/USDT'). If provided, only returns
|
|
test cases for that symbol. Critical for avoiding cross-symbol training.
|
|
|
|
Returns:
|
|
List of test case dictionaries
|
|
"""
|
|
test_cases = []
|
|
|
|
if not self.test_cases_dir.exists():
|
|
return test_cases
|
|
|
|
for test_case_file in self.test_cases_dir.glob("annotation_*.json"):
|
|
try:
|
|
with open(test_case_file, 'r') as f:
|
|
test_case = json.load(f)
|
|
|
|
# CRITICAL: Filter by symbol to avoid training on wrong symbol
|
|
if symbol:
|
|
test_case_symbol = test_case.get('symbol', '')
|
|
if test_case_symbol != symbol:
|
|
logger.debug(f"Skipping {test_case_file.name}: symbol {test_case_symbol} != {symbol}")
|
|
continue
|
|
|
|
test_cases.append(test_case)
|
|
except Exception as e:
|
|
logger.error(f"Error loading test case {test_case_file}: {e}")
|
|
|
|
if symbol:
|
|
logger.info(f"Loaded {len(test_cases)} test cases for symbol {symbol}")
|
|
else:
|
|
logger.info(f"Loaded {len(test_cases)} test cases (all symbols)")
|
|
return test_cases
|
|
|
|
def _calculate_holding_period(self, annotation: TradeAnnotation) -> float:
|
|
"""Calculate holding period in seconds"""
|
|
try:
|
|
entry_time = datetime.fromisoformat(annotation.entry['timestamp'].replace('Z', '+00:00'))
|
|
exit_time = datetime.fromisoformat(annotation.exit['timestamp'].replace('Z', '+00:00'))
|
|
return (exit_time - entry_time).total_seconds()
|
|
except Exception as e:
|
|
logger.error(f"Error calculating holding period: {e}")
|
|
return 0.0
|
|
|
|
def _generate_training_labels(self, market_state: Dict, entry_time: datetime,
|
|
exit_time: datetime, direction: str) -> Dict:
|
|
"""
|
|
Generate training labels for each timestamp in the market data.
|
|
This helps the model learn WHERE to signal and WHERE NOT to signal.
|
|
|
|
Labels:
|
|
- 0 = NO SIGNAL (before entry or after exit)
|
|
- 1 = ENTRY SIGNAL (at entry time)
|
|
- 2 = HOLD (between entry and exit)
|
|
- 3 = EXIT SIGNAL (at exit time)
|
|
"""
|
|
labels = {}
|
|
|
|
# Use 1m timeframe as reference for labeling
|
|
if 'ohlcv_1m' in market_state and 'timestamps' in market_state['ohlcv_1m']:
|
|
timestamps = market_state['ohlcv_1m']['timestamps']
|
|
|
|
label_list = []
|
|
for ts_str in timestamps:
|
|
try:
|
|
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
|
|
# Make timezone-aware to match entry_time
|
|
if ts.tzinfo is None:
|
|
ts = pytz.UTC.localize(ts)
|
|
|
|
# Determine label based on position relative to entry/exit
|
|
if abs((ts - entry_time).total_seconds()) < 60: # Within 1 minute of entry
|
|
label = 1 # ENTRY SIGNAL
|
|
elif abs((ts - exit_time).total_seconds()) < 60: # Within 1 minute of exit
|
|
label = 3 # EXIT SIGNAL
|
|
elif entry_time < ts < exit_time: # Between entry and exit
|
|
label = 2 # HOLD
|
|
else: # Before entry or after exit
|
|
label = 0 # NO SIGNAL
|
|
|
|
label_list.append(label)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error parsing timestamp {ts_str}: {e}")
|
|
label_list.append(0)
|
|
|
|
labels['labels_1m'] = label_list
|
|
labels['direction'] = direction
|
|
labels['entry_timestamp'] = entry_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
labels['exit_timestamp'] = exit_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
|
logger.info(f"Generated {len(label_list)} training labels: "
|
|
f"{label_list.count(0)} NO_SIGNAL, "
|
|
f"{label_list.count(1)} ENTRY, "
|
|
f"{label_list.count(2)} HOLD, "
|
|
f"{label_list.count(3)} EXIT")
|
|
|
|
return labels
|
|
|
|
def export_annotations(self, annotations: List[TradeAnnotation] = None,
|
|
format_type: str = 'json') -> Path:
|
|
"""Export annotations to file"""
|
|
if annotations is None:
|
|
annotations = self.get_annotations()
|
|
|
|
# Convert to dicts
|
|
export_data = [asdict(ann) for ann in annotations]
|
|
|
|
# Create export file
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
export_file = self.storage_path / f"export_{timestamp}.{format_type}"
|
|
|
|
if format_type == 'json':
|
|
with open(export_file, 'w') as f:
|
|
json.dump(export_data, f, indent=2)
|
|
elif format_type == 'csv':
|
|
import csv
|
|
with open(export_file, 'w', newline='') as f:
|
|
if export_data:
|
|
writer = csv.DictWriter(f, fieldnames=export_data[0].keys())
|
|
writer.writeheader()
|
|
writer.writerows(export_data)
|
|
|
|
logger.info(f"Exported {len(annotations)} annotations to {export_file}")
|
|
return export_file
|