cleanup and removed dummy data

This commit is contained in:
Dobromir Popov
2025-07-26 23:35:14 +03:00
parent 3eb6335169
commit 87942d3807
14 changed files with 220 additions and 2465 deletions

View File

@ -1,190 +0,0 @@
"""
Simplified Data Cache System
Replaces complex FIFO queues with a simple current state cache.
Supports unordered updates and extensible data types.
"""
import threading
import time
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass, field
from collections import defaultdict
import pandas as pd
import numpy as np
logger = logging.getLogger(__name__)
@dataclass
class DataCacheEntry:
"""Single cache entry with metadata"""
data: Any
timestamp: datetime
source: str = "unknown"
version: int = 1
class DataCache:
"""
Simplified data cache that stores only the latest data for each type.
Thread-safe and supports unordered updates from multiple sources.
"""
def __init__(self):
self.cache: Dict[str, Dict[str, DataCacheEntry]] = defaultdict(dict) # {data_type: {symbol: entry}}
self.locks: Dict[str, threading.RLock] = defaultdict(threading.RLock) # Per data_type locks
self.update_callbacks: Dict[str, List[Callable]] = defaultdict(list) # Update notifications
# Historical data storage (loaded once)
self.historical_data: Dict[str, Dict[str, pd.DataFrame]] = defaultdict(dict) # {symbol: {timeframe: df}}
self.historical_locks: Dict[str, threading.RLock] = defaultdict(threading.RLock)
logger.info("DataCache initialized with simplified architecture")
def update(self, data_type: str, symbol: str, data: Any, source: str = "unknown") -> bool:
"""
Update cache with latest data (thread-safe, unordered updates supported)
Args:
data_type: Type of data ('ohlcv_1s', 'technical_indicators', etc.)
symbol: Trading symbol
data: New data to store
source: Source of the update
Returns:
bool: True if updated successfully
"""
try:
with self.locks[data_type]:
# Create or update entry
old_entry = self.cache[data_type].get(symbol)
new_version = (old_entry.version + 1) if old_entry else 1
self.cache[data_type][symbol] = DataCacheEntry(
data=data,
timestamp=datetime.now(),
source=source,
version=new_version
)
# Notify callbacks
for callback in self.update_callbacks[data_type]:
try:
callback(symbol, data, source)
except Exception as e:
logger.error(f"Error in update callback: {e}")
return True
except Exception as e:
logger.error(f"Error updating cache {data_type}/{symbol}: {e}")
return False
def get(self, data_type: str, symbol: str) -> Optional[Any]:
"""Get latest data for a type/symbol"""
try:
with self.locks[data_type]:
entry = self.cache[data_type].get(symbol)
return entry.data if entry else None
except Exception as e:
logger.error(f"Error getting cache {data_type}/{symbol}: {e}")
return None
def get_with_metadata(self, data_type: str, symbol: str) -> Optional[DataCacheEntry]:
"""Get latest data with metadata"""
try:
with self.locks[data_type]:
return self.cache[data_type].get(symbol)
except Exception as e:
logger.error(f"Error getting cache metadata {data_type}/{symbol}: {e}")
return None
def get_all(self, data_type: str) -> Dict[str, Any]:
"""Get all data for a data type"""
try:
with self.locks[data_type]:
return {symbol: entry.data for symbol, entry in self.cache[data_type].items()}
except Exception as e:
logger.error(f"Error getting all cache data for {data_type}: {e}")
return {}
def has_data(self, data_type: str, symbol: str, max_age_seconds: int = None) -> bool:
"""Check if we have recent data"""
try:
with self.locks[data_type]:
entry = self.cache[data_type].get(symbol)
if not entry:
return False
if max_age_seconds:
age = (datetime.now() - entry.timestamp).total_seconds()
return age <= max_age_seconds
return True
except Exception as e:
logger.error(f"Error checking cache data {data_type}/{symbol}: {e}")
return False
def register_callback(self, data_type: str, callback: Callable[[str, Any, str], None]):
"""Register callback for data updates"""
self.update_callbacks[data_type].append(callback)
def get_status(self) -> Dict[str, Dict[str, Dict[str, Any]]]:
"""Get cache status for monitoring"""
status = {}
for data_type in self.cache:
with self.locks[data_type]:
status[data_type] = {}
for symbol, entry in self.cache[data_type].items():
age_seconds = (datetime.now() - entry.timestamp).total_seconds()
status[data_type][symbol] = {
'timestamp': entry.timestamp.isoformat(),
'age_seconds': age_seconds,
'source': entry.source,
'version': entry.version,
'has_data': entry.data is not None
}
return status
# Historical data management
def store_historical_data(self, symbol: str, timeframe: str, df: pd.DataFrame):
"""Store historical data (loaded once at startup)"""
try:
with self.historical_locks[symbol]:
self.historical_data[symbol][timeframe] = df.copy()
logger.info(f"Stored {len(df)} historical bars for {symbol} {timeframe}")
except Exception as e:
logger.error(f"Error storing historical data {symbol}/{timeframe}: {e}")
def get_historical_data(self, symbol: str, timeframe: str) -> Optional[pd.DataFrame]:
"""Get historical data"""
try:
with self.historical_locks[symbol]:
return self.historical_data[symbol].get(timeframe)
except Exception as e:
logger.error(f"Error getting historical data {symbol}/{timeframe}: {e}")
return None
def has_historical_data(self, symbol: str, timeframe: str, min_bars: int = 100) -> bool:
"""Check if we have sufficient historical data"""
try:
with self.historical_locks[symbol]:
df = self.historical_data[symbol].get(timeframe)
return df is not None and len(df) >= min_bars
except Exception:
return False
# Global cache instance
_data_cache_instance = None
def get_data_cache() -> DataCache:
"""Get the global data cache instance"""
global _data_cache_instance
if _data_cache_instance is None:
_data_cache_instance = DataCache()
return _data_cache_instance

View File

@ -114,42 +114,32 @@ class BaseDataInput:
FIXED_FEATURE_SIZE = 7850
features = []
# OHLCV features for ETH (300 frames x 4 timeframes x 5 features = 6000 features)
# OHLCV features for ETH (up to 300 frames x 4 timeframes x 5 features)
for ohlcv_list in [self.ohlcv_1s, self.ohlcv_1m, self.ohlcv_1h, self.ohlcv_1d]:
# Ensure exactly 300 frames by padding or truncating
# Use actual data only, up to 300 frames
ohlcv_frames = ohlcv_list[-300:] if len(ohlcv_list) >= 300 else ohlcv_list
# Pad with zeros if not enough data
while len(ohlcv_frames) < 300:
# Create a dummy OHLCV bar with zeros
dummy_bar = OHLCVBar(
symbol="ETH/USDT",
timestamp=datetime.now(),
open=0.0, high=0.0, low=0.0, close=0.0, volume=0.0,
timeframe="1s"
)
ohlcv_frames.insert(0, dummy_bar)
# Extract features from exactly 300 frames
# Extract features from actual frames
for bar in ohlcv_frames:
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
# BTC OHLCV features (300 frames x 5 features = 1500 features)
# Pad with zeros only if we have some data but less than 300 frames
frames_needed = 300 - len(ohlcv_frames)
if frames_needed > 0:
features.extend([0.0] * (frames_needed * 5)) # 5 features per frame
# BTC OHLCV features (up to 300 frames x 5 features = 1500 features)
btc_frames = self.btc_ohlcv_1s[-300:] if len(self.btc_ohlcv_1s) >= 300 else self.btc_ohlcv_1s
# Pad BTC data if needed
while len(btc_frames) < 300:
dummy_bar = OHLCVBar(
symbol="BTC/USDT",
timestamp=datetime.now(),
open=0.0, high=0.0, low=0.0, close=0.0, volume=0.0,
timeframe="1s"
)
btc_frames.insert(0, dummy_bar)
# Extract features from actual BTC frames
for bar in btc_frames:
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
# Pad with zeros only if we have some data but less than 300 frames
btc_frames_needed = 300 - len(btc_frames)
if btc_frames_needed > 0:
features.extend([0.0] * (btc_frames_needed * 5)) # 5 features per frame
# COB features (FIXED SIZE: 200 features)
cob_features = []
if self.cob_data:

View File

