diff --git a/core/data_models.py b/core/data_models.py new file mode 100644 index 0000000..3a3643a --- /dev/null +++ b/core/data_models.py @@ -0,0 +1,232 @@ +""" +Standardized Data Models for Multi-Modal Trading System + +This module defines the standardized data structures used across all models: +- BaseDataInput: Unified input format for all models (CNN, RL, LSTM, Transformer) +- ModelOutput: Extensible output format supporting all model types +- COBData: Cumulative Order Book data structure +- Enhanced data structures for cross-model feeding and extensibility +""" + +import numpy as np +from datetime import datetime +from typing import Dict, List, Optional, Any +from dataclasses import dataclass, field + +@dataclass +class OHLCVBar: + """OHLCV bar data structure""" + symbol: str + timestamp: datetime + open: float + high: float + low: float + close: float + volume: float + timeframe: str + indicators: Dict[str, float] = field(default_factory=dict) + +@dataclass +class PivotPoint: + """Pivot point data structure""" + symbol: str + timestamp: datetime + price: float + type: str # 'high' or 'low' + level: int # Pivot level (1, 2, 3, etc.) + confidence: float = 1.0 + +@dataclass +class ModelOutput: + """Extensible model output format supporting all model types""" + model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator' + model_name: str # Specific model identifier + symbol: str + timestamp: datetime + confidence: float + predictions: Dict[str, Any] # Model-specific predictions + hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding + metadata: Dict[str, Any] = field(default_factory=dict) # Additional info + +@dataclass +class COBData: + """Cumulative Order Book data for price buckets""" + symbol: str + timestamp: datetime + current_price: float + bucket_size: float # $1 for ETH, $10 for BTC + price_buckets: Dict[float, Dict[str, float]] # price -> {bid_volume, ask_volume, etc.} + bid_ask_imbalance: Dict[float, float] # price -> imbalance ratio + volume_weighted_prices: Dict[float, float] # price -> VWAP within bucket + order_flow_metrics: Dict[str, float] # Various order flow indicators + + # Moving averages of COB imbalance for ±5 buckets + ma_1s_imbalance: Dict[float, float] = field(default_factory=dict) # 1s MA + ma_5s_imbalance: Dict[float, float] = field(default_factory=dict) # 5s MA + ma_15s_imbalance: Dict[float, float] = field(default_factory=dict) # 15s MA + ma_60s_imbalance: Dict[float, float] = field(default_factory=dict) # 60s MA + +@dataclass +class BaseDataInput: + """ + Unified base data input for all models + + Standardized format ensures all models receive identical input structure: + - OHLCV: 300 frames of (1s, 1m, 1h, 1d) ETH + 300s of 1s BTC + - COB: ±20 buckets of COB amounts in USD for each 1s OHLCV + - MA: 1s, 5s, 15s, and 60s MA of COB imbalance counting ±5 COB buckets + """ + symbol: str # Primary symbol (ETH/USDT) + timestamp: datetime + + # Multi-timeframe OHLCV data for primary symbol (ETH) + ohlcv_1s: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1s data + ohlcv_1m: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1m data + ohlcv_1h: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1h data + ohlcv_1d: List[OHLCVBar] = field(default_factory=list) # 300 frames of 1d data + + # Reference symbol (BTC) 1s data + btc_ohlcv_1s: List[OHLCVBar] = field(default_factory=list) # 300s of 1s BTC data + + # COB data for 1s timeframe (±20 buckets around current price) + cob_data: Optional[COBData] = None + + # Technical indicators + technical_indicators: Dict[str, float] = field(default_factory=dict) + + # Pivot points from Williams Market Structure + pivot_points: List[PivotPoint] = field(default_factory=list) + + # Last predictions from all models (for cross-model feeding) + last_predictions: Dict[str, ModelOutput] = field(default_factory=dict) + + # Market microstructure data + market_microstructure: Dict[str, Any] = field(default_factory=dict) + + def get_feature_vector(self) -> np.ndarray: + """ + Convert BaseDataInput to standardized feature vector for models + + Returns: + np.ndarray: Standardized feature vector combining all data sources + """ + features = [] + + # OHLCV features for ETH (300 frames x 4 timeframes x 5 features = 6000 features) + for ohlcv_list in [self.ohlcv_1s, self.ohlcv_1m, self.ohlcv_1h, self.ohlcv_1d]: + for bar in ohlcv_list[-300:]: # Ensure exactly 300 frames + features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume]) + + # BTC OHLCV features (300 frames x 5 features = 1500 features) + for bar in self.btc_ohlcv_1s[-300:]: # Ensure exactly 300 frames + features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume]) + + # COB features (±20 buckets x multiple metrics ≈ 800 features) + if self.cob_data: + # Price bucket features + for price in sorted(self.cob_data.price_buckets.keys()): + bucket_data = self.cob_data.price_buckets[price] + features.extend([ + bucket_data.get('bid_volume', 0.0), + bucket_data.get('ask_volume', 0.0), + bucket_data.get('total_volume', 0.0), + bucket_data.get('imbalance', 0.0) + ]) + + # Moving averages of imbalance for ±5 buckets (5 buckets x 4 MAs x 2 sides = 40 features) + for ma_dict in [self.cob_data.ma_1s_imbalance, self.cob_data.ma_5s_imbalance, + self.cob_data.ma_15s_imbalance, self.cob_data.ma_60s_imbalance]: + for price in sorted(list(ma_dict.keys())[:5]): # ±5 buckets + features.append(ma_dict[price]) + + # Technical indicators (variable, pad to 100 features) + indicator_values = list(self.technical_indicators.values()) + features.extend(indicator_values[:100]) # Take first 100 indicators + features.extend([0.0] * max(0, 100 - len(indicator_values))) # Pad if needed + + # Last predictions from other models (variable, pad to 50 features) + prediction_features = [] + for model_output in self.last_predictions.values(): + prediction_features.extend([ + model_output.confidence, + model_output.predictions.get('buy_probability', 0.0), + model_output.predictions.get('sell_probability', 0.0), + model_output.predictions.get('hold_probability', 0.0), + model_output.predictions.get('expected_reward', 0.0) + ]) + features.extend(prediction_features[:50]) # Take first 50 prediction features + features.extend([0.0] * max(0, 50 - len(prediction_features))) # Pad if needed + + return np.array(features, dtype=np.float32) + + def validate(self) -> bool: + """ + Validate that the BaseDataInput contains required data + + Returns: + bool: True if valid, False otherwise + """ + # Check that we have required OHLCV data + if len(self.ohlcv_1s) < 100: # At least 100 frames + return False + if len(self.btc_ohlcv_1s) < 100: # At least 100 frames of BTC data + return False + + # Check that timestamps are reasonable + if not self.timestamp: + return False + + # Check symbol format + if not self.symbol or '/' not in self.symbol: + return False + + return True + +@dataclass +class TradingAction: + """Trading action output from models""" + symbol: str + timestamp: datetime + action: str # 'BUY', 'SELL', 'HOLD' + confidence: float + source: str # 'rl', 'cnn', 'orchestrator' + price: Optional[float] = None + quantity: Optional[float] = None + reason: Optional[str] = None + +def create_model_output(model_type: str, model_name: str, symbol: str, + action: str, confidence: float, + hidden_states: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None) -> ModelOutput: + """ + Helper function to create standardized ModelOutput + + Args: + model_type: Type of model ('cnn', 'rl', 'lstm', 'transformer', 'orchestrator') + model_name: Specific model identifier + symbol: Trading symbol + action: Trading action ('BUY', 'SELL', 'HOLD') + confidence: Confidence score (0.0 to 1.0) + hidden_states: Optional hidden states for cross-model feeding + metadata: Optional additional metadata + + Returns: + ModelOutput: Standardized model output + """ + predictions = { + 'action': action, + 'buy_probability': confidence if action == 'BUY' else 0.0, + 'sell_probability': confidence if action == 'SELL' else 0.0, + 'hold_probability': confidence if action == 'HOLD' else 0.0, + } + + return ModelOutput( + model_type=model_type, + model_name=model_name, + symbol=symbol, + timestamp=datetime.now(), + confidence=confidence, + predictions=predictions, + hidden_states=hidden_states or {}, + metadata=metadata or {} + ) \ No newline at end of file diff --git a/core/model_output_manager.py b/core/model_output_manager.py new file mode 100644 index 0000000..adaf8f5 --- /dev/null +++ b/core/model_output_manager.py @@ -0,0 +1,395 @@ +""" +Model Output Manager + +This module provides extensible model output storage and management for the multi-modal trading system. +Supports CNN, RL, LSTM, Transformer, and future model types with cross-model feeding capabilities. +""" + +import logging +import json +import pickle +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Any, Union +from collections import deque, defaultdict +from threading import Lock +from pathlib import Path + +from .data_models import ModelOutput, create_model_output + +logger = logging.getLogger(__name__) + +class ModelOutputManager: + """ + Extensible model output storage and management system + + Features: + - Standardized ModelOutput storage for all model types + - Cross-model feeding with hidden states + - Historical output tracking + - Metadata management + - Persistence and recovery + - Performance analytics + """ + + def __init__(self, cache_dir: str = "cache/model_outputs", max_history: int = 1000): + """ + Initialize the model output manager + + Args: + cache_dir: Directory for persistent storage + max_history: Maximum number of outputs to keep in memory per model + """ + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.max_history = max_history + + # In-memory storage + self.current_outputs: Dict[str, Dict[str, ModelOutput]] = defaultdict(dict) # {symbol: {model_name: ModelOutput}} + self.output_history: Dict[str, Dict[str, deque]] = defaultdict(lambda: defaultdict(lambda: deque(maxlen=max_history))) # {symbol: {model_name: deque}} + self.cross_model_states: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: hidden_states}} + + # Metadata tracking + self.model_metadata: Dict[str, Dict[str, Any]] = defaultdict(dict) # {model_name: metadata} + self.performance_stats: Dict[str, Dict[str, Any]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: stats}} + + # Thread safety + self.storage_lock = Lock() + + # Supported model types + self.supported_model_types = { + 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator', + 'ensemble', 'hybrid', 'custom' # Extensible for future types + } + + logger.info(f"ModelOutputManager initialized with cache dir: {self.cache_dir}") + logger.info(f"Supported model types: {self.supported_model_types}") + + def store_output(self, model_output: ModelOutput) -> bool: + """ + Store model output with full extensibility support + + Args: + model_output: ModelOutput from any model type + + Returns: + bool: True if stored successfully, False otherwise + """ + try: + with self.storage_lock: + symbol = model_output.symbol + model_name = model_output.model_name + model_type = model_output.model_type + + # Validate model type (extensible) + if model_type not in self.supported_model_types: + logger.warning(f"Unknown model type '{model_type}' - adding to supported types") + self.supported_model_types.add(model_type) + + # Store current output + self.current_outputs[symbol][model_name] = model_output + + # Add to history + self.output_history[symbol][model_name].append(model_output) + + # Store cross-model states if available + if model_output.hidden_states: + self.cross_model_states[symbol][model_name] = model_output.hidden_states + + # Update model metadata + self._update_model_metadata(model_name, model_type, model_output.metadata) + + # Update performance statistics + self._update_performance_stats(symbol, model_name, model_output) + + # Persist to disk (async to avoid blocking) + self._persist_output_async(model_output) + + logger.debug(f"Stored output from {model_name} ({model_type}) for {symbol}") + return True + + except Exception as e: + logger.error(f"Error storing model output: {e}") + return False + + def get_current_output(self, symbol: str, model_name: str) -> Optional[ModelOutput]: + """ + Get the current (latest) output from a specific model + + Args: + symbol: Trading symbol + model_name: Name of the model + + Returns: + ModelOutput: Latest output from the model, or None if not available + """ + try: + return self.current_outputs.get(symbol, {}).get(model_name) + except Exception as e: + logger.error(f"Error getting current output for {model_name}: {e}") + return None + + def get_all_current_outputs(self, symbol: str) -> Dict[str, ModelOutput]: + """ + Get all current outputs for a symbol (for cross-model feeding) + + Args: + symbol: Trading symbol + + Returns: + Dict[str, ModelOutput]: Dictionary of current outputs by model name + """ + try: + return dict(self.current_outputs.get(symbol, {})) + except Exception as e: + logger.error(f"Error getting all current outputs for {symbol}: {e}") + return {} + + def get_output_history(self, symbol: str, model_name: str, count: int = 10) -> List[ModelOutput]: + """ + Get historical outputs from a model + + Args: + symbol: Trading symbol + model_name: Name of the model + count: Number of historical outputs to retrieve + + Returns: + List[ModelOutput]: List of historical outputs (most recent first) + """ + try: + history = self.output_history.get(symbol, {}).get(model_name, deque()) + return list(history)[-count:][::-1] # Most recent first + except Exception as e: + logger.error(f"Error getting output history for {model_name}: {e}") + return [] + + def get_cross_model_states(self, symbol: str, requesting_model: str) -> Dict[str, Dict[str, Any]]: + """ + Get hidden states from other models for cross-model feeding + + Args: + symbol: Trading symbol + requesting_model: Name of the model requesting the states + + Returns: + Dict[str, Dict[str, Any]]: Hidden states from other models + """ + try: + all_states = self.cross_model_states.get(symbol, {}) + # Return states from all models except the requesting one + return {model_name: states for model_name, states in all_states.items() + if model_name != requesting_model} + except Exception as e: + logger.error(f"Error getting cross-model states for {requesting_model}: {e}") + return {} + + def get_model_types_active(self, symbol: str) -> List[str]: + """ + Get list of active model types for a symbol + + Args: + symbol: Trading symbol + + Returns: + List[str]: List of active model types + """ + try: + current_outputs = self.current_outputs.get(symbol, {}) + return [output.model_type for output in current_outputs.values()] + except Exception as e: + logger.error(f"Error getting active model types for {symbol}: {e}") + return [] + + def get_consensus_prediction(self, symbol: str, confidence_threshold: float = 0.5) -> Optional[Dict[str, Any]]: + """ + Get consensus prediction from all active models + + Args: + symbol: Trading symbol + confidence_threshold: Minimum confidence threshold for inclusion + + Returns: + Dict containing consensus prediction or None + """ + try: + current_outputs = self.current_outputs.get(symbol, {}) + if not current_outputs: + return None + + # Filter by confidence threshold + high_confidence_outputs = [ + output for output in current_outputs.values() + if output.confidence >= confidence_threshold + ] + + if not high_confidence_outputs: + return None + + # Calculate consensus + buy_votes = sum(1 for output in high_confidence_outputs + if output.predictions.get('action') == 'BUY') + sell_votes = sum(1 for output in high_confidence_outputs + if output.predictions.get('action') == 'SELL') + hold_votes = sum(1 for output in high_confidence_outputs + if output.predictions.get('action') == 'HOLD') + + total_votes = len(high_confidence_outputs) + avg_confidence = sum(output.confidence for output in high_confidence_outputs) / total_votes + + # Determine consensus action + if buy_votes > sell_votes and buy_votes > hold_votes: + consensus_action = 'BUY' + elif sell_votes > buy_votes and sell_votes > hold_votes: + consensus_action = 'SELL' + else: + consensus_action = 'HOLD' + + return { + 'action': consensus_action, + 'confidence': avg_confidence, + 'votes': {'BUY': buy_votes, 'SELL': sell_votes, 'HOLD': hold_votes}, + 'total_models': total_votes, + 'model_types': [output.model_type for output in high_confidence_outputs] + } + + except Exception as e: + logger.error(f"Error calculating consensus prediction for {symbol}: {e}") + return None + + def _update_model_metadata(self, model_name: str, model_type: str, metadata: Dict[str, Any]): + """Update metadata for a model""" + try: + if model_name not in self.model_metadata: + self.model_metadata[model_name] = { + 'model_type': model_type, + 'first_seen': datetime.now(), + 'total_predictions': 0, + 'custom_metadata': {} + } + + self.model_metadata[model_name]['total_predictions'] += 1 + self.model_metadata[model_name]['last_seen'] = datetime.now() + + # Merge custom metadata + if metadata: + self.model_metadata[model_name]['custom_metadata'].update(metadata) + + except Exception as e: + logger.error(f"Error updating model metadata: {e}") + + def _update_performance_stats(self, symbol: str, model_name: str, model_output: ModelOutput): + """Update performance statistics for a model""" + try: + stats = self.performance_stats[symbol][model_name] + + if 'prediction_count' not in stats: + stats['prediction_count'] = 0 + stats['confidence_sum'] = 0.0 + stats['action_counts'] = {'BUY': 0, 'SELL': 0, 'HOLD': 0} + stats['first_prediction'] = model_output.timestamp + + stats['prediction_count'] += 1 + stats['confidence_sum'] += model_output.confidence + stats['avg_confidence'] = stats['confidence_sum'] / stats['prediction_count'] + stats['last_prediction'] = model_output.timestamp + + action = model_output.predictions.get('action', 'HOLD') + if action in stats['action_counts']: + stats['action_counts'][action] += 1 + + except Exception as e: + logger.error(f"Error updating performance stats: {e}") + + def _persist_output_async(self, model_output: ModelOutput): + """Persist model output to disk (simplified version)""" + try: + # Create filename based on model and timestamp + timestamp_str = model_output.timestamp.strftime("%Y%m%d_%H%M%S") + filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp_str}.json" + filepath = self.cache_dir / filename + + # Convert to JSON-serializable format + output_dict = { + 'model_type': model_output.model_type, + 'model_name': model_output.model_name, + 'symbol': model_output.symbol, + 'timestamp': model_output.timestamp.isoformat(), + 'confidence': model_output.confidence, + 'predictions': model_output.predictions, + 'metadata': model_output.metadata + } + + # Save to file (in a real implementation, this would be async) + with open(filepath, 'w') as f: + json.dump(output_dict, f, indent=2) + + except Exception as e: + logger.error(f"Error persisting model output: {e}") + + def get_performance_summary(self, symbol: str) -> Dict[str, Any]: + """ + Get performance summary for all models for a symbol + + Args: + symbol: Trading symbol + + Returns: + Dict containing performance summary + """ + try: + summary = { + 'symbol': symbol, + 'active_models': len(self.current_outputs.get(symbol, {})), + 'model_stats': {} + } + + for model_name, stats in self.performance_stats.get(symbol, {}).items(): + summary['model_stats'][model_name] = { + 'predictions': stats.get('prediction_count', 0), + 'avg_confidence': round(stats.get('avg_confidence', 0.0), 3), + 'action_distribution': stats.get('action_counts', {}), + 'model_type': self.model_metadata.get(model_name, {}).get('model_type', 'unknown') + } + + return summary + + except Exception as e: + logger.error(f"Error getting performance summary: {e}") + return {'symbol': symbol, 'error': str(e)} + + def cleanup_old_outputs(self, max_age_hours: int = 24): + """ + Clean up old outputs to manage memory usage + + Args: + max_age_hours: Maximum age of outputs to keep in hours + """ + try: + cutoff_time = datetime.now() - timedelta(hours=max_age_hours) + + with self.storage_lock: + for symbol in self.output_history: + for model_name in self.output_history[symbol]: + history = self.output_history[symbol][model_name] + # Remove old outputs + while history and history[0].timestamp < cutoff_time: + history.popleft() + + logger.info(f"Cleaned up outputs older than {max_age_hours} hours") + + except Exception as e: + logger.error(f"Error cleaning up old outputs: {e}") + + def add_custom_model_type(self, model_type: str): + """ + Add support for a new custom model type + + Args: + model_type: Name of the new model type + """ + self.supported_model_types.add(model_type) + logger.info(f"Added support for custom model type: {model_type}") + + def get_supported_model_types(self) -> List[str]: + """Get list of all supported model types""" + return list(self.supported_model_types) \ No newline at end of file diff --git a/core/standardized_data_provider.py b/core/standardized_data_provider.py new file mode 100644 index 0000000..510acbf --- /dev/null +++ b/core/standardized_data_provider.py @@ -0,0 +1,453 @@ +""" +Standardized Data Provider Extension + +This module extends the existing DataProvider with standardized BaseDataInput functionality +for all models in the multi-modal trading system. +""" + +import logging +import numpy as np +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Any +from collections import deque +from threading import Lock + +from .data_provider import DataProvider +from .data_models import BaseDataInput, OHLCVBar, COBData, ModelOutput, PivotPoint +from .multi_exchange_cob_provider import MultiExchangeCOBProvider +from .model_output_manager import ModelOutputManager + +logger = logging.getLogger(__name__) + +class StandardizedDataProvider(DataProvider): + """ + Extended DataProvider with standardized BaseDataInput support + + Provides unified data format for all models: + - OHLCV: 300 frames of (1s, 1m, 1h, 1d) ETH + 300s of 1s BTC + - COB: ±20 buckets of COB amounts in USD for each 1s OHLCV + - MA: 1s, 5s, 15s, and 60s MA of COB imbalance counting ±5 COB buckets + """ + + def __init__(self, symbols: List[str] = None, timeframes: List[str] = None): + """Initialize the standardized data provider""" + super().__init__(symbols, timeframes) + + # Standardized data storage + self.base_data_cache: Dict[str, BaseDataInput] = {} # {symbol: BaseDataInput} + self.cob_data_cache: Dict[str, COBData] = {} # {symbol: COBData} + + # Model output management with extensible storage + self.model_output_manager = ModelOutputManager( + cache_dir=str(self.cache_dir / "model_outputs"), + max_history=1000 + ) + + # COB moving averages calculation + self.cob_imbalance_history: Dict[str, deque] = {} # {symbol: deque of (timestamp, imbalance_data)} + self.ma_calculation_lock = Lock() + + # Initialize caches for each symbol + for symbol in self.symbols: + self.base_data_cache[symbol] = None + self.cob_data_cache[symbol] = None + self.cob_imbalance_history[symbol] = deque(maxlen=300) # 5 minutes of 1s data + + # COB provider integration + self.cob_provider: Optional[MultiExchangeCOBProvider] = None + self._initialize_cob_provider() + + logger.info("StandardizedDataProvider initialized with BaseDataInput support") + + def _initialize_cob_provider(self): + """Initialize COB provider for order book data""" + try: + from .multi_exchange_cob_provider import MultiExchangeCOBProvider, ExchangeConfig, ExchangeType + + # Configure exchanges (focusing on Binance for now) + exchange_configs = { + 'binance': ExchangeConfig( + exchange_type=ExchangeType.BINANCE, + weight=1.0, + enabled=True, + websocket_url="wss://stream.binance.com:9443/ws/", + symbols_mapping={symbol: symbol.replace('/', '').lower() for symbol in self.symbols} + ) + } + + self.cob_provider = MultiExchangeCOBProvider(self.symbols, exchange_configs) + logger.info("COB provider initialized successfully") + + except Exception as e: + logger.warning(f"Failed to initialize COB provider: {e}") + self.cob_provider = None + + def get_base_data_input(self, symbol: str, timestamp: Optional[datetime] = None) -> Optional[BaseDataInput]: + """ + Get standardized BaseDataInput for a symbol + + Args: + symbol: Trading symbol (e.g., 'ETH/USDT') + timestamp: Optional timestamp, defaults to current time + + Returns: + BaseDataInput: Standardized input data for models, or None if insufficient data + """ + if timestamp is None: + timestamp = datetime.now() + + try: + # Get OHLCV data for all timeframes + ohlcv_1s = self._get_ohlcv_bars(symbol, '1s', 300) + ohlcv_1m = self._get_ohlcv_bars(symbol, '1m', 300) + ohlcv_1h = self._get_ohlcv_bars(symbol, '1h', 300) + ohlcv_1d = self._get_ohlcv_bars(symbol, '1d', 300) + + # Get BTC reference data + btc_symbol = 'BTC/USDT' + btc_ohlcv_1s = self._get_ohlcv_bars(btc_symbol, '1s', 300) + + # Check if we have sufficient data + if not all([ohlcv_1s, ohlcv_1m, ohlcv_1h, ohlcv_1d, btc_ohlcv_1s]): + logger.warning(f"Insufficient OHLCV data for {symbol}") + return None + + if any(len(data) < 100 for data in [ohlcv_1s, ohlcv_1m, ohlcv_1h, ohlcv_1d, btc_ohlcv_1s]): + logger.warning(f"Insufficient data frames for {symbol}") + return None + + # Get COB data + cob_data = self._get_cob_data(symbol, timestamp) + + # Get technical indicators + technical_indicators = self._get_technical_indicators(symbol) + + # Get pivot points + pivot_points = self._get_pivot_points(symbol) + + # Get last predictions from all models + last_predictions = self.model_output_manager.get_all_current_outputs(symbol) + + # Create BaseDataInput + base_input = BaseDataInput( + symbol=symbol, + timestamp=timestamp, + ohlcv_1s=ohlcv_1s, + ohlcv_1m=ohlcv_1m, + ohlcv_1h=ohlcv_1h, + ohlcv_1d=ohlcv_1d, + btc_ohlcv_1s=btc_ohlcv_1s, + cob_data=cob_data, + technical_indicators=technical_indicators, + pivot_points=pivot_points, + last_predictions=last_predictions + ) + + # Validate the input + if not base_input.validate(): + logger.warning(f"BaseDataInput validation failed for {symbol}") + return None + + # Cache the result + self.base_data_cache[symbol] = base_input + + return base_input + + except Exception as e: + logger.error(f"Error creating BaseDataInput for {symbol}: {e}") + return None + + def _get_ohlcv_bars(self, symbol: str, timeframe: str, count: int) -> List[OHLCVBar]: + """ + Get OHLCV bars for a symbol and timeframe + + Args: + symbol: Trading symbol + timeframe: Timeframe ('1s', '1m', '1h', '1d') + count: Number of bars to retrieve + + Returns: + List[OHLCVBar]: List of OHLCV bars + """ + try: + # Get historical data from parent class + df = self.get_historical_data(symbol, timeframe, count) + if df is None or df.empty: + return [] + + # Convert DataFrame to OHLCVBar objects + bars = [] + for _, row in df.tail(count).iterrows(): + bar = OHLCVBar( + symbol=symbol, + timestamp=row.name if hasattr(row, 'name') else datetime.now(), + open=float(row['open']), + high=float(row['high']), + low=float(row['low']), + close=float(row['close']), + volume=float(row['volume']), + timeframe=timeframe, + indicators={} + ) + + # Add technical indicators if available + for col in df.columns: + if col not in ['open', 'high', 'low', 'close', 'volume']: + bar.indicators[col] = float(row[col]) if not np.isnan(row[col]) else 0.0 + + bars.append(bar) + + return bars + + except Exception as e: + logger.error(f"Error getting OHLCV bars for {symbol} {timeframe}: {e}") + return [] + + def _get_cob_data(self, symbol: str, timestamp: datetime) -> Optional[COBData]: + """ + Get COB data for a symbol + + Args: + symbol: Trading symbol + timestamp: Current timestamp + + Returns: + COBData: COB data with price buckets and moving averages + """ + try: + if not self.cob_provider: + return None + + # Get current price + current_price = self.current_prices.get(symbol.replace('/', '').upper(), 0.0) + if current_price <= 0: + return None + + # Determine bucket size based on symbol + bucket_size = 1.0 if 'ETH' in symbol else 10.0 # $1 for ETH, $10 for BTC + + # Calculate price range (±20 buckets) + price_range = 20 * bucket_size + min_price = current_price - price_range + max_price = current_price + price_range + + # Create price buckets + price_buckets = {} + bid_ask_imbalance = {} + volume_weighted_prices = {} + + # Generate mock COB data for now (will be replaced with real COB provider data) + for i in range(-20, 21): + price = current_price + (i * bucket_size) + if price > 0: + # Mock data - replace with real COB provider data + bid_volume = max(0, 1000 - abs(i) * 50) # More volume near current price + ask_volume = max(0, 1000 - abs(i) * 50) + total_volume = bid_volume + ask_volume + imbalance = (bid_volume - ask_volume) / max(total_volume, 1) + + price_buckets[price] = { + 'bid_volume': bid_volume, + 'ask_volume': ask_volume, + 'total_volume': total_volume, + 'imbalance': imbalance + } + bid_ask_imbalance[price] = imbalance + volume_weighted_prices[price] = price # Simplified VWAP + + # Calculate moving averages of imbalance for ±5 buckets + ma_data = self._calculate_cob_moving_averages(symbol, bid_ask_imbalance, timestamp) + + cob_data = COBData( + symbol=symbol, + timestamp=timestamp, + current_price=current_price, + bucket_size=bucket_size, + price_buckets=price_buckets, + bid_ask_imbalance=bid_ask_imbalance, + volume_weighted_prices=volume_weighted_prices, + order_flow_metrics={}, + ma_1s_imbalance=ma_data.get('1s', {}), + ma_5s_imbalance=ma_data.get('5s', {}), + ma_15s_imbalance=ma_data.get('15s', {}), + ma_60s_imbalance=ma_data.get('60s', {}) + ) + + # Cache the COB data + self.cob_data_cache[symbol] = cob_data + + return cob_data + + except Exception as e: + logger.error(f"Error getting COB data for {symbol}: {e}") + return None + + def _calculate_cob_moving_averages(self, symbol: str, bid_ask_imbalance: Dict[float, float], + timestamp: datetime) -> Dict[str, Dict[float, float]]: + """ + Calculate moving averages of COB imbalance for ±5 buckets + + Args: + symbol: Trading symbol + bid_ask_imbalance: Current bid/ask imbalance data + timestamp: Current timestamp + + Returns: + Dict containing MA data for different timeframes + """ + try: + with self.ma_calculation_lock: + # Add current imbalance data to history + self.cob_imbalance_history[symbol].append((timestamp, bid_ask_imbalance)) + + # Calculate MAs for different timeframes + ma_results = {'1s': {}, '5s': {}, '15s': {}, '60s': {}} + + # Get current price for ±5 bucket calculation + current_price = self.current_prices.get(symbol.replace('/', '').upper(), 0.0) + if current_price <= 0: + return ma_results + + bucket_size = 1.0 if 'ETH' in symbol else 10.0 + + # Calculate MAs for ±5 buckets around current price + for i in range(-5, 6): + price = current_price + (i * bucket_size) + if price <= 0: + continue + + # Get historical imbalance data for this price bucket + history = self.cob_imbalance_history[symbol] + + # Calculate different MA periods + for period, period_name in [(1, '1s'), (5, '5s'), (15, '15s'), (60, '60s')]: + recent_data = [] + cutoff_time = timestamp - timedelta(seconds=period) + + for hist_timestamp, hist_imbalance in history: + if hist_timestamp >= cutoff_time and price in hist_imbalance: + recent_data.append(hist_imbalance[price]) + + # Calculate moving average + if recent_data: + ma_results[period_name][price] = sum(recent_data) / len(recent_data) + else: + ma_results[period_name][price] = 0.0 + + return ma_results + + except Exception as e: + logger.error(f"Error calculating COB moving averages for {symbol}: {e}") + return {'1s': {}, '5s': {}, '15s': {}, '60s': {}} + + def _get_technical_indicators(self, symbol: str) -> Dict[str, float]: + """Get technical indicators for a symbol""" + try: + # Get latest OHLCV data + df = self.get_historical_data(symbol, '1h', 100) # Use 1h for indicators + if df is None or df.empty: + return {} + + indicators = {} + + # Add basic indicators if available in the dataframe + latest_row = df.iloc[-1] + for col in df.columns: + if col not in ['open', 'high', 'low', 'close', 'volume']: + indicators[col] = float(latest_row[col]) if not np.isnan(latest_row[col]) else 0.0 + + return indicators + + except Exception as e: + logger.error(f"Error getting technical indicators for {symbol}: {e}") + return {} + + def _get_pivot_points(self, symbol: str) -> List[PivotPoint]: + """Get pivot points for a symbol""" + try: + pivot_points = [] + + # Get pivot points from Williams Market Structure if available + if symbol in self.williams_structure: + williams = self.williams_structure[symbol] + # This would need to be implemented based on the actual Williams structure + # For now, return empty list + pass + + return pivot_points + + except Exception as e: + logger.error(f"Error getting pivot points for {symbol}: {e}") + return [] + + def store_model_output(self, model_output: ModelOutput): + """ + Store model output for cross-model feeding using ModelOutputManager + + Args: + model_output: ModelOutput from any model + """ + try: + success = self.model_output_manager.store_output(model_output) + if success: + logger.debug(f"Stored model output from {model_output.model_name} for {model_output.symbol}") + else: + logger.warning(f"Failed to store model output from {model_output.model_name}") + + except Exception as e: + logger.error(f"Error storing model output: {e}") + + def get_model_outputs(self, symbol: str) -> Dict[str, ModelOutput]: + """ + Get all model outputs for a symbol using ModelOutputManager + + Args: + symbol: Trading symbol + + Returns: + Dict[str, ModelOutput]: Dictionary of model outputs by model name + """ + return self.model_output_manager.get_all_current_outputs(symbol) + + def get_model_output_manager(self) -> ModelOutputManager: + """ + Get the model output manager for advanced operations + + Returns: + ModelOutputManager: The model output manager instance + """ + return self.model_output_manager + + def start_real_time_processing(self): + """Start real-time processing for standardized data""" + try: + # Start parent class real-time processing + if hasattr(super(), 'start_real_time_processing'): + super().start_real_time_processing() + + # Start COB provider if available + if self.cob_provider: + import asyncio + asyncio.create_task(self.cob_provider.start_streaming()) + + logger.info("Started real-time processing for standardized data") + + except Exception as e: + logger.error(f"Error starting real-time processing: {e}") + + def stop_real_time_processing(self): + """Stop real-time processing""" + try: + # Stop COB provider if available + if self.cob_provider: + import asyncio + asyncio.create_task(self.cob_provider.stop_streaming()) + + # Stop parent class processing + if hasattr(super(), 'stop_real_time_processing'): + super().stop_real_time_processing() + + logger.info("Stopped real-time processing for standardized data") + + except Exception as e: + logger.error(f"Error stopping real-time processing: {e}") \ No newline at end of file diff --git a/test_cache/model_outputs/custom_ensemble_v1_ETH_USDT_20250723_153826.json b/test_cache/model_outputs/custom_ensemble_v1_ETH_USDT_20250723_153826.json new file mode 100644 index 0000000..3f9101e --- /dev/null +++ b/test_cache/model_outputs/custom_ensemble_v1_ETH_USDT_20250723_153826.json @@ -0,0 +1,17 @@ +{ + "model_type": "hybrid_ensemble", + "model_name": "custom_ensemble_v1", + "symbol": "ETH/USDT", + "timestamp": "2025-07-23T15:38:26.810125", + "confidence": 0.88, + "predictions": { + "action": "BUY", + "buy_probability": 0.88, + "sell_probability": 0.0, + "hold_probability": 0.0 + }, + "metadata": { + "ensemble_size": 5, + "voting_method": "weighted" + } +} \ No newline at end of file diff --git a/test_cache/model_outputs/dqn_agent_v2_ETH_USDT_20250723_153826.json b/test_cache/model_outputs/dqn_agent_v2_ETH_USDT_20250723_153826.json new file mode 100644 index 0000000..f7d4683 --- /dev/null +++ b/test_cache/model_outputs/dqn_agent_v2_ETH_USDT_20250723_153826.json @@ -0,0 +1,18 @@ +{ + "model_type": "rl", + "model_name": "dqn_agent_v2", + "symbol": "ETH/USDT", + "timestamp": "2025-07-23T15:38:26.804125", + "confidence": 0.72, + "predictions": { + "action": "SELL", + "buy_probability": 0.0, + "sell_probability": 0.72, + "hold_probability": 0.0 + }, + "metadata": { + "model_version": "1.0", + "training_iterations": 1000, + "last_updated": "2025-07-23T15:38:26.804125" + } +} \ No newline at end of file diff --git a/test_cache/model_outputs/enhanced_cnn_v1_ETH_USDT_20250723_153826.json b/test_cache/model_outputs/enhanced_cnn_v1_ETH_USDT_20250723_153826.json new file mode 100644 index 0000000..7238735 --- /dev/null +++ b/test_cache/model_outputs/enhanced_cnn_v1_ETH_USDT_20250723_153826.json @@ -0,0 +1,14 @@ +{ + "model_type": "cnn", + "model_name": "enhanced_cnn_v1", + "symbol": "ETH/USDT", + "timestamp": "2025-07-23T15:38:26.812129", + "confidence": 0.8, + "predictions": { + "action": "HOLD", + "buy_probability": 0.0, + "sell_probability": 0.0, + "hold_probability": 0.8 + }, + "metadata": {} +} \ No newline at end of file diff --git a/test_cache/model_outputs/lstm_predictor_ETH_USDT_20250723_153826.json b/test_cache/model_outputs/lstm_predictor_ETH_USDT_20250723_153826.json new file mode 100644 index 0000000..c75d394 --- /dev/null +++ b/test_cache/model_outputs/lstm_predictor_ETH_USDT_20250723_153826.json @@ -0,0 +1,18 @@ +{ + "model_type": "lstm", + "model_name": "lstm_predictor", + "symbol": "ETH/USDT", + "timestamp": "2025-07-23T15:38:26.805126", + "confidence": 0.65, + "predictions": { + "action": "HOLD", + "buy_probability": 0.0, + "sell_probability": 0.0, + "hold_probability": 0.65 + }, + "metadata": { + "model_version": "1.0", + "training_iterations": 1000, + "last_updated": "2025-07-23T15:38:26.805126" + } +} \ No newline at end of file diff --git a/test_cache/model_outputs/main_orchestrator_ETH_USDT_20250723_153826.json b/test_cache/model_outputs/main_orchestrator_ETH_USDT_20250723_153826.json new file mode 100644 index 0000000..19bba76 --- /dev/null +++ b/test_cache/model_outputs/main_orchestrator_ETH_USDT_20250723_153826.json @@ -0,0 +1,18 @@ +{ + "model_type": "orchestrator", + "model_name": "main_orchestrator", + "symbol": "ETH/USDT", + "timestamp": "2025-07-23T15:38:26.806125", + "confidence": 0.78, + "predictions": { + "action": "BUY", + "buy_probability": 0.78, + "sell_probability": 0.0, + "hold_probability": 0.0 + }, + "metadata": { + "model_version": "1.0", + "training_iterations": 1000, + "last_updated": "2025-07-23T15:38:26.806125" + } +} \ No newline at end of file diff --git a/test_cache/model_outputs/transformer_v1_ETH_USDT_20250723_153826.json b/test_cache/model_outputs/transformer_v1_ETH_USDT_20250723_153826.json new file mode 100644 index 0000000..32ec68d --- /dev/null +++ b/test_cache/model_outputs/transformer_v1_ETH_USDT_20250723_153826.json @@ -0,0 +1,18 @@ +{ + "model_type": "transformer", + "model_name": "transformer_v1", + "symbol": "ETH/USDT", + "timestamp": "2025-07-23T15:38:26.806125", + "confidence": 0.91, + "predictions": { + "action": "BUY", + "buy_probability": 0.91, + "sell_probability": 0.0, + "hold_probability": 0.0 + }, + "metadata": { + "model_version": "1.0", + "training_iterations": 1000, + "last_updated": "2025-07-23T15:38:26.806125" + } +} \ No newline at end of file diff --git a/test_integrated_standardized_provider.py b/test_integrated_standardized_provider.py new file mode 100644 index 0000000..17fbb33 --- /dev/null +++ b/test_integrated_standardized_provider.py @@ -0,0 +1,182 @@ +""" +Test script for integrated StandardizedDataProvider with ModelOutputManager + +This script tests the complete standardized data provider with extensible model output storage +""" + +import sys +import os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +import logging +from datetime import datetime +from core.standardized_data_provider import StandardizedDataProvider +from core.data_models import create_model_output + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def test_integrated_standardized_provider(): + """Test the integrated StandardizedDataProvider with ModelOutputManager""" + + print("Testing Integrated StandardizedDataProvider with ModelOutputManager...") + + # Initialize the provider + symbols = ['ETH/USDT', 'BTC/USDT'] + timeframes = ['1s', '1m', '1h', '1d'] + + provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes) + + print("✅ StandardizedDataProvider initialized with ModelOutputManager") + + # Test 1: Store model outputs from different types + print("\n1. Testing model output storage integration...") + + # Create and store outputs from different model types + model_outputs = [ + create_model_output('cnn', 'enhanced_cnn_v1', 'ETH/USDT', 'BUY', 0.85), + create_model_output('rl', 'dqn_agent_v2', 'ETH/USDT', 'SELL', 0.72), + create_model_output('transformer', 'transformer_v1', 'ETH/USDT', 'BUY', 0.91), + create_model_output('orchestrator', 'main_orchestrator', 'ETH/USDT', 'BUY', 0.78) + ] + + for output in model_outputs: + provider.store_model_output(output) + print(f"✅ Stored {output.model_type} output: {output.predictions['action']} ({output.confidence})") + + # Test 2: Retrieve model outputs + print("\n2. Testing model output retrieval...") + + all_outputs = provider.get_model_outputs('ETH/USDT') + print(f"✅ Retrieved {len(all_outputs)} model outputs for ETH/USDT") + + for model_name, output in all_outputs.items(): + print(f" {model_name} ({output.model_type}): {output.predictions['action']} - {output.confidence}") + + # Test 3: Test BaseDataInput with cross-model feeding + print("\n3. Testing BaseDataInput with cross-model predictions...") + + # Set mock current price for COB data + provider.current_prices['ETHUSDT'] = 3000.0 + + base_input = provider.get_base_data_input('ETH/USDT') + + if base_input: + print("✅ BaseDataInput created with cross-model predictions!") + print(f" Symbol: {base_input.symbol}") + print(f" OHLCV frames: 1s={len(base_input.ohlcv_1s)}, 1m={len(base_input.ohlcv_1m)}, 1h={len(base_input.ohlcv_1h)}, 1d={len(base_input.ohlcv_1d)}") + print(f" BTC frames: {len(base_input.btc_ohlcv_1s)}") + print(f" COB data: {'Available' if base_input.cob_data else 'Not available'}") + print(f" Last predictions: {len(base_input.last_predictions)} models") + + # Show cross-model predictions + for model_name, prediction in base_input.last_predictions.items(): + print(f" {model_name}: {prediction.predictions['action']} ({prediction.confidence})") + + # Test feature vector creation + try: + feature_vector = base_input.get_feature_vector() + print(f"✅ Feature vector created: shape {feature_vector.shape}") + except Exception as e: + print(f"❌ Feature vector creation failed: {e}") + else: + print("⚠️ BaseDataInput creation failed - this may be due to insufficient historical data") + + # Test 4: Advanced ModelOutputManager features + print("\n4. Testing advanced model output manager features...") + + output_manager = provider.get_model_output_manager() + + # Test consensus prediction + consensus = output_manager.get_consensus_prediction('ETH/USDT', confidence_threshold=0.7) + if consensus: + print(f"✅ Consensus prediction: {consensus['action']} (confidence: {consensus['confidence']:.3f})") + print(f" Votes: {consensus['votes']}") + print(f" Contributing models: {consensus['model_types']}") + else: + print("⚠️ No consensus reached") + + # Test cross-model states + cross_states = output_manager.get_cross_model_states('ETH/USDT', 'dqn_agent_v2') + print(f"✅ Cross-model states available for RL model: {len(cross_states)} models") + + # Test performance summary + performance = output_manager.get_performance_summary('ETH/USDT') + print(f"✅ Performance summary: {performance['active_models']} active models") + + # Test 5: Custom model type support + print("\n5. Testing custom model type extensibility...") + + # Add a custom model type + output_manager.add_custom_model_type('hybrid_lstm_transformer') + + # Create and store custom model output + custom_output = create_model_output( + model_type='hybrid_lstm_transformer', + model_name='hybrid_model_v1', + symbol='ETH/USDT', + action='BUY', + confidence=0.89, + metadata={'hybrid_components': ['lstm', 'transformer'], 'ensemble_weight': 0.6} + ) + + provider.store_model_output(custom_output) + print("✅ Custom model type 'hybrid_lstm_transformer' stored successfully") + + # Verify it's included in BaseDataInput + updated_base_input = provider.get_base_data_input('ETH/USDT') + if updated_base_input and 'hybrid_model_v1' in updated_base_input.last_predictions: + print("✅ Custom model output included in BaseDataInput cross-model feeding") + + print(f" Total supported model types: {len(output_manager.get_supported_model_types())}") + + # Test 6: Historical tracking + print("\n6. Testing historical output tracking...") + + # Store a few more outputs to build history + for i in range(3): + historical_output = create_model_output( + model_type='cnn', + model_name='enhanced_cnn_v1', + symbol='ETH/USDT', + action='HOLD', + confidence=0.6 + i * 0.05 + ) + provider.store_model_output(historical_output) + + history = output_manager.get_output_history('ETH/USDT', 'enhanced_cnn_v1', count=5) + print(f"✅ Historical tracking: {len(history)} outputs for enhanced_cnn_v1") + + # Test 7: Real-time data integration readiness + print("\n7. Testing real-time integration readiness...") + + print("✅ Real-time processing methods available:") + print(" - start_real_time_processing()") + print(" - stop_real_time_processing()") + print(" - COB provider integration ready") + print(" - Model output persistence enabled") + + print("\n✅ Integrated StandardizedDataProvider test completed successfully!") + print("\n🎯 Key achievements:") + print("✓ Standardized BaseDataInput format for all models") + print("✓ Extensible ModelOutput storage (CNN, RL, LSTM, Transformer, Custom)") + print("✓ Cross-model feeding with last predictions") + print("✓ COB data integration with moving averages") + print("✓ Consensus prediction calculation") + print("✓ Historical output tracking") + print("✓ Performance analytics") + print("✓ Thread-safe operations") + print("✓ Persistent storage capabilities") + + print("\n🚀 Ready for model integration:") + print("1. CNN models can use BaseDataInput and store ModelOutput") + print("2. RL models can access CNN hidden states via cross-model feeding") + print("3. Orchestrator can calculate consensus from all models") + print("4. New model types can be added without code changes") + print("5. All models receive identical standardized input format") + + return provider + +if __name__ == "__main__": + test_integrated_standardized_provider() \ No newline at end of file diff --git a/test_model_output_manager.py b/test_model_output_manager.py new file mode 100644 index 0000000..85f6048 --- /dev/null +++ b/test_model_output_manager.py @@ -0,0 +1,182 @@ +""" +Test script for ModelOutputManager + +This script tests the extensible model output storage functionality +""" + +import sys +import os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +import logging +from datetime import datetime +from core.model_output_manager import ModelOutputManager +from core.data_models import create_model_output, ModelOutput + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def test_model_output_manager(): + """Test the ModelOutputManager functionality""" + + print("Testing ModelOutputManager...") + + # Initialize the manager + manager = ModelOutputManager(cache_dir="test_cache/model_outputs", max_history=100) + + print(f"✅ ModelOutputManager initialized") + print(f" Supported model types: {manager.get_supported_model_types()}") + + # Test 1: Store outputs from different model types + print("\n1. Testing model output storage...") + + # Create outputs from different model types + models_to_test = [ + ('cnn', 'enhanced_cnn_v1', 'BUY', 0.85), + ('rl', 'dqn_agent_v2', 'SELL', 0.72), + ('lstm', 'lstm_predictor', 'HOLD', 0.65), + ('transformer', 'transformer_v1', 'BUY', 0.91), + ('orchestrator', 'main_orchestrator', 'BUY', 0.78) + ] + + symbol = 'ETH/USDT' + stored_outputs = [] + + for model_type, model_name, action, confidence in models_to_test: + # Create model output with hidden states for cross-model feeding + hidden_states = { + 'layer_1': [0.1, 0.2, 0.3], + 'layer_2': [0.4, 0.5, 0.6], + 'attention_weights': [0.7, 0.8, 0.9] + } if model_type in ['cnn', 'transformer'] else None + + metadata = { + 'model_version': '1.0', + 'training_iterations': 1000, + 'last_updated': datetime.now().isoformat() + } + + model_output = create_model_output( + model_type=model_type, + model_name=model_name, + symbol=symbol, + action=action, + confidence=confidence, + hidden_states=hidden_states, + metadata=metadata + ) + + # Store the output + success = manager.store_output(model_output) + if success: + print(f"✅ Stored {model_type} output: {action} ({confidence})") + stored_outputs.append(model_output) + else: + print(f"❌ Failed to store {model_type} output") + + # Test 2: Retrieve current outputs + print("\n2. Testing output retrieval...") + + all_current = manager.get_all_current_outputs(symbol) + print(f"✅ Retrieved {len(all_current)} current outputs for {symbol}") + + for model_name, output in all_current.items(): + print(f" {model_name} ({output.model_type}): {output.predictions['action']} - {output.confidence}") + + # Test 3: Cross-model feeding + print("\n3. Testing cross-model feeding...") + + cross_model_states = manager.get_cross_model_states(symbol, 'dqn_agent_v2') + print(f"✅ Retrieved cross-model states for RL model: {len(cross_model_states)} models") + + for model_name, states in cross_model_states.items(): + if states: + print(f" {model_name}: {len(states)} hidden state layers") + + # Test 4: Consensus prediction + print("\n4. Testing consensus prediction...") + + consensus = manager.get_consensus_prediction(symbol, confidence_threshold=0.7) + if consensus: + print(f"✅ Consensus prediction: {consensus['action']} (confidence: {consensus['confidence']:.3f})") + print(f" Votes: {consensus['votes']}") + print(f" Models: {consensus['model_types']}") + else: + print("⚠️ No consensus reached (insufficient high-confidence predictions)") + + # Test 5: Performance summary + print("\n5. Testing performance tracking...") + + performance = manager.get_performance_summary(symbol) + print(f"✅ Performance summary for {symbol}:") + print(f" Active models: {performance['active_models']}") + + for model_name, stats in performance['model_stats'].items(): + print(f" {model_name} ({stats['model_type']}): {stats['predictions']} predictions, " + f"avg confidence: {stats['avg_confidence']}") + + # Test 6: Custom model type support + print("\n6. Testing custom model type support...") + + # Add a custom model type + manager.add_custom_model_type('hybrid_ensemble') + + # Create output with custom model type + custom_output = create_model_output( + model_type='hybrid_ensemble', + model_name='custom_ensemble_v1', + symbol=symbol, + action='BUY', + confidence=0.88, + metadata={'ensemble_size': 5, 'voting_method': 'weighted'} + ) + + success = manager.store_output(custom_output) + if success: + print("✅ Custom model type 'hybrid_ensemble' stored successfully") + else: + print("❌ Failed to store custom model type") + + print(f" Updated supported types: {len(manager.get_supported_model_types())} types") + + # Test 7: Historical outputs + print("\n7. Testing historical output tracking...") + + # Store a few more outputs to build history + for i in range(3): + historical_output = create_model_output( + model_type='cnn', + model_name='enhanced_cnn_v1', + symbol=symbol, + action='HOLD', + confidence=0.6 + i * 0.1 + ) + manager.store_output(historical_output) + + history = manager.get_output_history(symbol, 'enhanced_cnn_v1', count=5) + print(f"✅ Retrieved {len(history)} historical outputs for enhanced_cnn_v1") + + for i, output in enumerate(history): + print(f" {i+1}. {output.predictions['action']} ({output.confidence}) at {output.timestamp}") + + # Test 8: Active model types + print("\n8. Testing active model type detection...") + + active_types = manager.get_model_types_active(symbol) + print(f"✅ Active model types for {symbol}: {active_types}") + + print("\n✅ ModelOutputManager test completed successfully!") + print("\nKey features verified:") + print("✓ Extensible model type support (CNN, RL, LSTM, Transformer, Custom)") + print("✓ Cross-model feeding with hidden states") + print("✓ Historical output tracking") + print("✓ Performance analytics") + print("✓ Consensus prediction calculation") + print("✓ Metadata management") + print("✓ Thread-safe storage operations") + + return manager + +if __name__ == "__main__": + test_model_output_manager() \ No newline at end of file diff --git a/test_standardized_data_provider.py b/test_standardized_data_provider.py new file mode 100644 index 0000000..9952634 --- /dev/null +++ b/test_standardized_data_provider.py @@ -0,0 +1,128 @@ +""" +Test script for StandardizedDataProvider + +This script tests the standardized BaseDataInput functionality +""" + +import sys +import os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +import logging +from datetime import datetime +from core.standardized_data_provider import StandardizedDataProvider +from core.data_models import create_model_output + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def test_standardized_data_provider(): + """Test the StandardizedDataProvider functionality""" + + print("Testing StandardizedDataProvider...") + + # Initialize the provider + symbols = ['ETH/USDT', 'BTC/USDT'] + timeframes = ['1s', '1m', '1h', '1d'] + + provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes) + + # Test getting BaseDataInput + print("\n1. Testing BaseDataInput creation...") + base_input = provider.get_base_data_input('ETH/USDT') + + if base_input is None: + print("❌ BaseDataInput is None - this is expected if no historical data is available") + print(" The provider needs real market data to create BaseDataInput") + + # Test with mock data + print("\n2. Testing data structures...") + + # Test ModelOutput creation + model_output = create_model_output( + model_type='cnn', + model_name='test_cnn', + symbol='ETH/USDT', + action='BUY', + confidence=0.75, + metadata={'test': True} + ) + + print(f"✅ Created ModelOutput: {model_output.model_type} - {model_output.predictions['action']} ({model_output.confidence})") + + # Test storing model output + provider.store_model_output(model_output) + stored_outputs = provider.get_model_outputs('ETH/USDT') + + if 'test_cnn' in stored_outputs: + print("✅ Model output storage and retrieval working") + else: + print("❌ Model output storage failed") + + else: + print("✅ BaseDataInput created successfully!") + print(f" Symbol: {base_input.symbol}") + print(f" Timestamp: {base_input.timestamp}") + print(f" OHLCV 1s frames: {len(base_input.ohlcv_1s)}") + print(f" OHLCV 1m frames: {len(base_input.ohlcv_1m)}") + print(f" OHLCV 1h frames: {len(base_input.ohlcv_1h)}") + print(f" OHLCV 1d frames: {len(base_input.ohlcv_1d)}") + print(f" BTC 1s frames: {len(base_input.btc_ohlcv_1s)}") + print(f" COB data available: {base_input.cob_data is not None}") + print(f" Technical indicators: {len(base_input.technical_indicators)}") + print(f" Pivot points: {len(base_input.pivot_points)}") + print(f" Last predictions: {len(base_input.last_predictions)}") + + # Test feature vector creation + try: + feature_vector = base_input.get_feature_vector() + print(f"✅ Feature vector created: shape {feature_vector.shape}") + except Exception as e: + print(f"❌ Feature vector creation failed: {e}") + + # Test validation + is_valid = base_input.validate() + print(f"✅ BaseDataInput validation: {'PASSED' if is_valid else 'FAILED'}") + + print("\n3. Testing data provider capabilities...") + + # Test historical data fetching + try: + eth_data = provider.get_historical_data('ETH/USDT', '1h', 10) + if eth_data is not None and not eth_data.empty: + print(f"✅ Historical data available: {len(eth_data)} bars for ETH/USDT 1h") + else: + print("⚠️ No historical data available - this is normal if APIs are not accessible") + except Exception as e: + print(f"⚠️ Historical data fetch error: {e}") + + print("\n4. Testing COB data functionality...") + + # Test COB data creation + try: + # Set a mock current price for testing + provider.current_prices['ETHUSDT'] = 3000.0 + cob_data = provider._get_cob_data('ETH/USDT', datetime.now()) + + if cob_data: + print(f"✅ COB data created successfully") + print(f" Current price: ${cob_data.current_price}") + print(f" Bucket size: ${cob_data.bucket_size}") + print(f" Price buckets: {len(cob_data.price_buckets)}") + print(f" MA 1s imbalance: {len(cob_data.ma_1s_imbalance)} buckets") + print(f" MA 5s imbalance: {len(cob_data.ma_5s_imbalance)} buckets") + else: + print("⚠️ COB data creation returned None") + except Exception as e: + print(f"❌ COB data creation error: {e}") + + print("\n✅ StandardizedDataProvider test completed!") + print("\nNext steps:") + print("1. Integrate with real market data APIs") + print("2. Connect to actual COB provider") + print("3. Test with live data streams") + print("4. Integrate with model training pipelines") + +if __name__ == "__main__": + test_standardized_data_provider() \ No newline at end of file