Files
gogo2/ANNOTATE/core/annotation_manager.py
2025-10-18 23:26:54 +03:00

295 lines
12 KiB
Python

"""
Annotation Manager - Manages trade annotations and test case generation
Handles storage, retrieval, and test case generation from manual trade annotations.
"""
import json
import uuid
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, asdict
import logging
logger = logging.getLogger(__name__)
@dataclass
class TradeAnnotation:
"""Represents a manually marked trade"""
annotation_id: str
symbol: str
timeframe: str
entry: Dict[str, Any] # {timestamp, price, index}
exit: Dict[str, Any] # {timestamp, price, index}
direction: str # 'LONG' or 'SHORT'
profit_loss_pct: float
notes: str = ""
created_at: str = None
market_context: Dict[str, Any] = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now().isoformat()
if self.market_context is None:
self.market_context = {}
class AnnotationManager:
"""Manages trade annotations and test case generation"""
def __init__(self, storage_path: str = "ANNOTATE/data/annotations"):
"""Initialize annotation manager"""
self.storage_path = Path(storage_path)
self.storage_path.mkdir(parents=True, exist_ok=True)
self.annotations_file = self.storage_path / "annotations_db.json"
self.test_cases_dir = self.storage_path.parent / "test_cases"
self.test_cases_dir.mkdir(parents=True, exist_ok=True)
self.annotations_db = self._load_annotations()
logger.info(f"AnnotationManager initialized with storage: {self.storage_path}")
def _load_annotations(self) -> Dict[str, List[Dict]]:
"""Load annotations from storage"""
if self.annotations_file.exists():
try:
with open(self.annotations_file, 'r') as f:
data = json.load(f)
logger.info(f"Loaded {len(data.get('annotations', []))} annotations")
return data
except Exception as e:
logger.error(f"Error loading annotations: {e}")
return {"annotations": [], "metadata": {}}
else:
return {"annotations": [], "metadata": {}}
def _save_annotations(self):
"""Save annotations to storage"""
try:
# Update metadata
self.annotations_db["metadata"] = {
"total_annotations": len(self.annotations_db["annotations"]),
"last_updated": datetime.now().isoformat()
}
with open(self.annotations_file, 'w') as f:
json.dump(self.annotations_db, f, indent=2)
logger.info(f"Saved {len(self.annotations_db['annotations'])} annotations")
except Exception as e:
logger.error(f"Error saving annotations: {e}")
raise
def create_annotation(self, entry_point: Dict, exit_point: Dict,
symbol: str, timeframe: str,
entry_market_state: Dict = None,
exit_market_state: Dict = None) -> TradeAnnotation:
"""Create new trade annotation"""
# Calculate direction and P&L
entry_price = entry_point['price']
exit_price = exit_point['price']
if exit_price > entry_price:
direction = 'LONG'
profit_loss_pct = ((exit_price - entry_price) / entry_price) * 100
else:
direction = 'SHORT'
profit_loss_pct = ((entry_price - exit_price) / entry_price) * 100
# Store complete market context for training
market_context = {
'entry_state': entry_market_state or {},
'exit_state': exit_market_state or {}
}
annotation = TradeAnnotation(
annotation_id=str(uuid.uuid4()),
symbol=symbol,
timeframe=timeframe,
entry=entry_point,
exit=exit_point,
direction=direction,
profit_loss_pct=profit_loss_pct,
market_context=market_context
)
logger.info(f"Created annotation: {annotation.annotation_id} ({direction}, {profit_loss_pct:.2f}%)")
logger.info(f" Entry state: {len(entry_market_state or {})} timeframes")
logger.info(f" Exit state: {len(exit_market_state or {})} timeframes")
return annotation
def save_annotation(self, annotation: TradeAnnotation):
"""Save annotation to storage"""
# Convert to dict
ann_dict = asdict(annotation)
# Add to database
self.annotations_db["annotations"].append(ann_dict)
# Save to file
self._save_annotations()
logger.info(f"Saved annotation: {annotation.annotation_id}")
def get_annotations(self, symbol: str = None,
timeframe: str = None) -> List[TradeAnnotation]:
"""Retrieve annotations with optional filtering"""
annotations = self.annotations_db.get("annotations", [])
# Filter by symbol
if symbol:
annotations = [a for a in annotations if a.get('symbol') == symbol]
# Filter by timeframe
if timeframe:
annotations = [a for a in annotations if a.get('timeframe') == timeframe]
# Convert to TradeAnnotation objects
result = []
for ann_dict in annotations:
try:
annotation = TradeAnnotation(**ann_dict)
result.append(annotation)
except Exception as e:
logger.error(f"Error converting annotation: {e}")
return result
def delete_annotation(self, annotation_id: str):
"""Delete annotation"""
original_count = len(self.annotations_db["annotations"])
self.annotations_db["annotations"] = [
a for a in self.annotations_db["annotations"]
if a.get('annotation_id') != annotation_id
]
if len(self.annotations_db["annotations"]) < original_count:
self._save_annotations()
logger.info(f"Deleted annotation: {annotation_id}")
else:
logger.warning(f"Annotation not found: {annotation_id}")
def generate_test_case(self, annotation: TradeAnnotation, data_provider=None) -> Dict:
"""
Generate test case from annotation in realtime format
Args:
annotation: TradeAnnotation object
data_provider: Optional DataProvider instance to fetch market context
Returns:
Test case dictionary in realtime format
"""
test_case = {
"test_case_id": f"annotation_{annotation.annotation_id}",
"symbol": annotation.symbol,
"timestamp": annotation.entry['timestamp'],
"action": "BUY" if annotation.direction == "LONG" else "SELL",
"market_state": {},
"expected_outcome": {
"direction": annotation.direction,
"profit_loss_pct": annotation.profit_loss_pct,
"holding_period_seconds": self._calculate_holding_period(annotation),
"exit_price": annotation.exit['price'],
"entry_price": annotation.entry['price']
},
"annotation_metadata": {
"annotator": "manual",
"confidence": 1.0,
"notes": annotation.notes,
"created_at": annotation.created_at,
"timeframe": annotation.timeframe
}
}
# Populate market state if data_provider is available
if data_provider and annotation.market_context:
test_case["market_state"] = annotation.market_context
elif data_provider:
# Fetch market state at entry time
try:
entry_time = datetime.fromisoformat(annotation.entry['timestamp'].replace('Z', '+00:00'))
# Fetch OHLCV data for all timeframes
timeframes = ['1s', '1m', '1h', '1d']
market_state = {}
for tf in timeframes:
df = data_provider.get_historical_data(
symbol=annotation.symbol,
timeframe=tf,
limit=100
)
if df is not None and not df.empty:
# Filter to data before entry time
df = df[df.index <= entry_time]
if not df.empty:
# Convert to list format
market_state[f'ohlcv_{tf}'] = {
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': df['open'].tolist(),
'high': df['high'].tolist(),
'low': df['low'].tolist(),
'close': df['close'].tolist(),
'volume': df['volume'].tolist()
}
test_case["market_state"] = market_state
logger.info(f"Populated market state with {len(market_state)} timeframes")
except Exception as e:
logger.error(f"Error fetching market state: {e}")
test_case["market_state"] = {}
else:
test_case["market_state"] = {}
# Save test case to file
test_case_file = self.test_cases_dir / f"{test_case['test_case_id']}.json"
with open(test_case_file, 'w') as f:
json.dump(test_case, f, indent=2)
logger.info(f"Generated test case: {test_case['test_case_id']}")
return test_case
def _calculate_holding_period(self, annotation: TradeAnnotation) -> float:
"""Calculate holding period in seconds"""
try:
entry_time = datetime.fromisoformat(annotation.entry['timestamp'].replace('Z', '+00:00'))
exit_time = datetime.fromisoformat(annotation.exit['timestamp'].replace('Z', '+00:00'))
return (exit_time - entry_time).total_seconds()
except Exception as e:
logger.error(f"Error calculating holding period: {e}")
return 0.0
def export_annotations(self, annotations: List[TradeAnnotation] = None,
format_type: str = 'json') -> Path:
"""Export annotations to file"""
if annotations is None:
annotations = self.get_annotations()
# Convert to dicts
export_data = [asdict(ann) for ann in annotations]
# Create export file
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
export_file = self.storage_path / f"export_{timestamp}.{format_type}"
if format_type == 'json':
with open(export_file, 'w') as f:
json.dump(export_data, f, indent=2)
elif format_type == 'csv':
import csv
with open(export_file, 'w', newline='') as f:
if export_data:
writer = csv.DictWriter(f, fieldnames=export_data[0].keys())
writer.writeheader()
writer.writerows(export_data)
logger.info(f"Exported {len(annotations)} annotations to {export_file}")
return export_file