@ -224,6 +224,12 @@ class DataProvider:
self.cob_data_cache[binance_symbol] = deque(maxlen=300) # 5 minutes of COB data
self.training_data_cache[binance_symbol] = deque(maxlen=1000) # Training data buffer
# Pre-built OHLCV cache for instant BaseDataInput building (optimization from SimplifiedDataIntegration)
self._ohlcv_cache = {} # {symbol: {timeframe: List[OHLCVBar]}}
self._ohlcv_cache_lock = Lock()
self._last_cache_update = {} # {symbol: {timeframe: datetime}}
self._cache_refresh_interval = 5 # seconds
# Data collection threads
self.data_collection_active = False
@ -1387,6 +1393,175 @@ class DataProvider:
logger.error(f"Error applying pivot normalization for {symbol}: {e}")
return df
def build_base_data_input(self, symbol: str) -> Optional['BaseDataInput']:
"""
Build BaseDataInput from cached data (optimized for speed)
Args:
symbol: Trading symbol
Returns:
BaseDataInput with consistent data structure
"""
try:
from .data_models import BaseDataInput
# Get OHLCV data directly from optimized cache (no validation checks for speed)
ohlcv_1s_list = self._get_cached_ohlcv_bars(symbol, '1s', 300)
ohlcv_1m_list = self._get_cached_ohlcv_bars(symbol, '1m', 300)
ohlcv_1h_list = self._get_cached_ohlcv_bars(symbol, '1h', 300)
ohlcv_1d_list = self._get_cached_ohlcv_bars(symbol, '1d', 300)
# Get BTC reference data
btc_symbol = 'BTC/USDT'
btc_ohlcv_1s_list = self._get_cached_ohlcv_bars(btc_symbol, '1s', 300)
if not btc_ohlcv_1s_list:
# Use ETH data as fallback
btc_ohlcv_1s_list = ohlcv_1s_list
# Get cached data (fast lookups)
technical_indicators = self._get_latest_technical_indicators(symbol)
cob_data = self._get_latest_cob_data_object(symbol)
last_predictions = {} # TODO: Implement model prediction caching
# Build BaseDataInput (no validation for speed - assume data is good)
base_data = BaseDataInput(
symbol=symbol,
timestamp=datetime.now(),
ohlcv_1s=ohlcv_1s_list,
ohlcv_1m=ohlcv_1m_list,
ohlcv_1h=ohlcv_1h_list,
ohlcv_1d=ohlcv_1d_list,
btc_ohlcv_1s=btc_ohlcv_1s_list,
technical_indicators=technical_indicators,
cob_data=cob_data,
last_predictions=last_predictions
)
return base_data
except Exception as e:
logger.error(f"Error building BaseDataInput for {symbol}: {e}")
return None
def _get_cached_ohlcv_bars(self, symbol: str, timeframe: str, max_count: int) -> List['OHLCVBar']:
"""Get OHLCV data list from pre-built cache for instant access"""
try:
with self._ohlcv_cache_lock:
cache_key = f"{symbol}_{timeframe}"
# Check if we have fresh cached data (updated within last 5 seconds)
last_update = self._last_cache_update.get(cache_key)
if (last_update and
(datetime.now() - last_update).total_seconds() < self._cache_refresh_interval and
cache_key in self._ohlcv_cache):
cached_data = self._ohlcv_cache[cache_key]
return cached_data[-max_count:] if len(cached_data) >= max_count else cached_data
# Need to rebuild cache for this symbol/timeframe
data_list = self._build_ohlcv_bar_cache(symbol, timeframe, max_count)
# Cache the result
self._ohlcv_cache[cache_key] = data_list
self._last_cache_update[cache_key] = datetime.now()
return data_list[-max_count:] if len(data_list) >= max_count else data_list
except Exception as e:
logger.error(f"Error getting cached OHLCV bars for {symbol}/{timeframe}: {e}")
return []
def _build_ohlcv_bar_cache(self, symbol: str, timeframe: str, max_count: int) -> List['OHLCVBar']:
"""Build OHLCV bar cache from historical and current data"""
try:
from .data_models import OHLCVBar
data_list = []
# Get historical data first (this should be fast as it's already cached)
historical_df = self.get_historical_data(symbol, timeframe, limit=max_count)
if historical_df is not None and not historical_df.empty:
# Convert historical data to OHLCVBar objects
for idx, row in historical_df.tail(max_count).iterrows():
bar = OHLCVBar(
symbol=symbol,
timestamp=idx if hasattr(idx, 'to_pydatetime') else datetime.now(),
open=float(row['open']),
high=float(row['high']),
low=float(row['low']),
close=float(row['close']),
volume=float(row['volume']),
timeframe=timeframe
)
data_list.append(bar)
return data_list
except Exception as e:
logger.error(f"Error building OHLCV bar cache for {symbol}/{timeframe}: {e}")
return []
def _get_latest_technical_indicators(self, symbol: str) -> Dict[str, float]:
"""Get latest technical indicators for a symbol"""
try:
# Get latest data and calculate indicators
df = self.get_historical_data(symbol, '1h', limit=50)
if df is not None and not df.empty:
df_with_indicators = self._add_technical_indicators(df)
if not df_with_indicators.empty:
# Return the latest indicators as a dict
latest_row = df_with_indicators.iloc[-1]
indicators = {}
for col in df_with_indicators.columns:
if col not in ['open', 'high', 'low', 'close', 'volume', 'timestamp']:
indicators[col] = float(latest_row[col]) if pd.notna(latest_row[col]) else 0.0
return indicators
return {}
except Exception as e:
logger.error(f"Error getting technical indicators for {symbol}: {e}")
return {}
def _get_latest_cob_data_object(self, symbol: str) -> Optional['COBData']:
"""Get latest COB data as COBData object"""
try:
from .data_models import COBData
# Get latest COB data from cache
cob_data = self.get_latest_cob_data(symbol)
if cob_data and 'current_price' in cob_data:
return COBData(
symbol=symbol,
timestamp=datetime.now(),
current_price=cob_data['current_price'],
bucket_size=1.0 if 'ETH' in symbol else 10.0,
price_buckets=cob_data.get('price_buckets', {}),
bid_ask_imbalance=cob_data.get('bid_ask_imbalance', {}),
volume_weighted_prices=cob_data.get('volume_weighted_prices', {}),
order_flow_metrics=cob_data.get('order_flow_metrics', {}),
ma_1s_imbalance=cob_data.get('ma_1s_imbalance', {}),
ma_5s_imbalance=cob_data.get('ma_5s_imbalance', {}),
ma_15s_imbalance=cob_data.get('ma_15s_imbalance', {}),
ma_60s_imbalance=cob_data.get('ma_60s_imbalance', {})
)
return None
except Exception as e:
logger.error(f"Error getting COB data object for {symbol}: {e}")
return None
def invalidate_ohlcv_cache(self, symbol: str):
"""Invalidate OHLCV cache for a symbol when new data arrives"""
try:
with self._ohlcv_cache_lock:
# Remove cached data for all timeframes of this symbol
keys_to_remove = [key for key in self._ohlcv_cache.keys() if key.startswith(f"{symbol}_")]
for key in keys_to_remove:
if key in self._ohlcv_cache:
del self._ohlcv_cache[key]
if key in self._last_cache_update:
del self._last_cache_update[key]
except Exception as e:
logger.error(f"Error invalidating OHLCV cache for {symbol}: {e}")
def _add_basic_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
"""Add basic indicators for small datasets"""
try:

View File

