240 lines
9.0 KiB
Python
240 lines
9.0 KiB
Python
"""
|
|
Annotation Manager - Manages trade annotations and test case generation
|
|
|
|
Handles storage, retrieval, and test case generation from manual trade annotations.
|
|
"""
|
|
|
|
import json
|
|
import uuid
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from typing import List, Dict, Optional, Any
|
|
from dataclasses import dataclass, asdict
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class TradeAnnotation:
|
|
"""Represents a manually marked trade"""
|
|
annotation_id: str
|
|
symbol: str
|
|
timeframe: str
|
|
entry: Dict[str, Any] # {timestamp, price, index}
|
|
exit: Dict[str, Any] # {timestamp, price, index}
|
|
direction: str # 'LONG' or 'SHORT'
|
|
profit_loss_pct: float
|
|
notes: str = ""
|
|
created_at: str = None
|
|
market_context: Dict[str, Any] = None
|
|
|
|
def __post_init__(self):
|
|
if self.created_at is None:
|
|
self.created_at = datetime.now().isoformat()
|
|
if self.market_context is None:
|
|
self.market_context = {}
|
|
|
|
|
|
class AnnotationManager:
|
|
"""Manages trade annotations and test case generation"""
|
|
|
|
def __init__(self, storage_path: str = "TESTCASES/data/annotations"):
|
|
"""Initialize annotation manager"""
|
|
self.storage_path = Path(storage_path)
|
|
self.storage_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
self.annotations_file = self.storage_path / "annotations_db.json"
|
|
self.test_cases_dir = self.storage_path.parent / "test_cases"
|
|
self.test_cases_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
self.annotations_db = self._load_annotations()
|
|
|
|
logger.info(f"AnnotationManager initialized with storage: {self.storage_path}")
|
|
|
|
def _load_annotations(self) -> Dict[str, List[Dict]]:
|
|
"""Load annotations from storage"""
|
|
if self.annotations_file.exists():
|
|
try:
|
|
with open(self.annotations_file, 'r') as f:
|
|
data = json.load(f)
|
|
logger.info(f"Loaded {len(data.get('annotations', []))} annotations")
|
|
return data
|
|
except Exception as e:
|
|
logger.error(f"Error loading annotations: {e}")
|
|
return {"annotations": [], "metadata": {}}
|
|
else:
|
|
return {"annotations": [], "metadata": {}}
|
|
|
|
def _save_annotations(self):
|
|
"""Save annotations to storage"""
|
|
try:
|
|
# Update metadata
|
|
self.annotations_db["metadata"] = {
|
|
"total_annotations": len(self.annotations_db["annotations"]),
|
|
"last_updated": datetime.now().isoformat()
|
|
}
|
|
|
|
with open(self.annotations_file, 'w') as f:
|
|
json.dump(self.annotations_db, f, indent=2)
|
|
|
|
logger.info(f"Saved {len(self.annotations_db['annotations'])} annotations")
|
|
except Exception as e:
|
|
logger.error(f"Error saving annotations: {e}")
|
|
raise
|
|
|
|
def create_annotation(self, entry_point: Dict, exit_point: Dict,
|
|
symbol: str, timeframe: str) -> TradeAnnotation:
|
|
"""Create new trade annotation"""
|
|
# Calculate direction and P&L
|
|
entry_price = entry_point['price']
|
|
exit_price = exit_point['price']
|
|
|
|
if exit_price > entry_price:
|
|
direction = 'LONG'
|
|
profit_loss_pct = ((exit_price - entry_price) / entry_price) * 100
|
|
else:
|
|
direction = 'SHORT'
|
|
profit_loss_pct = ((entry_price - exit_price) / entry_price) * 100
|
|
|
|
annotation = TradeAnnotation(
|
|
annotation_id=str(uuid.uuid4()),
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
entry=entry_point,
|
|
exit=exit_point,
|
|
direction=direction,
|
|
profit_loss_pct=profit_loss_pct
|
|
)
|
|
|
|
logger.info(f"Created annotation: {annotation.annotation_id} ({direction}, {profit_loss_pct:.2f}%)")
|
|
return annotation
|
|
|
|
def save_annotation(self, annotation: TradeAnnotation):
|
|
"""Save annotation to storage"""
|
|
# Convert to dict
|
|
ann_dict = asdict(annotation)
|
|
|
|
# Add to database
|
|
self.annotations_db["annotations"].append(ann_dict)
|
|
|
|
# Save to file
|
|
self._save_annotations()
|
|
|
|
logger.info(f"Saved annotation: {annotation.annotation_id}")
|
|
|
|
def get_annotations(self, symbol: str = None,
|
|
timeframe: str = None) -> List[TradeAnnotation]:
|
|
"""Retrieve annotations with optional filtering"""
|
|
annotations = self.annotations_db.get("annotations", [])
|
|
|
|
# Filter by symbol
|
|
if symbol:
|
|
annotations = [a for a in annotations if a.get('symbol') == symbol]
|
|
|
|
# Filter by timeframe
|
|
if timeframe:
|
|
annotations = [a for a in annotations if a.get('timeframe') == timeframe]
|
|
|
|
# Convert to TradeAnnotation objects
|
|
result = []
|
|
for ann_dict in annotations:
|
|
try:
|
|
annotation = TradeAnnotation(**ann_dict)
|
|
result.append(annotation)
|
|
except Exception as e:
|
|
logger.error(f"Error converting annotation: {e}")
|
|
|
|
return result
|
|
|
|
def delete_annotation(self, annotation_id: str):
|
|
"""Delete annotation"""
|
|
original_count = len(self.annotations_db["annotations"])
|
|
self.annotations_db["annotations"] = [
|
|
a for a in self.annotations_db["annotations"]
|
|
if a.get('annotation_id') != annotation_id
|
|
]
|
|
|
|
if len(self.annotations_db["annotations"]) < original_count:
|
|
self._save_annotations()
|
|
logger.info(f"Deleted annotation: {annotation_id}")
|
|
else:
|
|
logger.warning(f"Annotation not found: {annotation_id}")
|
|
|
|
def generate_test_case(self, annotation: TradeAnnotation) -> Dict:
|
|
"""Generate test case from annotation in realtime format"""
|
|
# This will be populated with actual market data in Task 2
|
|
test_case = {
|
|
"test_case_id": f"annotation_{annotation.annotation_id}",
|
|
"symbol": annotation.symbol,
|
|
"timestamp": annotation.entry['timestamp'],
|
|
"action": "BUY" if annotation.direction == "LONG" else "SELL",
|
|
"market_state": {
|
|
# Will be populated with BaseDataInput structure
|
|
"ohlcv_1s": [],
|
|
"ohlcv_1m": [],
|
|
"ohlcv_1h": [],
|
|
"ohlcv_1d": [],
|
|
"cob_data": {},
|
|
"technical_indicators": {},
|
|
"pivot_points": []
|
|
},
|
|
"expected_outcome": {
|
|
"direction": annotation.direction,
|
|
"profit_loss_pct": annotation.profit_loss_pct,
|
|
"holding_period_seconds": self._calculate_holding_period(annotation),
|
|
"exit_price": annotation.exit['price']
|
|
},
|
|
"annotation_metadata": {
|
|
"annotator": "manual",
|
|
"confidence": 1.0,
|
|
"notes": annotation.notes,
|
|
"created_at": annotation.created_at
|
|
}
|
|
}
|
|
|
|
# Save test case to file
|
|
test_case_file = self.test_cases_dir / f"{test_case['test_case_id']}.json"
|
|
with open(test_case_file, 'w') as f:
|
|
json.dump(test_case, f, indent=2)
|
|
|
|
logger.info(f"Generated test case: {test_case['test_case_id']}")
|
|
return test_case
|
|
|
|
def _calculate_holding_period(self, annotation: TradeAnnotation) -> float:
|
|
"""Calculate holding period in seconds"""
|
|
try:
|
|
entry_time = datetime.fromisoformat(annotation.entry['timestamp'].replace('Z', '+00:00'))
|
|
exit_time = datetime.fromisoformat(annotation.exit['timestamp'].replace('Z', '+00:00'))
|
|
return (exit_time - entry_time).total_seconds()
|
|
except Exception as e:
|
|
logger.error(f"Error calculating holding period: {e}")
|
|
return 0.0
|
|
|
|
def export_annotations(self, annotations: List[TradeAnnotation] = None,
|
|
format_type: str = 'json') -> Path:
|
|
"""Export annotations to file"""
|
|
if annotations is None:
|
|
annotations = self.get_annotations()
|
|
|
|
# Convert to dicts
|
|
export_data = [asdict(ann) for ann in annotations]
|
|
|
|
# Create export file
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
export_file = self.storage_path / f"export_{timestamp}.{format_type}"
|
|
|
|
if format_type == 'json':
|
|
with open(export_file, 'w') as f:
|
|
json.dump(export_data, f, indent=2)
|
|
elif format_type == 'csv':
|
|
import csv
|
|
with open(export_file, 'w', newline='') as f:
|
|
if export_data:
|
|
writer = csv.DictWriter(f, fieldnames=export_data[0].keys())
|
|
writer.writeheader()
|
|
writer.writerows(export_data)
|
|
|
|
logger.info(f"Exported {len(annotations)} annotations to {export_file}")
|
|
return export_file
|