665 lines
26 KiB
Python
665 lines
26 KiB
Python
"""
|
|
Historical data replay manager implementation.
|
|
Provides configurable playback of historical market data with session management.
|
|
"""
|
|
|
|
import asyncio
|
|
import uuid
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Callable, Any, Union
|
|
from dataclasses import replace
|
|
|
|
from ..interfaces.replay_manager import ReplayManager
|
|
from ..models.core import ReplaySession, ReplayStatus, OrderBookSnapshot, TradeEvent
|
|
from ..storage.storage_manager import StorageManager
|
|
from ..utils.logging import get_logger, set_correlation_id
|
|
from ..utils.exceptions import ReplayError, ValidationError
|
|
from ..utils.timing import get_current_timestamp
|
|
from ..config import Config
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class HistoricalReplayManager(ReplayManager):
|
|
"""
|
|
Implementation of historical data replay functionality.
|
|
|
|
Provides:
|
|
- Session-based replay management
|
|
- Configurable playback speeds
|
|
- Real-time data streaming
|
|
- Session controls (start/pause/stop/seek)
|
|
- Data filtering by symbol and exchange
|
|
"""
|
|
|
|
def __init__(self, storage_manager: StorageManager, config: Config):
|
|
"""
|
|
Initialize replay manager.
|
|
|
|
Args:
|
|
storage_manager: Storage manager for data access
|
|
config: System configuration
|
|
"""
|
|
self.storage_manager = storage_manager
|
|
self.config = config
|
|
|
|
# Session management
|
|
self.sessions: Dict[str, ReplaySession] = {}
|
|
self.session_tasks: Dict[str, asyncio.Task] = {}
|
|
self.session_callbacks: Dict[str, Dict[str, List[Callable]]] = {}
|
|
|
|
# Performance tracking
|
|
self.stats = {
|
|
'sessions_created': 0,
|
|
'sessions_completed': 0,
|
|
'sessions_failed': 0,
|
|
'total_events_replayed': 0,
|
|
'avg_replay_speed': 0.0
|
|
}
|
|
|
|
logger.info("Historical replay manager initialized")
|
|
|
|
def create_replay_session(self, start_time: datetime, end_time: datetime,
|
|
speed: float = 1.0, symbols: Optional[List[str]] = None,
|
|
exchanges: Optional[List[str]] = None) -> str:
|
|
"""Create a new replay session."""
|
|
try:
|
|
set_correlation_id()
|
|
|
|
# Validate parameters
|
|
validation_errors = self.validate_replay_request(start_time, end_time, symbols, exchanges)
|
|
if validation_errors:
|
|
raise ValidationError(f"Invalid replay request: {', '.join(validation_errors)}")
|
|
|
|
# Generate session ID
|
|
session_id = str(uuid.uuid4())
|
|
|
|
# Create session
|
|
session = ReplaySession(
|
|
session_id=session_id,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
current_time=start_time,
|
|
speed=speed,
|
|
status=ReplayStatus.CREATED,
|
|
symbols=symbols or [],
|
|
exchanges=exchanges or [],
|
|
created_at=get_current_timestamp(),
|
|
events_replayed=0,
|
|
total_events=0,
|
|
progress=0.0
|
|
)
|
|
|
|
# Store session
|
|
self.sessions[session_id] = session
|
|
self.session_callbacks[session_id] = {
|
|
'data': [],
|
|
'status': []
|
|
}
|
|
|
|
self.stats['sessions_created'] += 1
|
|
|
|
logger.info(f"Created replay session {session_id} for {start_time} to {end_time}")
|
|
return session_id
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create replay session: {e}")
|
|
raise ReplayError(f"Session creation failed: {e}")
|
|
|
|
async def start_replay(self, session_id: str) -> None:
|
|
"""Start replay session."""
|
|
try:
|
|
set_correlation_id()
|
|
|
|
if session_id not in self.sessions:
|
|
raise ReplayError(f"Session {session_id} not found")
|
|
|
|
session = self.sessions[session_id]
|
|
|
|
if session.status == ReplayStatus.RUNNING:
|
|
logger.warning(f"Session {session_id} is already running")
|
|
return
|
|
|
|
# Update session status
|
|
session.status = ReplayStatus.RUNNING
|
|
session.started_at = get_current_timestamp()
|
|
|
|
# Notify status callbacks
|
|
await self._notify_status_callbacks(session_id, ReplayStatus.RUNNING)
|
|
|
|
# Start replay task
|
|
task = asyncio.create_task(self._replay_task(session_id))
|
|
self.session_tasks[session_id] = task
|
|
|
|
logger.info(f"Started replay session {session_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to start replay session {session_id}: {e}")
|
|
await self._set_session_error(session_id, str(e))
|
|
raise ReplayError(f"Failed to start replay: {e}")
|
|
|
|
async def pause_replay(self, session_id: str) -> None:
|
|
"""Pause replay session."""
|
|
try:
|
|
if session_id not in self.sessions:
|
|
raise ReplayError(f"Session {session_id} not found")
|
|
|
|
session = self.sessions[session_id]
|
|
|
|
if session.status != ReplayStatus.RUNNING:
|
|
logger.warning(f"Session {session_id} is not running")
|
|
return
|
|
|
|
# Update session status
|
|
session.status = ReplayStatus.PAUSED
|
|
session.paused_at = get_current_timestamp()
|
|
|
|
# Cancel replay task
|
|
if session_id in self.session_tasks:
|
|
self.session_tasks[session_id].cancel()
|
|
del self.session_tasks[session_id]
|
|
|
|
# Notify status callbacks
|
|
await self._notify_status_callbacks(session_id, ReplayStatus.PAUSED)
|
|
|
|
logger.info(f"Paused replay session {session_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to pause replay session {session_id}: {e}")
|
|
raise ReplayError(f"Failed to pause replay: {e}")
|
|
|
|
async def resume_replay(self, session_id: str) -> None:
|
|
"""Resume paused replay session."""
|
|
try:
|
|
if session_id not in self.sessions:
|
|
raise ReplayError(f"Session {session_id} not found")
|
|
|
|
session = self.sessions[session_id]
|
|
|
|
if session.status != ReplayStatus.PAUSED:
|
|
logger.warning(f"Session {session_id} is not paused")
|
|
return
|
|
|
|
# Resume from current position
|
|
await self.start_replay(session_id)
|
|
|
|
logger.info(f"Resumed replay session {session_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to resume replay session {session_id}: {e}")
|
|
raise ReplayError(f"Failed to resume replay: {e}")
|
|
|
|
async def stop_replay(self, session_id: str) -> None:
|
|
"""Stop replay session."""
|
|
try:
|
|
if session_id not in self.sessions:
|
|
raise ReplayError(f"Session {session_id} not found")
|
|
|
|
session = self.sessions[session_id]
|
|
|
|
# Update session status
|
|
session.status = ReplayStatus.STOPPED
|
|
session.stopped_at = get_current_timestamp()
|
|
|
|
# Cancel replay task
|
|
if session_id in self.session_tasks:
|
|
self.session_tasks[session_id].cancel()
|
|
try:
|
|
await self.session_tasks[session_id]
|
|
except asyncio.CancelledError:
|
|
pass
|
|
del self.session_tasks[session_id]
|
|
|
|
# Notify status callbacks
|
|
await self._notify_status_callbacks(session_id, ReplayStatus.STOPPED)
|
|
|
|
logger.info(f"Stopped replay session {session_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to stop replay session {session_id}: {e}")
|
|
raise ReplayError(f"Failed to stop replay: {e}")
|
|
|
|
def get_replay_status(self, session_id: str) -> Optional[ReplaySession]:
|
|
"""Get replay session status."""
|
|
return self.sessions.get(session_id)
|
|
|
|
def list_replay_sessions(self) -> List[ReplaySession]:
|
|
"""List all replay sessions."""
|
|
return list(self.sessions.values())
|
|
|
|
def delete_replay_session(self, session_id: str) -> bool:
|
|
"""Delete replay session."""
|
|
try:
|
|
if session_id not in self.sessions:
|
|
return False
|
|
|
|
# Stop session if running
|
|
if self.sessions[session_id].status == ReplayStatus.RUNNING:
|
|
asyncio.create_task(self.stop_replay(session_id))
|
|
|
|
# Clean up
|
|
del self.sessions[session_id]
|
|
if session_id in self.session_callbacks:
|
|
del self.session_callbacks[session_id]
|
|
|
|
logger.info(f"Deleted replay session {session_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete replay session {session_id}: {e}")
|
|
return False
|
|
|
|
def set_replay_speed(self, session_id: str, speed: float) -> bool:
|
|
"""Change replay speed for active session."""
|
|
try:
|
|
if session_id not in self.sessions:
|
|
return False
|
|
|
|
if speed <= 0:
|
|
raise ValueError("Speed must be positive")
|
|
|
|
session = self.sessions[session_id]
|
|
session.speed = speed
|
|
|
|
logger.info(f"Set replay speed to {speed}x for session {session_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to set replay speed for session {session_id}: {e}")
|
|
return False
|
|
|
|
def seek_replay(self, session_id: str, timestamp: datetime) -> bool:
|
|
"""Seek to specific timestamp in replay."""
|
|
try:
|
|
if session_id not in self.sessions:
|
|
return False
|
|
|
|
session = self.sessions[session_id]
|
|
|
|
# Validate timestamp is within session range
|
|
if timestamp < session.start_time or timestamp > session.end_time:
|
|
logger.warning(f"Seek timestamp {timestamp} outside session range")
|
|
return False
|
|
|
|
# Update current time
|
|
session.current_time = timestamp
|
|
|
|
# Recalculate progress
|
|
total_duration = (session.end_time - session.start_time).total_seconds()
|
|
elapsed_duration = (timestamp - session.start_time).total_seconds()
|
|
session.progress = elapsed_duration / total_duration if total_duration > 0 else 0.0
|
|
|
|
logger.info(f"Seeked to {timestamp} in session {session_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to seek in session {session_id}: {e}")
|
|
return False
|
|
|
|
def add_data_callback(self, session_id: str, callback: Callable) -> bool:
|
|
"""Add callback for replay data."""
|
|
try:
|
|
if session_id not in self.session_callbacks:
|
|
return False
|
|
|
|
self.session_callbacks[session_id]['data'].append(callback)
|
|
logger.debug(f"Added data callback for session {session_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to add data callback for session {session_id}: {e}")
|
|
return False
|
|
|
|
def remove_data_callback(self, session_id: str, callback: Callable) -> bool:
|
|
"""Remove data callback from replay session."""
|
|
try:
|
|
if session_id not in self.session_callbacks:
|
|
return False
|
|
|
|
callbacks = self.session_callbacks[session_id]['data']
|
|
if callback in callbacks:
|
|
callbacks.remove(callback)
|
|
logger.debug(f"Removed data callback for session {session_id}")
|
|
return True
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to remove data callback for session {session_id}: {e}")
|
|
return False
|
|
|
|
def add_status_callback(self, session_id: str, callback: Callable) -> bool:
|
|
"""Add callback for replay status changes."""
|
|
try:
|
|
if session_id not in self.session_callbacks:
|
|
return False
|
|
|
|
self.session_callbacks[session_id]['status'].append(callback)
|
|
logger.debug(f"Added status callback for session {session_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to add status callback for session {session_id}: {e}")
|
|
return False
|
|
|
|
async def get_available_data_range(self, symbol: str,
|
|
exchange: Optional[str] = None) -> Optional[Dict[str, datetime]]:
|
|
"""Get available data time range for replay."""
|
|
try:
|
|
# Query database for data range
|
|
if exchange:
|
|
query = """
|
|
SELECT
|
|
MIN(timestamp) as start_time,
|
|
MAX(timestamp) as end_time
|
|
FROM order_book_snapshots
|
|
WHERE symbol = $1 AND exchange = $2
|
|
"""
|
|
result = await self.storage_manager.connection_pool.fetchrow(query, symbol, exchange)
|
|
else:
|
|
query = """
|
|
SELECT
|
|
MIN(timestamp) as start_time,
|
|
MAX(timestamp) as end_time
|
|
FROM order_book_snapshots
|
|
WHERE symbol = $1
|
|
"""
|
|
result = await self.storage_manager.connection_pool.fetchrow(query, symbol)
|
|
|
|
if result and result['start_time'] and result['end_time']:
|
|
return {
|
|
'start': result['start_time'],
|
|
'end': result['end_time']
|
|
}
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get data range for {symbol}: {e}")
|
|
return None
|
|
|
|
def validate_replay_request(self, start_time: datetime, end_time: datetime,
|
|
symbols: Optional[List[str]] = None,
|
|
exchanges: Optional[List[str]] = None) -> List[str]:
|
|
"""Validate replay request parameters."""
|
|
errors = []
|
|
|
|
# Validate time range
|
|
if start_time >= end_time:
|
|
errors.append("Start time must be before end time")
|
|
|
|
# Check if time range is too large (more than 30 days)
|
|
if (end_time - start_time).days > 30:
|
|
errors.append("Time range cannot exceed 30 days")
|
|
|
|
# Check if start time is too far in the past (more than 1 year)
|
|
if (get_current_timestamp() - start_time).days > 365:
|
|
errors.append("Start time cannot be more than 1 year ago")
|
|
|
|
# Validate symbols
|
|
if symbols:
|
|
for symbol in symbols:
|
|
if not symbol or len(symbol) < 3:
|
|
errors.append(f"Invalid symbol: {symbol}")
|
|
|
|
# Validate exchanges
|
|
if exchanges:
|
|
valid_exchanges = self.config.exchanges.exchanges
|
|
for exchange in exchanges:
|
|
if exchange not in valid_exchanges:
|
|
errors.append(f"Unsupported exchange: {exchange}")
|
|
|
|
return errors
|
|
|
|
async def _replay_task(self, session_id: str) -> None:
|
|
"""Main replay task that streams historical data."""
|
|
try:
|
|
session = self.sessions[session_id]
|
|
|
|
# Calculate total events for progress tracking
|
|
await self._calculate_total_events(session_id)
|
|
|
|
# Stream data
|
|
await self._stream_historical_data(session_id)
|
|
|
|
# Mark as completed
|
|
session.status = ReplayStatus.COMPLETED
|
|
session.completed_at = get_current_timestamp()
|
|
session.progress = 1.0
|
|
|
|
await self._notify_status_callbacks(session_id, ReplayStatus.COMPLETED)
|
|
self.stats['sessions_completed'] += 1
|
|
|
|
logger.info(f"Completed replay session {session_id}")
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info(f"Replay session {session_id} was cancelled")
|
|
except Exception as e:
|
|
logger.error(f"Replay task failed for session {session_id}: {e}")
|
|
await self._set_session_error(session_id, str(e))
|
|
self.stats['sessions_failed'] += 1
|
|
|
|
async def _calculate_total_events(self, session_id: str) -> None:
|
|
"""Calculate total number of events for progress tracking."""
|
|
try:
|
|
session = self.sessions[session_id]
|
|
|
|
# Build query conditions
|
|
conditions = ["timestamp >= $1", "timestamp <= $2"]
|
|
params = [session.start_time, session.end_time]
|
|
param_count = 2
|
|
|
|
if session.symbols:
|
|
param_count += 1
|
|
conditions.append(f"symbol = ANY(${param_count})")
|
|
params.append(session.symbols)
|
|
|
|
if session.exchanges:
|
|
param_count += 1
|
|
conditions.append(f"exchange = ANY(${param_count})")
|
|
params.append(session.exchanges)
|
|
|
|
where_clause = " AND ".join(conditions)
|
|
|
|
# Count order book events
|
|
orderbook_query = f"""
|
|
SELECT COUNT(*) FROM order_book_snapshots
|
|
WHERE {where_clause}
|
|
"""
|
|
orderbook_count = await self.storage_manager.connection_pool.fetchval(
|
|
orderbook_query, *params
|
|
)
|
|
|
|
# Count trade events
|
|
trade_query = f"""
|
|
SELECT COUNT(*) FROM trade_events
|
|
WHERE {where_clause}
|
|
"""
|
|
trade_count = await self.storage_manager.connection_pool.fetchval(
|
|
trade_query, *params
|
|
)
|
|
|
|
session.total_events = (orderbook_count or 0) + (trade_count or 0)
|
|
|
|
logger.debug(f"Session {session_id} has {session.total_events} total events")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to calculate total events for session {session_id}: {e}")
|
|
session.total_events = 0
|
|
|
|
async def _stream_historical_data(self, session_id: str) -> None:
|
|
"""Stream historical data for replay session."""
|
|
session = self.sessions[session_id]
|
|
|
|
# Build query conditions
|
|
conditions = ["timestamp >= $1", "timestamp <= $2"]
|
|
params = [session.current_time, session.end_time]
|
|
param_count = 2
|
|
|
|
if session.symbols:
|
|
param_count += 1
|
|
conditions.append(f"symbol = ANY(${param_count})")
|
|
params.append(session.symbols)
|
|
|
|
if session.exchanges:
|
|
param_count += 1
|
|
conditions.append(f"exchange = ANY(${param_count})")
|
|
params.append(session.exchanges)
|
|
|
|
where_clause = " AND ".join(conditions)
|
|
|
|
# Query both order book and trade data, ordered by timestamp
|
|
query = f"""
|
|
(
|
|
SELECT 'orderbook' as type, timestamp, symbol, exchange,
|
|
bids, asks, sequence_id, mid_price, spread, bid_volume, ask_volume,
|
|
NULL as price, NULL as size, NULL as side, NULL as trade_id
|
|
FROM order_book_snapshots
|
|
WHERE {where_clause}
|
|
)
|
|
UNION ALL
|
|
(
|
|
SELECT 'trade' as type, timestamp, symbol, exchange,
|
|
NULL as bids, NULL as asks, NULL as sequence_id,
|
|
NULL as mid_price, NULL as spread, NULL as bid_volume, NULL as ask_volume,
|
|
price, size, side, trade_id
|
|
FROM trade_events
|
|
WHERE {where_clause}
|
|
)
|
|
ORDER BY timestamp ASC
|
|
"""
|
|
|
|
# Stream data in chunks
|
|
chunk_size = 1000
|
|
offset = 0
|
|
last_timestamp = session.current_time
|
|
|
|
while session.status == ReplayStatus.RUNNING:
|
|
# Fetch chunk
|
|
chunk_query = f"{query} LIMIT {chunk_size} OFFSET {offset}"
|
|
rows = await self.storage_manager.connection_pool.fetch(chunk_query, *params)
|
|
|
|
if not rows:
|
|
break
|
|
|
|
# Process each row
|
|
for row in rows:
|
|
if session.status != ReplayStatus.RUNNING:
|
|
break
|
|
|
|
# Calculate delay based on replay speed
|
|
if last_timestamp < row['timestamp']:
|
|
time_diff = (row['timestamp'] - last_timestamp).total_seconds()
|
|
delay = time_diff / session.speed
|
|
|
|
if delay > 0:
|
|
await asyncio.sleep(delay)
|
|
|
|
# Create data object
|
|
if row['type'] == 'orderbook':
|
|
data = await self._create_orderbook_from_row(row)
|
|
else:
|
|
data = await self._create_trade_from_row(row)
|
|
|
|
# Notify data callbacks
|
|
await self._notify_data_callbacks(session_id, data)
|
|
|
|
# Update session progress
|
|
session.events_replayed += 1
|
|
session.current_time = row['timestamp']
|
|
|
|
if session.total_events > 0:
|
|
session.progress = session.events_replayed / session.total_events
|
|
|
|
last_timestamp = row['timestamp']
|
|
self.stats['total_events_replayed'] += 1
|
|
|
|
offset += chunk_size
|
|
|
|
async def _create_orderbook_from_row(self, row: Dict) -> OrderBookSnapshot:
|
|
"""Create OrderBookSnapshot from database row."""
|
|
import json
|
|
from ..models.core import PriceLevel
|
|
|
|
# Parse bids and asks from JSON
|
|
bids_data = json.loads(row['bids']) if row['bids'] else []
|
|
asks_data = json.loads(row['asks']) if row['asks'] else []
|
|
|
|
bids = [PriceLevel(price=b['price'], size=b['size'], count=b.get('count'))
|
|
for b in bids_data]
|
|
asks = [PriceLevel(price=a['price'], size=a['size'], count=a.get('count'))
|
|
for a in asks_data]
|
|
|
|
return OrderBookSnapshot(
|
|
symbol=row['symbol'],
|
|
exchange=row['exchange'],
|
|
timestamp=row['timestamp'],
|
|
bids=bids,
|
|
asks=asks,
|
|
sequence_id=row['sequence_id']
|
|
)
|
|
|
|
async def _create_trade_from_row(self, row: Dict) -> TradeEvent:
|
|
"""Create TradeEvent from database row."""
|
|
return TradeEvent(
|
|
symbol=row['symbol'],
|
|
exchange=row['exchange'],
|
|
timestamp=row['timestamp'],
|
|
price=float(row['price']),
|
|
size=float(row['size']),
|
|
side=row['side'],
|
|
trade_id=row['trade_id']
|
|
)
|
|
|
|
async def _notify_data_callbacks(self, session_id: str,
|
|
data: Union[OrderBookSnapshot, TradeEvent]) -> None:
|
|
"""Notify all data callbacks for a session."""
|
|
if session_id in self.session_callbacks:
|
|
callbacks = self.session_callbacks[session_id]['data']
|
|
for callback in callbacks:
|
|
try:
|
|
if asyncio.iscoroutinefunction(callback):
|
|
await callback(data)
|
|
else:
|
|
callback(data)
|
|
except Exception as e:
|
|
logger.error(f"Data callback error for session {session_id}: {e}")
|
|
|
|
async def _notify_status_callbacks(self, session_id: str, status: ReplayStatus) -> None:
|
|
"""Notify all status callbacks for a session."""
|
|
if session_id in self.session_callbacks:
|
|
callbacks = self.session_callbacks[session_id]['status']
|
|
for callback in callbacks:
|
|
try:
|
|
if asyncio.iscoroutinefunction(callback):
|
|
await callback(session_id, status)
|
|
else:
|
|
callback(session_id, status)
|
|
except Exception as e:
|
|
logger.error(f"Status callback error for session {session_id}: {e}")
|
|
|
|
async def _set_session_error(self, session_id: str, error_message: str) -> None:
|
|
"""Set session to error state."""
|
|
if session_id in self.sessions:
|
|
session = self.sessions[session_id]
|
|
session.status = ReplayStatus.ERROR
|
|
session.error_message = error_message
|
|
session.stopped_at = get_current_timestamp()
|
|
|
|
await self._notify_status_callbacks(session_id, ReplayStatus.ERROR)
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get replay manager statistics."""
|
|
active_sessions = sum(1 for s in self.sessions.values()
|
|
if s.status == ReplayStatus.RUNNING)
|
|
|
|
return {
|
|
**self.stats,
|
|
'active_sessions': active_sessions,
|
|
'total_sessions': len(self.sessions),
|
|
'session_statuses': {
|
|
status.value: sum(1 for s in self.sessions.values() if s.status == status)
|
|
for status in ReplayStatus
|
|
}
|
|
} |