494 lines
21 KiB
Python
494 lines
21 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Data Stream Monitor for Model Input Capture and Replay
|
|
|
|
Captures and streams all model input data in console-friendly text format.
|
|
Suitable for snapshots, training, and replay functionality.
|
|
"""
|
|
|
|
import logging
|
|
import json
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Dict, List, Any, Optional
|
|
from collections import deque
|
|
import threading
|
|
import os
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class DataStreamMonitor:
|
|
"""Monitors and streams all model input data for training and replay"""
|
|
|
|
def __init__(self, orchestrator=None, data_provider=None, training_system=None):
|
|
self.orchestrator = orchestrator
|
|
self.data_provider = data_provider
|
|
self.training_system = training_system
|
|
|
|
# Data buffers for streaming
|
|
self.data_streams = {
|
|
'ohlcv_1m': deque(maxlen=100),
|
|
'ohlcv_5m': deque(maxlen=50),
|
|
'ohlcv_15m': deque(maxlen=20),
|
|
'ticks': deque(maxlen=200),
|
|
'cob_raw': deque(maxlen=100),
|
|
'cob_aggregated': deque(maxlen=50),
|
|
'technical_indicators': deque(maxlen=100),
|
|
'model_states': deque(maxlen=50),
|
|
'predictions': deque(maxlen=100),
|
|
'training_experiences': deque(maxlen=200)
|
|
}
|
|
|
|
# Streaming configuration
|
|
self.stream_config = {
|
|
'console_output': True,
|
|
'compact_format': False,
|
|
'include_timestamps': True,
|
|
'filter_symbols': ['ETH/USDT'], # Focus on primary symbols
|
|
'sampling_rate': 1.0 # seconds between samples
|
|
}
|
|
|
|
self.is_streaming = False
|
|
self.stream_thread = None
|
|
self.last_sample_time = 0
|
|
|
|
logger.info("DataStreamMonitor initialized")
|
|
|
|
def start_streaming(self):
|
|
"""Start the data streaming thread"""
|
|
if self.is_streaming:
|
|
logger.warning("Data streaming already active")
|
|
return
|
|
|
|
self.is_streaming = True
|
|
self.stream_thread = threading.Thread(target=self._streaming_worker, daemon=True)
|
|
self.stream_thread.start()
|
|
logger.info("Data streaming started")
|
|
|
|
def stop_streaming(self):
|
|
"""Stop the data streaming"""
|
|
self.is_streaming = False
|
|
if self.stream_thread:
|
|
self.stream_thread.join(timeout=2)
|
|
logger.info("Data streaming stopped")
|
|
|
|
def _streaming_worker(self):
|
|
"""Main streaming worker that collects and outputs data"""
|
|
while self.is_streaming:
|
|
try:
|
|
current_time = time.time()
|
|
if current_time - self.last_sample_time >= self.stream_config['sampling_rate']:
|
|
self._collect_data_sample()
|
|
self._output_data_sample()
|
|
self.last_sample_time = current_time
|
|
|
|
time.sleep(0.5) # Check every 500ms
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in streaming worker: {e}")
|
|
time.sleep(2)
|
|
|
|
def _collect_data_sample(self):
|
|
"""Collect one sample of all data streams"""
|
|
try:
|
|
timestamp = datetime.now()
|
|
|
|
# 1. OHLCV Data Collection
|
|
self._collect_ohlcv_data(timestamp)
|
|
|
|
# 2. Tick Data Collection
|
|
self._collect_tick_data(timestamp)
|
|
|
|
# 3. COB Data Collection
|
|
self._collect_cob_data(timestamp)
|
|
|
|
# 4. Technical Indicators
|
|
self._collect_technical_indicators(timestamp)
|
|
|
|
# 5. Model States
|
|
self._collect_model_states(timestamp)
|
|
|
|
# 6. Predictions
|
|
self._collect_predictions(timestamp)
|
|
|
|
# 7. Training Experiences
|
|
self._collect_training_experiences(timestamp)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error collecting data sample: {e}")
|
|
|
|
def _collect_ohlcv_data(self, timestamp: datetime):
|
|
"""Collect OHLCV data for all timeframes"""
|
|
try:
|
|
for symbol in self.stream_config['filter_symbols']:
|
|
for timeframe in ['1m', '5m', '15m']:
|
|
if self.data_provider:
|
|
df = self.data_provider.get_historical_data(symbol, timeframe, limit=5)
|
|
if df is not None and not df.empty:
|
|
latest_bar = {
|
|
'timestamp': timestamp.isoformat(),
|
|
'symbol': symbol,
|
|
'timeframe': timeframe,
|
|
'open': float(df['open'].iloc[-1]),
|
|
'high': float(df['high'].iloc[-1]),
|
|
'low': float(df['low'].iloc[-1]),
|
|
'close': float(df['close'].iloc[-1]),
|
|
'volume': float(df['volume'].iloc[-1])
|
|
}
|
|
|
|
stream_key = f'ohlcv_{timeframe}'
|
|
if len(self.data_streams[stream_key]) == 0 or \
|
|
self.data_streams[stream_key][-1]['timestamp'] != latest_bar['timestamp']:
|
|
self.data_streams[stream_key].append(latest_bar)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error collecting OHLCV data: {e}")
|
|
|
|
def _collect_tick_data(self, timestamp: datetime):
|
|
"""Collect real-time tick data"""
|
|
try:
|
|
if self.data_provider and hasattr(self.data_provider, 'get_recent_ticks'):
|
|
recent_ticks = self.data_provider.get_recent_ticks(limit=10)
|
|
for tick in recent_ticks:
|
|
tick_data = {
|
|
'timestamp': timestamp.isoformat(),
|
|
'symbol': tick.get('symbol', 'ETH/USDT'),
|
|
'price': float(tick.get('price', 0)),
|
|
'volume': float(tick.get('volume', 0)),
|
|
'side': tick.get('side', 'unknown'),
|
|
'trade_id': tick.get('trade_id', ''),
|
|
'is_buyer_maker': tick.get('is_buyer_maker', False)
|
|
}
|
|
|
|
# Only add if different from last tick
|
|
if len(self.data_streams['ticks']) == 0 or \
|
|
self.data_streams['ticks'][-1]['trade_id'] != tick_data['trade_id']:
|
|
self.data_streams['ticks'].append(tick_data)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error collecting tick data: {e}")
|
|
|
|
def _collect_cob_data(self, timestamp: datetime):
|
|
"""Collect COB (Consolidated Order Book) data"""
|
|
try:
|
|
# Raw COB snapshots
|
|
if hasattr(self, 'orchestrator') and self.orchestrator and \
|
|
hasattr(self.orchestrator, 'latest_cob_data'):
|
|
for symbol in self.stream_config['filter_symbols']:
|
|
if symbol in self.orchestrator.latest_cob_data:
|
|
cob_data = self.orchestrator.latest_cob_data[symbol]
|
|
|
|
raw_cob = {
|
|
'timestamp': timestamp.isoformat(),
|
|
'symbol': symbol,
|
|
'stats': cob_data.get('stats', {}),
|
|
'bids_count': len(cob_data.get('bids', [])),
|
|
'asks_count': len(cob_data.get('asks', [])),
|
|
'imbalance': cob_data.get('stats', {}).get('imbalance', 0),
|
|
'spread_bps': cob_data.get('stats', {}).get('spread_bps', 0),
|
|
'mid_price': cob_data.get('stats', {}).get('mid_price', 0)
|
|
}
|
|
|
|
self.data_streams['cob_raw'].append(raw_cob)
|
|
|
|
# Top 5 bids and asks for aggregation
|
|
if cob_data.get('bids') and cob_data.get('asks'):
|
|
aggregated_cob = {
|
|
'timestamp': timestamp.isoformat(),
|
|
'symbol': symbol,
|
|
'bids': cob_data['bids'][:5], # Top 5 bids
|
|
'asks': cob_data['asks'][:5], # Top 5 asks
|
|
'imbalance': raw_cob['imbalance'],
|
|
'spread_bps': raw_cob['spread_bps']
|
|
}
|
|
self.data_streams['cob_aggregated'].append(aggregated_cob)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error collecting COB data: {e}")
|
|
|
|
def _collect_technical_indicators(self, timestamp: datetime):
|
|
"""Collect technical indicators"""
|
|
try:
|
|
if self.data_provider and hasattr(self.data_provider, 'calculate_technical_indicators'):
|
|
for symbol in self.stream_config['filter_symbols']:
|
|
indicators = self.data_provider.calculate_technical_indicators(symbol)
|
|
|
|
if indicators:
|
|
indicator_data = {
|
|
'timestamp': timestamp.isoformat(),
|
|
'symbol': symbol,
|
|
'indicators': indicators
|
|
}
|
|
self.data_streams['technical_indicators'].append(indicator_data)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error collecting technical indicators: {e}")
|
|
|
|
def _collect_model_states(self, timestamp: datetime):
|
|
"""Collect current model states for each model"""
|
|
try:
|
|
if not self.orchestrator:
|
|
return
|
|
|
|
model_states = {}
|
|
|
|
# DQN State
|
|
if hasattr(self.orchestrator, 'build_comprehensive_rl_state'):
|
|
for symbol in self.stream_config['filter_symbols']:
|
|
rl_state = self.orchestrator.build_comprehensive_rl_state(symbol)
|
|
if rl_state:
|
|
model_states['dqn'] = {
|
|
'symbol': symbol,
|
|
'state_vector': rl_state.get('state_vector', []),
|
|
'features': rl_state.get('features', {}),
|
|
'metadata': rl_state.get('metadata', {})
|
|
}
|
|
|
|
# CNN State
|
|
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
|
for symbol in self.stream_config['filter_symbols']:
|
|
if hasattr(self.orchestrator.cnn_model, 'get_state_features'):
|
|
cnn_features = self.orchestrator.cnn_model.get_state_features(symbol)
|
|
if cnn_features:
|
|
model_states['cnn'] = {
|
|
'symbol': symbol,
|
|
'features': cnn_features
|
|
}
|
|
|
|
# RL Agent State
|
|
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
|
rl_state_data = {
|
|
'epsilon': getattr(self.orchestrator.cob_rl_agent, 'epsilon', 0),
|
|
'total_steps': getattr(self.orchestrator.cob_rl_agent, 'total_steps', 0),
|
|
'current_reward': getattr(self.orchestrator.cob_rl_agent, 'current_reward', 0)
|
|
}
|
|
model_states['rl_agent'] = rl_state_data
|
|
|
|
if model_states:
|
|
state_sample = {
|
|
'timestamp': timestamp.isoformat(),
|
|
'models': model_states
|
|
}
|
|
self.data_streams['model_states'].append(state_sample)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error collecting model states: {e}")
|
|
|
|
def _collect_predictions(self, timestamp: datetime):
|
|
"""Collect recent predictions from all models"""
|
|
try:
|
|
if not self.orchestrator:
|
|
return
|
|
|
|
predictions = {}
|
|
|
|
# Get predictions from orchestrator
|
|
if hasattr(self.orchestrator, 'get_recent_predictions'):
|
|
recent_preds = self.orchestrator.get_recent_predictions(limit=5)
|
|
for pred in recent_preds:
|
|
model_name = pred.get('model_name', 'unknown')
|
|
if model_name not in predictions:
|
|
predictions[model_name] = []
|
|
predictions[model_name].append({
|
|
'timestamp': pred.get('timestamp', timestamp.isoformat()),
|
|
'symbol': pred.get('symbol', 'ETH/USDT'),
|
|
'prediction': pred.get('prediction'),
|
|
'confidence': pred.get('confidence', 0),
|
|
'action': pred.get('action')
|
|
})
|
|
|
|
if predictions:
|
|
prediction_sample = {
|
|
'timestamp': timestamp.isoformat(),
|
|
'predictions': predictions
|
|
}
|
|
self.data_streams['predictions'].append(prediction_sample)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error collecting predictions: {e}")
|
|
|
|
def _collect_training_experiences(self, timestamp: datetime):
|
|
"""Collect training experiences from the training system"""
|
|
try:
|
|
if self.training_system and hasattr(self.training_system, 'experience_buffer'):
|
|
# Get recent experiences
|
|
recent_experiences = list(self.training_system.experience_buffer)[-10:] # Last 10
|
|
|
|
for exp in recent_experiences:
|
|
experience_data = {
|
|
'timestamp': timestamp.isoformat(),
|
|
'state': exp.get('state', []),
|
|
'action': exp.get('action'),
|
|
'reward': exp.get('reward', 0),
|
|
'next_state': exp.get('next_state', []),
|
|
'done': exp.get('done', False),
|
|
'info': exp.get('info', {})
|
|
}
|
|
self.data_streams['training_experiences'].append(experience_data)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error collecting training experiences: {e}")
|
|
|
|
def _output_data_sample(self):
|
|
"""Output the current data sample to console"""
|
|
if not self.stream_config['console_output']:
|
|
return
|
|
|
|
try:
|
|
# Get latest data from each stream
|
|
sample_data = {}
|
|
for stream_name, stream_data in self.data_streams.items():
|
|
if stream_data:
|
|
sample_data[stream_name] = list(stream_data)[-5:] # Last 5 entries
|
|
|
|
if sample_data:
|
|
if self.stream_config['compact_format']:
|
|
self._output_compact_format(sample_data)
|
|
else:
|
|
self._output_detailed_format(sample_data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error outputting data sample: {e}")
|
|
|
|
def _output_compact_format(self, sample_data: Dict):
|
|
"""Output data in compact JSON format"""
|
|
try:
|
|
# Create compact summary
|
|
summary = {
|
|
'timestamp': datetime.now().isoformat(),
|
|
'ohlcv_count': len(sample_data.get('ohlcv_1m', [])),
|
|
'ticks_count': len(sample_data.get('ticks', [])),
|
|
'cob_count': len(sample_data.get('cob_raw', [])),
|
|
'predictions_count': len(sample_data.get('predictions', [])),
|
|
'experiences_count': len(sample_data.get('training_experiences', []))
|
|
}
|
|
|
|
# Add latest OHLCV if available
|
|
if sample_data.get('ohlcv_1m'):
|
|
latest_ohlcv = sample_data['ohlcv_1m'][-1]
|
|
summary['price'] = latest_ohlcv['close']
|
|
summary['volume'] = latest_ohlcv['volume']
|
|
|
|
# Add latest COB if available
|
|
if sample_data.get('cob_raw'):
|
|
latest_cob = sample_data['cob_raw'][-1]
|
|
summary['imbalance'] = latest_cob['imbalance']
|
|
summary['spread_bps'] = latest_cob['spread_bps']
|
|
|
|
print(f"DATA_STREAM: {json.dumps(summary, separators=(',', ':'))}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in compact output: {e}")
|
|
|
|
def _output_detailed_format(self, sample_data: Dict):
|
|
"""Output data in detailed human-readable format"""
|
|
try:
|
|
print(f"\n{'='*80}")
|
|
print(f"DATA STREAM SAMPLE - {datetime.now().strftime('%H:%M:%S')}")
|
|
print(f"{'='*80}")
|
|
|
|
# OHLCV Data
|
|
if sample_data.get('ohlcv_1m'):
|
|
latest = sample_data['ohlcv_1m'][-1]
|
|
print(f"OHLCV (1m): {latest['symbol']} | O:{latest['open']:.2f} H:{latest['high']:.2f} L:{latest['low']:.2f} C:{latest['close']:.2f} V:{latest['volume']:.1f}")
|
|
|
|
# Tick Data
|
|
if sample_data.get('ticks'):
|
|
latest_tick = sample_data['ticks'][-1]
|
|
print(f"TICK: {latest_tick['symbol']} | Price:{latest_tick['price']:.2f} Vol:{latest_tick['volume']:.4f} Side:{latest_tick['side']}")
|
|
|
|
# COB Data
|
|
if sample_data.get('cob_raw'):
|
|
latest_cob = sample_data['cob_raw'][-1]
|
|
print(f"COB: {latest_cob['symbol']} | Imbalance:{latest_cob['imbalance']:.3f} Spread:{latest_cob['spread_bps']:.1f}bps Mid:{latest_cob['mid_price']:.2f}")
|
|
|
|
# Model States
|
|
if sample_data.get('model_states'):
|
|
latest_state = sample_data['model_states'][-1]
|
|
models = latest_state.get('models', {})
|
|
if 'dqn' in models:
|
|
dqn_state = models['dqn']
|
|
state_vec = dqn_state.get('state_vector', [])
|
|
print(f"DQN State: {len(state_vec)} features | Price:{state_vec[0]*10000:.2f} if state_vec else 'No state'")
|
|
|
|
# Predictions
|
|
if sample_data.get('predictions'):
|
|
latest_preds = sample_data['predictions'][-1]
|
|
for model_name, preds in latest_preds.get('predictions', {}).items():
|
|
if preds:
|
|
latest_pred = preds[-1]
|
|
action = latest_pred.get('action', 'N/A')
|
|
conf = latest_pred.get('confidence', 0)
|
|
print(f"{model_name.upper()} Prediction: {action} (conf:{conf:.2f})")
|
|
|
|
# Training Experiences
|
|
if sample_data.get('training_experiences'):
|
|
latest_exp = sample_data['training_experiences'][-1]
|
|
reward = latest_exp.get('reward', 0)
|
|
action = latest_exp.get('action', 'N/A')
|
|
done = latest_exp.get('done', False)
|
|
print(f"Training Exp: Action:{action} Reward:{reward:.4f} Done:{done}")
|
|
|
|
print(f"{'='*80}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in detailed output: {e}")
|
|
|
|
def get_stream_snapshot(self) -> Dict[str, List]:
|
|
"""Get a complete snapshot of all data streams"""
|
|
return {stream_name: list(stream_data) for stream_name, stream_data in self.data_streams.items()}
|
|
|
|
def save_snapshot(self, filepath: str):
|
|
"""Save current data streams to file"""
|
|
try:
|
|
snapshot = self.get_stream_snapshot()
|
|
snapshot['metadata'] = {
|
|
'timestamp': datetime.now().isoformat(),
|
|
'config': self.stream_config
|
|
}
|
|
|
|
with open(filepath, 'w') as f:
|
|
json.dump(snapshot, f, indent=2, default=str)
|
|
|
|
logger.info(f"Data stream snapshot saved to {filepath}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving snapshot: {e}")
|
|
|
|
def load_snapshot(self, filepath: str):
|
|
"""Load data streams from file"""
|
|
try:
|
|
with open(filepath, 'r') as f:
|
|
snapshot = json.load(f)
|
|
|
|
for stream_name, data in snapshot.items():
|
|
if stream_name in self.data_streams and stream_name != 'metadata':
|
|
self.data_streams[stream_name].clear()
|
|
self.data_streams[stream_name].extend(data)
|
|
|
|
logger.info(f"Data stream snapshot loaded from {filepath}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading snapshot: {e}")
|
|
|
|
|
|
# Global instance for easy access
|
|
_data_stream_monitor = None
|
|
|
|
def get_data_stream_monitor(orchestrator=None, data_provider=None, training_system=None) -> DataStreamMonitor:
|
|
"""Get or create the global data stream monitor instance"""
|
|
global _data_stream_monitor
|
|
if _data_stream_monitor is None:
|
|
_data_stream_monitor = DataStreamMonitor(orchestrator, data_provider, training_system)
|
|
elif orchestrator is not None or data_provider is not None or training_system is not None:
|
|
# Update existing instance with new connections if provided
|
|
if orchestrator is not None:
|
|
_data_stream_monitor.orchestrator = orchestrator
|
|
if data_provider is not None:
|
|
_data_stream_monitor.data_provider = data_provider
|
|
if training_system is not None:
|
|
_data_stream_monitor.training_system = training_system
|
|
logger.info("Updated existing DataStreamMonitor with new connections")
|
|
return _data_stream_monitor
|
|
|