save/load data anotations
This commit is contained in:
@@ -172,7 +172,7 @@ class AnnotationManager:
|
||||
else:
|
||||
logger.warning(f"Annotation not found: {annotation_id}")
|
||||
|
||||
def generate_test_case(self, annotation: TradeAnnotation, data_provider=None) -> Dict:
|
||||
def generate_test_case(self, annotation: TradeAnnotation, data_provider=None, auto_save: bool = True) -> Dict:
|
||||
"""
|
||||
Generate test case from annotation in realtime format
|
||||
|
||||
@@ -205,57 +205,99 @@ class AnnotationManager:
|
||||
}
|
||||
}
|
||||
|
||||
# Populate market state if data_provider is available
|
||||
if data_provider and annotation.market_context:
|
||||
test_case["market_state"] = annotation.market_context
|
||||
elif data_provider:
|
||||
# Fetch market state at entry time
|
||||
# Populate market state with ±5 minutes of data for negative examples
|
||||
if data_provider:
|
||||
try:
|
||||
entry_time = datetime.fromisoformat(annotation.entry['timestamp'].replace('Z', '+00:00'))
|
||||
exit_time = datetime.fromisoformat(annotation.exit['timestamp'].replace('Z', '+00:00'))
|
||||
|
||||
# Calculate time window: ±5 minutes around entry
|
||||
time_window_before = timedelta(minutes=5)
|
||||
time_window_after = timedelta(minutes=5)
|
||||
|
||||
start_time = entry_time - time_window_before
|
||||
end_time = entry_time + time_window_after
|
||||
|
||||
logger.info(f"Fetching market data from {start_time} to {end_time} (±5min around entry)")
|
||||
|
||||
# Fetch OHLCV data for all timeframes
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
market_state = {}
|
||||
|
||||
for tf in timeframes:
|
||||
# Get data for the time window
|
||||
df = data_provider.get_historical_data(
|
||||
symbol=annotation.symbol,
|
||||
timeframe=tf,
|
||||
limit=100
|
||||
limit=1000 # Get enough data to cover ±5 minutes
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Filter to data before entry time
|
||||
df = df[df.index <= entry_time]
|
||||
# Filter to time window
|
||||
df_window = df[(df.index >= start_time) & (df.index <= end_time)]
|
||||
|
||||
if not df.empty:
|
||||
if not df_window.empty:
|
||||
# Convert to list format
|
||||
market_state[f'ohlcv_{tf}'] = {
|
||||
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
'open': df['open'].tolist(),
|
||||
'high': df['high'].tolist(),
|
||||
'low': df['low'].tolist(),
|
||||
'close': df['close'].tolist(),
|
||||
'volume': df['volume'].tolist()
|
||||
'timestamps': df_window.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
'open': df_window['open'].tolist(),
|
||||
'high': df_window['high'].tolist(),
|
||||
'low': df_window['low'].tolist(),
|
||||
'close': df_window['close'].tolist(),
|
||||
'volume': df_window['volume'].tolist()
|
||||
}
|
||||
|
||||
logger.info(f" {tf}: {len(df_window)} candles in ±5min window")
|
||||
|
||||
# Add training labels for each timestamp
|
||||
# This helps model learn WHERE to signal and WHERE NOT to signal
|
||||
market_state['training_labels'] = self._generate_training_labels(
|
||||
market_state,
|
||||
entry_time,
|
||||
exit_time,
|
||||
annotation.direction
|
||||
)
|
||||
|
||||
test_case["market_state"] = market_state
|
||||
logger.info(f"Populated market state with {len(market_state)} timeframes")
|
||||
logger.info(f"Populated market state with {len(market_state)-1} timeframes + training labels")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching market state: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
test_case["market_state"] = {}
|
||||
else:
|
||||
logger.warning("No data_provider available, market_state will be empty")
|
||||
test_case["market_state"] = {}
|
||||
|
||||
# Save test case to file
|
||||
test_case_file = self.test_cases_dir / f"{test_case['test_case_id']}.json"
|
||||
with open(test_case_file, 'w') as f:
|
||||
json.dump(test_case, f, indent=2)
|
||||
# Save test case to file if auto_save is True
|
||||
if auto_save:
|
||||
test_case_file = self.test_cases_dir / f"{test_case['test_case_id']}.json"
|
||||
with open(test_case_file, 'w') as f:
|
||||
json.dump(test_case, f, indent=2)
|
||||
logger.info(f"Saved test case to: {test_case_file}")
|
||||
|
||||
logger.info(f"Generated test case: {test_case['test_case_id']}")
|
||||
return test_case
|
||||
|
||||
def get_all_test_cases(self) -> List[Dict]:
|
||||
"""Load all test cases from disk"""
|
||||
test_cases = []
|
||||
|
||||
if not self.test_cases_dir.exists():
|
||||
return test_cases
|
||||
|
||||
for test_case_file in self.test_cases_dir.glob("annotation_*.json"):
|
||||
try:
|
||||
with open(test_case_file, 'r') as f:
|
||||
test_case = json.load(f)
|
||||
test_cases.append(test_case)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading test case {test_case_file}: {e}")
|
||||
|
||||
logger.info(f"Loaded {len(test_cases)} test cases from disk")
|
||||
return test_cases
|
||||
|
||||
def _calculate_holding_period(self, annotation: TradeAnnotation) -> float:
|
||||
"""Calculate holding period in seconds"""
|
||||
try:
|
||||
@@ -266,6 +308,58 @@ class AnnotationManager:
|
||||
logger.error(f"Error calculating holding period: {e}")
|
||||
return 0.0
|
||||
|
||||
def _generate_training_labels(self, market_state: Dict, entry_time: datetime,
|
||||
exit_time: datetime, direction: str) -> Dict:
|
||||
"""
|
||||
Generate training labels for each timestamp in the market data.
|
||||
This helps the model learn WHERE to signal and WHERE NOT to signal.
|
||||
|
||||
Labels:
|
||||
- 0 = NO SIGNAL (before entry or after exit)
|
||||
- 1 = ENTRY SIGNAL (at entry time)
|
||||
- 2 = HOLD (between entry and exit)
|
||||
- 3 = EXIT SIGNAL (at exit time)
|
||||
"""
|
||||
labels = {}
|
||||
|
||||
# Use 1m timeframe as reference for labeling
|
||||
if 'ohlcv_1m' in market_state and 'timestamps' in market_state['ohlcv_1m']:
|
||||
timestamps = market_state['ohlcv_1m']['timestamps']
|
||||
|
||||
label_list = []
|
||||
for ts_str in timestamps:
|
||||
try:
|
||||
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# Determine label based on position relative to entry/exit
|
||||
if abs((ts - entry_time).total_seconds()) < 60: # Within 1 minute of entry
|
||||
label = 1 # ENTRY SIGNAL
|
||||
elif abs((ts - exit_time).total_seconds()) < 60: # Within 1 minute of exit
|
||||
label = 3 # EXIT SIGNAL
|
||||
elif entry_time < ts < exit_time: # Between entry and exit
|
||||
label = 2 # HOLD
|
||||
else: # Before entry or after exit
|
||||
label = 0 # NO SIGNAL
|
||||
|
||||
label_list.append(label)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing timestamp {ts_str}: {e}")
|
||||
label_list.append(0)
|
||||
|
||||
labels['labels_1m'] = label_list
|
||||
labels['direction'] = direction
|
||||
labels['entry_timestamp'] = entry_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
labels['exit_timestamp'] = exit_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
logger.info(f"Generated {len(label_list)} training labels: "
|
||||
f"{label_list.count(0)} NO_SIGNAL, "
|
||||
f"{label_list.count(1)} ENTRY, "
|
||||
f"{label_list.count(2)} HOLD, "
|
||||
f"{label_list.count(3)} EXIT")
|
||||
|
||||
return labels
|
||||
|
||||
def export_annotations(self, annotations: List[TradeAnnotation] = None,
|
||||
format_type: str = 'json') -> Path:
|
||||
"""Export annotations to file"""
|
||||
|
||||
Reference in New Issue
Block a user