merge annotate /ANNOTATE/core into /core.
fix chart updates
This commit is contained in:
72
core/NO_SIMULATION_POLICY.md
Normal file
72
core/NO_SIMULATION_POLICY.md
Normal file
@@ -0,0 +1,72 @@
|
||||
# NO SIMULATION CODE POLICY
|
||||
|
||||
## CRITICAL RULE: NEVER CREATE SIMULATION CODE
|
||||
|
||||
**Date:** 2025-10-23
|
||||
**Status:** PERMANENT POLICY
|
||||
|
||||
## What Was Removed
|
||||
|
||||
We deleted `ANNOTATE/core/training_simulator.py` which contained simulation/mock training code.
|
||||
|
||||
## Why This Is Critical
|
||||
|
||||
1. **Real Training Only**: We have REAL training implementations in:
|
||||
- `NN/training/enhanced_realtime_training.py` - Real-time training system
|
||||
- `NN/training/model_manager.py` - Model checkpoint management
|
||||
- `core/unified_training_manager.py` - Unified training orchestration
|
||||
- `core/orchestrator.py` - Core model training methods
|
||||
|
||||
2. **No Shortcuts**: Simulation code creates technical debt and masks real issues
|
||||
3. **Production Quality**: All code must be production-ready, not simulated
|
||||
|
||||
## What To Use Instead
|
||||
|
||||
### For Model Training
|
||||
Use the real training implementations:
|
||||
|
||||
```python
|
||||
# Use EnhancedRealtimeTrainingSystem for real-time training
|
||||
from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
||||
|
||||
# Use UnifiedTrainingManager for coordinated training
|
||||
from core.unified_training_manager import UnifiedTrainingManager
|
||||
|
||||
# Use orchestrator's built-in training methods
|
||||
orchestrator.train_models()
|
||||
```
|
||||
|
||||
### For Model Management
|
||||
```python
|
||||
# Use ModelManager for checkpoint management
|
||||
from NN.training.model_manager import ModelManager
|
||||
|
||||
# Use CheckpointManager for saving/loading
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
```
|
||||
|
||||
## If You Need Training Features
|
||||
|
||||
1. **Extend existing real implementations** - Don't create new simulation code
|
||||
2. **Add to orchestrator** - Put training logic in the orchestrator
|
||||
3. **Use UnifiedTrainingManager** - For coordinated multi-model training
|
||||
4. **Integrate with EnhancedRealtimeTrainingSystem** - For online learning
|
||||
|
||||
## NEVER DO THIS
|
||||
|
||||
Create files with "simulator", "simulation", "mock", "fake" in the name
|
||||
Use placeholder/dummy training loops
|
||||
Return fake metrics or results
|
||||
Skip actual model training
|
||||
|
||||
## ALWAYS DO THIS
|
||||
|
||||
Use real model training methods
|
||||
Integrate with existing training systems
|
||||
Save real checkpoints
|
||||
Track real metrics
|
||||
Handle real data
|
||||
|
||||
---
|
||||
|
||||
**Remember**: If data is unavailable, return None/empty/error - NEVER simulate it!
|
||||
469
core/annotation_manager.py
Normal file
469
core/annotation_manager.py
Normal file
@@ -0,0 +1,469 @@
|
||||
"""
|
||||
Annotation Manager - Manages trade annotations and test case generation
|
||||
|
||||
Handles storage, retrieval, and test case generation from manual trade annotations.
|
||||
Stores annotations in both JSON (legacy) and SQLite (with full market data).
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
import logging
|
||||
import pytz
|
||||
|
||||
# Add parent directory to path for imports
|
||||
parent_dir = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(parent_dir))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import DuckDB storage
|
||||
try:
|
||||
from core.duckdb_storage import DuckDBStorage
|
||||
DUCKDB_AVAILABLE = True
|
||||
except ImportError:
|
||||
DUCKDB_AVAILABLE = False
|
||||
logger.warning("DuckDB storage not available for annotations")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TradeAnnotation:
|
||||
"""Represents a manually marked trade"""
|
||||
annotation_id: str
|
||||
symbol: str
|
||||
timeframe: str
|
||||
entry: Dict[str, Any] # {timestamp, price, index}
|
||||
exit: Dict[str, Any] # {timestamp, price, index}
|
||||
direction: str # 'LONG' or 'SHORT'
|
||||
profit_loss_pct: float
|
||||
notes: str = ""
|
||||
created_at: str = None
|
||||
market_context: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.now(pytz.UTC).isoformat()
|
||||
if self.market_context is None:
|
||||
self.market_context = {}
|
||||
|
||||
|
||||
class AnnotationManager:
|
||||
"""Manages trade annotations and test case generation"""
|
||||
|
||||
def __init__(self, storage_path: str = "ANNOTATE/data/annotations"):
|
||||
"""Initialize annotation manager"""
|
||||
self.storage_path = Path(storage_path)
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.annotations_file = self.storage_path / "annotations_db.json"
|
||||
self.test_cases_dir = self.storage_path.parent / "test_cases"
|
||||
self.test_cases_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.annotations_db = self._load_annotations()
|
||||
|
||||
# Initialize DuckDB storage for complete annotation data
|
||||
self.duckdb_storage: Optional[DuckDBStorage] = None
|
||||
if DUCKDB_AVAILABLE:
|
||||
try:
|
||||
self.duckdb_storage = DuckDBStorage()
|
||||
logger.info("DuckDB storage initialized for annotations")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not initialize DuckDB storage: {e}")
|
||||
|
||||
logger.info(f"AnnotationManager initialized with storage: {self.storage_path}")
|
||||
|
||||
def _load_annotations(self) -> Dict[str, List[Dict]]:
|
||||
"""Load annotations from storage"""
|
||||
if self.annotations_file.exists():
|
||||
try:
|
||||
with open(self.annotations_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
logger.info(f"Loaded {len(data.get('annotations', []))} annotations")
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading annotations: {e}")
|
||||
return {"annotations": [], "metadata": {}}
|
||||
else:
|
||||
return {"annotations": [], "metadata": {}}
|
||||
|
||||
def _save_annotations(self):
|
||||
"""Save annotations to storage"""
|
||||
try:
|
||||
# Update metadata
|
||||
self.annotations_db["metadata"] = {
|
||||
"total_annotations": len(self.annotations_db["annotations"]),
|
||||
"last_updated": datetime.now(pytz.UTC).isoformat()
|
||||
}
|
||||
|
||||
with open(self.annotations_file, 'w') as f:
|
||||
json.dump(self.annotations_db, f, indent=2)
|
||||
|
||||
logger.info(f"Saved {len(self.annotations_db['annotations'])} annotations")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving annotations: {e}")
|
||||
raise
|
||||
|
||||
def create_annotation(self, entry_point: Dict, exit_point: Dict,
|
||||
symbol: str, timeframe: str,
|
||||
entry_market_state: Dict = None,
|
||||
exit_market_state: Dict = None) -> TradeAnnotation:
|
||||
"""Create new trade annotation"""
|
||||
# Calculate direction and P&L
|
||||
entry_price = entry_point['price']
|
||||
exit_price = exit_point['price']
|
||||
|
||||
if exit_price > entry_price:
|
||||
direction = 'LONG'
|
||||
profit_loss_pct = ((exit_price - entry_price) / entry_price) * 100
|
||||
else:
|
||||
direction = 'SHORT'
|
||||
profit_loss_pct = ((entry_price - exit_price) / entry_price) * 100
|
||||
|
||||
# Store complete market context for training
|
||||
market_context = {
|
||||
'entry_state': entry_market_state or {},
|
||||
'exit_state': exit_market_state or {}
|
||||
}
|
||||
|
||||
annotation = TradeAnnotation(
|
||||
annotation_id=str(uuid.uuid4()),
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
entry=entry_point,
|
||||
exit=exit_point,
|
||||
direction=direction,
|
||||
profit_loss_pct=profit_loss_pct,
|
||||
market_context=market_context
|
||||
)
|
||||
|
||||
logger.info(f"Created annotation: {annotation.annotation_id} ({direction}, {profit_loss_pct:.2f}%)")
|
||||
logger.info(f" Entry state: {len(entry_market_state or {})} timeframes")
|
||||
logger.info(f" Exit state: {len(exit_market_state or {})} timeframes")
|
||||
return annotation
|
||||
|
||||
def save_annotation(self, annotation: TradeAnnotation,
|
||||
market_snapshots: Dict = None,
|
||||
model_predictions: List[Dict] = None):
|
||||
"""
|
||||
Save annotation to storage (JSON + SQLite)
|
||||
|
||||
Args:
|
||||
annotation: TradeAnnotation object
|
||||
market_snapshots: Dict of {timeframe: DataFrame} with OHLCV data
|
||||
model_predictions: List of model predictions at annotation time
|
||||
"""
|
||||
# Convert to dict
|
||||
ann_dict = asdict(annotation)
|
||||
|
||||
# Add to JSON database (legacy)
|
||||
self.annotations_db["annotations"].append(ann_dict)
|
||||
|
||||
# Save to JSON file
|
||||
self._save_annotations()
|
||||
|
||||
# Save to DuckDB with complete market data
|
||||
if self.duckdb_storage and market_snapshots:
|
||||
try:
|
||||
self.duckdb_storage.store_annotation(
|
||||
annotation_id=annotation.annotation_id,
|
||||
annotation_data=ann_dict,
|
||||
market_snapshots=market_snapshots,
|
||||
model_predictions=model_predictions
|
||||
)
|
||||
logger.info(f"Saved annotation {annotation.annotation_id} to DuckDB with {len(market_snapshots)} timeframes")
|
||||
except Exception as e:
|
||||
logger.error(f"Could not save annotation to DuckDB: {e}")
|
||||
|
||||
logger.info(f"Saved annotation: {annotation.annotation_id}")
|
||||
|
||||
def get_annotations(self, symbol: str = None,
|
||||
timeframe: str = None) -> List[TradeAnnotation]:
|
||||
"""Retrieve annotations with optional filtering"""
|
||||
annotations = self.annotations_db.get("annotations", [])
|
||||
|
||||
# Filter by symbol
|
||||
if symbol:
|
||||
annotations = [a for a in annotations if a.get('symbol') == symbol]
|
||||
|
||||
# Filter by timeframe
|
||||
if timeframe:
|
||||
annotations = [a for a in annotations if a.get('timeframe') == timeframe]
|
||||
|
||||
# Convert to TradeAnnotation objects
|
||||
result = []
|
||||
for ann_dict in annotations:
|
||||
try:
|
||||
annotation = TradeAnnotation(**ann_dict)
|
||||
result.append(annotation)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting annotation: {e}")
|
||||
|
||||
return result
|
||||
|
||||
def delete_annotation(self, annotation_id: str) -> bool:
|
||||
"""
|
||||
Delete annotation and its associated test case file
|
||||
|
||||
Args:
|
||||
annotation_id: ID of annotation to delete
|
||||
|
||||
Returns:
|
||||
bool: True if annotation was deleted, False if not found
|
||||
|
||||
Raises:
|
||||
Exception: If there's an error during deletion
|
||||
"""
|
||||
original_count = len(self.annotations_db["annotations"])
|
||||
self.annotations_db["annotations"] = [
|
||||
a for a in self.annotations_db["annotations"]
|
||||
if a.get('annotation_id') != annotation_id
|
||||
]
|
||||
|
||||
if len(self.annotations_db["annotations"]) < original_count:
|
||||
# Annotation was found and removed
|
||||
self._save_annotations()
|
||||
|
||||
# Also delete the associated test case file
|
||||
test_case_file = self.test_cases_dir / f"annotation_{annotation_id}.json"
|
||||
if test_case_file.exists():
|
||||
try:
|
||||
test_case_file.unlink()
|
||||
logger.info(f"Deleted test case file: {test_case_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting test case file {test_case_file}: {e}")
|
||||
# Don't fail the whole operation if test case deletion fails
|
||||
|
||||
logger.info(f"Deleted annotation: {annotation_id}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Annotation not found: {annotation_id}")
|
||||
return False
|
||||
|
||||
def clear_all_annotations(self, symbol: str = None):
|
||||
"""
|
||||
Clear all annotations (optionally filtered by symbol)
|
||||
More efficient than deleting one by one
|
||||
|
||||
Args:
|
||||
symbol: Optional symbol filter. If None, clears all annotations.
|
||||
|
||||
Returns:
|
||||
int: Number of annotations deleted
|
||||
"""
|
||||
# Get annotations to delete
|
||||
if symbol:
|
||||
annotations_to_delete = [
|
||||
a for a in self.annotations_db["annotations"]
|
||||
if a.get('symbol') == symbol
|
||||
]
|
||||
# Keep annotations for other symbols
|
||||
self.annotations_db["annotations"] = [
|
||||
a for a in self.annotations_db["annotations"]
|
||||
if a.get('symbol') != symbol
|
||||
]
|
||||
else:
|
||||
annotations_to_delete = self.annotations_db["annotations"].copy()
|
||||
self.annotations_db["annotations"] = []
|
||||
|
||||
deleted_count = len(annotations_to_delete)
|
||||
|
||||
if deleted_count > 0:
|
||||
# Save updated annotations database
|
||||
self._save_annotations()
|
||||
|
||||
# Delete associated test case files
|
||||
for annotation in annotations_to_delete:
|
||||
annotation_id = annotation.get('annotation_id')
|
||||
test_case_file = self.test_cases_dir / f"annotation_{annotation_id}.json"
|
||||
if test_case_file.exists():
|
||||
try:
|
||||
test_case_file.unlink()
|
||||
logger.debug(f"Deleted test case file: {test_case_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting test case file {test_case_file}: {e}")
|
||||
|
||||
logger.info(f"Cleared {deleted_count} annotations" + (f" for symbol {symbol}" if symbol else ""))
|
||||
|
||||
return deleted_count
|
||||
|
||||
def generate_test_case(self, annotation: TradeAnnotation, data_provider=None, auto_save: bool = True) -> Dict:
|
||||
"""
|
||||
Generate lightweight test case metadata (no OHLCV data stored)
|
||||
OHLCV data will be fetched dynamically from cache/database during training
|
||||
|
||||
Args:
|
||||
annotation: TradeAnnotation object
|
||||
data_provider: Optional DataProvider instance (not used for storage)
|
||||
|
||||
Returns:
|
||||
Test case metadata dictionary
|
||||
"""
|
||||
test_case = {
|
||||
"test_case_id": f"annotation_{annotation.annotation_id}",
|
||||
"symbol": annotation.symbol,
|
||||
"timestamp": annotation.entry['timestamp'],
|
||||
"action": "BUY" if annotation.direction == "LONG" else "SELL",
|
||||
"expected_outcome": {
|
||||
"direction": annotation.direction,
|
||||
"profit_loss_pct": annotation.profit_loss_pct,
|
||||
"holding_period_seconds": self._calculate_holding_period(annotation),
|
||||
"exit_price": annotation.exit['price'],
|
||||
"entry_price": annotation.entry['price']
|
||||
},
|
||||
"annotation_metadata": {
|
||||
"annotator": "manual",
|
||||
"confidence": 1.0,
|
||||
"notes": annotation.notes,
|
||||
"created_at": annotation.created_at,
|
||||
"timeframe": annotation.timeframe
|
||||
},
|
||||
"training_config": {
|
||||
"context_window_minutes": 5, # ±5 minutes around entry/exit
|
||||
"timeframes": ["1s", "1m", "1h", "1d"],
|
||||
"data_source": "cache" # Will fetch from cache/database
|
||||
}
|
||||
}
|
||||
|
||||
# Save lightweight test case metadata to file if auto_save is True
|
||||
if auto_save:
|
||||
test_case_file = self.test_cases_dir / f"{test_case['test_case_id']}.json"
|
||||
with open(test_case_file, 'w') as f:
|
||||
json.dump(test_case, f, indent=2)
|
||||
logger.info(f"Saved test case metadata to: {test_case_file}")
|
||||
|
||||
logger.info(f"Generated lightweight test case: {test_case['test_case_id']} (OHLCV data will be fetched dynamically)")
|
||||
return test_case
|
||||
|
||||
def get_all_test_cases(self, symbol: Optional[str] = None) -> List[Dict]:
|
||||
"""
|
||||
Load all test cases from disk
|
||||
|
||||
Args:
|
||||
symbol: Optional symbol filter (e.g., 'ETH/USDT'). If provided, only returns
|
||||
test cases for that symbol. Critical for avoiding cross-symbol training.
|
||||
|
||||
Returns:
|
||||
List of test case dictionaries
|
||||
"""
|
||||
test_cases = []
|
||||
|
||||
if not self.test_cases_dir.exists():
|
||||
return test_cases
|
||||
|
||||
for test_case_file in self.test_cases_dir.glob("annotation_*.json"):
|
||||
try:
|
||||
with open(test_case_file, 'r') as f:
|
||||
test_case = json.load(f)
|
||||
|
||||
# CRITICAL: Filter by symbol to avoid training on wrong symbol
|
||||
if symbol:
|
||||
test_case_symbol = test_case.get('symbol', '')
|
||||
if test_case_symbol != symbol:
|
||||
logger.debug(f"Skipping {test_case_file.name}: symbol {test_case_symbol} != {symbol}")
|
||||
continue
|
||||
|
||||
test_cases.append(test_case)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading test case {test_case_file}: {e}")
|
||||
|
||||
if symbol:
|
||||
logger.info(f"Loaded {len(test_cases)} test cases for symbol {symbol}")
|
||||
else:
|
||||
logger.info(f"Loaded {len(test_cases)} test cases (all symbols)")
|
||||
return test_cases
|
||||
|
||||
def _calculate_holding_period(self, annotation: TradeAnnotation) -> float:
|
||||
"""Calculate holding period in seconds"""
|
||||
try:
|
||||
entry_time = datetime.fromisoformat(annotation.entry['timestamp'].replace('Z', '+00:00'))
|
||||
exit_time = datetime.fromisoformat(annotation.exit['timestamp'].replace('Z', '+00:00'))
|
||||
return (exit_time - entry_time).total_seconds()
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating holding period: {e}")
|
||||
return 0.0
|
||||
|
||||
def _generate_training_labels(self, market_state: Dict, entry_time: datetime,
|
||||
exit_time: datetime, direction: str) -> Dict:
|
||||
"""
|
||||
Generate training labels for each timestamp in the market data.
|
||||
This helps the model learn WHERE to signal and WHERE NOT to signal.
|
||||
|
||||
Labels:
|
||||
- 0 = NO SIGNAL (before entry or after exit)
|
||||
- 1 = ENTRY SIGNAL (at entry time)
|
||||
- 2 = HOLD (between entry and exit)
|
||||
- 3 = EXIT SIGNAL (at exit time)
|
||||
"""
|
||||
labels = {}
|
||||
|
||||
# Use 1m timeframe as reference for labeling
|
||||
if 'ohlcv_1m' in market_state and 'timestamps' in market_state['ohlcv_1m']:
|
||||
timestamps = market_state['ohlcv_1m']['timestamps']
|
||||
|
||||
label_list = []
|
||||
for ts_str in timestamps:
|
||||
try:
|
||||
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
|
||||
# Make timezone-aware to match entry_time
|
||||
if ts.tzinfo is None:
|
||||
ts = pytz.UTC.localize(ts)
|
||||
|
||||
# Determine label based on position relative to entry/exit
|
||||
if abs((ts - entry_time).total_seconds()) < 60: # Within 1 minute of entry
|
||||
label = 1 # ENTRY SIGNAL
|
||||
elif abs((ts - exit_time).total_seconds()) < 60: # Within 1 minute of exit
|
||||
label = 3 # EXIT SIGNAL
|
||||
elif entry_time < ts < exit_time: # Between entry and exit
|
||||
label = 2 # HOLD
|
||||
else: # Before entry or after exit
|
||||
label = 0 # NO SIGNAL
|
||||
|
||||
label_list.append(label)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing timestamp {ts_str}: {e}")
|
||||
label_list.append(0)
|
||||
|
||||
labels['labels_1m'] = label_list
|
||||
labels['direction'] = direction
|
||||
labels['entry_timestamp'] = entry_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
labels['exit_timestamp'] = exit_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
logger.info(f"Generated {len(label_list)} training labels: "
|
||||
f"{label_list.count(0)} NO_SIGNAL, "
|
||||
f"{label_list.count(1)} ENTRY, "
|
||||
f"{label_list.count(2)} HOLD, "
|
||||
f"{label_list.count(3)} EXIT")
|
||||
|
||||
return labels
|
||||
|
||||
def export_annotations(self, annotations: List[TradeAnnotation] = None,
|
||||
format_type: str = 'json') -> Path:
|
||||
"""Export annotations to file"""
|
||||
if annotations is None:
|
||||
annotations = self.get_annotations()
|
||||
|
||||
# Convert to dicts
|
||||
export_data = [asdict(ann) for ann in annotations]
|
||||
|
||||
# Create export file
|
||||
timestamp = datetime.now(pytz.UTC).strftime('%Y%m%d_%H%M%S')
|
||||
export_file = self.storage_path / f"export_{timestamp}.{format_type}"
|
||||
|
||||
if format_type == 'json':
|
||||
with open(export_file, 'w') as f:
|
||||
json.dump(export_data, f, indent=2)
|
||||
elif format_type == 'csv':
|
||||
import csv
|
||||
with open(export_file, 'w', newline='') as f:
|
||||
if export_data:
|
||||
writer = csv.DictWriter(f, fieldnames=export_data[0].keys())
|
||||
writer.writeheader()
|
||||
writer.writerows(export_data)
|
||||
|
||||
logger.info(f"Exported {len(annotations)} annotations to {export_file}")
|
||||
return export_file
|
||||
@@ -758,4 +758,62 @@ def create_model_output(model_type: str, model_name: str, symbol: str,
|
||||
predictions=predictions,
|
||||
hidden_states=hidden_states or {},
|
||||
metadata=metadata or {}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class InferenceFrameReference:
|
||||
"""
|
||||
Reference to inference data stored in DuckDB with human-readable prediction outputs.
|
||||
No copying - just store timestamp ranges and query when needed.
|
||||
|
||||
Moved from ANNOTATE/core to main core for unified architecture.
|
||||
"""
|
||||
inference_id: str # Unique ID for this inference
|
||||
symbol: str
|
||||
timeframe: str
|
||||
prediction_timestamp: datetime # When prediction was made
|
||||
target_timestamp: Optional[datetime] = None # When result will be available (for candles)
|
||||
|
||||
# Reference to data in DuckDB (timestamp range)
|
||||
data_range_start: datetime # Start of 600-candle window
|
||||
data_range_end: datetime # End of 600-candle window
|
||||
|
||||
# Normalization parameters (small, can be stored)
|
||||
norm_params: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
||||
|
||||
# ENHANCED: Human-readable prediction outputs
|
||||
predicted_action: Optional[str] = None # 'BUY', 'SELL', 'HOLD'
|
||||
predicted_candle: Optional[Dict[str, List[float]]] = None # {timeframe: [O,H,L,C,V]}
|
||||
predicted_price: Optional[float] = None # Main predicted price
|
||||
confidence: float = 0.0
|
||||
|
||||
# Model metadata for decision making
|
||||
model_type: str = 'transformer' # 'transformer', 'cnn', 'dqn'
|
||||
prediction_steps: int = 1 # Number of steps predicted ahead
|
||||
|
||||
# Training status
|
||||
trained: bool = False
|
||||
training_timestamp: Optional[datetime] = None
|
||||
training_loss: Optional[float] = None
|
||||
training_accuracy: Optional[float] = None
|
||||
|
||||
# Actual results (filled when candle completes)
|
||||
actual_candle: Optional[List[float]] = None # [O,H,L,C,V]
|
||||
actual_price: Optional[float] = None
|
||||
prediction_error: Optional[float] = None # |predicted - actual|
|
||||
direction_correct: Optional[bool] = None # Did we predict direction correctly?
|
||||
|
||||
@dataclass
|
||||
class TrainingSession:
|
||||
"""Real training session tracking - moved from ANNOTATE/core"""
|
||||
training_id: str
|
||||
symbol: str
|
||||
timeframe: str
|
||||
model_type: str
|
||||
start_time: datetime
|
||||
end_time: Optional[datetime] = None
|
||||
status: str = 'running' # 'running', 'completed', 'failed'
|
||||
loss: Optional[float] = None
|
||||
accuracy: Optional[float] = None
|
||||
samples_trained: int = 0
|
||||
error_message: Optional[str] = None
|
||||
@@ -4372,3 +4372,78 @@ class DataProvider:
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting report data for multiple pairs: {e}")
|
||||
return {}
|
||||
# ===== ANNOTATION UI SUPPORT METHODS =====
|
||||
# Added to support ANNOTATE app without duplicate data_loader
|
||||
|
||||
def get_data_for_annotation(self, symbol: str, timeframe: str,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
limit: int = 2500,
|
||||
direction: str = 'latest') -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Get data specifically for annotation UI needs
|
||||
Combines functionality from the old HistoricalDataLoader
|
||||
"""
|
||||
try:
|
||||
# For live updates (small limit, direction='latest', no time range)
|
||||
is_live_update = (direction == 'latest' and not start_time and not end_time and limit <= 5)
|
||||
|
||||
if is_live_update:
|
||||
# Use get_latest_candles for live updates (combines cached + real-time)
|
||||
logger.debug(f"Getting live candles for annotation UI: {symbol} {timeframe}")
|
||||
return self.get_latest_candles(symbol, timeframe, limit)
|
||||
|
||||
# For historical data with time range
|
||||
if start_time or end_time:
|
||||
# Use DuckDB for historical queries
|
||||
if self.duckdb_storage:
|
||||
df = self.duckdb_storage.get_ohlcv_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
direction=direction
|
||||
)
|
||||
if df is not None and not df.empty:
|
||||
return df
|
||||
|
||||
# Fallback to API if DuckDB doesn't have the data
|
||||
logger.info(f"Fetching historical data from API for annotation: {symbol} {timeframe}")
|
||||
return self.get_historical_data(symbol, timeframe, limit, refresh=True)
|
||||
|
||||
# For regular data requests
|
||||
return self.get_historical_data(symbol, timeframe, limit)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting data for annotation: {e}")
|
||||
return None
|
||||
|
||||
def get_multi_timeframe_data_for_annotation(self, symbol: str,
|
||||
timeframes: List[str],
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
limit: int = 2500) -> Dict[str, pd.DataFrame]:
|
||||
"""Get data for multiple timeframes at once for annotation UI"""
|
||||
result = {}
|
||||
|
||||
for timeframe in timeframes:
|
||||
df = self.get_data_for_annotation(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
result[timeframe] = df
|
||||
|
||||
logger.info(f"Loaded annotation data for {len(result)}/{len(timeframes)} timeframes")
|
||||
return result
|
||||
|
||||
def disable_startup_mode(self):
|
||||
"""Disable startup mode - annotation UI compatibility method"""
|
||||
# This was used by the old data_loader, now we just ensure fresh data
|
||||
logger.info("Annotation UI requested fresh data mode")
|
||||
pass # Main DataProvider always provides fresh data when requested
|
||||
322
core/live_pivot_trainer.py
Normal file
322
core/live_pivot_trainer.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
Live Pivot Trainer - Automatic Training on L2 Pivot Points
|
||||
|
||||
This module monitors live 1s and 1m charts for L2 pivot points (peaks/troughs)
|
||||
and automatically creates training samples when they occur.
|
||||
|
||||
Integrates with:
|
||||
- Williams Market Structure for pivot detection
|
||||
- Real Training Adapter for model training
|
||||
- Data Provider for live market data
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from datetime import datetime, timezone
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LivePivotTrainer:
|
||||
"""
|
||||
Monitors live charts for L2 pivots and automatically trains models
|
||||
|
||||
Features:
|
||||
- Detects L2 pivot points on 1s and 1m timeframes
|
||||
- Creates training samples automatically
|
||||
- Trains models in background without blocking inference
|
||||
- Tracks training history to avoid duplicate training
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator, data_provider, training_adapter):
|
||||
"""
|
||||
Initialize Live Pivot Trainer
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance
|
||||
data_provider: DataProvider for market data
|
||||
training_adapter: RealTrainingAdapter for training
|
||||
"""
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = data_provider
|
||||
self.training_adapter = training_adapter
|
||||
|
||||
# Tracking
|
||||
self.running = False
|
||||
self.trained_pivots = deque(maxlen=1000) # Track last 1000 trained pivots
|
||||
self.pivot_history = {
|
||||
'1s': deque(maxlen=100),
|
||||
'1m': deque(maxlen=100)
|
||||
}
|
||||
|
||||
# Configuration
|
||||
self.check_interval = 5 # Check for new pivots every 5 seconds
|
||||
self.min_pivot_spacing = 60 # Minimum 60 seconds between training on same timeframe
|
||||
self.last_training_time = {
|
||||
'1s': 0,
|
||||
'1m': 0
|
||||
}
|
||||
|
||||
# Williams Market Structure for pivot detection
|
||||
try:
|
||||
from core.williams_market_structure import WilliamsMarketStructure
|
||||
# Fix: WilliamsMarketStructure.__init__ does not accept num_levels
|
||||
# It defaults to 5 levels internally
|
||||
self.williams_1s = WilliamsMarketStructure()
|
||||
self.williams_1m = WilliamsMarketStructure()
|
||||
logger.info("Williams Market Structure initialized for pivot detection")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Williams Market Structure: {e}")
|
||||
self.williams_1s = None
|
||||
self.williams_1m = None
|
||||
|
||||
logger.info("LivePivotTrainer initialized")
|
||||
|
||||
def start(self, symbol: str = 'ETH/USDT'):
|
||||
"""Start monitoring for L2 pivots"""
|
||||
if self.running:
|
||||
logger.warning("LivePivotTrainer already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.symbol = symbol
|
||||
|
||||
# Start monitoring thread
|
||||
thread = threading.Thread(
|
||||
target=self._monitoring_loop,
|
||||
args=(symbol,),
|
||||
daemon=True
|
||||
)
|
||||
thread.start()
|
||||
|
||||
logger.info(f"LivePivotTrainer started for {symbol}")
|
||||
|
||||
def stop(self):
|
||||
"""Stop monitoring"""
|
||||
self.running = False
|
||||
logger.info("LivePivotTrainer stopped")
|
||||
|
||||
def _monitoring_loop(self, symbol: str):
|
||||
"""Main monitoring loop - checks for new L2 pivots"""
|
||||
logger.info(f"LivePivotTrainer monitoring loop started for {symbol}")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Check 1s timeframe
|
||||
self._check_timeframe_for_pivots(symbol, '1s')
|
||||
|
||||
# Check 1m timeframe
|
||||
self._check_timeframe_for_pivots(symbol, '1m')
|
||||
|
||||
# Sleep before next check
|
||||
time.sleep(self.check_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in LivePivotTrainer monitoring loop: {e}")
|
||||
time.sleep(10) # Wait longer on error
|
||||
|
||||
def _check_timeframe_for_pivots(self, symbol: str, timeframe: str):
|
||||
"""
|
||||
Check a specific timeframe for new L2 pivots
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: '1s' or '1m'
|
||||
"""
|
||||
try:
|
||||
# Rate limiting - don't train too frequently on same timeframe
|
||||
current_time = time.time()
|
||||
if current_time - self.last_training_time[timeframe] < self.min_pivot_spacing:
|
||||
return
|
||||
|
||||
# Get recent candles
|
||||
candles = self.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=200 # Need enough candles to detect pivots
|
||||
)
|
||||
|
||||
if candles is None or candles.empty:
|
||||
logger.debug(f"No candles available for {symbol} {timeframe}")
|
||||
return
|
||||
|
||||
# Detect pivots using Williams Market Structure
|
||||
williams = self.williams_1s if timeframe == '1s' else self.williams_1m
|
||||
if williams is None:
|
||||
return
|
||||
|
||||
# Prepare data for Williams Market Structure
|
||||
# Convert DataFrame to numpy array format
|
||||
df = candles.copy()
|
||||
ohlcv_array = df[['open', 'high', 'low', 'close', 'volume']].copy()
|
||||
|
||||
# Handle timestamp conversion based on index type
|
||||
if isinstance(df.index, pd.DatetimeIndex):
|
||||
# Convert ns to ms
|
||||
timestamps = df.index.astype(np.int64) // 10**6
|
||||
else:
|
||||
# Assume it's already timestamp or handle accordingly
|
||||
timestamps = df.index
|
||||
|
||||
ohlcv_array.insert(0, 'timestamp', timestamps)
|
||||
ohlcv_array = ohlcv_array.to_numpy()
|
||||
|
||||
# Calculate pivots
|
||||
pivot_levels = williams.calculate_recursive_pivot_points(ohlcv_array)
|
||||
|
||||
if not pivot_levels or 2 not in pivot_levels:
|
||||
return
|
||||
|
||||
# Get Level 2 pivots
|
||||
l2_trend_level = pivot_levels[2]
|
||||
l2_pivots_objs = l2_trend_level.pivot_points
|
||||
|
||||
if not l2_pivots_objs:
|
||||
return
|
||||
|
||||
# Check for new L2 pivots (not in history)
|
||||
new_pivots = []
|
||||
for p in l2_pivots_objs:
|
||||
# Convert pivot object to dict for compatibility
|
||||
pivot_dict = {
|
||||
'timestamp': p.timestamp, # Keep as datetime object for compatibility
|
||||
'price': p.price,
|
||||
'type': p.pivot_type,
|
||||
'strength': p.strength
|
||||
}
|
||||
|
||||
pivot_id = f"{symbol}_{timeframe}_{pivot_dict['timestamp']}_{pivot_dict['type']}"
|
||||
|
||||
if pivot_id not in self.trained_pivots:
|
||||
new_pivots.append(pivot_dict)
|
||||
self.trained_pivots.append(pivot_id)
|
||||
|
||||
if new_pivots:
|
||||
logger.info(f"Found {len(new_pivots)} new L2 pivots on {symbol} {timeframe}")
|
||||
|
||||
# Train on the most recent pivot
|
||||
latest_pivot = new_pivots[-1]
|
||||
self._train_on_pivot(symbol, timeframe, latest_pivot, candles)
|
||||
|
||||
self.last_training_time[timeframe] = current_time
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking {timeframe} for pivots: {e}")
|
||||
|
||||
def _train_on_pivot(self, symbol: str, timeframe: str, pivot: Dict, candles):
|
||||
"""
|
||||
Create training sample from pivot and train model
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe of pivot
|
||||
pivot: Pivot point data
|
||||
candles: DataFrame with OHLCV data
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Training on L2 {pivot['type']} pivot @ {pivot['price']} on {symbol} {timeframe}")
|
||||
|
||||
# Determine trade direction based on pivot type
|
||||
if pivot['type'] == 'high':
|
||||
# High pivot = potential SHORT entry
|
||||
direction = 'SHORT'
|
||||
action = 'SELL'
|
||||
else:
|
||||
# Low pivot = potential LONG entry
|
||||
direction = 'LONG'
|
||||
action = 'BUY'
|
||||
|
||||
# Create training sample
|
||||
training_sample = {
|
||||
'test_case_id': f"live_pivot_{symbol}_{timeframe}_{pivot['timestamp']}",
|
||||
'symbol': symbol,
|
||||
'timestamp': pivot['timestamp'],
|
||||
'action': action,
|
||||
'expected_outcome': {
|
||||
'direction': direction,
|
||||
'entry_price': pivot['price'],
|
||||
'exit_price': None, # Will be determined by model
|
||||
'profit_loss_pct': 0.0, # Unknown yet
|
||||
'holding_period_seconds': 300 # 5 minutes default
|
||||
},
|
||||
'training_config': {
|
||||
'timeframes': ['1s', '1m', '1h', '1d'],
|
||||
'candles_per_timeframe': 200
|
||||
},
|
||||
'annotation_metadata': {
|
||||
'source': 'live_pivot_detection',
|
||||
'pivot_level': 'L2',
|
||||
'pivot_type': pivot['type'],
|
||||
'confidence': pivot.get('strength', 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
# Train model in background (non-blocking)
|
||||
thread = threading.Thread(
|
||||
target=self._background_training,
|
||||
args=(training_sample,),
|
||||
daemon=True
|
||||
)
|
||||
thread.start()
|
||||
|
||||
logger.info(f"Started background training on L2 pivot")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training on pivot: {e}")
|
||||
|
||||
def _background_training(self, training_sample: Dict):
|
||||
"""
|
||||
Execute training in background thread
|
||||
|
||||
Args:
|
||||
training_sample: Training sample data
|
||||
"""
|
||||
try:
|
||||
# Use Transformer model for live pivot training
|
||||
model_name = 'Transformer'
|
||||
|
||||
logger.info(f"Background training started for {training_sample['test_case_id']}")
|
||||
|
||||
# Start training session
|
||||
training_id = self.training_adapter.start_training(
|
||||
model_name=model_name,
|
||||
test_cases=[training_sample]
|
||||
)
|
||||
|
||||
logger.info(f"Live pivot training session started: {training_id}")
|
||||
|
||||
# Monitor training (optional - could poll status)
|
||||
# For now, just fire and forget
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background training: {e}")
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""Get training statistics"""
|
||||
return {
|
||||
'running': self.running,
|
||||
'total_trained_pivots': len(self.trained_pivots),
|
||||
'last_training_1s': self.last_training_time.get('1s', 0),
|
||||
'last_training_1m': self.last_training_time.get('1m', 0),
|
||||
'pivot_history_1s': len(self.pivot_history['1s']),
|
||||
'pivot_history_1m': len(self.pivot_history['1m'])
|
||||
}
|
||||
|
||||
|
||||
# Global instance
|
||||
_live_pivot_trainer = None
|
||||
|
||||
|
||||
def get_live_pivot_trainer(orchestrator=None, data_provider=None, training_adapter=None):
|
||||
"""Get or create global LivePivotTrainer instance"""
|
||||
global _live_pivot_trainer
|
||||
|
||||
if _live_pivot_trainer is None and all([orchestrator, data_provider, training_adapter]):
|
||||
_live_pivot_trainer = LivePivotTrainer(orchestrator, data_provider, training_adapter)
|
||||
|
||||
return _live_pivot_trainer
|
||||
@@ -70,6 +70,7 @@ from NN.models.model_interfaces import (
|
||||
|
||||
from .config import get_config
|
||||
from .data_provider import DataProvider
|
||||
from .data_models import InferenceFrameReference, TrainingSession
|
||||
from .universal_data_adapter import UniversalDataAdapter, UniversalDataStream
|
||||
|
||||
# Import COB integration for real-time market microstructure data
|
||||
@@ -513,20 +514,12 @@ class TradingOrchestrator:
|
||||
self.inference_logger = None # Will be initialized later if needed
|
||||
self.db_manager = None # Will be initialized later if needed
|
||||
|
||||
# Inference Training Coordinator - manages inference frame references and training events
|
||||
# Integrated into orchestrator to reduce duplication and centralize coordination
|
||||
self.inference_training_coordinator = None
|
||||
try:
|
||||
from ANNOTATE.core.inference_training_system import InferenceTrainingCoordinator
|
||||
duckdb_storage = getattr(self.data_provider, 'duckdb_storage', None)
|
||||
self.inference_training_coordinator = InferenceTrainingCoordinator(
|
||||
data_provider=self.data_provider,
|
||||
duckdb_storage=duckdb_storage
|
||||
)
|
||||
logger.info("InferenceTrainingCoordinator initialized in orchestrator")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not initialize InferenceTrainingCoordinator: {e}")
|
||||
self.inference_training_coordinator = None
|
||||
# Integrated Training Coordination (moved from ANNOTATE/core for unified architecture)
|
||||
# Manages inference frame references and training events directly in orchestrator
|
||||
self.training_event_subscribers = []
|
||||
self.inference_frames = {} # Store inference frames by ID
|
||||
self.training_sessions = {} # Track active training sessions
|
||||
logger.info("Integrated training coordination initialized in orchestrator")
|
||||
|
||||
# CRITICAL: Initialize model_states dictionary to track model performance
|
||||
self.model_states: Dict[str, Dict[str, Any]] = {
|
||||
@@ -2965,3 +2958,169 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing predictions: {e}")
|
||||
|
||||
# ===== INTEGRATED TRAINING COORDINATION METHODS =====
|
||||
# Moved from ANNOTATE/core/inference_training_system.py for unified architecture
|
||||
|
||||
def subscribe_training_events(self, callback, event_types: List[str]):
|
||||
"""Subscribe to training events (candle completion, pivot events, etc.)"""
|
||||
try:
|
||||
subscriber = {
|
||||
'callback': callback,
|
||||
'event_types': event_types,
|
||||
'id': f"subscriber_{len(self.training_event_subscribers)}"
|
||||
}
|
||||
self.training_event_subscribers.append(subscriber)
|
||||
logger.info(f"Registered training event subscriber for events: {event_types}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error subscribing to training events: {e}")
|
||||
|
||||
def store_inference_frame(self, symbol: str, timeframe: str, prediction_data: Dict) -> str:
|
||||
"""Store inference frame reference for later training"""
|
||||
try:
|
||||
from uuid import uuid4
|
||||
|
||||
inference_id = str(uuid4())
|
||||
|
||||
# Create inference frame reference
|
||||
frame_ref = InferenceFrameReference(
|
||||
inference_id=inference_id,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
prediction_timestamp=datetime.now(),
|
||||
predicted_action=prediction_data.get('action'),
|
||||
predicted_price=prediction_data.get('predicted_price'),
|
||||
confidence=prediction_data.get('confidence', 0.0),
|
||||
model_type=prediction_data.get('model_type', 'transformer'),
|
||||
data_range_start=prediction_data.get('data_range_start', datetime.now() - timedelta(hours=1)),
|
||||
data_range_end=prediction_data.get('data_range_end', datetime.now())
|
||||
)
|
||||
|
||||
# Store in memory
|
||||
self.inference_frames[inference_id] = frame_ref
|
||||
|
||||
# Store in DuckDB if available
|
||||
if hasattr(self.data_provider, 'duckdb_storage') and self.data_provider.duckdb_storage:
|
||||
try:
|
||||
# Store inference frame in DuckDB for persistence
|
||||
# This would be implemented based on the DuckDB schema
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not store inference frame in DuckDB: {e}")
|
||||
|
||||
logger.debug(f"Stored inference frame: {inference_id} for {symbol} {timeframe}")
|
||||
return inference_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing inference frame: {e}")
|
||||
return ""
|
||||
|
||||
def trigger_training_on_event(self, event_type: str, event_data: Dict):
|
||||
"""Trigger training based on events (candle completion, pivot detection, etc.)"""
|
||||
try:
|
||||
# Notify all subscribers interested in this event type
|
||||
for subscriber in self.training_event_subscribers:
|
||||
if event_type in subscriber['event_types']:
|
||||
try:
|
||||
subscriber['callback'](event_type, event_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training event callback: {e}")
|
||||
|
||||
logger.debug(f"Triggered training event: {event_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering training event: {e}")
|
||||
|
||||
def start_training_session(self, symbol: str, timeframe: str, model_type: str) -> str:
|
||||
"""Start a new training session"""
|
||||
try:
|
||||
from uuid import uuid4
|
||||
|
||||
session_id = str(uuid4())
|
||||
|
||||
session = TrainingSession(
|
||||
training_id=session_id,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
model_type=model_type,
|
||||
start_time=datetime.now(),
|
||||
status='running'
|
||||
)
|
||||
|
||||
self.training_sessions[session_id] = session
|
||||
logger.info(f"Started training session: {session_id} for {symbol} {timeframe} {model_type}")
|
||||
|
||||
return session_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting training session: {e}")
|
||||
return ""
|
||||
|
||||
def complete_training_session(self, session_id: str, loss: float = None, accuracy: float = None, samples_trained: int = 0):
|
||||
"""Complete a training session with results"""
|
||||
try:
|
||||
if session_id in self.training_sessions:
|
||||
session = self.training_sessions[session_id]
|
||||
session.end_time = datetime.now()
|
||||
session.status = 'completed'
|
||||
session.loss = loss
|
||||
session.accuracy = accuracy
|
||||
session.samples_trained = samples_trained
|
||||
|
||||
logger.info(f"Completed training session: {session_id} - Loss: {loss}, Accuracy: {accuracy}, Samples: {samples_trained}")
|
||||
else:
|
||||
logger.warning(f"Training session not found: {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error completing training session: {e}")
|
||||
|
||||
def get_training_session_status(self, session_id: str) -> Optional[Dict]:
|
||||
"""Get status of a training session"""
|
||||
try:
|
||||
if session_id in self.training_sessions:
|
||||
session = self.training_sessions[session_id]
|
||||
return {
|
||||
'training_id': session.training_id,
|
||||
'symbol': session.symbol,
|
||||
'timeframe': session.timeframe,
|
||||
'model_type': session.model_type,
|
||||
'status': session.status,
|
||||
'start_time': session.start_time.isoformat() if session.start_time else None,
|
||||
'end_time': session.end_time.isoformat() if session.end_time else None,
|
||||
'loss': session.loss,
|
||||
'accuracy': session.accuracy,
|
||||
'samples_trained': session.samples_trained
|
||||
}
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting training session status: {e}")
|
||||
return None
|
||||
|
||||
def get_inference_frame(self, inference_id: str) -> Optional[InferenceFrameReference]:
|
||||
"""Get stored inference frame by ID"""
|
||||
return self.inference_frames.get(inference_id)
|
||||
|
||||
def update_inference_frame_results(self, inference_id: str, actual_candle: List[float], actual_price: float):
|
||||
"""Update inference frame with actual results for training"""
|
||||
try:
|
||||
if inference_id in self.inference_frames:
|
||||
frame_ref = self.inference_frames[inference_id]
|
||||
frame_ref.actual_candle = actual_candle
|
||||
frame_ref.actual_price = actual_price
|
||||
|
||||
# Calculate prediction error
|
||||
if frame_ref.predicted_price and actual_price:
|
||||
frame_ref.prediction_error = abs(frame_ref.predicted_price - actual_price)
|
||||
|
||||
# Check direction correctness
|
||||
if frame_ref.predicted_action and len(actual_candle) >= 4:
|
||||
open_price, close_price = actual_candle[0], actual_candle[3]
|
||||
actual_direction = 'BUY' if close_price > open_price else 'SELL' if close_price < open_price else 'HOLD'
|
||||
frame_ref.direction_correct = (frame_ref.predicted_action == actual_direction)
|
||||
|
||||
logger.debug(f"Updated inference frame results: {inference_id}")
|
||||
else:
|
||||
logger.warning(f"Inference frame not found: {inference_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating inference frame results: {e}")
|
||||
5093
core/real_training_adapter.py
Normal file
5093
core/real_training_adapter.py
Normal file
File diff suppressed because it is too large
Load Diff
299
core/training_data_fetcher.py
Normal file
299
core/training_data_fetcher.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Training Data Fetcher - Dynamic OHLCV data retrieval for model training
|
||||
|
||||
Fetches ±5 minutes of OHLCV data around annotated events from cache/database
|
||||
instead of storing it in JSON files. This allows efficient training on optimal timing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytz
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingDataFetcher:
|
||||
"""
|
||||
Fetches training data dynamically from cache/database for annotated events.
|
||||
|
||||
Key Features:
|
||||
- Fetches ±5 minutes of OHLCV data around entry/exit points
|
||||
- Generates training labels for optimal timing detection
|
||||
- Supports multiple timeframes (1s, 1m, 1h, 1d)
|
||||
- Efficient memory usage (no JSON storage)
|
||||
- Real-time data from cache/database
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider):
|
||||
"""
|
||||
Initialize training data fetcher
|
||||
|
||||
Args:
|
||||
data_provider: DataProvider instance for fetching OHLCV data
|
||||
"""
|
||||
self.data_provider = data_provider
|
||||
logger.info("TrainingDataFetcher initialized")
|
||||
|
||||
def fetch_training_data_for_annotation(self, annotation: Dict,
|
||||
context_window_minutes: int = 5) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch complete training data for an annotation
|
||||
|
||||
Args:
|
||||
annotation: Annotation metadata (from annotations_db.json)
|
||||
context_window_minutes: Minutes before/after entry to include
|
||||
|
||||
Returns:
|
||||
Dict with market_state, training_labels, and expected_outcome
|
||||
"""
|
||||
try:
|
||||
# Parse timestamps
|
||||
entry_time = datetime.fromisoformat(annotation['entry']['timestamp'].replace('Z', '+00:00'))
|
||||
exit_time = datetime.fromisoformat(annotation['exit']['timestamp'].replace('Z', '+00:00'))
|
||||
|
||||
symbol = annotation['symbol']
|
||||
direction = annotation['direction']
|
||||
|
||||
logger.info(f"Fetching training data for {symbol} at {entry_time} (±{context_window_minutes}min)")
|
||||
|
||||
# Fetch OHLCV data for all timeframes around entry time
|
||||
market_state = self._fetch_market_state_at_time(
|
||||
symbol=symbol,
|
||||
timestamp=entry_time,
|
||||
context_window_minutes=context_window_minutes
|
||||
)
|
||||
|
||||
# Generate training labels for optimal timing detection
|
||||
training_labels = self._generate_timing_labels(
|
||||
market_state=market_state,
|
||||
entry_time=entry_time,
|
||||
exit_time=exit_time,
|
||||
direction=direction
|
||||
)
|
||||
|
||||
# Prepare expected outcome
|
||||
expected_outcome = {
|
||||
"direction": direction,
|
||||
"profit_loss_pct": annotation['profit_loss_pct'],
|
||||
"entry_price": annotation['entry']['price'],
|
||||
"exit_price": annotation['exit']['price'],
|
||||
"holding_period_seconds": (exit_time - entry_time).total_seconds()
|
||||
}
|
||||
|
||||
return {
|
||||
"test_case_id": f"annotation_{annotation['annotation_id']}",
|
||||
"symbol": symbol,
|
||||
"timestamp": annotation['entry']['timestamp'],
|
||||
"action": "BUY" if direction == "LONG" else "SELL",
|
||||
"market_state": market_state,
|
||||
"training_labels": training_labels,
|
||||
"expected_outcome": expected_outcome,
|
||||
"annotation_metadata": {
|
||||
"annotator": "manual",
|
||||
"confidence": 1.0,
|
||||
"notes": annotation.get('notes', ''),
|
||||
"created_at": annotation.get('created_at'),
|
||||
"timeframe": annotation.get('timeframe', '1m')
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching training data for annotation: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {}
|
||||
|
||||
def _fetch_market_state_at_time(self, symbol: str, timestamp: datetime,
|
||||
context_window_minutes: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch market state at specific time from cache/database
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timestamp: Target timestamp
|
||||
context_window_minutes: Minutes before/after to include
|
||||
|
||||
Returns:
|
||||
Dict with OHLCV data for all timeframes
|
||||
"""
|
||||
try:
|
||||
# Use data provider's method to get market state
|
||||
market_state = self.data_provider.get_market_state_at_time(
|
||||
symbol=symbol,
|
||||
timestamp=timestamp,
|
||||
context_window_minutes=context_window_minutes
|
||||
)
|
||||
|
||||
logger.info(f"Fetched market state with {len(market_state)} timeframes")
|
||||
return market_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching market state: {e}")
|
||||
return {}
|
||||
|
||||
def _generate_timing_labels(self, market_state: Dict, entry_time: datetime,
|
||||
exit_time: datetime, direction: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate training labels for optimal timing detection
|
||||
|
||||
Labels help model learn:
|
||||
- WHEN to enter (optimal entry timing)
|
||||
- WHEN to exit (optimal exit timing)
|
||||
- WHEN NOT to trade (avoid bad timing)
|
||||
|
||||
Args:
|
||||
market_state: OHLCV data for all timeframes
|
||||
entry_time: Entry timestamp
|
||||
exit_time: Exit timestamp
|
||||
direction: Trade direction (LONG/SHORT)
|
||||
|
||||
Returns:
|
||||
Dict with training labels for each timeframe
|
||||
"""
|
||||
labels = {
|
||||
'direction': direction,
|
||||
'entry_timestamp': entry_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'exit_timestamp': exit_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
|
||||
# Generate labels for each timeframe
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
for tf in timeframes:
|
||||
tf_key = f'ohlcv_{tf}'
|
||||
if tf_key in market_state and 'timestamps' in market_state[tf_key]:
|
||||
timestamps = market_state[tf_key]['timestamps']
|
||||
|
||||
label_list = []
|
||||
entry_idx = -1
|
||||
exit_idx = -1
|
||||
|
||||
for i, ts_str in enumerate(timestamps):
|
||||
try:
|
||||
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
|
||||
# Make timezone-aware
|
||||
if ts.tzinfo is None:
|
||||
ts = pytz.UTC.localize(ts)
|
||||
|
||||
# Make entry_time and exit_time timezone-aware if needed
|
||||
if entry_time.tzinfo is None:
|
||||
entry_time = pytz.UTC.localize(entry_time)
|
||||
if exit_time.tzinfo is None:
|
||||
exit_time = pytz.UTC.localize(exit_time)
|
||||
|
||||
# Determine label based on timing
|
||||
if abs((ts - entry_time).total_seconds()) < 60: # Within 1 minute of entry
|
||||
label = 1 # OPTIMAL ENTRY TIMING
|
||||
entry_idx = i
|
||||
elif abs((ts - exit_time).total_seconds()) < 60: # Within 1 minute of exit
|
||||
label = 3 # OPTIMAL EXIT TIMING
|
||||
exit_idx = i
|
||||
elif entry_time < ts < exit_time: # Between entry and exit
|
||||
label = 2 # HOLD POSITION
|
||||
else: # Before entry or after exit
|
||||
label = 0 # NO ACTION (avoid trading)
|
||||
|
||||
label_list.append(label)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing timestamp {ts_str}: {e}")
|
||||
label_list.append(0)
|
||||
|
||||
labels[f'labels_{tf}'] = label_list
|
||||
labels[f'entry_index_{tf}'] = entry_idx
|
||||
labels[f'exit_index_{tf}'] = exit_idx
|
||||
|
||||
# Log label distribution
|
||||
label_counts = {0: 0, 1: 0, 2: 0, 3: 0}
|
||||
for label in label_list:
|
||||
label_counts[label] += 1
|
||||
|
||||
logger.info(f"Generated {tf} labels: {label_counts[0]} NO_ACTION, "
|
||||
f"{label_counts[1]} ENTRY, {label_counts[2]} HOLD, {label_counts[3]} EXIT")
|
||||
|
||||
return labels
|
||||
|
||||
def fetch_training_batch(self, annotations: List[Dict],
|
||||
context_window_minutes: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch training data for multiple annotations
|
||||
|
||||
Args:
|
||||
annotations: List of annotation metadata
|
||||
context_window_minutes: Minutes before/after entry to include
|
||||
|
||||
Returns:
|
||||
List of training data dictionaries
|
||||
"""
|
||||
training_data = []
|
||||
|
||||
logger.info(f"Fetching training batch for {len(annotations)} annotations")
|
||||
|
||||
for annotation in annotations:
|
||||
try:
|
||||
training_sample = self.fetch_training_data_for_annotation(
|
||||
annotation, context_window_minutes
|
||||
)
|
||||
|
||||
if training_sample:
|
||||
training_data.append(training_sample)
|
||||
else:
|
||||
logger.warning(f"Failed to fetch training data for annotation {annotation.get('annotation_id')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing annotation {annotation.get('annotation_id')}: {e}")
|
||||
|
||||
logger.info(f"Successfully fetched training data for {len(training_data)}/{len(annotations)} annotations")
|
||||
return training_data
|
||||
|
||||
def get_training_statistics(self, training_data: List[Dict]) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics about training data
|
||||
|
||||
Args:
|
||||
training_data: List of training data samples
|
||||
|
||||
Returns:
|
||||
Dict with training statistics
|
||||
"""
|
||||
if not training_data:
|
||||
return {}
|
||||
|
||||
stats = {
|
||||
'total_samples': len(training_data),
|
||||
'symbols': {},
|
||||
'directions': {'LONG': 0, 'SHORT': 0},
|
||||
'avg_profit_loss': 0.0,
|
||||
'timeframes_available': set()
|
||||
}
|
||||
|
||||
total_pnl = 0.0
|
||||
|
||||
for sample in training_data:
|
||||
symbol = sample.get('symbol', 'UNKNOWN')
|
||||
direction = sample.get('expected_outcome', {}).get('direction', 'UNKNOWN')
|
||||
pnl = sample.get('expected_outcome', {}).get('profit_loss_pct', 0.0)
|
||||
|
||||
# Count symbols
|
||||
stats['symbols'][symbol] = stats['symbols'].get(symbol, 0) + 1
|
||||
|
||||
# Count directions
|
||||
if direction in stats['directions']:
|
||||
stats['directions'][direction] += 1
|
||||
|
||||
# Accumulate P&L
|
||||
total_pnl += pnl
|
||||
|
||||
# Check available timeframes
|
||||
market_state = sample.get('market_state', {})
|
||||
for key in market_state.keys():
|
||||
if key.startswith('ohlcv_'):
|
||||
stats['timeframes_available'].add(key.replace('ohlcv_', ''))
|
||||
|
||||
stats['avg_profit_loss'] = total_pnl / len(training_data)
|
||||
stats['timeframes_available'] = list(stats['timeframes_available'])
|
||||
|
||||
return stats
|
||||
Reference in New Issue
Block a user