@ -179,12 +179,7 @@ class TradingOrchestrator:
self.fusion_decisions_count: int = 0
self.fusion_training_data: List[Any] = [] # Store training examples for decision model
# Simplified Data Integration - Replace complex FIFO queues with efficient cache
from core.simplified_data_integration import SimplifiedDataIntegration
self.data_integration = SimplifiedDataIntegration(
data_provider=self.data_provider,
symbols=[self.symbol] + self.ref_symbols
)
# Use data provider directly for BaseDataInput building (optimized)
# COB Integration - Real-time market microstructure data
self.cob_integration = None # Will be set to COBIntegration instance if available
@ -250,8 +245,7 @@ class TradingOrchestrator:
self.data_provider.start_centralized_data_collection()
logger.info("Centralized data collection started - all models and dashboard will receive data")
# Initialize simplified data integration
self._initialize_simplified_data_integration()
# Data provider is already initialized and optimized
# Log initial data status
logger.info("Simplified data integration initialized")
@ -897,31 +891,10 @@ class TradingOrchestrator:
try:
self.latest_cob_data[symbol] = cob_data
# Update data cache with COB data for BaseDataInput
if hasattr(self, 'data_integration') and self.data_integration:
# Convert cob_data to COBData format if needed
from .data_models import COBData
# Create COBData object from the raw cob_data
if 'price_buckets' in cob_data and 'current_price' in cob_data:
cob_data_obj = COBData(
symbol=symbol,
timestamp=datetime.now(),
current_price=cob_data['current_price'],
bucket_size=1.0 if 'ETH' in symbol else 10.0,
price_buckets=cob_data.get('price_buckets', {}),
bid_ask_imbalance=cob_data.get('bid_ask_imbalance', {}),
volume_weighted_prices=cob_data.get('volume_weighted_prices', {}),
order_flow_metrics=cob_data.get('order_flow_metrics', {}),
ma_1s_imbalance=cob_data.get('ma_1s_imbalance', {}),
ma_5s_imbalance=cob_data.get('ma_5s_imbalance', {}),
ma_15s_imbalance=cob_data.get('ma_15s_imbalance', {}),
ma_60s_imbalance=cob_data.get('ma_60s_imbalance', {})
)
# Update cache with COB data
self.data_integration.cache.update('cob_data', symbol, cob_data_obj, 'cob_integration')
logger.debug(f"Updated cache with COB data for {symbol}")
# Invalidate data provider cache when new COB data arrives
if hasattr(self.data_provider, 'invalidate_ohlcv_cache'):
self.data_provider.invalidate_ohlcv_cache(symbol)
logger.debug(f"Invalidated data provider cache for {symbol} due to COB update")
# Update dashboard
if self.dashboard and hasattr(self.dashboard, 'update_cob_data'):
@ -3722,54 +3695,34 @@ class TradingOrchestrator:
"""
return self.db_manager.get_best_checkpoint_metadata(model_name)
# === SIMPLIFIED DATA MANAGEMENT ===
def _initialize_simplified_data_integration(self):
"""Initialize the simplified data integration system"""
try:
# Start the data integration system
self.data_integration.start()
logger.info("Simplified data integration started successfully")
except Exception as e:
logger.error(f"Error starting simplified data integration: {e}")
# === DATA MANAGEMENT ===
def _log_data_status(self):
"""Log current data status"""
try:
status = self.data_integration.get_cache_status()
cache_status = status.get('cache_status', {})
logger.info("=== Data Cache Status ===")
for data_type, symbols_data in cache_status.items():
symbol_info = []
for symbol, info in symbols_data.items():
age = info.get('age_seconds', 0)
has_data = info.get('has_data', False)
if has_data and age < 300: # Recent data
symbol_info.append(f"{symbol}:✅")
else:
symbol_info.append(f"{symbol}:❌")
if symbol_info:
logger.info(f"{data_type}: {', '.join(symbol_info)}")
logger.info("=== Data Provider Status ===")
logger.info("Data provider is running and optimized for BaseDataInput building")
except Exception as e:
logger.error(f"Error logging data status: {e}")
def update_data_cache(self, data_type: str, symbol: str, data: Any, source: str = "orchestrator") -> bool:
"""
Update data cache with new data (simplified approach)
Update data cache through data provider
Args:
data_type: Type of data ('ohlcv_1s', 'technical_indicators', etc.)
symbol: Trading symbol
data: New data to store
source: Source of the data
data: Data to store
source: Source of the update
Returns:
bool: True if successful
bool: True if updated successfully
"""
try:
return self.data_integration.cache.update(data_type, symbol, data, source)
# Invalidate cache when new data arrives
if hasattr(self.data_provider, 'invalidate_ohlcv_cache'):
self.data_provider.invalidate_ohlcv_cache(symbol)
return True
except Exception as e:
logger.error(f"Error updating data cache {data_type}/{symbol}: {e}")
return False
@ -3929,7 +3882,7 @@ class TradingOrchestrator:
def build_base_data_input(self, symbol: str) -> Optional[Any]:
"""
Build BaseDataInput using simplified data integration (optimized for speed)
Build BaseDataInput using optimized data provider (should be instantaneous)
Args:
symbol: Trading symbol
@ -3938,8 +3891,8 @@ class TradingOrchestrator:
BaseDataInput with consistent data structure
"""
try:
# Use simplified data integration to build BaseDataInput (should be instantaneous)
return self.data_integration.build_base_data_input(symbol)
# Use data provider's optimized build_base_data_input method
return self.data_provider.build_base_data_input(symbol)
except Exception as e:
logger.error(f"Error building BaseDataInput for {symbol}: {e}")

View File

@ -1,284 +0,0 @@
"""
Simplified Data Integration for Orchestrator
Replaces complex FIFO queues with simple cache-based data access.
Integrates with SmartDataUpdater for efficient data management.
"""
import logging
import threading
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
import pandas as pd
from .data_cache import get_data_cache
from .smart_data_updater import SmartDataUpdater
from .data_models import BaseDataInput, OHLCVBar
logger = logging.getLogger(__name__)
class SimplifiedDataIntegration:
"""
Simplified data integration that replaces FIFO queues with efficient caching
"""
def __init__(self, data_provider, symbols: List[str]):
self.data_provider = data_provider
self.symbols = symbols
self.cache = get_data_cache()
# Initialize smart data updater
self.data_updater = SmartDataUpdater(data_provider, symbols)
# Pre-built OHLCV data cache for instant access
self._ohlcv_cache = {} # {symbol: {timeframe: List[OHLCVBar]}}
self._ohlcv_cache_lock = threading.RLock()
self._last_cache_update = {} # {symbol: {timeframe: datetime}}
# Register for tick data if available
self._setup_tick_integration()
logger.info(f"SimplifiedDataIntegration initialized for {symbols}")
def start(self):
"""Start the data integration system"""
self.data_updater.start()
logger.info("SimplifiedDataIntegration started")
def stop(self):
"""Stop the data integration system"""
self.data_updater.stop()
logger.info("SimplifiedDataIntegration stopped")
def _setup_tick_integration(self):
"""Setup integration with tick data sources"""
try:
# Register callbacks for tick data if available
if hasattr(self.data_provider, 'register_tick_callback'):
self.data_provider.register_tick_callback(self._on_tick_data)
# Register for WebSocket data if available
if hasattr(self.data_provider, 'register_websocket_callback'):
self.data_provider.register_websocket_callback(self._on_websocket_data)
except Exception as e:
logger.warning(f"Tick integration setup failed: {e}")
def _on_tick_data(self, symbol: str, price: float, volume: float, timestamp: datetime = None):
"""Handle incoming tick data"""
self.data_updater.add_tick(symbol, price, volume, timestamp)
# Invalidate OHLCV cache for this symbol
self._invalidate_ohlcv_cache(symbol)
def _on_websocket_data(self, symbol: str, data: Dict[str, Any]):
"""Handle WebSocket data updates"""
try:
# Extract price and volume from WebSocket data
if 'price' in data and 'volume' in data:
self.data_updater.add_tick(symbol, data['price'], data['volume'])
# Invalidate OHLCV cache for this symbol
self._invalidate_ohlcv_cache(symbol)
except Exception as e:
logger.error(f"Error processing WebSocket data: {e}")
def _invalidate_ohlcv_cache(self, symbol: str):
"""Invalidate OHLCV cache for a symbol when new data arrives"""
try:
with self._ohlcv_cache_lock:
# Remove cached data for all timeframes of this symbol
keys_to_remove = [key for key in self._ohlcv_cache.keys() if key.startswith(f"{symbol}_")]
for key in keys_to_remove:
if key in self._ohlcv_cache:
del self._ohlcv_cache[key]
if key in self._last_cache_update:
del self._last_cache_update[key]
except Exception as e:
logger.error(f"Error invalidating OHLCV cache for {symbol}: {e}")
def build_base_data_input(self, symbol: str) -> Optional[BaseDataInput]:
"""
Build BaseDataInput from cached data (optimized for speed)
Args:
symbol: Trading symbol
Returns:
BaseDataInput with consistent data structure
"""
try:
# Get OHLCV data directly from optimized cache (no validation checks for speed)
ohlcv_1s_list = self._get_ohlcv_data_list(symbol, '1s', 300)
ohlcv_1m_list = self._get_ohlcv_data_list(symbol, '1m', 300)
ohlcv_1h_list = self._get_ohlcv_data_list(symbol, '1h', 300)
ohlcv_1d_list = self._get_ohlcv_data_list(symbol, '1d', 300)
# Get BTC reference data
btc_symbol = 'BTC/USDT'
btc_ohlcv_1s_list = self._get_ohlcv_data_list(btc_symbol, '1s', 300)
if not btc_ohlcv_1s_list:
# Use ETH data as fallback
btc_ohlcv_1s_list = ohlcv_1s_list
# Get cached data (fast lookups)
technical_indicators = self.cache.get('technical_indicators', symbol) or {}
cob_data = self.cache.get('cob_data', symbol)
last_predictions = self._get_recent_predictions(symbol)
# Build BaseDataInput (no validation for speed - assume data is good)
base_data = BaseDataInput(
symbol=symbol,
timestamp=datetime.now(),
ohlcv_1s=ohlcv_1s_list,
ohlcv_1m=ohlcv_1m_list,
ohlcv_1h=ohlcv_1h_list,
ohlcv_1d=ohlcv_1d_list,
btc_ohlcv_1s=btc_ohlcv_1s_list,
technical_indicators=technical_indicators,
cob_data=cob_data,
last_predictions=last_predictions
)
return base_data
except Exception as e:
logger.error(f"Error building BaseDataInput for {symbol}: {e}")
return None
def _get_ohlcv_data_list(self, symbol: str, timeframe: str, max_count: int) -> List[OHLCVBar]:
"""Get OHLCV data list from pre-built cache for instant access"""
try:
with self._ohlcv_cache_lock:
cache_key = f"{symbol}_{timeframe}"
# Check if we have fresh cached data (updated within last 5 seconds)
last_update = self._last_cache_update.get(cache_key)
if (last_update and
(datetime.now() - last_update).total_seconds() < 5 and
cache_key in self._ohlcv_cache):
cached_data = self._ohlcv_cache[cache_key]
return cached_data[-max_count:] if len(cached_data) >= max_count else cached_data
# Need to rebuild cache for this symbol/timeframe
data_list = self._build_ohlcv_cache(symbol, timeframe, max_count)
# Cache the result
self._ohlcv_cache[cache_key] = data_list
self._last_cache_update[cache_key] = datetime.now()
return data_list[-max_count:] if len(data_list) >= max_count else data_list
except Exception as e:
logger.error(f"Error getting OHLCV data list for {symbol}/{timeframe}: {e}")
return self._create_dummy_data_list(symbol, timeframe, max_count)
def _build_ohlcv_cache(self, symbol: str, timeframe: str, max_count: int) -> List[OHLCVBar]:
"""Build OHLCV cache from historical and current data"""
try:
data_list = []
# Get historical data first (this should be fast as it's already cached)
historical_df = self.cache.get_historical_data(symbol, timeframe)
if historical_df is not None and not historical_df.empty:
# Convert historical data to OHLCVBar objects
for idx, row in historical_df.tail(max_count - 1).iterrows():
bar = OHLCVBar(
symbol=symbol,
timestamp=idx if hasattr(idx, 'to_pydatetime') else datetime.now(),
open=float(row['open']),
high=float(row['high']),
low=float(row['low']),
close=float(row['close']),
volume=float(row['volume']),
timeframe=timeframe
)
data_list.append(bar)
# Add current data from cache
current_ohlcv = self.cache.get(f'ohlcv_{timeframe}', symbol)
if current_ohlcv and isinstance(current_ohlcv, OHLCVBar):
data_list.append(current_ohlcv)
# Ensure we have the right amount of data (pad if necessary)
while len(data_list) < max_count:
data_list.extend(self._create_dummy_data_list(symbol, timeframe, max_count - len(data_list)))
return data_list
except Exception as e:
logger.error(f"Error building OHLCV cache for {symbol}/{timeframe}: {e}")
return self._create_dummy_data_list(symbol, timeframe, max_count)
def _try_historical_fallback(self, symbol: str, missing_timeframes: List[str]) -> bool:
"""Try to use historical data for missing timeframes"""
try:
for timeframe in missing_timeframes:
historical_df = self.cache.get_historical_data(symbol, timeframe)
if historical_df is not None and not historical_df.empty:
# Use latest historical data as current data
latest_row = historical_df.iloc[-1]
ohlcv_bar = OHLCVBar(
symbol=symbol,
timestamp=historical_df.index[-1] if hasattr(historical_df.index[-1], 'to_pydatetime') else datetime.now(),
open=float(latest_row['open']),
high=float(latest_row['high']),
low=float(latest_row['low']),
close=float(latest_row['close']),
volume=float(latest_row['volume']),
timeframe=timeframe
)
self.cache.update(f'ohlcv_{timeframe}', symbol, ohlcv_bar, 'historical_fallback')
logger.info(f"Used historical fallback for {symbol} {timeframe}")
return True
except Exception as e:
logger.error(f"Error in historical fallback: {e}")
return False
def _get_recent_predictions(self, symbol: str) -> Dict[str, Any]:
"""Get recent model predictions"""
try:
predictions = {}
# Get predictions from cache
for model_type in ['cnn', 'rl', 'extrema']:
prediction_data = self.cache.get(f'prediction_{model_type}', symbol)
if prediction_data:
predictions[model_type] = prediction_data
return predictions
except Exception as e:
logger.error(f"Error getting recent predictions for {symbol}: {e}")
return {}
def update_model_prediction(self, model_name: str, symbol: str, prediction_data: Any):
"""Update model prediction in cache"""
self.cache.update(f'prediction_{model_name}', symbol, prediction_data, model_name)
def get_current_price(self, symbol: str) -> Optional[float]:
"""Get current price for a symbol"""
return self.data_updater.get_current_price(symbol)
def get_cache_status(self) -> Dict[str, Any]:
"""Get cache status for monitoring"""
return {
'cache_status': self.cache.get_status(),
'updater_status': self.data_updater.get_status()
}
def has_sufficient_data(self, symbol: str) -> bool:
"""Check if we have sufficient data for model predictions"""
required_data = ['ohlcv_1s', 'ohlcv_1m', 'ohlcv_1h', 'ohlcv_1d']
for data_type in required_data:
if not self.cache.has_data(data_type, symbol, max_age_seconds=300):
# Check historical data as fallback
timeframe = data_type.split('_')[1]
if not self.cache.has_historical_data(symbol, timeframe, min_bars=50):
return False
return True

View File

@ -1,358 +0,0 @@
"""
Smart Data Updater
Efficiently manages data updates using:
1. Initial historical data load (once)
2. Live tick data from WebSocket
3. Periodic HTTP updates (1m every minute, 1h every hour)
4. Smart candle construction from ticks
"""
import threading
import time
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
import pandas as pd
import numpy as np
from collections import deque
from .data_cache import get_data_cache, DataCache
from .data_models import OHLCVBar
logger = logging.getLogger(__name__)
class SmartDataUpdater:
"""
Smart data updater that efficiently manages market data with minimal API calls
"""
def __init__(self, data_provider, symbols: List[str]):
self.data_provider = data_provider
self.symbols = symbols
self.cache = get_data_cache()
self.running = False
# Tick data for candle construction
self.tick_buffers: Dict[str, deque] = {symbol: deque(maxlen=1000) for symbol in symbols}
self.tick_locks: Dict[str, threading.Lock] = {symbol: threading.Lock() for symbol in symbols}
# Current candle construction
self.current_candles: Dict[str, Dict[str, Dict]] = {} # {symbol: {timeframe: candle_data}}
self.candle_locks: Dict[str, threading.Lock] = {symbol: threading.Lock() for symbol in symbols}
# Update timers
self.last_updates: Dict[str, Dict[str, datetime]] = {} # {symbol: {timeframe: last_update}}
# Update intervals (in seconds)
self.update_intervals = {
'1s': 10, # Update 1s candles every 10 seconds from ticks
'1m': 60, # Update 1m candles every minute via HTTP
'1h': 3600, # Update 1h candles every hour via HTTP
'1d': 86400 # Update 1d candles every day via HTTP
}
logger.info(f"SmartDataUpdater initialized for {len(symbols)} symbols")
def start(self):
"""Start the smart data updater"""
if self.running:
return
self.running = True
# Load initial historical data
self._load_initial_historical_data()
# Start update threads
self.update_thread = threading.Thread(target=self._update_worker, daemon=True)
self.update_thread.start()
# Start tick processing thread
self.tick_thread = threading.Thread(target=self._tick_processor, daemon=True)
self.tick_thread.start()
logger.info("SmartDataUpdater started")
def stop(self):
"""Stop the smart data updater"""
self.running = False
logger.info("SmartDataUpdater stopped")
def add_tick(self, symbol: str, price: float, volume: float, timestamp: datetime = None):
"""Add tick data for candle construction"""
if symbol not in self.tick_buffers:
return
tick_data = {
'price': price,
'volume': volume,
'timestamp': timestamp or datetime.now()
}
with self.tick_locks[symbol]:
self.tick_buffers[symbol].append(tick_data)
def _load_initial_historical_data(self):
"""Load initial historical data for all symbols and timeframes"""
logger.info("Loading initial historical data...")
timeframes = ['1s', '1m', '1h', '1d']
limits = {'1s': 300, '1m': 300, '1h': 300, '1d': 300}
for symbol in self.symbols:
self.last_updates[symbol] = {}
self.current_candles[symbol] = {}
for timeframe in timeframes:
try:
limit = limits.get(timeframe, 300)
# Get historical data
df = None
if hasattr(self.data_provider, 'get_historical_data'):
df = self.data_provider.get_historical_data(symbol, timeframe, limit=limit)
if df is not None and not df.empty:
# Store in cache
self.cache.store_historical_data(symbol, timeframe, df)
# Update current candle data from latest bar
latest_bar = df.iloc[-1]
self._update_current_candle_from_bar(symbol, timeframe, latest_bar)
# Update cache with latest OHLCV
ohlcv_bar = self._df_row_to_ohlcv_bar(symbol, timeframe, latest_bar, df.index[-1])
self.cache.update(f'ohlcv_{timeframe}', symbol, ohlcv_bar, 'historical')
self.last_updates[symbol][timeframe] = datetime.now()
logger.info(f"Loaded {len(df)} {timeframe} bars for {symbol}")
else:
logger.warning(f"No historical data for {symbol} {timeframe}")
except Exception as e:
logger.error(f"Error loading historical data for {symbol} {timeframe}: {e}")
# Calculate initial technical indicators
self._calculate_technical_indicators()
logger.info("Initial historical data loading completed")
def _update_worker(self):
"""Background worker for periodic data updates"""
while self.running:
try:
current_time = datetime.now()
for symbol in self.symbols:
for timeframe in ['1m', '1h', '1d']: # Skip 1s (built from ticks)
try:
# Check if it's time to update
last_update = self.last_updates[symbol].get(timeframe)
interval = self.update_intervals[timeframe]
if not last_update or (current_time - last_update).total_seconds() >= interval:
self._update_timeframe_data(symbol, timeframe)
self.last_updates[symbol][timeframe] = current_time
except Exception as e:
logger.error(f"Error updating {symbol} {timeframe}: {e}")
# Update technical indicators every minute
if current_time.second < 10: # Update in first 10 seconds of each minute
self._calculate_technical_indicators()
time.sleep(10) # Check every 10 seconds
except Exception as e:
logger.error(f"Error in update worker: {e}")
time.sleep(30)
def _tick_processor(self):
"""Process ticks to build 1s candles"""
while self.running:
try:
current_time = datetime.now()
for symbol in self.symbols:
# Check if it's time to update 1s candles
last_update = self.last_updates[symbol].get('1s')
if not last_update or (current_time - last_update).total_seconds() >= self.update_intervals['1s']:
self._build_1s_candle_from_ticks(symbol)
self.last_updates[symbol]['1s'] = current_time
time.sleep(5) # Process every 5 seconds
except Exception as e:
logger.error(f"Error in tick processor: {e}")
time.sleep(10)
def _update_timeframe_data(self, symbol: str, timeframe: str):
"""Update data for a specific timeframe via HTTP"""
try:
# Get latest data from API
df = None
if hasattr(self.data_provider, 'get_latest_candles'):
df = self.data_provider.get_latest_candles(symbol, timeframe, limit=1)
elif hasattr(self.data_provider, 'get_historical_data'):
df = self.data_provider.get_historical_data(symbol, timeframe, limit=1)
if df is not None and not df.empty:
latest_bar = df.iloc[-1]
# Update current candle
self._update_current_candle_from_bar(symbol, timeframe, latest_bar)
# Update cache
ohlcv_bar = self._df_row_to_ohlcv_bar(symbol, timeframe, latest_bar, df.index[-1])
self.cache.update(f'ohlcv_{timeframe}', symbol, ohlcv_bar, 'http_update')
logger.debug(f"Updated {symbol} {timeframe} via HTTP")
except Exception as e:
logger.error(f"Error updating {symbol} {timeframe} via HTTP: {e}")
def _build_1s_candle_from_ticks(self, symbol: str):
"""Build 1s candle from accumulated ticks"""
try:
with self.tick_locks[symbol]:
ticks = list(self.tick_buffers[symbol])
if not ticks:
return
# Get ticks from last 10 seconds
cutoff_time = datetime.now() - timedelta(seconds=10)
recent_ticks = [tick for tick in ticks if tick['timestamp'] >= cutoff_time]
if not recent_ticks:
return
# Build OHLCV from ticks
prices = [tick['price'] for tick in recent_ticks]
volumes = [tick['volume'] for tick in recent_ticks]
ohlcv_data = {
'open': prices[0],
'high': max(prices),
'low': min(prices),
'close': prices[-1],
'volume': sum(volumes)
}
# Update current candle
with self.candle_locks[symbol]:
self.current_candles[symbol]['1s'] = ohlcv_data
# Create OHLCV bar and update cache
ohlcv_bar = OHLCVBar(
symbol=symbol,
timestamp=recent_ticks[-1]['timestamp'],
open=ohlcv_data['open'],
high=ohlcv_data['high'],
low=ohlcv_data['low'],
close=ohlcv_data['close'],
volume=ohlcv_data['volume'],
timeframe='1s'
)
self.cache.update('ohlcv_1s', symbol, ohlcv_bar, 'tick_constructed')
logger.debug(f"Built 1s candle for {symbol} from {len(recent_ticks)} ticks")
except Exception as e:
logger.error(f"Error building 1s candle from ticks for {symbol}: {e}")
def _update_current_candle_from_bar(self, symbol: str, timeframe: str, bar_data):
"""Update current candle data from a bar"""
try:
with self.candle_locks[symbol]:
self.current_candles[symbol][timeframe] = {
'open': float(bar_data['open']),
'high': float(bar_data['high']),
'low': float(bar_data['low']),
'close': float(bar_data['close']),
'volume': float(bar_data['volume'])
}
except Exception as e:
logger.error(f"Error updating current candle for {symbol} {timeframe}: {e}")
def _df_row_to_ohlcv_bar(self, symbol: str, timeframe: str, row, timestamp) -> OHLCVBar:
"""Convert DataFrame row to OHLCVBar"""
return OHLCVBar(
symbol=symbol,
timestamp=timestamp if hasattr(timestamp, 'to_pydatetime') else datetime.now(),
open=float(row['open']),
high=float(row['high']),
low=float(row['low']),
close=float(row['close']),
volume=float(row['volume']),
timeframe=timeframe
)
def _calculate_technical_indicators(self):
"""Calculate technical indicators for all symbols"""
try:
for symbol in self.symbols:
# Use 1m historical data for indicators
df = self.cache.get_historical_data(symbol, '1m')
if df is None or len(df) < 20:
continue
indicators = {}
try:
import ta
# RSI
if len(df) >= 14:
indicators['rsi'] = ta.momentum.RSIIndicator(df['close']).rsi().iloc[-1]
# Moving averages
if len(df) >= 20:
indicators['sma_20'] = df['close'].rolling(20).mean().iloc[-1]
if len(df) >= 12:
indicators['ema_12'] = df['close'].ewm(span=12).mean().iloc[-1]
if len(df) >= 26:
indicators['ema_26'] = df['close'].ewm(span=26).mean().iloc[-1]
if 'ema_12' in indicators:
indicators['macd'] = indicators['ema_12'] - indicators['ema_26']
# Bollinger Bands
if len(df) >= 20:
bb_period = 20
bb_std = 2
sma = df['close'].rolling(bb_period).mean()
std = df['close'].rolling(bb_period).std()
indicators['bb_upper'] = (sma + (std * bb_std)).iloc[-1]
indicators['bb_lower'] = (sma - (std * bb_std)).iloc[-1]
indicators['bb_middle'] = sma.iloc[-1]
# Remove NaN values
indicators = {k: float(v) for k, v in indicators.items() if not pd.isna(v)}
if indicators:
self.cache.update('technical_indicators', symbol, indicators, 'calculated')
logger.debug(f"Calculated {len(indicators)} indicators for {symbol}")
except Exception as e:
logger.error(f"Error calculating indicators for {symbol}: {e}")
except Exception as e:
logger.error(f"Error in technical indicators calculation: {e}")
def get_current_price(self, symbol: str) -> Optional[float]:
"""Get current price from latest 1s candle"""
ohlcv_1s = self.cache.get('ohlcv_1s', symbol)
if ohlcv_1s:
return ohlcv_1s.close
return None
def get_status(self) -> Dict[str, Any]:
"""Get updater status"""
status = {
'running': self.running,
'symbols': self.symbols,
'last_updates': self.last_updates,
'tick_buffer_sizes': {symbol: len(buffer) for symbol, buffer in self.tick_buffers.items()},
'cache_status': self.cache.get_status()
}
return status

View File

@ -1,221 +0,0 @@
#!/usr/bin/env python3
"""
Test COB Data Integration
This script tests that COB data is properly flowing through to BaseDataInput
and being used in the CNN model predictions.
"""
import sys
import os
import time
import logging
from datetime import datetime
# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from core.orchestrator import TradingOrchestrator
from core.config import get_config
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_cob_data_flow():
"""Test that COB data flows through to BaseDataInput"""
logger.info("=== Testing COB Data Integration ===")
try:
# Initialize orchestrator
config = get_config()
orchestrator = TradingOrchestrator(
symbol="ETH/USDT",
config=config
)
logger.info("✅ Orchestrator initialized")
# Check if COB integration is available
if orchestrator.cob_integration:
logger.info("✅ COB integration is available")
else:
logger.warning("⚠️ COB integration is not available")
# Wait a bit for COB data to potentially arrive
logger.info("Waiting for COB data...")
time.sleep(5)
# Test building BaseDataInput
symbol = "ETH/USDT"
base_data = orchestrator.build_base_data_input(symbol)
if base_data:
logger.info("✅ BaseDataInput created successfully")
# Check if COB data is present
if base_data.cob_data:
logger.info("✅ COB data is present in BaseDataInput")
logger.info(f" COB current price: {base_data.cob_data.current_price}")
logger.info(f" COB bucket size: {base_data.cob_data.bucket_size}")
logger.info(f" COB price buckets: {len(base_data.cob_data.price_buckets)} buckets")
logger.info(f" COB bid/ask imbalance: {len(base_data.cob_data.bid_ask_imbalance)} entries")
# Test feature vector generation
features = base_data.get_feature_vector()
logger.info(f"✅ Feature vector generated: {len(features)} features")
# Check if COB features are non-zero (indicating real data)
# COB features are at positions 7500-7700 (after OHLCV and BTC data)
cob_features = features[7500:7700] # 200 COB features
non_zero_cob = sum(1 for f in cob_features if f != 0.0)
if non_zero_cob > 0:
logger.info(f"✅ COB features contain real data: {non_zero_cob}/200 non-zero features")
else:
logger.warning("⚠️ COB features are all zeros (no real COB data)")
else:
logger.warning("⚠️ COB data is None in BaseDataInput")
# Check if there's COB data in the cache
if hasattr(orchestrator, 'data_integration'):
cached_cob = orchestrator.data_integration.cache.get('cob_data', symbol)
if cached_cob:
logger.info("✅ COB data found in cache but not in BaseDataInput")
else:
logger.warning("⚠️ No COB data in cache either")
# Test CNN prediction with the BaseDataInput
if orchestrator.cnn_adapter:
logger.info("Testing CNN prediction with BaseDataInput...")
try:
prediction = orchestrator.cnn_adapter.predict(base_data)
if prediction:
logger.info("✅ CNN prediction successful")
logger.info(f" Action: {prediction.predictions['action']}")
logger.info(f" Confidence: {prediction.confidence:.3f}")
logger.info(f" Pivot price: {prediction.predictions.get('pivot_price', 'N/A')}")
else:
logger.warning("⚠️ CNN prediction returned None")
except Exception as e:
logger.error(f"❌ CNN prediction failed: {e}")
else:
logger.warning("⚠️ CNN adapter not available")
else:
logger.error("❌ Failed to create BaseDataInput")
# Check orchestrator's latest COB data
if hasattr(orchestrator, 'latest_cob_data') and orchestrator.latest_cob_data:
logger.info(f"✅ Orchestrator has COB data for symbols: {list(orchestrator.latest_cob_data.keys())}")
for sym, cob_data in orchestrator.latest_cob_data.items():
logger.info(f" {sym}: {len(cob_data)} COB data fields")
else:
logger.warning("⚠️ No COB data in orchestrator.latest_cob_data")
return base_data is not None and (base_data.cob_data is not None if base_data else False)
except Exception as e:
logger.error(f"❌ Test failed: {e}")
import traceback
traceback.print_exc()
return False
def test_cob_cache_updates():
"""Test that COB data updates are properly cached"""
logger.info("=== Testing COB Cache Updates ===")
try:
# Initialize orchestrator
config = get_config()
orchestrator = TradingOrchestrator(
symbol="ETH/USDT",
config=config
)
# Check initial cache state
symbol = "ETH/USDT"
initial_cob = orchestrator.data_integration.cache.get('cob_data', symbol)
logger.info(f"Initial COB data in cache: {initial_cob is not None}")
# Simulate COB data update
from core.data_models import COBData
mock_cob_data = {
'current_price': 3000.0,
'price_buckets': {
2999.0: {'bid_volume': 100.0, 'ask_volume': 80.0, 'total_volume': 180.0, 'imbalance': 0.11},
3000.0: {'bid_volume': 150.0, 'ask_volume': 120.0, 'total_volume': 270.0, 'imbalance': 0.11},
3001.0: {'bid_volume': 90.0, 'ask_volume': 110.0, 'total_volume': 200.0, 'imbalance': -0.10}
},
'bid_ask_imbalance': {2999.0: 0.11, 3000.0: 0.11, 3001.0: -0.10},
'volume_weighted_prices': {2999.0: 2999.5, 3000.0: 3000.2, 3001.0: 3000.8},
'order_flow_metrics': {'total_volume': 650.0, 'avg_imbalance': 0.04},
'ma_1s_imbalance': {3000.0: 0.05},
'ma_5s_imbalance': {3000.0: 0.03}
}
# Trigger COB data update through callback
logger.info("Simulating COB data update...")
orchestrator._on_cob_dashboard_data(symbol, mock_cob_data)
# Check if cache was updated
updated_cob = orchestrator.data_integration.cache.get('cob_data', symbol)
if updated_cob:
logger.info("✅ COB data successfully updated in cache")
logger.info(f" Current price: {updated_cob.current_price}")
logger.info(f" Price buckets: {len(updated_cob.price_buckets)}")
else:
logger.warning("⚠️ COB data not found in cache after update")
# Test BaseDataInput with updated COB data
base_data = orchestrator.build_base_data_input(symbol)
if base_data and base_data.cob_data:
logger.info("✅ BaseDataInput now contains COB data")
# Test feature vector with real COB data
features = base_data.get_feature_vector()
cob_features = features[7500:7700] # 200 COB features
non_zero_cob = sum(1 for f in cob_features if f != 0.0)
logger.info(f"✅ COB features with real data: {non_zero_cob}/200 non-zero")
else:
logger.warning("⚠️ BaseDataInput still doesn't have COB data")
return updated_cob is not None
except Exception as e:
logger.error(f"❌ Cache update test failed: {e}")
return False
def main():
"""Run all COB integration tests"""
logger.info("Starting COB Data Integration Tests")
# Test 1: COB data flow
test1_passed = test_cob_data_flow()
# Test 2: COB cache updates
test2_passed = test_cob_cache_updates()
# Summary
logger.info("=== Test Summary ===")
logger.info(f"COB Data Flow: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
logger.info(f"COB Cache Updates: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
if test1_passed and test2_passed:
logger.info("🎉 All tests passed! COB data integration is working.")
logger.info("The system now:")
logger.info(" - Properly integrates COB data into BaseDataInput")
logger.info(" - Updates cache when COB data arrives")
logger.info(" - Includes COB features in CNN model input")
else:
logger.error("❌ Some tests failed. COB integration needs attention.")
if not test1_passed:
logger.error(" - COB data is not flowing to BaseDataInput")
if not test2_passed:
logger.error(" - COB cache updates are not working")
if __name__ == "__main__":
main()

View File

@ -1,527 +0,0 @@
#!/usr/bin/env python3
"""
Complete Training System Integration Test
This script demonstrates the full training system integration including:
- Comprehensive training data collection with validation
- CNN training pipeline with profitable episode replay
- RL training pipeline with profit-weighted experience replay
- Integration with existing DataProvider and models
- Real-time outcome validation and profitability tracking
"""
import asyncio
import logging
import numpy as np
import pandas as pd
import time
from datetime import datetime, timedelta
from pathlib import Path
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Import the complete training system
from core.training_data_collector import TrainingDataCollector
from core.cnn_training_pipeline import CNNPivotPredictor, CNNTrainer
from core.rl_training_pipeline import RLTradingAgent, RLTrainer
from core.enhanced_training_integration import EnhancedTrainingIntegration, EnhancedTrainingConfig
from core.data_provider import DataProvider
def create_mock_data_provider():
"""Create a mock data provider for testing"""
class MockDataProvider:
def __init__(self):
self.symbols = ['ETH/USDT', 'BTC/USDT']
self.timeframes = ['1s', '1m', '5m', '15m', '1h', '1d']
def get_historical_data(self, symbol, timeframe, limit=300, refresh=False):
"""Generate mock OHLCV data"""
dates = pd.date_range(start='2024-01-01', periods=limit, freq='1min')
# Generate realistic price data
base_price = 3000.0 if 'ETH' in symbol else 50000.0
price_data = []
current_price = base_price
for i in range(limit):
change = np.random.normal(0, 0.002)
current_price *= (1 + change)
price_data.append({
'timestamp': dates[i],
'open': current_price,
'high': current_price * (1 + abs(np.random.normal(0, 0.001))),
'low': current_price * (1 - abs(np.random.normal(0, 0.001))),
'close': current_price * (1 + np.random.normal(0, 0.0005)),
'volume': np.random.uniform(100, 1000),
'rsi_14': np.random.uniform(30, 70),
'macd': np.random.normal(0, 0.5),
'sma_20': current_price * (1 + np.random.normal(0, 0.01))
})
current_price = price_data[-1]['close']
df = pd.DataFrame(price_data)
df.set_index('timestamp', inplace=True)
return df
return MockDataProvider()
def test_training_data_collection():
"""Test the comprehensive training data collection system"""
logger.info("=== Testing Training Data Collection ===")
collector = TrainingDataCollector(
storage_dir="test_complete_training/data_collection",
max_episodes_per_symbol=1000
)
collector.start_collection()
# Simulate data collection for multiple episodes
for i in range(20):
symbol = 'ETHUSDT'
# Create sample data
ohlcv_data = {}
for timeframe in ['1s', '1m', '5m', '15m', '1h']:
dates = pd.date_range(start='2024-01-01', periods=300, freq='1min')
base_price = 3000.0 + i * 10 # Vary price over episodes
price_data = []
current_price = base_price
for j in range(300):
change = np.random.normal(0, 0.002)
current_price *= (1 + change)
price_data.append({
'timestamp': dates[j],
'open': current_price,
'high': current_price * (1 + abs(np.random.normal(0, 0.001))),
'low': current_price * (1 - abs(np.random.normal(0, 0.001))),
'close': current_price * (1 + np.random.normal(0, 0.0005)),
'volume': np.random.uniform(100, 1000)
})
current_price = price_data[-1]['close']
df = pd.DataFrame(price_data)
df.set_index('timestamp', inplace=True)
ohlcv_data[timeframe] = df
# Create other data
tick_data = [
{
'timestamp': datetime.now() - timedelta(seconds=j),
'price': base_price + np.random.normal(0, 5),
'volume': np.random.uniform(0.1, 10.0),
'side': 'buy' if np.random.random() > 0.5 else 'sell',
'trade_id': f'trade_{i}_{j}'
}
for j in range(100)
]
cob_data = {
'timestamp': datetime.now(),
'cob_features': np.random.randn(120).tolist(),
'spread': np.random.uniform(0.5, 2.0)
}
technical_indicators = {
'rsi_14': np.random.uniform(30, 70),
'macd': np.random.normal(0, 0.5),
'sma_20': base_price * (1 + np.random.normal(0, 0.01)),
'ema_12': base_price * (1 + np.random.normal(0, 0.01))
}
pivot_points = [
{
'timestamp': datetime.now() - timedelta(minutes=30),
'price': base_price + np.random.normal(0, 20),
'type': 'high' if np.random.random() > 0.5 else 'low'
}
]
# Create features
cnn_features = np.random.randn(2000).astype(np.float32)
rl_state = np.random.randn(2000).astype(np.float32)
orchestrator_context = {
'market_session': 'european',
'volatility_regime': 'medium',
'trend_direction': 'uptrend'
}
# Collect training data
episode_id = collector.collect_training_data(
symbol=symbol,
ohlcv_data=ohlcv_data,
tick_data=tick_data,
cob_data=cob_data,
technical_indicators=technical_indicators,
pivot_points=pivot_points,
cnn_features=cnn_features,
rl_state=rl_state,
orchestrator_context=orchestrator_context
)
logger.info(f"Created episode {i+1}: {episode_id}")
time.sleep(0.1)
# Get statistics
stats = collector.get_collection_statistics()
logger.info(f"Collection statistics: {stats}")
# Validate data integrity
validation = collector.validate_data_integrity()
logger.info(f"Data integrity: {validation}")
collector.stop_collection()
return collector
def test_cnn_training_pipeline():
"""Test the CNN training pipeline with profitable episode replay"""
logger.info("=== Testing CNN Training Pipeline ===")
# Initialize CNN model and trainer
model = CNNPivotPredictor(
input_channels=10,
sequence_length=300,
hidden_dim=256,
num_pivot_classes=3
)
trainer = CNNTrainer(
model=model,
device='cpu',
learning_rate=0.001,
storage_dir="test_complete_training/cnn_training"
)
# Create sample training episodes with outcomes
from core.training_data_collector import TrainingEpisode, ModelInputPackage, TrainingOutcome
episodes = []
for i in range(100):
# Create input package
input_package = ModelInputPackage(
timestamp=datetime.now() - timedelta(minutes=i),
symbol='ETHUSDT',
ohlcv_data={}, # Simplified for testing
tick_data=[],
cob_data={},
technical_indicators={'rsi': 50.0 + i},
pivot_points=[],
cnn_features=np.random.randn(2000).astype(np.float32),
rl_state=np.random.randn(2000).astype(np.float32),
orchestrator_context={}
)
# Create outcome with varying profitability
is_profitable = np.random.random() > 0.3 # 70% profitable
profitability_score = np.random.uniform(0.7, 1.0) if is_profitable else np.random.uniform(0.0, 0.3)
outcome = TrainingOutcome(
input_package_hash=input_package.data_hash,
timestamp=input_package.timestamp,
symbol='ETHUSDT',
price_change_1m=np.random.normal(0, 0.01),
price_change_5m=np.random.normal(0, 0.02),
price_change_15m=np.random.normal(0, 0.03),
price_change_1h=np.random.normal(0, 0.05),
max_profit_potential=abs(np.random.normal(0, 0.02)),
max_loss_potential=abs(np.random.normal(0, 0.015)),
optimal_entry_price=3000.0,
optimal_exit_price=3000.0 + np.random.normal(0, 10),
optimal_holding_time=timedelta(minutes=np.random.randint(5, 60)),
is_profitable=is_profitable,
profitability_score=profitability_score,
risk_reward_ratio=np.random.uniform(1.0, 3.0),
is_rapid_change=np.random.random() > 0.8,
change_velocity=np.random.uniform(0.1, 2.0),
volatility_spike=np.random.random() > 0.9,
outcome_validated=True
)
# Create episode
episode = TrainingEpisode(
episode_id=f"cnn_test_episode_{i}",
input_package=input_package,
model_predictions={},
actual_outcome=outcome,
episode_type='high_profit' if profitability_score > 0.8 else 'normal'
)
episodes.append(episode)
# Test training on all episodes
logger.info("Training on all episodes...")
results = trainer._train_on_episodes(episodes, training_mode='test_batch')
logger.info(f"Training results: {results}")
# Test training on profitable episodes only
logger.info("Training on profitable episodes only...")
profitable_results = trainer.train_on_profitable_episodes(
symbol='ETHUSDT',
min_profitability=0.7,
max_episodes=50
)
logger.info(f"Profitable training results: {profitable_results}")
# Get training statistics
stats = trainer.get_training_statistics()
logger.info(f"CNN training statistics: {stats}")
return trainer
def test_rl_training_pipeline():
"""Test the RL training pipeline with profit-weighted experience replay"""
logger.info("=== Testing RL Training Pipeline ===")
# Initialize RL agent and trainer
agent = RLTradingAgent(state_dim=2000, action_dim=3, hidden_dim=512)
trainer = RLTrainer(
agent=agent,
device='cpu',
storage_dir="test_complete_training/rl_training"
)
# Add sample experiences with varying profitability
logger.info("Adding sample experiences...")
experience_ids = []
for i in range(200):
state = np.random.randn(2000).astype(np.float32)
action = np.random.randint(0, 3) # SELL, HOLD, BUY
reward = np.random.normal(0, 0.1)
next_state = np.random.randn(2000).astype(np.float32)
done = np.random.random() > 0.9
market_context = {
'symbol': 'ETHUSDT',
'episode_id': f'rl_episode_{i}',
'timestamp': datetime.now() - timedelta(minutes=i),
'market_session': 'european',
'volatility_regime': 'medium'
}
cnn_predictions = {
'pivot_logits': np.random.randn(3).tolist(),
'confidence': np.random.uniform(0.3, 0.9)
}
experience_id = trainer.add_experience(
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
market_context=market_context,
cnn_predictions=cnn_predictions,
confidence_score=np.random.uniform(0.3, 0.9)
)
if experience_id:
experience_ids.append(experience_id)
# Simulate outcome validation for some experiences
if np.random.random() > 0.5: # 50% get outcomes
actual_profit = np.random.normal(0, 0.02)
optimal_action = np.random.randint(0, 3)
trainer.experience_buffer.update_experience_outcomes(
experience_id, actual_profit, optimal_action
)
logger.info(f"Added {len(experience_ids)} experiences")
# Test training on experiences
logger.info("Training on experiences...")
results = trainer.train_on_experiences(batch_size=32, num_batches=20)
logger.info(f"RL training results: {results}")
# Test training on profitable experiences only
logger.info("Training on profitable experiences only...")
profitable_results = trainer.train_on_profitable_experiences(
min_profitability=0.01,
max_experiences=100,
batch_size=32
)
logger.info(f"Profitable RL training results: {profitable_results}")
# Get training statistics
stats = trainer.get_training_statistics()
logger.info(f"RL training statistics: {stats}")
# Get buffer statistics
buffer_stats = trainer.experience_buffer.get_buffer_statistics()
logger.info(f"Experience buffer statistics: {buffer_stats}")
return trainer
def test_enhanced_integration():
"""Test the complete enhanced training integration"""
logger.info("=== Testing Enhanced Training Integration ===")
# Create mock data provider
data_provider = create_mock_data_provider()
# Create enhanced training configuration
config = EnhancedTrainingConfig(
collection_interval=0.5, # Faster for testing
min_data_completeness=0.7,
min_episodes_for_cnn_training=10, # Lower for testing
min_experiences_for_rl_training=20, # Lower for testing
training_frequency_minutes=1, # Faster for testing
min_profitability_for_replay=0.05,
use_existing_cob_rl_model=False, # Don't use for testing
enable_cross_model_learning=True,
enable_background_validation=True
)
# Initialize enhanced integration
integration = EnhancedTrainingIntegration(
data_provider=data_provider,
config=config
)
# Start integration
logger.info("Starting enhanced training integration...")
integration.start_enhanced_integration()
# Let it run for a short time
logger.info("Running integration for 30 seconds...")
time.sleep(30)
# Get statistics
stats = integration.get_integration_statistics()
logger.info(f"Integration statistics: {stats}")
# Test manual training trigger
logger.info("Testing manual training trigger...")
manual_results = integration.trigger_manual_training(training_type='all')
logger.info(f"Manual training results: {manual_results}")
# Stop integration
logger.info("Stopping enhanced training integration...")
integration.stop_enhanced_integration()
return integration
def test_complete_system():
"""Test the complete training system integration"""
logger.info("=== Testing Complete Training System ===")
try:
# Test individual components
logger.info("Testing individual components...")
collector = test_training_data_collection()
cnn_trainer = test_cnn_training_pipeline()
rl_trainer = test_rl_training_pipeline()
logger.info("✅ Individual components tested successfully!")
# Test complete integration
logger.info("Testing complete integration...")
integration = test_enhanced_integration()
logger.info("✅ Complete integration tested successfully!")
# Generate comprehensive report
logger.info("\n" + "="*80)
logger.info("COMPREHENSIVE TRAINING SYSTEM TEST REPORT")
logger.info("="*80)
# Data collection report
collection_stats = collector.get_collection_statistics()
logger.info(f"\n📊 DATA COLLECTION:")
logger.info(f" • Total episodes: {collection_stats.get('total_episodes', 0)}")
logger.info(f" • Profitable episodes: {collection_stats.get('profitable_episodes', 0)}")
logger.info(f" • Rapid change episodes: {collection_stats.get('rapid_change_episodes', 0)}")
logger.info(f" • Data completeness avg: {collection_stats.get('data_completeness_avg', 0):.3f}")
# CNN training report
cnn_stats = cnn_trainer.get_training_statistics()
logger.info(f"\n🧠 CNN TRAINING:")
logger.info(f" • Total sessions: {cnn_stats.get('total_sessions', 0)}")
logger.info(f" • Total steps: {cnn_stats.get('total_steps', 0)}")
logger.info(f" • Replay sessions: {cnn_stats.get('replay_sessions', 0)}")
# RL training report
rl_stats = rl_trainer.get_training_statistics()
logger.info(f"\n🤖 RL TRAINING:")
logger.info(f" • Total sessions: {rl_stats.get('total_sessions', 0)}")
logger.info(f" • Total experiences: {rl_stats.get('total_experiences', 0)}")
logger.info(f" • Average reward: {rl_stats.get('average_reward', 0):.4f}")
# Integration report
integration_stats = integration.get_integration_statistics()
logger.info(f"\n🔗 INTEGRATION:")
logger.info(f" • Total data packages: {integration_stats.get('total_data_packages', 0)}")
logger.info(f" • CNN training sessions: {integration_stats.get('cnn_training_sessions', 0)}")
logger.info(f" • RL training sessions: {integration_stats.get('rl_training_sessions', 0)}")
logger.info(f" • Overall profitability rate: {integration_stats.get('overall_profitability_rate', 0):.3f}")
logger.info("\n🎯 SYSTEM CAPABILITIES DEMONSTRATED:")
logger.info(" ✓ Comprehensive training data collection with validation")
logger.info(" ✓ CNN training with profitable episode replay")
logger.info(" ✓ RL training with profit-weighted experience replay")
logger.info(" ✓ Real-time outcome validation and profitability tracking")
logger.info(" ✓ Integrated training coordination across all models")
logger.info(" ✓ Gradient and backpropagation data storage for replay")
logger.info(" ✓ Rapid price change detection for premium training examples")
logger.info(" ✓ Data integrity validation and completeness checking")
logger.info("\n🚀 READY FOR PRODUCTION INTEGRATION:")
logger.info(" 1. Connect to your existing DataProvider")
logger.info(" 2. Integrate with your CNN and RL models")
logger.info(" 3. Connect to your Orchestrator and TradingExecutor")
logger.info(" 4. Enable real-time outcome validation")
logger.info(" 5. Deploy with monitoring and alerting")
return True
except Exception as e:
logger.error(f"❌ Complete system test failed: {e}")
import traceback
logger.error(traceback.format_exc())
return False
def main():
"""Main test function"""
logger.info("=" * 100)
logger.info("COMPREHENSIVE TRAINING SYSTEM INTEGRATION TEST")
logger.info("=" * 100)
start_time = time.time()
try:
# Run complete system test
success = test_complete_system()
end_time = time.time()
duration = end_time - start_time
logger.info("=" * 100)
if success:
logger.info("🎉 ALL TESTS PASSED! TRAINING SYSTEM READY FOR PRODUCTION!")
else:
logger.info("❌ SOME TESTS FAILED - CHECK LOGS FOR DETAILS")
logger.info(f"Total test duration: {duration:.2f} seconds")
logger.info("=" * 100)
except Exception as e:
logger.error(f"❌ Test execution failed: {e}")
import traceback
logger.error(traceback.format_exc())
if __name__ == "__main__":
main()

View File

@ -1,187 +0,0 @@
#!/usr/bin/env python3
"""
Test Fixed Input Size
Verify that the CNN model now receives consistent input dimensions
"""
import numpy as np
from datetime import datetime
from core.data_models import BaseDataInput, OHLCVBar
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
def create_test_data(with_cob=True, with_indicators=True):
"""Create test BaseDataInput with varying data completeness"""
# Create basic OHLCV data
ohlcv_bars = []
for i in range(100): # Less than 300 to test padding
bar = OHLCVBar(
symbol="ETH/USDT",
timestamp=datetime.now(),
open=100.0 + i,
high=101.0 + i,
low=99.0 + i,
close=100.5 + i,
volume=1000 + i,
timeframe="1s"
)
ohlcv_bars.append(bar)
# Create test data
base_data = BaseDataInput(
symbol="ETH/USDT",
timestamp=datetime.now(),
ohlcv_1s=ohlcv_bars,
ohlcv_1m=ohlcv_bars[:50], # Even less data
ohlcv_1h=ohlcv_bars[:20],
ohlcv_1d=ohlcv_bars[:10],
btc_ohlcv_1s=ohlcv_bars[:80], # Incomplete BTC data
technical_indicators={'rsi': 50.0, 'macd': 0.1} if with_indicators else {},
last_predictions={}
)
# Add COB data if requested (simplified for testing)
if with_cob:
# Create a simple mock COB data object
class MockCOBData:
def __init__(self):
self.price_buckets = {
2500.0: {'bid_volume': 100, 'ask_volume': 90, 'total_volume': 190, 'imbalance': 0.05},
2501.0: {'bid_volume': 80, 'ask_volume': 120, 'total_volume': 200, 'imbalance': -0.2}
}
self.ma_1s_imbalance = {2500.0: 0.1, 2501.0: -0.1}
self.ma_5s_imbalance = {2500.0: 0.05, 2501.0: -0.05}
base_data.cob_data = MockCOBData()
return base_data
def test_consistent_feature_size():
"""Test that feature vectors are always the same size"""
print("=== Testing Consistent Feature Size ===")
# Test different data scenarios
scenarios = [
("Full data", True, True),
("No COB data", False, True),
("No indicators", True, False),
("Minimal data", False, False)
]
feature_sizes = []
for name, with_cob, with_indicators in scenarios:
base_data = create_test_data(with_cob, with_indicators)
features = base_data.get_feature_vector()
print(f"{name}: {len(features)} features")
feature_sizes.append(len(features))
# Check if all sizes are the same
if len(set(feature_sizes)) == 1:
print(f"✅ All feature vectors have consistent size: {feature_sizes[0]}")
return feature_sizes[0]
else:
print(f"❌ Inconsistent feature sizes: {feature_sizes}")
return None
def test_cnn_adapter():
"""Test that CNN adapter works with fixed input size"""
print("\n=== Testing CNN Adapter ===")
try:
# Create CNN adapter
adapter = EnhancedCNNAdapter()
print(f"CNN model initialized with feature_dim: {adapter.model.feature_dim}")
# Test with different data scenarios
scenarios = [
("Full data", True, True),
("No COB data", False, True),
("Minimal data", False, False)
]
for name, with_cob, with_indicators in scenarios:
try:
base_data = create_test_data(with_cob, with_indicators)
# Make prediction
result = adapter.predict(base_data)
print(f"{name}: Prediction successful - {result.action} (conf={result.confidence:.3f})")
except Exception as e:
print(f"{name}: Prediction failed - {e}")
return True
except Exception as e:
print(f"❌ CNN adapter initialization failed: {e}")
return False
def test_no_network_rebuilding():
"""Test that network doesn't rebuild during runtime"""
print("\n=== Testing No Network Rebuilding ===")
try:
adapter = EnhancedCNNAdapter()
original_feature_dim = adapter.model.feature_dim
print(f"Original feature_dim: {original_feature_dim}")
# Make multiple predictions with different data
for i in range(5):
base_data = create_test_data(with_cob=(i % 2 == 0), with_indicators=(i % 3 == 0))
try:
result = adapter.predict(base_data)
current_feature_dim = adapter.model.feature_dim
if current_feature_dim != original_feature_dim:
print(f"❌ Network was rebuilt! Original: {original_feature_dim}, Current: {current_feature_dim}")
return False
print(f"✅ Prediction {i+1}: No rebuilding, feature_dim stable at {current_feature_dim}")
except Exception as e:
print(f"❌ Prediction {i+1} failed: {e}")
return False
print("✅ Network architecture remained stable throughout all predictions")
return True
except Exception as e:
print(f"❌ Test failed: {e}")
return False
def main():
"""Run all tests"""
print("=== Fixed Input Size Test Suite ===\n")
# Test 1: Consistent feature size
fixed_size = test_consistent_feature_size()
if fixed_size:
# Test 2: CNN adapter works
adapter_works = test_cnn_adapter()
if adapter_works:
# Test 3: No network rebuilding
no_rebuilding = test_no_network_rebuilding()
if no_rebuilding:
print("\n✅ ALL TESTS PASSED!")
print("✅ Feature vectors have consistent size")
print("✅ CNN adapter works with fixed input")
print("✅ No runtime network rebuilding")
print(f"✅ Fixed feature size: {fixed_size}")
else:
print("\n❌ Network rebuilding test failed")
else:
print("\n❌ CNN adapter test failed")
else:
print("\n❌ Feature size consistency test failed")
if __name__ == "__main__":
main()

