merge annotate /ANNOTATE/core into /core.
fix chart updates
This commit is contained in:
469
core/annotation_manager.py
Normal file
469
core/annotation_manager.py
Normal file
@@ -0,0 +1,469 @@
|
||||
"""
|
||||
Annotation Manager - Manages trade annotations and test case generation
|
||||
|
||||
Handles storage, retrieval, and test case generation from manual trade annotations.
|
||||
Stores annotations in both JSON (legacy) and SQLite (with full market data).
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
import logging
|
||||
import pytz
|
||||
|
||||
# Add parent directory to path for imports
|
||||
parent_dir = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(parent_dir))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import DuckDB storage
|
||||
try:
|
||||
from core.duckdb_storage import DuckDBStorage
|
||||
DUCKDB_AVAILABLE = True
|
||||
except ImportError:
|
||||
DUCKDB_AVAILABLE = False
|
||||
logger.warning("DuckDB storage not available for annotations")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TradeAnnotation:
|
||||
"""Represents a manually marked trade"""
|
||||
annotation_id: str
|
||||
symbol: str
|
||||
timeframe: str
|
||||
entry: Dict[str, Any] # {timestamp, price, index}
|
||||
exit: Dict[str, Any] # {timestamp, price, index}
|
||||
direction: str # 'LONG' or 'SHORT'
|
||||
profit_loss_pct: float
|
||||
notes: str = ""
|
||||
created_at: str = None
|
||||
market_context: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.now(pytz.UTC).isoformat()
|
||||
if self.market_context is None:
|
||||
self.market_context = {}
|
||||
|
||||
|
||||
class AnnotationManager:
|
||||
"""Manages trade annotations and test case generation"""
|
||||
|
||||
def __init__(self, storage_path: str = "ANNOTATE/data/annotations"):
|
||||
"""Initialize annotation manager"""
|
||||
self.storage_path = Path(storage_path)
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.annotations_file = self.storage_path / "annotations_db.json"
|
||||
self.test_cases_dir = self.storage_path.parent / "test_cases"
|
||||
self.test_cases_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.annotations_db = self._load_annotations()
|
||||
|
||||
# Initialize DuckDB storage for complete annotation data
|
||||
self.duckdb_storage: Optional[DuckDBStorage] = None
|
||||
if DUCKDB_AVAILABLE:
|
||||
try:
|
||||
self.duckdb_storage = DuckDBStorage()
|
||||
logger.info("DuckDB storage initialized for annotations")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not initialize DuckDB storage: {e}")
|
||||
|
||||
logger.info(f"AnnotationManager initialized with storage: {self.storage_path}")
|
||||
|
||||
def _load_annotations(self) -> Dict[str, List[Dict]]:
|
||||
"""Load annotations from storage"""
|
||||
if self.annotations_file.exists():
|
||||
try:
|
||||
with open(self.annotations_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
logger.info(f"Loaded {len(data.get('annotations', []))} annotations")
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading annotations: {e}")
|
||||
return {"annotations": [], "metadata": {}}
|
||||
else:
|
||||
return {"annotations": [], "metadata": {}}
|
||||
|
||||
def _save_annotations(self):
|
||||
"""Save annotations to storage"""
|
||||
try:
|
||||
# Update metadata
|
||||
self.annotations_db["metadata"] = {
|
||||
"total_annotations": len(self.annotations_db["annotations"]),
|
||||
"last_updated": datetime.now(pytz.UTC).isoformat()
|
||||
}
|
||||
|
||||
with open(self.annotations_file, 'w') as f:
|
||||
json.dump(self.annotations_db, f, indent=2)
|
||||
|
||||
logger.info(f"Saved {len(self.annotations_db['annotations'])} annotations")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving annotations: {e}")
|
||||
raise
|
||||
|
||||
def create_annotation(self, entry_point: Dict, exit_point: Dict,
|
||||
symbol: str, timeframe: str,
|
||||
entry_market_state: Dict = None,
|
||||
exit_market_state: Dict = None) -> TradeAnnotation:
|
||||
"""Create new trade annotation"""
|
||||
# Calculate direction and P&L
|
||||
entry_price = entry_point['price']
|
||||
exit_price = exit_point['price']
|
||||
|
||||
if exit_price > entry_price:
|
||||
direction = 'LONG'
|
||||
profit_loss_pct = ((exit_price - entry_price) / entry_price) * 100
|
||||
else:
|
||||
direction = 'SHORT'
|
||||
profit_loss_pct = ((entry_price - exit_price) / entry_price) * 100
|
||||
|
||||
# Store complete market context for training
|
||||
market_context = {
|
||||
'entry_state': entry_market_state or {},
|
||||
'exit_state': exit_market_state or {}
|
||||
}
|
||||
|
||||
annotation = TradeAnnotation(
|
||||
annotation_id=str(uuid.uuid4()),
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
entry=entry_point,
|
||||
exit=exit_point,
|
||||
direction=direction,
|
||||
profit_loss_pct=profit_loss_pct,
|
||||
market_context=market_context
|
||||
)
|
||||
|
||||
logger.info(f"Created annotation: {annotation.annotation_id} ({direction}, {profit_loss_pct:.2f}%)")
|
||||
logger.info(f" Entry state: {len(entry_market_state or {})} timeframes")
|
||||
logger.info(f" Exit state: {len(exit_market_state or {})} timeframes")
|
||||
return annotation
|
||||
|
||||
def save_annotation(self, annotation: TradeAnnotation,
|
||||
market_snapshots: Dict = None,
|
||||
model_predictions: List[Dict] = None):
|
||||
"""
|
||||
Save annotation to storage (JSON + SQLite)
|
||||
|
||||
Args:
|
||||
annotation: TradeAnnotation object
|
||||
market_snapshots: Dict of {timeframe: DataFrame} with OHLCV data
|
||||
model_predictions: List of model predictions at annotation time
|
||||
"""
|
||||
# Convert to dict
|
||||
ann_dict = asdict(annotation)
|
||||
|
||||
# Add to JSON database (legacy)
|
||||
self.annotations_db["annotations"].append(ann_dict)
|
||||
|
||||
# Save to JSON file
|
||||
self._save_annotations()
|
||||
|
||||
# Save to DuckDB with complete market data
|
||||
if self.duckdb_storage and market_snapshots:
|
||||
try:
|
||||
self.duckdb_storage.store_annotation(
|
||||
annotation_id=annotation.annotation_id,
|
||||
annotation_data=ann_dict,
|
||||
market_snapshots=market_snapshots,
|
||||
model_predictions=model_predictions
|
||||
)
|
||||
logger.info(f"Saved annotation {annotation.annotation_id} to DuckDB with {len(market_snapshots)} timeframes")
|
||||
except Exception as e:
|
||||
logger.error(f"Could not save annotation to DuckDB: {e}")
|
||||
|
||||
logger.info(f"Saved annotation: {annotation.annotation_id}")
|
||||
|
||||
def get_annotations(self, symbol: str = None,
|
||||
timeframe: str = None) -> List[TradeAnnotation]:
|
||||
"""Retrieve annotations with optional filtering"""
|
||||
annotations = self.annotations_db.get("annotations", [])
|
||||
|
||||
# Filter by symbol
|
||||
if symbol:
|
||||
annotations = [a for a in annotations if a.get('symbol') == symbol]
|
||||
|
||||
# Filter by timeframe
|
||||
if timeframe:
|
||||
annotations = [a for a in annotations if a.get('timeframe') == timeframe]
|
||||
|
||||
# Convert to TradeAnnotation objects
|
||||
result = []
|
||||
for ann_dict in annotations:
|
||||
try:
|
||||
annotation = TradeAnnotation(**ann_dict)
|
||||
result.append(annotation)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting annotation: {e}")
|
||||
|
||||
return result
|
||||
|
||||
def delete_annotation(self, annotation_id: str) -> bool:
|
||||
"""
|
||||
Delete annotation and its associated test case file
|
||||
|
||||
Args:
|
||||
annotation_id: ID of annotation to delete
|
||||
|
||||
Returns:
|
||||
bool: True if annotation was deleted, False if not found
|
||||
|
||||
Raises:
|
||||
Exception: If there's an error during deletion
|
||||
"""
|
||||
original_count = len(self.annotations_db["annotations"])
|
||||
self.annotations_db["annotations"] = [
|
||||
a for a in self.annotations_db["annotations"]
|
||||
if a.get('annotation_id') != annotation_id
|
||||
]
|
||||
|
||||
if len(self.annotations_db["annotations"]) < original_count:
|
||||
# Annotation was found and removed
|
||||
self._save_annotations()
|
||||
|
||||
# Also delete the associated test case file
|
||||
test_case_file = self.test_cases_dir / f"annotation_{annotation_id}.json"
|
||||
if test_case_file.exists():
|
||||
try:
|
||||
test_case_file.unlink()
|
||||
logger.info(f"Deleted test case file: {test_case_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting test case file {test_case_file}: {e}")
|
||||
# Don't fail the whole operation if test case deletion fails
|
||||
|
||||
logger.info(f"Deleted annotation: {annotation_id}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Annotation not found: {annotation_id}")
|
||||
return False
|
||||
|
||||
def clear_all_annotations(self, symbol: str = None):
|
||||
"""
|
||||
Clear all annotations (optionally filtered by symbol)
|
||||
More efficient than deleting one by one
|
||||
|
||||
Args:
|
||||
symbol: Optional symbol filter. If None, clears all annotations.
|
||||
|
||||
Returns:
|
||||
int: Number of annotations deleted
|
||||
"""
|
||||
# Get annotations to delete
|
||||
if symbol:
|
||||
annotations_to_delete = [
|
||||
a for a in self.annotations_db["annotations"]
|
||||
if a.get('symbol') == symbol
|
||||
]
|
||||
# Keep annotations for other symbols
|
||||
self.annotations_db["annotations"] = [
|
||||
a for a in self.annotations_db["annotations"]
|
||||
if a.get('symbol') != symbol
|
||||
]
|
||||
else:
|
||||
annotations_to_delete = self.annotations_db["annotations"].copy()
|
||||
self.annotations_db["annotations"] = []
|
||||
|
||||
deleted_count = len(annotations_to_delete)
|
||||
|
||||
if deleted_count > 0:
|
||||
# Save updated annotations database
|
||||
self._save_annotations()
|
||||
|
||||
# Delete associated test case files
|
||||
for annotation in annotations_to_delete:
|
||||
annotation_id = annotation.get('annotation_id')
|
||||
test_case_file = self.test_cases_dir / f"annotation_{annotation_id}.json"
|
||||
if test_case_file.exists():
|
||||
try:
|
||||
test_case_file.unlink()
|
||||
logger.debug(f"Deleted test case file: {test_case_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting test case file {test_case_file}: {e}")
|
||||
|
||||
logger.info(f"Cleared {deleted_count} annotations" + (f" for symbol {symbol}" if symbol else ""))
|
||||
|
||||
return deleted_count
|
||||
|
||||
def generate_test_case(self, annotation: TradeAnnotation, data_provider=None, auto_save: bool = True) -> Dict:
|
||||
"""
|
||||
Generate lightweight test case metadata (no OHLCV data stored)
|
||||
OHLCV data will be fetched dynamically from cache/database during training
|
||||
|
||||
Args:
|
||||
annotation: TradeAnnotation object
|
||||
data_provider: Optional DataProvider instance (not used for storage)
|
||||
|
||||
Returns:
|
||||
Test case metadata dictionary
|
||||
"""
|
||||
test_case = {
|
||||
"test_case_id": f"annotation_{annotation.annotation_id}",
|
||||
"symbol": annotation.symbol,
|
||||
"timestamp": annotation.entry['timestamp'],
|
||||
"action": "BUY" if annotation.direction == "LONG" else "SELL",
|
||||
"expected_outcome": {
|
||||
"direction": annotation.direction,
|
||||
"profit_loss_pct": annotation.profit_loss_pct,
|
||||
"holding_period_seconds": self._calculate_holding_period(annotation),
|
||||
"exit_price": annotation.exit['price'],
|
||||
"entry_price": annotation.entry['price']
|
||||
},
|
||||
"annotation_metadata": {
|
||||
"annotator": "manual",
|
||||
"confidence": 1.0,
|
||||
"notes": annotation.notes,
|
||||
"created_at": annotation.created_at,
|
||||
"timeframe": annotation.timeframe
|
||||
},
|
||||
"training_config": {
|
||||
"context_window_minutes": 5, # ±5 minutes around entry/exit
|
||||
"timeframes": ["1s", "1m", "1h", "1d"],
|
||||
"data_source": "cache" # Will fetch from cache/database
|
||||
}
|
||||
}
|
||||
|
||||
# Save lightweight test case metadata to file if auto_save is True
|
||||
if auto_save:
|
||||
test_case_file = self.test_cases_dir / f"{test_case['test_case_id']}.json"
|
||||
with open(test_case_file, 'w') as f:
|
||||
json.dump(test_case, f, indent=2)
|
||||
logger.info(f"Saved test case metadata to: {test_case_file}")
|
||||
|
||||
logger.info(f"Generated lightweight test case: {test_case['test_case_id']} (OHLCV data will be fetched dynamically)")
|
||||
return test_case
|
||||
|
||||
def get_all_test_cases(self, symbol: Optional[str] = None) -> List[Dict]:
|
||||
"""
|
||||
Load all test cases from disk
|
||||
|
||||
Args:
|
||||
symbol: Optional symbol filter (e.g., 'ETH/USDT'). If provided, only returns
|
||||
test cases for that symbol. Critical for avoiding cross-symbol training.
|
||||
|
||||
Returns:
|
||||
List of test case dictionaries
|
||||
"""
|
||||
test_cases = []
|
||||
|
||||
if not self.test_cases_dir.exists():
|
||||
return test_cases
|
||||
|
||||
for test_case_file in self.test_cases_dir.glob("annotation_*.json"):
|
||||
try:
|
||||
with open(test_case_file, 'r') as f:
|
||||
test_case = json.load(f)
|
||||
|
||||
# CRITICAL: Filter by symbol to avoid training on wrong symbol
|
||||
if symbol:
|
||||
test_case_symbol = test_case.get('symbol', '')
|
||||
if test_case_symbol != symbol:
|
||||
logger.debug(f"Skipping {test_case_file.name}: symbol {test_case_symbol} != {symbol}")
|
||||
continue
|
||||
|
||||
test_cases.append(test_case)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading test case {test_case_file}: {e}")
|
||||
|
||||
if symbol:
|
||||
logger.info(f"Loaded {len(test_cases)} test cases for symbol {symbol}")
|
||||
else:
|
||||
logger.info(f"Loaded {len(test_cases)} test cases (all symbols)")
|
||||
return test_cases
|
||||
|
||||
def _calculate_holding_period(self, annotation: TradeAnnotation) -> float:
|
||||
"""Calculate holding period in seconds"""
|
||||
try:
|
||||
entry_time = datetime.fromisoformat(annotation.entry['timestamp'].replace('Z', '+00:00'))
|
||||
exit_time = datetime.fromisoformat(annotation.exit['timestamp'].replace('Z', '+00:00'))
|
||||
return (exit_time - entry_time).total_seconds()
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating holding period: {e}")
|
||||
return 0.0
|
||||
|
||||
def _generate_training_labels(self, market_state: Dict, entry_time: datetime,
|
||||
exit_time: datetime, direction: str) -> Dict:
|
||||
"""
|
||||
Generate training labels for each timestamp in the market data.
|
||||
This helps the model learn WHERE to signal and WHERE NOT to signal.
|
||||
|
||||
Labels:
|
||||
- 0 = NO SIGNAL (before entry or after exit)
|
||||
- 1 = ENTRY SIGNAL (at entry time)
|
||||
- 2 = HOLD (between entry and exit)
|
||||
- 3 = EXIT SIGNAL (at exit time)
|
||||
"""
|
||||
labels = {}
|
||||
|
||||
# Use 1m timeframe as reference for labeling
|
||||
if 'ohlcv_1m' in market_state and 'timestamps' in market_state['ohlcv_1m']:
|
||||
timestamps = market_state['ohlcv_1m']['timestamps']
|
||||
|
||||
label_list = []
|
||||
for ts_str in timestamps:
|
||||
try:
|
||||
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
|
||||
# Make timezone-aware to match entry_time
|
||||
if ts.tzinfo is None:
|
||||
ts = pytz.UTC.localize(ts)
|
||||
|
||||
# Determine label based on position relative to entry/exit
|
||||
if abs((ts - entry_time).total_seconds()) < 60: # Within 1 minute of entry
|
||||
label = 1 # ENTRY SIGNAL
|
||||
elif abs((ts - exit_time).total_seconds()) < 60: # Within 1 minute of exit
|
||||
label = 3 # EXIT SIGNAL
|
||||
elif entry_time < ts < exit_time: # Between entry and exit
|
||||
label = 2 # HOLD
|
||||
else: # Before entry or after exit
|
||||
label = 0 # NO SIGNAL
|
||||
|
||||
label_list.append(label)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing timestamp {ts_str}: {e}")
|
||||
label_list.append(0)
|
||||
|
||||
labels['labels_1m'] = label_list
|
||||
labels['direction'] = direction
|
||||
labels['entry_timestamp'] = entry_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
labels['exit_timestamp'] = exit_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
logger.info(f"Generated {len(label_list)} training labels: "
|
||||
f"{label_list.count(0)} NO_SIGNAL, "
|
||||
f"{label_list.count(1)} ENTRY, "
|
||||
f"{label_list.count(2)} HOLD, "
|
||||
f"{label_list.count(3)} EXIT")
|
||||
|
||||
return labels
|
||||
|
||||
def export_annotations(self, annotations: List[TradeAnnotation] = None,
|
||||
format_type: str = 'json') -> Path:
|
||||
"""Export annotations to file"""
|
||||
if annotations is None:
|
||||
annotations = self.get_annotations()
|
||||
|
||||
# Convert to dicts
|
||||
export_data = [asdict(ann) for ann in annotations]
|
||||
|
||||
# Create export file
|
||||
timestamp = datetime.now(pytz.UTC).strftime('%Y%m%d_%H%M%S')
|
||||
export_file = self.storage_path / f"export_{timestamp}.{format_type}"
|
||||
|
||||
if format_type == 'json':
|
||||
with open(export_file, 'w') as f:
|
||||
json.dump(export_data, f, indent=2)
|
||||
elif format_type == 'csv':
|
||||
import csv
|
||||
with open(export_file, 'w', newline='') as f:
|
||||
if export_data:
|
||||
writer = csv.DictWriter(f, fieldnames=export_data[0].keys())
|
||||
writer.writeheader()
|
||||
writer.writerows(export_data)
|
||||
|
||||
logger.info(f"Exported {len(annotations)} annotations to {export_file}")
|
||||
return export_file
|
||||
Reference in New Issue
Block a user