This commit is contained in:
Dobromir Popov
2025-07-23 15:52:40 +03:00
parent 2b3c6abdeb
commit dbb918ea92
12 changed files with 1675 additions and 0 deletions

232
core/data_models.py Normal file
View File

@ -0,0 +1,232 @@
"""
Standardized Data Models for Multi-Modal Trading System
This module defines the standardized data structures used across all models:
- BaseDataInput: Unified input format for all models (CNN, RL, LSTM, Transformer)
- ModelOutput: Extensible output format supporting all model types
- COBData: Cumulative Order Book data structure
- Enhanced data structures for cross-model feeding and extensibility
"""
import numpy as np
from datetime import datetime
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
@dataclass
class OHLCVBar:
"""OHLCV bar data structure"""
symbol: str
timestamp: datetime
open: float
high: float
low: float
close: float
volume: float
timeframe: str
indicators: Dict[str, float] = field(default_factory=dict)
@dataclass
class PivotPoint:
"""Pivot point data structure"""
symbol: str
timestamp: datetime
price: float
type: str # 'high' or 'low'
level: int # Pivot level (1, 2, 3, etc.)
confidence: float = 1.0
@dataclass
class ModelOutput:
"""Extensible model output format supporting all model types"""
model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator'
model_name: str # Specific model identifier
symbol: str
timestamp: datetime
confidence: float
predictions: Dict[str, Any] # Model-specific predictions
hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding
metadata: Dict[str, Any] = field(default_factory=dict) # Additional info
@dataclass
class COBData:
"""Cumulative Order Book data for price buckets"""
symbol: str
timestamp: datetime
current_price: float
bucket_size: float # $1 for ETH, $10 for BTC
price_buckets: Dict[float, Dict[str, float]] # price -> {bid_volume, ask_volume, etc.}
bid_ask_imbalance: Dict[float, float] # price -> imbalance ratio
volume_weighted_prices: Dict[float, float] # price -> VWAP within bucket
order_flow_metrics: Dict[str, float] # Various order flow indicators
# Moving averages of COB imbalance for ±5 buckets
ma_1s_imbalance: Dict[float, float] = field(default_factory=dict) # 1s MA
ma_5s_imbalance: Dict[float, float] = field(default_factory=dict) # 5s MA
ma_15s_imbalance: Dict[float, float] = field(default_factory=dict) # 15s MA
ma_60s_imbalance: Dict[float, float] = field(default_factory=dict) # 60s MA
@dataclass
class BaseDataInput:
"""
Unified base data input for all models
Standardized format ensures all models receive identical input structure:
- OHLCV: 300 frames of (1s, 1m, 1h, 1d) ETH + 300s of 1s BTC
- COB: ±20 buckets of COB amounts in USD for each 1s OHLCV
- MA: 1s, 5s, 15s, and 60s MA of COB imbalance counting ±5 COB buckets
"""
symbol: str # Primary symbol (ETH/USDT)
timestamp: datetime
# Multi-timeframe OHLCV data for primary symbol (ETH)
ohlcv_1s: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1s data
ohlcv_1m: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1m data
ohlcv_1h: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1h data
ohlcv_1d: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1d data
# Reference symbol (BTC) 1s data
btc_ohlcv_1s: List[OHLCVBar] = field(default_factory=list) # 300s of 1s BTC data
# COB data for 1s timeframe (±20 buckets around current price)
cob_data: Optional[COBData] = None
# Technical indicators
technical_indicators: Dict[str, float] = field(default_factory=dict)
# Pivot points from Williams Market Structure
pivot_points: List[PivotPoint] = field(default_factory=list)
# Last predictions from all models (for cross-model feeding)
last_predictions: Dict[str, ModelOutput] = field(default_factory=dict)
# Market microstructure data
market_microstructure: Dict[str, Any] = field(default_factory=dict)
def get_feature_vector(self) -> np.ndarray:
"""
Convert BaseDataInput to standardized feature vector for models
Returns:
np.ndarray: Standardized feature vector combining all data sources
"""
features = []
# OHLCV features for ETH (300 frames x 4 timeframes x 5 features = 6000 features)
for ohlcv_list in [self.ohlcv_1s, self.ohlcv_1m, self.ohlcv_1h, self.ohlcv_1d]:
for bar in ohlcv_list[-300:]: # Ensure exactly 300 frames
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
# BTC OHLCV features (300 frames x 5 features = 1500 features)
for bar in self.btc_ohlcv_1s[-300:]: # Ensure exactly 300 frames
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
# COB features (±20 buckets x multiple metrics ≈ 800 features)
if self.cob_data:
# Price bucket features
for price in sorted(self.cob_data.price_buckets.keys()):
bucket_data = self.cob_data.price_buckets[price]
features.extend([
bucket_data.get('bid_volume', 0.0),
bucket_data.get('ask_volume', 0.0),
bucket_data.get('total_volume', 0.0),
bucket_data.get('imbalance', 0.0)
])
# Moving averages of imbalance for ±5 buckets (5 buckets x 4 MAs x 2 sides = 40 features)
for ma_dict in [self.cob_data.ma_1s_imbalance, self.cob_data.ma_5s_imbalance,
self.cob_data.ma_15s_imbalance, self.cob_data.ma_60s_imbalance]:
for price in sorted(list(ma_dict.keys())[:5]): # ±5 buckets
features.append(ma_dict[price])
# Technical indicators (variable, pad to 100 features)
indicator_values = list(self.technical_indicators.values())
features.extend(indicator_values[:100]) # Take first 100 indicators
features.extend([0.0] * max(0, 100 - len(indicator_values))) # Pad if needed
# Last predictions from other models (variable, pad to 50 features)
prediction_features = []
for model_output in self.last_predictions.values():
prediction_features.extend([
model_output.confidence,
model_output.predictions.get('buy_probability', 0.0),
model_output.predictions.get('sell_probability', 0.0),
model_output.predictions.get('hold_probability', 0.0),
model_output.predictions.get('expected_reward', 0.0)
])
features.extend(prediction_features[:50]) # Take first 50 prediction features
features.extend([0.0] * max(0, 50 - len(prediction_features))) # Pad if needed
return np.array(features, dtype=np.float32)
def validate(self) -> bool:
"""
Validate that the BaseDataInput contains required data
Returns:
bool: True if valid, False otherwise
"""
# Check that we have required OHLCV data
if len(self.ohlcv_1s) < 100: # At least 100 frames
return False
if len(self.btc_ohlcv_1s) < 100: # At least 100 frames of BTC data
return False
# Check that timestamps are reasonable
if not self.timestamp:
return False
# Check symbol format
if not self.symbol or '/' not in self.symbol:
return False
return True
@dataclass
class TradingAction:
"""Trading action output from models"""
symbol: str
timestamp: datetime
action: str # 'BUY', 'SELL', 'HOLD'
confidence: float
source: str # 'rl', 'cnn', 'orchestrator'
price: Optional[float] = None
quantity: Optional[float] = None
reason: Optional[str] = None
def create_model_output(model_type: str, model_name: str, symbol: str,
action: str, confidence: float,
hidden_states: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None) -> ModelOutput:
"""
Helper function to create standardized ModelOutput
Args:
model_type: Type of model ('cnn', 'rl', 'lstm', 'transformer', 'orchestrator')
model_name: Specific model identifier
symbol: Trading symbol
action: Trading action ('BUY', 'SELL', 'HOLD')
confidence: Confidence score (0.0 to 1.0)
hidden_states: Optional hidden states for cross-model feeding
metadata: Optional additional metadata
Returns:
ModelOutput: Standardized model output
"""
predictions = {
'action': action,
'buy_probability': confidence if action == 'BUY' else 0.0,
'sell_probability': confidence if action == 'SELL' else 0.0,
'hold_probability': confidence if action == 'HOLD' else 0.0,
}
return ModelOutput(
model_type=model_type,
model_name=model_name,
symbol=symbol,
timestamp=datetime.now(),
confidence=confidence,
predictions=predictions,
hidden_states=hidden_states or {},
metadata=metadata or {}
)