View File

@ -57,8 +57,7 @@ def test_integrated_standardized_provider():
# Test 3: Test BaseDataInput with cross-model feeding
print("\n3. Testing BaseDataInput with cross-model predictions...")
# Set mock current price for COB data
provider.current_prices['ETHUSDT'] = 3000.0
# Use real current prices only - no mock data
base_input = provider.get_base_data_input('ETH/USDT')

View File

@ -1,261 +0,0 @@
"""
Test script for StandardizedCNN
This script tests the standardized CNN model with BaseDataInput format
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import logging
import torch
from datetime import datetime
from core.standardized_data_provider import StandardizedDataProvider
from NN.models.standardized_cnn import StandardizedCNN
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_standardized_cnn():
"""Test the StandardizedCNN with BaseDataInput"""
print("Testing StandardizedCNN with BaseDataInput...")
# Initialize data provider
symbols = ['ETH/USDT', 'BTC/USDT']
provider = StandardizedDataProvider(symbols=symbols)
# Initialize CNN model
cnn_model = StandardizedCNN(
model_name="test_standardized_cnn_v1",
confidence_threshold=0.6
)
print("✅ StandardizedCNN initialized")
print(f" Model info: {cnn_model.get_model_info()}")
# Test 1: Get BaseDataInput
print("\n1. Testing BaseDataInput creation...")
# Set mock current price for COB data
provider.current_prices['ETHUSDT'] = 3000.0
provider.current_prices['BTCUSDT'] = 50000.0
base_input = provider.get_base_data_input('ETH/USDT')
if base_input is None:
print("⚠️ BaseDataInput is None - creating mock data for testing")
# Create mock BaseDataInput for testing
from core.data_models import BaseDataInput, OHLCVBar, COBData
# Create mock OHLCV data
mock_ohlcv = []
for i in range(300):
bar = OHLCVBar(
symbol='ETH/USDT',
timestamp=datetime.now(),
open=3000.0 + i,
high=3010.0 + i,
low=2990.0 + i,
close=3005.0 + i,
volume=1000.0,
timeframe='1s'
)
mock_ohlcv.append(bar)
# Create mock COB data
mock_cob = COBData(
symbol='ETH/USDT',
timestamp=datetime.now(),
current_price=3000.0,
bucket_size=1.0,
price_buckets={3000.0 + i: {'bid_volume': 100, 'ask_volume': 100, 'total_volume': 200, 'imbalance': 0.0} for i in range(-20, 21)},
bid_ask_imbalance={3000.0 + i: 0.0 for i in range(-20, 21)},
volume_weighted_prices={3000.0 + i: 3000.0 + i for i in range(-20, 21)},
order_flow_metrics={}
)
base_input = BaseDataInput(
symbol='ETH/USDT',
timestamp=datetime.now(),
ohlcv_1s=mock_ohlcv,
ohlcv_1m=mock_ohlcv,
ohlcv_1h=mock_ohlcv,
ohlcv_1d=mock_ohlcv,
btc_ohlcv_1s=mock_ohlcv,
cob_data=mock_cob
)
print(f"✅ BaseDataInput available: {base_input.symbol}")
print(f" Feature vector shape: {base_input.get_feature_vector().shape}")
print(f" Validation: {'PASSED' if base_input.validate() else 'FAILED'}")
# Test 2: CNN Inference
print("\n2. Testing CNN inference with BaseDataInput...")
try:
model_output = cnn_model.predict_from_base_input(base_input)
print("✅ CNN inference successful!")
print(f" Model: {model_output.model_name} ({model_output.model_type})")
print(f" Action: {model_output.predictions['action']}")
print(f" Confidence: {model_output.confidence:.3f}")
print(f" Probabilities: BUY={model_output.predictions['buy_probability']:.3f}, "
f"SELL={model_output.predictions['sell_probability']:.3f}, "
f"HOLD={model_output.predictions['hold_probability']:.3f}")
print(f" Hidden states: {len(model_output.hidden_states)} layers")
print(f" Metadata: {len(model_output.metadata)} fields")
# Test hidden states for cross-model feeding
if model_output.hidden_states:
print(" Hidden state layers:")
for key, value in model_output.hidden_states.items():
if isinstance(value, list):
print(f" {key}: {len(value)} features")
else:
print(f" {key}: {type(value)}")
except Exception as e:
print(f"❌ CNN inference failed: {e}")
import traceback
traceback.print_exc()
# Test 3: Integration with StandardizedDataProvider
print("\n3. Testing integration with StandardizedDataProvider...")
try:
# Store the model output in the provider
provider.store_model_output(model_output)
# Retrieve it back
stored_outputs = provider.get_model_outputs('ETH/USDT')
if cnn_model.model_name in stored_outputs:
print("✅ Model output storage and retrieval successful!")
stored_output = stored_outputs[cnn_model.model_name]
print(f" Stored action: {stored_output.predictions['action']}")
print(f" Stored confidence: {stored_output.confidence:.3f}")
else:
print("❌ Model output storage failed")
# Test cross-model feeding
updated_base_input = provider.get_base_data_input('ETH/USDT')
if updated_base_input and cnn_model.model_name in updated_base_input.last_predictions:
print("✅ Cross-model feeding working!")
print(f" CNN prediction available in BaseDataInput for other models")
else:
print("⚠️ Cross-model feeding not working as expected")
except Exception as e:
print(f"❌ Integration test failed: {e}")
# Test 4: Training capabilities
print("\n4. Testing training capabilities...")
try:
# Create mock training data
training_inputs = [base_input] * 5 # Small batch
training_targets = ['BUY', 'SELL', 'HOLD', 'BUY', 'HOLD']
# Create optimizer
optimizer = torch.optim.Adam(cnn_model.parameters(), lr=0.001)
# Perform training step
loss = cnn_model.train_step(training_inputs, training_targets, optimizer)
print(f"✅ Training step successful!")
print(f" Training loss: {loss:.4f}")
# Test evaluation
eval_metrics = cnn_model.evaluate(training_inputs, training_targets)
print(f" Evaluation metrics: {eval_metrics}")
except Exception as e:
print(f"❌ Training test failed: {e}")
import traceback
traceback.print_exc()
# Test 5: Checkpoint management
print("\n5. Testing checkpoint management...")
try:
# Save checkpoint
checkpoint_path = "test_cache/cnn_checkpoint.pth"
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
metadata = {
'training_loss': loss if 'loss' in locals() else 0.5,
'accuracy': eval_metrics.get('accuracy', 0.0) if 'eval_metrics' in locals() else 0.0,
'test_run': True
}
cnn_model.save_checkpoint(checkpoint_path, metadata)
print("✅ Checkpoint saved successfully!")
# Create new model and load checkpoint
new_cnn = StandardizedCNN(model_name="loaded_cnn_v1")
success = new_cnn.load_checkpoint(checkpoint_path)
if success:
print("✅ Checkpoint loaded successfully!")
print(f" Loaded model info: {new_cnn.get_model_info()}")
else:
print("❌ Checkpoint loading failed")
except Exception as e:
print(f"❌ Checkpoint test failed: {e}")
# Test 6: Performance and compatibility
print("\n6. Testing performance and compatibility...")
try:
# Test inference speed
import time
start_time = time.time()
for _ in range(10):
_ = cnn_model.predict_from_base_input(base_input)
end_time = time.time()
avg_inference_time = (end_time - start_time) / 10 * 1000 # ms
print(f"✅ Performance test completed!")
print(f" Average inference time: {avg_inference_time:.2f} ms")
# Test memory usage
if torch.cuda.is_available():
memory_used = torch.cuda.memory_allocated() / 1024 / 1024 # MB
print(f" GPU memory used: {memory_used:.2f} MB")
# Test model size
param_count = sum(p.numel() for p in cnn_model.parameters())
model_size_mb = param_count * 4 / 1024 / 1024 # Assuming float32
print(f" Model parameters: {param_count:,}")
print(f" Estimated model size: {model_size_mb:.2f} MB")
except Exception as e:
print(f"❌ Performance test failed: {e}")
print("\n✅ StandardizedCNN test completed!")
print("\n🎯 Key achievements:")
print("✓ Accepts standardized BaseDataInput format")
print("✓ Processes COB+OHLCV data (300 frames multi-timeframe)")
print("✓ Outputs BUY/SELL/HOLD with confidence scores")
print("✓ Provides hidden states for cross-model feeding")
print("✓ Integrates with ModelOutputManager")
print("✓ Supports training and evaluation")
print("✓ Checkpoint management for persistence")
print("✓ Real-time inference capabilities")
print("\n🚀 Ready for integration:")
print("1. Can be used by orchestrator for decision making")
print("2. Hidden states available for RL model cross-feeding")
print("3. Outputs stored in standardized ModelOutput format")
print("4. Compatible with checkpoint management system")
print("5. Optimized for real-time trading inference")
return cnn_model
if __name__ == "__main__":
test_standardized_cnn()

View File

@ -36,7 +36,7 @@ def test_standardized_data_provider():
print("❌ BaseDataInput is None - this is expected if no historical data is available")
print(" The provider needs real market data to create BaseDataInput")
# Test with mock data
# Test with real data only
print("\n2. Testing data structures...")
# Test ModelOutput creation

View File

@ -1,337 +0,0 @@
#!/usr/bin/env python3
"""
Test Trading System Fixes
This script tests the fixes for the trading system by simulating trades
and verifying that the issues are resolved.
Usage:
python test_trading_fixes.py
"""
import os
import sys
import logging
import time
from pathlib import Path
from datetime import datetime
import json
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('logs/test_fixes.log')
]
)
logger = logging.getLogger(__name__)
class MockPosition:
"""Mock position for testing"""
def __init__(self, symbol, side, size, entry_price):
self.symbol = symbol
self.side = side
self.size = size
self.entry_price = entry_price
self.fees = 0.0
class MockTradingExecutor:
"""Mock trading executor for testing fixes"""
def __init__(self):
self.positions = {}
self.current_prices = {}
self.simulation_mode = True
def get_current_price(self, symbol):
"""Get current price for a symbol"""
# Simulate price movement
if symbol not in self.current_prices:
self.current_prices[symbol] = 3600.0
else:
# Add some random movement
import random
self.current_prices[symbol] += random.uniform(-10, 10)
return self.current_prices[symbol]
def execute_action(self, decision):
"""Execute a trading action"""
logger.info(f"Executing {decision.action} for {decision.symbol} at ${decision.price:.2f}")
# Simulate execution
if decision.action in ['BUY', 'LONG']:
self.positions[decision.symbol] = MockPosition(
decision.symbol, 'LONG', decision.size, decision.price
)
elif decision.action in ['SELL', 'SHORT']:
self.positions[decision.symbol] = MockPosition(
decision.symbol, 'SHORT', decision.size, decision.price
)
return True
def close_position(self, symbol, price=None):
"""Close a position"""
if symbol not in self.positions:
return False
if price is None:
price = self.get_current_price(symbol)
position = self.positions[symbol]
# Calculate P&L
if position.side == 'LONG':
pnl = (price - position.entry_price) * position.size
else: # SHORT
pnl = (position.entry_price - price) * position.size
logger.info(f"Closing {position.side} position for {symbol} at ${price:.2f}, P&L: ${pnl:.2f}")
# Remove position
del self.positions[symbol]
return True
class MockDecision:
"""Mock trading decision for testing"""
def __init__(self, symbol, action, price=None, size=10.0, confidence=0.8):
self.symbol = symbol
self.action = action
self.price = price
self.size = size
self.confidence = confidence
self.timestamp = datetime.now()
self.executed = False
self.blocked = False
self.blocked_reason = None
def test_price_caching_fix():
"""Test the price caching fix"""
logger.info("Testing price caching fix...")
# Create mock trading executor
executor = MockTradingExecutor()
# Import and apply fixes
try:
from core.trading_executor_fix import TradingExecutorFix
TradingExecutorFix.apply_fixes(executor)
# Test price caching
symbol = 'ETH/USDT'
# Get initial price
price1 = executor.get_current_price(symbol)
logger.info(f"Initial price: ${price1:.2f}")
# Get price again immediately (should be cached)
price2 = executor.get_current_price(symbol)
logger.info(f"Immediate second price: ${price2:.2f}")
# Wait for cache to expire
logger.info("Waiting for cache to expire (6 seconds)...")
time.sleep(6)
# Get price after cache expiry (should be different)
price3 = executor.get_current_price(symbol)
logger.info(f"Price after cache expiry: ${price3:.2f}")
# Check if prices are different
if price1 == price2:
logger.info("✅ Immediate price check uses cache as expected")
else:
logger.warning("❌ Immediate price check did not use cache")
if price1 != price3:
logger.info("✅ Price cache expiry working correctly")
else:
logger.warning("❌ Price cache expiry not working")
return True
except Exception as e:
logger.error(f"Error testing price caching fix: {e}")
import traceback
logger.error(traceback.format_exc())
return False
def test_duplicate_entry_prevention():
"""Test the duplicate entry prevention fix"""
logger.info("Testing duplicate entry prevention...")
# Create mock trading executor
executor = MockTradingExecutor()
# Import and apply fixes
try:
from core.trading_executor_fix import TradingExecutorFix
TradingExecutorFix.apply_fixes(executor)
# Test duplicate entry prevention
symbol = 'ETH/USDT'
# Create first decision
decision1 = MockDecision(symbol, 'SHORT')
decision1.price = executor.get_current_price(symbol)
# Execute first decision
result1 = executor.execute_action(decision1)
logger.info(f"First execution result: {result1}")
# Manually set recent entries to simulate a successful trade
if not hasattr(executor, 'recent_entries'):
executor.recent_entries = {}
executor.recent_entries[symbol] = {
'price': decision1.price,
'timestamp': time.time(),
'action': decision1.action
}
# Create second decision with same action
decision2 = MockDecision(symbol, 'SHORT')
decision2.price = decision1.price # Use same price to trigger duplicate detection
# Execute second decision immediately (should be blocked)
result2 = executor.execute_action(decision2)
logger.info(f"Second execution result: {result2}")
logger.info(f"Second decision blocked: {getattr(decision2, 'blocked', False)}")
logger.info(f"Block reason: {getattr(decision2, 'blocked_reason', 'None')}")
# Check if second decision was blocked by trade cooldown
# This is also acceptable as it prevents duplicate entries
if getattr(decision2, 'blocked', False):
logger.info("✅ Trade prevention working correctly (via cooldown)")
return True
else:
logger.warning("❌ Trade prevention not working correctly")
return False
except Exception as e:
logger.error(f"Error testing duplicate entry prevention: {e}")
import traceback
logger.error(traceback.format_exc())
return False
def test_pnl_calculation_fix():
"""Test the P&L calculation fix"""
logger.info("Testing P&L calculation fix...")
# Create mock trading executor
executor = MockTradingExecutor()
# Import and apply fixes
try:
from core.trading_executor_fix import TradingExecutorFix
TradingExecutorFix.apply_fixes(executor)
# Test P&L calculation
symbol = 'ETH/USDT'
# Create a position
entry_price = 3600.0
size = 10.0
executor.positions[symbol] = MockPosition(symbol, 'SHORT', size, entry_price)
# Set exit price
exit_price = 3550.0
# Calculate P&L using fixed method
pnl_result = executor._calculate_pnl(executor.positions[symbol], exit_price)
# Calculate expected P&L
expected_pnl = (entry_price - exit_price) * size
logger.info(f"Entry price: ${entry_price:.2f}")
logger.info(f"Exit price: ${exit_price:.2f}")
logger.info(f"Size: {size}")
logger.info(f"Calculated P&L: ${pnl_result['gross_pnl']:.2f}")
logger.info(f"Expected P&L: ${expected_pnl:.2f}")
# Check if P&L calculation is correct
if abs(pnl_result['gross_pnl'] - expected_pnl) < 0.01:
logger.info("✅ P&L calculation fix working correctly")
return True
else:
logger.warning("❌ P&L calculation fix not working correctly")
return False
except Exception as e:
logger.error(f"Error testing P&L calculation fix: {e}")
import traceback
logger.error(traceback.format_exc())
return False
def run_all_tests():
"""Run all tests"""
logger.info("=" * 70)
logger.info("TESTING TRADING SYSTEM FIXES")
logger.info("=" * 70)
# Create logs directory if it doesn't exist
os.makedirs('logs', exist_ok=True)
# Run tests
tests = [
("Price Caching Fix", test_price_caching_fix),
("Duplicate Entry Prevention", test_duplicate_entry_prevention),
("P&L Calculation Fix", test_pnl_calculation_fix)
]
results = {}
for test_name, test_func in tests:
logger.info(f"\n{'-'*30}")
logger.info(f"Running test: {test_name}")
logger.info(f"{'-'*30}")
try:
result = test_func()
results[test_name] = result
except Exception as e:
logger.error(f"Test {test_name} failed with error: {e}")
results[test_name] = False
# Print summary
logger.info("\n" + "=" * 70)
logger.info("TEST RESULTS SUMMARY")
logger.info("=" * 70)
all_passed = True
for test_name, result in results.items():
status = "✅ PASSED" if result else "❌ FAILED"
logger.info(f"{test_name}: {status}")
if not result:
all_passed = False
logger.info("=" * 70)
logger.info(f"OVERALL RESULT: {'✅ ALL TESTS PASSED' if all_passed else '❌ SOME TESTS FAILED'}")
logger.info("=" * 70)
# Save results to file
with open('logs/test_results.json', 'w') as f:
json.dump({
'timestamp': datetime.now().isoformat(),
'results': {k: "PASSED" if v else "FAILED" for k, v in results.items()},
'all_passed': all_passed
}, f, indent=2)
return all_passed
if __name__ == "__main__":
success = run_all_tests()
if success:
print("\nAll tests passed!")
sys.exit(0)
else:
print("\nSome tests failed. Check logs for details.")
sys.exit(1)

View File

@ -283,13 +283,16 @@ class TrainingSystemValidator:
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
logger.info(" ✓ RL Agent loaded")
# Test prediction capability
# Test prediction capability with real data
if hasattr(self.orchestrator.rl_agent, 'predict'):
# Create dummy state for testing
dummy_state = np.random.random(1000) # Simplified test state
try:
prediction = self.orchestrator.rl_agent.predict(dummy_state)
logger.info(" ✓ RL Agent can make predictions")
# Use real state from orchestrator instead of dummy data
real_state = self.orchestrator._get_rl_state('ETH/USDT')
if real_state is not None:
prediction = self.orchestrator.rl_agent.predict(real_state)
logger.info(" ✓ RL Agent can make predictions with real data")
else:
logger.warning(" ⚠ No real state available for RL prediction test")
except Exception as e:
logger.warning(f" ⚠ RL Agent prediction failed: {e}")
else: