replay system
This commit is contained in:
8
COBY/replay/__init__.py
Normal file
8
COBY/replay/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
"""
|
||||
Historical data replay system for the COBY multi-exchange data aggregation system.
|
||||
Provides configurable playback of historical market data with session management.
|
||||
"""
|
||||
|
||||
from .replay_manager import HistoricalReplayManager
|
||||
|
||||
__all__ = ['HistoricalReplayManager']
|
665
COBY/replay/replay_manager.py
Normal file
665
COBY/replay/replay_manager.py
Normal file
@ -0,0 +1,665 @@
|
||||
"""
|
||||
Historical data replay manager implementation.
|
||||
Provides configurable playback of historical market data with session management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Callable, Any, Union
|
||||
from dataclasses import replace
|
||||
|
||||
from ..interfaces.replay_manager import ReplayManager
|
||||
from ..models.core import ReplaySession, ReplayStatus, OrderBookSnapshot, TradeEvent
|
||||
from ..storage.storage_manager import StorageManager
|
||||
from ..utils.logging import get_logger, set_correlation_id
|
||||
from ..utils.exceptions import ReplayError, ValidationError
|
||||
from ..utils.timing import get_current_timestamp
|
||||
from ..config import Config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HistoricalReplayManager(ReplayManager):
|
||||
"""
|
||||
Implementation of historical data replay functionality.
|
||||
|
||||
Provides:
|
||||
- Session-based replay management
|
||||
- Configurable playback speeds
|
||||
- Real-time data streaming
|
||||
- Session controls (start/pause/stop/seek)
|
||||
- Data filtering by symbol and exchange
|
||||
"""
|
||||
|
||||
def __init__(self, storage_manager: StorageManager, config: Config):
|
||||
"""
|
||||
Initialize replay manager.
|
||||
|
||||
Args:
|
||||
storage_manager: Storage manager for data access
|
||||
config: System configuration
|
||||
"""
|
||||
self.storage_manager = storage_manager
|
||||
self.config = config
|
||||
|
||||
# Session management
|
||||
self.sessions: Dict[str, ReplaySession] = {}
|
||||
self.session_tasks: Dict[str, asyncio.Task] = {}
|
||||
self.session_callbacks: Dict[str, Dict[str, List[Callable]]] = {}
|
||||
|
||||
# Performance tracking
|
||||
self.stats = {
|
||||
'sessions_created': 0,
|
||||
'sessions_completed': 0,
|
||||
'sessions_failed': 0,
|
||||
'total_events_replayed': 0,
|
||||
'avg_replay_speed': 0.0
|
||||
}
|
||||
|
||||
logger.info("Historical replay manager initialized")
|
||||
|
||||
def create_replay_session(self, start_time: datetime, end_time: datetime,
|
||||
speed: float = 1.0, symbols: Optional[List[str]] = None,
|
||||
exchanges: Optional[List[str]] = None) -> str:
|
||||
"""Create a new replay session."""
|
||||
try:
|
||||
set_correlation_id()
|
||||
|
||||
# Validate parameters
|
||||
validation_errors = self.validate_replay_request(start_time, end_time, symbols, exchanges)
|
||||
if validation_errors:
|
||||
raise ValidationError(f"Invalid replay request: {', '.join(validation_errors)}")
|
||||
|
||||
# Generate session ID
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Create session
|
||||
session = ReplaySession(
|
||||
session_id=session_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
current_time=start_time,
|
||||
speed=speed,
|
||||
status=ReplayStatus.CREATED,
|
||||
symbols=symbols or [],
|
||||
exchanges=exchanges or [],
|
||||
created_at=get_current_timestamp(),
|
||||
events_replayed=0,
|
||||
total_events=0,
|
||||
progress=0.0
|
||||
)
|
||||
|
||||
# Store session
|
||||
self.sessions[session_id] = session
|
||||
self.session_callbacks[session_id] = {
|
||||
'data': [],
|
||||
'status': []
|
||||
}
|
||||
|
||||
self.stats['sessions_created'] += 1
|
||||
|
||||
logger.info(f"Created replay session {session_id} for {start_time} to {end_time}")
|
||||
return session_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create replay session: {e}")
|
||||
raise ReplayError(f"Session creation failed: {e}")
|
||||
|
||||
async def start_replay(self, session_id: str) -> None:
|
||||
"""Start replay session."""
|
||||
try:
|
||||
set_correlation_id()
|
||||
|
||||
if session_id not in self.sessions:
|
||||
raise ReplayError(f"Session {session_id} not found")
|
||||
|
||||
session = self.sessions[session_id]
|
||||
|
||||
if session.status == ReplayStatus.RUNNING:
|
||||
logger.warning(f"Session {session_id} is already running")
|
||||
return
|
||||
|
||||
# Update session status
|
||||
session.status = ReplayStatus.RUNNING
|
||||
session.started_at = get_current_timestamp()
|
||||
|
||||
# Notify status callbacks
|
||||
await self._notify_status_callbacks(session_id, ReplayStatus.RUNNING)
|
||||
|
||||
# Start replay task
|
||||
task = asyncio.create_task(self._replay_task(session_id))
|
||||
self.session_tasks[session_id] = task
|
||||
|
||||
logger.info(f"Started replay session {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start replay session {session_id}: {e}")
|
||||
await self._set_session_error(session_id, str(e))
|
||||
raise ReplayError(f"Failed to start replay: {e}")
|
||||
|
||||
async def pause_replay(self, session_id: str) -> None:
|
||||
"""Pause replay session."""
|
||||
try:
|
||||
if session_id not in self.sessions:
|
||||
raise ReplayError(f"Session {session_id} not found")
|
||||
|
||||
session = self.sessions[session_id]
|
||||
|
||||
if session.status != ReplayStatus.RUNNING:
|
||||
logger.warning(f"Session {session_id} is not running")
|
||||
return
|
||||
|
||||
# Update session status
|
||||
session.status = ReplayStatus.PAUSED
|
||||
session.paused_at = get_current_timestamp()
|
||||
|
||||
# Cancel replay task
|
||||
if session_id in self.session_tasks:
|
||||
self.session_tasks[session_id].cancel()
|
||||
del self.session_tasks[session_id]
|
||||
|
||||
# Notify status callbacks
|
||||
await self._notify_status_callbacks(session_id, ReplayStatus.PAUSED)
|
||||
|
||||
logger.info(f"Paused replay session {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to pause replay session {session_id}: {e}")
|
||||
raise ReplayError(f"Failed to pause replay: {e}")
|
||||
|
||||
async def resume_replay(self, session_id: str) -> None:
|
||||
"""Resume paused replay session."""
|
||||
try:
|
||||
if session_id not in self.sessions:
|
||||
raise ReplayError(f"Session {session_id} not found")
|
||||
|
||||
session = self.sessions[session_id]
|
||||
|
||||
if session.status != ReplayStatus.PAUSED:
|
||||
logger.warning(f"Session {session_id} is not paused")
|
||||
return
|
||||
|
||||
# Resume from current position
|
||||
await self.start_replay(session_id)
|
||||
|
||||
logger.info(f"Resumed replay session {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to resume replay session {session_id}: {e}")
|
||||
raise ReplayError(f"Failed to resume replay: {e}")
|
||||
|
||||
async def stop_replay(self, session_id: str) -> None:
|
||||
"""Stop replay session."""
|
||||
try:
|
||||
if session_id not in self.sessions:
|
||||
raise ReplayError(f"Session {session_id} not found")
|
||||
|
||||
session = self.sessions[session_id]
|
||||
|
||||
# Update session status
|
||||
session.status = ReplayStatus.STOPPED
|
||||
session.stopped_at = get_current_timestamp()
|
||||
|
||||
# Cancel replay task
|
||||
if session_id in self.session_tasks:
|
||||
self.session_tasks[session_id].cancel()
|
||||
try:
|
||||
await self.session_tasks[session_id]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
del self.session_tasks[session_id]
|
||||
|
||||
# Notify status callbacks
|
||||
await self._notify_status_callbacks(session_id, ReplayStatus.STOPPED)
|
||||
|
||||
logger.info(f"Stopped replay session {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop replay session {session_id}: {e}")
|
||||
raise ReplayError(f"Failed to stop replay: {e}")
|
||||
|
||||
def get_replay_status(self, session_id: str) -> Optional[ReplaySession]:
|
||||
"""Get replay session status."""
|
||||
return self.sessions.get(session_id)
|
||||
|
||||
def list_replay_sessions(self) -> List[ReplaySession]:
|
||||
"""List all replay sessions."""
|
||||
return list(self.sessions.values())
|
||||
|
||||
def delete_replay_session(self, session_id: str) -> bool:
|
||||
"""Delete replay session."""
|
||||
try:
|
||||
if session_id not in self.sessions:
|
||||
return False
|
||||
|
||||
# Stop session if running
|
||||
if self.sessions[session_id].status == ReplayStatus.RUNNING:
|
||||
asyncio.create_task(self.stop_replay(session_id))
|
||||
|
||||
# Clean up
|
||||
del self.sessions[session_id]
|
||||
if session_id in self.session_callbacks:
|
||||
del self.session_callbacks[session_id]
|
||||
|
||||
logger.info(f"Deleted replay session {session_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete replay session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
def set_replay_speed(self, session_id: str, speed: float) -> bool:
|
||||
"""Change replay speed for active session."""
|
||||
try:
|
||||
if session_id not in self.sessions:
|
||||
return False
|
||||
|
||||
if speed <= 0:
|
||||
raise ValueError("Speed must be positive")
|
||||
|
||||
session = self.sessions[session_id]
|
||||
session.speed = speed
|
||||
|
||||
logger.info(f"Set replay speed to {speed}x for session {session_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set replay speed for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
def seek_replay(self, session_id: str, timestamp: datetime) -> bool:
|
||||
"""Seek to specific timestamp in replay."""
|
||||
try:
|
||||
if session_id not in self.sessions:
|
||||
return False
|
||||
|
||||
session = self.sessions[session_id]
|
||||
|
||||
# Validate timestamp is within session range
|
||||
if timestamp < session.start_time or timestamp > session.end_time:
|
||||
logger.warning(f"Seek timestamp {timestamp} outside session range")
|
||||
return False
|
||||
|
||||
# Update current time
|
||||
session.current_time = timestamp
|
||||
|
||||
# Recalculate progress
|
||||
total_duration = (session.end_time - session.start_time).total_seconds()
|
||||
elapsed_duration = (timestamp - session.start_time).total_seconds()
|
||||
session.progress = elapsed_duration / total_duration if total_duration > 0 else 0.0
|
||||
|
||||
logger.info(f"Seeked to {timestamp} in session {session_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to seek in session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
def add_data_callback(self, session_id: str, callback: Callable) -> bool:
|
||||
"""Add callback for replay data."""
|
||||
try:
|
||||
if session_id not in self.session_callbacks:
|
||||
return False
|
||||
|
||||
self.session_callbacks[session_id]['data'].append(callback)
|
||||
logger.debug(f"Added data callback for session {session_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add data callback for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
def remove_data_callback(self, session_id: str, callback: Callable) -> bool:
|
||||
"""Remove data callback from replay session."""
|
||||
try:
|
||||
if session_id not in self.session_callbacks:
|
||||
return False
|
||||
|
||||
callbacks = self.session_callbacks[session_id]['data']
|
||||
if callback in callbacks:
|
||||
callbacks.remove(callback)
|
||||
logger.debug(f"Removed data callback for session {session_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove data callback for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
def add_status_callback(self, session_id: str, callback: Callable) -> bool:
|
||||
"""Add callback for replay status changes."""
|
||||
try:
|
||||
if session_id not in self.session_callbacks:
|
||||
return False
|
||||
|
||||
self.session_callbacks[session_id]['status'].append(callback)
|
||||
logger.debug(f"Added status callback for session {session_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add status callback for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
async def get_available_data_range(self, symbol: str,
|
||||
exchange: Optional[str] = None) -> Optional[Dict[str, datetime]]:
|
||||
"""Get available data time range for replay."""
|
||||
try:
|
||||
# Query database for data range
|
||||
if exchange:
|
||||
query = """
|
||||
SELECT
|
||||
MIN(timestamp) as start_time,
|
||||
MAX(timestamp) as end_time
|
||||
FROM order_book_snapshots
|
||||
WHERE symbol = $1 AND exchange = $2
|
||||
"""
|
||||
result = await self.storage_manager.connection_pool.fetchrow(query, symbol, exchange)
|
||||
else:
|
||||
query = """
|
||||
SELECT
|
||||
MIN(timestamp) as start_time,
|
||||
MAX(timestamp) as end_time
|
||||
FROM order_book_snapshots
|
||||
WHERE symbol = $1
|
||||
"""
|
||||
result = await self.storage_manager.connection_pool.fetchrow(query, symbol)
|
||||
|
||||
if result and result['start_time'] and result['end_time']:
|
||||
return {
|
||||
'start': result['start_time'],
|
||||
'end': result['end_time']
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get data range for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def validate_replay_request(self, start_time: datetime, end_time: datetime,
|
||||
symbols: Optional[List[str]] = None,
|
||||
exchanges: Optional[List[str]] = None) -> List[str]:
|
||||
"""Validate replay request parameters."""
|
||||
errors = []
|
||||
|
||||
# Validate time range
|
||||
if start_time >= end_time:
|
||||
errors.append("Start time must be before end time")
|
||||
|
||||
# Check if time range is too large (more than 30 days)
|
||||
if (end_time - start_time).days > 30:
|
||||
errors.append("Time range cannot exceed 30 days")
|
||||
|
||||
# Check if start time is too far in the past (more than 1 year)
|
||||
if (get_current_timestamp() - start_time).days > 365:
|
||||
errors.append("Start time cannot be more than 1 year ago")
|
||||
|
||||
# Validate symbols
|
||||
if symbols:
|
||||
for symbol in symbols:
|
||||
if not symbol or len(symbol) < 3:
|
||||
errors.append(f"Invalid symbol: {symbol}")
|
||||
|
||||
# Validate exchanges
|
||||
if exchanges:
|
||||
valid_exchanges = self.config.exchanges.exchanges
|
||||
for exchange in exchanges:
|
||||
if exchange not in valid_exchanges:
|
||||
errors.append(f"Unsupported exchange: {exchange}")
|
||||
|
||||
return errors
|
||||
|
||||
async def _replay_task(self, session_id: str) -> None:
|
||||
"""Main replay task that streams historical data."""
|
||||
try:
|
||||
session = self.sessions[session_id]
|
||||
|
||||
# Calculate total events for progress tracking
|
||||
await self._calculate_total_events(session_id)
|
||||
|
||||
# Stream data
|
||||
await self._stream_historical_data(session_id)
|
||||
|
||||
# Mark as completed
|
||||
session.status = ReplayStatus.COMPLETED
|
||||
session.completed_at = get_current_timestamp()
|
||||
session.progress = 1.0
|
||||
|
||||
await self._notify_status_callbacks(session_id, ReplayStatus.COMPLETED)
|
||||
self.stats['sessions_completed'] += 1
|
||||
|
||||
logger.info(f"Completed replay session {session_id}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Replay session {session_id} was cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Replay task failed for session {session_id}: {e}")
|
||||
await self._set_session_error(session_id, str(e))
|
||||
self.stats['sessions_failed'] += 1
|
||||
|
||||
async def _calculate_total_events(self, session_id: str) -> None:
|
||||
"""Calculate total number of events for progress tracking."""
|
||||
try:
|
||||
session = self.sessions[session_id]
|
||||
|
||||
# Build query conditions
|
||||
conditions = ["timestamp >= $1", "timestamp <= $2"]
|
||||
params = [session.start_time, session.end_time]
|
||||
param_count = 2
|
||||
|
||||
if session.symbols:
|
||||
param_count += 1
|
||||
conditions.append(f"symbol = ANY(${param_count})")
|
||||
params.append(session.symbols)
|
||||
|
||||
if session.exchanges:
|
||||
param_count += 1
|
||||
conditions.append(f"exchange = ANY(${param_count})")
|
||||
params.append(session.exchanges)
|
||||
|
||||
where_clause = " AND ".join(conditions)
|
||||
|
||||
# Count order book events
|
||||
orderbook_query = f"""
|
||||
SELECT COUNT(*) FROM order_book_snapshots
|
||||
WHERE {where_clause}
|
||||
"""
|
||||
orderbook_count = await self.storage_manager.connection_pool.fetchval(
|
||||
orderbook_query, *params
|
||||
)
|
||||
|
||||
# Count trade events
|
||||
trade_query = f"""
|
||||
SELECT COUNT(*) FROM trade_events
|
||||
WHERE {where_clause}
|
||||
"""
|
||||
trade_count = await self.storage_manager.connection_pool.fetchval(
|
||||
trade_query, *params
|
||||
)
|
||||
|
||||
session.total_events = (orderbook_count or 0) + (trade_count or 0)
|
||||
|
||||
logger.debug(f"Session {session_id} has {session.total_events} total events")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to calculate total events for session {session_id}: {e}")
|
||||
session.total_events = 0
|
||||
|
||||
async def _stream_historical_data(self, session_id: str) -> None:
|
||||
"""Stream historical data for replay session."""
|
||||
session = self.sessions[session_id]
|
||||
|
||||
# Build query conditions
|
||||
conditions = ["timestamp >= $1", "timestamp <= $2"]
|
||||
params = [session.current_time, session.end_time]
|
||||
param_count = 2
|
||||
|
||||
if session.symbols:
|
||||
param_count += 1
|
||||
conditions.append(f"symbol = ANY(${param_count})")
|
||||
params.append(session.symbols)
|
||||
|
||||
if session.exchanges:
|
||||
param_count += 1
|
||||
conditions.append(f"exchange = ANY(${param_count})")
|
||||
params.append(session.exchanges)
|
||||
|
||||
where_clause = " AND ".join(conditions)
|
||||
|
||||
# Query both order book and trade data, ordered by timestamp
|
||||
query = f"""
|
||||
(
|
||||
SELECT 'orderbook' as type, timestamp, symbol, exchange,
|
||||
bids, asks, sequence_id, mid_price, spread, bid_volume, ask_volume,
|
||||
NULL as price, NULL as size, NULL as side, NULL as trade_id
|
||||
FROM order_book_snapshots
|
||||
WHERE {where_clause}
|
||||
)
|
||||
UNION ALL
|
||||
(
|
||||
SELECT 'trade' as type, timestamp, symbol, exchange,
|
||||
NULL as bids, NULL as asks, NULL as sequence_id,
|
||||
NULL as mid_price, NULL as spread, NULL as bid_volume, NULL as ask_volume,
|
||||
price, size, side, trade_id
|
||||
FROM trade_events
|
||||
WHERE {where_clause}
|
||||
)
|
||||
ORDER BY timestamp ASC
|
||||
"""
|
||||
|
||||
# Stream data in chunks
|
||||
chunk_size = 1000
|
||||
offset = 0
|
||||
last_timestamp = session.current_time
|
||||
|
||||
while session.status == ReplayStatus.RUNNING:
|
||||
# Fetch chunk
|
||||
chunk_query = f"{query} LIMIT {chunk_size} OFFSET {offset}"
|
||||
rows = await self.storage_manager.connection_pool.fetch(chunk_query, *params)
|
||||
|
||||
if not rows:
|
||||
break
|
||||
|
||||
# Process each row
|
||||
for row in rows:
|
||||
if session.status != ReplayStatus.RUNNING:
|
||||
break
|
||||
|
||||
# Calculate delay based on replay speed
|
||||
if last_timestamp < row['timestamp']:
|
||||
time_diff = (row['timestamp'] - last_timestamp).total_seconds()
|
||||
delay = time_diff / session.speed
|
||||
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Create data object
|
||||
if row['type'] == 'orderbook':
|
||||
data = await self._create_orderbook_from_row(row)
|
||||
else:
|
||||
data = await self._create_trade_from_row(row)
|
||||
|
||||
# Notify data callbacks
|
||||
await self._notify_data_callbacks(session_id, data)
|
||||
|
||||
# Update session progress
|
||||
session.events_replayed += 1
|
||||
session.current_time = row['timestamp']
|
||||
|
||||
if session.total_events > 0:
|
||||
session.progress = session.events_replayed / session.total_events
|
||||
|
||||
last_timestamp = row['timestamp']
|
||||
self.stats['total_events_replayed'] += 1
|
||||
|
||||
offset += chunk_size
|
||||
|
||||
async def _create_orderbook_from_row(self, row: Dict) -> OrderBookSnapshot:
|
||||
"""Create OrderBookSnapshot from database row."""
|
||||
import json
|
||||
from ..models.core import PriceLevel
|
||||
|
||||
# Parse bids and asks from JSON
|
||||
bids_data = json.loads(row['bids']) if row['bids'] else []
|
||||
asks_data = json.loads(row['asks']) if row['asks'] else []
|
||||
|
||||
bids = [PriceLevel(price=b['price'], size=b['size'], count=b.get('count'))
|
||||
for b in bids_data]
|
||||
asks = [PriceLevel(price=a['price'], size=a['size'], count=a.get('count'))
|
||||
for a in asks_data]
|
||||
|
||||
return OrderBookSnapshot(
|
||||
symbol=row['symbol'],
|
||||
exchange=row['exchange'],
|
||||
timestamp=row['timestamp'],
|
||||
bids=bids,
|
||||
asks=asks,
|
||||
sequence_id=row['sequence_id']
|
||||
)
|
||||
|
||||
async def _create_trade_from_row(self, row: Dict) -> TradeEvent:
|
||||
"""Create TradeEvent from database row."""
|
||||
return TradeEvent(
|
||||
symbol=row['symbol'],
|
||||
exchange=row['exchange'],
|
||||
timestamp=row['timestamp'],
|
||||
price=float(row['price']),
|
||||
size=float(row['size']),
|
||||
side=row['side'],
|
||||
trade_id=row['trade_id']
|
||||
)
|
||||
|
||||
async def _notify_data_callbacks(self, session_id: str,
|
||||
data: Union[OrderBookSnapshot, TradeEvent]) -> None:
|
||||
"""Notify all data callbacks for a session."""
|
||||
if session_id in self.session_callbacks:
|
||||
callbacks = self.session_callbacks[session_id]['data']
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(data)
|
||||
else:
|
||||
callback(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Data callback error for session {session_id}: {e}")
|
||||
|
||||
async def _notify_status_callbacks(self, session_id: str, status: ReplayStatus) -> None:
|
||||
"""Notify all status callbacks for a session."""
|
||||
if session_id in self.session_callbacks:
|
||||
callbacks = self.session_callbacks[session_id]['status']
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(session_id, status)
|
||||
else:
|
||||
callback(session_id, status)
|
||||
except Exception as e:
|
||||
logger.error(f"Status callback error for session {session_id}: {e}")
|
||||
|
||||
async def _set_session_error(self, session_id: str, error_message: str) -> None:
|
||||
"""Set session to error state."""
|
||||
if session_id in self.sessions:
|
||||
session = self.sessions[session_id]
|
||||
session.status = ReplayStatus.ERROR
|
||||
session.error_message = error_message
|
||||
session.stopped_at = get_current_timestamp()
|
||||
|
||||
await self._notify_status_callbacks(session_id, ReplayStatus.ERROR)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get replay manager statistics."""
|
||||
active_sessions = sum(1 for s in self.sessions.values()
|
||||
if s.status == ReplayStatus.RUNNING)
|
||||
|
||||
return {
|
||||
**self.stats,
|
||||
'active_sessions': active_sessions,
|
||||
'total_sessions': len(self.sessions),
|
||||
'session_statuses': {
|
||||
status.value: sum(1 for s in self.sessions.values() if s.status == status)
|
||||
for status in ReplayStatus
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user