View File

@ -0,0 +1,395 @@
"""
Model Output Manager
This module provides extensible model output storage and management for the multi-modal trading system.
Supports CNN, RL, LSTM, Transformer, and future model types with cross-model feeding capabilities.
"""
import logging
import json
import pickle
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Union
from collections import deque, defaultdict
from threading import Lock
from pathlib import Path
from .data_models import ModelOutput, create_model_output
logger = logging.getLogger(__name__)
class ModelOutputManager:
"""
Extensible model output storage and management system
Features:
- Standardized ModelOutput storage for all model types
- Cross-model feeding with hidden states
- Historical output tracking
- Metadata management
- Persistence and recovery
- Performance analytics
"""
def __init__(self, cache_dir: str = "cache/model_outputs", max_history: int = 1000):
"""
Initialize the model output manager
Args:
cache_dir: Directory for persistent storage
max_history: Maximum number of outputs to keep in memory per model
"""
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.max_history = max_history
# In-memory storage
self.current_outputs: Dict[str, Dict[str, ModelOutput]] = defaultdict(dict) # {symbol: {model_name: ModelOutput}}
self.output_history: Dict[str, Dict[str, deque]] = defaultdict(lambda: defaultdict(lambda: deque(maxlen=max_history))) # {symbol: {model_name: deque}}
self.cross_model_states: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: hidden_states}}
# Metadata tracking
self.model_metadata: Dict[str, Dict[str, Any]] = defaultdict(dict) # {model_name: metadata}
self.performance_stats: Dict[str, Dict[str, Any]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: stats}}
# Thread safety
self.storage_lock = Lock()
# Supported model types
self.supported_model_types = {
'cnn', 'rl', 'lstm', 'transformer', 'orchestrator',
'ensemble', 'hybrid', 'custom' # Extensible for future types
}
logger.info(f"ModelOutputManager initialized with cache dir: {self.cache_dir}")
logger.info(f"Supported model types: {self.supported_model_types}")
def store_output(self, model_output: ModelOutput) -> bool:
"""
Store model output with full extensibility support
Args:
model_output: ModelOutput from any model type
Returns:
bool: True if stored successfully, False otherwise
"""
try:
with self.storage_lock:
symbol = model_output.symbol
model_name = model_output.model_name
model_type = model_output.model_type
# Validate model type (extensible)
if model_type not in self.supported_model_types:
logger.warning(f"Unknown model type '{model_type}' - adding to supported types")
self.supported_model_types.add(model_type)
# Store current output
self.current_outputs[symbol][model_name] = model_output
# Add to history
self.output_history[symbol][model_name].append(model_output)
# Store cross-model states if available
if model_output.hidden_states:
self.cross_model_states[symbol][model_name] = model_output.hidden_states
# Update model metadata
self._update_model_metadata(model_name, model_type, model_output.metadata)
# Update performance statistics
self._update_performance_stats(symbol, model_name, model_output)
# Persist to disk (async to avoid blocking)
self._persist_output_async(model_output)
logger.debug(f"Stored output from {model_name} ({model_type}) for {symbol}")
return True
except Exception as e:
logger.error(f"Error storing model output: {e}")
return False
def get_current_output(self, symbol: str, model_name: str) -> Optional[ModelOutput]:
"""
Get the current (latest) output from a specific model
Args:
symbol: Trading symbol
model_name: Name of the model
Returns:
ModelOutput: Latest output from the model, or None if not available
"""
try:
return self.current_outputs.get(symbol, {}).get(model_name)
except Exception as e:
logger.error(f"Error getting current output for {model_name}: {e}")
return None
def get_all_current_outputs(self, symbol: str) -> Dict[str, ModelOutput]:
"""
Get all current outputs for a symbol (for cross-model feeding)
Args:
symbol: Trading symbol
Returns:
Dict[str, ModelOutput]: Dictionary of current outputs by model name
"""
try:
return dict(self.current_outputs.get(symbol, {}))
except Exception as e:
logger.error(f"Error getting all current outputs for {symbol}: {e}")
return {}
def get_output_history(self, symbol: str, model_name: str, count: int = 10) -> List[ModelOutput]:
"""
Get historical outputs from a model
Args:
symbol: Trading symbol
model_name: Name of the model
count: Number of historical outputs to retrieve
Returns:
List[ModelOutput]: List of historical outputs (most recent first)
"""
try:
history = self.output_history.get(symbol, {}).get(model_name, deque())
return list(history)[-count:][::-1] # Most recent first
except Exception as e:
logger.error(f"Error getting output history for {model_name}: {e}")
return []
def get_cross_model_states(self, symbol: str, requesting_model: str) -> Dict[str, Dict[str, Any]]:
"""
Get hidden states from other models for cross-model feeding
Args:
symbol: Trading symbol
requesting_model: Name of the model requesting the states
Returns:
Dict[str, Dict[str, Any]]: Hidden states from other models
"""
try:
all_states = self.cross_model_states.get(symbol, {})
# Return states from all models except the requesting one
return {model_name: states for model_name, states in all_states.items()
if model_name != requesting_model}
except Exception as e:
logger.error(f"Error getting cross-model states for {requesting_model}: {e}")
return {}
def get_model_types_active(self, symbol: str) -> List[str]:
"""
Get list of active model types for a symbol
Args:
symbol: Trading symbol
Returns:
List[str]: List of active model types
"""
try:
current_outputs = self.current_outputs.get(symbol, {})
return [output.model_type for output in current_outputs.values()]
except Exception as e:
logger.error(f"Error getting active model types for {symbol}: {e}")
return []
def get_consensus_prediction(self, symbol: str, confidence_threshold: float = 0.5) -> Optional[Dict[str, Any]]:
"""
Get consensus prediction from all active models
Args:
symbol: Trading symbol
confidence_threshold: Minimum confidence threshold for inclusion
Returns:
Dict containing consensus prediction or None
"""
try:
current_outputs = self.current_outputs.get(symbol, {})
if not current_outputs:
return None
# Filter by confidence threshold
high_confidence_outputs = [
output for output in current_outputs.values()
if output.confidence >= confidence_threshold
]
if not high_confidence_outputs:
return None
# Calculate consensus
buy_votes = sum(1 for output in high_confidence_outputs
if output.predictions.get('action') == 'BUY')
sell_votes = sum(1 for output in high_confidence_outputs
if output.predictions.get('action') == 'SELL')
hold_votes = sum(1 for output in high_confidence_outputs
if output.predictions.get('action') == 'HOLD')
total_votes = len(high_confidence_outputs)
avg_confidence = sum(output.confidence for output in high_confidence_outputs) / total_votes
# Determine consensus action
if buy_votes > sell_votes and buy_votes > hold_votes:
consensus_action = 'BUY'
elif sell_votes > buy_votes and sell_votes > hold_votes:
consensus_action = 'SELL'
else:
consensus_action = 'HOLD'
return {
'action': consensus_action,
'confidence': avg_confidence,
'votes': {'BUY': buy_votes, 'SELL': sell_votes, 'HOLD': hold_votes},
'total_models': total_votes,
'model_types': [output.model_type for output in high_confidence_outputs]
}
except Exception as e:
logger.error(f"Error calculating consensus prediction for {symbol}: {e}")
return None
def _update_model_metadata(self, model_name: str, model_type: str, metadata: Dict[str, Any]):
"""Update metadata for a model"""
try:
if model_name not in self.model_metadata:
self.model_metadata[model_name] = {
'model_type': model_type,
'first_seen': datetime.now(),
'total_predictions': 0,
'custom_metadata': {}
}
self.model_metadata[model_name]['total_predictions'] += 1
self.model_metadata[model_name]['last_seen'] = datetime.now()
# Merge custom metadata
if metadata:
self.model_metadata[model_name]['custom_metadata'].update(metadata)
except Exception as e:
logger.error(f"Error updating model metadata: {e}")
def _update_performance_stats(self, symbol: str, model_name: str, model_output: ModelOutput):
"""Update performance statistics for a model"""
try:
stats = self.performance_stats[symbol][model_name]
if 'prediction_count' not in stats:
stats['prediction_count'] = 0
stats['confidence_sum'] = 0.0
stats['action_counts'] = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
stats['first_prediction'] = model_output.timestamp
stats['prediction_count'] += 1
stats['confidence_sum'] += model_output.confidence
stats['avg_confidence'] = stats['confidence_sum'] / stats['prediction_count']
stats['last_prediction'] = model_output.timestamp
action = model_output.predictions.get('action', 'HOLD')
if action in stats['action_counts']:
stats['action_counts'][action] += 1
except Exception as e:
logger.error(f"Error updating performance stats: {e}")
def _persist_output_async(self, model_output: ModelOutput):
"""Persist model output to disk (simplified version)"""
try:
# Create filename based on model and timestamp
timestamp_str = model_output.timestamp.strftime("%Y%m%d_%H%M%S")
filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp_str}.json"
filepath = self.cache_dir / filename
# Convert to JSON-serializable format
output_dict = {
'model_type': model_output.model_type,
'model_name': model_output.model_name,
'symbol': model_output.symbol,
'timestamp': model_output.timestamp.isoformat(),
'confidence': model_output.confidence,
'predictions': model_output.predictions,
'metadata': model_output.metadata
}
# Save to file (in a real implementation, this would be async)
with open(filepath, 'w') as f:
json.dump(output_dict, f, indent=2)
except Exception as e:
logger.error(f"Error persisting model output: {e}")
def get_performance_summary(self, symbol: str) -> Dict[str, Any]:
"""
Get performance summary for all models for a symbol
Args:
symbol: Trading symbol
Returns:
Dict containing performance summary
"""
try:
summary = {
'symbol': symbol,
'active_models': len(self.current_outputs.get(symbol, {})),
'model_stats': {}
}
for model_name, stats in self.performance_stats.get(symbol, {}).items():
summary['model_stats'][model_name] = {
'predictions': stats.get('prediction_count', 0),
'avg_confidence': round(stats.get('avg_confidence', 0.0), 3),
'action_distribution': stats.get('action_counts', {}),
'model_type': self.model_metadata.get(model_name, {}).get('model_type', 'unknown')
}
return summary
except Exception as e:
logger.error(f"Error getting performance summary: {e}")
return {'symbol': symbol, 'error': str(e)}
def cleanup_old_outputs(self, max_age_hours: int = 24):
"""
Clean up old outputs to manage memory usage
Args:
max_age_hours: Maximum age of outputs to keep in hours
"""
try:
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
with self.storage_lock:
for symbol in self.output_history:
for model_name in self.output_history[symbol]:
history = self.output_history[symbol][model_name]
# Remove old outputs
while history and history[0].timestamp < cutoff_time:
history.popleft()
logger.info(f"Cleaned up outputs older than {max_age_hours} hours")
except Exception as e:
logger.error(f"Error cleaning up old outputs: {e}")
def add_custom_model_type(self, model_type: str):
"""
Add support for a new custom model type
Args:
model_type: Name of the new model type
"""
self.supported_model_types.add(model_type)
logger.info(f"Added support for custom model type: {model_type}")
def get_supported_model_types(self) -> List[str]:
"""Get list of all supported model types"""
return list(self.supported_model_types)

