fix models loading /saving issue
This commit is contained in:
484
data_stream_monitor.py
Normal file
484
data_stream_monitor.py
Normal file
@@ -0,0 +1,484 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Data Stream Monitor for Model Input Capture and Replay
|
||||
|
||||
Captures and streams all model input data in console-friendly text format.
|
||||
Suitable for snapshots, training, and replay functionality.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Optional
|
||||
from collections import deque
|
||||
import threading
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DataStreamMonitor:
|
||||
"""Monitors and streams all model input data for training and replay"""
|
||||
|
||||
def __init__(self, orchestrator=None, data_provider=None, training_system=None):
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = data_provider
|
||||
self.training_system = training_system
|
||||
|
||||
# Data buffers for streaming
|
||||
self.data_streams = {
|
||||
'ohlcv_1m': deque(maxlen=100),
|
||||
'ohlcv_5m': deque(maxlen=50),
|
||||
'ohlcv_15m': deque(maxlen=20),
|
||||
'ticks': deque(maxlen=200),
|
||||
'cob_raw': deque(maxlen=100),
|
||||
'cob_aggregated': deque(maxlen=50),
|
||||
'technical_indicators': deque(maxlen=100),
|
||||
'model_states': deque(maxlen=50),
|
||||
'predictions': deque(maxlen=100),
|
||||
'training_experiences': deque(maxlen=200)
|
||||
}
|
||||
|
||||
# Streaming configuration
|
||||
self.stream_config = {
|
||||
'console_output': True,
|
||||
'compact_format': False,
|
||||
'include_timestamps': True,
|
||||
'filter_symbols': ['ETH/USDT'], # Focus on primary symbols
|
||||
'sampling_rate': 1.0 # seconds between samples
|
||||
}
|
||||
|
||||
self.is_streaming = False
|
||||
self.stream_thread = None
|
||||
self.last_sample_time = 0
|
||||
|
||||
logger.info("DataStreamMonitor initialized")
|
||||
|
||||
def start_streaming(self):
|
||||
"""Start the data streaming thread"""
|
||||
if self.is_streaming:
|
||||
logger.warning("Data streaming already active")
|
||||
return
|
||||
|
||||
self.is_streaming = True
|
||||
self.stream_thread = threading.Thread(target=self._streaming_worker, daemon=True)
|
||||
self.stream_thread.start()
|
||||
logger.info("Data streaming started")
|
||||
|
||||
def stop_streaming(self):
|
||||
"""Stop the data streaming"""
|
||||
self.is_streaming = False
|
||||
if self.stream_thread:
|
||||
self.stream_thread.join(timeout=2)
|
||||
logger.info("Data streaming stopped")
|
||||
|
||||
def _streaming_worker(self):
|
||||
"""Main streaming worker that collects and outputs data"""
|
||||
while self.is_streaming:
|
||||
try:
|
||||
current_time = time.time()
|
||||
if current_time - self.last_sample_time >= self.stream_config['sampling_rate']:
|
||||
self._collect_data_sample()
|
||||
self._output_data_sample()
|
||||
self.last_sample_time = current_time
|
||||
|
||||
time.sleep(0.5) # Check every 500ms
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming worker: {e}")
|
||||
time.sleep(2)
|
||||
|
||||
def _collect_data_sample(self):
|
||||
"""Collect one sample of all data streams"""
|
||||
try:
|
||||
timestamp = datetime.now()
|
||||
|
||||
# 1. OHLCV Data Collection
|
||||
self._collect_ohlcv_data(timestamp)
|
||||
|
||||
# 2. Tick Data Collection
|
||||
self._collect_tick_data(timestamp)
|
||||
|
||||
# 3. COB Data Collection
|
||||
self._collect_cob_data(timestamp)
|
||||
|
||||
# 4. Technical Indicators
|
||||
self._collect_technical_indicators(timestamp)
|
||||
|
||||
# 5. Model States
|
||||
self._collect_model_states(timestamp)
|
||||
|
||||
# 6. Predictions
|
||||
self._collect_predictions(timestamp)
|
||||
|
||||
# 7. Training Experiences
|
||||
self._collect_training_experiences(timestamp)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting data sample: {e}")
|
||||
|
||||
def _collect_ohlcv_data(self, timestamp: datetime):
|
||||
"""Collect OHLCV data for all timeframes"""
|
||||
try:
|
||||
for symbol in self.stream_config['filter_symbols']:
|
||||
for timeframe in ['1m', '5m', '15m']:
|
||||
if self.data_provider:
|
||||
df = self.data_provider.get_historical_data(symbol, timeframe, limit=5)
|
||||
if df is not None and not df.empty:
|
||||
latest_bar = {
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe,
|
||||
'open': float(df['open'].iloc[-1]),
|
||||
'high': float(df['high'].iloc[-1]),
|
||||
'low': float(df['low'].iloc[-1]),
|
||||
'close': float(df['close'].iloc[-1]),
|
||||
'volume': float(df['volume'].iloc[-1])
|
||||
}
|
||||
|
||||
stream_key = f'ohlcv_{timeframe}'
|
||||
if len(self.data_streams[stream_key]) == 0 or \
|
||||
self.data_streams[stream_key][-1]['timestamp'] != latest_bar['timestamp']:
|
||||
self.data_streams[stream_key].append(latest_bar)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error collecting OHLCV data: {e}")
|
||||
|
||||
def _collect_tick_data(self, timestamp: datetime):
|
||||
"""Collect real-time tick data"""
|
||||
try:
|
||||
if self.data_provider and hasattr(self.data_provider, 'get_recent_ticks'):
|
||||
recent_ticks = self.data_provider.get_recent_ticks(limit=10)
|
||||
for tick in recent_ticks:
|
||||
tick_data = {
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'symbol': tick.get('symbol', 'ETH/USDT'),
|
||||
'price': float(tick.get('price', 0)),
|
||||
'volume': float(tick.get('volume', 0)),
|
||||
'side': tick.get('side', 'unknown'),
|
||||
'trade_id': tick.get('trade_id', ''),
|
||||
'is_buyer_maker': tick.get('is_buyer_maker', False)
|
||||
}
|
||||
|
||||
# Only add if different from last tick
|
||||
if len(self.data_streams['ticks']) == 0 or \
|
||||
self.data_streams['ticks'][-1]['trade_id'] != tick_data['trade_id']:
|
||||
self.data_streams['ticks'].append(tick_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error collecting tick data: {e}")
|
||||
|
||||
def _collect_cob_data(self, timestamp: datetime):
|
||||
"""Collect COB (Consolidated Order Book) data"""
|
||||
try:
|
||||
# Raw COB snapshots
|
||||
if hasattr(self, 'orchestrator') and self.orchestrator and \
|
||||
hasattr(self.orchestrator, 'latest_cob_data'):
|
||||
for symbol in self.stream_config['filter_symbols']:
|
||||
if symbol in self.orchestrator.latest_cob_data:
|
||||
cob_data = self.orchestrator.latest_cob_data[symbol]
|
||||
|
||||
raw_cob = {
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'symbol': symbol,
|
||||
'stats': cob_data.get('stats', {}),
|
||||
'bids_count': len(cob_data.get('bids', [])),
|
||||
'asks_count': len(cob_data.get('asks', [])),
|
||||
'imbalance': cob_data.get('stats', {}).get('imbalance', 0),
|
||||
'spread_bps': cob_data.get('stats', {}).get('spread_bps', 0),
|
||||
'mid_price': cob_data.get('stats', {}).get('mid_price', 0)
|
||||
}
|
||||
|
||||
self.data_streams['cob_raw'].append(raw_cob)
|
||||
|
||||
# Top 5 bids and asks for aggregation
|
||||
if cob_data.get('bids') and cob_data.get('asks'):
|
||||
aggregated_cob = {
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'symbol': symbol,
|
||||
'bids': cob_data['bids'][:5], # Top 5 bids
|
||||
'asks': cob_data['asks'][:5], # Top 5 asks
|
||||
'imbalance': raw_cob['imbalance'],
|
||||
'spread_bps': raw_cob['spread_bps']
|
||||
}
|
||||
self.data_streams['cob_aggregated'].append(aggregated_cob)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error collecting COB data: {e}")
|
||||
|
||||
def _collect_technical_indicators(self, timestamp: datetime):
|
||||
"""Collect technical indicators"""
|
||||
try:
|
||||
if self.data_provider and hasattr(self.data_provider, 'calculate_technical_indicators'):
|
||||
for symbol in self.stream_config['filter_symbols']:
|
||||
indicators = self.data_provider.calculate_technical_indicators(symbol)
|
||||
|
||||
if indicators:
|
||||
indicator_data = {
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'symbol': symbol,
|
||||
'indicators': indicators
|
||||
}
|
||||
self.data_streams['technical_indicators'].append(indicator_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error collecting technical indicators: {e}")
|
||||
|
||||
def _collect_model_states(self, timestamp: datetime):
|
||||
"""Collect current model states for each model"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
return
|
||||
|
||||
model_states = {}
|
||||
|
||||
# DQN State
|
||||
if hasattr(self.orchestrator, 'build_comprehensive_rl_state'):
|
||||
for symbol in self.stream_config['filter_symbols']:
|
||||
rl_state = self.orchestrator.build_comprehensive_rl_state(symbol)
|
||||
if rl_state:
|
||||
model_states['dqn'] = {
|
||||
'symbol': symbol,
|
||||
'state_vector': rl_state.get('state_vector', []),
|
||||
'features': rl_state.get('features', {}),
|
||||
'metadata': rl_state.get('metadata', {})
|
||||
}
|
||||
|
||||
# CNN State
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
for symbol in self.stream_config['filter_symbols']:
|
||||
if hasattr(self.orchestrator.cnn_model, 'get_state_features'):
|
||||
cnn_features = self.orchestrator.cnn_model.get_state_features(symbol)
|
||||
if cnn_features:
|
||||
model_states['cnn'] = {
|
||||
'symbol': symbol,
|
||||
'features': cnn_features
|
||||
}
|
||||
|
||||
# RL Agent State
|
||||
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
||||
rl_state_data = {
|
||||
'epsilon': getattr(self.orchestrator.cob_rl_agent, 'epsilon', 0),
|
||||
'total_steps': getattr(self.orchestrator.cob_rl_agent, 'total_steps', 0),
|
||||
'current_reward': getattr(self.orchestrator.cob_rl_agent, 'current_reward', 0)
|
||||
}
|
||||
model_states['rl_agent'] = rl_state_data
|
||||
|
||||
if model_states:
|
||||
state_sample = {
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'models': model_states
|
||||
}
|
||||
self.data_streams['model_states'].append(state_sample)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error collecting model states: {e}")
|
||||
|
||||
def _collect_predictions(self, timestamp: datetime):
|
||||
"""Collect recent predictions from all models"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
return
|
||||
|
||||
predictions = {}
|
||||
|
||||
# Get predictions from orchestrator
|
||||
if hasattr(self.orchestrator, 'get_recent_predictions'):
|
||||
recent_preds = self.orchestrator.get_recent_predictions(limit=5)
|
||||
for pred in recent_preds:
|
||||
model_name = pred.get('model_name', 'unknown')
|
||||
if model_name not in predictions:
|
||||
predictions[model_name] = []
|
||||
predictions[model_name].append({
|
||||
'timestamp': pred.get('timestamp', timestamp.isoformat()),
|
||||
'symbol': pred.get('symbol', 'ETH/USDT'),
|
||||
'prediction': pred.get('prediction'),
|
||||
'confidence': pred.get('confidence', 0),
|
||||
'action': pred.get('action')
|
||||
})
|
||||
|
||||
if predictions:
|
||||
prediction_sample = {
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'predictions': predictions
|
||||
}
|
||||
self.data_streams['predictions'].append(prediction_sample)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error collecting predictions: {e}")
|
||||
|
||||
def _collect_training_experiences(self, timestamp: datetime):
|
||||
"""Collect training experiences from the training system"""
|
||||
try:
|
||||
if self.training_system and hasattr(self.training_system, 'experience_buffer'):
|
||||
# Get recent experiences
|
||||
recent_experiences = list(self.training_system.experience_buffer)[-10:] # Last 10
|
||||
|
||||
for exp in recent_experiences:
|
||||
experience_data = {
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'state': exp.get('state', []),
|
||||
'action': exp.get('action'),
|
||||
'reward': exp.get('reward', 0),
|
||||
'next_state': exp.get('next_state', []),
|
||||
'done': exp.get('done', False),
|
||||
'info': exp.get('info', {})
|
||||
}
|
||||
self.data_streams['training_experiences'].append(experience_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error collecting training experiences: {e}")
|
||||
|
||||
def _output_data_sample(self):
|
||||
"""Output the current data sample to console"""
|
||||
if not self.stream_config['console_output']:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get latest data from each stream
|
||||
sample_data = {}
|
||||
for stream_name, stream_data in self.data_streams.items():
|
||||
if stream_data:
|
||||
sample_data[stream_name] = list(stream_data)[-5:] # Last 5 entries
|
||||
|
||||
if sample_data:
|
||||
if self.stream_config['compact_format']:
|
||||
self._output_compact_format(sample_data)
|
||||
else:
|
||||
self._output_detailed_format(sample_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error outputting data sample: {e}")
|
||||
|
||||
def _output_compact_format(self, sample_data: Dict):
|
||||
"""Output data in compact JSON format"""
|
||||
try:
|
||||
# Create compact summary
|
||||
summary = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'ohlcv_count': len(sample_data.get('ohlcv_1m', [])),
|
||||
'ticks_count': len(sample_data.get('ticks', [])),
|
||||
'cob_count': len(sample_data.get('cob_raw', [])),
|
||||
'predictions_count': len(sample_data.get('predictions', [])),
|
||||
'experiences_count': len(sample_data.get('training_experiences', []))
|
||||
}
|
||||
|
||||
# Add latest OHLCV if available
|
||||
if sample_data.get('ohlcv_1m'):
|
||||
latest_ohlcv = sample_data['ohlcv_1m'][-1]
|
||||
summary['price'] = latest_ohlcv['close']
|
||||
summary['volume'] = latest_ohlcv['volume']
|
||||
|
||||
# Add latest COB if available
|
||||
if sample_data.get('cob_raw'):
|
||||
latest_cob = sample_data['cob_raw'][-1]
|
||||
summary['imbalance'] = latest_cob['imbalance']
|
||||
summary['spread_bps'] = latest_cob['spread_bps']
|
||||
|
||||
print(f"DATA_STREAM: {json.dumps(summary, separators=(',', ':'))}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in compact output: {e}")
|
||||
|
||||
def _output_detailed_format(self, sample_data: Dict):
|
||||
"""Output data in detailed human-readable format"""
|
||||
try:
|
||||
print(f"\n{'='*80}")
|
||||
print(f"DATA STREAM SAMPLE - {datetime.now().strftime('%H:%M:%S')}")
|
||||
print(f"{'='*80}")
|
||||
|
||||
# OHLCV Data
|
||||
if sample_data.get('ohlcv_1m'):
|
||||
latest = sample_data['ohlcv_1m'][-1]
|
||||
print(f"OHLCV (1m): {latest['symbol']} | O:{latest['open']:.2f} H:{latest['high']:.2f} L:{latest['low']:.2f} C:{latest['close']:.2f} V:{latest['volume']:.1f}")
|
||||
|
||||
# Tick Data
|
||||
if sample_data.get('ticks'):
|
||||
latest_tick = sample_data['ticks'][-1]
|
||||
print(f"TICK: {latest_tick['symbol']} | Price:{latest_tick['price']:.2f} Vol:{latest_tick['volume']:.4f} Side:{latest_tick['side']}")
|
||||
|
||||
# COB Data
|
||||
if sample_data.get('cob_raw'):
|
||||
latest_cob = sample_data['cob_raw'][-1]
|
||||
print(f"COB: {latest_cob['symbol']} | Imbalance:{latest_cob['imbalance']:.3f} Spread:{latest_cob['spread_bps']:.1f}bps Mid:{latest_cob['mid_price']:.2f}")
|
||||
|
||||
# Model States
|
||||
if sample_data.get('model_states'):
|
||||
latest_state = sample_data['model_states'][-1]
|
||||
models = latest_state.get('models', {})
|
||||
if 'dqn' in models:
|
||||
dqn_state = models['dqn']
|
||||
state_vec = dqn_state.get('state_vector', [])
|
||||
print(f"DQN State: {len(state_vec)} features | Price:{state_vec[0]*10000:.2f} if state_vec else 'No state'")
|
||||
|
||||
# Predictions
|
||||
if sample_data.get('predictions'):
|
||||
latest_preds = sample_data['predictions'][-1]
|
||||
for model_name, preds in latest_preds.get('predictions', {}).items():
|
||||
if preds:
|
||||
latest_pred = preds[-1]
|
||||
action = latest_pred.get('action', 'N/A')
|
||||
conf = latest_pred.get('confidence', 0)
|
||||
print(f"{model_name.upper()} Prediction: {action} (conf:{conf:.2f})")
|
||||
|
||||
# Training Experiences
|
||||
if sample_data.get('training_experiences'):
|
||||
latest_exp = sample_data['training_experiences'][-1]
|
||||
reward = latest_exp.get('reward', 0)
|
||||
action = latest_exp.get('action', 'N/A')
|
||||
done = latest_exp.get('done', False)
|
||||
print(f"Training Exp: Action:{action} Reward:{reward:.4f} Done:{done}")
|
||||
|
||||
print(f"{'='*80}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in detailed output: {e}")
|
||||
|
||||
def get_stream_snapshot(self) -> Dict[str, List]:
|
||||
"""Get a complete snapshot of all data streams"""
|
||||
return {stream_name: list(stream_data) for stream_name, stream_data in self.data_streams.items()}
|
||||
|
||||
def save_snapshot(self, filepath: str):
|
||||
"""Save current data streams to file"""
|
||||
try:
|
||||
snapshot = self.get_stream_snapshot()
|
||||
snapshot['metadata'] = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'config': self.stream_config
|
||||
}
|
||||
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(snapshot, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"Data stream snapshot saved to {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving snapshot: {e}")
|
||||
|
||||
def load_snapshot(self, filepath: str):
|
||||
"""Load data streams from file"""
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
snapshot = json.load(f)
|
||||
|
||||
for stream_name, data in snapshot.items():
|
||||
if stream_name in self.data_streams and stream_name != 'metadata':
|
||||
self.data_streams[stream_name].clear()
|
||||
self.data_streams[stream_name].extend(data)
|
||||
|
||||
logger.info(f"Data stream snapshot loaded from {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading snapshot: {e}")
|
||||
|
||||
|
||||
# Global instance for easy access
|
||||
_data_stream_monitor = None
|
||||
|
||||
def get_data_stream_monitor(orchestrator=None, data_provider=None, training_system=None) -> DataStreamMonitor:
|
||||
"""Get or create the global data stream monitor instance"""
|
||||
global _data_stream_monitor
|
||||
if _data_stream_monitor is None:
|
||||
_data_stream_monitor = DataStreamMonitor(orchestrator, data_provider, training_system)
|
||||
return _data_stream_monitor
|
||||
|
Reference in New Issue
Block a user