cleanup and removed dummy data
This commit is contained in:
@ -1,190 +0,0 @@
|
|||||||
"""
|
|
||||||
Simplified Data Cache System
|
|
||||||
|
|
||||||
Replaces complex FIFO queues with a simple current state cache.
|
|
||||||
Supports unordered updates and extensible data types.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Dict, List, Optional, Any, Callable
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from collections import defaultdict
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DataCacheEntry:
|
|
||||||
"""Single cache entry with metadata"""
|
|
||||||
data: Any
|
|
||||||
timestamp: datetime
|
|
||||||
source: str = "unknown"
|
|
||||||
version: int = 1
|
|
||||||
|
|
||||||
class DataCache:
|
|
||||||
"""
|
|
||||||
Simplified data cache that stores only the latest data for each type.
|
|
||||||
Thread-safe and supports unordered updates from multiple sources.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.cache: Dict[str, Dict[str, DataCacheEntry]] = defaultdict(dict) # {data_type: {symbol: entry}}
|
|
||||||
self.locks: Dict[str, threading.RLock] = defaultdict(threading.RLock) # Per data_type locks
|
|
||||||
self.update_callbacks: Dict[str, List[Callable]] = defaultdict(list) # Update notifications
|
|
||||||
|
|
||||||
# Historical data storage (loaded once)
|
|
||||||
self.historical_data: Dict[str, Dict[str, pd.DataFrame]] = defaultdict(dict) # {symbol: {timeframe: df}}
|
|
||||||
self.historical_locks: Dict[str, threading.RLock] = defaultdict(threading.RLock)
|
|
||||||
|
|
||||||
logger.info("DataCache initialized with simplified architecture")
|
|
||||||
|
|
||||||
def update(self, data_type: str, symbol: str, data: Any, source: str = "unknown") -> bool:
|
|
||||||
"""
|
|
||||||
Update cache with latest data (thread-safe, unordered updates supported)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data_type: Type of data ('ohlcv_1s', 'technical_indicators', etc.)
|
|
||||||
symbol: Trading symbol
|
|
||||||
data: New data to store
|
|
||||||
source: Source of the update
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if updated successfully
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with self.locks[data_type]:
|
|
||||||
# Create or update entry
|
|
||||||
old_entry = self.cache[data_type].get(symbol)
|
|
||||||
new_version = (old_entry.version + 1) if old_entry else 1
|
|
||||||
|
|
||||||
self.cache[data_type][symbol] = DataCacheEntry(
|
|
||||||
data=data,
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
source=source,
|
|
||||||
version=new_version
|
|
||||||
)
|
|
||||||
|
|
||||||
# Notify callbacks
|
|
||||||
for callback in self.update_callbacks[data_type]:
|
|
||||||
try:
|
|
||||||
callback(symbol, data, source)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in update callback: {e}")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error updating cache {data_type}/{symbol}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get(self, data_type: str, symbol: str) -> Optional[Any]:
|
|
||||||
"""Get latest data for a type/symbol"""
|
|
||||||
try:
|
|
||||||
with self.locks[data_type]:
|
|
||||||
entry = self.cache[data_type].get(symbol)
|
|
||||||
return entry.data if entry else None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting cache {data_type}/{symbol}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_with_metadata(self, data_type: str, symbol: str) -> Optional[DataCacheEntry]:
|
|
||||||
"""Get latest data with metadata"""
|
|
||||||
try:
|
|
||||||
with self.locks[data_type]:
|
|
||||||
return self.cache[data_type].get(symbol)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting cache metadata {data_type}/{symbol}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_all(self, data_type: str) -> Dict[str, Any]:
|
|
||||||
"""Get all data for a data type"""
|
|
||||||
try:
|
|
||||||
with self.locks[data_type]:
|
|
||||||
return {symbol: entry.data for symbol, entry in self.cache[data_type].items()}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting all cache data for {data_type}: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def has_data(self, data_type: str, symbol: str, max_age_seconds: int = None) -> bool:
|
|
||||||
"""Check if we have recent data"""
|
|
||||||
try:
|
|
||||||
with self.locks[data_type]:
|
|
||||||
entry = self.cache[data_type].get(symbol)
|
|
||||||
if not entry:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if max_age_seconds:
|
|
||||||
age = (datetime.now() - entry.timestamp).total_seconds()
|
|
||||||
return age <= max_age_seconds
|
|
||||||
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error checking cache data {data_type}/{symbol}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def register_callback(self, data_type: str, callback: Callable[[str, Any, str], None]):
|
|
||||||
"""Register callback for data updates"""
|
|
||||||
self.update_callbacks[data_type].append(callback)
|
|
||||||
|
|
||||||
def get_status(self) -> Dict[str, Dict[str, Dict[str, Any]]]:
|
|
||||||
"""Get cache status for monitoring"""
|
|
||||||
status = {}
|
|
||||||
|
|
||||||
for data_type in self.cache:
|
|
||||||
with self.locks[data_type]:
|
|
||||||
status[data_type] = {}
|
|
||||||
for symbol, entry in self.cache[data_type].items():
|
|
||||||
age_seconds = (datetime.now() - entry.timestamp).total_seconds()
|
|
||||||
status[data_type][symbol] = {
|
|
||||||
'timestamp': entry.timestamp.isoformat(),
|
|
||||||
'age_seconds': age_seconds,
|
|
||||||
'source': entry.source,
|
|
||||||
'version': entry.version,
|
|
||||||
'has_data': entry.data is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
return status
|
|
||||||
|
|
||||||
# Historical data management
|
|
||||||
def store_historical_data(self, symbol: str, timeframe: str, df: pd.DataFrame):
|
|
||||||
"""Store historical data (loaded once at startup)"""
|
|
||||||
try:
|
|
||||||
with self.historical_locks[symbol]:
|
|
||||||
self.historical_data[symbol][timeframe] = df.copy()
|
|
||||||
logger.info(f"Stored {len(df)} historical bars for {symbol} {timeframe}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error storing historical data {symbol}/{timeframe}: {e}")
|
|
||||||
|
|
||||||
def get_historical_data(self, symbol: str, timeframe: str) -> Optional[pd.DataFrame]:
|
|
||||||
"""Get historical data"""
|
|
||||||
try:
|
|
||||||
with self.historical_locks[symbol]:
|
|
||||||
return self.historical_data[symbol].get(timeframe)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting historical data {symbol}/{timeframe}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def has_historical_data(self, symbol: str, timeframe: str, min_bars: int = 100) -> bool:
|
|
||||||
"""Check if we have sufficient historical data"""
|
|
||||||
try:
|
|
||||||
with self.historical_locks[symbol]:
|
|
||||||
df = self.historical_data[symbol].get(timeframe)
|
|
||||||
return df is not None and len(df) >= min_bars
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Global cache instance
|
|
||||||
_data_cache_instance = None
|
|
||||||
|
|
||||||
def get_data_cache() -> DataCache:
|
|
||||||
"""Get the global data cache instance"""
|
|
||||||
global _data_cache_instance
|
|
||||||
|
|
||||||
if _data_cache_instance is None:
|
|
||||||
_data_cache_instance = DataCache()
|
|
||||||
|
|
||||||
return _data_cache_instance
|
|
@ -114,42 +114,32 @@ class BaseDataInput:
|
|||||||
FIXED_FEATURE_SIZE = 7850
|
FIXED_FEATURE_SIZE = 7850
|
||||||
features = []
|
features = []
|
||||||
|
|
||||||
# OHLCV features for ETH (300 frames x 4 timeframes x 5 features = 6000 features)
|
# OHLCV features for ETH (up to 300 frames x 4 timeframes x 5 features)
|
||||||
for ohlcv_list in [self.ohlcv_1s, self.ohlcv_1m, self.ohlcv_1h, self.ohlcv_1d]:
|
for ohlcv_list in [self.ohlcv_1s, self.ohlcv_1m, self.ohlcv_1h, self.ohlcv_1d]:
|
||||||
# Ensure exactly 300 frames by padding or truncating
|
# Use actual data only, up to 300 frames
|
||||||
ohlcv_frames = ohlcv_list[-300:] if len(ohlcv_list) >= 300 else ohlcv_list
|
ohlcv_frames = ohlcv_list[-300:] if len(ohlcv_list) >= 300 else ohlcv_list
|
||||||
|
|
||||||
# Pad with zeros if not enough data
|
# Extract features from actual frames
|
||||||
while len(ohlcv_frames) < 300:
|
|
||||||
# Create a dummy OHLCV bar with zeros
|
|
||||||
dummy_bar = OHLCVBar(
|
|
||||||
symbol="ETH/USDT",
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
open=0.0, high=0.0, low=0.0, close=0.0, volume=0.0,
|
|
||||||
timeframe="1s"
|
|
||||||
)
|
|
||||||
ohlcv_frames.insert(0, dummy_bar)
|
|
||||||
|
|
||||||
# Extract features from exactly 300 frames
|
|
||||||
for bar in ohlcv_frames:
|
for bar in ohlcv_frames:
|
||||||
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
|
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
|
||||||
|
|
||||||
# BTC OHLCV features (300 frames x 5 features = 1500 features)
|
# Pad with zeros only if we have some data but less than 300 frames
|
||||||
|
frames_needed = 300 - len(ohlcv_frames)
|
||||||
|
if frames_needed > 0:
|
||||||
|
features.extend([0.0] * (frames_needed * 5)) # 5 features per frame
|
||||||
|
|
||||||
|
# BTC OHLCV features (up to 300 frames x 5 features = 1500 features)
|
||||||
btc_frames = self.btc_ohlcv_1s[-300:] if len(self.btc_ohlcv_1s) >= 300 else self.btc_ohlcv_1s
|
btc_frames = self.btc_ohlcv_1s[-300:] if len(self.btc_ohlcv_1s) >= 300 else self.btc_ohlcv_1s
|
||||||
|
|
||||||
# Pad BTC data if needed
|
# Extract features from actual BTC frames
|
||||||
while len(btc_frames) < 300:
|
|
||||||
dummy_bar = OHLCVBar(
|
|
||||||
symbol="BTC/USDT",
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
open=0.0, high=0.0, low=0.0, close=0.0, volume=0.0,
|
|
||||||
timeframe="1s"
|
|
||||||
)
|
|
||||||
btc_frames.insert(0, dummy_bar)
|
|
||||||
|
|
||||||
for bar in btc_frames:
|
for bar in btc_frames:
|
||||||
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
|
features.extend([bar.open, bar.high, bar.low, bar.close, bar.volume])
|
||||||
|
|
||||||
|
# Pad with zeros only if we have some data but less than 300 frames
|
||||||
|
btc_frames_needed = 300 - len(btc_frames)
|
||||||
|
if btc_frames_needed > 0:
|
||||||
|
features.extend([0.0] * (btc_frames_needed * 5)) # 5 features per frame
|
||||||
|
|
||||||
# COB features (FIXED SIZE: 200 features)
|
# COB features (FIXED SIZE: 200 features)
|
||||||
cob_features = []
|
cob_features = []
|
||||||
if self.cob_data:
|
if self.cob_data:
|
||||||
|
@ -224,6 +224,12 @@ class DataProvider:
|
|||||||
self.cob_data_cache[binance_symbol] = deque(maxlen=300) # 5 minutes of COB data
|
self.cob_data_cache[binance_symbol] = deque(maxlen=300) # 5 minutes of COB data
|
||||||
self.training_data_cache[binance_symbol] = deque(maxlen=1000) # Training data buffer
|
self.training_data_cache[binance_symbol] = deque(maxlen=1000) # Training data buffer
|
||||||
|
|
||||||
|
# Pre-built OHLCV cache for instant BaseDataInput building (optimization from SimplifiedDataIntegration)
|
||||||
|
self._ohlcv_cache = {} # {symbol: {timeframe: List[OHLCVBar]}}
|
||||||
|
self._ohlcv_cache_lock = Lock()
|
||||||
|
self._last_cache_update = {} # {symbol: {timeframe: datetime}}
|
||||||
|
self._cache_refresh_interval = 5 # seconds
|
||||||
|
|
||||||
# Data collection threads
|
# Data collection threads
|
||||||
self.data_collection_active = False
|
self.data_collection_active = False
|
||||||
|
|
||||||
@ -1387,6 +1393,175 @@ class DataProvider:
|
|||||||
logger.error(f"Error applying pivot normalization for {symbol}: {e}")
|
logger.error(f"Error applying pivot normalization for {symbol}: {e}")
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
def build_base_data_input(self, symbol: str) -> Optional['BaseDataInput']:
|
||||||
|
"""
|
||||||
|
Build BaseDataInput from cached data (optimized for speed)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseDataInput with consistent data structure
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from .data_models import BaseDataInput
|
||||||
|
|
||||||
|
# Get OHLCV data directly from optimized cache (no validation checks for speed)
|
||||||
|
ohlcv_1s_list = self._get_cached_ohlcv_bars(symbol, '1s', 300)
|
||||||
|
ohlcv_1m_list = self._get_cached_ohlcv_bars(symbol, '1m', 300)
|
||||||
|
ohlcv_1h_list = self._get_cached_ohlcv_bars(symbol, '1h', 300)
|
||||||
|
ohlcv_1d_list = self._get_cached_ohlcv_bars(symbol, '1d', 300)
|
||||||
|
|
||||||
|
# Get BTC reference data
|
||||||
|
btc_symbol = 'BTC/USDT'
|
||||||
|
btc_ohlcv_1s_list = self._get_cached_ohlcv_bars(btc_symbol, '1s', 300)
|
||||||
|
if not btc_ohlcv_1s_list:
|
||||||
|
# Use ETH data as fallback
|
||||||
|
btc_ohlcv_1s_list = ohlcv_1s_list
|
||||||
|
|
||||||
|
# Get cached data (fast lookups)
|
||||||
|
technical_indicators = self._get_latest_technical_indicators(symbol)
|
||||||
|
cob_data = self._get_latest_cob_data_object(symbol)
|
||||||
|
last_predictions = {} # TODO: Implement model prediction caching
|
||||||
|
|
||||||
|
# Build BaseDataInput (no validation for speed - assume data is good)
|
||||||
|
base_data = BaseDataInput(
|
||||||
|
symbol=symbol,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
ohlcv_1s=ohlcv_1s_list,
|
||||||
|
ohlcv_1m=ohlcv_1m_list,
|
||||||
|
ohlcv_1h=ohlcv_1h_list,
|
||||||
|
ohlcv_1d=ohlcv_1d_list,
|
||||||
|
btc_ohlcv_1s=btc_ohlcv_1s_list,
|
||||||
|
technical_indicators=technical_indicators,
|
||||||
|
cob_data=cob_data,
|
||||||
|
last_predictions=last_predictions
|
||||||
|
)
|
||||||
|
|
||||||
|
return base_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error building BaseDataInput for {symbol}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_cached_ohlcv_bars(self, symbol: str, timeframe: str, max_count: int) -> List['OHLCVBar']:
|
||||||
|
"""Get OHLCV data list from pre-built cache for instant access"""
|
||||||
|
try:
|
||||||
|
with self._ohlcv_cache_lock:
|
||||||
|
cache_key = f"{symbol}_{timeframe}"
|
||||||
|
|
||||||
|
# Check if we have fresh cached data (updated within last 5 seconds)
|
||||||
|
last_update = self._last_cache_update.get(cache_key)
|
||||||
|
if (last_update and
|
||||||
|
(datetime.now() - last_update).total_seconds() < self._cache_refresh_interval and
|
||||||
|
cache_key in self._ohlcv_cache):
|
||||||
|
|
||||||
|
cached_data = self._ohlcv_cache[cache_key]
|
||||||
|
return cached_data[-max_count:] if len(cached_data) >= max_count else cached_data
|
||||||
|
|
||||||
|
# Need to rebuild cache for this symbol/timeframe
|
||||||
|
data_list = self._build_ohlcv_bar_cache(symbol, timeframe, max_count)
|
||||||
|
|
||||||
|
# Cache the result
|
||||||
|
self._ohlcv_cache[cache_key] = data_list
|
||||||
|
self._last_cache_update[cache_key] = datetime.now()
|
||||||
|
|
||||||
|
return data_list[-max_count:] if len(data_list) >= max_count else data_list
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting cached OHLCV bars for {symbol}/{timeframe}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _build_ohlcv_bar_cache(self, symbol: str, timeframe: str, max_count: int) -> List['OHLCVBar']:
|
||||||
|
"""Build OHLCV bar cache from historical and current data"""
|
||||||
|
try:
|
||||||
|
from .data_models import OHLCVBar
|
||||||
|
data_list = []
|
||||||
|
|
||||||
|
# Get historical data first (this should be fast as it's already cached)
|
||||||
|
historical_df = self.get_historical_data(symbol, timeframe, limit=max_count)
|
||||||
|
if historical_df is not None and not historical_df.empty:
|
||||||
|
# Convert historical data to OHLCVBar objects
|
||||||
|
for idx, row in historical_df.tail(max_count).iterrows():
|
||||||
|
bar = OHLCVBar(
|
||||||
|
symbol=symbol,
|
||||||
|
timestamp=idx if hasattr(idx, 'to_pydatetime') else datetime.now(),
|
||||||
|
open=float(row['open']),
|
||||||
|
high=float(row['high']),
|
||||||
|
low=float(row['low']),
|
||||||
|
close=float(row['close']),
|
||||||
|
volume=float(row['volume']),
|
||||||
|
timeframe=timeframe
|
||||||
|
)
|
||||||
|
data_list.append(bar)
|
||||||
|
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error building OHLCV bar cache for {symbol}/{timeframe}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _get_latest_technical_indicators(self, symbol: str) -> Dict[str, float]:
|
||||||
|
"""Get latest technical indicators for a symbol"""
|
||||||
|
try:
|
||||||
|
# Get latest data and calculate indicators
|
||||||
|
df = self.get_historical_data(symbol, '1h', limit=50)
|
||||||
|
if df is not None and not df.empty:
|
||||||
|
df_with_indicators = self._add_technical_indicators(df)
|
||||||
|
if not df_with_indicators.empty:
|
||||||
|
# Return the latest indicators as a dict
|
||||||
|
latest_row = df_with_indicators.iloc[-1]
|
||||||
|
indicators = {}
|
||||||
|
for col in df_with_indicators.columns:
|
||||||
|
if col not in ['open', 'high', 'low', 'close', 'volume', 'timestamp']:
|
||||||
|
indicators[col] = float(latest_row[col]) if pd.notna(latest_row[col]) else 0.0
|
||||||
|
return indicators
|
||||||
|
return {}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting technical indicators for {symbol}: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _get_latest_cob_data_object(self, symbol: str) -> Optional['COBData']:
|
||||||
|
"""Get latest COB data as COBData object"""
|
||||||
|
try:
|
||||||
|
from .data_models import COBData
|
||||||
|
|
||||||
|
# Get latest COB data from cache
|
||||||
|
cob_data = self.get_latest_cob_data(symbol)
|
||||||
|
if cob_data and 'current_price' in cob_data:
|
||||||
|
return COBData(
|
||||||
|
symbol=symbol,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
current_price=cob_data['current_price'],
|
||||||
|
bucket_size=1.0 if 'ETH' in symbol else 10.0,
|
||||||
|
price_buckets=cob_data.get('price_buckets', {}),
|
||||||
|
bid_ask_imbalance=cob_data.get('bid_ask_imbalance', {}),
|
||||||
|
volume_weighted_prices=cob_data.get('volume_weighted_prices', {}),
|
||||||
|
order_flow_metrics=cob_data.get('order_flow_metrics', {}),
|
||||||
|
ma_1s_imbalance=cob_data.get('ma_1s_imbalance', {}),
|
||||||
|
ma_5s_imbalance=cob_data.get('ma_5s_imbalance', {}),
|
||||||
|
ma_15s_imbalance=cob_data.get('ma_15s_imbalance', {}),
|
||||||
|
ma_60s_imbalance=cob_data.get('ma_60s_imbalance', {})
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting COB data object for {symbol}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def invalidate_ohlcv_cache(self, symbol: str):
|
||||||
|
"""Invalidate OHLCV cache for a symbol when new data arrives"""
|
||||||
|
try:
|
||||||
|
with self._ohlcv_cache_lock:
|
||||||
|
# Remove cached data for all timeframes of this symbol
|
||||||
|
keys_to_remove = [key for key in self._ohlcv_cache.keys() if key.startswith(f"{symbol}_")]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
if key in self._ohlcv_cache:
|
||||||
|
del self._ohlcv_cache[key]
|
||||||
|
if key in self._last_cache_update:
|
||||||
|
del self._last_cache_update[key]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error invalidating OHLCV cache for {symbol}: {e}")
|
||||||
|
|
||||||
def _add_basic_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
|
def _add_basic_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||||
"""Add basic indicators for small datasets"""
|
"""Add basic indicators for small datasets"""
|
||||||
try:
|
try:
|
||||||
|
@ -179,12 +179,7 @@ class TradingOrchestrator:
|
|||||||
self.fusion_decisions_count: int = 0
|
self.fusion_decisions_count: int = 0
|
||||||
self.fusion_training_data: List[Any] = [] # Store training examples for decision model
|
self.fusion_training_data: List[Any] = [] # Store training examples for decision model
|
||||||
|
|
||||||
# Simplified Data Integration - Replace complex FIFO queues with efficient cache
|
# Use data provider directly for BaseDataInput building (optimized)
|
||||||
from core.simplified_data_integration import SimplifiedDataIntegration
|
|
||||||
self.data_integration = SimplifiedDataIntegration(
|
|
||||||
data_provider=self.data_provider,
|
|
||||||
symbols=[self.symbol] + self.ref_symbols
|
|
||||||
)
|
|
||||||
|
|
||||||
# COB Integration - Real-time market microstructure data
|
# COB Integration - Real-time market microstructure data
|
||||||
self.cob_integration = None # Will be set to COBIntegration instance if available
|
self.cob_integration = None # Will be set to COBIntegration instance if available
|
||||||
@ -250,8 +245,7 @@ class TradingOrchestrator:
|
|||||||
self.data_provider.start_centralized_data_collection()
|
self.data_provider.start_centralized_data_collection()
|
||||||
logger.info("Centralized data collection started - all models and dashboard will receive data")
|
logger.info("Centralized data collection started - all models and dashboard will receive data")
|
||||||
|
|
||||||
# Initialize simplified data integration
|
# Data provider is already initialized and optimized
|
||||||
self._initialize_simplified_data_integration()
|
|
||||||
|
|
||||||
# Log initial data status
|
# Log initial data status
|
||||||
logger.info("Simplified data integration initialized")
|
logger.info("Simplified data integration initialized")
|
||||||
@ -897,31 +891,10 @@ class TradingOrchestrator:
|
|||||||
try:
|
try:
|
||||||
self.latest_cob_data[symbol] = cob_data
|
self.latest_cob_data[symbol] = cob_data
|
||||||
|
|
||||||
# Update data cache with COB data for BaseDataInput
|
# Invalidate data provider cache when new COB data arrives
|
||||||
if hasattr(self, 'data_integration') and self.data_integration:
|
if hasattr(self.data_provider, 'invalidate_ohlcv_cache'):
|
||||||
# Convert cob_data to COBData format if needed
|
self.data_provider.invalidate_ohlcv_cache(symbol)
|
||||||
from .data_models import COBData
|
logger.debug(f"Invalidated data provider cache for {symbol} due to COB update")
|
||||||
|
|
||||||
# Create COBData object from the raw cob_data
|
|
||||||
if 'price_buckets' in cob_data and 'current_price' in cob_data:
|
|
||||||
cob_data_obj = COBData(
|
|
||||||
symbol=symbol,
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
current_price=cob_data['current_price'],
|
|
||||||
bucket_size=1.0 if 'ETH' in symbol else 10.0,
|
|
||||||
price_buckets=cob_data.get('price_buckets', {}),
|
|
||||||
bid_ask_imbalance=cob_data.get('bid_ask_imbalance', {}),
|
|
||||||
volume_weighted_prices=cob_data.get('volume_weighted_prices', {}),
|
|
||||||
order_flow_metrics=cob_data.get('order_flow_metrics', {}),
|
|
||||||
ma_1s_imbalance=cob_data.get('ma_1s_imbalance', {}),
|
|
||||||
ma_5s_imbalance=cob_data.get('ma_5s_imbalance', {}),
|
|
||||||
ma_15s_imbalance=cob_data.get('ma_15s_imbalance', {}),
|
|
||||||
ma_60s_imbalance=cob_data.get('ma_60s_imbalance', {})
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update cache with COB data
|
|
||||||
self.data_integration.cache.update('cob_data', symbol, cob_data_obj, 'cob_integration')
|
|
||||||
logger.debug(f"Updated cache with COB data for {symbol}")
|
|
||||||
|
|
||||||
# Update dashboard
|
# Update dashboard
|
||||||
if self.dashboard and hasattr(self.dashboard, 'update_cob_data'):
|
if self.dashboard and hasattr(self.dashboard, 'update_cob_data'):
|
||||||
@ -3722,54 +3695,34 @@ class TradingOrchestrator:
|
|||||||
"""
|
"""
|
||||||
return self.db_manager.get_best_checkpoint_metadata(model_name)
|
return self.db_manager.get_best_checkpoint_metadata(model_name)
|
||||||
|
|
||||||
# === SIMPLIFIED DATA MANAGEMENT ===
|
# === DATA MANAGEMENT ===
|
||||||
|
|
||||||
def _initialize_simplified_data_integration(self):
|
|
||||||
"""Initialize the simplified data integration system"""
|
|
||||||
try:
|
|
||||||
# Start the data integration system
|
|
||||||
self.data_integration.start()
|
|
||||||
logger.info("Simplified data integration started successfully")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error starting simplified data integration: {e}")
|
|
||||||
|
|
||||||
def _log_data_status(self):
|
def _log_data_status(self):
|
||||||
"""Log current data status"""
|
"""Log current data status"""
|
||||||
try:
|
try:
|
||||||
status = self.data_integration.get_cache_status()
|
logger.info("=== Data Provider Status ===")
|
||||||
cache_status = status.get('cache_status', {})
|
logger.info("Data provider is running and optimized for BaseDataInput building")
|
||||||
|
|
||||||
logger.info("=== Data Cache Status ===")
|
|
||||||
for data_type, symbols_data in cache_status.items():
|
|
||||||
symbol_info = []
|
|
||||||
for symbol, info in symbols_data.items():
|
|
||||||
age = info.get('age_seconds', 0)
|
|
||||||
has_data = info.get('has_data', False)
|
|
||||||
if has_data and age < 300: # Recent data
|
|
||||||
symbol_info.append(f"{symbol}:✅")
|
|
||||||
else:
|
|
||||||
symbol_info.append(f"{symbol}:❌")
|
|
||||||
|
|
||||||
if symbol_info:
|
|
||||||
logger.info(f"{data_type}: {', '.join(symbol_info)}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error logging data status: {e}")
|
logger.error(f"Error logging data status: {e}")
|
||||||
|
|
||||||
def update_data_cache(self, data_type: str, symbol: str, data: Any, source: str = "orchestrator") -> bool:
|
def update_data_cache(self, data_type: str, symbol: str, data: Any, source: str = "orchestrator") -> bool:
|
||||||
"""
|
"""
|
||||||
Update data cache with new data (simplified approach)
|
Update data cache through data provider
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_type: Type of data ('ohlcv_1s', 'technical_indicators', etc.)
|
data_type: Type of data ('ohlcv_1s', 'technical_indicators', etc.)
|
||||||
symbol: Trading symbol
|
symbol: Trading symbol
|
||||||
data: New data to store
|
data: Data to store
|
||||||
source: Source of the data
|
source: Source of the update
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if successful
|
bool: True if updated successfully
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return self.data_integration.cache.update(data_type, symbol, data, source)
|
# Invalidate cache when new data arrives
|
||||||
|
if hasattr(self.data_provider, 'invalidate_ohlcv_cache'):
|
||||||
|
self.data_provider.invalidate_ohlcv_cache(symbol)
|
||||||
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error updating data cache {data_type}/{symbol}: {e}")
|
logger.error(f"Error updating data cache {data_type}/{symbol}: {e}")
|
||||||
return False
|
return False
|
||||||
@ -3929,7 +3882,7 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
def build_base_data_input(self, symbol: str) -> Optional[Any]:
|
def build_base_data_input(self, symbol: str) -> Optional[Any]:
|
||||||
"""
|
"""
|
||||||
Build BaseDataInput using simplified data integration (optimized for speed)
|
Build BaseDataInput using optimized data provider (should be instantaneous)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
symbol: Trading symbol
|
symbol: Trading symbol
|
||||||
@ -3938,8 +3891,8 @@ class TradingOrchestrator:
|
|||||||
BaseDataInput with consistent data structure
|
BaseDataInput with consistent data structure
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Use simplified data integration to build BaseDataInput (should be instantaneous)
|
# Use data provider's optimized build_base_data_input method
|
||||||
return self.data_integration.build_base_data_input(symbol)
|
return self.data_provider.build_base_data_input(symbol)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error building BaseDataInput for {symbol}: {e}")
|
logger.error(f"Error building BaseDataInput for {symbol}: {e}")
|
||||||
|
@ -1,284 +0,0 @@
|
|||||||
"""
|
|
||||||
Simplified Data Integration for Orchestrator
|
|
||||||
|
|
||||||
Replaces complex FIFO queues with simple cache-based data access.
|
|
||||||
Integrates with SmartDataUpdater for efficient data management.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import threading
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Dict, List, Optional, Any
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
from .data_cache import get_data_cache
|
|
||||||
from .smart_data_updater import SmartDataUpdater
|
|
||||||
from .data_models import BaseDataInput, OHLCVBar
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class SimplifiedDataIntegration:
|
|
||||||
"""
|
|
||||||
Simplified data integration that replaces FIFO queues with efficient caching
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, data_provider, symbols: List[str]):
|
|
||||||
self.data_provider = data_provider
|
|
||||||
self.symbols = symbols
|
|
||||||
self.cache = get_data_cache()
|
|
||||||
|
|
||||||
# Initialize smart data updater
|
|
||||||
self.data_updater = SmartDataUpdater(data_provider, symbols)
|
|
||||||
|
|
||||||
# Pre-built OHLCV data cache for instant access
|
|
||||||
self._ohlcv_cache = {} # {symbol: {timeframe: List[OHLCVBar]}}
|
|
||||||
self._ohlcv_cache_lock = threading.RLock()
|
|
||||||
self._last_cache_update = {} # {symbol: {timeframe: datetime}}
|
|
||||||
|
|
||||||
# Register for tick data if available
|
|
||||||
self._setup_tick_integration()
|
|
||||||
|
|
||||||
logger.info(f"SimplifiedDataIntegration initialized for {symbols}")
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
"""Start the data integration system"""
|
|
||||||
self.data_updater.start()
|
|
||||||
logger.info("SimplifiedDataIntegration started")
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
"""Stop the data integration system"""
|
|
||||||
self.data_updater.stop()
|
|
||||||
logger.info("SimplifiedDataIntegration stopped")
|
|
||||||
|
|
||||||
def _setup_tick_integration(self):
|
|
||||||
"""Setup integration with tick data sources"""
|
|
||||||
try:
|
|
||||||
# Register callbacks for tick data if available
|
|
||||||
if hasattr(self.data_provider, 'register_tick_callback'):
|
|
||||||
self.data_provider.register_tick_callback(self._on_tick_data)
|
|
||||||
|
|
||||||
# Register for WebSocket data if available
|
|
||||||
if hasattr(self.data_provider, 'register_websocket_callback'):
|
|
||||||
self.data_provider.register_websocket_callback(self._on_websocket_data)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Tick integration setup failed: {e}")
|
|
||||||
|
|
||||||
def _on_tick_data(self, symbol: str, price: float, volume: float, timestamp: datetime = None):
|
|
||||||
"""Handle incoming tick data"""
|
|
||||||
self.data_updater.add_tick(symbol, price, volume, timestamp)
|
|
||||||
# Invalidate OHLCV cache for this symbol
|
|
||||||
self._invalidate_ohlcv_cache(symbol)
|
|
||||||
|
|
||||||
def _on_websocket_data(self, symbol: str, data: Dict[str, Any]):
|
|
||||||
"""Handle WebSocket data updates"""
|
|
||||||
try:
|
|
||||||
# Extract price and volume from WebSocket data
|
|
||||||
if 'price' in data and 'volume' in data:
|
|
||||||
self.data_updater.add_tick(symbol, data['price'], data['volume'])
|
|
||||||
# Invalidate OHLCV cache for this symbol
|
|
||||||
self._invalidate_ohlcv_cache(symbol)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing WebSocket data: {e}")
|
|
||||||
|
|
||||||
def _invalidate_ohlcv_cache(self, symbol: str):
|
|
||||||
"""Invalidate OHLCV cache for a symbol when new data arrives"""
|
|
||||||
try:
|
|
||||||
with self._ohlcv_cache_lock:
|
|
||||||
# Remove cached data for all timeframes of this symbol
|
|
||||||
keys_to_remove = [key for key in self._ohlcv_cache.keys() if key.startswith(f"{symbol}_")]
|
|
||||||
for key in keys_to_remove:
|
|
||||||
if key in self._ohlcv_cache:
|
|
||||||
del self._ohlcv_cache[key]
|
|
||||||
if key in self._last_cache_update:
|
|
||||||
del self._last_cache_update[key]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error invalidating OHLCV cache for {symbol}: {e}")
|
|
||||||
|
|
||||||
def build_base_data_input(self, symbol: str) -> Optional[BaseDataInput]:
|
|
||||||
"""
|
|
||||||
Build BaseDataInput from cached data (optimized for speed)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
symbol: Trading symbol
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BaseDataInput with consistent data structure
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Get OHLCV data directly from optimized cache (no validation checks for speed)
|
|
||||||
ohlcv_1s_list = self._get_ohlcv_data_list(symbol, '1s', 300)
|
|
||||||
ohlcv_1m_list = self._get_ohlcv_data_list(symbol, '1m', 300)
|
|
||||||
ohlcv_1h_list = self._get_ohlcv_data_list(symbol, '1h', 300)
|
|
||||||
ohlcv_1d_list = self._get_ohlcv_data_list(symbol, '1d', 300)
|
|
||||||
|
|
||||||
# Get BTC reference data
|
|
||||||
btc_symbol = 'BTC/USDT'
|
|
||||||
btc_ohlcv_1s_list = self._get_ohlcv_data_list(btc_symbol, '1s', 300)
|
|
||||||
if not btc_ohlcv_1s_list:
|
|
||||||
# Use ETH data as fallback
|
|
||||||
btc_ohlcv_1s_list = ohlcv_1s_list
|
|
||||||
|
|
||||||
# Get cached data (fast lookups)
|
|
||||||
technical_indicators = self.cache.get('technical_indicators', symbol) or {}
|
|
||||||
cob_data = self.cache.get('cob_data', symbol)
|
|
||||||
last_predictions = self._get_recent_predictions(symbol)
|
|
||||||
|
|
||||||
# Build BaseDataInput (no validation for speed - assume data is good)
|
|
||||||
base_data = BaseDataInput(
|
|
||||||
symbol=symbol,
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
ohlcv_1s=ohlcv_1s_list,
|
|
||||||
ohlcv_1m=ohlcv_1m_list,
|
|
||||||
ohlcv_1h=ohlcv_1h_list,
|
|
||||||
ohlcv_1d=ohlcv_1d_list,
|
|
||||||
btc_ohlcv_1s=btc_ohlcv_1s_list,
|
|
||||||
technical_indicators=technical_indicators,
|
|
||||||
cob_data=cob_data,
|
|
||||||
last_predictions=last_predictions
|
|
||||||
)
|
|
||||||
|
|
||||||
return base_data
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error building BaseDataInput for {symbol}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _get_ohlcv_data_list(self, symbol: str, timeframe: str, max_count: int) -> List[OHLCVBar]:
|
|
||||||
"""Get OHLCV data list from pre-built cache for instant access"""
|
|
||||||
try:
|
|
||||||
with self._ohlcv_cache_lock:
|
|
||||||
cache_key = f"{symbol}_{timeframe}"
|
|
||||||
|
|
||||||
# Check if we have fresh cached data (updated within last 5 seconds)
|
|
||||||
last_update = self._last_cache_update.get(cache_key)
|
|
||||||
if (last_update and
|
|
||||||
(datetime.now() - last_update).total_seconds() < 5 and
|
|
||||||
cache_key in self._ohlcv_cache):
|
|
||||||
|
|
||||||
cached_data = self._ohlcv_cache[cache_key]
|
|
||||||
return cached_data[-max_count:] if len(cached_data) >= max_count else cached_data
|
|
||||||
|
|
||||||
# Need to rebuild cache for this symbol/timeframe
|
|
||||||
data_list = self._build_ohlcv_cache(symbol, timeframe, max_count)
|
|
||||||
|
|
||||||
# Cache the result
|
|
||||||
self._ohlcv_cache[cache_key] = data_list
|
|
||||||
self._last_cache_update[cache_key] = datetime.now()
|
|
||||||
|
|
||||||
return data_list[-max_count:] if len(data_list) >= max_count else data_list
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting OHLCV data list for {symbol}/{timeframe}: {e}")
|
|
||||||
return self._create_dummy_data_list(symbol, timeframe, max_count)
|
|
||||||
|
|
||||||
def _build_ohlcv_cache(self, symbol: str, timeframe: str, max_count: int) -> List[OHLCVBar]:
|
|
||||||
"""Build OHLCV cache from historical and current data"""
|
|
||||||
try:
|
|
||||||
data_list = []
|
|
||||||
|
|
||||||
# Get historical data first (this should be fast as it's already cached)
|
|
||||||
historical_df = self.cache.get_historical_data(symbol, timeframe)
|
|
||||||
if historical_df is not None and not historical_df.empty:
|
|
||||||
# Convert historical data to OHLCVBar objects
|
|
||||||
for idx, row in historical_df.tail(max_count - 1).iterrows():
|
|
||||||
bar = OHLCVBar(
|
|
||||||
symbol=symbol,
|
|
||||||
timestamp=idx if hasattr(idx, 'to_pydatetime') else datetime.now(),
|
|
||||||
open=float(row['open']),
|
|
||||||
high=float(row['high']),
|
|
||||||
low=float(row['low']),
|
|
||||||
close=float(row['close']),
|
|
||||||
volume=float(row['volume']),
|
|
||||||
timeframe=timeframe
|
|
||||||
)
|
|
||||||
data_list.append(bar)
|
|
||||||
|
|
||||||
# Add current data from cache
|
|
||||||
current_ohlcv = self.cache.get(f'ohlcv_{timeframe}', symbol)
|
|
||||||
if current_ohlcv and isinstance(current_ohlcv, OHLCVBar):
|
|
||||||
data_list.append(current_ohlcv)
|
|
||||||
|
|
||||||
# Ensure we have the right amount of data (pad if necessary)
|
|
||||||
while len(data_list) < max_count:
|
|
||||||
data_list.extend(self._create_dummy_data_list(symbol, timeframe, max_count - len(data_list)))
|
|
||||||
|
|
||||||
return data_list
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error building OHLCV cache for {symbol}/{timeframe}: {e}")
|
|
||||||
return self._create_dummy_data_list(symbol, timeframe, max_count)
|
|
||||||
|
|
||||||
|
|
||||||
def _try_historical_fallback(self, symbol: str, missing_timeframes: List[str]) -> bool:
|
|
||||||
"""Try to use historical data for missing timeframes"""
|
|
||||||
try:
|
|
||||||
for timeframe in missing_timeframes:
|
|
||||||
historical_df = self.cache.get_historical_data(symbol, timeframe)
|
|
||||||
if historical_df is not None and not historical_df.empty:
|
|
||||||
# Use latest historical data as current data
|
|
||||||
latest_row = historical_df.iloc[-1]
|
|
||||||
ohlcv_bar = OHLCVBar(
|
|
||||||
symbol=symbol,
|
|
||||||
timestamp=historical_df.index[-1] if hasattr(historical_df.index[-1], 'to_pydatetime') else datetime.now(),
|
|
||||||
open=float(latest_row['open']),
|
|
||||||
high=float(latest_row['high']),
|
|
||||||
low=float(latest_row['low']),
|
|
||||||
close=float(latest_row['close']),
|
|
||||||
volume=float(latest_row['volume']),
|
|
||||||
timeframe=timeframe
|
|
||||||
)
|
|
||||||
|
|
||||||
self.cache.update(f'ohlcv_{timeframe}', symbol, ohlcv_bar, 'historical_fallback')
|
|
||||||
logger.info(f"Used historical fallback for {symbol} {timeframe}")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in historical fallback: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _get_recent_predictions(self, symbol: str) -> Dict[str, Any]:
|
|
||||||
"""Get recent model predictions"""
|
|
||||||
try:
|
|
||||||
predictions = {}
|
|
||||||
|
|
||||||
# Get predictions from cache
|
|
||||||
for model_type in ['cnn', 'rl', 'extrema']:
|
|
||||||
prediction_data = self.cache.get(f'prediction_{model_type}', symbol)
|
|
||||||
if prediction_data:
|
|
||||||
predictions[model_type] = prediction_data
|
|
||||||
|
|
||||||
return predictions
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting recent predictions for {symbol}: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def update_model_prediction(self, model_name: str, symbol: str, prediction_data: Any):
|
|
||||||
"""Update model prediction in cache"""
|
|
||||||
self.cache.update(f'prediction_{model_name}', symbol, prediction_data, model_name)
|
|
||||||
|
|
||||||
def get_current_price(self, symbol: str) -> Optional[float]:
|
|
||||||
"""Get current price for a symbol"""
|
|
||||||
return self.data_updater.get_current_price(symbol)
|
|
||||||
|
|
||||||
def get_cache_status(self) -> Dict[str, Any]:
|
|
||||||
"""Get cache status for monitoring"""
|
|
||||||
return {
|
|
||||||
'cache_status': self.cache.get_status(),
|
|
||||||
'updater_status': self.data_updater.get_status()
|
|
||||||
}
|
|
||||||
|
|
||||||
def has_sufficient_data(self, symbol: str) -> bool:
|
|
||||||
"""Check if we have sufficient data for model predictions"""
|
|
||||||
required_data = ['ohlcv_1s', 'ohlcv_1m', 'ohlcv_1h', 'ohlcv_1d']
|
|
||||||
|
|
||||||
for data_type in required_data:
|
|
||||||
if not self.cache.has_data(data_type, symbol, max_age_seconds=300):
|
|
||||||
# Check historical data as fallback
|
|
||||||
timeframe = data_type.split('_')[1]
|
|
||||||
if not self.cache.has_historical_data(symbol, timeframe, min_bars=50):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
@ -1,358 +0,0 @@
|
|||||||
"""
|
|
||||||
Smart Data Updater
|
|
||||||
|
|
||||||
Efficiently manages data updates using:
|
|
||||||
1. Initial historical data load (once)
|
|
||||||
2. Live tick data from WebSocket
|
|
||||||
3. Periodic HTTP updates (1m every minute, 1h every hour)
|
|
||||||
4. Smart candle construction from ticks
|
|
||||||
"""
|
|
||||||
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Dict, List, Optional, Any
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
from .data_cache import get_data_cache, DataCache
|
|
||||||
from .data_models import OHLCVBar
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class SmartDataUpdater:
|
|
||||||
"""
|
|
||||||
Smart data updater that efficiently manages market data with minimal API calls
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, data_provider, symbols: List[str]):
|
|
||||||
self.data_provider = data_provider
|
|
||||||
self.symbols = symbols
|
|
||||||
self.cache = get_data_cache()
|
|
||||||
self.running = False
|
|
||||||
|
|
||||||
# Tick data for candle construction
|
|
||||||
self.tick_buffers: Dict[str, deque] = {symbol: deque(maxlen=1000) for symbol in symbols}
|
|
||||||
self.tick_locks: Dict[str, threading.Lock] = {symbol: threading.Lock() for symbol in symbols}
|
|
||||||
|
|
||||||
# Current candle construction
|
|
||||||
self.current_candles: Dict[str, Dict[str, Dict]] = {} # {symbol: {timeframe: candle_data}}
|
|
||||||
self.candle_locks: Dict[str, threading.Lock] = {symbol: threading.Lock() for symbol in symbols}
|
|
||||||
|
|
||||||
# Update timers
|
|
||||||
self.last_updates: Dict[str, Dict[str, datetime]] = {} # {symbol: {timeframe: last_update}}
|
|
||||||
|
|
||||||
# Update intervals (in seconds)
|
|
||||||
self.update_intervals = {
|
|
||||||
'1s': 10, # Update 1s candles every 10 seconds from ticks
|
|
||||||
'1m': 60, # Update 1m candles every minute via HTTP
|
|
||||||
'1h': 3600, # Update 1h candles every hour via HTTP
|
|
||||||
'1d': 86400 # Update 1d candles every day via HTTP
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"SmartDataUpdater initialized for {len(symbols)} symbols")
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
"""Start the smart data updater"""
|
|
||||||
if self.running:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.running = True
|
|
||||||
|
|
||||||
# Load initial historical data
|
|
||||||
self._load_initial_historical_data()
|
|
||||||
|
|
||||||
# Start update threads
|
|
||||||
self.update_thread = threading.Thread(target=self._update_worker, daemon=True)
|
|
||||||
self.update_thread.start()
|
|
||||||
|
|
||||||
# Start tick processing thread
|
|
||||||
self.tick_thread = threading.Thread(target=self._tick_processor, daemon=True)
|
|
||||||
self.tick_thread.start()
|
|
||||||
|
|
||||||
logger.info("SmartDataUpdater started")
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
"""Stop the smart data updater"""
|
|
||||||
self.running = False
|
|
||||||
logger.info("SmartDataUpdater stopped")
|
|
||||||
|
|
||||||
def add_tick(self, symbol: str, price: float, volume: float, timestamp: datetime = None):
|
|
||||||
"""Add tick data for candle construction"""
|
|
||||||
if symbol not in self.tick_buffers:
|
|
||||||
return
|
|
||||||
|
|
||||||
tick_data = {
|
|
||||||
'price': price,
|
|
||||||
'volume': volume,
|
|
||||||
'timestamp': timestamp or datetime.now()
|
|
||||||
}
|
|
||||||
|
|
||||||
with self.tick_locks[symbol]:
|
|
||||||
self.tick_buffers[symbol].append(tick_data)
|
|
||||||
|
|
||||||
def _load_initial_historical_data(self):
|
|
||||||
"""Load initial historical data for all symbols and timeframes"""
|
|
||||||
logger.info("Loading initial historical data...")
|
|
||||||
|
|
||||||
timeframes = ['1s', '1m', '1h', '1d']
|
|
||||||
limits = {'1s': 300, '1m': 300, '1h': 300, '1d': 300}
|
|
||||||
|
|
||||||
for symbol in self.symbols:
|
|
||||||
self.last_updates[symbol] = {}
|
|
||||||
self.current_candles[symbol] = {}
|
|
||||||
|
|
||||||
for timeframe in timeframes:
|
|
||||||
try:
|
|
||||||
limit = limits.get(timeframe, 300)
|
|
||||||
|
|
||||||
# Get historical data
|
|
||||||
df = None
|
|
||||||
if hasattr(self.data_provider, 'get_historical_data'):
|
|
||||||
df = self.data_provider.get_historical_data(symbol, timeframe, limit=limit)
|
|
||||||
|
|
||||||
if df is not None and not df.empty:
|
|
||||||
# Store in cache
|
|
||||||
self.cache.store_historical_data(symbol, timeframe, df)
|
|
||||||
|
|
||||||
# Update current candle data from latest bar
|
|
||||||
latest_bar = df.iloc[-1]
|
|
||||||
self._update_current_candle_from_bar(symbol, timeframe, latest_bar)
|
|
||||||
|
|
||||||
# Update cache with latest OHLCV
|
|
||||||
ohlcv_bar = self._df_row_to_ohlcv_bar(symbol, timeframe, latest_bar, df.index[-1])
|
|
||||||
self.cache.update(f'ohlcv_{timeframe}', symbol, ohlcv_bar, 'historical')
|
|
||||||
|
|
||||||
self.last_updates[symbol][timeframe] = datetime.now()
|
|
||||||
logger.info(f"Loaded {len(df)} {timeframe} bars for {symbol}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"No historical data for {symbol} {timeframe}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading historical data for {symbol} {timeframe}: {e}")
|
|
||||||
|
|
||||||
# Calculate initial technical indicators
|
|
||||||
self._calculate_technical_indicators()
|
|
||||||
|
|
||||||
logger.info("Initial historical data loading completed")
|
|
||||||
|
|
||||||
def _update_worker(self):
|
|
||||||
"""Background worker for periodic data updates"""
|
|
||||||
while self.running:
|
|
||||||
try:
|
|
||||||
current_time = datetime.now()
|
|
||||||
|
|
||||||
for symbol in self.symbols:
|
|
||||||
for timeframe in ['1m', '1h', '1d']: # Skip 1s (built from ticks)
|
|
||||||
try:
|
|
||||||
# Check if it's time to update
|
|
||||||
last_update = self.last_updates[symbol].get(timeframe)
|
|
||||||
interval = self.update_intervals[timeframe]
|
|
||||||
|
|
||||||
if not last_update or (current_time - last_update).total_seconds() >= interval:
|
|
||||||
self._update_timeframe_data(symbol, timeframe)
|
|
||||||
self.last_updates[symbol][timeframe] = current_time
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error updating {symbol} {timeframe}: {e}")
|
|
||||||
|
|
||||||
# Update technical indicators every minute
|
|
||||||
if current_time.second < 10: # Update in first 10 seconds of each minute
|
|
||||||
self._calculate_technical_indicators()
|
|
||||||
|
|
||||||
time.sleep(10) # Check every 10 seconds
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in update worker: {e}")
|
|
||||||
time.sleep(30)
|
|
||||||
|
|
||||||
def _tick_processor(self):
|
|
||||||
"""Process ticks to build 1s candles"""
|
|
||||||
while self.running:
|
|
||||||
try:
|
|
||||||
current_time = datetime.now()
|
|
||||||
|
|
||||||
for symbol in self.symbols:
|
|
||||||
# Check if it's time to update 1s candles
|
|
||||||
last_update = self.last_updates[symbol].get('1s')
|
|
||||||
if not last_update or (current_time - last_update).total_seconds() >= self.update_intervals['1s']:
|
|
||||||
self._build_1s_candle_from_ticks(symbol)
|
|
||||||
self.last_updates[symbol]['1s'] = current_time
|
|
||||||
|
|
||||||
time.sleep(5) # Process every 5 seconds
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in tick processor: {e}")
|
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
def _update_timeframe_data(self, symbol: str, timeframe: str):
|
|
||||||
"""Update data for a specific timeframe via HTTP"""
|
|
||||||
try:
|
|
||||||
# Get latest data from API
|
|
||||||
df = None
|
|
||||||
if hasattr(self.data_provider, 'get_latest_candles'):
|
|
||||||
df = self.data_provider.get_latest_candles(symbol, timeframe, limit=1)
|
|
||||||
elif hasattr(self.data_provider, 'get_historical_data'):
|
|
||||||
df = self.data_provider.get_historical_data(symbol, timeframe, limit=1)
|
|
||||||
|
|
||||||
if df is not None and not df.empty:
|
|
||||||
latest_bar = df.iloc[-1]
|
|
||||||
|
|
||||||
# Update current candle
|
|
||||||
self._update_current_candle_from_bar(symbol, timeframe, latest_bar)
|
|
||||||
|
|
||||||
# Update cache
|
|
||||||
ohlcv_bar = self._df_row_to_ohlcv_bar(symbol, timeframe, latest_bar, df.index[-1])
|
|
||||||
self.cache.update(f'ohlcv_{timeframe}', symbol, ohlcv_bar, 'http_update')
|
|
||||||
|
|
||||||
logger.debug(f"Updated {symbol} {timeframe} via HTTP")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error updating {symbol} {timeframe} via HTTP: {e}")
|
|
||||||
|
|
||||||
def _build_1s_candle_from_ticks(self, symbol: str):
|
|
||||||
"""Build 1s candle from accumulated ticks"""
|
|
||||||
try:
|
|
||||||
with self.tick_locks[symbol]:
|
|
||||||
ticks = list(self.tick_buffers[symbol])
|
|
||||||
|
|
||||||
if not ticks:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get ticks from last 10 seconds
|
|
||||||
cutoff_time = datetime.now() - timedelta(seconds=10)
|
|
||||||
recent_ticks = [tick for tick in ticks if tick['timestamp'] >= cutoff_time]
|
|
||||||
|
|
||||||
if not recent_ticks:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Build OHLCV from ticks
|
|
||||||
prices = [tick['price'] for tick in recent_ticks]
|
|
||||||
volumes = [tick['volume'] for tick in recent_ticks]
|
|
||||||
|
|
||||||
ohlcv_data = {
|
|
||||||
'open': prices[0],
|
|
||||||
'high': max(prices),
|
|
||||||
'low': min(prices),
|
|
||||||
'close': prices[-1],
|
|
||||||
'volume': sum(volumes)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Update current candle
|
|
||||||
with self.candle_locks[symbol]:
|
|
||||||
self.current_candles[symbol]['1s'] = ohlcv_data
|
|
||||||
|
|
||||||
# Create OHLCV bar and update cache
|
|
||||||
ohlcv_bar = OHLCVBar(
|
|
||||||
symbol=symbol,
|
|
||||||
timestamp=recent_ticks[-1]['timestamp'],
|
|
||||||
open=ohlcv_data['open'],
|
|
||||||
high=ohlcv_data['high'],
|
|
||||||
low=ohlcv_data['low'],
|
|
||||||
close=ohlcv_data['close'],
|
|
||||||
volume=ohlcv_data['volume'],
|
|
||||||
timeframe='1s'
|
|
||||||
)
|
|
||||||
|
|
||||||
self.cache.update('ohlcv_1s', symbol, ohlcv_bar, 'tick_constructed')
|
|
||||||
logger.debug(f"Built 1s candle for {symbol} from {len(recent_ticks)} ticks")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error building 1s candle from ticks for {symbol}: {e}")
|
|
||||||
|
|
||||||
def _update_current_candle_from_bar(self, symbol: str, timeframe: str, bar_data):
|
|
||||||
"""Update current candle data from a bar"""
|
|
||||||
try:
|
|
||||||
with self.candle_locks[symbol]:
|
|
||||||
self.current_candles[symbol][timeframe] = {
|
|
||||||
'open': float(bar_data['open']),
|
|
||||||
'high': float(bar_data['high']),
|
|
||||||
'low': float(bar_data['low']),
|
|
||||||
'close': float(bar_data['close']),
|
|
||||||
'volume': float(bar_data['volume'])
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error updating current candle for {symbol} {timeframe}: {e}")
|
|
||||||
|
|
||||||
def _df_row_to_ohlcv_bar(self, symbol: str, timeframe: str, row, timestamp) -> OHLCVBar:
|
|
||||||
"""Convert DataFrame row to OHLCVBar"""
|
|
||||||
return OHLCVBar(
|
|
||||||
symbol=symbol,
|
|
||||||
timestamp=timestamp if hasattr(timestamp, 'to_pydatetime') else datetime.now(),
|
|
||||||
open=float(row['open']),
|
|
||||||
high=float(row['high']),
|
|
||||||
low=float(row['low']),
|
|
||||||
close=float(row['close']),
|
|
||||||
volume=float(row['volume']),
|
|
||||||
timeframe=timeframe
|
|
||||||
)
|
|
||||||
|
|
||||||
def _calculate_technical_indicators(self):
|
|
||||||
"""Calculate technical indicators for all symbols"""
|
|
||||||
try:
|
|
||||||
for symbol in self.symbols:
|
|
||||||
# Use 1m historical data for indicators
|
|
||||||
df = self.cache.get_historical_data(symbol, '1m')
|
|
||||||
if df is None or len(df) < 20:
|
|
||||||
continue
|
|
||||||
|
|
||||||
indicators = {}
|
|
||||||
try:
|
|
||||||
import ta
|
|
||||||
|
|
||||||
# RSI
|
|
||||||
if len(df) >= 14:
|
|
||||||
indicators['rsi'] = ta.momentum.RSIIndicator(df['close']).rsi().iloc[-1]
|
|
||||||
|
|
||||||
# Moving averages
|
|
||||||
if len(df) >= 20:
|
|
||||||
indicators['sma_20'] = df['close'].rolling(20).mean().iloc[-1]
|
|
||||||
if len(df) >= 12:
|
|
||||||
indicators['ema_12'] = df['close'].ewm(span=12).mean().iloc[-1]
|
|
||||||
if len(df) >= 26:
|
|
||||||
indicators['ema_26'] = df['close'].ewm(span=26).mean().iloc[-1]
|
|
||||||
if 'ema_12' in indicators:
|
|
||||||
indicators['macd'] = indicators['ema_12'] - indicators['ema_26']
|
|
||||||
|
|
||||||
# Bollinger Bands
|
|
||||||
if len(df) >= 20:
|
|
||||||
bb_period = 20
|
|
||||||
bb_std = 2
|
|
||||||
sma = df['close'].rolling(bb_period).mean()
|
|
||||||
std = df['close'].rolling(bb_period).std()
|
|
||||||
indicators['bb_upper'] = (sma + (std * bb_std)).iloc[-1]
|
|
||||||
indicators['bb_lower'] = (sma - (std * bb_std)).iloc[-1]
|
|
||||||
indicators['bb_middle'] = sma.iloc[-1]
|
|
||||||
|
|
||||||
# Remove NaN values
|
|
||||||
indicators = {k: float(v) for k, v in indicators.items() if not pd.isna(v)}
|
|
||||||
|
|
||||||
if indicators:
|
|
||||||
self.cache.update('technical_indicators', symbol, indicators, 'calculated')
|
|
||||||
logger.debug(f"Calculated {len(indicators)} indicators for {symbol}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error calculating indicators for {symbol}: {e}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in technical indicators calculation: {e}")
|
|
||||||
|
|
||||||
def get_current_price(self, symbol: str) -> Optional[float]:
|
|
||||||
"""Get current price from latest 1s candle"""
|
|
||||||
ohlcv_1s = self.cache.get('ohlcv_1s', symbol)
|
|
||||||
if ohlcv_1s:
|
|
||||||
return ohlcv_1s.close
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_status(self) -> Dict[str, Any]:
|
|
||||||
"""Get updater status"""
|
|
||||||
status = {
|
|
||||||
'running': self.running,
|
|
||||||
'symbols': self.symbols,
|
|
||||||
'last_updates': self.last_updates,
|
|
||||||
'tick_buffer_sizes': {symbol: len(buffer) for symbol, buffer in self.tick_buffers.items()},
|
|
||||||
'cache_status': self.cache.get_status()
|
|
||||||
}
|
|
||||||
return status
|
|
@ -1,221 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test COB Data Integration
|
|
||||||
|
|
||||||
This script tests that COB data is properly flowing through to BaseDataInput
|
|
||||||
and being used in the CNN model predictions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
# Add project root to path
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
from core.orchestrator import TradingOrchestrator
|
|
||||||
from core.config import get_config
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def test_cob_data_flow():
|
|
||||||
"""Test that COB data flows through to BaseDataInput"""
|
|
||||||
|
|
||||||
logger.info("=== Testing COB Data Integration ===")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Initialize orchestrator
|
|
||||||
config = get_config()
|
|
||||||
orchestrator = TradingOrchestrator(
|
|
||||||
symbol="ETH/USDT",
|
|
||||||
config=config
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("✅ Orchestrator initialized")
|
|
||||||
|
|
||||||
# Check if COB integration is available
|
|
||||||
if orchestrator.cob_integration:
|
|
||||||
logger.info("✅ COB integration is available")
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ COB integration is not available")
|
|
||||||
|
|
||||||
# Wait a bit for COB data to potentially arrive
|
|
||||||
logger.info("Waiting for COB data...")
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
# Test building BaseDataInput
|
|
||||||
symbol = "ETH/USDT"
|
|
||||||
base_data = orchestrator.build_base_data_input(symbol)
|
|
||||||
|
|
||||||
if base_data:
|
|
||||||
logger.info("✅ BaseDataInput created successfully")
|
|
||||||
|
|
||||||
# Check if COB data is present
|
|
||||||
if base_data.cob_data:
|
|
||||||
logger.info("✅ COB data is present in BaseDataInput")
|
|
||||||
logger.info(f" COB current price: {base_data.cob_data.current_price}")
|
|
||||||
logger.info(f" COB bucket size: {base_data.cob_data.bucket_size}")
|
|
||||||
logger.info(f" COB price buckets: {len(base_data.cob_data.price_buckets)} buckets")
|
|
||||||
logger.info(f" COB bid/ask imbalance: {len(base_data.cob_data.bid_ask_imbalance)} entries")
|
|
||||||
|
|
||||||
# Test feature vector generation
|
|
||||||
features = base_data.get_feature_vector()
|
|
||||||
logger.info(f"✅ Feature vector generated: {len(features)} features")
|
|
||||||
|
|
||||||
# Check if COB features are non-zero (indicating real data)
|
|
||||||
# COB features are at positions 7500-7700 (after OHLCV and BTC data)
|
|
||||||
cob_features = features[7500:7700] # 200 COB features
|
|
||||||
non_zero_cob = sum(1 for f in cob_features if f != 0.0)
|
|
||||||
|
|
||||||
if non_zero_cob > 0:
|
|
||||||
logger.info(f"✅ COB features contain real data: {non_zero_cob}/200 non-zero features")
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ COB features are all zeros (no real COB data)")
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ COB data is None in BaseDataInput")
|
|
||||||
|
|
||||||
# Check if there's COB data in the cache
|
|
||||||
if hasattr(orchestrator, 'data_integration'):
|
|
||||||
cached_cob = orchestrator.data_integration.cache.get('cob_data', symbol)
|
|
||||||
if cached_cob:
|
|
||||||
logger.info("✅ COB data found in cache but not in BaseDataInput")
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ No COB data in cache either")
|
|
||||||
|
|
||||||
# Test CNN prediction with the BaseDataInput
|
|
||||||
if orchestrator.cnn_adapter:
|
|
||||||
logger.info("Testing CNN prediction with BaseDataInput...")
|
|
||||||
try:
|
|
||||||
prediction = orchestrator.cnn_adapter.predict(base_data)
|
|
||||||
if prediction:
|
|
||||||
logger.info("✅ CNN prediction successful")
|
|
||||||
logger.info(f" Action: {prediction.predictions['action']}")
|
|
||||||
logger.info(f" Confidence: {prediction.confidence:.3f}")
|
|
||||||
logger.info(f" Pivot price: {prediction.predictions.get('pivot_price', 'N/A')}")
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ CNN prediction returned None")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ CNN prediction failed: {e}")
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ CNN adapter not available")
|
|
||||||
else:
|
|
||||||
logger.error("❌ Failed to create BaseDataInput")
|
|
||||||
|
|
||||||
# Check orchestrator's latest COB data
|
|
||||||
if hasattr(orchestrator, 'latest_cob_data') and orchestrator.latest_cob_data:
|
|
||||||
logger.info(f"✅ Orchestrator has COB data for symbols: {list(orchestrator.latest_cob_data.keys())}")
|
|
||||||
for sym, cob_data in orchestrator.latest_cob_data.items():
|
|
||||||
logger.info(f" {sym}: {len(cob_data)} COB data fields")
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ No COB data in orchestrator.latest_cob_data")
|
|
||||||
|
|
||||||
return base_data is not None and (base_data.cob_data is not None if base_data else False)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ Test failed: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_cob_cache_updates():
|
|
||||||
"""Test that COB data updates are properly cached"""
|
|
||||||
|
|
||||||
logger.info("=== Testing COB Cache Updates ===")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Initialize orchestrator
|
|
||||||
config = get_config()
|
|
||||||
orchestrator = TradingOrchestrator(
|
|
||||||
symbol="ETH/USDT",
|
|
||||||
config=config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check initial cache state
|
|
||||||
symbol = "ETH/USDT"
|
|
||||||
initial_cob = orchestrator.data_integration.cache.get('cob_data', symbol)
|
|
||||||
logger.info(f"Initial COB data in cache: {initial_cob is not None}")
|
|
||||||
|
|
||||||
# Simulate COB data update
|
|
||||||
from core.data_models import COBData
|
|
||||||
|
|
||||||
mock_cob_data = {
|
|
||||||
'current_price': 3000.0,
|
|
||||||
'price_buckets': {
|
|
||||||
2999.0: {'bid_volume': 100.0, 'ask_volume': 80.0, 'total_volume': 180.0, 'imbalance': 0.11},
|
|
||||||
3000.0: {'bid_volume': 150.0, 'ask_volume': 120.0, 'total_volume': 270.0, 'imbalance': 0.11},
|
|
||||||
3001.0: {'bid_volume': 90.0, 'ask_volume': 110.0, 'total_volume': 200.0, 'imbalance': -0.10}
|
|
||||||
},
|
|
||||||
'bid_ask_imbalance': {2999.0: 0.11, 3000.0: 0.11, 3001.0: -0.10},
|
|
||||||
'volume_weighted_prices': {2999.0: 2999.5, 3000.0: 3000.2, 3001.0: 3000.8},
|
|
||||||
'order_flow_metrics': {'total_volume': 650.0, 'avg_imbalance': 0.04},
|
|
||||||
'ma_1s_imbalance': {3000.0: 0.05},
|
|
||||||
'ma_5s_imbalance': {3000.0: 0.03}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Trigger COB data update through callback
|
|
||||||
logger.info("Simulating COB data update...")
|
|
||||||
orchestrator._on_cob_dashboard_data(symbol, mock_cob_data)
|
|
||||||
|
|
||||||
# Check if cache was updated
|
|
||||||
updated_cob = orchestrator.data_integration.cache.get('cob_data', symbol)
|
|
||||||
if updated_cob:
|
|
||||||
logger.info("✅ COB data successfully updated in cache")
|
|
||||||
logger.info(f" Current price: {updated_cob.current_price}")
|
|
||||||
logger.info(f" Price buckets: {len(updated_cob.price_buckets)}")
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ COB data not found in cache after update")
|
|
||||||
|
|
||||||
# Test BaseDataInput with updated COB data
|
|
||||||
base_data = orchestrator.build_base_data_input(symbol)
|
|
||||||
if base_data and base_data.cob_data:
|
|
||||||
logger.info("✅ BaseDataInput now contains COB data")
|
|
||||||
|
|
||||||
# Test feature vector with real COB data
|
|
||||||
features = base_data.get_feature_vector()
|
|
||||||
cob_features = features[7500:7700] # 200 COB features
|
|
||||||
non_zero_cob = sum(1 for f in cob_features if f != 0.0)
|
|
||||||
logger.info(f"✅ COB features with real data: {non_zero_cob}/200 non-zero")
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ BaseDataInput still doesn't have COB data")
|
|
||||||
|
|
||||||
return updated_cob is not None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ Cache update test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run all COB integration tests"""
|
|
||||||
|
|
||||||
logger.info("Starting COB Data Integration Tests")
|
|
||||||
|
|
||||||
# Test 1: COB data flow
|
|
||||||
test1_passed = test_cob_data_flow()
|
|
||||||
|
|
||||||
# Test 2: COB cache updates
|
|
||||||
test2_passed = test_cob_cache_updates()
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
logger.info("=== Test Summary ===")
|
|
||||||
logger.info(f"COB Data Flow: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
|
|
||||||
logger.info(f"COB Cache Updates: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
|
|
||||||
|
|
||||||
if test1_passed and test2_passed:
|
|
||||||
logger.info("🎉 All tests passed! COB data integration is working.")
|
|
||||||
logger.info("The system now:")
|
|
||||||
logger.info(" - Properly integrates COB data into BaseDataInput")
|
|
||||||
logger.info(" - Updates cache when COB data arrives")
|
|
||||||
logger.info(" - Includes COB features in CNN model input")
|
|
||||||
else:
|
|
||||||
logger.error("❌ Some tests failed. COB integration needs attention.")
|
|
||||||
if not test1_passed:
|
|
||||||
logger.error(" - COB data is not flowing to BaseDataInput")
|
|
||||||
if not test2_passed:
|
|
||||||
logger.error(" - COB cache updates are not working")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,527 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Complete Training System Integration Test
|
|
||||||
|
|
||||||
This script demonstrates the full training system integration including:
|
|
||||||
- Comprehensive training data collection with validation
|
|
||||||
- CNN training pipeline with profitable episode replay
|
|
||||||
- RL training pipeline with profit-weighted experience replay
|
|
||||||
- Integration with existing DataProvider and models
|
|
||||||
- Real-time outcome validation and profitability tracking
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import time
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Setup logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Import the complete training system
|
|
||||||
from core.training_data_collector import TrainingDataCollector
|
|
||||||
from core.cnn_training_pipeline import CNNPivotPredictor, CNNTrainer
|
|
||||||
from core.rl_training_pipeline import RLTradingAgent, RLTrainer
|
|
||||||
from core.enhanced_training_integration import EnhancedTrainingIntegration, EnhancedTrainingConfig
|
|
||||||
from core.data_provider import DataProvider
|
|
||||||
|
|
||||||
def create_mock_data_provider():
|
|
||||||
"""Create a mock data provider for testing"""
|
|
||||||
class MockDataProvider:
|
|
||||||
def __init__(self):
|
|
||||||
self.symbols = ['ETH/USDT', 'BTC/USDT']
|
|
||||||
self.timeframes = ['1s', '1m', '5m', '15m', '1h', '1d']
|
|
||||||
|
|
||||||
def get_historical_data(self, symbol, timeframe, limit=300, refresh=False):
|
|
||||||
"""Generate mock OHLCV data"""
|
|
||||||
dates = pd.date_range(start='2024-01-01', periods=limit, freq='1min')
|
|
||||||
|
|
||||||
# Generate realistic price data
|
|
||||||
base_price = 3000.0 if 'ETH' in symbol else 50000.0
|
|
||||||
price_data = []
|
|
||||||
current_price = base_price
|
|
||||||
|
|
||||||
for i in range(limit):
|
|
||||||
change = np.random.normal(0, 0.002)
|
|
||||||
current_price *= (1 + change)
|
|
||||||
|
|
||||||
price_data.append({
|
|
||||||
'timestamp': dates[i],
|
|
||||||
'open': current_price,
|
|
||||||
'high': current_price * (1 + abs(np.random.normal(0, 0.001))),
|
|
||||||
'low': current_price * (1 - abs(np.random.normal(0, 0.001))),
|
|
||||||
'close': current_price * (1 + np.random.normal(0, 0.0005)),
|
|
||||||
'volume': np.random.uniform(100, 1000),
|
|
||||||
'rsi_14': np.random.uniform(30, 70),
|
|
||||||
'macd': np.random.normal(0, 0.5),
|
|
||||||
'sma_20': current_price * (1 + np.random.normal(0, 0.01))
|
|
||||||
})
|
|
||||||
|
|
||||||
current_price = price_data[-1]['close']
|
|
||||||
|
|
||||||
df = pd.DataFrame(price_data)
|
|
||||||
df.set_index('timestamp', inplace=True)
|
|
||||||
return df
|
|
||||||
|
|
||||||
return MockDataProvider()
|
|
||||||
|
|
||||||
def test_training_data_collection():
|
|
||||||
"""Test the comprehensive training data collection system"""
|
|
||||||
logger.info("=== Testing Training Data Collection ===")
|
|
||||||
|
|
||||||
collector = TrainingDataCollector(
|
|
||||||
storage_dir="test_complete_training/data_collection",
|
|
||||||
max_episodes_per_symbol=1000
|
|
||||||
)
|
|
||||||
|
|
||||||
collector.start_collection()
|
|
||||||
|
|
||||||
# Simulate data collection for multiple episodes
|
|
||||||
for i in range(20):
|
|
||||||
symbol = 'ETHUSDT'
|
|
||||||
|
|
||||||
# Create sample data
|
|
||||||
ohlcv_data = {}
|
|
||||||
for timeframe in ['1s', '1m', '5m', '15m', '1h']:
|
|
||||||
dates = pd.date_range(start='2024-01-01', periods=300, freq='1min')
|
|
||||||
base_price = 3000.0 + i * 10 # Vary price over episodes
|
|
||||||
|
|
||||||
price_data = []
|
|
||||||
current_price = base_price
|
|
||||||
|
|
||||||
for j in range(300):
|
|
||||||
change = np.random.normal(0, 0.002)
|
|
||||||
current_price *= (1 + change)
|
|
||||||
|
|
||||||
price_data.append({
|
|
||||||
'timestamp': dates[j],
|
|
||||||
'open': current_price,
|
|
||||||
'high': current_price * (1 + abs(np.random.normal(0, 0.001))),
|
|
||||||
'low': current_price * (1 - abs(np.random.normal(0, 0.001))),
|
|
||||||
'close': current_price * (1 + np.random.normal(0, 0.0005)),
|
|
||||||
'volume': np.random.uniform(100, 1000)
|
|
||||||
})
|
|
||||||
|
|
||||||
current_price = price_data[-1]['close']
|
|
||||||
|
|
||||||
df = pd.DataFrame(price_data)
|
|
||||||
df.set_index('timestamp', inplace=True)
|
|
||||||
ohlcv_data[timeframe] = df
|
|
||||||
|
|
||||||
# Create other data
|
|
||||||
tick_data = [
|
|
||||||
{
|
|
||||||
'timestamp': datetime.now() - timedelta(seconds=j),
|
|
||||||
'price': base_price + np.random.normal(0, 5),
|
|
||||||
'volume': np.random.uniform(0.1, 10.0),
|
|
||||||
'side': 'buy' if np.random.random() > 0.5 else 'sell',
|
|
||||||
'trade_id': f'trade_{i}_{j}'
|
|
||||||
}
|
|
||||||
for j in range(100)
|
|
||||||
]
|
|
||||||
|
|
||||||
cob_data = {
|
|
||||||
'timestamp': datetime.now(),
|
|
||||||
'cob_features': np.random.randn(120).tolist(),
|
|
||||||
'spread': np.random.uniform(0.5, 2.0)
|
|
||||||
}
|
|
||||||
|
|
||||||
technical_indicators = {
|
|
||||||
'rsi_14': np.random.uniform(30, 70),
|
|
||||||
'macd': np.random.normal(0, 0.5),
|
|
||||||
'sma_20': base_price * (1 + np.random.normal(0, 0.01)),
|
|
||||||
'ema_12': base_price * (1 + np.random.normal(0, 0.01))
|
|
||||||
}
|
|
||||||
|
|
||||||
pivot_points = [
|
|
||||||
{
|
|
||||||
'timestamp': datetime.now() - timedelta(minutes=30),
|
|
||||||
'price': base_price + np.random.normal(0, 20),
|
|
||||||
'type': 'high' if np.random.random() > 0.5 else 'low'
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create features
|
|
||||||
cnn_features = np.random.randn(2000).astype(np.float32)
|
|
||||||
rl_state = np.random.randn(2000).astype(np.float32)
|
|
||||||
|
|
||||||
orchestrator_context = {
|
|
||||||
'market_session': 'european',
|
|
||||||
'volatility_regime': 'medium',
|
|
||||||
'trend_direction': 'uptrend'
|
|
||||||
}
|
|
||||||
|
|
||||||
# Collect training data
|
|
||||||
episode_id = collector.collect_training_data(
|
|
||||||
symbol=symbol,
|
|
||||||
ohlcv_data=ohlcv_data,
|
|
||||||
tick_data=tick_data,
|
|
||||||
cob_data=cob_data,
|
|
||||||
technical_indicators=technical_indicators,
|
|
||||||
pivot_points=pivot_points,
|
|
||||||
cnn_features=cnn_features,
|
|
||||||
rl_state=rl_state,
|
|
||||||
orchestrator_context=orchestrator_context
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Created episode {i+1}: {episode_id}")
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
# Get statistics
|
|
||||||
stats = collector.get_collection_statistics()
|
|
||||||
logger.info(f"Collection statistics: {stats}")
|
|
||||||
|
|
||||||
# Validate data integrity
|
|
||||||
validation = collector.validate_data_integrity()
|
|
||||||
logger.info(f"Data integrity: {validation}")
|
|
||||||
|
|
||||||
collector.stop_collection()
|
|
||||||
return collector
|
|
||||||
|
|
||||||
def test_cnn_training_pipeline():
|
|
||||||
"""Test the CNN training pipeline with profitable episode replay"""
|
|
||||||
logger.info("=== Testing CNN Training Pipeline ===")
|
|
||||||
|
|
||||||
# Initialize CNN model and trainer
|
|
||||||
model = CNNPivotPredictor(
|
|
||||||
input_channels=10,
|
|
||||||
sequence_length=300,
|
|
||||||
hidden_dim=256,
|
|
||||||
num_pivot_classes=3
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = CNNTrainer(
|
|
||||||
model=model,
|
|
||||||
device='cpu',
|
|
||||||
learning_rate=0.001,
|
|
||||||
storage_dir="test_complete_training/cnn_training"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create sample training episodes with outcomes
|
|
||||||
from core.training_data_collector import TrainingEpisode, ModelInputPackage, TrainingOutcome
|
|
||||||
|
|
||||||
episodes = []
|
|
||||||
for i in range(100):
|
|
||||||
# Create input package
|
|
||||||
input_package = ModelInputPackage(
|
|
||||||
timestamp=datetime.now() - timedelta(minutes=i),
|
|
||||||
symbol='ETHUSDT',
|
|
||||||
ohlcv_data={}, # Simplified for testing
|
|
||||||
tick_data=[],
|
|
||||||
cob_data={},
|
|
||||||
technical_indicators={'rsi': 50.0 + i},
|
|
||||||
pivot_points=[],
|
|
||||||
cnn_features=np.random.randn(2000).astype(np.float32),
|
|
||||||
rl_state=np.random.randn(2000).astype(np.float32),
|
|
||||||
orchestrator_context={}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create outcome with varying profitability
|
|
||||||
is_profitable = np.random.random() > 0.3 # 70% profitable
|
|
||||||
profitability_score = np.random.uniform(0.7, 1.0) if is_profitable else np.random.uniform(0.0, 0.3)
|
|
||||||
|
|
||||||
outcome = TrainingOutcome(
|
|
||||||
input_package_hash=input_package.data_hash,
|
|
||||||
timestamp=input_package.timestamp,
|
|
||||||
symbol='ETHUSDT',
|
|
||||||
price_change_1m=np.random.normal(0, 0.01),
|
|
||||||
price_change_5m=np.random.normal(0, 0.02),
|
|
||||||
price_change_15m=np.random.normal(0, 0.03),
|
|
||||||
price_change_1h=np.random.normal(0, 0.05),
|
|
||||||
max_profit_potential=abs(np.random.normal(0, 0.02)),
|
|
||||||
max_loss_potential=abs(np.random.normal(0, 0.015)),
|
|
||||||
optimal_entry_price=3000.0,
|
|
||||||
optimal_exit_price=3000.0 + np.random.normal(0, 10),
|
|
||||||
optimal_holding_time=timedelta(minutes=np.random.randint(5, 60)),
|
|
||||||
is_profitable=is_profitable,
|
|
||||||
profitability_score=profitability_score,
|
|
||||||
risk_reward_ratio=np.random.uniform(1.0, 3.0),
|
|
||||||
is_rapid_change=np.random.random() > 0.8,
|
|
||||||
change_velocity=np.random.uniform(0.1, 2.0),
|
|
||||||
volatility_spike=np.random.random() > 0.9,
|
|
||||||
outcome_validated=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create episode
|
|
||||||
episode = TrainingEpisode(
|
|
||||||
episode_id=f"cnn_test_episode_{i}",
|
|
||||||
input_package=input_package,
|
|
||||||
model_predictions={},
|
|
||||||
actual_outcome=outcome,
|
|
||||||
episode_type='high_profit' if profitability_score > 0.8 else 'normal'
|
|
||||||
)
|
|
||||||
|
|
||||||
episodes.append(episode)
|
|
||||||
|
|
||||||
# Test training on all episodes
|
|
||||||
logger.info("Training on all episodes...")
|
|
||||||
results = trainer._train_on_episodes(episodes, training_mode='test_batch')
|
|
||||||
logger.info(f"Training results: {results}")
|
|
||||||
|
|
||||||
# Test training on profitable episodes only
|
|
||||||
logger.info("Training on profitable episodes only...")
|
|
||||||
profitable_results = trainer.train_on_profitable_episodes(
|
|
||||||
symbol='ETHUSDT',
|
|
||||||
min_profitability=0.7,
|
|
||||||
max_episodes=50
|
|
||||||
)
|
|
||||||
logger.info(f"Profitable training results: {profitable_results}")
|
|
||||||
|
|
||||||
# Get training statistics
|
|
||||||
stats = trainer.get_training_statistics()
|
|
||||||
logger.info(f"CNN training statistics: {stats}")
|
|
||||||
|
|
||||||
return trainer
|
|
||||||
|
|
||||||
def test_rl_training_pipeline():
|
|
||||||
"""Test the RL training pipeline with profit-weighted experience replay"""
|
|
||||||
logger.info("=== Testing RL Training Pipeline ===")
|
|
||||||
|
|
||||||
# Initialize RL agent and trainer
|
|
||||||
agent = RLTradingAgent(state_dim=2000, action_dim=3, hidden_dim=512)
|
|
||||||
trainer = RLTrainer(
|
|
||||||
agent=agent,
|
|
||||||
device='cpu',
|
|
||||||
storage_dir="test_complete_training/rl_training"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add sample experiences with varying profitability
|
|
||||||
logger.info("Adding sample experiences...")
|
|
||||||
experience_ids = []
|
|
||||||
|
|
||||||
for i in range(200):
|
|
||||||
state = np.random.randn(2000).astype(np.float32)
|
|
||||||
action = np.random.randint(0, 3) # SELL, HOLD, BUY
|
|
||||||
reward = np.random.normal(0, 0.1)
|
|
||||||
next_state = np.random.randn(2000).astype(np.float32)
|
|
||||||
done = np.random.random() > 0.9
|
|
||||||
|
|
||||||
market_context = {
|
|
||||||
'symbol': 'ETHUSDT',
|
|
||||||
'episode_id': f'rl_episode_{i}',
|
|
||||||
'timestamp': datetime.now() - timedelta(minutes=i),
|
|
||||||
'market_session': 'european',
|
|
||||||
'volatility_regime': 'medium'
|
|
||||||
}
|
|
||||||
|
|
||||||
cnn_predictions = {
|
|
||||||
'pivot_logits': np.random.randn(3).tolist(),
|
|
||||||
'confidence': np.random.uniform(0.3, 0.9)
|
|
||||||
}
|
|
||||||
|
|
||||||
experience_id = trainer.add_experience(
|
|
||||||
state=state,
|
|
||||||
action=action,
|
|
||||||
reward=reward,
|
|
||||||
next_state=next_state,
|
|
||||||
done=done,
|
|
||||||
market_context=market_context,
|
|
||||||
cnn_predictions=cnn_predictions,
|
|
||||||
confidence_score=np.random.uniform(0.3, 0.9)
|
|
||||||
)
|
|
||||||
|
|
||||||
if experience_id:
|
|
||||||
experience_ids.append(experience_id)
|
|
||||||
|
|
||||||
# Simulate outcome validation for some experiences
|
|
||||||
if np.random.random() > 0.5: # 50% get outcomes
|
|
||||||
actual_profit = np.random.normal(0, 0.02)
|
|
||||||
optimal_action = np.random.randint(0, 3)
|
|
||||||
|
|
||||||
trainer.experience_buffer.update_experience_outcomes(
|
|
||||||
experience_id, actual_profit, optimal_action
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Added {len(experience_ids)} experiences")
|
|
||||||
|
|
||||||
# Test training on experiences
|
|
||||||
logger.info("Training on experiences...")
|
|
||||||
results = trainer.train_on_experiences(batch_size=32, num_batches=20)
|
|
||||||
logger.info(f"RL training results: {results}")
|
|
||||||
|
|
||||||
# Test training on profitable experiences only
|
|
||||||
logger.info("Training on profitable experiences only...")
|
|
||||||
profitable_results = trainer.train_on_profitable_experiences(
|
|
||||||
min_profitability=0.01,
|
|
||||||
max_experiences=100,
|
|
||||||
batch_size=32
|
|
||||||
)
|
|
||||||
logger.info(f"Profitable RL training results: {profitable_results}")
|
|
||||||
|
|
||||||
# Get training statistics
|
|
||||||
stats = trainer.get_training_statistics()
|
|
||||||
logger.info(f"RL training statistics: {stats}")
|
|
||||||
|
|
||||||
# Get buffer statistics
|
|
||||||
buffer_stats = trainer.experience_buffer.get_buffer_statistics()
|
|
||||||
logger.info(f"Experience buffer statistics: {buffer_stats}")
|
|
||||||
|
|
||||||
return trainer
|
|
||||||
|
|
||||||
def test_enhanced_integration():
|
|
||||||
"""Test the complete enhanced training integration"""
|
|
||||||
logger.info("=== Testing Enhanced Training Integration ===")
|
|
||||||
|
|
||||||
# Create mock data provider
|
|
||||||
data_provider = create_mock_data_provider()
|
|
||||||
|
|
||||||
# Create enhanced training configuration
|
|
||||||
config = EnhancedTrainingConfig(
|
|
||||||
collection_interval=0.5, # Faster for testing
|
|
||||||
min_data_completeness=0.7,
|
|
||||||
min_episodes_for_cnn_training=10, # Lower for testing
|
|
||||||
min_experiences_for_rl_training=20, # Lower for testing
|
|
||||||
training_frequency_minutes=1, # Faster for testing
|
|
||||||
min_profitability_for_replay=0.05,
|
|
||||||
use_existing_cob_rl_model=False, # Don't use for testing
|
|
||||||
enable_cross_model_learning=True,
|
|
||||||
enable_background_validation=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize enhanced integration
|
|
||||||
integration = EnhancedTrainingIntegration(
|
|
||||||
data_provider=data_provider,
|
|
||||||
config=config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Start integration
|
|
||||||
logger.info("Starting enhanced training integration...")
|
|
||||||
integration.start_enhanced_integration()
|
|
||||||
|
|
||||||
# Let it run for a short time
|
|
||||||
logger.info("Running integration for 30 seconds...")
|
|
||||||
time.sleep(30)
|
|
||||||
|
|
||||||
# Get statistics
|
|
||||||
stats = integration.get_integration_statistics()
|
|
||||||
logger.info(f"Integration statistics: {stats}")
|
|
||||||
|
|
||||||
# Test manual training trigger
|
|
||||||
logger.info("Testing manual training trigger...")
|
|
||||||
manual_results = integration.trigger_manual_training(training_type='all')
|
|
||||||
logger.info(f"Manual training results: {manual_results}")
|
|
||||||
|
|
||||||
# Stop integration
|
|
||||||
logger.info("Stopping enhanced training integration...")
|
|
||||||
integration.stop_enhanced_integration()
|
|
||||||
|
|
||||||
return integration
|
|
||||||
|
|
||||||
def test_complete_system():
|
|
||||||
"""Test the complete training system integration"""
|
|
||||||
logger.info("=== Testing Complete Training System ===")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Test individual components
|
|
||||||
logger.info("Testing individual components...")
|
|
||||||
|
|
||||||
collector = test_training_data_collection()
|
|
||||||
cnn_trainer = test_cnn_training_pipeline()
|
|
||||||
rl_trainer = test_rl_training_pipeline()
|
|
||||||
|
|
||||||
logger.info("✅ Individual components tested successfully!")
|
|
||||||
|
|
||||||
# Test complete integration
|
|
||||||
logger.info("Testing complete integration...")
|
|
||||||
integration = test_enhanced_integration()
|
|
||||||
|
|
||||||
logger.info("✅ Complete integration tested successfully!")
|
|
||||||
|
|
||||||
# Generate comprehensive report
|
|
||||||
logger.info("\n" + "="*80)
|
|
||||||
logger.info("COMPREHENSIVE TRAINING SYSTEM TEST REPORT")
|
|
||||||
logger.info("="*80)
|
|
||||||
|
|
||||||
# Data collection report
|
|
||||||
collection_stats = collector.get_collection_statistics()
|
|
||||||
logger.info(f"\n📊 DATA COLLECTION:")
|
|
||||||
logger.info(f" • Total episodes: {collection_stats.get('total_episodes', 0)}")
|
|
||||||
logger.info(f" • Profitable episodes: {collection_stats.get('profitable_episodes', 0)}")
|
|
||||||
logger.info(f" • Rapid change episodes: {collection_stats.get('rapid_change_episodes', 0)}")
|
|
||||||
logger.info(f" • Data completeness avg: {collection_stats.get('data_completeness_avg', 0):.3f}")
|
|
||||||
|
|
||||||
# CNN training report
|
|
||||||
cnn_stats = cnn_trainer.get_training_statistics()
|
|
||||||
logger.info(f"\n🧠 CNN TRAINING:")
|
|
||||||
logger.info(f" • Total sessions: {cnn_stats.get('total_sessions', 0)}")
|
|
||||||
logger.info(f" • Total steps: {cnn_stats.get('total_steps', 0)}")
|
|
||||||
logger.info(f" • Replay sessions: {cnn_stats.get('replay_sessions', 0)}")
|
|
||||||
|
|
||||||
# RL training report
|
|
||||||
rl_stats = rl_trainer.get_training_statistics()
|
|
||||||
logger.info(f"\n🤖 RL TRAINING:")
|
|
||||||
logger.info(f" • Total sessions: {rl_stats.get('total_sessions', 0)}")
|
|
||||||
logger.info(f" • Total experiences: {rl_stats.get('total_experiences', 0)}")
|
|
||||||
logger.info(f" • Average reward: {rl_stats.get('average_reward', 0):.4f}")
|
|
||||||
|
|
||||||
# Integration report
|
|
||||||
integration_stats = integration.get_integration_statistics()
|
|
||||||
logger.info(f"\n🔗 INTEGRATION:")
|
|
||||||
logger.info(f" • Total data packages: {integration_stats.get('total_data_packages', 0)}")
|
|
||||||
logger.info(f" • CNN training sessions: {integration_stats.get('cnn_training_sessions', 0)}")
|
|
||||||
logger.info(f" • RL training sessions: {integration_stats.get('rl_training_sessions', 0)}")
|
|
||||||
logger.info(f" • Overall profitability rate: {integration_stats.get('overall_profitability_rate', 0):.3f}")
|
|
||||||
|
|
||||||
logger.info("\n🎯 SYSTEM CAPABILITIES DEMONSTRATED:")
|
|
||||||
logger.info(" ✓ Comprehensive training data collection with validation")
|
|
||||||
logger.info(" ✓ CNN training with profitable episode replay")
|
|
||||||
logger.info(" ✓ RL training with profit-weighted experience replay")
|
|
||||||
logger.info(" ✓ Real-time outcome validation and profitability tracking")
|
|
||||||
logger.info(" ✓ Integrated training coordination across all models")
|
|
||||||
logger.info(" ✓ Gradient and backpropagation data storage for replay")
|
|
||||||
logger.info(" ✓ Rapid price change detection for premium training examples")
|
|
||||||
logger.info(" ✓ Data integrity validation and completeness checking")
|
|
||||||
|
|
||||||
logger.info("\n🚀 READY FOR PRODUCTION INTEGRATION:")
|
|
||||||
logger.info(" 1. Connect to your existing DataProvider")
|
|
||||||
logger.info(" 2. Integrate with your CNN and RL models")
|
|
||||||
logger.info(" 3. Connect to your Orchestrator and TradingExecutor")
|
|
||||||
logger.info(" 4. Enable real-time outcome validation")
|
|
||||||
logger.info(" 5. Deploy with monitoring and alerting")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ Complete system test failed: {e}")
|
|
||||||
import traceback
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main test function"""
|
|
||||||
logger.info("=" * 100)
|
|
||||||
logger.info("COMPREHENSIVE TRAINING SYSTEM INTEGRATION TEST")
|
|
||||||
logger.info("=" * 100)
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Run complete system test
|
|
||||||
success = test_complete_system()
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
duration = end_time - start_time
|
|
||||||
|
|
||||||
logger.info("=" * 100)
|
|
||||||
if success:
|
|
||||||
logger.info("🎉 ALL TESTS PASSED! TRAINING SYSTEM READY FOR PRODUCTION!")
|
|
||||||
else:
|
|
||||||
logger.info("❌ SOME TESTS FAILED - CHECK LOGS FOR DETAILS")
|
|
||||||
|
|
||||||
logger.info(f"Total test duration: {duration:.2f} seconds")
|
|
||||||
logger.info("=" * 100)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ Test execution failed: {e}")
|
|
||||||
import traceback
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,187 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test Fixed Input Size
|
|
||||||
|
|
||||||
Verify that the CNN model now receives consistent input dimensions
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from datetime import datetime
|
|
||||||
from core.data_models import BaseDataInput, OHLCVBar
|
|
||||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
|
||||||
|
|
||||||
def create_test_data(with_cob=True, with_indicators=True):
|
|
||||||
"""Create test BaseDataInput with varying data completeness"""
|
|
||||||
|
|
||||||
# Create basic OHLCV data
|
|
||||||
ohlcv_bars = []
|
|
||||||
for i in range(100): # Less than 300 to test padding
|
|
||||||
bar = OHLCVBar(
|
|
||||||
symbol="ETH/USDT",
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
open=100.0 + i,
|
|
||||||
high=101.0 + i,
|
|
||||||
low=99.0 + i,
|
|
||||||
close=100.5 + i,
|
|
||||||
volume=1000 + i,
|
|
||||||
timeframe="1s"
|
|
||||||
)
|
|
||||||
ohlcv_bars.append(bar)
|
|
||||||
|
|
||||||
# Create test data
|
|
||||||
base_data = BaseDataInput(
|
|
||||||
symbol="ETH/USDT",
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
ohlcv_1s=ohlcv_bars,
|
|
||||||
ohlcv_1m=ohlcv_bars[:50], # Even less data
|
|
||||||
ohlcv_1h=ohlcv_bars[:20],
|
|
||||||
ohlcv_1d=ohlcv_bars[:10],
|
|
||||||
btc_ohlcv_1s=ohlcv_bars[:80], # Incomplete BTC data
|
|
||||||
technical_indicators={'rsi': 50.0, 'macd': 0.1} if with_indicators else {},
|
|
||||||
last_predictions={}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add COB data if requested (simplified for testing)
|
|
||||||
if with_cob:
|
|
||||||
# Create a simple mock COB data object
|
|
||||||
class MockCOBData:
|
|
||||||
def __init__(self):
|
|
||||||
self.price_buckets = {
|
|
||||||
2500.0: {'bid_volume': 100, 'ask_volume': 90, 'total_volume': 190, 'imbalance': 0.05},
|
|
||||||
2501.0: {'bid_volume': 80, 'ask_volume': 120, 'total_volume': 200, 'imbalance': -0.2}
|
|
||||||
}
|
|
||||||
self.ma_1s_imbalance = {2500.0: 0.1, 2501.0: -0.1}
|
|
||||||
self.ma_5s_imbalance = {2500.0: 0.05, 2501.0: -0.05}
|
|
||||||
|
|
||||||
base_data.cob_data = MockCOBData()
|
|
||||||
|
|
||||||
return base_data
|
|
||||||
|
|
||||||
def test_consistent_feature_size():
|
|
||||||
"""Test that feature vectors are always the same size"""
|
|
||||||
print("=== Testing Consistent Feature Size ===")
|
|
||||||
|
|
||||||
# Test different data scenarios
|
|
||||||
scenarios = [
|
|
||||||
("Full data", True, True),
|
|
||||||
("No COB data", False, True),
|
|
||||||
("No indicators", True, False),
|
|
||||||
("Minimal data", False, False)
|
|
||||||
]
|
|
||||||
|
|
||||||
feature_sizes = []
|
|
||||||
|
|
||||||
for name, with_cob, with_indicators in scenarios:
|
|
||||||
base_data = create_test_data(with_cob, with_indicators)
|
|
||||||
features = base_data.get_feature_vector()
|
|
||||||
|
|
||||||
print(f"{name}: {len(features)} features")
|
|
||||||
feature_sizes.append(len(features))
|
|
||||||
|
|
||||||
# Check if all sizes are the same
|
|
||||||
if len(set(feature_sizes)) == 1:
|
|
||||||
print(f"✅ All feature vectors have consistent size: {feature_sizes[0]}")
|
|
||||||
return feature_sizes[0]
|
|
||||||
else:
|
|
||||||
print(f"❌ Inconsistent feature sizes: {feature_sizes}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def test_cnn_adapter():
|
|
||||||
"""Test that CNN adapter works with fixed input size"""
|
|
||||||
print("\n=== Testing CNN Adapter ===")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create CNN adapter
|
|
||||||
adapter = EnhancedCNNAdapter()
|
|
||||||
print(f"CNN model initialized with feature_dim: {adapter.model.feature_dim}")
|
|
||||||
|
|
||||||
# Test with different data scenarios
|
|
||||||
scenarios = [
|
|
||||||
("Full data", True, True),
|
|
||||||
("No COB data", False, True),
|
|
||||||
("Minimal data", False, False)
|
|
||||||
]
|
|
||||||
|
|
||||||
for name, with_cob, with_indicators in scenarios:
|
|
||||||
try:
|
|
||||||
base_data = create_test_data(with_cob, with_indicators)
|
|
||||||
|
|
||||||
# Make prediction
|
|
||||||
result = adapter.predict(base_data)
|
|
||||||
|
|
||||||
print(f"✅ {name}: Prediction successful - {result.action} (conf={result.confidence:.3f})")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ {name}: Prediction failed - {e}")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ CNN adapter initialization failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_no_network_rebuilding():
|
|
||||||
"""Test that network doesn't rebuild during runtime"""
|
|
||||||
print("\n=== Testing No Network Rebuilding ===")
|
|
||||||
|
|
||||||
try:
|
|
||||||
adapter = EnhancedCNNAdapter()
|
|
||||||
original_feature_dim = adapter.model.feature_dim
|
|
||||||
|
|
||||||
print(f"Original feature_dim: {original_feature_dim}")
|
|
||||||
|
|
||||||
# Make multiple predictions with different data
|
|
||||||
for i in range(5):
|
|
||||||
base_data = create_test_data(with_cob=(i % 2 == 0), with_indicators=(i % 3 == 0))
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = adapter.predict(base_data)
|
|
||||||
current_feature_dim = adapter.model.feature_dim
|
|
||||||
|
|
||||||
if current_feature_dim != original_feature_dim:
|
|
||||||
print(f"❌ Network was rebuilt! Original: {original_feature_dim}, Current: {current_feature_dim}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"✅ Prediction {i+1}: No rebuilding, feature_dim stable at {current_feature_dim}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Prediction {i+1} failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print("✅ Network architecture remained stable throughout all predictions")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run all tests"""
|
|
||||||
print("=== Fixed Input Size Test Suite ===\n")
|
|
||||||
|
|
||||||
# Test 1: Consistent feature size
|
|
||||||
fixed_size = test_consistent_feature_size()
|
|
||||||
|
|
||||||
if fixed_size:
|
|
||||||
# Test 2: CNN adapter works
|
|
||||||
adapter_works = test_cnn_adapter()
|
|
||||||
|
|
||||||
if adapter_works:
|
|
||||||
# Test 3: No network rebuilding
|
|
||||||
no_rebuilding = test_no_network_rebuilding()
|
|
||||||
|
|
||||||
if no_rebuilding:
|
|
||||||
print("\n✅ ALL TESTS PASSED!")
|
|
||||||
print("✅ Feature vectors have consistent size")
|
|
||||||
print("✅ CNN adapter works with fixed input")
|
|
||||||
print("✅ No runtime network rebuilding")
|
|
||||||
print(f"✅ Fixed feature size: {fixed_size}")
|
|
||||||
else:
|
|
||||||
print("\n❌ Network rebuilding test failed")
|
|
||||||
else:
|
|
||||||
print("\n❌ CNN adapter test failed")
|
|
||||||
else:
|
|
||||||
print("\n❌ Feature size consistency test failed")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -57,8 +57,7 @@ def test_integrated_standardized_provider():
|
|||||||
# Test 3: Test BaseDataInput with cross-model feeding
|
# Test 3: Test BaseDataInput with cross-model feeding
|
||||||
print("\n3. Testing BaseDataInput with cross-model predictions...")
|
print("\n3. Testing BaseDataInput with cross-model predictions...")
|
||||||
|
|
||||||
# Set mock current price for COB data
|
# Use real current prices only - no mock data
|
||||||
provider.current_prices['ETHUSDT'] = 3000.0
|
|
||||||
|
|
||||||
base_input = provider.get_base_data_input('ETH/USDT')
|
base_input = provider.get_base_data_input('ETH/USDT')
|
||||||
|
|
||||||
|
@ -1,261 +0,0 @@
|
|||||||
"""
|
|
||||||
Test script for StandardizedCNN
|
|
||||||
|
|
||||||
This script tests the standardized CNN model with BaseDataInput format
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import torch
|
|
||||||
from datetime import datetime
|
|
||||||
from core.standardized_data_provider import StandardizedDataProvider
|
|
||||||
from NN.models.standardized_cnn import StandardizedCNN
|
|
||||||
|
|
||||||
# Set up logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def test_standardized_cnn():
|
|
||||||
"""Test the StandardizedCNN with BaseDataInput"""
|
|
||||||
|
|
||||||
print("Testing StandardizedCNN with BaseDataInput...")
|
|
||||||
|
|
||||||
# Initialize data provider
|
|
||||||
symbols = ['ETH/USDT', 'BTC/USDT']
|
|
||||||
provider = StandardizedDataProvider(symbols=symbols)
|
|
||||||
|
|
||||||
# Initialize CNN model
|
|
||||||
cnn_model = StandardizedCNN(
|
|
||||||
model_name="test_standardized_cnn_v1",
|
|
||||||
confidence_threshold=0.6
|
|
||||||
)
|
|
||||||
|
|
||||||
print("✅ StandardizedCNN initialized")
|
|
||||||
print(f" Model info: {cnn_model.get_model_info()}")
|
|
||||||
|
|
||||||
# Test 1: Get BaseDataInput
|
|
||||||
print("\n1. Testing BaseDataInput creation...")
|
|
||||||
|
|
||||||
# Set mock current price for COB data
|
|
||||||
provider.current_prices['ETHUSDT'] = 3000.0
|
|
||||||
provider.current_prices['BTCUSDT'] = 50000.0
|
|
||||||
|
|
||||||
base_input = provider.get_base_data_input('ETH/USDT')
|
|
||||||
|
|
||||||
if base_input is None:
|
|
||||||
print("⚠️ BaseDataInput is None - creating mock data for testing")
|
|
||||||
# Create mock BaseDataInput for testing
|
|
||||||
from core.data_models import BaseDataInput, OHLCVBar, COBData
|
|
||||||
|
|
||||||
# Create mock OHLCV data
|
|
||||||
mock_ohlcv = []
|
|
||||||
for i in range(300):
|
|
||||||
bar = OHLCVBar(
|
|
||||||
symbol='ETH/USDT',
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
open=3000.0 + i,
|
|
||||||
high=3010.0 + i,
|
|
||||||
low=2990.0 + i,
|
|
||||||
close=3005.0 + i,
|
|
||||||
volume=1000.0,
|
|
||||||
timeframe='1s'
|
|
||||||
)
|
|
||||||
mock_ohlcv.append(bar)
|
|
||||||
|
|
||||||
# Create mock COB data
|
|
||||||
mock_cob = COBData(
|
|
||||||
symbol='ETH/USDT',
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
current_price=3000.0,
|
|
||||||
bucket_size=1.0,
|
|
||||||
price_buckets={3000.0 + i: {'bid_volume': 100, 'ask_volume': 100, 'total_volume': 200, 'imbalance': 0.0} for i in range(-20, 21)},
|
|
||||||
bid_ask_imbalance={3000.0 + i: 0.0 for i in range(-20, 21)},
|
|
||||||
volume_weighted_prices={3000.0 + i: 3000.0 + i for i in range(-20, 21)},
|
|
||||||
order_flow_metrics={}
|
|
||||||
)
|
|
||||||
|
|
||||||
base_input = BaseDataInput(
|
|
||||||
symbol='ETH/USDT',
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
ohlcv_1s=mock_ohlcv,
|
|
||||||
ohlcv_1m=mock_ohlcv,
|
|
||||||
ohlcv_1h=mock_ohlcv,
|
|
||||||
ohlcv_1d=mock_ohlcv,
|
|
||||||
btc_ohlcv_1s=mock_ohlcv,
|
|
||||||
cob_data=mock_cob
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"✅ BaseDataInput available: {base_input.symbol}")
|
|
||||||
print(f" Feature vector shape: {base_input.get_feature_vector().shape}")
|
|
||||||
print(f" Validation: {'PASSED' if base_input.validate() else 'FAILED'}")
|
|
||||||
|
|
||||||
# Test 2: CNN Inference
|
|
||||||
print("\n2. Testing CNN inference with BaseDataInput...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
model_output = cnn_model.predict_from_base_input(base_input)
|
|
||||||
|
|
||||||
print("✅ CNN inference successful!")
|
|
||||||
print(f" Model: {model_output.model_name} ({model_output.model_type})")
|
|
||||||
print(f" Action: {model_output.predictions['action']}")
|
|
||||||
print(f" Confidence: {model_output.confidence:.3f}")
|
|
||||||
print(f" Probabilities: BUY={model_output.predictions['buy_probability']:.3f}, "
|
|
||||||
f"SELL={model_output.predictions['sell_probability']:.3f}, "
|
|
||||||
f"HOLD={model_output.predictions['hold_probability']:.3f}")
|
|
||||||
print(f" Hidden states: {len(model_output.hidden_states)} layers")
|
|
||||||
print(f" Metadata: {len(model_output.metadata)} fields")
|
|
||||||
|
|
||||||
# Test hidden states for cross-model feeding
|
|
||||||
if model_output.hidden_states:
|
|
||||||
print(" Hidden state layers:")
|
|
||||||
for key, value in model_output.hidden_states.items():
|
|
||||||
if isinstance(value, list):
|
|
||||||
print(f" {key}: {len(value)} features")
|
|
||||||
else:
|
|
||||||
print(f" {key}: {type(value)}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ CNN inference failed: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
# Test 3: Integration with StandardizedDataProvider
|
|
||||||
print("\n3. Testing integration with StandardizedDataProvider...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Store the model output in the provider
|
|
||||||
provider.store_model_output(model_output)
|
|
||||||
|
|
||||||
# Retrieve it back
|
|
||||||
stored_outputs = provider.get_model_outputs('ETH/USDT')
|
|
||||||
|
|
||||||
if cnn_model.model_name in stored_outputs:
|
|
||||||
print("✅ Model output storage and retrieval successful!")
|
|
||||||
stored_output = stored_outputs[cnn_model.model_name]
|
|
||||||
print(f" Stored action: {stored_output.predictions['action']}")
|
|
||||||
print(f" Stored confidence: {stored_output.confidence:.3f}")
|
|
||||||
else:
|
|
||||||
print("❌ Model output storage failed")
|
|
||||||
|
|
||||||
# Test cross-model feeding
|
|
||||||
updated_base_input = provider.get_base_data_input('ETH/USDT')
|
|
||||||
if updated_base_input and cnn_model.model_name in updated_base_input.last_predictions:
|
|
||||||
print("✅ Cross-model feeding working!")
|
|
||||||
print(f" CNN prediction available in BaseDataInput for other models")
|
|
||||||
else:
|
|
||||||
print("⚠️ Cross-model feeding not working as expected")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Integration test failed: {e}")
|
|
||||||
|
|
||||||
# Test 4: Training capabilities
|
|
||||||
print("\n4. Testing training capabilities...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create mock training data
|
|
||||||
training_inputs = [base_input] * 5 # Small batch
|
|
||||||
training_targets = ['BUY', 'SELL', 'HOLD', 'BUY', 'HOLD']
|
|
||||||
|
|
||||||
# Create optimizer
|
|
||||||
optimizer = torch.optim.Adam(cnn_model.parameters(), lr=0.001)
|
|
||||||
|
|
||||||
# Perform training step
|
|
||||||
loss = cnn_model.train_step(training_inputs, training_targets, optimizer)
|
|
||||||
|
|
||||||
print(f"✅ Training step successful!")
|
|
||||||
print(f" Training loss: {loss:.4f}")
|
|
||||||
|
|
||||||
# Test evaluation
|
|
||||||
eval_metrics = cnn_model.evaluate(training_inputs, training_targets)
|
|
||||||
print(f" Evaluation metrics: {eval_metrics}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Training test failed: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
# Test 5: Checkpoint management
|
|
||||||
print("\n5. Testing checkpoint management...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Save checkpoint
|
|
||||||
checkpoint_path = "test_cache/cnn_checkpoint.pth"
|
|
||||||
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
'training_loss': loss if 'loss' in locals() else 0.5,
|
|
||||||
'accuracy': eval_metrics.get('accuracy', 0.0) if 'eval_metrics' in locals() else 0.0,
|
|
||||||
'test_run': True
|
|
||||||
}
|
|
||||||
|
|
||||||
cnn_model.save_checkpoint(checkpoint_path, metadata)
|
|
||||||
print("✅ Checkpoint saved successfully!")
|
|
||||||
|
|
||||||
# Create new model and load checkpoint
|
|
||||||
new_cnn = StandardizedCNN(model_name="loaded_cnn_v1")
|
|
||||||
success = new_cnn.load_checkpoint(checkpoint_path)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
print("✅ Checkpoint loaded successfully!")
|
|
||||||
print(f" Loaded model info: {new_cnn.get_model_info()}")
|
|
||||||
else:
|
|
||||||
print("❌ Checkpoint loading failed")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Checkpoint test failed: {e}")
|
|
||||||
|
|
||||||
# Test 6: Performance and compatibility
|
|
||||||
print("\n6. Testing performance and compatibility...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Test inference speed
|
|
||||||
import time
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
for _ in range(10):
|
|
||||||
_ = cnn_model.predict_from_base_input(base_input)
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
avg_inference_time = (end_time - start_time) / 10 * 1000 # ms
|
|
||||||
print(f"✅ Performance test completed!")
|
|
||||||
print(f" Average inference time: {avg_inference_time:.2f} ms")
|
|
||||||
|
|
||||||
# Test memory usage
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
memory_used = torch.cuda.memory_allocated() / 1024 / 1024 # MB
|
|
||||||
print(f" GPU memory used: {memory_used:.2f} MB")
|
|
||||||
|
|
||||||
# Test model size
|
|
||||||
param_count = sum(p.numel() for p in cnn_model.parameters())
|
|
||||||
model_size_mb = param_count * 4 / 1024 / 1024 # Assuming float32
|
|
||||||
print(f" Model parameters: {param_count:,}")
|
|
||||||
print(f" Estimated model size: {model_size_mb:.2f} MB")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Performance test failed: {e}")
|
|
||||||
|
|
||||||
print("\n✅ StandardizedCNN test completed!")
|
|
||||||
print("\n🎯 Key achievements:")
|
|
||||||
print("✓ Accepts standardized BaseDataInput format")
|
|
||||||
print("✓ Processes COB+OHLCV data (300 frames multi-timeframe)")
|
|
||||||
print("✓ Outputs BUY/SELL/HOLD with confidence scores")
|
|
||||||
print("✓ Provides hidden states for cross-model feeding")
|
|
||||||
print("✓ Integrates with ModelOutputManager")
|
|
||||||
print("✓ Supports training and evaluation")
|
|
||||||
print("✓ Checkpoint management for persistence")
|
|
||||||
print("✓ Real-time inference capabilities")
|
|
||||||
|
|
||||||
print("\n🚀 Ready for integration:")
|
|
||||||
print("1. Can be used by orchestrator for decision making")
|
|
||||||
print("2. Hidden states available for RL model cross-feeding")
|
|
||||||
print("3. Outputs stored in standardized ModelOutput format")
|
|
||||||
print("4. Compatible with checkpoint management system")
|
|
||||||
print("5. Optimized for real-time trading inference")
|
|
||||||
|
|
||||||
return cnn_model
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_standardized_cnn()
|
|
@ -36,7 +36,7 @@ def test_standardized_data_provider():
|
|||||||
print("❌ BaseDataInput is None - this is expected if no historical data is available")
|
print("❌ BaseDataInput is None - this is expected if no historical data is available")
|
||||||
print(" The provider needs real market data to create BaseDataInput")
|
print(" The provider needs real market data to create BaseDataInput")
|
||||||
|
|
||||||
# Test with mock data
|
# Test with real data only
|
||||||
print("\n2. Testing data structures...")
|
print("\n2. Testing data structures...")
|
||||||
|
|
||||||
# Test ModelOutput creation
|
# Test ModelOutput creation
|
||||||
|
@ -1,337 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Test Trading System Fixes
|
|
||||||
|
|
||||||
This script tests the fixes for the trading system by simulating trades
|
|
||||||
and verifying that the issues are resolved.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python test_trading_fixes.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime
|
|
||||||
import json
|
|
||||||
|
|
||||||
# Add project root to path
|
|
||||||
project_root = Path(__file__).parent
|
|
||||||
sys.path.insert(0, str(project_root))
|
|
||||||
|
|
||||||
# Setup logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
||||||
handlers=[
|
|
||||||
logging.StreamHandler(),
|
|
||||||
logging.FileHandler('logs/test_fixes.log')
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class MockPosition:
|
|
||||||
"""Mock position for testing"""
|
|
||||||
def __init__(self, symbol, side, size, entry_price):
|
|
||||||
self.symbol = symbol
|
|
||||||
self.side = side
|
|
||||||
self.size = size
|
|
||||||
self.entry_price = entry_price
|
|
||||||
self.fees = 0.0
|
|
||||||
|
|
||||||
class MockTradingExecutor:
|
|
||||||
"""Mock trading executor for testing fixes"""
|
|
||||||
def __init__(self):
|
|
||||||
self.positions = {}
|
|
||||||
self.current_prices = {}
|
|
||||||
self.simulation_mode = True
|
|
||||||
|
|
||||||
def get_current_price(self, symbol):
|
|
||||||
"""Get current price for a symbol"""
|
|
||||||
# Simulate price movement
|
|
||||||
if symbol not in self.current_prices:
|
|
||||||
self.current_prices[symbol] = 3600.0
|
|
||||||
else:
|
|
||||||
# Add some random movement
|
|
||||||
import random
|
|
||||||
self.current_prices[symbol] += random.uniform(-10, 10)
|
|
||||||
|
|
||||||
return self.current_prices[symbol]
|
|
||||||
|
|
||||||
def execute_action(self, decision):
|
|
||||||
"""Execute a trading action"""
|
|
||||||
logger.info(f"Executing {decision.action} for {decision.symbol} at ${decision.price:.2f}")
|
|
||||||
|
|
||||||
# Simulate execution
|
|
||||||
if decision.action in ['BUY', 'LONG']:
|
|
||||||
self.positions[decision.symbol] = MockPosition(
|
|
||||||
decision.symbol, 'LONG', decision.size, decision.price
|
|
||||||
)
|
|
||||||
elif decision.action in ['SELL', 'SHORT']:
|
|
||||||
self.positions[decision.symbol] = MockPosition(
|
|
||||||
decision.symbol, 'SHORT', decision.size, decision.price
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def close_position(self, symbol, price=None):
|
|
||||||
"""Close a position"""
|
|
||||||
if symbol not in self.positions:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if price is None:
|
|
||||||
price = self.get_current_price(symbol)
|
|
||||||
|
|
||||||
position = self.positions[symbol]
|
|
||||||
|
|
||||||
# Calculate P&L
|
|
||||||
if position.side == 'LONG':
|
|
||||||
pnl = (price - position.entry_price) * position.size
|
|
||||||
else: # SHORT
|
|
||||||
pnl = (position.entry_price - price) * position.size
|
|
||||||
|
|
||||||
logger.info(f"Closing {position.side} position for {symbol} at ${price:.2f}, P&L: ${pnl:.2f}")
|
|
||||||
|
|
||||||
# Remove position
|
|
||||||
del self.positions[symbol]
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
class MockDecision:
|
|
||||||
"""Mock trading decision for testing"""
|
|
||||||
def __init__(self, symbol, action, price=None, size=10.0, confidence=0.8):
|
|
||||||
self.symbol = symbol
|
|
||||||
self.action = action
|
|
||||||
self.price = price
|
|
||||||
self.size = size
|
|
||||||
self.confidence = confidence
|
|
||||||
self.timestamp = datetime.now()
|
|
||||||
self.executed = False
|
|
||||||
self.blocked = False
|
|
||||||
self.blocked_reason = None
|
|
||||||
|
|
||||||
def test_price_caching_fix():
|
|
||||||
"""Test the price caching fix"""
|
|
||||||
logger.info("Testing price caching fix...")
|
|
||||||
|
|
||||||
# Create mock trading executor
|
|
||||||
executor = MockTradingExecutor()
|
|
||||||
|
|
||||||
# Import and apply fixes
|
|
||||||
try:
|
|
||||||
from core.trading_executor_fix import TradingExecutorFix
|
|
||||||
TradingExecutorFix.apply_fixes(executor)
|
|
||||||
|
|
||||||
# Test price caching
|
|
||||||
symbol = 'ETH/USDT'
|
|
||||||
|
|
||||||
# Get initial price
|
|
||||||
price1 = executor.get_current_price(symbol)
|
|
||||||
logger.info(f"Initial price: ${price1:.2f}")
|
|
||||||
|
|
||||||
# Get price again immediately (should be cached)
|
|
||||||
price2 = executor.get_current_price(symbol)
|
|
||||||
logger.info(f"Immediate second price: ${price2:.2f}")
|
|
||||||
|
|
||||||
# Wait for cache to expire
|
|
||||||
logger.info("Waiting for cache to expire (6 seconds)...")
|
|
||||||
time.sleep(6)
|
|
||||||
|
|
||||||
# Get price after cache expiry (should be different)
|
|
||||||
price3 = executor.get_current_price(symbol)
|
|
||||||
logger.info(f"Price after cache expiry: ${price3:.2f}")
|
|
||||||
|
|
||||||
# Check if prices are different
|
|
||||||
if price1 == price2:
|
|
||||||
logger.info("✅ Immediate price check uses cache as expected")
|
|
||||||
else:
|
|
||||||
logger.warning("❌ Immediate price check did not use cache")
|
|
||||||
|
|
||||||
if price1 != price3:
|
|
||||||
logger.info("✅ Price cache expiry working correctly")
|
|
||||||
else:
|
|
||||||
logger.warning("❌ Price cache expiry not working")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error testing price caching fix: {e}")
|
|
||||||
import traceback
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_duplicate_entry_prevention():
|
|
||||||
"""Test the duplicate entry prevention fix"""
|
|
||||||
logger.info("Testing duplicate entry prevention...")
|
|
||||||
|
|
||||||
# Create mock trading executor
|
|
||||||
executor = MockTradingExecutor()
|
|
||||||
|
|
||||||
# Import and apply fixes
|
|
||||||
try:
|
|
||||||
from core.trading_executor_fix import TradingExecutorFix
|
|
||||||
TradingExecutorFix.apply_fixes(executor)
|
|
||||||
|
|
||||||
# Test duplicate entry prevention
|
|
||||||
symbol = 'ETH/USDT'
|
|
||||||
|
|
||||||
# Create first decision
|
|
||||||
decision1 = MockDecision(symbol, 'SHORT')
|
|
||||||
decision1.price = executor.get_current_price(symbol)
|
|
||||||
|
|
||||||
# Execute first decision
|
|
||||||
result1 = executor.execute_action(decision1)
|
|
||||||
logger.info(f"First execution result: {result1}")
|
|
||||||
|
|
||||||
# Manually set recent entries to simulate a successful trade
|
|
||||||
if not hasattr(executor, 'recent_entries'):
|
|
||||||
executor.recent_entries = {}
|
|
||||||
|
|
||||||
executor.recent_entries[symbol] = {
|
|
||||||
'price': decision1.price,
|
|
||||||
'timestamp': time.time(),
|
|
||||||
'action': decision1.action
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create second decision with same action
|
|
||||||
decision2 = MockDecision(symbol, 'SHORT')
|
|
||||||
decision2.price = decision1.price # Use same price to trigger duplicate detection
|
|
||||||
|
|
||||||
# Execute second decision immediately (should be blocked)
|
|
||||||
result2 = executor.execute_action(decision2)
|
|
||||||
logger.info(f"Second execution result: {result2}")
|
|
||||||
logger.info(f"Second decision blocked: {getattr(decision2, 'blocked', False)}")
|
|
||||||
logger.info(f"Block reason: {getattr(decision2, 'blocked_reason', 'None')}")
|
|
||||||
|
|
||||||
# Check if second decision was blocked by trade cooldown
|
|
||||||
# This is also acceptable as it prevents duplicate entries
|
|
||||||
if getattr(decision2, 'blocked', False):
|
|
||||||
logger.info("✅ Trade prevention working correctly (via cooldown)")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.warning("❌ Trade prevention not working correctly")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error testing duplicate entry prevention: {e}")
|
|
||||||
import traceback
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_pnl_calculation_fix():
|
|
||||||
"""Test the P&L calculation fix"""
|
|
||||||
logger.info("Testing P&L calculation fix...")
|
|
||||||
|
|
||||||
# Create mock trading executor
|
|
||||||
executor = MockTradingExecutor()
|
|
||||||
|
|
||||||
# Import and apply fixes
|
|
||||||
try:
|
|
||||||
from core.trading_executor_fix import TradingExecutorFix
|
|
||||||
TradingExecutorFix.apply_fixes(executor)
|
|
||||||
|
|
||||||
# Test P&L calculation
|
|
||||||
symbol = 'ETH/USDT'
|
|
||||||
|
|
||||||
# Create a position
|
|
||||||
entry_price = 3600.0
|
|
||||||
size = 10.0
|
|
||||||
executor.positions[symbol] = MockPosition(symbol, 'SHORT', size, entry_price)
|
|
||||||
|
|
||||||
# Set exit price
|
|
||||||
exit_price = 3550.0
|
|
||||||
|
|
||||||
# Calculate P&L using fixed method
|
|
||||||
pnl_result = executor._calculate_pnl(executor.positions[symbol], exit_price)
|
|
||||||
|
|
||||||
# Calculate expected P&L
|
|
||||||
expected_pnl = (entry_price - exit_price) * size
|
|
||||||
|
|
||||||
logger.info(f"Entry price: ${entry_price:.2f}")
|
|
||||||
logger.info(f"Exit price: ${exit_price:.2f}")
|
|
||||||
logger.info(f"Size: {size}")
|
|
||||||
logger.info(f"Calculated P&L: ${pnl_result['gross_pnl']:.2f}")
|
|
||||||
logger.info(f"Expected P&L: ${expected_pnl:.2f}")
|
|
||||||
|
|
||||||
# Check if P&L calculation is correct
|
|
||||||
if abs(pnl_result['gross_pnl'] - expected_pnl) < 0.01:
|
|
||||||
logger.info("✅ P&L calculation fix working correctly")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.warning("❌ P&L calculation fix not working correctly")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error testing P&L calculation fix: {e}")
|
|
||||||
import traceback
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
def run_all_tests():
|
|
||||||
"""Run all tests"""
|
|
||||||
logger.info("=" * 70)
|
|
||||||
logger.info("TESTING TRADING SYSTEM FIXES")
|
|
||||||
logger.info("=" * 70)
|
|
||||||
|
|
||||||
# Create logs directory if it doesn't exist
|
|
||||||
os.makedirs('logs', exist_ok=True)
|
|
||||||
|
|
||||||
# Run tests
|
|
||||||
tests = [
|
|
||||||
("Price Caching Fix", test_price_caching_fix),
|
|
||||||
("Duplicate Entry Prevention", test_duplicate_entry_prevention),
|
|
||||||
("P&L Calculation Fix", test_pnl_calculation_fix)
|
|
||||||
]
|
|
||||||
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
for test_name, test_func in tests:
|
|
||||||
logger.info(f"\n{'-'*30}")
|
|
||||||
logger.info(f"Running test: {test_name}")
|
|
||||||
logger.info(f"{'-'*30}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = test_func()
|
|
||||||
results[test_name] = result
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Test {test_name} failed with error: {e}")
|
|
||||||
results[test_name] = False
|
|
||||||
|
|
||||||
# Print summary
|
|
||||||
logger.info("\n" + "=" * 70)
|
|
||||||
logger.info("TEST RESULTS SUMMARY")
|
|
||||||
logger.info("=" * 70)
|
|
||||||
|
|
||||||
all_passed = True
|
|
||||||
for test_name, result in results.items():
|
|
||||||
status = "✅ PASSED" if result else "❌ FAILED"
|
|
||||||
logger.info(f"{test_name}: {status}")
|
|
||||||
if not result:
|
|
||||||
all_passed = False
|
|
||||||
|
|
||||||
logger.info("=" * 70)
|
|
||||||
logger.info(f"OVERALL RESULT: {'✅ ALL TESTS PASSED' if all_passed else '❌ SOME TESTS FAILED'}")
|
|
||||||
logger.info("=" * 70)
|
|
||||||
|
|
||||||
# Save results to file
|
|
||||||
with open('logs/test_results.json', 'w') as f:
|
|
||||||
json.dump({
|
|
||||||
'timestamp': datetime.now().isoformat(),
|
|
||||||
'results': {k: "PASSED" if v else "FAILED" for k, v in results.items()},
|
|
||||||
'all_passed': all_passed
|
|
||||||
}, f, indent=2)
|
|
||||||
|
|
||||||
return all_passed
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = run_all_tests()
|
|
||||||
|
|
||||||
if success:
|
|
||||||
print("\nAll tests passed!")
|
|
||||||
sys.exit(0)
|
|
||||||
else:
|
|
||||||
print("\nSome tests failed. Check logs for details.")
|
|
||||||
sys.exit(1)
|
|
@ -283,13 +283,16 @@ class TrainingSystemValidator:
|
|||||||
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||||
logger.info(" ✓ RL Agent loaded")
|
logger.info(" ✓ RL Agent loaded")
|
||||||
|
|
||||||
# Test prediction capability
|
# Test prediction capability with real data
|
||||||
if hasattr(self.orchestrator.rl_agent, 'predict'):
|
if hasattr(self.orchestrator.rl_agent, 'predict'):
|
||||||
# Create dummy state for testing
|
|
||||||
dummy_state = np.random.random(1000) # Simplified test state
|
|
||||||
try:
|
try:
|
||||||
prediction = self.orchestrator.rl_agent.predict(dummy_state)
|
# Use real state from orchestrator instead of dummy data
|
||||||
logger.info(" ✓ RL Agent can make predictions")
|
real_state = self.orchestrator._get_rl_state('ETH/USDT')
|
||||||
|
if real_state is not None:
|
||||||
|
prediction = self.orchestrator.rl_agent.predict(real_state)
|
||||||
|
logger.info(" ✓ RL Agent can make predictions with real data")
|
||||||
|
else:
|
||||||
|
logger.warning(" ⚠ No real state available for RL prediction test")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f" ⚠ RL Agent prediction failed: {e}")
|
logger.warning(f" ⚠ RL Agent prediction failed: {e}")
|
||||||
else:
|
else:
|
||||||
|
Reference in New Issue
Block a user