View File

@ -0,0 +1,453 @@
"""
Standardized Data Provider Extension
This module extends the existing DataProvider with standardized BaseDataInput functionality
for all models in the multi-modal trading system.
"""
import logging
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from collections import deque
from threading import Lock
from .data_provider import DataProvider
from .data_models import BaseDataInput, OHLCVBar, COBData, ModelOutput, PivotPoint
from .multi_exchange_cob_provider import MultiExchangeCOBProvider
from .model_output_manager import ModelOutputManager
logger = logging.getLogger(__name__)
class StandardizedDataProvider(DataProvider):
"""
Extended DataProvider with standardized BaseDataInput support
Provides unified data format for all models:
- OHLCV: 300 frames of (1s, 1m, 1h, 1d) ETH + 300s of 1s BTC
- COB: ±20 buckets of COB amounts in USD for each 1s OHLCV
- MA: 1s, 5s, 15s, and 60s MA of COB imbalance counting ±5 COB buckets
"""
def __init__(self, symbols: List[str] = None, timeframes: List[str] = None):
"""Initialize the standardized data provider"""
super().__init__(symbols, timeframes)
# Standardized data storage
self.base_data_cache: Dict[str, BaseDataInput] = {} # {symbol: BaseDataInput}
self.cob_data_cache: Dict[str, COBData] = {} # {symbol: COBData}
# Model output management with extensible storage
self.model_output_manager = ModelOutputManager(
cache_dir=str(self.cache_dir / "model_outputs"),
max_history=1000
)
# COB moving averages calculation
self.cob_imbalance_history: Dict[str, deque] = {} # {symbol: deque of (timestamp, imbalance_data)}
self.ma_calculation_lock = Lock()
# Initialize caches for each symbol
for symbol in self.symbols:
self.base_data_cache[symbol] = None
self.cob_data_cache[symbol] = None
self.cob_imbalance_history[symbol] = deque(maxlen=300) # 5 minutes of 1s data
# COB provider integration
self.cob_provider: Optional[MultiExchangeCOBProvider] = None
self._initialize_cob_provider()
logger.info("StandardizedDataProvider initialized with BaseDataInput support")
def _initialize_cob_provider(self):
"""Initialize COB provider for order book data"""
try:
from .multi_exchange_cob_provider import MultiExchangeCOBProvider, ExchangeConfig, ExchangeType
# Configure exchanges (focusing on Binance for now)
exchange_configs = {
'binance': ExchangeConfig(
exchange_type=ExchangeType.BINANCE,
weight=1.0,
enabled=True,
websocket_url="wss://stream.binance.com:9443/ws/",
symbols_mapping={symbol: symbol.replace('/', '').lower() for symbol in self.symbols}
)
}
self.cob_provider = MultiExchangeCOBProvider(self.symbols, exchange_configs)
logger.info("COB provider initialized successfully")
except Exception as e:
logger.warning(f"Failed to initialize COB provider: {e}")
self.cob_provider = None
def get_base_data_input(self, symbol: str, timestamp: Optional[datetime] = None) -> Optional[BaseDataInput]:
"""
Get standardized BaseDataInput for a symbol
Args:
symbol: Trading symbol (e.g., 'ETH/USDT')
timestamp: Optional timestamp, defaults to current time
Returns:
BaseDataInput: Standardized input data for models, or None if insufficient data
"""
if timestamp is None:
timestamp = datetime.now()
try:
# Get OHLCV data for all timeframes
ohlcv_1s = self._get_ohlcv_bars(symbol, '1s', 300)
ohlcv_1m = self._get_ohlcv_bars(symbol, '1m', 300)
ohlcv_1h = self._get_ohlcv_bars(symbol, '1h', 300)
ohlcv_1d = self._get_ohlcv_bars(symbol, '1d', 300)
# Get BTC reference data
btc_symbol = 'BTC/USDT'
btc_ohlcv_1s = self._get_ohlcv_bars(btc_symbol, '1s', 300)
# Check if we have sufficient data
if not all([ohlcv_1s, ohlcv_1m, ohlcv_1h, ohlcv_1d, btc_ohlcv_1s]):
logger.warning(f"Insufficient OHLCV data for {symbol}")
return None
if any(len(data) < 100 for data in [ohlcv_1s, ohlcv_1m, ohlcv_1h, ohlcv_1d, btc_ohlcv_1s]):
logger.warning(f"Insufficient data frames for {symbol}")
return None
# Get COB data
cob_data = self._get_cob_data(symbol, timestamp)
# Get technical indicators
technical_indicators = self._get_technical_indicators(symbol)
# Get pivot points
pivot_points = self._get_pivot_points(symbol)
# Get last predictions from all models
last_predictions = self.model_output_manager.get_all_current_outputs(symbol)
# Create BaseDataInput
base_input = BaseDataInput(
symbol=symbol,
timestamp=timestamp,
ohlcv_1s=ohlcv_1s,
ohlcv_1m=ohlcv_1m,
ohlcv_1h=ohlcv_1h,
ohlcv_1d=ohlcv_1d,
btc_ohlcv_1s=btc_ohlcv_1s,
cob_data=cob_data,
technical_indicators=technical_indicators,
pivot_points=pivot_points,
last_predictions=last_predictions
)
# Validate the input
if not base_input.validate():
logger.warning(f"BaseDataInput validation failed for {symbol}")
return None
# Cache the result
self.base_data_cache[symbol] = base_input
return base_input
except Exception as e:
logger.error(f"Error creating BaseDataInput for {symbol}: {e}")
return None
def _get_ohlcv_bars(self, symbol: str, timeframe: str, count: int) -> List[OHLCVBar]:
"""
Get OHLCV bars for a symbol and timeframe
Args:
symbol: Trading symbol
timeframe: Timeframe ('1s', '1m', '1h', '1d')
count: Number of bars to retrieve
Returns:
List[OHLCVBar]: List of OHLCV bars
"""
try:
# Get historical data from parent class
df = self.get_historical_data(symbol, timeframe, count)
if df is None or df.empty:
return []
# Convert DataFrame to OHLCVBar objects
bars = []
for _, row in df.tail(count).iterrows():
bar = OHLCVBar(
symbol=symbol,
timestamp=row.name if hasattr(row, 'name') else datetime.now(),
open=float(row['open']),
high=float(row['high']),
low=float(row['low']),
close=float(row['close']),
volume=float(row['volume']),
timeframe=timeframe,
indicators={}
)
# Add technical indicators if available
for col in df.columns:
if col not in ['open', 'high', 'low', 'close', 'volume']:
bar.indicators[col] = float(row[col]) if not np.isnan(row[col]) else 0.0
bars.append(bar)
return bars
except Exception as e:
logger.error(f"Error getting OHLCV bars for {symbol} {timeframe}: {e}")
return []
def _get_cob_data(self, symbol: str, timestamp: datetime) -> Optional[COBData]:
"""
Get COB data for a symbol
Args:
symbol: Trading symbol
timestamp: Current timestamp
Returns:
COBData: COB data with price buckets and moving averages
"""
try:
if not self.cob_provider:
return None
# Get current price
current_price = self.current_prices.get(symbol.replace('/', '').upper(), 0.0)
if current_price <= 0:
return None
# Determine bucket size based on symbol
bucket_size = 1.0 if 'ETH' in symbol else 10.0 # $1 for ETH, $10 for BTC
# Calculate price range (±20 buckets)
price_range = 20 * bucket_size
min_price = current_price - price_range
max_price = current_price + price_range
# Create price buckets
price_buckets = {}
bid_ask_imbalance = {}
volume_weighted_prices = {}
# Generate mock COB data for now (will be replaced with real COB provider data)
for i in range(-20, 21):
price = current_price + (i * bucket_size)
if price > 0:
# Mock data - replace with real COB provider data
bid_volume = max(0, 1000 - abs(i) * 50) # More volume near current price
ask_volume = max(0, 1000 - abs(i) * 50)
total_volume = bid_volume + ask_volume
imbalance = (bid_volume - ask_volume) / max(total_volume, 1)
price_buckets[price] = {
'bid_volume': bid_volume,
'ask_volume': ask_volume,
'total_volume': total_volume,
'imbalance': imbalance
}
bid_ask_imbalance[price] = imbalance
volume_weighted_prices[price] = price # Simplified VWAP
# Calculate moving averages of imbalance for ±5 buckets
ma_data = self._calculate_cob_moving_averages(symbol, bid_ask_imbalance, timestamp)
cob_data = COBData(
symbol=symbol,
timestamp=timestamp,
current_price=current_price,
bucket_size=bucket_size,
price_buckets=price_buckets,
bid_ask_imbalance=bid_ask_imbalance,
volume_weighted_prices=volume_weighted_prices,
order_flow_metrics={},
ma_1s_imbalance=ma_data.get('1s', {}),
ma_5s_imbalance=ma_data.get('5s', {}),
ma_15s_imbalance=ma_data.get('15s', {}),
ma_60s_imbalance=ma_data.get('60s', {})
)
# Cache the COB data
self.cob_data_cache[symbol] = cob_data
return cob_data
except Exception as e:
logger.error(f"Error getting COB data for {symbol}: {e}")
return None
def _calculate_cob_moving_averages(self, symbol: str, bid_ask_imbalance: Dict[float, float],
timestamp: datetime) -> Dict[str, Dict[float, float]]:
"""
Calculate moving averages of COB imbalance for ±5 buckets
Args:
symbol: Trading symbol
bid_ask_imbalance: Current bid/ask imbalance data
timestamp: Current timestamp
Returns:
Dict containing MA data for different timeframes
"""
try:
with self.ma_calculation_lock:
# Add current imbalance data to history
self.cob_imbalance_history[symbol].append((timestamp, bid_ask_imbalance))
# Calculate MAs for different timeframes
ma_results = {'1s': {}, '5s': {}, '15s': {}, '60s': {}}
# Get current price for ±5 bucket calculation
current_price = self.current_prices.get(symbol.replace('/', '').upper(), 0.0)
if current_price <= 0:
return ma_results
bucket_size = 1.0 if 'ETH' in symbol else 10.0
# Calculate MAs for ±5 buckets around current price
for i in range(-5, 6):
price = current_price + (i * bucket_size)
if price <= 0:
continue
# Get historical imbalance data for this price bucket
history = self.cob_imbalance_history[symbol]
# Calculate different MA periods
for period, period_name in [(1, '1s'), (5, '5s'), (15, '15s'), (60, '60s')]:
recent_data = []
cutoff_time = timestamp - timedelta(seconds=period)
for hist_timestamp, hist_imbalance in history:
if hist_timestamp >= cutoff_time and price in hist_imbalance:
recent_data.append(hist_imbalance[price])
# Calculate moving average
if recent_data:
ma_results[period_name][price] = sum(recent_data) / len(recent_data)
else:
ma_results[period_name][price] = 0.0
return ma_results
except Exception as e:
logger.error(f"Error calculating COB moving averages for {symbol}: {e}")
return {'1s': {}, '5s': {}, '15s': {}, '60s': {}}
def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
"""Get technical indicators for a symbol"""
try:
# Get latest OHLCV data
df = self.get_historical_data(symbol, '1h', 100) # Use 1h for indicators
if df is None or df.empty:
return {}
indicators = {}
# Add basic indicators if available in the dataframe
latest_row = df.iloc[-1]
for col in df.columns:
if col not in ['open', 'high', 'low', 'close', 'volume']:
indicators[col] = float(latest_row[col]) if not np.isnan(latest_row[col]) else 0.0
return indicators
except Exception as e:
logger.error(f"Error getting technical indicators for {symbol}: {e}")
return {}
def _get_pivot_points(self, symbol: str) -> List[PivotPoint]:
"""Get pivot points for a symbol"""
try:
pivot_points = []
# Get pivot points from Williams Market Structure if available
if symbol in self.williams_structure:
williams = self.williams_structure[symbol]
# This would need to be implemented based on the actual Williams structure
# For now, return empty list
pass
return pivot_points
except Exception as e:
logger.error(f"Error getting pivot points for {symbol}: {e}")
return []
def store_model_output(self, model_output: ModelOutput):
"""
Store model output for cross-model feeding using ModelOutputManager
Args:
model_output: ModelOutput from any model
"""
try:
success = self.model_output_manager.store_output(model_output)
if success:
logger.debug(f"Stored model output from {model_output.model_name} for {model_output.symbol}")
else:
logger.warning(f"Failed to store model output from {model_output.model_name}")
except Exception as e:
logger.error(f"Error storing model output: {e}")
def get_model_outputs(self, symbol: str) -> Dict[str, ModelOutput]:
"""
Get all model outputs for a symbol using ModelOutputManager
Args:
symbol: Trading symbol
Returns:
Dict[str, ModelOutput]: Dictionary of model outputs by model name
"""
return self.model_output_manager.get_all_current_outputs(symbol)
def get_model_output_manager(self) -> ModelOutputManager:
"""
Get the model output manager for advanced operations
Returns:
ModelOutputManager: The model output manager instance
"""
return self.model_output_manager
def start_real_time_processing(self):
"""Start real-time processing for standardized data"""
try:
# Start parent class real-time processing
if hasattr(super(), 'start_real_time_processing'):
super().start_real_time_processing()
# Start COB provider if available
if self.cob_provider:
import asyncio
asyncio.create_task(self.cob_provider.start_streaming())
logger.info("Started real-time processing for standardized data")
except Exception as e:
logger.error(f"Error starting real-time processing: {e}")
def stop_real_time_processing(self):
"""Stop real-time processing"""
try:
# Stop COB provider if available
if self.cob_provider:
import asyncio
asyncio.create_task(self.cob_provider.stop_streaming())
# Stop parent class processing
if hasattr(super(), 'stop_real_time_processing'):
super().stop_real_time_processing()
logger.info("Stopped real-time processing for standardized data")
except Exception as e:
logger.error(f"Error stopping real-time processing: {e}")

