Files
gogo2/COBY/replay/replay_manager.py
Dobromir Popov 1479ac1624 replay system
2025-08-04 22:46:11 +03:00

665 lines
26 KiB
Python

"""
Historical data replay manager implementation.
Provides configurable playback of historical market data with session management.
"""
import asyncio
import uuid
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Callable, Any, Union
from dataclasses import replace
from ..interfaces.replay_manager import ReplayManager
from ..models.core import ReplaySession, ReplayStatus, OrderBookSnapshot, TradeEvent
from ..storage.storage_manager import StorageManager
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ReplayError, ValidationError
from ..utils.timing import get_current_timestamp
from ..config import Config
logger = get_logger(__name__)
class HistoricalReplayManager(ReplayManager):
"""
Implementation of historical data replay functionality.
Provides:
- Session-based replay management
- Configurable playback speeds
- Real-time data streaming
- Session controls (start/pause/stop/seek)
- Data filtering by symbol and exchange
"""
def __init__(self, storage_manager: StorageManager, config: Config):
"""
Initialize replay manager.
Args:
storage_manager: Storage manager for data access
config: System configuration
"""
self.storage_manager = storage_manager
self.config = config
# Session management
self.sessions: Dict[str, ReplaySession] = {}
self.session_tasks: Dict[str, asyncio.Task] = {}
self.session_callbacks: Dict[str, Dict[str, List[Callable]]] = {}
# Performance tracking
self.stats = {
'sessions_created': 0,
'sessions_completed': 0,
'sessions_failed': 0,
'total_events_replayed': 0,
'avg_replay_speed': 0.0
}
logger.info("Historical replay manager initialized")
def create_replay_session(self, start_time: datetime, end_time: datetime,
speed: float = 1.0, symbols: Optional[List[str]] = None,
exchanges: Optional[List[str]] = None) -> str:
"""Create a new replay session."""
try:
set_correlation_id()
# Validate parameters
validation_errors = self.validate_replay_request(start_time, end_time, symbols, exchanges)
if validation_errors:
raise ValidationError(f"Invalid replay request: {', '.join(validation_errors)}")
# Generate session ID
session_id = str(uuid.uuid4())
# Create session
session = ReplaySession(
session_id=session_id,
start_time=start_time,
end_time=end_time,
current_time=start_time,
speed=speed,
status=ReplayStatus.CREATED,
symbols=symbols or [],
exchanges=exchanges or [],
created_at=get_current_timestamp(),
events_replayed=0,
total_events=0,
progress=0.0
)
# Store session
self.sessions[session_id] = session
self.session_callbacks[session_id] = {
'data': [],
'status': []
}
self.stats['sessions_created'] += 1
logger.info(f"Created replay session {session_id} for {start_time} to {end_time}")
return session_id
except Exception as e:
logger.error(f"Failed to create replay session: {e}")
raise ReplayError(f"Session creation failed: {e}")
async def start_replay(self, session_id: str) -> None:
"""Start replay session."""
try:
set_correlation_id()
if session_id not in self.sessions:
raise ReplayError(f"Session {session_id} not found")
session = self.sessions[session_id]
if session.status == ReplayStatus.RUNNING:
logger.warning(f"Session {session_id} is already running")
return
# Update session status
session.status = ReplayStatus.RUNNING
session.started_at = get_current_timestamp()
# Notify status callbacks
await self._notify_status_callbacks(session_id, ReplayStatus.RUNNING)
# Start replay task
task = asyncio.create_task(self._replay_task(session_id))
self.session_tasks[session_id] = task
logger.info(f"Started replay session {session_id}")
except Exception as e:
logger.error(f"Failed to start replay session {session_id}: {e}")
await self._set_session_error(session_id, str(e))
raise ReplayError(f"Failed to start replay: {e}")
async def pause_replay(self, session_id: str) -> None:
"""Pause replay session."""
try:
if session_id not in self.sessions:
raise ReplayError(f"Session {session_id} not found")
session = self.sessions[session_id]
if session.status != ReplayStatus.RUNNING:
logger.warning(f"Session {session_id} is not running")
return
# Update session status
session.status = ReplayStatus.PAUSED
session.paused_at = get_current_timestamp()
# Cancel replay task
if session_id in self.session_tasks:
self.session_tasks[session_id].cancel()
del self.session_tasks[session_id]
# Notify status callbacks
await self._notify_status_callbacks(session_id, ReplayStatus.PAUSED)
logger.info(f"Paused replay session {session_id}")
except Exception as e:
logger.error(f"Failed to pause replay session {session_id}: {e}")
raise ReplayError(f"Failed to pause replay: {e}")
async def resume_replay(self, session_id: str) -> None:
"""Resume paused replay session."""
try:
if session_id not in self.sessions:
raise ReplayError(f"Session {session_id} not found")
session = self.sessions[session_id]
if session.status != ReplayStatus.PAUSED:
logger.warning(f"Session {session_id} is not paused")
return
# Resume from current position
await self.start_replay(session_id)
logger.info(f"Resumed replay session {session_id}")
except Exception as e:
logger.error(f"Failed to resume replay session {session_id}: {e}")
raise ReplayError(f"Failed to resume replay: {e}")
async def stop_replay(self, session_id: str) -> None:
"""Stop replay session."""
try:
if session_id not in self.sessions:
raise ReplayError(f"Session {session_id} not found")
session = self.sessions[session_id]
# Update session status
session.status = ReplayStatus.STOPPED
session.stopped_at = get_current_timestamp()
# Cancel replay task
if session_id in self.session_tasks:
self.session_tasks[session_id].cancel()
try:
await self.session_tasks[session_id]
except asyncio.CancelledError:
pass
del self.session_tasks[session_id]
# Notify status callbacks
await self._notify_status_callbacks(session_id, ReplayStatus.STOPPED)
logger.info(f"Stopped replay session {session_id}")
except Exception as e:
logger.error(f"Failed to stop replay session {session_id}: {e}")
raise ReplayError(f"Failed to stop replay: {e}")
def get_replay_status(self, session_id: str) -> Optional[ReplaySession]:
"""Get replay session status."""
return self.sessions.get(session_id)
def list_replay_sessions(self) -> List[ReplaySession]:
"""List all replay sessions."""
return list(self.sessions.values())
def delete_replay_session(self, session_id: str) -> bool:
"""Delete replay session."""
try:
if session_id not in self.sessions:
return False
# Stop session if running
if self.sessions[session_id].status == ReplayStatus.RUNNING:
asyncio.create_task(self.stop_replay(session_id))
# Clean up
del self.sessions[session_id]
if session_id in self.session_callbacks:
del self.session_callbacks[session_id]
logger.info(f"Deleted replay session {session_id}")
return True
except Exception as e:
logger.error(f"Failed to delete replay session {session_id}: {e}")
return False
def set_replay_speed(self, session_id: str, speed: float) -> bool:
"""Change replay speed for active session."""
try:
if session_id not in self.sessions:
return False
if speed <= 0:
raise ValueError("Speed must be positive")
session = self.sessions[session_id]
session.speed = speed
logger.info(f"Set replay speed to {speed}x for session {session_id}")
return True
except Exception as e:
logger.error(f"Failed to set replay speed for session {session_id}: {e}")
return False
def seek_replay(self, session_id: str, timestamp: datetime) -> bool:
"""Seek to specific timestamp in replay."""
try:
if session_id not in self.sessions:
return False
session = self.sessions[session_id]
# Validate timestamp is within session range
if timestamp < session.start_time or timestamp > session.end_time:
logger.warning(f"Seek timestamp {timestamp} outside session range")
return False
# Update current time
session.current_time = timestamp
# Recalculate progress
total_duration = (session.end_time - session.start_time).total_seconds()
elapsed_duration = (timestamp - session.start_time).total_seconds()
session.progress = elapsed_duration / total_duration if total_duration > 0 else 0.0
logger.info(f"Seeked to {timestamp} in session {session_id}")
return True
except Exception as e:
logger.error(f"Failed to seek in session {session_id}: {e}")
return False
def add_data_callback(self, session_id: str, callback: Callable) -> bool:
"""Add callback for replay data."""
try:
if session_id not in self.session_callbacks:
return False
self.session_callbacks[session_id]['data'].append(callback)
logger.debug(f"Added data callback for session {session_id}")
return True
except Exception as e:
logger.error(f"Failed to add data callback for session {session_id}: {e}")
return False
def remove_data_callback(self, session_id: str, callback: Callable) -> bool:
"""Remove data callback from replay session."""
try:
if session_id not in self.session_callbacks:
return False
callbacks = self.session_callbacks[session_id]['data']
if callback in callbacks:
callbacks.remove(callback)
logger.debug(f"Removed data callback for session {session_id}")
return True
return False
except Exception as e:
logger.error(f"Failed to remove data callback for session {session_id}: {e}")
return False
def add_status_callback(self, session_id: str, callback: Callable) -> bool:
"""Add callback for replay status changes."""
try:
if session_id not in self.session_callbacks:
return False
self.session_callbacks[session_id]['status'].append(callback)
logger.debug(f"Added status callback for session {session_id}")
return True
except Exception as e:
logger.error(f"Failed to add status callback for session {session_id}: {e}")
return False
async def get_available_data_range(self, symbol: str,
exchange: Optional[str] = None) -> Optional[Dict[str, datetime]]:
"""Get available data time range for replay."""
try:
# Query database for data range
if exchange:
query = """
SELECT
MIN(timestamp) as start_time,
MAX(timestamp) as end_time
FROM order_book_snapshots
WHERE symbol = $1 AND exchange = $2
"""
result = await self.storage_manager.connection_pool.fetchrow(query, symbol, exchange)
else:
query = """
SELECT
MIN(timestamp) as start_time,
MAX(timestamp) as end_time
FROM order_book_snapshots
WHERE symbol = $1
"""
result = await self.storage_manager.connection_pool.fetchrow(query, symbol)
if result and result['start_time'] and result['end_time']:
return {
'start': result['start_time'],
'end': result['end_time']
}
return None
except Exception as e:
logger.error(f"Failed to get data range for {symbol}: {e}")
return None
def validate_replay_request(self, start_time: datetime, end_time: datetime,
symbols: Optional[List[str]] = None,
exchanges: Optional[List[str]] = None) -> List[str]:
"""Validate replay request parameters."""
errors = []
# Validate time range
if start_time >= end_time:
errors.append("Start time must be before end time")
# Check if time range is too large (more than 30 days)
if (end_time - start_time).days > 30:
errors.append("Time range cannot exceed 30 days")
# Check if start time is too far in the past (more than 1 year)
if (get_current_timestamp() - start_time).days > 365:
errors.append("Start time cannot be more than 1 year ago")
# Validate symbols
if symbols:
for symbol in symbols:
if not symbol or len(symbol) < 3:
errors.append(f"Invalid symbol: {symbol}")
# Validate exchanges
if exchanges:
valid_exchanges = self.config.exchanges.exchanges
for exchange in exchanges:
if exchange not in valid_exchanges:
errors.append(f"Unsupported exchange: {exchange}")
return errors
async def _replay_task(self, session_id: str) -> None:
"""Main replay task that streams historical data."""
try:
session = self.sessions[session_id]
# Calculate total events for progress tracking
await self._calculate_total_events(session_id)
# Stream data
await self._stream_historical_data(session_id)
# Mark as completed
session.status = ReplayStatus.COMPLETED
session.completed_at = get_current_timestamp()
session.progress = 1.0
await self._notify_status_callbacks(session_id, ReplayStatus.COMPLETED)
self.stats['sessions_completed'] += 1
logger.info(f"Completed replay session {session_id}")
except asyncio.CancelledError:
logger.info(f"Replay session {session_id} was cancelled")
except Exception as e:
logger.error(f"Replay task failed for session {session_id}: {e}")
await self._set_session_error(session_id, str(e))
self.stats['sessions_failed'] += 1
async def _calculate_total_events(self, session_id: str) -> None:
"""Calculate total number of events for progress tracking."""
try:
session = self.sessions[session_id]
# Build query conditions
conditions = ["timestamp >= $1", "timestamp <= $2"]
params = [session.start_time, session.end_time]
param_count = 2
if session.symbols:
param_count += 1
conditions.append(f"symbol = ANY(${param_count})")
params.append(session.symbols)
if session.exchanges:
param_count += 1
conditions.append(f"exchange = ANY(${param_count})")
params.append(session.exchanges)
where_clause = " AND ".join(conditions)
# Count order book events
orderbook_query = f"""
SELECT COUNT(*) FROM order_book_snapshots
WHERE {where_clause}
"""
orderbook_count = await self.storage_manager.connection_pool.fetchval(
orderbook_query, *params
)
# Count trade events
trade_query = f"""
SELECT COUNT(*) FROM trade_events
WHERE {where_clause}
"""
trade_count = await self.storage_manager.connection_pool.fetchval(
trade_query, *params
)
session.total_events = (orderbook_count or 0) + (trade_count or 0)
logger.debug(f"Session {session_id} has {session.total_events} total events")
except Exception as e:
logger.error(f"Failed to calculate total events for session {session_id}: {e}")
session.total_events = 0
async def _stream_historical_data(self, session_id: str) -> None:
"""Stream historical data for replay session."""
session = self.sessions[session_id]
# Build query conditions
conditions = ["timestamp >= $1", "timestamp <= $2"]
params = [session.current_time, session.end_time]
param_count = 2
if session.symbols:
param_count += 1
conditions.append(f"symbol = ANY(${param_count})")
params.append(session.symbols)
if session.exchanges:
param_count += 1
conditions.append(f"exchange = ANY(${param_count})")
params.append(session.exchanges)
where_clause = " AND ".join(conditions)
# Query both order book and trade data, ordered by timestamp
query = f"""
(
SELECT 'orderbook' as type, timestamp, symbol, exchange,
bids, asks, sequence_id, mid_price, spread, bid_volume, ask_volume,
NULL as price, NULL as size, NULL as side, NULL as trade_id
FROM order_book_snapshots
WHERE {where_clause}
)
UNION ALL
(
SELECT 'trade' as type, timestamp, symbol, exchange,
NULL as bids, NULL as asks, NULL as sequence_id,
NULL as mid_price, NULL as spread, NULL as bid_volume, NULL as ask_volume,
price, size, side, trade_id
FROM trade_events
WHERE {where_clause}
)
ORDER BY timestamp ASC
"""
# Stream data in chunks
chunk_size = 1000
offset = 0
last_timestamp = session.current_time
while session.status == ReplayStatus.RUNNING:
# Fetch chunk
chunk_query = f"{query} LIMIT {chunk_size} OFFSET {offset}"
rows = await self.storage_manager.connection_pool.fetch(chunk_query, *params)
if not rows:
break
# Process each row
for row in rows:
if session.status != ReplayStatus.RUNNING:
break
# Calculate delay based on replay speed
if last_timestamp < row['timestamp']:
time_diff = (row['timestamp'] - last_timestamp).total_seconds()
delay = time_diff / session.speed
if delay > 0:
await asyncio.sleep(delay)
# Create data object
if row['type'] == 'orderbook':
data = await self._create_orderbook_from_row(row)
else:
data = await self._create_trade_from_row(row)
# Notify data callbacks
await self._notify_data_callbacks(session_id, data)
# Update session progress
session.events_replayed += 1
session.current_time = row['timestamp']
if session.total_events > 0:
session.progress = session.events_replayed / session.total_events
last_timestamp = row['timestamp']
self.stats['total_events_replayed'] += 1
offset += chunk_size
async def _create_orderbook_from_row(self, row: Dict) -> OrderBookSnapshot:
"""Create OrderBookSnapshot from database row."""
import json
from ..models.core import PriceLevel
# Parse bids and asks from JSON
bids_data = json.loads(row['bids']) if row['bids'] else []
asks_data = json.loads(row['asks']) if row['asks'] else []
bids = [PriceLevel(price=b['price'], size=b['size'], count=b.get('count'))
for b in bids_data]
asks = [PriceLevel(price=a['price'], size=a['size'], count=a.get('count'))
for a in asks_data]
return OrderBookSnapshot(
symbol=row['symbol'],
exchange=row['exchange'],
timestamp=row['timestamp'],
bids=bids,
asks=asks,
sequence_id=row['sequence_id']
)
async def _create_trade_from_row(self, row: Dict) -> TradeEvent:
"""Create TradeEvent from database row."""
return TradeEvent(
symbol=row['symbol'],
exchange=row['exchange'],
timestamp=row['timestamp'],
price=float(row['price']),
size=float(row['size']),
side=row['side'],
trade_id=row['trade_id']
)
async def _notify_data_callbacks(self, session_id: str,
data: Union[OrderBookSnapshot, TradeEvent]) -> None:
"""Notify all data callbacks for a session."""
if session_id in self.session_callbacks:
callbacks = self.session_callbacks[session_id]['data']
for callback in callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(data)
else:
callback(data)
except Exception as e:
logger.error(f"Data callback error for session {session_id}: {e}")
async def _notify_status_callbacks(self, session_id: str, status: ReplayStatus) -> None:
"""Notify all status callbacks for a session."""
if session_id in self.session_callbacks:
callbacks = self.session_callbacks[session_id]['status']
for callback in callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(session_id, status)
else:
callback(session_id, status)
except Exception as e:
logger.error(f"Status callback error for session {session_id}: {e}")
async def _set_session_error(self, session_id: str, error_message: str) -> None:
"""Set session to error state."""
if session_id in self.sessions:
session = self.sessions[session_id]
session.status = ReplayStatus.ERROR
session.error_message = error_message
session.stopped_at = get_current_timestamp()
await self._notify_status_callbacks(session_id, ReplayStatus.ERROR)
def get_stats(self) -> Dict[str, Any]:
"""Get replay manager statistics."""
active_sessions = sum(1 for s in self.sessions.values()
if s.status == ReplayStatus.RUNNING)
return {
**self.stats,
'active_sessions': active_sessions,
'total_sessions': len(self.sessions),
'session_statuses': {
status.value: sum(1 for s in self.sessions.values() if s.status == status)
for status in ReplayStatus
}
}