merge annotate /ANNOTATE/core into /core.
fix chart updates
This commit is contained in:
@@ -1,72 +0,0 @@
|
||||
# NO SIMULATION CODE POLICY
|
||||
|
||||
## CRITICAL RULE: NEVER CREATE SIMULATION CODE
|
||||
|
||||
**Date:** 2025-10-23
|
||||
**Status:** PERMANENT POLICY
|
||||
|
||||
## What Was Removed
|
||||
|
||||
We deleted `ANNOTATE/core/training_simulator.py` which contained simulation/mock training code.
|
||||
|
||||
## Why This Is Critical
|
||||
|
||||
1. **Real Training Only**: We have REAL training implementations in:
|
||||
- `NN/training/enhanced_realtime_training.py` - Real-time training system
|
||||
- `NN/training/model_manager.py` - Model checkpoint management
|
||||
- `core/unified_training_manager.py` - Unified training orchestration
|
||||
- `core/orchestrator.py` - Core model training methods
|
||||
|
||||
2. **No Shortcuts**: Simulation code creates technical debt and masks real issues
|
||||
3. **Production Quality**: All code must be production-ready, not simulated
|
||||
|
||||
## What To Use Instead
|
||||
|
||||
### For Model Training
|
||||
Use the real training implementations:
|
||||
|
||||
```python
|
||||
# Use EnhancedRealtimeTrainingSystem for real-time training
|
||||
from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
||||
|
||||
# Use UnifiedTrainingManager for coordinated training
|
||||
from core.unified_training_manager import UnifiedTrainingManager
|
||||
|
||||
# Use orchestrator's built-in training methods
|
||||
orchestrator.train_models()
|
||||
```
|
||||
|
||||
### For Model Management
|
||||
```python
|
||||
# Use ModelManager for checkpoint management
|
||||
from NN.training.model_manager import ModelManager
|
||||
|
||||
# Use CheckpointManager for saving/loading
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
```
|
||||
|
||||
## If You Need Training Features
|
||||
|
||||
1. **Extend existing real implementations** - Don't create new simulation code
|
||||
2. **Add to orchestrator** - Put training logic in the orchestrator
|
||||
3. **Use UnifiedTrainingManager** - For coordinated multi-model training
|
||||
4. **Integrate with EnhancedRealtimeTrainingSystem** - For online learning
|
||||
|
||||
## NEVER DO THIS
|
||||
|
||||
Create files with "simulator", "simulation", "mock", "fake" in the name
|
||||
Use placeholder/dummy training loops
|
||||
Return fake metrics or results
|
||||
Skip actual model training
|
||||
|
||||
## ALWAYS DO THIS
|
||||
|
||||
Use real model training methods
|
||||
Integrate with existing training systems
|
||||
Save real checkpoints
|
||||
Track real metrics
|
||||
Handle real data
|
||||
|
||||
---
|
||||
|
||||
**Remember**: If data is unavailable, return None/empty/error - NEVER simulate it!
|
||||
@@ -1,5 +0,0 @@
|
||||
"""
|
||||
ANNOTATE Core Module
|
||||
|
||||
Core business logic for the Manual Trade Annotation UI
|
||||
"""
|
||||
@@ -1,469 +0,0 @@
|
||||
"""
|
||||
Annotation Manager - Manages trade annotations and test case generation
|
||||
|
||||
Handles storage, retrieval, and test case generation from manual trade annotations.
|
||||
Stores annotations in both JSON (legacy) and SQLite (with full market data).
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
import logging
|
||||
import pytz
|
||||
|
||||
# Add parent directory to path for imports
|
||||
parent_dir = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(parent_dir))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import DuckDB storage
|
||||
try:
|
||||
from core.duckdb_storage import DuckDBStorage
|
||||
DUCKDB_AVAILABLE = True
|
||||
except ImportError:
|
||||
DUCKDB_AVAILABLE = False
|
||||
logger.warning("DuckDB storage not available for annotations")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TradeAnnotation:
|
||||
"""Represents a manually marked trade"""
|
||||
annotation_id: str
|
||||
symbol: str
|
||||
timeframe: str
|
||||
entry: Dict[str, Any] # {timestamp, price, index}
|
||||
exit: Dict[str, Any] # {timestamp, price, index}
|
||||
direction: str # 'LONG' or 'SHORT'
|
||||
profit_loss_pct: float
|
||||
notes: str = ""
|
||||
created_at: str = None
|
||||
market_context: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.now(pytz.UTC).isoformat()
|
||||
if self.market_context is None:
|
||||
self.market_context = {}
|
||||
|
||||
|
||||
class AnnotationManager:
|
||||
"""Manages trade annotations and test case generation"""
|
||||
|
||||
def __init__(self, storage_path: str = "ANNOTATE/data/annotations"):
|
||||
"""Initialize annotation manager"""
|
||||
self.storage_path = Path(storage_path)
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.annotations_file = self.storage_path / "annotations_db.json"
|
||||
self.test_cases_dir = self.storage_path.parent / "test_cases"
|
||||
self.test_cases_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.annotations_db = self._load_annotations()
|
||||
|
||||
# Initialize DuckDB storage for complete annotation data
|
||||
self.duckdb_storage: Optional[DuckDBStorage] = None
|
||||
if DUCKDB_AVAILABLE:
|
||||
try:
|
||||
self.duckdb_storage = DuckDBStorage()
|
||||
logger.info("DuckDB storage initialized for annotations")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not initialize DuckDB storage: {e}")
|
||||
|
||||
logger.info(f"AnnotationManager initialized with storage: {self.storage_path}")
|
||||
|
||||
def _load_annotations(self) -> Dict[str, List[Dict]]:
|
||||
"""Load annotations from storage"""
|
||||
if self.annotations_file.exists():
|
||||
try:
|
||||
with open(self.annotations_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
logger.info(f"Loaded {len(data.get('annotations', []))} annotations")
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading annotations: {e}")
|
||||
return {"annotations": [], "metadata": {}}
|
||||
else:
|
||||
return {"annotations": [], "metadata": {}}
|
||||
|
||||
def _save_annotations(self):
|
||||
"""Save annotations to storage"""
|
||||
try:
|
||||
# Update metadata
|
||||
self.annotations_db["metadata"] = {
|
||||
"total_annotations": len(self.annotations_db["annotations"]),
|
||||
"last_updated": datetime.now(pytz.UTC).isoformat()
|
||||
}
|
||||
|
||||
with open(self.annotations_file, 'w') as f:
|
||||
json.dump(self.annotations_db, f, indent=2)
|
||||
|
||||
logger.info(f"Saved {len(self.annotations_db['annotations'])} annotations")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving annotations: {e}")
|
||||
raise
|
||||
|
||||
def create_annotation(self, entry_point: Dict, exit_point: Dict,
|
||||
symbol: str, timeframe: str,
|
||||
entry_market_state: Dict = None,
|
||||
exit_market_state: Dict = None) -> TradeAnnotation:
|
||||
"""Create new trade annotation"""
|
||||
# Calculate direction and P&L
|
||||
entry_price = entry_point['price']
|
||||
exit_price = exit_point['price']
|
||||
|
||||
if exit_price > entry_price:
|
||||
direction = 'LONG'
|
||||
profit_loss_pct = ((exit_price - entry_price) / entry_price) * 100
|
||||
else:
|
||||
direction = 'SHORT'
|
||||
profit_loss_pct = ((entry_price - exit_price) / entry_price) * 100
|
||||
|
||||
# Store complete market context for training
|
||||
market_context = {
|
||||
'entry_state': entry_market_state or {},
|
||||
'exit_state': exit_market_state or {}
|
||||
}
|
||||
|
||||
annotation = TradeAnnotation(
|
||||
annotation_id=str(uuid.uuid4()),
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
entry=entry_point,
|
||||
exit=exit_point,
|
||||
direction=direction,
|
||||
profit_loss_pct=profit_loss_pct,
|
||||
market_context=market_context
|
||||
)
|
||||
|
||||
logger.info(f"Created annotation: {annotation.annotation_id} ({direction}, {profit_loss_pct:.2f}%)")
|
||||
logger.info(f" Entry state: {len(entry_market_state or {})} timeframes")
|
||||
logger.info(f" Exit state: {len(exit_market_state or {})} timeframes")
|
||||
return annotation
|
||||
|
||||
def save_annotation(self, annotation: TradeAnnotation,
|
||||
market_snapshots: Dict = None,
|
||||
model_predictions: List[Dict] = None):
|
||||
"""
|
||||
Save annotation to storage (JSON + SQLite)
|
||||
|
||||
Args:
|
||||
annotation: TradeAnnotation object
|
||||
market_snapshots: Dict of {timeframe: DataFrame} with OHLCV data
|
||||
model_predictions: List of model predictions at annotation time
|
||||
"""
|
||||
# Convert to dict
|
||||
ann_dict = asdict(annotation)
|
||||
|
||||
# Add to JSON database (legacy)
|
||||
self.annotations_db["annotations"].append(ann_dict)
|
||||
|
||||
# Save to JSON file
|
||||
self._save_annotations()
|
||||
|
||||
# Save to DuckDB with complete market data
|
||||
if self.duckdb_storage and market_snapshots:
|
||||
try:
|
||||
self.duckdb_storage.store_annotation(
|
||||
annotation_id=annotation.annotation_id,
|
||||
annotation_data=ann_dict,
|
||||
market_snapshots=market_snapshots,
|
||||
model_predictions=model_predictions
|
||||
)
|
||||
logger.info(f"Saved annotation {annotation.annotation_id} to DuckDB with {len(market_snapshots)} timeframes")
|
||||
except Exception as e:
|
||||
logger.error(f"Could not save annotation to DuckDB: {e}")
|
||||
|
||||
logger.info(f"Saved annotation: {annotation.annotation_id}")
|
||||
|
||||
def get_annotations(self, symbol: str = None,
|
||||
timeframe: str = None) -> List[TradeAnnotation]:
|
||||
"""Retrieve annotations with optional filtering"""
|
||||
annotations = self.annotations_db.get("annotations", [])
|
||||
|
||||
# Filter by symbol
|
||||
if symbol:
|
||||
annotations = [a for a in annotations if a.get('symbol') == symbol]
|
||||
|
||||
# Filter by timeframe
|
||||
if timeframe:
|
||||
annotations = [a for a in annotations if a.get('timeframe') == timeframe]
|
||||
|
||||
# Convert to TradeAnnotation objects
|
||||
result = []
|
||||
for ann_dict in annotations:
|
||||
try:
|
||||
annotation = TradeAnnotation(**ann_dict)
|
||||
result.append(annotation)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting annotation: {e}")
|
||||
|
||||
return result
|
||||
|
||||
def delete_annotation(self, annotation_id: str) -> bool:
|
||||
"""
|
||||
Delete annotation and its associated test case file
|
||||
|
||||
Args:
|
||||
annotation_id: ID of annotation to delete
|
||||
|
||||
Returns:
|
||||
bool: True if annotation was deleted, False if not found
|
||||
|
||||
Raises:
|
||||
Exception: If there's an error during deletion
|
||||
"""
|
||||
original_count = len(self.annotations_db["annotations"])
|
||||
self.annotations_db["annotations"] = [
|
||||
a for a in self.annotations_db["annotations"]
|
||||
if a.get('annotation_id') != annotation_id
|
||||
]
|
||||
|
||||
if len(self.annotations_db["annotations"]) < original_count:
|
||||
# Annotation was found and removed
|
||||
self._save_annotations()
|
||||
|
||||
# Also delete the associated test case file
|
||||
test_case_file = self.test_cases_dir / f"annotation_{annotation_id}.json"
|
||||
if test_case_file.exists():
|
||||
try:
|
||||
test_case_file.unlink()
|
||||
logger.info(f"Deleted test case file: {test_case_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting test case file {test_case_file}: {e}")
|
||||
# Don't fail the whole operation if test case deletion fails
|
||||
|
||||
logger.info(f"Deleted annotation: {annotation_id}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Annotation not found: {annotation_id}")
|
||||
return False
|
||||
|
||||
def clear_all_annotations(self, symbol: str = None):
|
||||
"""
|
||||
Clear all annotations (optionally filtered by symbol)
|
||||
More efficient than deleting one by one
|
||||
|
||||
Args:
|
||||
symbol: Optional symbol filter. If None, clears all annotations.
|
||||
|
||||
Returns:
|
||||
int: Number of annotations deleted
|
||||
"""
|
||||
# Get annotations to delete
|
||||
if symbol:
|
||||
annotations_to_delete = [
|
||||
a for a in self.annotations_db["annotations"]
|
||||
if a.get('symbol') == symbol
|
||||
]
|
||||
# Keep annotations for other symbols
|
||||
self.annotations_db["annotations"] = [
|
||||
a for a in self.annotations_db["annotations"]
|
||||
if a.get('symbol') != symbol
|
||||
]
|
||||
else:
|
||||
annotations_to_delete = self.annotations_db["annotations"].copy()
|
||||
self.annotations_db["annotations"] = []
|
||||
|
||||
deleted_count = len(annotations_to_delete)
|
||||
|
||||
if deleted_count > 0:
|
||||
# Save updated annotations database
|
||||
self._save_annotations()
|
||||
|
||||
# Delete associated test case files
|
||||
for annotation in annotations_to_delete:
|
||||
annotation_id = annotation.get('annotation_id')
|
||||
test_case_file = self.test_cases_dir / f"annotation_{annotation_id}.json"
|
||||
if test_case_file.exists():
|
||||
try:
|
||||
test_case_file.unlink()
|
||||
logger.debug(f"Deleted test case file: {test_case_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting test case file {test_case_file}: {e}")
|
||||
|
||||
logger.info(f"Cleared {deleted_count} annotations" + (f" for symbol {symbol}" if symbol else ""))
|
||||
|
||||
return deleted_count
|
||||
|
||||
def generate_test_case(self, annotation: TradeAnnotation, data_provider=None, auto_save: bool = True) -> Dict:
|
||||
"""
|
||||
Generate lightweight test case metadata (no OHLCV data stored)
|
||||
OHLCV data will be fetched dynamically from cache/database during training
|
||||
|
||||
Args:
|
||||
annotation: TradeAnnotation object
|
||||
data_provider: Optional DataProvider instance (not used for storage)
|
||||
|
||||
Returns:
|
||||
Test case metadata dictionary
|
||||
"""
|
||||
test_case = {
|
||||
"test_case_id": f"annotation_{annotation.annotation_id}",
|
||||
"symbol": annotation.symbol,
|
||||
"timestamp": annotation.entry['timestamp'],
|
||||
"action": "BUY" if annotation.direction == "LONG" else "SELL",
|
||||
"expected_outcome": {
|
||||
"direction": annotation.direction,
|
||||
"profit_loss_pct": annotation.profit_loss_pct,
|
||||
"holding_period_seconds": self._calculate_holding_period(annotation),
|
||||
"exit_price": annotation.exit['price'],
|
||||
"entry_price": annotation.entry['price']
|
||||
},
|
||||
"annotation_metadata": {
|
||||
"annotator": "manual",
|
||||
"confidence": 1.0,
|
||||
"notes": annotation.notes,
|
||||
"created_at": annotation.created_at,
|
||||
"timeframe": annotation.timeframe
|
||||
},
|
||||
"training_config": {
|
||||
"context_window_minutes": 5, # ±5 minutes around entry/exit
|
||||
"timeframes": ["1s", "1m", "1h", "1d"],
|
||||
"data_source": "cache" # Will fetch from cache/database
|
||||
}
|
||||
}
|
||||
|
||||
# Save lightweight test case metadata to file if auto_save is True
|
||||
if auto_save:
|
||||
test_case_file = self.test_cases_dir / f"{test_case['test_case_id']}.json"
|
||||
with open(test_case_file, 'w') as f:
|
||||
json.dump(test_case, f, indent=2)
|
||||
logger.info(f"Saved test case metadata to: {test_case_file}")
|
||||
|
||||
logger.info(f"Generated lightweight test case: {test_case['test_case_id']} (OHLCV data will be fetched dynamically)")
|
||||
return test_case
|
||||
|
||||
def get_all_test_cases(self, symbol: Optional[str] = None) -> List[Dict]:
|
||||
"""
|
||||
Load all test cases from disk
|
||||
|
||||
Args:
|
||||
symbol: Optional symbol filter (e.g., 'ETH/USDT'). If provided, only returns
|
||||
test cases for that symbol. Critical for avoiding cross-symbol training.
|
||||
|
||||
Returns:
|
||||
List of test case dictionaries
|
||||
"""
|
||||
test_cases = []
|
||||
|
||||
if not self.test_cases_dir.exists():
|
||||
return test_cases
|
||||
|
||||
for test_case_file in self.test_cases_dir.glob("annotation_*.json"):
|
||||
try:
|
||||
with open(test_case_file, 'r') as f:
|
||||
test_case = json.load(f)
|
||||
|
||||
# CRITICAL: Filter by symbol to avoid training on wrong symbol
|
||||
if symbol:
|
||||
test_case_symbol = test_case.get('symbol', '')
|
||||
if test_case_symbol != symbol:
|
||||
logger.debug(f"Skipping {test_case_file.name}: symbol {test_case_symbol} != {symbol}")
|
||||
continue
|
||||
|
||||
test_cases.append(test_case)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading test case {test_case_file}: {e}")
|
||||
|
||||
if symbol:
|
||||
logger.info(f"Loaded {len(test_cases)} test cases for symbol {symbol}")
|
||||
else:
|
||||
logger.info(f"Loaded {len(test_cases)} test cases (all symbols)")
|
||||
return test_cases
|
||||
|
||||
def _calculate_holding_period(self, annotation: TradeAnnotation) -> float:
|
||||
"""Calculate holding period in seconds"""
|
||||
try:
|
||||
entry_time = datetime.fromisoformat(annotation.entry['timestamp'].replace('Z', '+00:00'))
|
||||
exit_time = datetime.fromisoformat(annotation.exit['timestamp'].replace('Z', '+00:00'))
|
||||
return (exit_time - entry_time).total_seconds()
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating holding period: {e}")
|
||||
return 0.0
|
||||
|
||||
def _generate_training_labels(self, market_state: Dict, entry_time: datetime,
|
||||
exit_time: datetime, direction: str) -> Dict:
|
||||
"""
|
||||
Generate training labels for each timestamp in the market data.
|
||||
This helps the model learn WHERE to signal and WHERE NOT to signal.
|
||||
|
||||
Labels:
|
||||
- 0 = NO SIGNAL (before entry or after exit)
|
||||
- 1 = ENTRY SIGNAL (at entry time)
|
||||
- 2 = HOLD (between entry and exit)
|
||||
- 3 = EXIT SIGNAL (at exit time)
|
||||
"""
|
||||
labels = {}
|
||||
|
||||
# Use 1m timeframe as reference for labeling
|
||||
if 'ohlcv_1m' in market_state and 'timestamps' in market_state['ohlcv_1m']:
|
||||
timestamps = market_state['ohlcv_1m']['timestamps']
|
||||
|
||||
label_list = []
|
||||
for ts_str in timestamps:
|
||||
try:
|
||||
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
|
||||
# Make timezone-aware to match entry_time
|
||||
if ts.tzinfo is None:
|
||||
ts = pytz.UTC.localize(ts)
|
||||
|
||||
# Determine label based on position relative to entry/exit
|
||||
if abs((ts - entry_time).total_seconds()) < 60: # Within 1 minute of entry
|
||||
label = 1 # ENTRY SIGNAL
|
||||
elif abs((ts - exit_time).total_seconds()) < 60: # Within 1 minute of exit
|
||||
label = 3 # EXIT SIGNAL
|
||||
elif entry_time < ts < exit_time: # Between entry and exit
|
||||
label = 2 # HOLD
|
||||
else: # Before entry or after exit
|
||||
label = 0 # NO SIGNAL
|
||||
|
||||
label_list.append(label)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing timestamp {ts_str}: {e}")
|
||||
label_list.append(0)
|
||||
|
||||
labels['labels_1m'] = label_list
|
||||
labels['direction'] = direction
|
||||
labels['entry_timestamp'] = entry_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
labels['exit_timestamp'] = exit_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
logger.info(f"Generated {len(label_list)} training labels: "
|
||||
f"{label_list.count(0)} NO_SIGNAL, "
|
||||
f"{label_list.count(1)} ENTRY, "
|
||||
f"{label_list.count(2)} HOLD, "
|
||||
f"{label_list.count(3)} EXIT")
|
||||
|
||||
return labels
|
||||
|
||||
def export_annotations(self, annotations: List[TradeAnnotation] = None,
|
||||
format_type: str = 'json') -> Path:
|
||||
"""Export annotations to file"""
|
||||
if annotations is None:
|
||||
annotations = self.get_annotations()
|
||||
|
||||
# Convert to dicts
|
||||
export_data = [asdict(ann) for ann in annotations]
|
||||
|
||||
# Create export file
|
||||
timestamp = datetime.now(pytz.UTC).strftime('%Y%m%d_%H%M%S')
|
||||
export_file = self.storage_path / f"export_{timestamp}.{format_type}"
|
||||
|
||||
if format_type == 'json':
|
||||
with open(export_file, 'w') as f:
|
||||
json.dump(export_data, f, indent=2)
|
||||
elif format_type == 'csv':
|
||||
import csv
|
||||
with open(export_file, 'w', newline='') as f:
|
||||
if export_data:
|
||||
writer = csv.DictWriter(f, fieldnames=export_data[0].keys())
|
||||
writer.writeheader()
|
||||
writer.writerows(export_data)
|
||||
|
||||
logger.info(f"Exported {len(annotations)} annotations to {export_file}")
|
||||
return export_file
|
||||
@@ -1,737 +0,0 @@
|
||||
"""
|
||||
Historical Data Loader - Integrates with existing DataProvider
|
||||
|
||||
Provides data loading and caching for the annotation UI, ensuring the same
|
||||
data quality and structure used by training and inference systems.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HistoricalDataLoader:
|
||||
"""
|
||||
Loads historical data from the main system's DataProvider
|
||||
Ensures consistency with training/inference data
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider):
|
||||
"""
|
||||
Initialize with existing DataProvider
|
||||
|
||||
Args:
|
||||
data_provider: Instance of core.data_provider.DataProvider
|
||||
"""
|
||||
self.data_provider = data_provider
|
||||
self.cache_dir = Path("ANNOTATE/data/cache")
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Cache for recently loaded data
|
||||
self.memory_cache = {}
|
||||
self.cache_ttl = timedelta(minutes=5)
|
||||
|
||||
# Startup mode - allow stale cache for faster loading
|
||||
self.startup_mode = True
|
||||
|
||||
logger.info("HistoricalDataLoader initialized with existing DataProvider (startup mode: ON)")
|
||||
|
||||
def get_data(self, symbol: str, timeframe: str,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
limit: int = 2500,
|
||||
direction: str = 'latest') -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Get historical data for symbol and timeframe
|
||||
|
||||
Args:
|
||||
symbol: Trading pair (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe (e.g., '1s', '1m', '1h', '1d')
|
||||
start_time: Start time for data range
|
||||
end_time: End time for data range
|
||||
limit: Maximum number of candles to return
|
||||
direction: 'latest' (most recent), 'before' (older data), 'after' (newer data)
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data or None if unavailable
|
||||
"""
|
||||
start_time_ms = time.time()
|
||||
|
||||
# Check memory cache first (exclude direction from cache key for infinite scroll)
|
||||
cache_key = f"{symbol}_{timeframe}_{start_time}_{end_time}_{limit}"
|
||||
|
||||
# Determine TTL based on timeframe
|
||||
current_ttl = self.cache_ttl
|
||||
if timeframe == '1s':
|
||||
current_ttl = timedelta(seconds=1)
|
||||
elif timeframe == '1m':
|
||||
current_ttl = timedelta(seconds=5)
|
||||
|
||||
# For 'after' direction (incremental updates), we should force a refresh if cache is stale
|
||||
# or simply bypass cache for 1s/1m to ensure we get the absolute latest
|
||||
bypass_cache = (direction == 'after' and timeframe in ['1s', '1m'])
|
||||
|
||||
if cache_key in self.memory_cache and direction == 'latest' and not bypass_cache:
|
||||
cached_data, cached_time = self.memory_cache[cache_key]
|
||||
if datetime.now() - cached_time < current_ttl:
|
||||
# For 1s/1m, we want to return immediately if valid
|
||||
if timeframe not in ['1s', '1m']:
|
||||
elapsed_ms = (time.time() - start_time_ms) * 1000
|
||||
logger.debug(f"Memory cache hit for {symbol} {timeframe} ({elapsed_ms:.1f}ms)")
|
||||
return cached_data
|
||||
|
||||
try:
|
||||
# FORCE refresh for 1s/1m if requesting latest data OR incremental update
|
||||
# Also force refresh for live updates (small limit + direction='latest' + no time range)
|
||||
is_live_update = (direction == 'latest' and not start_time and not end_time and limit <= 5)
|
||||
force_refresh = (timeframe in ['1s', '1m'] and (bypass_cache or (not start_time and not end_time))) or is_live_update
|
||||
|
||||
if is_live_update:
|
||||
logger.debug(f"Live update detected for {symbol} {timeframe} (limit={limit}, direction={direction}) - forcing refresh")
|
||||
|
||||
# Try to get data from DataProvider's cached data first (most efficient)
|
||||
if hasattr(self.data_provider, 'cached_data'):
|
||||
with self.data_provider.data_lock:
|
||||
cached_df = self.data_provider.cached_data.get(symbol, {}).get(timeframe)
|
||||
|
||||
if cached_df is not None and not cached_df.empty:
|
||||
# If time range is specified, check if cached data covers it
|
||||
use_cached_data = True
|
||||
if start_time or end_time:
|
||||
if isinstance(cached_df.index, pd.DatetimeIndex):
|
||||
cache_start = cached_df.index.min()
|
||||
cache_end = cached_df.index.max()
|
||||
|
||||
# Check if requested range is within cached range
|
||||
if start_time and start_time < cache_start:
|
||||
use_cached_data = False
|
||||
elif end_time and end_time > cache_end:
|
||||
use_cached_data = False
|
||||
elif start_time and end_time:
|
||||
# Both specified - check if range overlaps
|
||||
if end_time < cache_start or start_time > cache_end:
|
||||
use_cached_data = False
|
||||
|
||||
# Use cached data if we have enough candles and it covers the range
|
||||
if use_cached_data and len(cached_df) >= min(limit, 100): # Use cached if we have at least 100 candles
|
||||
elapsed_ms = (time.time() - start_time_ms) * 1000
|
||||
logger.debug(f" DataProvider cache hit for {symbol} {timeframe} ({len(cached_df)} candles, {elapsed_ms:.1f}ms)")
|
||||
|
||||
# Filter by time range with direction support
|
||||
filtered_df = self._filter_by_time_range(
|
||||
cached_df.copy(),
|
||||
start_time,
|
||||
end_time,
|
||||
direction,
|
||||
limit
|
||||
)
|
||||
|
||||
# Only return cached data if filter produced results
|
||||
if filtered_df is not None and not filtered_df.empty:
|
||||
# Cache in memory
|
||||
self.memory_cache[cache_key] = (filtered_df, datetime.now())
|
||||
return filtered_df
|
||||
# If filter returned empty, fall through to fetch from DuckDB/API
|
||||
|
||||
# Try unified storage first if available
|
||||
if hasattr(self.data_provider, 'is_unified_storage_enabled') and \
|
||||
self.data_provider.is_unified_storage_enabled():
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
# Get data from unified storage
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# If we have a specific time range, get historical data
|
||||
if start_time or end_time:
|
||||
target_time = end_time if end_time else start_time
|
||||
inference_data = loop.run_until_complete(
|
||||
self.data_provider.get_inference_data_unified(
|
||||
symbol,
|
||||
timestamp=target_time,
|
||||
context_window_minutes=60
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Get latest real-time data
|
||||
inference_data = loop.run_until_complete(
|
||||
self.data_provider.get_inference_data_unified(symbol)
|
||||
)
|
||||
|
||||
# Extract the requested timeframe
|
||||
df = inference_data.get_timeframe_data(timeframe)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Limit number of candles
|
||||
if len(df) > limit:
|
||||
df = df.tail(limit)
|
||||
|
||||
# Cache in memory
|
||||
self.memory_cache[cache_key] = (df.copy(), datetime.now())
|
||||
|
||||
logger.info(f"Loaded {len(df)} candles from unified storage for {symbol} {timeframe}")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Unified storage not available, falling back to cached data: {e}")
|
||||
|
||||
# Fallback to existing cached data method (duplicate check - should not reach here if first check worked)
|
||||
# This is kept for backward compatibility but should rarely execute
|
||||
if hasattr(self.data_provider, 'cached_data'):
|
||||
if symbol in self.data_provider.cached_data:
|
||||
if timeframe in self.data_provider.cached_data[symbol]:
|
||||
df = self.data_provider.cached_data[symbol][timeframe]
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Check if cached data covers the requested time range
|
||||
use_cached_data = True
|
||||
if start_time or end_time:
|
||||
if isinstance(df.index, pd.DatetimeIndex):
|
||||
cache_start = df.index.min()
|
||||
cache_end = df.index.max()
|
||||
|
||||
if start_time and start_time < cache_start:
|
||||
use_cached_data = False
|
||||
elif end_time and end_time > cache_end:
|
||||
use_cached_data = False
|
||||
elif start_time and end_time:
|
||||
if end_time < cache_start or start_time > cache_end:
|
||||
use_cached_data = False
|
||||
|
||||
if use_cached_data:
|
||||
# Filter by time range with direction support
|
||||
df = self._filter_by_time_range(
|
||||
df.copy(),
|
||||
start_time,
|
||||
end_time,
|
||||
direction,
|
||||
limit
|
||||
)
|
||||
|
||||
# Only return if filter produced results
|
||||
if df is not None and not df.empty:
|
||||
# Cache in memory
|
||||
self.memory_cache[cache_key] = (df.copy(), datetime.now())
|
||||
|
||||
logger.info(f"Loaded {len(df)} candles for {symbol} {timeframe}")
|
||||
return df
|
||||
# If filter returned empty or range not covered, fall through to fetch from DuckDB/API
|
||||
|
||||
# Check DuckDB first for historical data (always check for infinite scroll)
|
||||
if self.data_provider.duckdb_storage and (start_time or end_time):
|
||||
logger.info(f"Checking DuckDB for {symbol} {timeframe} historical data (direction={direction})")
|
||||
df = self.data_provider.duckdb_storage.get_ohlcv_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
direction=direction
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
elapsed_ms = (time.time() - start_time_ms) * 1000
|
||||
logger.info(f" DuckDB hit for {symbol} {timeframe} ({len(df)} candles, {elapsed_ms:.1f}ms)")
|
||||
# Cache in memory
|
||||
self.memory_cache[cache_key] = (df.copy(), datetime.now())
|
||||
return df
|
||||
else:
|
||||
logger.info(f"No data in DuckDB, fetching from exchange API for {symbol} {timeframe}")
|
||||
|
||||
# Fetch from exchange API with time range
|
||||
df = self._fetch_from_exchange_api(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
direction=direction
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
elapsed_ms = (time.time() - start_time_ms) * 1000
|
||||
logger.info(f"Exchange API hit for {symbol} {timeframe} ({len(df)} candles, {elapsed_ms:.1f}ms)")
|
||||
|
||||
# Store in DuckDB for future use
|
||||
if self.data_provider.duckdb_storage:
|
||||
stored_count = self.data_provider.duckdb_storage.store_ohlcv_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
df=df
|
||||
)
|
||||
logger.info(f"Stored {stored_count} new candles in DuckDB")
|
||||
|
||||
# Cache in memory
|
||||
self.memory_cache[cache_key] = (df.copy(), datetime.now())
|
||||
return df
|
||||
else:
|
||||
logger.warning(f"No data available from exchange API for {symbol} {timeframe}")
|
||||
return None
|
||||
|
||||
# Fallback: Use DataProvider for latest data (startup mode or no time range)
|
||||
if self.startup_mode and not (start_time or end_time) and not force_refresh:
|
||||
logger.info(f"Loading data for {symbol} {timeframe} (startup mode: allow stale cache)")
|
||||
df = self.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=limit,
|
||||
allow_stale_cache=True
|
||||
)
|
||||
elif is_live_update:
|
||||
# For live updates, use get_latest_candles which combines cached + real-time data
|
||||
logger.debug(f"Getting live candles (cached + real-time) for {symbol} {timeframe}")
|
||||
df = self.data_provider.get_latest_candles(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# Log the latest candle timestamp to help debug stale data
|
||||
if df is not None and not df.empty:
|
||||
latest_timestamp = df.index[-1] if hasattr(df.index, '__getitem__') else df.iloc[-1].name
|
||||
logger.debug(f"Live update for {symbol} {timeframe}: latest candle at {latest_timestamp}")
|
||||
else:
|
||||
# Fetch from API and store in DuckDB (no time range specified)
|
||||
# For 1s/1m, logging every request is too verbose, use debug
|
||||
if timeframe in ['1s', '1m']:
|
||||
logger.debug(f"Fetching latest data from API for {symbol} {timeframe}")
|
||||
else:
|
||||
logger.info(f"Fetching latest data from API for {symbol} {timeframe}")
|
||||
|
||||
df = self.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=limit,
|
||||
refresh=True # Force API fetch
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Filter by time range with direction support
|
||||
df = self._filter_by_time_range(
|
||||
df.copy(),
|
||||
start_time,
|
||||
end_time,
|
||||
direction,
|
||||
limit
|
||||
)
|
||||
|
||||
# Cache in memory
|
||||
self.memory_cache[cache_key] = (df.copy(), datetime.now())
|
||||
|
||||
logger.info(f"Fetched {len(df)} candles for {symbol} {timeframe}")
|
||||
return df
|
||||
|
||||
logger.warning(f"No data available for {symbol} {timeframe}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading data for {symbol} {timeframe}: {e}")
|
||||
return None
|
||||
|
||||
def _fetch_from_exchange_api(self, symbol: str, timeframe: str,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
limit: int = 1000,
|
||||
direction: str = 'latest') -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Fetch historical data from exchange API (Binance/MEXC) with time range support
|
||||
|
||||
Args:
|
||||
symbol: Trading pair
|
||||
timeframe: Timeframe
|
||||
start_time: Start time for data range
|
||||
end_time: End time for data range
|
||||
limit: Maximum number of candles
|
||||
direction: 'latest', 'before', or 'after'
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data or None
|
||||
"""
|
||||
try:
|
||||
import requests
|
||||
from core.api_rate_limiter import get_rate_limiter
|
||||
|
||||
# Convert symbol format for Binance
|
||||
binance_symbol = symbol.replace('/', '').upper()
|
||||
|
||||
# Convert timeframe
|
||||
timeframe_map = {
|
||||
'1s': '1s', '1m': '1m', '5m': '5m', '15m': '15m', '30m': '30m',
|
||||
'1h': '1h', '4h': '4h', '1d': '1d'
|
||||
}
|
||||
binance_timeframe = timeframe_map.get(timeframe, '1m')
|
||||
|
||||
# Build initial API parameters
|
||||
params = {
|
||||
'symbol': binance_symbol,
|
||||
'interval': binance_timeframe
|
||||
}
|
||||
|
||||
# Add time range parameters if specified
|
||||
if direction == 'before' and end_time:
|
||||
params['endTime'] = int(end_time.timestamp() * 1000)
|
||||
elif direction == 'after' and start_time:
|
||||
params['startTime'] = int(start_time.timestamp() * 1000)
|
||||
elif start_time:
|
||||
params['startTime'] = int(start_time.timestamp() * 1000)
|
||||
if end_time and direction != 'before':
|
||||
params['endTime'] = int(end_time.timestamp() * 1000)
|
||||
|
||||
# Use rate limiter
|
||||
rate_limiter = get_rate_limiter()
|
||||
url = "https://api.binance.com/api/v3/klines"
|
||||
|
||||
logger.info(f"Fetching from Binance: {symbol} {timeframe} (direction={direction}, limit={limit})")
|
||||
|
||||
# Pagination variables
|
||||
all_dfs = []
|
||||
total_fetched = 0
|
||||
is_fetching_forward = (direction == 'after')
|
||||
|
||||
# Fetch loop
|
||||
while total_fetched < limit:
|
||||
# Calculate batch limit (max 1000 per request)
|
||||
batch_limit = min(limit - total_fetched, 1000)
|
||||
params['limit'] = batch_limit
|
||||
|
||||
response = rate_limiter.make_request('binance_api', url, 'GET', params=params)
|
||||
|
||||
if response is None or response.status_code != 200:
|
||||
if total_fetched == 0:
|
||||
logger.warning(f"Binance API failed, trying MEXC...")
|
||||
return self._fetch_from_mexc_with_time_range(
|
||||
symbol, timeframe, start_time, end_time, limit, direction
|
||||
)
|
||||
else:
|
||||
logger.warning("Binance API failed during pagination, returning partial data")
|
||||
break
|
||||
|
||||
data = response.json()
|
||||
|
||||
if not data:
|
||||
if total_fetched == 0:
|
||||
logger.warning(f"No data returned from Binance for {symbol} {timeframe}")
|
||||
return None
|
||||
else:
|
||||
break
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(data, columns=[
|
||||
'timestamp', 'open', 'high', 'low', 'close', 'volume',
|
||||
'close_time', 'quote_volume', 'trades', 'taker_buy_base',
|
||||
'taker_buy_quote', 'ignore'
|
||||
])
|
||||
|
||||
# Process columns
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||
df[col] = df[col].astype(float)
|
||||
|
||||
# Keep only OHLCV columns
|
||||
df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
|
||||
df = df.set_index('timestamp')
|
||||
df = df.sort_index()
|
||||
|
||||
if df.empty:
|
||||
break
|
||||
|
||||
all_dfs.append(df)
|
||||
total_fetched += len(df)
|
||||
|
||||
# Prepare for next batch
|
||||
if total_fetched >= limit:
|
||||
break
|
||||
|
||||
# Update params for next iteration
|
||||
if is_fetching_forward:
|
||||
# Next batch starts after the last candle
|
||||
last_ts = df.index[-1]
|
||||
params['startTime'] = int(last_ts.value / 10**6) + 1
|
||||
# Check if we exceeded end_time
|
||||
if 'endTime' in params and params['startTime'] > params['endTime']:
|
||||
break
|
||||
else:
|
||||
# Next batch ends before the first candle
|
||||
first_ts = df.index[0]
|
||||
params['endTime'] = int(first_ts.value / 10**6) - 1
|
||||
# Check if we exceeded start_time
|
||||
if 'startTime' in params and params['endTime'] < params['startTime']:
|
||||
break
|
||||
|
||||
# Combine all batches
|
||||
if not all_dfs:
|
||||
return None
|
||||
|
||||
final_df = pd.concat(all_dfs)
|
||||
final_df = final_df.sort_index()
|
||||
final_df = final_df[~final_df.index.duplicated(keep='first')]
|
||||
|
||||
logger.info(f" Fetched {len(final_df)} candles from Binance for {symbol} {timeframe} (requested {limit})")
|
||||
return final_df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching from exchange API: {e}")
|
||||
return None
|
||||
|
||||
def _fetch_from_mexc_with_time_range(self, symbol: str, timeframe: str,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
limit: int = 1000,
|
||||
direction: str = 'latest') -> Optional[pd.DataFrame]:
|
||||
"""Fetch from MEXC with time range support (fallback)"""
|
||||
try:
|
||||
# MEXC implementation would go here
|
||||
# For now, just return None to indicate unavailable
|
||||
logger.warning("MEXC time range fetch not implemented yet")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching from MEXC: {e}")
|
||||
return None
|
||||
|
||||
def _filter_by_time_range(self, df: pd.DataFrame,
|
||||
start_time: Optional[datetime],
|
||||
end_time: Optional[datetime],
|
||||
direction: str = 'latest',
|
||||
limit: int = 500) -> pd.DataFrame:
|
||||
"""
|
||||
Filter DataFrame by time range with direction support
|
||||
|
||||
Args:
|
||||
df: DataFrame to filter
|
||||
start_time: Start time filter
|
||||
end_time: End time filter
|
||||
direction: 'latest', 'before', or 'after'
|
||||
limit: Maximum number of candles
|
||||
|
||||
Returns:
|
||||
Filtered DataFrame
|
||||
"""
|
||||
try:
|
||||
# Ensure df index is datetime and timezone-aware (UTC)
|
||||
if not isinstance(df.index, pd.DatetimeIndex):
|
||||
df.index = pd.to_datetime(df.index, utc=True)
|
||||
elif df.index.tz is None:
|
||||
df.index = df.index.tz_localize('UTC')
|
||||
else:
|
||||
# If already aware but not UTC, convert
|
||||
if str(df.index.tz) != 'UTC' and str(df.index.tz) != 'datetime.timezone.utc':
|
||||
df.index = df.index.tz_convert('UTC')
|
||||
|
||||
# Ensure start_time/end_time are UTC
|
||||
if start_time and start_time.tzinfo is None:
|
||||
start_time = start_time.replace(tzinfo=timezone.utc)
|
||||
elif start_time:
|
||||
start_time = start_time.astimezone(timezone.utc)
|
||||
|
||||
if end_time and end_time.tzinfo is None:
|
||||
end_time = end_time.replace(tzinfo=timezone.utc)
|
||||
elif end_time:
|
||||
end_time = end_time.astimezone(timezone.utc)
|
||||
|
||||
if direction == 'before' and end_time:
|
||||
# Get candles BEFORE end_time
|
||||
df = df[df.index < end_time]
|
||||
# Return the most recent N candles before end_time
|
||||
df = df.tail(limit)
|
||||
elif direction == 'after' and start_time:
|
||||
# Get candles AFTER start_time
|
||||
df = df[df.index > start_time]
|
||||
# Return the oldest N candles after start_time
|
||||
df = df.head(limit)
|
||||
else:
|
||||
# Default: filter by range
|
||||
if start_time:
|
||||
df = df[df.index >= start_time]
|
||||
if end_time:
|
||||
df = df[df.index <= end_time]
|
||||
# Return most recent candles
|
||||
if len(df) > limit:
|
||||
df = df.tail(limit)
|
||||
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.error(f"Error filtering data: {e}")
|
||||
# Fallback: return original or empty
|
||||
return df if not df.empty else pd.DataFrame()
|
||||
|
||||
def get_multi_timeframe_data(self, symbol: str,
|
||||
timeframes: List[str],
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
limit: int = 2500) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
Get data for multiple timeframes at once
|
||||
|
||||
Args:
|
||||
symbol: Trading pair
|
||||
timeframes: List of timeframes
|
||||
start_time: Start time for data range
|
||||
end_time: End time for data range
|
||||
limit: Maximum number of candles per timeframe
|
||||
|
||||
Returns:
|
||||
Dictionary mapping timeframe to DataFrame
|
||||
"""
|
||||
result = {}
|
||||
|
||||
for timeframe in timeframes:
|
||||
df = self.get_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
if df is not None:
|
||||
result[timeframe] = df
|
||||
|
||||
logger.info(f"Loaded data for {len(result)}/{len(timeframes)} timeframes")
|
||||
return result
|
||||
|
||||
def prefetch_data(self, symbol: str, timeframes: List[str], limit: int = 1000):
|
||||
"""
|
||||
Prefetch data for smooth scrolling
|
||||
|
||||
Args:
|
||||
symbol: Trading pair
|
||||
timeframes: List of timeframes to prefetch
|
||||
limit: Number of candles to prefetch
|
||||
"""
|
||||
logger.info(f"Prefetching data for {symbol}: {timeframes}")
|
||||
|
||||
for timeframe in timeframes:
|
||||
self.get_data(symbol, timeframe, limit=limit)
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear memory cache"""
|
||||
self.memory_cache.clear()
|
||||
logger.info("Memory cache cleared")
|
||||
|
||||
def disable_startup_mode(self):
|
||||
"""Disable startup mode to fetch fresh data"""
|
||||
self.startup_mode = False
|
||||
logger.info("Startup mode disabled - will fetch fresh data on next request")
|
||||
|
||||
def get_data_boundaries(self, symbol: str, timeframe: str) -> Tuple[Optional[datetime], Optional[datetime]]:
|
||||
"""
|
||||
Get the earliest and latest available data timestamps
|
||||
|
||||
Args:
|
||||
symbol: Trading pair
|
||||
timeframe: Timeframe
|
||||
|
||||
Returns:
|
||||
Tuple of (earliest_time, latest_time) or (None, None) if no data
|
||||
"""
|
||||
try:
|
||||
df = self.get_data(symbol, timeframe, limit=10000)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
return (df.index.min(), df.index.max())
|
||||
|
||||
return (None, None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting data boundaries: {e}")
|
||||
return (None, None)
|
||||
|
||||
|
||||
class TimeRangeManager:
|
||||
"""Manages time range calculations and data prefetching"""
|
||||
|
||||
def __init__(self, data_loader: HistoricalDataLoader):
|
||||
"""
|
||||
Initialize with data loader
|
||||
|
||||
Args:
|
||||
data_loader: HistoricalDataLoader instance
|
||||
"""
|
||||
self.data_loader = data_loader
|
||||
|
||||
# Time range presets in seconds
|
||||
self.range_presets = {
|
||||
'1h': 3600,
|
||||
'4h': 14400,
|
||||
'1d': 86400,
|
||||
'1w': 604800,
|
||||
'1M': 2592000
|
||||
}
|
||||
|
||||
logger.info("TimeRangeManager initialized")
|
||||
|
||||
def calculate_time_range(self, center_time: datetime,
|
||||
range_preset: str) -> Tuple[datetime, datetime]:
|
||||
"""
|
||||
Calculate start and end times for a range preset
|
||||
|
||||
Args:
|
||||
center_time: Center point of the range
|
||||
range_preset: Range preset ('1h', '4h', '1d', '1w', '1M')
|
||||
|
||||
Returns:
|
||||
Tuple of (start_time, end_time)
|
||||
"""
|
||||
range_seconds = self.range_presets.get(range_preset, 86400)
|
||||
half_range = timedelta(seconds=range_seconds / 2)
|
||||
|
||||
start_time = center_time - half_range
|
||||
end_time = center_time + half_range
|
||||
|
||||
return (start_time, end_time)
|
||||
|
||||
def get_navigation_increment(self, range_preset: str) -> timedelta:
|
||||
"""
|
||||
Get time increment for navigation (10% of range)
|
||||
|
||||
Args:
|
||||
range_preset: Range preset
|
||||
|
||||
Returns:
|
||||
timedelta for navigation increment
|
||||
"""
|
||||
range_seconds = self.range_presets.get(range_preset, 86400)
|
||||
increment_seconds = range_seconds / 10
|
||||
|
||||
return timedelta(seconds=increment_seconds)
|
||||
|
||||
def prefetch_adjacent_ranges(self, symbol: str, timeframes: List[str],
|
||||
center_time: datetime, range_preset: str):
|
||||
"""
|
||||
Prefetch data for adjacent time ranges for smooth scrolling
|
||||
|
||||
Args:
|
||||
symbol: Trading pair
|
||||
timeframes: List of timeframes
|
||||
center_time: Current center time
|
||||
range_preset: Current range preset
|
||||
"""
|
||||
increment = self.get_navigation_increment(range_preset)
|
||||
|
||||
# Prefetch previous range
|
||||
prev_center = center_time - increment
|
||||
prev_start, prev_end = self.calculate_time_range(prev_center, range_preset)
|
||||
|
||||
# Prefetch next range
|
||||
next_center = center_time + increment
|
||||
next_start, next_end = self.calculate_time_range(next_center, range_preset)
|
||||
|
||||
logger.debug(f"Prefetching adjacent ranges for {symbol}")
|
||||
|
||||
# Prefetch in background (non-blocking)
|
||||
import threading
|
||||
|
||||
def prefetch():
|
||||
for timeframe in timeframes:
|
||||
self.data_loader.get_data(symbol, timeframe, prev_start, prev_end)
|
||||
self.data_loader.get_data(symbol, timeframe, next_start, next_end)
|
||||
|
||||
thread = threading.Thread(target=prefetch, daemon=True)
|
||||
thread.start()
|
||||
@@ -1,389 +0,0 @@
|
||||
"""
|
||||
Event-Driven Inference Training System
|
||||
|
||||
This system provides:
|
||||
1. Reference-based inference frame storage (no 600-candle copies)
|
||||
2. Subscription system for candle completion and pivot events
|
||||
3. Flexible training methods (backprop for Transformer, others for different models)
|
||||
4. Integration with DuckDB for efficient data retrieval
|
||||
|
||||
Architecture:
|
||||
- Inference frames stored as references (timestamp ranges) in DuckDB
|
||||
- Training adapter subscribes to data provider events
|
||||
- Time-based triggers: candle completion (known result time)
|
||||
- Event-based triggers: pivot points (L2L, L2H, etc. - unknown timing)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Dict, List, Optional, Callable, Tuple, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingTriggerType(Enum):
|
||||
"""Types of training triggers"""
|
||||
CANDLE_COMPLETION = "candle_completion" # Time-based: next candle closes
|
||||
PIVOT_EVENT = "pivot_event" # Event-based: pivot detected (L2L, L2H, etc.)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceFrameReference:
|
||||
"""
|
||||
Reference to inference data stored in DuckDB with human-readable prediction outputs.
|
||||
No copying - just store timestamp ranges and query when needed.
|
||||
"""
|
||||
inference_id: str # Unique ID for this inference
|
||||
symbol: str
|
||||
timeframe: str
|
||||
prediction_timestamp: datetime # When prediction was made
|
||||
target_timestamp: Optional[datetime] = None # When result will be available (for candles)
|
||||
|
||||
# Reference to data in DuckDB (timestamp range)
|
||||
data_range_start: datetime # Start of 600-candle window
|
||||
data_range_end: datetime # End of 600-candle window
|
||||
|
||||
# Normalization parameters (small, can be stored)
|
||||
norm_params: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
||||
|
||||
# ENHANCED: Human-readable prediction outputs
|
||||
predicted_action: Optional[str] = None # 'BUY', 'SELL', 'HOLD'
|
||||
predicted_candle: Optional[Dict[str, List[float]]] = None # {timeframe: [O,H,L,C,V]}
|
||||
predicted_price: Optional[float] = None # Main predicted price
|
||||
confidence: float = 0.0
|
||||
|
||||
# Model metadata for decision making
|
||||
model_type: str = 'transformer' # 'transformer', 'cnn', 'dqn'
|
||||
prediction_steps: int = 1 # Number of steps predicted ahead
|
||||
|
||||
# Training status
|
||||
trained: bool = False
|
||||
training_timestamp: Optional[datetime] = None
|
||||
training_loss: Optional[float] = None
|
||||
training_accuracy: Optional[float] = None
|
||||
|
||||
# Actual results (filled when candle completes)
|
||||
actual_candle: Optional[List[float]] = None # [O,H,L,C,V]
|
||||
actual_price: Optional[float] = None
|
||||
prediction_error: Optional[float] = None # |predicted - actual|
|
||||
direction_correct: Optional[bool] = None # Did we predict direction correctly?
|
||||
|
||||
|
||||
@dataclass
|
||||
class PivotEvent:
|
||||
"""Pivot point event for training"""
|
||||
symbol: str
|
||||
timeframe: str
|
||||
timestamp: datetime
|
||||
pivot_type: str # 'L2L', 'L2H', 'L3L', 'L3H', etc.
|
||||
price: float
|
||||
level: int # Pivot level (2, 3, 4, etc.)
|
||||
strength: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class CandleCompletionEvent:
|
||||
"""Candle completion event for training"""
|
||||
symbol: str
|
||||
timeframe: str
|
||||
timestamp: datetime # When candle closed
|
||||
ohlcv: Dict[str, float] # {'open', 'high', 'low', 'close', 'volume'}
|
||||
|
||||
|
||||
class TrainingEventSubscriber:
|
||||
"""
|
||||
Subscriber interface for training events.
|
||||
Training adapters implement this to receive callbacks.
|
||||
"""
|
||||
|
||||
def on_candle_completion(self, event: CandleCompletionEvent, inference_ref: Optional[InferenceFrameReference]) -> None:
|
||||
"""
|
||||
Called when a candle completes.
|
||||
|
||||
Args:
|
||||
event: Candle completion event with actual OHLCV
|
||||
inference_ref: Reference to inference frame if available (for this candle)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def on_pivot_event(self, event: PivotEvent, inference_refs: List[InferenceFrameReference]) -> None:
|
||||
"""
|
||||
Called when a pivot point is detected.
|
||||
|
||||
Args:
|
||||
event: Pivot event (L2L, L2H, etc.)
|
||||
inference_refs: List of inference frames that predicted this pivot
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class InferenceTrainingCoordinator:
|
||||
"""
|
||||
Coordinates inference frame storage and training event distribution.
|
||||
|
||||
NOTE: This should be integrated into TradingOrchestrator to reduce duplication.
|
||||
The orchestrator already manages models, training, and predictions, so it's the
|
||||
natural place for inference-training coordination.
|
||||
|
||||
Responsibilities:
|
||||
1. Store inference frame references (not copies)
|
||||
2. Register training subscriptions (candle/pivot events)
|
||||
3. Match inference frames to actual results
|
||||
4. Trigger training callbacks
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider, duckdb_storage=None):
|
||||
"""
|
||||
Initialize coordinator
|
||||
|
||||
Args:
|
||||
data_provider: DataProvider instance for event subscriptions
|
||||
duckdb_storage: DuckDBStorage instance for data retrieval
|
||||
"""
|
||||
self.data_provider = data_provider
|
||||
self.duckdb_storage = duckdb_storage
|
||||
|
||||
# Store inference frame references (by inference_id)
|
||||
self.inference_frames: Dict[str, InferenceFrameReference] = {}
|
||||
|
||||
# Index by target timestamp for candle matching
|
||||
self.candle_inferences: Dict[Tuple[str, str, datetime], List[str]] = {} # (symbol, timeframe, timestamp) -> [inference_ids]
|
||||
|
||||
# Index by pivot type for pivot matching
|
||||
self.pivot_subscriptions: Dict[Tuple[str, str, str], List[str]] = {} # (symbol, timeframe, pivot_type) -> [inference_ids]
|
||||
|
||||
# Training subscribers
|
||||
self.training_subscribers: List[TrainingEventSubscriber] = []
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
logger.info("InferenceTrainingCoordinator initialized")
|
||||
|
||||
def register_inference_frame(self, inference_ref: InferenceFrameReference) -> None:
|
||||
"""
|
||||
Register an inference frame reference (stored in DuckDB, not copied).
|
||||
|
||||
Args:
|
||||
inference_ref: Reference to inference data
|
||||
"""
|
||||
with self.lock:
|
||||
self.inference_frames[inference_ref.inference_id] = inference_ref
|
||||
|
||||
# Index by target timestamp for candle matching
|
||||
if inference_ref.target_timestamp:
|
||||
key = (inference_ref.symbol, inference_ref.timeframe, inference_ref.target_timestamp)
|
||||
if key not in self.candle_inferences:
|
||||
self.candle_inferences[key] = []
|
||||
self.candle_inferences[key].append(inference_ref.inference_id)
|
||||
|
||||
logger.debug(f"Registered inference frame: {inference_ref.inference_id} for {inference_ref.symbol} {inference_ref.timeframe}")
|
||||
|
||||
def subscribe_to_candle_completion(self, subscriber: TrainingEventSubscriber,
|
||||
symbol: str, timeframe: str) -> None:
|
||||
"""
|
||||
Subscribe to candle completion events for a symbol/timeframe.
|
||||
|
||||
Args:
|
||||
subscriber: Training subscriber
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe (1m, 5m, etc.)
|
||||
"""
|
||||
with self.lock:
|
||||
if subscriber not in self.training_subscribers:
|
||||
self.training_subscribers.append(subscriber)
|
||||
|
||||
# Register with data provider for candle completion callbacks
|
||||
if hasattr(self.data_provider, 'subscribe_candle_completion'):
|
||||
self.data_provider.subscribe_candle_completion(
|
||||
callback=lambda event: self._handle_candle_completion(event),
|
||||
symbol=symbol,
|
||||
timeframe=timeframe
|
||||
)
|
||||
|
||||
logger.info(f"Subscribed to candle completion: {symbol} {timeframe}")
|
||||
|
||||
def subscribe_to_pivot_events(self, subscriber: TrainingEventSubscriber,
|
||||
symbol: str, timeframe: str,
|
||||
pivot_types: List[str]) -> None:
|
||||
"""
|
||||
Subscribe to pivot events (L2L, L2H, etc.).
|
||||
|
||||
Args:
|
||||
subscriber: Training subscriber
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
pivot_types: List of pivot types to subscribe to (e.g., ['L2L', 'L2H', 'L3L'])
|
||||
"""
|
||||
with self.lock:
|
||||
if subscriber not in self.training_subscribers:
|
||||
self.training_subscribers.append(subscriber)
|
||||
|
||||
# Register pivot subscriptions
|
||||
for pivot_type in pivot_types:
|
||||
key = (symbol, timeframe, pivot_type)
|
||||
if key not in self.pivot_subscriptions:
|
||||
self.pivot_subscriptions[key] = []
|
||||
# Store subscriber reference (we'll match inference frames later)
|
||||
|
||||
# Register with data provider for pivot callbacks
|
||||
if hasattr(self.data_provider, 'subscribe_pivot_events'):
|
||||
self.data_provider.subscribe_pivot_events(
|
||||
callback=lambda event: self._handle_pivot_event(event),
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
pivot_types=pivot_types
|
||||
)
|
||||
|
||||
logger.info(f"Subscribed to pivot events: {symbol} {timeframe} {pivot_types}")
|
||||
|
||||
def _handle_pivot_event(self, event: PivotEvent) -> None:
|
||||
"""Handle pivot event from data provider and trigger training"""
|
||||
with self.lock:
|
||||
# Find matching inference frames (predictions made before this pivot)
|
||||
# Look for predictions within a reasonable window (e.g., last 5 minutes)
|
||||
window_start = event.timestamp - timedelta(minutes=5)
|
||||
|
||||
matching_refs = []
|
||||
for inference_ref in self.inference_frames.values():
|
||||
if (inference_ref.symbol == event.symbol and
|
||||
inference_ref.timeframe == event.timeframe and
|
||||
inference_ref.prediction_timestamp >= window_start and
|
||||
not inference_ref.trained):
|
||||
matching_refs.append(inference_ref)
|
||||
|
||||
# Notify subscribers
|
||||
for subscriber in self.training_subscribers:
|
||||
try:
|
||||
subscriber.on_pivot_event(event, matching_refs)
|
||||
# Mark as trained
|
||||
for ref in matching_refs:
|
||||
ref.trained = True
|
||||
ref.training_timestamp = datetime.now(timezone.utc)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pivot event callback: {e}", exc_info=True)
|
||||
|
||||
def _handle_candle_completion(self, event: CandleCompletionEvent) -> None:
|
||||
"""Handle candle completion event and trigger training"""
|
||||
with self.lock:
|
||||
# Find matching inference frames
|
||||
key = (event.symbol, event.timeframe, event.timestamp)
|
||||
inference_ids = self.candle_inferences.get(key, [])
|
||||
|
||||
# Get inference references
|
||||
inference_refs = [self.inference_frames[iid] for iid in inference_ids
|
||||
if iid in self.inference_frames and not self.inference_frames[iid].trained]
|
||||
|
||||
# Notify subscribers
|
||||
for subscriber in self.training_subscribers:
|
||||
for inference_ref in inference_refs:
|
||||
try:
|
||||
subscriber.on_candle_completion(event, inference_ref)
|
||||
# Mark as trained
|
||||
inference_ref.trained = True
|
||||
inference_ref.training_timestamp = datetime.now(timezone.utc)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in candle completion callback: {e}", exc_info=True)
|
||||
|
||||
|
||||
def get_inference_data(self, inference_ref: InferenceFrameReference) -> Optional[Dict]:
|
||||
"""
|
||||
Retrieve inference data from DuckDB using reference.
|
||||
|
||||
This queries DuckDB efficiently using the timestamp range stored in the reference.
|
||||
No copying - data is retrieved on-demand when training is triggered.
|
||||
|
||||
Args:
|
||||
inference_ref: Reference to inference frame
|
||||
|
||||
Returns:
|
||||
Dict with model inputs (price_data_1m, price_data_1h, etc.) or None
|
||||
"""
|
||||
if not self.data_provider:
|
||||
logger.warning("Data provider not available for inference data retrieval")
|
||||
return None
|
||||
|
||||
try:
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# Query data provider for OHLCV data (it uses DuckDB internally)
|
||||
# This is efficient - DuckDB handles the query
|
||||
model_inputs = {}
|
||||
|
||||
# Use norm_params from reference if available, otherwise calculate
|
||||
norm_params = inference_ref.norm_params.copy() if inference_ref.norm_params else {}
|
||||
|
||||
for tf in ['1s', '1m', '1h', '1d']:
|
||||
# Get 600 candles - data_provider queries DuckDB efficiently
|
||||
df = self.data_provider.get_historical_data(
|
||||
symbol=inference_ref.symbol,
|
||||
timeframe=tf,
|
||||
limit=600
|
||||
)
|
||||
|
||||
if df is not None and len(df) >= 600:
|
||||
# Take last 600 candles
|
||||
df = df.tail(600)
|
||||
|
||||
# Extract OHLCV arrays
|
||||
opens = df['open'].values.astype(np.float32)
|
||||
highs = df['high'].values.astype(np.float32)
|
||||
lows = df['low'].values.astype(np.float32)
|
||||
closes = df['close'].values.astype(np.float32)
|
||||
volumes = df['volume'].values.astype(np.float32)
|
||||
|
||||
# Stack OHLCV [seq_len, 5]
|
||||
ohlcv = np.stack([opens, highs, lows, closes, volumes], axis=-1)
|
||||
|
||||
# Calculate normalization params if not stored
|
||||
if tf not in norm_params:
|
||||
price_min = np.min(ohlcv[:, :4])
|
||||
price_max = np.max(ohlcv[:, :4])
|
||||
volume_min = np.min(ohlcv[:, 4])
|
||||
volume_max = np.max(ohlcv[:, 4])
|
||||
|
||||
if price_max == price_min:
|
||||
price_max += 1.0
|
||||
if volume_max == volume_min:
|
||||
volume_max += 1.0
|
||||
|
||||
norm_params[tf] = {
|
||||
'price_min': float(price_min),
|
||||
'price_max': float(price_max),
|
||||
'volume_min': float(volume_min),
|
||||
'volume_max': float(volume_max)
|
||||
}
|
||||
|
||||
# Normalize using params
|
||||
params = norm_params[tf]
|
||||
price_min = params['price_min']
|
||||
price_max = params['price_max']
|
||||
vol_min = params['volume_min']
|
||||
vol_max = params['volume_max']
|
||||
|
||||
ohlcv[:, :4] = (ohlcv[:, :4] - price_min) / (price_max - price_min)
|
||||
ohlcv[:, 4] = (ohlcv[:, 4] - vol_min) / (vol_max - vol_min)
|
||||
|
||||
# Convert to tensor [1, seq_len, 5]
|
||||
candles_tensor = torch.tensor(ohlcv, dtype=torch.float32).unsqueeze(0)
|
||||
model_inputs[f'price_data_{tf}'] = candles_tensor
|
||||
|
||||
# Store norm_params in reference for future use
|
||||
inference_ref.norm_params = norm_params
|
||||
|
||||
# Add placeholder data for other inputs
|
||||
device = next(iter(model_inputs.values())).device if model_inputs else torch.device('cpu')
|
||||
model_inputs['tech_data'] = torch.zeros(1, 40, dtype=torch.float32, device=device)
|
||||
model_inputs['market_data'] = torch.zeros(1, 30, dtype=torch.float32, device=device)
|
||||
model_inputs['cob_data'] = torch.zeros(1, 600, 100, dtype=torch.float32, device=device)
|
||||
|
||||
return model_inputs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving inference data: {e}", exc_info=True)
|
||||
return None
|
||||
@@ -1,322 +0,0 @@
|
||||
"""
|
||||
Live Pivot Trainer - Automatic Training on L2 Pivot Points
|
||||
|
||||
This module monitors live 1s and 1m charts for L2 pivot points (peaks/troughs)
|
||||
and automatically creates training samples when they occur.
|
||||
|
||||
Integrates with:
|
||||
- Williams Market Structure for pivot detection
|
||||
- Real Training Adapter for model training
|
||||
- Data Provider for live market data
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from datetime import datetime, timezone
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LivePivotTrainer:
|
||||
"""
|
||||
Monitors live charts for L2 pivots and automatically trains models
|
||||
|
||||
Features:
|
||||
- Detects L2 pivot points on 1s and 1m timeframes
|
||||
- Creates training samples automatically
|
||||
- Trains models in background without blocking inference
|
||||
- Tracks training history to avoid duplicate training
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator, data_provider, training_adapter):
|
||||
"""
|
||||
Initialize Live Pivot Trainer
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance
|
||||
data_provider: DataProvider for market data
|
||||
training_adapter: RealTrainingAdapter for training
|
||||
"""
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = data_provider
|
||||
self.training_adapter = training_adapter
|
||||
|
||||
# Tracking
|
||||
self.running = False
|
||||
self.trained_pivots = deque(maxlen=1000) # Track last 1000 trained pivots
|
||||
self.pivot_history = {
|
||||
'1s': deque(maxlen=100),
|
||||
'1m': deque(maxlen=100)
|
||||
}
|
||||
|
||||
# Configuration
|
||||
self.check_interval = 5 # Check for new pivots every 5 seconds
|
||||
self.min_pivot_spacing = 60 # Minimum 60 seconds between training on same timeframe
|
||||
self.last_training_time = {
|
||||
'1s': 0,
|
||||
'1m': 0
|
||||
}
|
||||
|
||||
# Williams Market Structure for pivot detection
|
||||
try:
|
||||
from core.williams_market_structure import WilliamsMarketStructure
|
||||
# Fix: WilliamsMarketStructure.__init__ does not accept num_levels
|
||||
# It defaults to 5 levels internally
|
||||
self.williams_1s = WilliamsMarketStructure()
|
||||
self.williams_1m = WilliamsMarketStructure()
|
||||
logger.info("Williams Market Structure initialized for pivot detection")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Williams Market Structure: {e}")
|
||||
self.williams_1s = None
|
||||
self.williams_1m = None
|
||||
|
||||
logger.info("LivePivotTrainer initialized")
|
||||
|
||||
def start(self, symbol: str = 'ETH/USDT'):
|
||||
"""Start monitoring for L2 pivots"""
|
||||
if self.running:
|
||||
logger.warning("LivePivotTrainer already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.symbol = symbol
|
||||
|
||||
# Start monitoring thread
|
||||
thread = threading.Thread(
|
||||
target=self._monitoring_loop,
|
||||
args=(symbol,),
|
||||
daemon=True
|
||||
)
|
||||
thread.start()
|
||||
|
||||
logger.info(f"LivePivotTrainer started for {symbol}")
|
||||
|
||||
def stop(self):
|
||||
"""Stop monitoring"""
|
||||
self.running = False
|
||||
logger.info("LivePivotTrainer stopped")
|
||||
|
||||
def _monitoring_loop(self, symbol: str):
|
||||
"""Main monitoring loop - checks for new L2 pivots"""
|
||||
logger.info(f"LivePivotTrainer monitoring loop started for {symbol}")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Check 1s timeframe
|
||||
self._check_timeframe_for_pivots(symbol, '1s')
|
||||
|
||||
# Check 1m timeframe
|
||||
self._check_timeframe_for_pivots(symbol, '1m')
|
||||
|
||||
# Sleep before next check
|
||||
time.sleep(self.check_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in LivePivotTrainer monitoring loop: {e}")
|
||||
time.sleep(10) # Wait longer on error
|
||||
|
||||
def _check_timeframe_for_pivots(self, symbol: str, timeframe: str):
|
||||
"""
|
||||
Check a specific timeframe for new L2 pivots
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: '1s' or '1m'
|
||||
"""
|
||||
try:
|
||||
# Rate limiting - don't train too frequently on same timeframe
|
||||
current_time = time.time()
|
||||
if current_time - self.last_training_time[timeframe] < self.min_pivot_spacing:
|
||||
return
|
||||
|
||||
# Get recent candles
|
||||
candles = self.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=200 # Need enough candles to detect pivots
|
||||
)
|
||||
|
||||
if candles is None or candles.empty:
|
||||
logger.debug(f"No candles available for {symbol} {timeframe}")
|
||||
return
|
||||
|
||||
# Detect pivots using Williams Market Structure
|
||||
williams = self.williams_1s if timeframe == '1s' else self.williams_1m
|
||||
if williams is None:
|
||||
return
|
||||
|
||||
# Prepare data for Williams Market Structure
|
||||
# Convert DataFrame to numpy array format
|
||||
df = candles.copy()
|
||||
ohlcv_array = df[['open', 'high', 'low', 'close', 'volume']].copy()
|
||||
|
||||
# Handle timestamp conversion based on index type
|
||||
if isinstance(df.index, pd.DatetimeIndex):
|
||||
# Convert ns to ms
|
||||
timestamps = df.index.astype(np.int64) // 10**6
|
||||
else:
|
||||
# Assume it's already timestamp or handle accordingly
|
||||
timestamps = df.index
|
||||
|
||||
ohlcv_array.insert(0, 'timestamp', timestamps)
|
||||
ohlcv_array = ohlcv_array.to_numpy()
|
||||
|
||||
# Calculate pivots
|
||||
pivot_levels = williams.calculate_recursive_pivot_points(ohlcv_array)
|
||||
|
||||
if not pivot_levels or 2 not in pivot_levels:
|
||||
return
|
||||
|
||||
# Get Level 2 pivots
|
||||
l2_trend_level = pivot_levels[2]
|
||||
l2_pivots_objs = l2_trend_level.pivot_points
|
||||
|
||||
if not l2_pivots_objs:
|
||||
return
|
||||
|
||||
# Check for new L2 pivots (not in history)
|
||||
new_pivots = []
|
||||
for p in l2_pivots_objs:
|
||||
# Convert pivot object to dict for compatibility
|
||||
pivot_dict = {
|
||||
'timestamp': p.timestamp, # Keep as datetime object for compatibility
|
||||
'price': p.price,
|
||||
'type': p.pivot_type,
|
||||
'strength': p.strength
|
||||
}
|
||||
|
||||
pivot_id = f"{symbol}_{timeframe}_{pivot_dict['timestamp']}_{pivot_dict['type']}"
|
||||
|
||||
if pivot_id not in self.trained_pivots:
|
||||
new_pivots.append(pivot_dict)
|
||||
self.trained_pivots.append(pivot_id)
|
||||
|
||||
if new_pivots:
|
||||
logger.info(f"Found {len(new_pivots)} new L2 pivots on {symbol} {timeframe}")
|
||||
|
||||
# Train on the most recent pivot
|
||||
latest_pivot = new_pivots[-1]
|
||||
self._train_on_pivot(symbol, timeframe, latest_pivot, candles)
|
||||
|
||||
self.last_training_time[timeframe] = current_time
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking {timeframe} for pivots: {e}")
|
||||
|
||||
def _train_on_pivot(self, symbol: str, timeframe: str, pivot: Dict, candles):
|
||||
"""
|
||||
Create training sample from pivot and train model
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe of pivot
|
||||
pivot: Pivot point data
|
||||
candles: DataFrame with OHLCV data
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Training on L2 {pivot['type']} pivot @ {pivot['price']} on {symbol} {timeframe}")
|
||||
|
||||
# Determine trade direction based on pivot type
|
||||
if pivot['type'] == 'high':
|
||||
# High pivot = potential SHORT entry
|
||||
direction = 'SHORT'
|
||||
action = 'SELL'
|
||||
else:
|
||||
# Low pivot = potential LONG entry
|
||||
direction = 'LONG'
|
||||
action = 'BUY'
|
||||
|
||||
# Create training sample
|
||||
training_sample = {
|
||||
'test_case_id': f"live_pivot_{symbol}_{timeframe}_{pivot['timestamp']}",
|
||||
'symbol': symbol,
|
||||
'timestamp': pivot['timestamp'],
|
||||
'action': action,
|
||||
'expected_outcome': {
|
||||
'direction': direction,
|
||||
'entry_price': pivot['price'],
|
||||
'exit_price': None, # Will be determined by model
|
||||
'profit_loss_pct': 0.0, # Unknown yet
|
||||
'holding_period_seconds': 300 # 5 minutes default
|
||||
},
|
||||
'training_config': {
|
||||
'timeframes': ['1s', '1m', '1h', '1d'],
|
||||
'candles_per_timeframe': 200
|
||||
},
|
||||
'annotation_metadata': {
|
||||
'source': 'live_pivot_detection',
|
||||
'pivot_level': 'L2',
|
||||
'pivot_type': pivot['type'],
|
||||
'confidence': pivot.get('strength', 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
# Train model in background (non-blocking)
|
||||
thread = threading.Thread(
|
||||
target=self._background_training,
|
||||
args=(training_sample,),
|
||||
daemon=True
|
||||
)
|
||||
thread.start()
|
||||
|
||||
logger.info(f"Started background training on L2 pivot")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training on pivot: {e}")
|
||||
|
||||
def _background_training(self, training_sample: Dict):
|
||||
"""
|
||||
Execute training in background thread
|
||||
|
||||
Args:
|
||||
training_sample: Training sample data
|
||||
"""
|
||||
try:
|
||||
# Use Transformer model for live pivot training
|
||||
model_name = 'Transformer'
|
||||
|
||||
logger.info(f"Background training started for {training_sample['test_case_id']}")
|
||||
|
||||
# Start training session
|
||||
training_id = self.training_adapter.start_training(
|
||||
model_name=model_name,
|
||||
test_cases=[training_sample]
|
||||
)
|
||||
|
||||
logger.info(f"Live pivot training session started: {training_id}")
|
||||
|
||||
# Monitor training (optional - could poll status)
|
||||
# For now, just fire and forget
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background training: {e}")
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""Get training statistics"""
|
||||
return {
|
||||
'running': self.running,
|
||||
'total_trained_pivots': len(self.trained_pivots),
|
||||
'last_training_1s': self.last_training_time.get('1s', 0),
|
||||
'last_training_1m': self.last_training_time.get('1m', 0),
|
||||
'pivot_history_1s': len(self.pivot_history['1s']),
|
||||
'pivot_history_1m': len(self.pivot_history['1m'])
|
||||
}
|
||||
|
||||
|
||||
# Global instance
|
||||
_live_pivot_trainer = None
|
||||
|
||||
|
||||
def get_live_pivot_trainer(orchestrator=None, data_provider=None, training_adapter=None):
|
||||
"""Get or create global LivePivotTrainer instance"""
|
||||
global _live_pivot_trainer
|
||||
|
||||
if _live_pivot_trainer is None and all([orchestrator, data_provider, training_adapter]):
|
||||
_live_pivot_trainer = LivePivotTrainer(orchestrator, data_provider, training_adapter)
|
||||
|
||||
return _live_pivot_trainer
|
||||
@@ -1 +0,0 @@
|
||||
once there are 2 Low or 2 high Level 2 pivots AFTER the trend line prediction, we should make a trend line and do backpropagation to adjust our model predictions of trend
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,299 +0,0 @@
|
||||
"""
|
||||
Training Data Fetcher - Dynamic OHLCV data retrieval for model training
|
||||
|
||||
Fetches ±5 minutes of OHLCV data around annotated events from cache/database
|
||||
instead of storing it in JSON files. This allows efficient training on optimal timing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytz
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingDataFetcher:
|
||||
"""
|
||||
Fetches training data dynamically from cache/database for annotated events.
|
||||
|
||||
Key Features:
|
||||
- Fetches ±5 minutes of OHLCV data around entry/exit points
|
||||
- Generates training labels for optimal timing detection
|
||||
- Supports multiple timeframes (1s, 1m, 1h, 1d)
|
||||
- Efficient memory usage (no JSON storage)
|
||||
- Real-time data from cache/database
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider):
|
||||
"""
|
||||
Initialize training data fetcher
|
||||
|
||||
Args:
|
||||
data_provider: DataProvider instance for fetching OHLCV data
|
||||
"""
|
||||
self.data_provider = data_provider
|
||||
logger.info("TrainingDataFetcher initialized")
|
||||
|
||||
def fetch_training_data_for_annotation(self, annotation: Dict,
|
||||
context_window_minutes: int = 5) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch complete training data for an annotation
|
||||
|
||||
Args:
|
||||
annotation: Annotation metadata (from annotations_db.json)
|
||||
context_window_minutes: Minutes before/after entry to include
|
||||
|
||||
Returns:
|
||||
Dict with market_state, training_labels, and expected_outcome
|
||||
"""
|
||||
try:
|
||||
# Parse timestamps
|
||||
entry_time = datetime.fromisoformat(annotation['entry']['timestamp'].replace('Z', '+00:00'))
|
||||
exit_time = datetime.fromisoformat(annotation['exit']['timestamp'].replace('Z', '+00:00'))
|
||||
|
||||
symbol = annotation['symbol']
|
||||
direction = annotation['direction']
|
||||
|
||||
logger.info(f"Fetching training data for {symbol} at {entry_time} (±{context_window_minutes}min)")
|
||||
|
||||
# Fetch OHLCV data for all timeframes around entry time
|
||||
market_state = self._fetch_market_state_at_time(
|
||||
symbol=symbol,
|
||||
timestamp=entry_time,
|
||||
context_window_minutes=context_window_minutes
|
||||
)
|
||||
|
||||
# Generate training labels for optimal timing detection
|
||||
training_labels = self._generate_timing_labels(
|
||||
market_state=market_state,
|
||||
entry_time=entry_time,
|
||||
exit_time=exit_time,
|
||||
direction=direction
|
||||
)
|
||||
|
||||
# Prepare expected outcome
|
||||
expected_outcome = {
|
||||
"direction": direction,
|
||||
"profit_loss_pct": annotation['profit_loss_pct'],
|
||||
"entry_price": annotation['entry']['price'],
|
||||
"exit_price": annotation['exit']['price'],
|
||||
"holding_period_seconds": (exit_time - entry_time).total_seconds()
|
||||
}
|
||||
|
||||
return {
|
||||
"test_case_id": f"annotation_{annotation['annotation_id']}",
|
||||
"symbol": symbol,
|
||||
"timestamp": annotation['entry']['timestamp'],
|
||||
"action": "BUY" if direction == "LONG" else "SELL",
|
||||
"market_state": market_state,
|
||||
"training_labels": training_labels,
|
||||
"expected_outcome": expected_outcome,
|
||||
"annotation_metadata": {
|
||||
"annotator": "manual",
|
||||
"confidence": 1.0,
|
||||
"notes": annotation.get('notes', ''),
|
||||
"created_at": annotation.get('created_at'),
|
||||
"timeframe": annotation.get('timeframe', '1m')
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching training data for annotation: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {}
|
||||
|
||||
def _fetch_market_state_at_time(self, symbol: str, timestamp: datetime,
|
||||
context_window_minutes: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch market state at specific time from cache/database
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timestamp: Target timestamp
|
||||
context_window_minutes: Minutes before/after to include
|
||||
|
||||
Returns:
|
||||
Dict with OHLCV data for all timeframes
|
||||
"""
|
||||
try:
|
||||
# Use data provider's method to get market state
|
||||
market_state = self.data_provider.get_market_state_at_time(
|
||||
symbol=symbol,
|
||||
timestamp=timestamp,
|
||||
context_window_minutes=context_window_minutes
|
||||
)
|
||||
|
||||
logger.info(f"Fetched market state with {len(market_state)} timeframes")
|
||||
return market_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching market state: {e}")
|
||||
return {}
|
||||
|
||||
def _generate_timing_labels(self, market_state: Dict, entry_time: datetime,
|
||||
exit_time: datetime, direction: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate training labels for optimal timing detection
|
||||
|
||||
Labels help model learn:
|
||||
- WHEN to enter (optimal entry timing)
|
||||
- WHEN to exit (optimal exit timing)
|
||||
- WHEN NOT to trade (avoid bad timing)
|
||||
|
||||
Args:
|
||||
market_state: OHLCV data for all timeframes
|
||||
entry_time: Entry timestamp
|
||||
exit_time: Exit timestamp
|
||||
direction: Trade direction (LONG/SHORT)
|
||||
|
||||
Returns:
|
||||
Dict with training labels for each timeframe
|
||||
"""
|
||||
labels = {
|
||||
'direction': direction,
|
||||
'entry_timestamp': entry_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'exit_timestamp': exit_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
|
||||
# Generate labels for each timeframe
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
for tf in timeframes:
|
||||
tf_key = f'ohlcv_{tf}'
|
||||
if tf_key in market_state and 'timestamps' in market_state[tf_key]:
|
||||
timestamps = market_state[tf_key]['timestamps']
|
||||
|
||||
label_list = []
|
||||
entry_idx = -1
|
||||
exit_idx = -1
|
||||
|
||||
for i, ts_str in enumerate(timestamps):
|
||||
try:
|
||||
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
|
||||
# Make timezone-aware
|
||||
if ts.tzinfo is None:
|
||||
ts = pytz.UTC.localize(ts)
|
||||
|
||||
# Make entry_time and exit_time timezone-aware if needed
|
||||
if entry_time.tzinfo is None:
|
||||
entry_time = pytz.UTC.localize(entry_time)
|
||||
if exit_time.tzinfo is None:
|
||||
exit_time = pytz.UTC.localize(exit_time)
|
||||
|
||||
# Determine label based on timing
|
||||
if abs((ts - entry_time).total_seconds()) < 60: # Within 1 minute of entry
|
||||
label = 1 # OPTIMAL ENTRY TIMING
|
||||
entry_idx = i
|
||||
elif abs((ts - exit_time).total_seconds()) < 60: # Within 1 minute of exit
|
||||
label = 3 # OPTIMAL EXIT TIMING
|
||||
exit_idx = i
|
||||
elif entry_time < ts < exit_time: # Between entry and exit
|
||||
label = 2 # HOLD POSITION
|
||||
else: # Before entry or after exit
|
||||
label = 0 # NO ACTION (avoid trading)
|
||||
|
||||
label_list.append(label)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing timestamp {ts_str}: {e}")
|
||||
label_list.append(0)
|
||||
|
||||
labels[f'labels_{tf}'] = label_list
|
||||
labels[f'entry_index_{tf}'] = entry_idx
|
||||
labels[f'exit_index_{tf}'] = exit_idx
|
||||
|
||||
# Log label distribution
|
||||
label_counts = {0: 0, 1: 0, 2: 0, 3: 0}
|
||||
for label in label_list:
|
||||
label_counts[label] += 1
|
||||
|
||||
logger.info(f"Generated {tf} labels: {label_counts[0]} NO_ACTION, "
|
||||
f"{label_counts[1]} ENTRY, {label_counts[2]} HOLD, {label_counts[3]} EXIT")
|
||||
|
||||
return labels
|
||||
|
||||
def fetch_training_batch(self, annotations: List[Dict],
|
||||
context_window_minutes: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch training data for multiple annotations
|
||||
|
||||
Args:
|
||||
annotations: List of annotation metadata
|
||||
context_window_minutes: Minutes before/after entry to include
|
||||
|
||||
Returns:
|
||||
List of training data dictionaries
|
||||
"""
|
||||
training_data = []
|
||||
|
||||
logger.info(f"Fetching training batch for {len(annotations)} annotations")
|
||||
|
||||
for annotation in annotations:
|
||||
try:
|
||||
training_sample = self.fetch_training_data_for_annotation(
|
||||
annotation, context_window_minutes
|
||||
)
|
||||
|
||||
if training_sample:
|
||||
training_data.append(training_sample)
|
||||
else:
|
||||
logger.warning(f"Failed to fetch training data for annotation {annotation.get('annotation_id')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing annotation {annotation.get('annotation_id')}: {e}")
|
||||
|
||||
logger.info(f"Successfully fetched training data for {len(training_data)}/{len(annotations)} annotations")
|
||||
return training_data
|
||||
|
||||
def get_training_statistics(self, training_data: List[Dict]) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics about training data
|
||||
|
||||
Args:
|
||||
training_data: List of training data samples
|
||||
|
||||
Returns:
|
||||
Dict with training statistics
|
||||
"""
|
||||
if not training_data:
|
||||
return {}
|
||||
|
||||
stats = {
|
||||
'total_samples': len(training_data),
|
||||
'symbols': {},
|
||||
'directions': {'LONG': 0, 'SHORT': 0},
|
||||
'avg_profit_loss': 0.0,
|
||||
'timeframes_available': set()
|
||||
}
|
||||
|
||||
total_pnl = 0.0
|
||||
|
||||
for sample in training_data:
|
||||
symbol = sample.get('symbol', 'UNKNOWN')
|
||||
direction = sample.get('expected_outcome', {}).get('direction', 'UNKNOWN')
|
||||
pnl = sample.get('expected_outcome', {}).get('profit_loss_pct', 0.0)
|
||||
|
||||
# Count symbols
|
||||
stats['symbols'][symbol] = stats['symbols'].get(symbol, 0) + 1
|
||||
|
||||
# Count directions
|
||||
if direction in stats['directions']:
|
||||
stats['directions'][direction] += 1
|
||||
|
||||
# Accumulate P&L
|
||||
total_pnl += pnl
|
||||
|
||||
# Check available timeframes
|
||||
market_state = sample.get('market_state', {})
|
||||
for key in market_state.keys():
|
||||
if key.startswith('ohlcv_'):
|
||||
stats['timeframes_available'].add(key.replace('ohlcv_', ''))
|
||||
|
||||
stats['avg_profit_loss'] = total_pnl / len(training_data)
|
||||
stats['timeframes_available'] = list(stats['timeframes_available'])
|
||||
|
||||
return stats
|
||||
@@ -1,12 +0,0 @@
|
||||
the problem we have is we have duplicate implementations.
|
||||
|
||||
we should have only one data provider implementation in the main /core folder and extend it there if we need more functionality
|
||||
|
||||
we need to fully move the Inference Training Coordinator functions in Orchestrator - both classes have overlaping responsibilities and only one should exist.
|
||||
|
||||
InferenceFrameReference also should be in core/data_models.py.
|
||||
|
||||
we do not need a core folder in ANNOTATE app. we should refactor and move the classes in the main /core folder. this is a design flaw. we should have only one "core" naturally.
|
||||
the purpose of ANNOTATE app is to provide UI for creating test cases and anotating data and also running inference and training.
|
||||
all implementations should be in the main system and only referenced and used in the ANNOTATE app
|
||||
we should have only one data provider implementation in the main /core folder and extend it there if we need more functionality
|
||||
@@ -48,7 +48,7 @@ sys.path.insert(0, str(annotate_dir))
|
||||
try:
|
||||
from core.annotation_manager import AnnotationManager
|
||||
from core.real_training_adapter import RealTrainingAdapter
|
||||
from core.data_loader import HistoricalDataLoader, TimeRangeManager
|
||||
# Using main DataProvider directly instead of duplicate data_loader
|
||||
except ImportError:
|
||||
# Try alternative import path
|
||||
import importlib.util
|
||||
@@ -71,15 +71,9 @@ except ImportError:
|
||||
train_spec.loader.exec_module(train_module)
|
||||
RealTrainingAdapter = train_module.RealTrainingAdapter
|
||||
|
||||
# Load data_loader
|
||||
data_spec = importlib.util.spec_from_file_location(
|
||||
"data_loader",
|
||||
annotate_dir / "core" / "data_loader.py"
|
||||
)
|
||||
data_module = importlib.util.module_from_spec(data_spec)
|
||||
data_spec.loader.exec_module(data_module)
|
||||
HistoricalDataLoader = data_module.HistoricalDataLoader
|
||||
TimeRangeManager = data_module.TimeRangeManager
|
||||
# Using main DataProvider directly - no need for duplicate data_loader
|
||||
HistoricalDataLoader = None
|
||||
TimeRangeManager = None
|
||||
|
||||
# Setup logging - configure before any logging occurs
|
||||
log_dir = Path(__file__).parent.parent / 'logs'
|
||||
@@ -745,7 +739,17 @@ class AnnotationDashboard:
|
||||
])
|
||||
|
||||
# Initialize core components (skip initial load for fast startup)
|
||||
self.data_provider = DataProvider(skip_initial_load=True) if DataProvider else None
|
||||
try:
|
||||
if DataProvider:
|
||||
config = get_config()
|
||||
self.data_provider = DataProvider(skip_initial_load=True)
|
||||
logger.info("DataProvider initialized successfully")
|
||||
else:
|
||||
self.data_provider = None
|
||||
logger.warning("DataProvider class not available")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize DataProvider: {e}")
|
||||
self.data_provider = None
|
||||
|
||||
# Enable unified storage for real-time data access
|
||||
if self.data_provider:
|
||||
@@ -780,15 +784,15 @@ class AnnotationDashboard:
|
||||
else:
|
||||
logger.info("Auto-load disabled. Models available for lazy loading: " + ", ".join(self.available_models))
|
||||
|
||||
# Initialize data loader with existing DataProvider
|
||||
self.data_loader = HistoricalDataLoader(self.data_provider) if self.data_provider else None
|
||||
self.time_range_manager = TimeRangeManager(self.data_loader) if self.data_loader else None
|
||||
# Use main DataProvider directly instead of duplicate data_loader
|
||||
self.data_loader = None # Deprecated - using data_provider directly
|
||||
self.time_range_manager = None # Deprecated
|
||||
|
||||
# Setup routes
|
||||
self._setup_routes()
|
||||
|
||||
# Start background data refresh after startup
|
||||
if self.data_loader:
|
||||
if self.data_provider:
|
||||
self._start_background_data_refresh()
|
||||
|
||||
logger.info("Annotation Dashboard initialized")
|
||||
@@ -1105,7 +1109,8 @@ class AnnotationDashboard:
|
||||
logger.info(" Starting one-time background data refresh (fetching only recent missing data)")
|
||||
|
||||
# Disable startup mode to fetch fresh data
|
||||
self.data_loader.disable_startup_mode()
|
||||
if self.data_provider:
|
||||
self.data_provider.disable_startup_mode()
|
||||
|
||||
# Use the new on-demand refresh method
|
||||
logger.info("Using on-demand refresh for recent data")
|
||||
@@ -1374,15 +1379,14 @@ class AnnotationDashboard:
|
||||
|
||||
pivot_logger.info(f"Recalculating pivots for {symbol} {timeframe} using backend data")
|
||||
|
||||
if not self.data_loader:
|
||||
if not self.data_provider:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': {'code': 'DATA_LOADER_UNAVAILABLE', 'message': 'Data loader not available'}
|
||||
'error': {'code': 'DATA_PROVIDER_UNAVAILABLE', 'message': 'Data provider not available'}
|
||||
})
|
||||
|
||||
# Fetch latest data from data_loader (which should have the updated cache/DB from previous calls)
|
||||
# We get enough history for proper pivot calculation
|
||||
df = self.data_loader.get_data(
|
||||
# Fetch latest data from data_provider for pivot calculation
|
||||
df = self.data_provider.get_data_for_annotation(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=2500, # Enough for context
|
||||
@@ -1423,14 +1427,14 @@ class AnnotationDashboard:
|
||||
|
||||
webui_logger.info(f"Chart data GET request: {symbol} {timeframe} limit={limit}")
|
||||
|
||||
if not self.data_loader:
|
||||
if not self.data_provider:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': {'code': 'DATA_LOADER_UNAVAILABLE', 'message': 'Data loader not available'}
|
||||
'error': {'code': 'DATA_PROVIDER_UNAVAILABLE', 'message': 'Data provider not available'}
|
||||
})
|
||||
|
||||
# Fetch data using data loader
|
||||
df = self.data_loader.get_data(
|
||||
# Fetch data using main data provider
|
||||
df = self.data_provider.get_data_for_annotation(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=limit,
|
||||
@@ -1486,12 +1490,12 @@ class AnnotationDashboard:
|
||||
if end_time_str:
|
||||
webui_logger.info(f" end_time: {end_time_str}")
|
||||
|
||||
if not self.data_loader:
|
||||
if not self.data_provider:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': {
|
||||
'code': 'DATA_LOADER_UNAVAILABLE',
|
||||
'message': 'Data loader not available'
|
||||
'code': 'DATA_PROVIDER_UNAVAILABLE',
|
||||
'message': 'Data provider not available'
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1499,14 +1503,14 @@ class AnnotationDashboard:
|
||||
start_time = datetime.fromisoformat(start_time_str.replace('Z', '+00:00')) if start_time_str else None
|
||||
end_time = datetime.fromisoformat(end_time_str.replace('Z', '+00:00')) if end_time_str else None
|
||||
|
||||
# Fetch data for each timeframe using data loader
|
||||
# Fetch data for each timeframe using data provider
|
||||
# This will automatically:
|
||||
# 1. Check DuckDB first
|
||||
# 2. Fetch from API if not in cache
|
||||
# 3. Store in DuckDB for future use
|
||||
chart_data = {}
|
||||
for timeframe in timeframes:
|
||||
df = self.data_loader.get_data(
|
||||
df = self.data_provider.get_data_for_annotation(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
start_time=start_time,
|
||||
@@ -1625,7 +1629,7 @@ class AnnotationDashboard:
|
||||
|
||||
# Collect market snapshots for SQLite storage
|
||||
market_snapshots = {}
|
||||
if self.data_loader:
|
||||
if self.data_provider:
|
||||
try:
|
||||
# Get OHLCV data for all timeframes around the annotation time
|
||||
entry_time = datetime.fromisoformat(data['entry']['timestamp'].replace('Z', '+00:00'))
|
||||
@@ -1636,7 +1640,7 @@ class AnnotationDashboard:
|
||||
end_time = exit_time + timedelta(minutes=5)
|
||||
|
||||
for timeframe in ['1s', '1m', '1h', '1d']:
|
||||
df = self.data_loader.get_data(
|
||||
df = self.data_provider.get_data_for_annotation(
|
||||
symbol=data['symbol'],
|
||||
timeframe=timeframe,
|
||||
start_time=start_time,
|
||||
@@ -2530,11 +2534,11 @@ class AnnotationDashboard:
|
||||
'prediction': None
|
||||
}
|
||||
|
||||
# Get latest candle for the requested timeframe using data_loader
|
||||
if self.data_loader:
|
||||
# Get latest candle for the requested timeframe using data_provider
|
||||
if self.data_provider:
|
||||
try:
|
||||
# Get latest candle from data_loader
|
||||
df = self.data_loader.get_data(symbol, timeframe, limit=2, direction='latest')
|
||||
# Get latest candle from data_provider (includes real-time data)
|
||||
df = self.data_provider.get_data_for_annotation(symbol, timeframe, limit=2, direction='latest')
|
||||
if df is not None and not df.empty:
|
||||
latest_candle = df.iloc[-1]
|
||||
|
||||
@@ -2567,9 +2571,9 @@ class AnnotationDashboard:
|
||||
'is_confirmed': is_confirmed
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting latest candle from data_loader: {e}", exc_info=True)
|
||||
logger.debug(f"Error getting latest candle from data_provider: {e}", exc_info=True)
|
||||
else:
|
||||
logger.debug("Data loader not available for live updates")
|
||||
logger.debug("Data provider not available for live updates")
|
||||
|
||||
# Get latest model predictions
|
||||
if self.orchestrator:
|
||||
@@ -2641,10 +2645,10 @@ class AnnotationDashboard:
|
||||
}
|
||||
|
||||
# Get latest candle for each requested timeframe
|
||||
if self.data_loader:
|
||||
if self.data_provider:
|
||||
for timeframe in timeframes:
|
||||
try:
|
||||
df = self.data_loader.get_data(symbol, timeframe, limit=2, direction='latest')
|
||||
df = self.data_provider.get_data_for_annotation(symbol, timeframe, limit=2, direction='latest')
|
||||
if df is not None and not df.empty:
|
||||
latest_candle = df.iloc[-1]
|
||||
|
||||
@@ -3301,15 +3305,17 @@ class AnnotationDashboard:
|
||||
for tf in required_tfs + optional_tfs:
|
||||
try:
|
||||
# Fetch enough candles (600 for training, but accept less)
|
||||
df = self.data_loader.get_data(
|
||||
symbol=symbol,
|
||||
timeframe=tf,
|
||||
end_time=dt,
|
||||
limit=600,
|
||||
direction='before'
|
||||
) if self.data_loader else None
|
||||
df = None
|
||||
if self.data_provider:
|
||||
df = self.data_provider.get_data_for_annotation(
|
||||
symbol=symbol,
|
||||
timeframe=tf,
|
||||
end_time=dt,
|
||||
limit=600,
|
||||
direction='before'
|
||||
)
|
||||
|
||||
# Fallback to data provider if data_loader not available
|
||||
# Fallback to regular historical data if annotation method fails
|
||||
if df is None or df.empty:
|
||||
if self.data_provider:
|
||||
df = self.data_provider.get_historical_data(symbol, tf, limit=600, refresh=False)
|
||||
|
||||
Reference in New Issue
Block a user