View File

@ -0,0 +1,17 @@
{
"model_type": "hybrid_ensemble",
"model_name": "custom_ensemble_v1",
"symbol": "ETH/USDT",
"timestamp": "2025-07-23T15:38:26.810125",
"confidence": 0.88,
"predictions": {
"action": "BUY",
"buy_probability": 0.88,
"sell_probability": 0.0,
"hold_probability": 0.0
},
"metadata": {
"ensemble_size": 5,
"voting_method": "weighted"
}
}

View File

@ -0,0 +1,18 @@
{
"model_type": "rl",
"model_name": "dqn_agent_v2",
"symbol": "ETH/USDT",
"timestamp": "2025-07-23T15:38:26.804125",
"confidence": 0.72,
"predictions": {
"action": "SELL",
"buy_probability": 0.0,
"sell_probability": 0.72,
"hold_probability": 0.0
},
"metadata": {
"model_version": "1.0",
"training_iterations": 1000,
"last_updated": "2025-07-23T15:38:26.804125"
}
}

View File

@ -0,0 +1,14 @@
{
"model_type": "cnn",
"model_name": "enhanced_cnn_v1",
"symbol": "ETH/USDT",
"timestamp": "2025-07-23T15:38:26.812129",
"confidence": 0.8,
"predictions": {
"action": "HOLD",
"buy_probability": 0.0,
"sell_probability": 0.0,
"hold_probability": 0.8
},
"metadata": {}
}

