wip
This commit is contained in:
232
core/data_models.py
Normal file
232
core/data_models.py
Normal file
@ -0,0 +1,232 @@
|
||||
"""
|
||||
Standardized Data Models for Multi-Modal Trading System
|
||||
|
||||
This module defines the standardized data structures used across all models:
|
||||
- BaseDataInput: Unified input format for all models (CNN, RL, LSTM, Transformer)
|
||||
- ModelOutput: Extensible output format supporting all model types
|
||||
- COBData: Cumulative Order Book data structure
|
||||
- Enhanced data structures for cross-model feeding and extensibility
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@dataclass
|
||||
class OHLCVBar:
|
||||
"""OHLCV bar data structure"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: float
|
||||
timeframe: str
|
||||
indicators: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class PivotPoint:
|
||||
"""Pivot point data structure"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
price: float
|
||||
type: str # 'high' or 'low'
|
||||
level: int # Pivot level (1, 2, 3, etc.)
|
||||
confidence: float = 1.0
|
||||
|
||||
@dataclass
|
||||
class ModelOutput:
|
||||
"""Extensible model output format supporting all model types"""
|
||||
model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator'
|
||||
model_name: str # Specific model identifier
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
confidence: float
|
||||
predictions: Dict[str, Any] # Model-specific predictions
|
||||
hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # Additional info
|
||||
|
||||
@dataclass
|
||||
class COBData:
|
||||
"""Cumulative Order Book data for price buckets"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
current_price: float
|
||||
bucket_size: float # $1 for ETH, $10 for BTC
|
||||
price_buckets: Dict[float, Dict[str, float]] # price -> {bid_volume, ask_volume, etc.}
|
||||
bid_ask_imbalance: Dict[float, float] # price -> imbalance ratio
|
||||
volume_weighted_prices: Dict[float, float] # price -> VWAP within bucket
|
||||
order_flow_metrics: Dict[str, float] # Various order flow indicators
|
||||
|
||||
# Moving averages of COB imbalance for ±5 buckets
|
||||
ma_1s_imbalance: Dict[float, float] = field(default_factory=dict) # 1s MA
|
||||
ma_5s_imbalance: Dict[float, float] = field(default_factory=dict) # 5s MA
|
||||
ma_15s_imbalance: Dict[float, float] = field(default_factory=dict) # 15s MA
|
||||
ma_60s_imbalance: Dict[float, float] = field(default_factory=dict) # 60s MA
|
||||
|
||||
@dataclass
|
||||
class BaseDataInput:
|
||||
"""
|
||||
Unified base data input for all models
|
||||
|
||||
Standardized format ensures all models receive identical input structure:
|
||||
- OHLCV: 300 frames of (1s, 1m, 1h, 1d) ETH + 300s of 1s BTC
|
||||
- COB: ±20 buckets of COB amounts in USD for each 1s OHLCV
|
||||
- MA: 1s, 5s, 15s, and 60s MA of COB imbalance counting ±5 COB buckets
|
||||
"""
|
||||
symbol: str # Primary symbol (ETH/USDT)
|
||||
timestamp: datetime
|
||||
|
||||
# Multi-timeframe OHLCV data for primary symbol (ETH)
|
||||
ohlcv_1s: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1s data
|
||||
ohlcv_1m: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1m data
|
||||
ohlcv_1h: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1h data
|
||||
ohlcv_1d: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1d data
|
||||
|
||||
# Reference symbol (BTC) 1s data
|
||||
btc_ohlcv_1s: List[OHLCVBar] = field(default_factory=list) # 300s of 1s BTC data
|
||||
|
||||
# COB data for 1s timeframe (±20 buckets around current price)
|
||||
cob_data: Optional[COBData] = None
|
||||
|
||||
# Technical indicators
|
||||
technical_indicators: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
# Pivot points from Williams Market Structure
|
||||
pivot_points: List[PivotPoint] = field(default_factory=list)
|
||||
|
||||
# Last predictions from all models (for cross-model feeding)
|
||||
last_predictions: Dict[str, ModelOutput] = field(default_factory=dict)
|
||||
|
||||
# Market microstructure data
|
||||
market_microstructure: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def get_feature_vector(self) -> np.ndarray:
|
||||
"""
|
||||
Convert BaseDataInput to standardized feature vector for models
|
||||
|
||||
Returns:
|
||||
np.ndarray: Standardized feature vector combining all data sources
|
||||
"""
|
||||
features = []
|
||||
|
||||
# OHLCV features for ETH (300 frames x 4 timeframes x 5 features = 6000 features)
|
||||
for ohlcv_list in [self.ohlcv_1s, self.ohlcv_1m, self.ohlcv_1h, self.ohlcv_1d]:
|
||||
for bar in ohlcv_list[-300:]: # Ensure exactly 300 frames
|
||||
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
|
||||
|
||||
# BTC OHLCV features (300 frames x 5 features = 1500 features)
|
||||
for bar in self.btc_ohlcv_1s[-300:]: # Ensure exactly 300 frames
|
||||
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
|
||||
|
||||
# COB features (±20 buckets x multiple metrics ≈ 800 features)
|
||||
if self.cob_data:
|
||||
# Price bucket features
|
||||
for price in sorted(self.cob_data.price_buckets.keys()):
|
||||
bucket_data = self.cob_data.price_buckets[price]
|
||||
features.extend([
|
||||
bucket_data.get('bid_volume', 0.0),
|
||||
bucket_data.get('ask_volume', 0.0),
|
||||
bucket_data.get('total_volume', 0.0),
|
||||
bucket_data.get('imbalance', 0.0)
|
||||
])
|
||||
|
||||
# Moving averages of imbalance for ±5 buckets (5 buckets x 4 MAs x 2 sides = 40 features)
|
||||
for ma_dict in [self.cob_data.ma_1s_imbalance, self.cob_data.ma_5s_imbalance,
|
||||
self.cob_data.ma_15s_imbalance, self.cob_data.ma_60s_imbalance]:
|
||||
for price in sorted(list(ma_dict.keys())[:5]): # ±5 buckets
|
||||
features.append(ma_dict[price])
|
||||
|
||||
# Technical indicators (variable, pad to 100 features)
|
||||
indicator_values = list(self.technical_indicators.values())
|
||||
features.extend(indicator_values[:100]) # Take first 100 indicators
|
||||
features.extend([0.0] * max(0, 100 - len(indicator_values))) # Pad if needed
|
||||
|
||||
# Last predictions from other models (variable, pad to 50 features)
|
||||
prediction_features = []
|
||||
for model_output in self.last_predictions.values():
|
||||
prediction_features.extend([
|
||||
model_output.confidence,
|
||||
model_output.predictions.get('buy_probability', 0.0),
|
||||
model_output.predictions.get('sell_probability', 0.0),
|
||||
model_output.predictions.get('hold_probability', 0.0),
|
||||
model_output.predictions.get('expected_reward', 0.0)
|
||||
])
|
||||
features.extend(prediction_features[:50]) # Take first 50 prediction features
|
||||
features.extend([0.0] * max(0, 50 - len(prediction_features))) # Pad if needed
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""
|
||||
Validate that the BaseDataInput contains required data
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise
|
||||
"""
|
||||
# Check that we have required OHLCV data
|
||||
if len(self.ohlcv_1s) < 100: # At least 100 frames
|
||||
return False
|
||||
if len(self.btc_ohlcv_1s) < 100: # At least 100 frames of BTC data
|
||||
return False
|
||||
|
||||
# Check that timestamps are reasonable
|
||||
if not self.timestamp:
|
||||
return False
|
||||
|
||||
# Check symbol format
|
||||
if not self.symbol or '/' not in self.symbol:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@dataclass
|
||||
class TradingAction:
|
||||
"""Trading action output from models"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
action: str # 'BUY', 'SELL', 'HOLD'
|
||||
confidence: float
|
||||
source: str # 'rl', 'cnn', 'orchestrator'
|
||||
price: Optional[float] = None
|
||||
quantity: Optional[float] = None
|
||||
reason: Optional[str] = None
|
||||
|
||||
def create_model_output(model_type: str, model_name: str, symbol: str,
|
||||
action: str, confidence: float,
|
||||
hidden_states: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> ModelOutput:
|
||||
"""
|
||||
Helper function to create standardized ModelOutput
|
||||
|
||||
Args:
|
||||
model_type: Type of model ('cnn', 'rl', 'lstm', 'transformer', 'orchestrator')
|
||||
model_name: Specific model identifier
|
||||
symbol: Trading symbol
|
||||
action: Trading action ('BUY', 'SELL', 'HOLD')
|
||||
confidence: Confidence score (0.0 to 1.0)
|
||||
hidden_states: Optional hidden states for cross-model feeding
|
||||
metadata: Optional additional metadata
|
||||
|
||||
Returns:
|
||||
ModelOutput: Standardized model output
|
||||
"""
|
||||
predictions = {
|
||||
'action': action,
|
||||
'buy_probability': confidence if action == 'BUY' else 0.0,
|
||||
'sell_probability': confidence if action == 'SELL' else 0.0,
|
||||
'hold_probability': confidence if action == 'HOLD' else 0.0,
|
||||
}
|
||||
|
||||
return ModelOutput(
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
confidence=confidence,
|
||||
predictions=predictions,
|
||||
hidden_states=hidden_states or {},
|
||||
metadata=metadata or {}
|
||||
)
|
395
core/model_output_manager.py
Normal file
395
core/model_output_manager.py
Normal file
@ -0,0 +1,395 @@
|
||||
"""
|
||||
Model Output Manager
|
||||
|
||||
This module provides extensible model output storage and management for the multi-modal trading system.
|
||||
Supports CNN, RL, LSTM, Transformer, and future model types with cross-model feeding capabilities.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import pickle
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from collections import deque, defaultdict
|
||||
from threading import Lock
|
||||
from pathlib import Path
|
||||
|
||||
from .data_models import ModelOutput, create_model_output
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelOutputManager:
|
||||
"""
|
||||
Extensible model output storage and management system
|
||||
|
||||
Features:
|
||||
- Standardized ModelOutput storage for all model types
|
||||
- Cross-model feeding with hidden states
|
||||
- Historical output tracking
|
||||
- Metadata management
|
||||
- Persistence and recovery
|
||||
- Performance analytics
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: str = "cache/model_outputs", max_history: int = 1000):
|
||||
"""
|
||||
Initialize the model output manager
|
||||
|
||||
Args:
|
||||
cache_dir: Directory for persistent storage
|
||||
max_history: Maximum number of outputs to keep in memory per model
|
||||
"""
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.max_history = max_history
|
||||
|
||||
# In-memory storage
|
||||
self.current_outputs: Dict[str, Dict[str, ModelOutput]] = defaultdict(dict) # {symbol: {model_name: ModelOutput}}
|
||||
self.output_history: Dict[str, Dict[str, deque]] = defaultdict(lambda: defaultdict(lambda: deque(maxlen=max_history))) # {symbol: {model_name: deque}}
|
||||
self.cross_model_states: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: hidden_states}}
|
||||
|
||||
# Metadata tracking
|
||||
self.model_metadata: Dict[str, Dict[str, Any]] = defaultdict(dict) # {model_name: metadata}
|
||||
self.performance_stats: Dict[str, Dict[str, Any]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: stats}}
|
||||
|
||||
# Thread safety
|
||||
self.storage_lock = Lock()
|
||||
|
||||
# Supported model types
|
||||
self.supported_model_types = {
|
||||
'cnn', 'rl', 'lstm', 'transformer', 'orchestrator',
|
||||
'ensemble', 'hybrid', 'custom' # Extensible for future types
|
||||
}
|
||||
|
||||
logger.info(f"ModelOutputManager initialized with cache dir: {self.cache_dir}")
|
||||
logger.info(f"Supported model types: {self.supported_model_types}")
|
||||
|
||||
def store_output(self, model_output: ModelOutput) -> bool:
|
||||
"""
|
||||
Store model output with full extensibility support
|
||||
|
||||
Args:
|
||||
model_output: ModelOutput from any model type
|
||||
|
||||
Returns:
|
||||
bool: True if stored successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
with self.storage_lock:
|
||||
symbol = model_output.symbol
|
||||
model_name = model_output.model_name
|
||||
model_type = model_output.model_type
|
||||
|
||||
# Validate model type (extensible)
|
||||
if model_type not in self.supported_model_types:
|
||||
logger.warning(f"Unknown model type '{model_type}' - adding to supported types")
|
||||
self.supported_model_types.add(model_type)
|
||||
|
||||
# Store current output
|
||||
self.current_outputs[symbol][model_name] = model_output
|
||||
|
||||
# Add to history
|
||||
self.output_history[symbol][model_name].append(model_output)
|
||||
|
||||
# Store cross-model states if available
|
||||
if model_output.hidden_states:
|
||||
self.cross_model_states[symbol][model_name] = model_output.hidden_states
|
||||
|
||||
# Update model metadata
|
||||
self._update_model_metadata(model_name, model_type, model_output.metadata)
|
||||
|
||||
# Update performance statistics
|
||||
self._update_performance_stats(symbol, model_name, model_output)
|
||||
|
||||
# Persist to disk (async to avoid blocking)
|
||||
self._persist_output_async(model_output)
|
||||
|
||||
logger.debug(f"Stored output from {model_name} ({model_type}) for {symbol}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing model output: {e}")
|
||||
return False
|
||||
|
||||
def get_current_output(self, symbol: str, model_name: str) -> Optional[ModelOutput]:
|
||||
"""
|
||||
Get the current (latest) output from a specific model
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
model_name: Name of the model
|
||||
|
||||
Returns:
|
||||
ModelOutput: Latest output from the model, or None if not available
|
||||
"""
|
||||
try:
|
||||
return self.current_outputs.get(symbol, {}).get(model_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current output for {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def get_all_current_outputs(self, symbol: str) -> Dict[str, ModelOutput]:
|
||||
"""
|
||||
Get all current outputs for a symbol (for cross-model feeding)
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
Dict[str, ModelOutput]: Dictionary of current outputs by model name
|
||||
"""
|
||||
try:
|
||||
return dict(self.current_outputs.get(symbol, {}))
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all current outputs for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def get_output_history(self, symbol: str, model_name: str, count: int = 10) -> List[ModelOutput]:
|
||||
"""
|
||||
Get historical outputs from a model
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
model_name: Name of the model
|
||||
count: Number of historical outputs to retrieve
|
||||
|
||||
Returns:
|
||||
List[ModelOutput]: List of historical outputs (most recent first)
|
||||
"""
|
||||
try:
|
||||
history = self.output_history.get(symbol, {}).get(model_name, deque())
|
||||
return list(history)[-count:][::-1] # Most recent first
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting output history for {model_name}: {e}")
|
||||
return []
|
||||
|
||||
def get_cross_model_states(self, symbol: str, requesting_model: str) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Get hidden states from other models for cross-model feeding
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
requesting_model: Name of the model requesting the states
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, Any]]: Hidden states from other models
|
||||
"""
|
||||
try:
|
||||
all_states = self.cross_model_states.get(symbol, {})
|
||||
# Return states from all models except the requesting one
|
||||
return {model_name: states for model_name, states in all_states.items()
|
||||
if model_name != requesting_model}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting cross-model states for {requesting_model}: {e}")
|
||||
return {}
|
||||
|
||||
def get_model_types_active(self, symbol: str) -> List[str]:
|
||||
"""
|
||||
Get list of active model types for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
List[str]: List of active model types
|
||||
"""
|
||||
try:
|
||||
current_outputs = self.current_outputs.get(symbol, {})
|
||||
return [output.model_type for output in current_outputs.values()]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting active model types for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def get_consensus_prediction(self, symbol: str, confidence_threshold: float = 0.5) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get consensus prediction from all active models
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
confidence_threshold: Minimum confidence threshold for inclusion
|
||||
|
||||
Returns:
|
||||
Dict containing consensus prediction or None
|
||||
"""
|
||||
try:
|
||||
current_outputs = self.current_outputs.get(symbol, {})
|
||||
if not current_outputs:
|
||||
return None
|
||||
|
||||
# Filter by confidence threshold
|
||||
high_confidence_outputs = [
|
||||
output for output in current_outputs.values()
|
||||
if output.confidence >= confidence_threshold
|
||||
]
|
||||
|
||||
if not high_confidence_outputs:
|
||||
return None
|
||||
|
||||
# Calculate consensus
|
||||
buy_votes = sum(1 for output in high_confidence_outputs
|
||||
if output.predictions.get('action') == 'BUY')
|
||||
sell_votes = sum(1 for output in high_confidence_outputs
|
||||
if output.predictions.get('action') == 'SELL')
|
||||
hold_votes = sum(1 for output in high_confidence_outputs
|
||||
if output.predictions.get('action') == 'HOLD')
|
||||
|
||||
total_votes = len(high_confidence_outputs)
|
||||
avg_confidence = sum(output.confidence for output in high_confidence_outputs) / total_votes
|
||||
|
||||
# Determine consensus action
|
||||
if buy_votes > sell_votes and buy_votes > hold_votes:
|
||||
consensus_action = 'BUY'
|
||||
elif sell_votes > buy_votes and sell_votes > hold_votes:
|
||||
consensus_action = 'SELL'
|
||||
else:
|
||||
consensus_action = 'HOLD'
|
||||
|
||||
return {
|
||||
'action': consensus_action,
|
||||
'confidence': avg_confidence,
|
||||
'votes': {'BUY': buy_votes, 'SELL': sell_votes, 'HOLD': hold_votes},
|
||||
'total_models': total_votes,
|
||||
'model_types': [output.model_type for output in high_confidence_outputs]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating consensus prediction for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _update_model_metadata(self, model_name: str, model_type: str, metadata: Dict[str, Any]):
|
||||
"""Update metadata for a model"""
|
||||
try:
|
||||
if model_name not in self.model_metadata:
|
||||
self.model_metadata[model_name] = {
|
||||
'model_type': model_type,
|
||||
'first_seen': datetime.now(),
|
||||
'total_predictions': 0,
|
||||
'custom_metadata': {}
|
||||
}
|
||||
|
||||
self.model_metadata[model_name]['total_predictions'] += 1
|
||||
self.model_metadata[model_name]['last_seen'] = datetime.now()
|
||||
|
||||
# Merge custom metadata
|
||||
if metadata:
|
||||
self.model_metadata[model_name]['custom_metadata'].update(metadata)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating model metadata: {e}")
|
||||
|
||||
def _update_performance_stats(self, symbol: str, model_name: str, model_output: ModelOutput):
|
||||
"""Update performance statistics for a model"""
|
||||
try:
|
||||
stats = self.performance_stats[symbol][model_name]
|
||||
|
||||
if 'prediction_count' not in stats:
|
||||
stats['prediction_count'] = 0
|
||||
stats['confidence_sum'] = 0.0
|
||||
stats['action_counts'] = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
||||
stats['first_prediction'] = model_output.timestamp
|
||||
|
||||
stats['prediction_count'] += 1
|
||||
stats['confidence_sum'] += model_output.confidence
|
||||
stats['avg_confidence'] = stats['confidence_sum'] / stats['prediction_count']
|
||||
stats['last_prediction'] = model_output.timestamp
|
||||
|
||||
action = model_output.predictions.get('action', 'HOLD')
|
||||
if action in stats['action_counts']:
|
||||
stats['action_counts'][action] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating performance stats: {e}")
|
||||
|
||||
def _persist_output_async(self, model_output: ModelOutput):
|
||||
"""Persist model output to disk (simplified version)"""
|
||||
try:
|
||||
# Create filename based on model and timestamp
|
||||
timestamp_str = model_output.timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp_str}.json"
|
||||
filepath = self.cache_dir / filename
|
||||
|
||||
# Convert to JSON-serializable format
|
||||
output_dict = {
|
||||
'model_type': model_output.model_type,
|
||||
'model_name': model_output.model_name,
|
||||
'symbol': model_output.symbol,
|
||||
'timestamp': model_output.timestamp.isoformat(),
|
||||
'confidence': model_output.confidence,
|
||||
'predictions': model_output.predictions,
|
||||
'metadata': model_output.metadata
|
||||
}
|
||||
|
||||
# Save to file (in a real implementation, this would be async)
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(output_dict, f, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error persisting model output: {e}")
|
||||
|
||||
def get_performance_summary(self, symbol: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get performance summary for all models for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
Dict containing performance summary
|
||||
"""
|
||||
try:
|
||||
summary = {
|
||||
'symbol': symbol,
|
||||
'active_models': len(self.current_outputs.get(symbol, {})),
|
||||
'model_stats': {}
|
||||
}
|
||||
|
||||
for model_name, stats in self.performance_stats.get(symbol, {}).items():
|
||||
summary['model_stats'][model_name] = {
|
||||
'predictions': stats.get('prediction_count', 0),
|
||||
'avg_confidence': round(stats.get('avg_confidence', 0.0), 3),
|
||||
'action_distribution': stats.get('action_counts', {}),
|
||||
'model_type': self.model_metadata.get(model_name, {}).get('model_type', 'unknown')
|
||||
}
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting performance summary: {e}")
|
||||
return {'symbol': symbol, 'error': str(e)}
|
||||
|
||||
def cleanup_old_outputs(self, max_age_hours: int = 24):
|
||||
"""
|
||||
Clean up old outputs to manage memory usage
|
||||
|
||||
Args:
|
||||
max_age_hours: Maximum age of outputs to keep in hours
|
||||
"""
|
||||
try:
|
||||
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
|
||||
|
||||
with self.storage_lock:
|
||||
for symbol in self.output_history:
|
||||
for model_name in self.output_history[symbol]:
|
||||
history = self.output_history[symbol][model_name]
|
||||
# Remove old outputs
|
||||
while history and history[0].timestamp < cutoff_time:
|
||||
history.popleft()
|
||||
|
||||
logger.info(f"Cleaned up outputs older than {max_age_hours} hours")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up old outputs: {e}")
|
||||
|
||||
def add_custom_model_type(self, model_type: str):
|
||||
"""
|
||||
Add support for a new custom model type
|
||||
|
||||
Args:
|
||||
model_type: Name of the new model type
|
||||
"""
|
||||
self.supported_model_types.add(model_type)
|
||||
logger.info(f"Added support for custom model type: {model_type}")
|
||||
|
||||
def get_supported_model_types(self) -> List[str]:
|
||||
"""Get list of all supported model types"""
|
||||
return list(self.supported_model_types)
|
453
core/standardized_data_provider.py
Normal file
453
core/standardized_data_provider.py
Normal file
@ -0,0 +1,453 @@
|
||||
"""
|
||||
Standardized Data Provider Extension
|
||||
|
||||
This module extends the existing DataProvider with standardized BaseDataInput functionality
|
||||
for all models in the multi-modal trading system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from collections import deque
|
||||
from threading import Lock
|
||||
|
||||
from .data_provider import DataProvider
|
||||
from .data_models import BaseDataInput, OHLCVBar, COBData, ModelOutput, PivotPoint
|
||||
from .multi_exchange_cob_provider import MultiExchangeCOBProvider
|
||||
from .model_output_manager import ModelOutputManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class StandardizedDataProvider(DataProvider):
|
||||
"""
|
||||
Extended DataProvider with standardized BaseDataInput support
|
||||
|
||||
Provides unified data format for all models:
|
||||
- OHLCV: 300 frames of (1s, 1m, 1h, 1d) ETH + 300s of 1s BTC
|
||||
- COB: ±20 buckets of COB amounts in USD for each 1s OHLCV
|
||||
- MA: 1s, 5s, 15s, and 60s MA of COB imbalance counting ±5 COB buckets
|
||||
"""
|
||||
|
||||
def __init__(self, symbols: List[str] = None, timeframes: List[str] = None):
|
||||
"""Initialize the standardized data provider"""
|
||||
super().__init__(symbols, timeframes)
|
||||
|
||||
# Standardized data storage
|
||||
self.base_data_cache: Dict[str, BaseDataInput] = {} # {symbol: BaseDataInput}
|
||||
self.cob_data_cache: Dict[str, COBData] = {} # {symbol: COBData}
|
||||
|
||||
# Model output management with extensible storage
|
||||
self.model_output_manager = ModelOutputManager(
|
||||
cache_dir=str(self.cache_dir / "model_outputs"),
|
||||
max_history=1000
|
||||
)
|
||||
|
||||
# COB moving averages calculation
|
||||
self.cob_imbalance_history: Dict[str, deque] = {} # {symbol: deque of (timestamp, imbalance_data)}
|
||||
self.ma_calculation_lock = Lock()
|
||||
|
||||
# Initialize caches for each symbol
|
||||
for symbol in self.symbols:
|
||||
self.base_data_cache[symbol] = None
|
||||
self.cob_data_cache[symbol] = None
|
||||
self.cob_imbalance_history[symbol] = deque(maxlen=300) # 5 minutes of 1s data
|
||||
|
||||
# COB provider integration
|
||||
self.cob_provider: Optional[MultiExchangeCOBProvider] = None
|
||||
self._initialize_cob_provider()
|
||||
|
||||
logger.info("StandardizedDataProvider initialized with BaseDataInput support")
|
||||
|
||||
def _initialize_cob_provider(self):
|
||||
"""Initialize COB provider for order book data"""
|
||||
try:
|
||||
from .multi_exchange_cob_provider import MultiExchangeCOBProvider, ExchangeConfig, ExchangeType
|
||||
|
||||
# Configure exchanges (focusing on Binance for now)
|
||||
exchange_configs = {
|
||||
'binance': ExchangeConfig(
|
||||
exchange_type=ExchangeType.BINANCE,
|
||||
weight=1.0,
|
||||
enabled=True,
|
||||
websocket_url="wss://stream.binance.com:9443/ws/",
|
||||
symbols_mapping={symbol: symbol.replace('/', '').lower() for symbol in self.symbols}
|
||||
)
|
||||
}
|
||||
|
||||
self.cob_provider = MultiExchangeCOBProvider(self.symbols, exchange_configs)
|
||||
logger.info("COB provider initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize COB provider: {e}")
|
||||
self.cob_provider = None
|
||||
|
||||
def get_base_data_input(self, symbol: str, timestamp: Optional[datetime] = None) -> Optional[BaseDataInput]:
|
||||
"""
|
||||
Get standardized BaseDataInput for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
timestamp: Optional timestamp, defaults to current time
|
||||
|
||||
Returns:
|
||||
BaseDataInput: Standardized input data for models, or None if insufficient data
|
||||
"""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
|
||||
try:
|
||||
# Get OHLCV data for all timeframes
|
||||
ohlcv_1s = self._get_ohlcv_bars(symbol, '1s', 300)
|
||||
ohlcv_1m = self._get_ohlcv_bars(symbol, '1m', 300)
|
||||
ohlcv_1h = self._get_ohlcv_bars(symbol, '1h', 300)
|
||||
ohlcv_1d = self._get_ohlcv_bars(symbol, '1d', 300)
|
||||
|
||||
# Get BTC reference data
|
||||
btc_symbol = 'BTC/USDT'
|
||||
btc_ohlcv_1s = self._get_ohlcv_bars(btc_symbol, '1s', 300)
|
||||
|
||||
# Check if we have sufficient data
|
||||
if not all([ohlcv_1s, ohlcv_1m, ohlcv_1h, ohlcv_1d, btc_ohlcv_1s]):
|
||||
logger.warning(f"Insufficient OHLCV data for {symbol}")
|
||||
return None
|
||||
|
||||
if any(len(data) < 100 for data in [ohlcv_1s, ohlcv_1m, ohlcv_1h, ohlcv_1d, btc_ohlcv_1s]):
|
||||
logger.warning(f"Insufficient data frames for {symbol}")
|
||||
return None
|
||||
|
||||
# Get COB data
|
||||
cob_data = self._get_cob_data(symbol, timestamp)
|
||||
|
||||
# Get technical indicators
|
||||
technical_indicators = self._get_technical_indicators(symbol)
|
||||
|
||||
# Get pivot points
|
||||
pivot_points = self._get_pivot_points(symbol)
|
||||
|
||||
# Get last predictions from all models
|
||||
last_predictions = self.model_output_manager.get_all_current_outputs(symbol)
|
||||
|
||||
# Create BaseDataInput
|
||||
base_input = BaseDataInput(
|
||||
symbol=symbol,
|
||||
timestamp=timestamp,
|
||||
ohlcv_1s=ohlcv_1s,
|
||||
ohlcv_1m=ohlcv_1m,
|
||||
ohlcv_1h=ohlcv_1h,
|
||||
ohlcv_1d=ohlcv_1d,
|
||||
btc_ohlcv_1s=btc_ohlcv_1s,
|
||||
cob_data=cob_data,
|
||||
technical_indicators=technical_indicators,
|
||||
pivot_points=pivot_points,
|
||||
last_predictions=last_predictions
|
||||
)
|
||||
|
||||
# Validate the input
|
||||
if not base_input.validate():
|
||||
logger.warning(f"BaseDataInput validation failed for {symbol}")
|
||||
return None
|
||||
|
||||
# Cache the result
|
||||
self.base_data_cache[symbol] = base_input
|
||||
|
||||
return base_input
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating BaseDataInput for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _get_ohlcv_bars(self, symbol: str, timeframe: str, count: int) -> List[OHLCVBar]:
|
||||
"""
|
||||
Get OHLCV bars for a symbol and timeframe
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe ('1s', '1m', '1h', '1d')
|
||||
count: Number of bars to retrieve
|
||||
|
||||
Returns:
|
||||
List[OHLCVBar]: List of OHLCV bars
|
||||
"""
|
||||
try:
|
||||
# Get historical data from parent class
|
||||
df = self.get_historical_data(symbol, timeframe, count)
|
||||
if df is None or df.empty:
|
||||
return []
|
||||
|
||||
# Convert DataFrame to OHLCVBar objects
|
||||
bars = []
|
||||
for _, row in df.tail(count).iterrows():
|
||||
bar = OHLCVBar(
|
||||
symbol=symbol,
|
||||
timestamp=row.name if hasattr(row, 'name') else datetime.now(),
|
||||
open=float(row['open']),
|
||||
high=float(row['high']),
|
||||
low=float(row['low']),
|
||||
close=float(row['close']),
|
||||
volume=float(row['volume']),
|
||||
timeframe=timeframe,
|
||||
indicators={}
|
||||
)
|
||||
|
||||
# Add technical indicators if available
|
||||
for col in df.columns:
|
||||
if col not in ['open', 'high', 'low', 'close', 'volume']:
|
||||
bar.indicators[col] = float(row[col]) if not np.isnan(row[col]) else 0.0
|
||||
|
||||
bars.append(bar)
|
||||
|
||||
return bars
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting OHLCV bars for {symbol} {timeframe}: {e}")
|
||||
return []
|
||||
|
||||
def _get_cob_data(self, symbol: str, timestamp: datetime) -> Optional[COBData]:
|
||||
"""
|
||||
Get COB data for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timestamp: Current timestamp
|
||||
|
||||
Returns:
|
||||
COBData: COB data with price buckets and moving averages
|
||||
"""
|
||||
try:
|
||||
if not self.cob_provider:
|
||||
return None
|
||||
|
||||
# Get current price
|
||||
current_price = self.current_prices.get(symbol.replace('/', '').upper(), 0.0)
|
||||
if current_price <= 0:
|
||||
return None
|
||||
|
||||
# Determine bucket size based on symbol
|
||||
bucket_size = 1.0 if 'ETH' in symbol else 10.0 # $1 for ETH, $10 for BTC
|
||||
|
||||
# Calculate price range (±20 buckets)
|
||||
price_range = 20 * bucket_size
|
||||
min_price = current_price - price_range
|
||||
max_price = current_price + price_range
|
||||
|
||||
# Create price buckets
|
||||
price_buckets = {}
|
||||
bid_ask_imbalance = {}
|
||||
volume_weighted_prices = {}
|
||||
|
||||
# Generate mock COB data for now (will be replaced with real COB provider data)
|
||||
for i in range(-20, 21):
|
||||
price = current_price + (i * bucket_size)
|
||||
if price > 0:
|
||||
# Mock data - replace with real COB provider data
|
||||
bid_volume = max(0, 1000 - abs(i) * 50) # More volume near current price
|
||||
ask_volume = max(0, 1000 - abs(i) * 50)
|
||||
total_volume = bid_volume + ask_volume
|
||||
imbalance = (bid_volume - ask_volume) / max(total_volume, 1)
|
||||
|
||||
price_buckets[price] = {
|
||||
'bid_volume': bid_volume,
|
||||
'ask_volume': ask_volume,
|
||||
'total_volume': total_volume,
|
||||
'imbalance': imbalance
|
||||
}
|
||||
bid_ask_imbalance[price] = imbalance
|
||||
volume_weighted_prices[price] = price # Simplified VWAP
|
||||
|
||||
# Calculate moving averages of imbalance for ±5 buckets
|
||||
ma_data = self._calculate_cob_moving_averages(symbol, bid_ask_imbalance, timestamp)
|
||||
|
||||
cob_data = COBData(
|
||||
symbol=symbol,
|
||||
timestamp=timestamp,
|
||||
current_price=current_price,
|
||||
bucket_size=bucket_size,
|
||||
price_buckets=price_buckets,
|
||||
bid_ask_imbalance=bid_ask_imbalance,
|
||||
volume_weighted_prices=volume_weighted_prices,
|
||||
order_flow_metrics={},
|
||||
ma_1s_imbalance=ma_data.get('1s', {}),
|
||||
ma_5s_imbalance=ma_data.get('5s', {}),
|
||||
ma_15s_imbalance=ma_data.get('15s', {}),
|
||||
ma_60s_imbalance=ma_data.get('60s', {})
|
||||
)
|
||||
|
||||
# Cache the COB data
|
||||
self.cob_data_cache[symbol] = cob_data
|
||||
|
||||
return cob_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting COB data for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _calculate_cob_moving_averages(self, symbol: str, bid_ask_imbalance: Dict[float, float],
|
||||
timestamp: datetime) -> Dict[str, Dict[float, float]]:
|
||||
"""
|
||||
Calculate moving averages of COB imbalance for ±5 buckets
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
bid_ask_imbalance: Current bid/ask imbalance data
|
||||
timestamp: Current timestamp
|
||||
|
||||
Returns:
|
||||
Dict containing MA data for different timeframes
|
||||
"""
|
||||
try:
|
||||
with self.ma_calculation_lock:
|
||||
# Add current imbalance data to history
|
||||
self.cob_imbalance_history[symbol].append((timestamp, bid_ask_imbalance))
|
||||
|
||||
# Calculate MAs for different timeframes
|
||||
ma_results = {'1s': {}, '5s': {}, '15s': {}, '60s': {}}
|
||||
|
||||
# Get current price for ±5 bucket calculation
|
||||
current_price = self.current_prices.get(symbol.replace('/', '').upper(), 0.0)
|
||||
if current_price <= 0:
|
||||
return ma_results
|
||||
|
||||
bucket_size = 1.0 if 'ETH' in symbol else 10.0
|
||||
|
||||
# Calculate MAs for ±5 buckets around current price
|
||||
for i in range(-5, 6):
|
||||
price = current_price + (i * bucket_size)
|
||||
if price <= 0:
|
||||
continue
|
||||
|
||||
# Get historical imbalance data for this price bucket
|
||||
history = self.cob_imbalance_history[symbol]
|
||||
|
||||
# Calculate different MA periods
|
||||
for period, period_name in [(1, '1s'), (5, '5s'), (15, '15s'), (60, '60s')]:
|
||||
recent_data = []
|
||||
cutoff_time = timestamp - timedelta(seconds=period)
|
||||
|
||||
for hist_timestamp, hist_imbalance in history:
|
||||
if hist_timestamp >= cutoff_time and price in hist_imbalance:
|
||||
recent_data.append(hist_imbalance[price])
|
||||
|
||||
# Calculate moving average
|
||||
if recent_data:
|
||||
ma_results[period_name][price] = sum(recent_data) / len(recent_data)
|
||||
else:
|
||||
ma_results[period_name][price] = 0.0
|
||||
|
||||
return ma_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating COB moving averages for {symbol}: {e}")
|
||||
return {'1s': {}, '5s': {}, '15s': {}, '60s': {}}
|
||||
|
||||
def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
|
||||
"""Get technical indicators for a symbol"""
|
||||
try:
|
||||
# Get latest OHLCV data
|
||||
df = self.get_historical_data(symbol, '1h', 100) # Use 1h for indicators
|
||||
if df is None or df.empty:
|
||||
return {}
|
||||
|
||||
indicators = {}
|
||||
|
||||
# Add basic indicators if available in the dataframe
|
||||
latest_row = df.iloc[-1]
|
||||
for col in df.columns:
|
||||
if col not in ['open', 'high', 'low', 'close', 'volume']:
|
||||
indicators[col] = float(latest_row[col]) if not np.isnan(latest_row[col]) else 0.0
|
||||
|
||||
return indicators
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting technical indicators for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def _get_pivot_points(self, symbol: str) -> List[PivotPoint]:
|
||||
"""Get pivot points for a symbol"""
|
||||
try:
|
||||
pivot_points = []
|
||||
|
||||
# Get pivot points from Williams Market Structure if available
|
||||
if symbol in self.williams_structure:
|
||||
williams = self.williams_structure[symbol]
|
||||
# This would need to be implemented based on the actual Williams structure
|
||||
# For now, return empty list
|
||||
pass
|
||||
|
||||
return pivot_points
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pivot points for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def store_model_output(self, model_output: ModelOutput):
|
||||
"""
|
||||
Store model output for cross-model feeding using ModelOutputManager
|
||||
|
||||
Args:
|
||||
model_output: ModelOutput from any model
|
||||
"""
|
||||
try:
|
||||
success = self.model_output_manager.store_output(model_output)
|
||||
if success:
|
||||
logger.debug(f"Stored model output from {model_output.model_name} for {model_output.symbol}")
|
||||
else:
|
||||
logger.warning(f"Failed to store model output from {model_output.model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing model output: {e}")
|
||||
|
||||
def get_model_outputs(self, symbol: str) -> Dict[str, ModelOutput]:
|
||||
"""
|
||||
Get all model outputs for a symbol using ModelOutputManager
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
Dict[str, ModelOutput]: Dictionary of model outputs by model name
|
||||
"""
|
||||
return self.model_output_manager.get_all_current_outputs(symbol)
|
||||
|
||||
def get_model_output_manager(self) -> ModelOutputManager:
|
||||
"""
|
||||
Get the model output manager for advanced operations
|
||||
|
||||
Returns:
|
||||
ModelOutputManager: The model output manager instance
|
||||
"""
|
||||
return self.model_output_manager
|
||||
|
||||
def start_real_time_processing(self):
|
||||
"""Start real-time processing for standardized data"""
|
||||
try:
|
||||
# Start parent class real-time processing
|
||||
if hasattr(super(), 'start_real_time_processing'):
|
||||
super().start_real_time_processing()
|
||||
|
||||
# Start COB provider if available
|
||||
if self.cob_provider:
|
||||
import asyncio
|
||||
asyncio.create_task(self.cob_provider.start_streaming())
|
||||
|
||||
logger.info("Started real-time processing for standardized data")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting real-time processing: {e}")
|
||||
|
||||
def stop_real_time_processing(self):
|
||||
"""Stop real-time processing"""
|
||||
try:
|
||||
# Stop COB provider if available
|
||||
if self.cob_provider:
|
||||
import asyncio
|
||||
asyncio.create_task(self.cob_provider.stop_streaming())
|
||||
|
||||
# Stop parent class processing
|
||||
if hasattr(super(), 'stop_real_time_processing'):
|
||||
super().stop_real_time_processing()
|
||||
|
||||
logger.info("Stopped real-time processing for standardized data")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping real-time processing: {e}")
|
@ -0,0 +1,17 @@
|
||||
{
|
||||
"model_type": "hybrid_ensemble",
|
||||
"model_name": "custom_ensemble_v1",
|
||||
"symbol": "ETH/USDT",
|
||||
"timestamp": "2025-07-23T15:38:26.810125",
|
||||
"confidence": 0.88,
|
||||
"predictions": {
|
||||
"action": "BUY",
|
||||
"buy_probability": 0.88,
|
||||
"sell_probability": 0.0,
|
||||
"hold_probability": 0.0
|
||||
},
|
||||
"metadata": {
|
||||
"ensemble_size": 5,
|
||||
"voting_method": "weighted"
|
||||
}
|
||||
}
|
@ -0,0 +1,18 @@
|
||||
{
|
||||
"model_type": "rl",
|
||||
"model_name": "dqn_agent_v2",
|
||||
"symbol": "ETH/USDT",
|
||||
"timestamp": "2025-07-23T15:38:26.804125",
|
||||
"confidence": 0.72,
|
||||
"predictions": {
|
||||
"action": "SELL",
|
||||
"buy_probability": 0.0,
|
||||
"sell_probability": 0.72,
|
||||
"hold_probability": 0.0
|
||||
},
|
||||
"metadata": {
|
||||
"model_version": "1.0",
|
||||
"training_iterations": 1000,
|
||||
"last_updated": "2025-07-23T15:38:26.804125"
|
||||
}
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
{
|
||||
"model_type": "cnn",
|
||||
"model_name": "enhanced_cnn_v1",
|
||||
"symbol": "ETH/USDT",
|
||||
"timestamp": "2025-07-23T15:38:26.812129",
|
||||
"confidence": 0.8,
|
||||
"predictions": {
|
||||
"action": "HOLD",
|
||||
"buy_probability": 0.0,
|
||||
"sell_probability": 0.0,
|
||||
"hold_probability": 0.8
|
||||
},
|
||||
"metadata": {}
|
||||
}
|
@ -0,0 +1,18 @@
|
||||
{
|
||||
"model_type": "lstm",
|
||||
"model_name": "lstm_predictor",
|
||||
"symbol": "ETH/USDT",
|
||||
"timestamp": "2025-07-23T15:38:26.805126",
|
||||
"confidence": 0.65,
|
||||
"predictions": {
|
||||
"action": "HOLD",
|
||||
"buy_probability": 0.0,
|
||||
"sell_probability": 0.0,
|
||||
"hold_probability": 0.65
|
||||
},
|
||||
"metadata": {
|
||||
"model_version": "1.0",
|
||||
"training_iterations": 1000,
|
||||
"last_updated": "2025-07-23T15:38:26.805126"
|
||||
}
|
||||
}
|
@ -0,0 +1,18 @@
|
||||
{
|
||||
"model_type": "orchestrator",
|
||||
"model_name": "main_orchestrator",
|
||||
"symbol": "ETH/USDT",
|
||||
"timestamp": "2025-07-23T15:38:26.806125",
|
||||
"confidence": 0.78,
|
||||
"predictions": {
|
||||
"action": "BUY",
|
||||
"buy_probability": 0.78,
|
||||
"sell_probability": 0.0,
|
||||
"hold_probability": 0.0
|
||||
},
|
||||
"metadata": {
|
||||
"model_version": "1.0",
|
||||
"training_iterations": 1000,
|
||||
"last_updated": "2025-07-23T15:38:26.806125"
|
||||
}
|
||||
}
|
@ -0,0 +1,18 @@
|
||||
{
|
||||
"model_type": "transformer",
|
||||
"model_name": "transformer_v1",
|
||||
"symbol": "ETH/USDT",
|
||||
"timestamp": "2025-07-23T15:38:26.806125",
|
||||
"confidence": 0.91,
|
||||
"predictions": {
|
||||
"action": "BUY",
|
||||
"buy_probability": 0.91,
|
||||
"sell_probability": 0.0,
|
||||
"hold_probability": 0.0
|
||||
},
|
||||
"metadata": {
|
||||
"model_version": "1.0",
|
||||
"training_iterations": 1000,
|
||||
"last_updated": "2025-07-23T15:38:26.806125"
|
||||
}
|
||||
}
|
182
test_integrated_standardized_provider.py
Normal file
182
test_integrated_standardized_provider.py
Normal file
@ -0,0 +1,182 @@
|
||||
"""
|
||||
Test script for integrated StandardizedDataProvider with ModelOutputManager
|
||||
|
||||
This script tests the complete standardized data provider with extensible model output storage
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from core.standardized_data_provider import StandardizedDataProvider
|
||||
from core.data_models import create_model_output
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_integrated_standardized_provider():
|
||||
"""Test the integrated StandardizedDataProvider with ModelOutputManager"""
|
||||
|
||||
print("Testing Integrated StandardizedDataProvider with ModelOutputManager...")
|
||||
|
||||
# Initialize the provider
|
||||
symbols = ['ETH/USDT', 'BTC/USDT']
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
|
||||
|
||||
print("✅ StandardizedDataProvider initialized with ModelOutputManager")
|
||||
|
||||
# Test 1: Store model outputs from different types
|
||||
print("\n1. Testing model output storage integration...")
|
||||
|
||||
# Create and store outputs from different model types
|
||||
model_outputs = [
|
||||
create_model_output('cnn', 'enhanced_cnn_v1', 'ETH/USDT', 'BUY', 0.85),
|
||||
create_model_output('rl', 'dqn_agent_v2', 'ETH/USDT', 'SELL', 0.72),
|
||||
create_model_output('transformer', 'transformer_v1', 'ETH/USDT', 'BUY', 0.91),
|
||||
create_model_output('orchestrator', 'main_orchestrator', 'ETH/USDT', 'BUY', 0.78)
|
||||
]
|
||||
|
||||
for output in model_outputs:
|
||||
provider.store_model_output(output)
|
||||
print(f"✅ Stored {output.model_type} output: {output.predictions['action']} ({output.confidence})")
|
||||
|
||||
# Test 2: Retrieve model outputs
|
||||
print("\n2. Testing model output retrieval...")
|
||||
|
||||
all_outputs = provider.get_model_outputs('ETH/USDT')
|
||||
print(f"✅ Retrieved {len(all_outputs)} model outputs for ETH/USDT")
|
||||
|
||||
for model_name, output in all_outputs.items():
|
||||
print(f" {model_name} ({output.model_type}): {output.predictions['action']} - {output.confidence}")
|
||||
|
||||
# Test 3: Test BaseDataInput with cross-model feeding
|
||||
print("\n3. Testing BaseDataInput with cross-model predictions...")
|
||||
|
||||
# Set mock current price for COB data
|
||||
provider.current_prices['ETHUSDT'] = 3000.0
|
||||
|
||||
base_input = provider.get_base_data_input('ETH/USDT')
|
||||
|
||||
if base_input:
|
||||
print("✅ BaseDataInput created with cross-model predictions!")
|
||||
print(f" Symbol: {base_input.symbol}")
|
||||
print(f" OHLCV frames: 1s={len(base_input.ohlcv_1s)}, 1m={len(base_input.ohlcv_1m)}, 1h={len(base_input.ohlcv_1h)}, 1d={len(base_input.ohlcv_1d)}")
|
||||
print(f" BTC frames: {len(base_input.btc_ohlcv_1s)}")
|
||||
print(f" COB data: {'Available' if base_input.cob_data else 'Not available'}")
|
||||
print(f" Last predictions: {len(base_input.last_predictions)} models")
|
||||
|
||||
# Show cross-model predictions
|
||||
for model_name, prediction in base_input.last_predictions.items():
|
||||
print(f" {model_name}: {prediction.predictions['action']} ({prediction.confidence})")
|
||||
|
||||
# Test feature vector creation
|
||||
try:
|
||||
feature_vector = base_input.get_feature_vector()
|
||||
print(f"✅ Feature vector created: shape {feature_vector.shape}")
|
||||
except Exception as e:
|
||||
print(f"❌ Feature vector creation failed: {e}")
|
||||
else:
|
||||
print("⚠️ BaseDataInput creation failed - this may be due to insufficient historical data")
|
||||
|
||||
# Test 4: Advanced ModelOutputManager features
|
||||
print("\n4. Testing advanced model output manager features...")
|
||||
|
||||
output_manager = provider.get_model_output_manager()
|
||||
|
||||
# Test consensus prediction
|
||||
consensus = output_manager.get_consensus_prediction('ETH/USDT', confidence_threshold=0.7)
|
||||
if consensus:
|
||||
print(f"✅ Consensus prediction: {consensus['action']} (confidence: {consensus['confidence']:.3f})")
|
||||
print(f" Votes: {consensus['votes']}")
|
||||
print(f" Contributing models: {consensus['model_types']}")
|
||||
else:
|
||||
print("⚠️ No consensus reached")
|
||||
|
||||
# Test cross-model states
|
||||
cross_states = output_manager.get_cross_model_states('ETH/USDT', 'dqn_agent_v2')
|
||||
print(f"✅ Cross-model states available for RL model: {len(cross_states)} models")
|
||||
|
||||
# Test performance summary
|
||||
performance = output_manager.get_performance_summary('ETH/USDT')
|
||||
print(f"✅ Performance summary: {performance['active_models']} active models")
|
||||
|
||||
# Test 5: Custom model type support
|
||||
print("\n5. Testing custom model type extensibility...")
|
||||
|
||||
# Add a custom model type
|
||||
output_manager.add_custom_model_type('hybrid_lstm_transformer')
|
||||
|
||||
# Create and store custom model output
|
||||
custom_output = create_model_output(
|
||||
model_type='hybrid_lstm_transformer',
|
||||
model_name='hybrid_model_v1',
|
||||
symbol='ETH/USDT',
|
||||
action='BUY',
|
||||
confidence=0.89,
|
||||
metadata={'hybrid_components': ['lstm', 'transformer'], 'ensemble_weight': 0.6}
|
||||
)
|
||||
|
||||
provider.store_model_output(custom_output)
|
||||
print("✅ Custom model type 'hybrid_lstm_transformer' stored successfully")
|
||||
|
||||
# Verify it's included in BaseDataInput
|
||||
updated_base_input = provider.get_base_data_input('ETH/USDT')
|
||||
if updated_base_input and 'hybrid_model_v1' in updated_base_input.last_predictions:
|
||||
print("✅ Custom model output included in BaseDataInput cross-model feeding")
|
||||
|
||||
print(f" Total supported model types: {len(output_manager.get_supported_model_types())}")
|
||||
|
||||
# Test 6: Historical tracking
|
||||
print("\n6. Testing historical output tracking...")
|
||||
|
||||
# Store a few more outputs to build history
|
||||
for i in range(3):
|
||||
historical_output = create_model_output(
|
||||
model_type='cnn',
|
||||
model_name='enhanced_cnn_v1',
|
||||
symbol='ETH/USDT',
|
||||
action='HOLD',
|
||||
confidence=0.6 + i * 0.05
|
||||
)
|
||||
provider.store_model_output(historical_output)
|
||||
|
||||
history = output_manager.get_output_history('ETH/USDT', 'enhanced_cnn_v1', count=5)
|
||||
print(f"✅ Historical tracking: {len(history)} outputs for enhanced_cnn_v1")
|
||||
|
||||
# Test 7: Real-time data integration readiness
|
||||
print("\n7. Testing real-time integration readiness...")
|
||||
|
||||
print("✅ Real-time processing methods available:")
|
||||
print(" - start_real_time_processing()")
|
||||
print(" - stop_real_time_processing()")
|
||||
print(" - COB provider integration ready")
|
||||
print(" - Model output persistence enabled")
|
||||
|
||||
print("\n✅ Integrated StandardizedDataProvider test completed successfully!")
|
||||
print("\n🎯 Key achievements:")
|
||||
print("✓ Standardized BaseDataInput format for all models")
|
||||
print("✓ Extensible ModelOutput storage (CNN, RL, LSTM, Transformer, Custom)")
|
||||
print("✓ Cross-model feeding with last predictions")
|
||||
print("✓ COB data integration with moving averages")
|
||||
print("✓ Consensus prediction calculation")
|
||||
print("✓ Historical output tracking")
|
||||
print("✓ Performance analytics")
|
||||
print("✓ Thread-safe operations")
|
||||
print("✓ Persistent storage capabilities")
|
||||
|
||||
print("\n🚀 Ready for model integration:")
|
||||
print("1. CNN models can use BaseDataInput and store ModelOutput")
|
||||
print("2. RL models can access CNN hidden states via cross-model feeding")
|
||||
print("3. Orchestrator can calculate consensus from all models")
|
||||
print("4. New model types can be added without code changes")
|
||||
print("5. All models receive identical standardized input format")
|
||||
|
||||
return provider
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_integrated_standardized_provider()
|
182
test_model_output_manager.py
Normal file
182
test_model_output_manager.py
Normal file
@ -0,0 +1,182 @@
|
||||
"""
|
||||
Test script for ModelOutputManager
|
||||
|
||||
This script tests the extensible model output storage functionality
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from core.model_output_manager import ModelOutputManager
|
||||
from core.data_models import create_model_output, ModelOutput
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_model_output_manager():
|
||||
"""Test the ModelOutputManager functionality"""
|
||||
|
||||
print("Testing ModelOutputManager...")
|
||||
|
||||
# Initialize the manager
|
||||
manager = ModelOutputManager(cache_dir="test_cache/model_outputs", max_history=100)
|
||||
|
||||
print(f"✅ ModelOutputManager initialized")
|
||||
print(f" Supported model types: {manager.get_supported_model_types()}")
|
||||
|
||||
# Test 1: Store outputs from different model types
|
||||
print("\n1. Testing model output storage...")
|
||||
|
||||
# Create outputs from different model types
|
||||
models_to_test = [
|
||||
('cnn', 'enhanced_cnn_v1', 'BUY', 0.85),
|
||||
('rl', 'dqn_agent_v2', 'SELL', 0.72),
|
||||
('lstm', 'lstm_predictor', 'HOLD', 0.65),
|
||||
('transformer', 'transformer_v1', 'BUY', 0.91),
|
||||
('orchestrator', 'main_orchestrator', 'BUY', 0.78)
|
||||
]
|
||||
|
||||
symbol = 'ETH/USDT'
|
||||
stored_outputs = []
|
||||
|
||||
for model_type, model_name, action, confidence in models_to_test:
|
||||
# Create model output with hidden states for cross-model feeding
|
||||
hidden_states = {
|
||||
'layer_1': [0.1, 0.2, 0.3],
|
||||
'layer_2': [0.4, 0.5, 0.6],
|
||||
'attention_weights': [0.7, 0.8, 0.9]
|
||||
} if model_type in ['cnn', 'transformer'] else None
|
||||
|
||||
metadata = {
|
||||
'model_version': '1.0',
|
||||
'training_iterations': 1000,
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
model_output = create_model_output(
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
hidden_states=hidden_states,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Store the output
|
||||
success = manager.store_output(model_output)
|
||||
if success:
|
||||
print(f"✅ Stored {model_type} output: {action} ({confidence})")
|
||||
stored_outputs.append(model_output)
|
||||
else:
|
||||
print(f"❌ Failed to store {model_type} output")
|
||||
|
||||
# Test 2: Retrieve current outputs
|
||||
print("\n2. Testing output retrieval...")
|
||||
|
||||
all_current = manager.get_all_current_outputs(symbol)
|
||||
print(f"✅ Retrieved {len(all_current)} current outputs for {symbol}")
|
||||
|
||||
for model_name, output in all_current.items():
|
||||
print(f" {model_name} ({output.model_type}): {output.predictions['action']} - {output.confidence}")
|
||||
|
||||
# Test 3: Cross-model feeding
|
||||
print("\n3. Testing cross-model feeding...")
|
||||
|
||||
cross_model_states = manager.get_cross_model_states(symbol, 'dqn_agent_v2')
|
||||
print(f"✅ Retrieved cross-model states for RL model: {len(cross_model_states)} models")
|
||||
|
||||
for model_name, states in cross_model_states.items():
|
||||
if states:
|
||||
print(f" {model_name}: {len(states)} hidden state layers")
|
||||
|
||||
# Test 4: Consensus prediction
|
||||
print("\n4. Testing consensus prediction...")
|
||||
|
||||
consensus = manager.get_consensus_prediction(symbol, confidence_threshold=0.7)
|
||||
if consensus:
|
||||
print(f"✅ Consensus prediction: {consensus['action']} (confidence: {consensus['confidence']:.3f})")
|
||||
print(f" Votes: {consensus['votes']}")
|
||||
print(f" Models: {consensus['model_types']}")
|
||||
else:
|
||||
print("⚠️ No consensus reached (insufficient high-confidence predictions)")
|
||||
|
||||
# Test 5: Performance summary
|
||||
print("\n5. Testing performance tracking...")
|
||||
|
||||
performance = manager.get_performance_summary(symbol)
|
||||
print(f"✅ Performance summary for {symbol}:")
|
||||
print(f" Active models: {performance['active_models']}")
|
||||
|
||||
for model_name, stats in performance['model_stats'].items():
|
||||
print(f" {model_name} ({stats['model_type']}): {stats['predictions']} predictions, "
|
||||
f"avg confidence: {stats['avg_confidence']}")
|
||||
|
||||
# Test 6: Custom model type support
|
||||
print("\n6. Testing custom model type support...")
|
||||
|
||||
# Add a custom model type
|
||||
manager.add_custom_model_type('hybrid_ensemble')
|
||||
|
||||
# Create output with custom model type
|
||||
custom_output = create_model_output(
|
||||
model_type='hybrid_ensemble',
|
||||
model_name='custom_ensemble_v1',
|
||||
symbol=symbol,
|
||||
action='BUY',
|
||||
confidence=0.88,
|
||||
metadata={'ensemble_size': 5, 'voting_method': 'weighted'}
|
||||
)
|
||||
|
||||
success = manager.store_output(custom_output)
|
||||
if success:
|
||||
print("✅ Custom model type 'hybrid_ensemble' stored successfully")
|
||||
else:
|
||||
print("❌ Failed to store custom model type")
|
||||
|
||||
print(f" Updated supported types: {len(manager.get_supported_model_types())} types")
|
||||
|
||||
# Test 7: Historical outputs
|
||||
print("\n7. Testing historical output tracking...")
|
||||
|
||||
# Store a few more outputs to build history
|
||||
for i in range(3):
|
||||
historical_output = create_model_output(
|
||||
model_type='cnn',
|
||||
model_name='enhanced_cnn_v1',
|
||||
symbol=symbol,
|
||||
action='HOLD',
|
||||
confidence=0.6 + i * 0.1
|
||||
)
|
||||
manager.store_output(historical_output)
|
||||
|
||||
history = manager.get_output_history(symbol, 'enhanced_cnn_v1', count=5)
|
||||
print(f"✅ Retrieved {len(history)} historical outputs for enhanced_cnn_v1")
|
||||
|
||||
for i, output in enumerate(history):
|
||||
print(f" {i+1}. {output.predictions['action']} ({output.confidence}) at {output.timestamp}")
|
||||
|
||||
# Test 8: Active model types
|
||||
print("\n8. Testing active model type detection...")
|
||||
|
||||
active_types = manager.get_model_types_active(symbol)
|
||||
print(f"✅ Active model types for {symbol}: {active_types}")
|
||||
|
||||
print("\n✅ ModelOutputManager test completed successfully!")
|
||||
print("\nKey features verified:")
|
||||
print("✓ Extensible model type support (CNN, RL, LSTM, Transformer, Custom)")
|
||||
print("✓ Cross-model feeding with hidden states")
|
||||
print("✓ Historical output tracking")
|
||||
print("✓ Performance analytics")
|
||||
print("✓ Consensus prediction calculation")
|
||||
print("✓ Metadata management")
|
||||
print("✓ Thread-safe storage operations")
|
||||
|
||||
return manager
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_model_output_manager()
|
128
test_standardized_data_provider.py
Normal file
128
test_standardized_data_provider.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""
|
||||
Test script for StandardizedDataProvider
|
||||
|
||||
This script tests the standardized BaseDataInput functionality
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from core.standardized_data_provider import StandardizedDataProvider
|
||||
from core.data_models import create_model_output
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_standardized_data_provider():
|
||||
"""Test the StandardizedDataProvider functionality"""
|
||||
|
||||
print("Testing StandardizedDataProvider...")
|
||||
|
||||
# Initialize the provider
|
||||
symbols = ['ETH/USDT', 'BTC/USDT']
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
|
||||
|
||||
# Test getting BaseDataInput
|
||||
print("\n1. Testing BaseDataInput creation...")
|
||||
base_input = provider.get_base_data_input('ETH/USDT')
|
||||
|
||||
if base_input is None:
|
||||
print("❌ BaseDataInput is None - this is expected if no historical data is available")
|
||||
print(" The provider needs real market data to create BaseDataInput")
|
||||
|
||||
# Test with mock data
|
||||
print("\n2. Testing data structures...")
|
||||
|
||||
# Test ModelOutput creation
|
||||
model_output = create_model_output(
|
||||
model_type='cnn',
|
||||
model_name='test_cnn',
|
||||
symbol='ETH/USDT',
|
||||
action='BUY',
|
||||
confidence=0.75,
|
||||
metadata={'test': True}
|
||||
)
|
||||
|
||||
print(f"✅ Created ModelOutput: {model_output.model_type} - {model_output.predictions['action']} ({model_output.confidence})")
|
||||
|
||||
# Test storing model output
|
||||
provider.store_model_output(model_output)
|
||||
stored_outputs = provider.get_model_outputs('ETH/USDT')
|
||||
|
||||
if 'test_cnn' in stored_outputs:
|
||||
print("✅ Model output storage and retrieval working")
|
||||
else:
|
||||
print("❌ Model output storage failed")
|
||||
|
||||
else:
|
||||
print("✅ BaseDataInput created successfully!")
|
||||
print(f" Symbol: {base_input.symbol}")
|
||||
print(f" Timestamp: {base_input.timestamp}")
|
||||
print(f" OHLCV 1s frames: {len(base_input.ohlcv_1s)}")
|
||||
print(f" OHLCV 1m frames: {len(base_input.ohlcv_1m)}")
|
||||
print(f" OHLCV 1h frames: {len(base_input.ohlcv_1h)}")
|
||||
print(f" OHLCV 1d frames: {len(base_input.ohlcv_1d)}")
|
||||
print(f" BTC 1s frames: {len(base_input.btc_ohlcv_1s)}")
|
||||
print(f" COB data available: {base_input.cob_data is not None}")
|
||||
print(f" Technical indicators: {len(base_input.technical_indicators)}")
|
||||
print(f" Pivot points: {len(base_input.pivot_points)}")
|
||||
print(f" Last predictions: {len(base_input.last_predictions)}")
|
||||
|
||||
# Test feature vector creation
|
||||
try:
|
||||
feature_vector = base_input.get_feature_vector()
|
||||
print(f"✅ Feature vector created: shape {feature_vector.shape}")
|
||||
except Exception as e:
|
||||
print(f"❌ Feature vector creation failed: {e}")
|
||||
|
||||
# Test validation
|
||||
is_valid = base_input.validate()
|
||||
print(f"✅ BaseDataInput validation: {'PASSED' if is_valid else 'FAILED'}")
|
||||
|
||||
print("\n3. Testing data provider capabilities...")
|
||||
|
||||
# Test historical data fetching
|
||||
try:
|
||||
eth_data = provider.get_historical_data('ETH/USDT', '1h', 10)
|
||||
if eth_data is not None and not eth_data.empty:
|
||||
print(f"✅ Historical data available: {len(eth_data)} bars for ETH/USDT 1h")
|
||||
else:
|
||||
print("⚠️ No historical data available - this is normal if APIs are not accessible")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Historical data fetch error: {e}")
|
||||
|
||||
print("\n4. Testing COB data functionality...")
|
||||
|
||||
# Test COB data creation
|
||||
try:
|
||||
# Set a mock current price for testing
|
||||
provider.current_prices['ETHUSDT'] = 3000.0
|
||||
cob_data = provider._get_cob_data('ETH/USDT', datetime.now())
|
||||
|
||||
if cob_data:
|
||||
print(f"✅ COB data created successfully")
|
||||
print(f" Current price: ${cob_data.current_price}")
|
||||
print(f" Bucket size: ${cob_data.bucket_size}")
|
||||
print(f" Price buckets: {len(cob_data.price_buckets)}")
|
||||
print(f" MA 1s imbalance: {len(cob_data.ma_1s_imbalance)} buckets")
|
||||
print(f" MA 5s imbalance: {len(cob_data.ma_5s_imbalance)} buckets")
|
||||
else:
|
||||
print("⚠️ COB data creation returned None")
|
||||
except Exception as e:
|
||||
print(f"❌ COB data creation error: {e}")
|
||||
|
||||
print("\n✅ StandardizedDataProvider test completed!")
|
||||
print("\nNext steps:")
|
||||
print("1. Integrate with real market data APIs")
|
||||
print("2. Connect to actual COB provider")
|
||||
print("3. Test with live data streams")
|
||||
print("4. Integrate with model training pipelines")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_standardized_data_provider()
|
Reference in New Issue
Block a user