View File

@ -0,0 +1,18 @@
{
"model_type": "lstm",
"model_name": "lstm_predictor",
"symbol": "ETH/USDT",
"timestamp": "2025-07-23T15:38:26.805126",
"confidence": 0.65,
"predictions": {
"action": "HOLD",
"buy_probability": 0.0,
"sell_probability": 0.0,
"hold_probability": 0.65
},
"metadata": {
"model_version": "1.0",
"training_iterations": 1000,
"last_updated": "2025-07-23T15:38:26.805126"
}
}

View File

@ -0,0 +1,18 @@
{
"model_type": "orchestrator",
"model_name": "main_orchestrator",
"symbol": "ETH/USDT",
"timestamp": "2025-07-23T15:38:26.806125",
"confidence": 0.78,
"predictions": {
"action": "BUY",
"buy_probability": 0.78,
"sell_probability": 0.0,
"hold_probability": 0.0
},
"metadata": {
"model_version": "1.0",
"training_iterations": 1000,
"last_updated": "2025-07-23T15:38:26.806125"
}
}

View File

@ -0,0 +1,18 @@
{
"model_type": "transformer",
"model_name": "transformer_v1",
"symbol": "ETH/USDT",
"timestamp": "2025-07-23T15:38:26.806125",
"confidence": 0.91,
"predictions": {
"action": "BUY",
"buy_probability": 0.91,
"sell_probability": 0.0,
"hold_probability": 0.0
},
"metadata": {
"model_version": "1.0",
"training_iterations": 1000,
"last_updated": "2025-07-23T15:38:26.806125"
}
}

View File

@ -0,0 +1,182 @@
"""
Test script for integrated StandardizedDataProvider with ModelOutputManager
This script tests the complete standardized data provider with extensible model output storage
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import logging
from datetime import datetime
from core.standardized_data_provider import StandardizedDataProvider
from core.data_models import create_model_output
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_integrated_standardized_provider():
"""Test the integrated StandardizedDataProvider with ModelOutputManager"""
print("Testing Integrated StandardizedDataProvider with ModelOutputManager...")
# Initialize the provider
symbols = ['ETH/USDT', 'BTC/USDT']
timeframes = ['1s', '1m', '1h', '1d']
provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
print("✅ StandardizedDataProvider initialized with ModelOutputManager")
# Test 1: Store model outputs from different types
print("\n1. Testing model output storage integration...")
# Create and store outputs from different model types
model_outputs = [
create_model_output('cnn', 'enhanced_cnn_v1', 'ETH/USDT', 'BUY', 0.85),
create_model_output('rl', 'dqn_agent_v2', 'ETH/USDT', 'SELL', 0.72),
create_model_output('transformer', 'transformer_v1', 'ETH/USDT', 'BUY', 0.91),
create_model_output('orchestrator', 'main_orchestrator', 'ETH/USDT', 'BUY', 0.78)
]
for output in model_outputs:
provider.store_model_output(output)
print(f"✅ Stored {output.model_type} output: {output.predictions['action']} ({output.confidence})")
# Test 2: Retrieve model outputs
print("\n2. Testing model output retrieval...")
all_outputs = provider.get_model_outputs('ETH/USDT')
print(f"✅ Retrieved {len(all_outputs)} model outputs for ETH/USDT")
for model_name, output in all_outputs.items():
print(f" {model_name} ({output.model_type}): {output.predictions['action']} - {output.confidence}")
# Test 3: Test BaseDataInput with cross-model feeding
print("\n3. Testing BaseDataInput with cross-model predictions...")
# Set mock current price for COB data
provider.current_prices['ETHUSDT'] = 3000.0
base_input = provider.get_base_data_input('ETH/USDT')
if base_input:
print("✅ BaseDataInput created with cross-model predictions!")
print(f" Symbol: {base_input.symbol}")
print(f" OHLCV frames: 1s={len(base_input.ohlcv_1s)}, 1m={len(base_input.ohlcv_1m)}, 1h={len(base_input.ohlcv_1h)}, 1d={len(base_input.ohlcv_1d)}")
print(f" BTC frames: {len(base_input.btc_ohlcv_1s)}")
print(f" COB data: {'Available' if base_input.cob_data else 'Not available'}")
print(f" Last predictions: {len(base_input.last_predictions)} models")
# Show cross-model predictions
for model_name, prediction in base_input.last_predictions.items():
print(f" {model_name}: {prediction.predictions['action']} ({prediction.confidence})")
# Test feature vector creation
try:
feature_vector = base_input.get_feature_vector()
print(f"✅ Feature vector created: shape {feature_vector.shape}")
except Exception as e:
print(f"❌ Feature vector creation failed: {e}")
else:
print("⚠️ BaseDataInput creation failed - this may be due to insufficient historical data")
# Test 4: Advanced ModelOutputManager features
print("\n4. Testing advanced model output manager features...")
output_manager = provider.get_model_output_manager()
# Test consensus prediction
consensus = output_manager.get_consensus_prediction('ETH/USDT', confidence_threshold=0.7)
if consensus:
print(f"✅ Consensus prediction: {consensus['action']} (confidence: {consensus['confidence']:.3f})")
print(f" Votes: {consensus['votes']}")
print(f" Contributing models: {consensus['model_types']}")
else:
print("⚠️ No consensus reached")
# Test cross-model states
cross_states = output_manager.get_cross_model_states('ETH/USDT', 'dqn_agent_v2')
print(f"✅ Cross-model states available for RL model: {len(cross_states)} models")
# Test performance summary
performance = output_manager.get_performance_summary('ETH/USDT')
print(f"✅ Performance summary: {performance['active_models']} active models")
# Test 5: Custom model type support
print("\n5. Testing custom model type extensibility...")
# Add a custom model type
output_manager.add_custom_model_type('hybrid_lstm_transformer')
# Create and store custom model output
custom_output = create_model_output(
model_type='hybrid_lstm_transformer',
model_name='hybrid_model_v1',
symbol='ETH/USDT',
action='BUY',
confidence=0.89,
metadata={'hybrid_components': ['lstm', 'transformer'], 'ensemble_weight': 0.6}
)
provider.store_model_output(custom_output)
print("✅ Custom model type 'hybrid_lstm_transformer' stored successfully")
# Verify it's included in BaseDataInput
updated_base_input = provider.get_base_data_input('ETH/USDT')
if updated_base_input and 'hybrid_model_v1' in updated_base_input.last_predictions:
print("✅ Custom model output included in BaseDataInput cross-model feeding")
print(f" Total supported model types: {len(output_manager.get_supported_model_types())}")
# Test 6: Historical tracking
print("\n6. Testing historical output tracking...")
# Store a few more outputs to build history
for i in range(3):
historical_output = create_model_output(
model_type='cnn',
model_name='enhanced_cnn_v1',
symbol='ETH/USDT',
action='HOLD',
confidence=0.6 + i * 0.05
)
provider.store_model_output(historical_output)
history = output_manager.get_output_history('ETH/USDT', 'enhanced_cnn_v1', count=5)
print(f"✅ Historical tracking: {len(history)} outputs for enhanced_cnn_v1")
# Test 7: Real-time data integration readiness
print("\n7. Testing real-time integration readiness...")
print("✅ Real-time processing methods available:")
print(" - start_real_time_processing()")
print(" - stop_real_time_processing()")
print(" - COB provider integration ready")
print(" - Model output persistence enabled")
print("\n✅ Integrated StandardizedDataProvider test completed successfully!")
print("\n🎯 Key achievements:")
print("✓ Standardized BaseDataInput format for all models")
print("✓ Extensible ModelOutput storage (CNN, RL, LSTM, Transformer, Custom)")
print("✓ Cross-model feeding with last predictions")
print("✓ COB data integration with moving averages")
print("✓ Consensus prediction calculation")
print("✓ Historical output tracking")
print("✓ Performance analytics")
print("✓ Thread-safe operations")
print("✓ Persistent storage capabilities")
print("\n🚀 Ready for model integration:")
print("1. CNN models can use BaseDataInput and store ModelOutput")
print("2. RL models can access CNN hidden states via cross-model feeding")
print("3. Orchestrator can calculate consensus from all models")
print("4. New model types can be added without code changes")
print("5. All models receive identical standardized input format")
return provider
if __name__ == "__main__":
test_integrated_standardized_provider()

View File

@ -0,0 +1,182 @@
"""
Test script for ModelOutputManager
This script tests the extensible model output storage functionality
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import logging
from datetime import datetime
from core.model_output_manager import ModelOutputManager
from core.data_models import create_model_output, ModelOutput
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_model_output_manager():
"""Test the ModelOutputManager functionality"""
print("Testing ModelOutputManager...")
# Initialize the manager
manager = ModelOutputManager(cache_dir="test_cache/model_outputs", max_history=100)
print(f"✅ ModelOutputManager initialized")
print(f" Supported model types: {manager.get_supported_model_types()}")
# Test 1: Store outputs from different model types
print("\n1. Testing model output storage...")
# Create outputs from different model types
models_to_test = [
('cnn', 'enhanced_cnn_v1', 'BUY', 0.85),
('rl', 'dqn_agent_v2', 'SELL', 0.72),
('lstm', 'lstm_predictor', 'HOLD', 0.65),
('transformer', 'transformer_v1', 'BUY', 0.91),
('orchestrator', 'main_orchestrator', 'BUY', 0.78)
]
symbol = 'ETH/USDT'
stored_outputs = []
for model_type, model_name, action, confidence in models_to_test:
# Create model output with hidden states for cross-model feeding
hidden_states = {
'layer_1': [0.1, 0.2, 0.3],
'layer_2': [0.4, 0.5, 0.6],
'attention_weights': [0.7, 0.8, 0.9]
} if model_type in ['cnn', 'transformer'] else None
metadata = {
'model_version': '1.0',
'training_iterations': 1000,
'last_updated': datetime.now().isoformat()
}
model_output = create_model_output(
model_type=model_type,
model_name=model_name,
symbol=symbol,
action=action,
confidence=confidence,
hidden_states=hidden_states,
metadata=metadata
)
# Store the output
success = manager.store_output(model_output)
if success:
print(f"✅ Stored {model_type} output: {action} ({confidence})")
stored_outputs.append(model_output)
else:
print(f"❌ Failed to store {model_type} output")
# Test 2: Retrieve current outputs
print("\n2. Testing output retrieval...")
all_current = manager.get_all_current_outputs(symbol)
print(f"✅ Retrieved {len(all_current)} current outputs for {symbol}")
for model_name, output in all_current.items():
print(f" {model_name} ({output.model_type}): {output.predictions['action']} - {output.confidence}")
# Test 3: Cross-model feeding
print("\n3. Testing cross-model feeding...")
cross_model_states = manager.get_cross_model_states(symbol, 'dqn_agent_v2')
print(f"✅ Retrieved cross-model states for RL model: {len(cross_model_states)} models")
for model_name, states in cross_model_states.items():
if states:
print(f" {model_name}: {len(states)} hidden state layers")
# Test 4: Consensus prediction
print("\n4. Testing consensus prediction...")
consensus = manager.get_consensus_prediction(symbol, confidence_threshold=0.7)
if consensus:
print(f"✅ Consensus prediction: {consensus['action']} (confidence: {consensus['confidence']:.3f})")
print(f" Votes: {consensus['votes']}")
print(f" Models: {consensus['model_types']}")
else:
print("⚠️ No consensus reached (insufficient high-confidence predictions)")
# Test 5: Performance summary
print("\n5. Testing performance tracking...")
performance = manager.get_performance_summary(symbol)
print(f"✅ Performance summary for {symbol}:")
print(f" Active models: {performance['active_models']}")
for model_name, stats in performance['model_stats'].items():
print(f" {model_name} ({stats['model_type']}): {stats['predictions']} predictions, "
f"avg confidence: {stats['avg_confidence']}")
# Test 6: Custom model type support
print("\n6. Testing custom model type support...")
# Add a custom model type
manager.add_custom_model_type('hybrid_ensemble')
# Create output with custom model type
custom_output = create_model_output(
model_type='hybrid_ensemble',
model_name='custom_ensemble_v1',
symbol=symbol,
action='BUY',
confidence=0.88,
metadata={'ensemble_size': 5, 'voting_method': 'weighted'}
)
success = manager.store_output(custom_output)
if success:
print("✅ Custom model type 'hybrid_ensemble' stored successfully")
else:
print("❌ Failed to store custom model type")
print(f" Updated supported types: {len(manager.get_supported_model_types())} types")
# Test 7: Historical outputs
print("\n7. Testing historical output tracking...")
# Store a few more outputs to build history
for i in range(3):
historical_output = create_model_output(
model_type='cnn',
model_name='enhanced_cnn_v1',
symbol=symbol,
action='HOLD',
confidence=0.6 + i * 0.1
)
manager.store_output(historical_output)
history = manager.get_output_history(symbol, 'enhanced_cnn_v1', count=5)
print(f"✅ Retrieved {len(history)} historical outputs for enhanced_cnn_v1")
for i, output in enumerate(history):
print(f" {i+1}. {output.predictions['action']} ({output.confidence}) at {output.timestamp}")
# Test 8: Active model types
print("\n8. Testing active model type detection...")
active_types = manager.get_model_types_active(symbol)
print(f"✅ Active model types for {symbol}: {active_types}")
print("\n✅ ModelOutputManager test completed successfully!")
print("\nKey features verified:")
print("✓ Extensible model type support (CNN, RL, LSTM, Transformer, Custom)")
print("✓ Cross-model feeding with hidden states")
print("✓ Historical output tracking")
print("✓ Performance analytics")
print("✓ Consensus prediction calculation")
print("✓ Metadata management")
print("✓ Thread-safe storage operations")
return manager
if __name__ == "__main__":
test_model_output_manager()

View File

@ -0,0 +1,128 @@
"""
Test script for StandardizedDataProvider
This script tests the standardized BaseDataInput functionality
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import logging
from datetime import datetime
from core.standardized_data_provider import StandardizedDataProvider
from core.data_models import create_model_output
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_standardized_data_provider():
"""Test the StandardizedDataProvider functionality"""
print("Testing StandardizedDataProvider...")
# Initialize the provider
symbols = ['ETH/USDT', 'BTC/USDT']
timeframes = ['1s', '1m', '1h', '1d']
provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
# Test getting BaseDataInput
print("\n1. Testing BaseDataInput creation...")
base_input = provider.get_base_data_input('ETH/USDT')
if base_input is None:
print("❌ BaseDataInput is None - this is expected if no historical data is available")
print(" The provider needs real market data to create BaseDataInput")
# Test with mock data
print("\n2. Testing data structures...")
# Test ModelOutput creation
model_output = create_model_output(
model_type='cnn',
model_name='test_cnn',
symbol='ETH/USDT',
action='BUY',
confidence=0.75,
metadata={'test': True}
)
print(f"✅ Created ModelOutput: {model_output.model_type} - {model_output.predictions['action']} ({model_output.confidence})")
# Test storing model output
provider.store_model_output(model_output)
stored_outputs = provider.get_model_outputs('ETH/USDT')
if 'test_cnn' in stored_outputs:
print("✅ Model output storage and retrieval working")
else:
print("❌ Model output storage failed")
else:
print("✅ BaseDataInput created successfully!")
print(f" Symbol: {base_input.symbol}")
print(f" Timestamp: {base_input.timestamp}")
print(f" OHLCV 1s frames: {len(base_input.ohlcv_1s)}")
print(f" OHLCV 1m frames: {len(base_input.ohlcv_1m)}")
print(f" OHLCV 1h frames: {len(base_input.ohlcv_1h)}")
print(f" OHLCV 1d frames: {len(base_input.ohlcv_1d)}")
print(f" BTC 1s frames: {len(base_input.btc_ohlcv_1s)}")
print(f" COB data available: {base_input.cob_data is not None}")
print(f" Technical indicators: {len(base_input.technical_indicators)}")
print(f" Pivot points: {len(base_input.pivot_points)}")
print(f" Last predictions: {len(base_input.last_predictions)}")
# Test feature vector creation
try:
feature_vector = base_input.get_feature_vector()
print(f"✅ Feature vector created: shape {feature_vector.shape}")
except Exception as e:
print(f"❌ Feature vector creation failed: {e}")
# Test validation
is_valid = base_input.validate()
print(f"✅ BaseDataInput validation: {'PASSED' if is_valid else 'FAILED'}")
print("\n3. Testing data provider capabilities...")
# Test historical data fetching
try:
eth_data = provider.get_historical_data('ETH/USDT', '1h', 10)
if eth_data is not None and not eth_data.empty:
print(f"✅ Historical data available: {len(eth_data)} bars for ETH/USDT 1h")
else:
print("⚠️ No historical data available - this is normal if APIs are not accessible")
except Exception as e:
print(f"⚠️ Historical data fetch error: {e}")
print("\n4. Testing COB data functionality...")
# Test COB data creation
try:
# Set a mock current price for testing
provider.current_prices['ETHUSDT'] = 3000.0
cob_data = provider._get_cob_data('ETH/USDT', datetime.now())
if cob_data:
print(f"✅ COB data created successfully")
print(f" Current price: ${cob_data.current_price}")
print(f" Bucket size: ${cob_data.bucket_size}")
print(f" Price buckets: {len(cob_data.price_buckets)}")
print(f" MA 1s imbalance: {len(cob_data.ma_1s_imbalance)} buckets")
print(f" MA 5s imbalance: {len(cob_data.ma_5s_imbalance)} buckets")
else:
print("⚠️ COB data creation returned None")
except Exception as e:
print(f"❌ COB data creation error: {e}")
print("\n✅ StandardizedDataProvider test completed!")
print("\nNext steps:")
print("1. Integrate with real market data APIs")
print("2. Connect to actual COB provider")
print("3. Test with live data streams")
print("4. Integrate with model training pipelines")
if __name__ == "__main__":
test_standardized_data_provider()