training wip

This commit is contained in:
Dobromir Popov
2025-09-02 19:25:13 +03:00
parent 6dcb82c184
commit 226a6aa047
8 changed files with 241 additions and 771 deletions

View File

@@ -21,7 +21,6 @@ from typing import Dict, Any, Optional, Tuple
# Import checkpoint management
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
from utils.training_integration import get_training_integration
# Configure logging
logger = logging.getLogger(__name__)
@@ -522,7 +521,7 @@ class CNNModelTrainer:
# Checkpoint management
self.model_name = model_name
self.enable_checkpoints = enable_checkpoints
self.training_integration = get_training_integration() if enable_checkpoints else None
self.training_integration = None # Removed dependency on utils.training_integration
self.epoch_count = 0
self.best_val_accuracy = 0.0
self.best_val_loss = float('inf')

View File

@@ -16,7 +16,6 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(
# Import checkpoint management
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
from utils.training_integration import get_training_integration
# Configure logger
logger = logging.getLogger(__name__)
@@ -44,7 +43,7 @@ class DQNAgent:
# Checkpoint management
self.model_name = model_name
self.enable_checkpoints = enable_checkpoints
self.training_integration = get_training_integration() if enable_checkpoints else None
self.training_integration = None # Removed dependency on utils.training_integration
self.episode_count = 0
self.best_reward = float('-inf')
self.reward_history = deque(maxlen=100)

View File

@@ -1802,604 +1802,177 @@ class DataProvider:
logger.debug(f"Applied pivot-based normalization for {symbol}")
else:
# Fallback to traditional normalization when pivot bounds not available
logger.debug("Using traditional normalization (no pivot bounds available)")
for col in df_norm.columns:
if col in ['open', 'high', 'low', 'close', 'sma_10', 'sma_20', 'sma_50',
'ema_12', 'ema_26', 'ema_50', 'bb_upper', 'bb_lower', 'bb_middle',
'keltner_upper', 'keltner_lower', 'keltner_middle', 'psar', 'vwap']:
# Price-based indicators: normalize by close price
if 'close' in df_norm.columns:
base_price = df_norm['close'].iloc[-1] # Use latest close as reference
if base_price > 0:
df_norm[col] = df_norm[col] / base_price
elif col == 'volume':
# Volume: normalize by its own rolling mean
volume_mean = df_norm[col].rolling(window=min(20, len(df_norm))).mean().iloc[-1]
if volume_mean > 0:
df_norm[col] = df_norm[col] / volume_mean
# Normalize indicators that have standard ranges (regardless of pivot bounds)
for col in df_norm.columns:
if col in ['rsi_14', 'rsi_7', 'rsi_21']:
# RSI: already 0-100, normalize to 0-1
df_norm[col] = df_norm[col] / 100.0
elif col in ['stoch_k', 'stoch_d']:
# Stochastic: already 0-100, normalize to 0-1
df_norm[col] = df_norm[col] / 100.0
elif col == 'williams_r':
# Williams %R: -100 to 0, normalize to 0-1
df_norm[col] = (df_norm[col] + 100) / 100.0
elif col in ['macd', 'macd_signal', 'macd_histogram']:
# MACD: normalize by ATR or close price
if 'atr' in df_norm.columns and df_norm['atr'].iloc[-1] > 0:
df_norm[col] = df_norm[col] / df_norm['atr'].iloc[-1]
elif 'close' in df_norm.columns and df_norm['close'].iloc[-1] > 0:
df_norm[col] = df_norm[col] / df_norm['close'].iloc[-1]
elif col in ['bb_width', 'bb_percent', 'price_position', 'trend_strength',
'momentum_composite', 'volatility_regime', 'pivot_price_position',
'pivot_support_distance', 'pivot_resistance_distance']:
# Already normalized indicators: ensure 0-1 range
df_norm[col] = np.clip(df_norm[col], 0, 1)
elif col in ['atr', 'true_range']:
# Volatility indicators: normalize by close price or pivot range
if symbol and symbol in self.pivot_bounds:
bounds = self.pivot_bounds[symbol]
df_norm[col] = df_norm[col] / bounds.get_price_range()
elif 'close' in df_norm.columns and df_norm['close'].iloc[-1] > 0:
df_norm[col] = df_norm[col] / df_norm['close'].iloc[-1]
elif col not in ['timestamp', 'near_pivot_support', 'near_pivot_resistance']:
# Other indicators: z-score normalization
col_mean = df_norm[col].rolling(window=min(20, len(df_norm))).mean().iloc[-1]
col_std = df_norm[col].rolling(window=min(20, len(df_norm))).std().iloc[-1]
if col_std > 0:
df_norm[col] = (df_norm[col] - col_mean) / col_std
else:
df_norm[col] = 0
# Replace inf/-inf with 0
df_norm = df_norm.replace([np.inf, -np.inf], 0)
# Use symbol-grouped normalization with consistent ranges
df_norm = self._apply_symbol_grouped_normalization(df_norm, symbol)
# Fill any remaining NaN values
df_norm = df_norm.fillna(0.0)
return df_norm
except Exception as e:
logger.error(f"Error normalizing features for {symbol}: {e}")
return df.fillna(0.0) if df is not None else None
def _apply_symbol_grouped_normalization(self, df: pd.DataFrame, symbol: str) -> pd.DataFrame:
"""Apply symbol-grouped normalization with consistent ranges across timeframes"""
try:
df_norm = df.copy()
# Get symbol-specific price ranges for consistent normalization
symbol_price_ranges = {
'ETH/USDT': {'min': 1000, 'max': 5000}, # ETH price range
'BTC/USDT': {'min': 90000, 'max': 120000} # BTC price range
}
if symbol in symbol_price_ranges:
price_range = symbol_price_ranges[symbol]
range_size = price_range['max'] - price_range['min']
# Normalize price columns to [0, 1] range specific to symbol
price_cols = ['open', 'high', 'low', 'close']
for col in price_cols:
if col in df_norm.columns:
df_norm[col] = (df_norm[col] - price_range['min']) / range_size
df_norm[col] = np.clip(df_norm[col], 0, 1) # Ensure [0,1] range
# Normalize volume to [0, 1] using log scale
if 'volume' in df_norm.columns:
df_norm['volume'] = np.log1p(df_norm['volume'])
vol_max = df_norm['volume'].max()
if vol_max > 0:
df_norm['volume'] = df_norm['volume'] / vol_max
logger.debug(f"Applied symbol-grouped normalization for {symbol}")
# Fill any NaN values
df_norm = df_norm.fillna(0)
return df_norm
except Exception as e:
logger.error(f"Error normalizing features: {e}")
logger.error(f"Error in symbol-grouped normalization for {symbol}: {e}")
return df
def get_multi_symbol_feature_matrix(self, symbols: List[str] = None,
timeframes: List[str] = None,
window_size: int = 20) -> Optional[np.ndarray]:
"""
Get feature matrix for multiple symbols and timeframes
Returns:
np.ndarray: Shape (n_symbols, n_timeframes, window_size, n_features)
"""
def get_historical_data_for_inference(self, symbol: str, timeframe: str, limit: int = 300) -> Optional[pd.DataFrame]:
"""Get normalized historical data specifically for model inference"""
try:
if symbols is None:
symbols = self.symbols
if timeframes is None:
timeframes = self.timeframes
# Get raw historical data
raw_df = self.get_historical_data(symbol, timeframe, limit)
symbol_matrices = []
if raw_df is None or raw_df.empty:
return None
for symbol in symbols:
symbol_matrix = self.get_feature_matrix(symbol, timeframes, window_size)
if symbol_matrix is not None:
symbol_matrices.append(symbol_matrix)
# Apply normalization for inference
normalized_df = self._normalize_features(raw_df, symbol)
logger.debug(f"Retrieved normalized historical data for inference: {symbol} {timeframe} ({len(normalized_df)} records)")
return normalized_df
except Exception as e:
logger.error(f"Error getting normalized historical data for inference: {symbol} {timeframe}: {e}")
return None
def get_multi_symbol_features_for_inference(self, symbols_timeframes: List[Tuple[str, str]], limit: int = 300) -> Dict[str, Dict[str, pd.DataFrame]]:
"""Get normalized multi-symbol feature matrices for model inference"""
try:
logger.info("Preparing normalized multi-symbol features for model inference...")
symbol_features = {}
for symbol, timeframe in symbols_timeframes:
if symbol not in symbol_features:
symbol_features[symbol] = {}
# Get normalized data for inference
normalized_df = self.get_historical_data_for_inference(symbol, timeframe, limit)
if normalized_df is not None and not normalized_df.empty:
symbol_features[symbol][timeframe] = normalized_df
logger.debug(f"Prepared normalized features for {symbol} {timeframe}: {len(normalized_df)} records")
else:
logger.warning(f"Could not create feature matrix for {symbol}")
logger.warning(f"No normalized data available for {symbol} {timeframe}")
symbol_features[symbol][timeframe] = None
if symbol_matrices:
# Stack all symbol matrices
multi_symbol_matrix = np.stack(symbol_matrices, axis=0)
logger.info(f"Created multi-symbol feature matrix: {multi_symbol_matrix.shape}")
return multi_symbol_matrix
return symbol_features
except Exception as e:
logger.error(f"Error preparing multi-symbol features for inference: {e}")
return {}
def get_cnn_features_for_inference(self, symbol: str, timeframe: str = '1m', window_size: int = 60) -> Optional[np.ndarray]:
"""Get normalized CNN features for a specific symbol and timeframe"""
try:
# Get normalized data
df = self.get_historical_data_for_inference(symbol, timeframe, limit=300)
if df is None or df.empty:
return None
# Extract recent window for CNN
recent_data = df.tail(window_size)
# Extract OHLCV features
features = recent_data[['open', 'high', 'low', 'close', 'volume']].values
logger.debug(f"Extracted CNN features for {symbol} {timeframe}: {features.shape}")
return features.flatten()
except Exception as e:
logger.error(f"Error extracting CNN features for {symbol} {timeframe}: {e}")
return None
def get_dqn_state_for_inference(self, symbols_timeframes: List[Tuple[str, str]], target_size: int = 100) -> Optional[np.ndarray]:
"""Get normalized DQN state vector combining multiple symbols and timeframes"""
try:
state_components = []
for symbol, timeframe in symbols_timeframes:
df = self.get_historical_data_for_inference(symbol, timeframe, limit=50)
if df is not None and not df.empty:
# Extract key features for state
latest = df.iloc[-1]
state_features = [
latest['close'], # Current price (normalized)
latest['volume'], # Current volume (normalized)
df['close'].pct_change().iloc[-1] if len(df) > 1 else 0, # Price change
]
state_components.extend(state_features)
if state_components:
# Pad or truncate to expected DQN state size
if len(state_components) < target_size:
state_components.extend([0] * (target_size - len(state_components)))
else:
state_components = state_components[:target_size]
state_vector = np.array(state_components, dtype=np.float32)
logger.debug(f"Created DQN state vector: {len(state_vector)} dimensions")
return state_vector
return None
except Exception as e:
logger.error(f"Error creating multi-symbol feature matrix: {e}")
logger.error(f"Error creating DQN state for inference: {e}")
return None
def health_check(self) -> Dict[str, Any]:
"""Get health status of the data provider"""
status = {
'streaming': self.is_streaming,
'symbols': len(self.symbols),
'timeframes': len(self.timeframes),
'current_prices': len(self.current_prices),
'websocket_tasks': len(self.websocket_tasks),
'historical_data_loaded': {}
}
# Check historical data availability
for symbol in self.symbols:
status['historical_data_loaded'][symbol] = {}
for tf in self.timeframes:
has_data = (symbol in self.historical_data and
tf in self.historical_data[symbol] and
not self.historical_data[symbol][tf].empty)
status['historical_data_loaded'][symbol][tf] = has_data
return status
def subscribe_to_ticks(self, callback: Callable[[MarketTick], None],
symbols: List[str] = None,
subscriber_name: str = None) -> str:
"""Subscribe to real-time tick data updates"""
subscriber_id = str(uuid.uuid4())[:8]
subscriber_name = subscriber_name or f"subscriber_{subscriber_id}"
# Convert symbols to Binance format
if symbols:
binance_symbols = [s.replace('/', '').upper() for s in symbols]
else:
binance_symbols = [s.replace('/', '').upper() for s in self.symbols]
subscriber = DataSubscriber(
subscriber_id=subscriber_id,
callback=callback,
symbols=binance_symbols,
subscriber_name=subscriber_name
)
with self.subscriber_lock:
self.subscribers[subscriber_id] = subscriber
logger.info(f"New tick subscriber registered: {subscriber_name} ({subscriber_id}) for symbols: {binance_symbols}")
# Send recent tick data to new subscriber
self._send_recent_ticks_to_subscriber(subscriber)
return subscriber_id
def unsubscribe_from_ticks(self, subscriber_id: str):
"""Unsubscribe from tick data updates"""
with self.subscriber_lock:
if subscriber_id in self.subscribers:
subscriber_name = self.subscribers[subscriber_id].subscriber_name
self.subscribers[subscriber_id].active = False
del self.subscribers[subscriber_id]
logger.info(f"Subscriber {subscriber_name} ({subscriber_id}) unsubscribed")
def _send_recent_ticks_to_subscriber(self, subscriber: DataSubscriber):
"""Send recent tick data to a new subscriber"""
def get_transformer_sequences_for_inference(self, symbols_timeframes: List[Tuple[str, str]], seq_length: int = 150) -> List[np.ndarray]:
"""Get normalized sequences for transformer inference"""
try:
for symbol in subscriber.symbols:
if symbol in self.tick_buffers:
# Send last 50 ticks to get subscriber up to speed
recent_ticks = list(self.tick_buffers[symbol])[-50:]
for tick in recent_ticks:
try:
subscriber.callback(tick)
except Exception as e:
logger.warning(f"Error sending recent tick to subscriber {subscriber.subscriber_id}: {e}")
except Exception as e:
logger.error(f"Error sending recent ticks: {e}")
sequences = []
def _distribute_tick(self, tick: MarketTick):
"""Distribute tick to all relevant subscribers"""
distributed_count = 0
for symbol, timeframe in symbols_timeframes:
df = self.get_historical_data_for_inference(symbol, timeframe, limit=300)
with self.subscriber_lock:
subscribers_to_remove = []
if df is not None and not df.empty:
# Use last seq_length points as sequence
sequence = df.tail(seq_length)[['open', 'high', 'low', 'close', 'volume']].values
sequences.append(sequence)
logger.debug(f"Created transformer sequence for {symbol} {timeframe}: {sequence.shape}")
for subscriber_id, subscriber in self.subscribers.items():
if not subscriber.active:
subscribers_to_remove.append(subscriber_id)
continue
if tick.symbol in subscriber.symbols:
try:
# Call subscriber callback in a thread to avoid blocking
def call_callback():
try:
subscriber.callback(tick)
subscriber.tick_count += 1
subscriber.last_update = datetime.now()
except Exception as e:
logger.warning(f"Error in subscriber {subscriber_id} callback: {e}")
subscriber.active = False
# Use thread to avoid blocking the main data processing
Thread(target=call_callback, daemon=True).start()
distributed_count += 1
return sequences
except Exception as e:
logger.warning(f"Error distributing tick to subscriber {subscriber_id}: {e}")
subscriber.active = False
# Remove inactive subscribers
for subscriber_id in subscribers_to_remove:
if subscriber_id in self.subscribers:
del self.subscribers[subscriber_id]
self.distribution_stats['total_ticks_distributed'] += distributed_count
def _validate_tick_data(self, symbol: str, price: float, volume: float) -> bool:
"""Validate incoming tick data for quality"""
try:
# Basic validation
if price <= 0 or volume < 0:
return False
# Price change validation
last_price = self.last_prices.get(symbol, 0)
if last_price > 0:
price_change_pct = abs(price - last_price) / last_price
if price_change_pct > self.price_change_threshold:
logger.warning(f"Large price change for {symbol}: {price_change_pct:.2%}")
# Don't reject, just warn - could be legitimate
return True
except Exception as e:
logger.error(f"Error validating tick data: {e}")
return False
def get_recent_ticks(self, symbol: str, count: int = 100) -> List[MarketTick]:
"""Get recent ticks for a symbol"""
binance_symbol = symbol.replace('/', '').upper()
if binance_symbol in self.tick_buffers:
return list(self.tick_buffers[binance_symbol])[-count:]
logger.error(f"Error creating transformer sequences for inference: {e}")
return []
def subscribe_to_raw_ticks(self, callback: Callable[[RawTick], None]) -> str:
"""Subscribe to raw tick data with timing information"""
subscriber_id = str(uuid.uuid4())
self.raw_tick_callbacks.append(callback)
logger.info(f"Raw tick subscriber added: {subscriber_id}")
return subscriber_id
def subscribe_to_ohlcv_bars(self, callback: Callable[[OHLCVBar], None]) -> str:
"""Subscribe to 1s OHLCV bars calculated from ticks"""
subscriber_id = str(uuid.uuid4())
self.ohlcv_bar_callbacks.append(callback)
logger.info(f"OHLCV bar subscriber added: {subscriber_id}")
return subscriber_id
def get_raw_tick_features(self, symbol: str, window_size: int = 50) -> Optional[np.ndarray]:
"""Get raw tick features for model consumption"""
return self.tick_aggregator.get_tick_features_for_model(symbol, window_size)
def get_ohlcv_features(self, symbol: str, window_size: int = 60) -> Optional[np.ndarray]:
"""Get 1s OHLCV features for model consumption"""
return self.tick_aggregator.get_ohlcv_features_for_model(symbol, window_size)
def get_detected_patterns(self, symbol: str, count: int = 50) -> List:
"""Get recently detected tick patterns"""
return self.tick_aggregator.get_detected_patterns(symbol, count)
def get_tick_aggregator_stats(self) -> Dict[str, Any]:
"""Get tick aggregator statistics"""
return self.tick_aggregator.get_statistics()
def get_subscriber_stats(self) -> Dict[str, Any]:
"""Get subscriber and distribution statistics"""
with self.subscriber_lock:
active_subscribers = len([s for s in self.subscribers.values() if s.active])
subscriber_stats = {
sid: {
'name': s.subscriber_name,
'active': s.active,
'symbols': s.symbols,
'tick_count': s.tick_count,
'last_update': s.last_update.isoformat() if s.last_update else None
}
for sid, s in self.subscribers.items()
}
# Get tick aggregator stats
aggregator_stats = self.get_tick_aggregator_stats()
return {
'active_subscribers': active_subscribers,
'total_subscribers': len(self.subscribers),
'raw_tick_callbacks': len(self.raw_tick_callbacks),
'ohlcv_bar_callbacks': len(self.ohlcv_bar_callbacks),
'subscriber_details': subscriber_stats,
'distribution_stats': self.distribution_stats.copy(),
'buffer_sizes': {symbol: len(buffer) for symbol, buffer in self.tick_buffers.items()},
'tick_aggregator': aggregator_stats
}
def update_bom_cache(self, symbol: str, bom_features: List[float], cob_integration=None):
"""
Update BOM cache with latest features for a symbol
Args:
symbol: Trading symbol (e.g., 'ETH/USDT')
bom_features: List of BOM features (should be 120 features)
cob_integration: Optional COB integration instance for real BOM data
"""
try:
current_time = datetime.now()
# Ensure we have exactly 120 features
if len(bom_features) != self.bom_feature_count:
if len(bom_features) > self.bom_feature_count:
bom_features = bom_features[:self.bom_feature_count]
else:
bom_features.extend([0.0] * (self.bom_feature_count - len(bom_features)))
# Convert to numpy array for efficient storage
bom_array = np.array(bom_features, dtype=np.float32)
# Add timestamp and features to cache
with self.data_lock:
self.bom_data_cache[symbol].append((current_time, bom_array))
logger.debug(f"Updated BOM cache for {symbol}: {len(self.bom_data_cache[symbol])} timestamps cached")
except Exception as e:
logger.error(f"Error updating BOM cache for {symbol}: {e}")
def get_bom_matrix_for_cnn(self, symbol: str, sequence_length: int = 50) -> Optional[np.ndarray]:
"""
Get BOM matrix for CNN input from cached 1s data
Args:
symbol: Trading symbol (e.g., 'ETH/USDT')
sequence_length: Required sequence length (default 50)
Returns:
np.ndarray: BOM matrix of shape (sequence_length, 120) or None if insufficient data
"""
try:
with self.data_lock:
if symbol not in self.bom_data_cache or len(self.bom_data_cache[symbol]) == 0:
logger.warning(f"No BOM data cached for {symbol}")
return None
# Get recent data
cached_data = list(self.bom_data_cache[symbol])
if len(cached_data) < sequence_length:
logger.warning(f"Insufficient BOM data for {symbol}: {len(cached_data)} < {sequence_length}")
# Pad with zeros if we don't have enough data
bom_matrix = np.zeros((sequence_length, self.bom_feature_count), dtype=np.float32)
# Fill available data at the end
for i, (timestamp, features) in enumerate(cached_data):
if i < sequence_length:
bom_matrix[sequence_length - len(cached_data) + i] = features
return bom_matrix
# Take the most recent sequence_length samples
recent_data = cached_data[-sequence_length:]
# Create matrix
bom_matrix = np.zeros((sequence_length, self.bom_feature_count), dtype=np.float32)
for i, (timestamp, features) in enumerate(recent_data):
bom_matrix[i] = features
logger.debug(f"Retrieved BOM matrix for {symbol}: shape={bom_matrix.shape}")
return bom_matrix
except Exception as e:
logger.error(f"Error getting BOM matrix for {symbol}: {e}")
return None
def get_real_bom_features(self, symbol: str) -> Optional[List[float]]:
"""
Get REAL BOM features from actual market data ONLY
NO SYNTHETIC DATA - Returns None if real data is not available
"""
try:
# Try to get real COB data from integration
if hasattr(self, 'cob_integration') and self.cob_integration:
return self._extract_real_bom_features(symbol, self.cob_integration)
# No real data available - return None instead of synthetic
logger.warning(f"No real BOM data available for {symbol} - waiting for real market data")
return None
except Exception as e:
logger.error(f"Error getting real BOM features for {symbol}: {e}")
return None
def start_bom_cache_updates(self, cob_integration=None):
"""
Start background updates of BOM cache every second
Args:
cob_integration: Optional COB integration instance for real data
"""
try:
def update_loop():
while self.is_streaming:
try:
for symbol in self.symbols:
if cob_integration:
# Try to get real BOM features from COB integration
try:
bom_features = self._extract_real_bom_features(symbol, cob_integration)
if bom_features:
self.update_bom_cache(symbol, bom_features, cob_integration)
else:
# NO SYNTHETIC FALLBACK - Wait for real data
logger.warning(f"No real BOM features available for {symbol} - waiting for real data")
except Exception as e:
logger.warning(f"Error getting real BOM features for {symbol}: {e}")
logger.warning(f"Waiting for real data instead of using synthetic")
else:
# NO SYNTHETIC FEATURES - Wait for real COB integration
logger.warning(f"No COB integration available for {symbol} - waiting for real data")
time.sleep(1.0) # Update every second
except Exception as e:
logger.error(f"Error in BOM cache update loop: {e}")
time.sleep(5.0) # Wait longer on error
# Start background thread
bom_thread = Thread(target=update_loop, daemon=True)
bom_thread.start()
logger.info("Started BOM cache updates (1s resolution)")
except Exception as e:
logger.error(f"Error starting BOM cache updates: {e}")
def _extract_real_bom_features(self, symbol: str, cob_integration) -> Optional[List[float]]:
"""Extract real BOM features from COB integration"""
try:
features = []
# Get consolidated order book
if hasattr(cob_integration, 'get_consolidated_orderbook'):
cob_snapshot = cob_integration.get_consolidated_orderbook(symbol)
if cob_snapshot:
# Extract order book features (40 features)
features.extend(self._extract_orderbook_features(cob_snapshot))
else:
features.extend([0.0] * 40)
else:
features.extend([0.0] * 40)
# Get volume profile features (30 features)
if hasattr(cob_integration, 'get_session_volume_profile'):
volume_profile = cob_integration.get_session_volume_profile(symbol)
if volume_profile:
features.extend(self._extract_volume_profile_features(volume_profile))
else:
features.extend([0.0] * 30)
else:
features.extend([0.0] * 30)
# Add flow and microstructure features (50 features)
features.extend(self._extract_flow_microstructure_features(symbol, cob_integration))
# Ensure exactly 120 features
if len(features) > 120:
features = features[:120]
elif len(features) < 120:
features.extend([0.0] * (120 - len(features)))
return features
except Exception as e:
logger.warning(f"Error extracting real BOM features for {symbol}: {e}")
return None
def _extract_orderbook_features(self, cob_snapshot) -> List[float]:
"""Extract order book features from COB snapshot"""
features = []
try:
# Top 10 bid levels
for i in range(10):
if i < len(cob_snapshot.consolidated_bids):
level = cob_snapshot.consolidated_bids[i]
price_offset = (level.price - cob_snapshot.volume_weighted_mid) / cob_snapshot.volume_weighted_mid
volume_normalized = level.total_volume_usd / 1000000
features.extend([price_offset, volume_normalized])
else:
features.extend([0.0, 0.0])
# Top 10 ask levels
for i in range(10):
if i < len(cob_snapshot.consolidated_asks):
level = cob_snapshot.consolidated_asks[i]
price_offset = (level.price - cob_snapshot.volume_weighted_mid) / cob_snapshot.volume_weighted_mid
volume_normalized = level.total_volume_usd / 1000000
features.extend([price_offset, volume_normalized])
else:
features.extend([0.0, 0.0])
except Exception as e:
logger.warning(f"Error extracting order book features: {e}")
features = [0.0] * 40
return features[:40]
def _extract_volume_profile_features(self, volume_profile) -> List[float]:
"""Extract volume profile features"""
features = []
try:
if 'data' in volume_profile:
svp_data = volume_profile['data']
top_levels = sorted(svp_data, key=lambda x: x.get('total_volume', 0), reverse=True)[:10]
for level in top_levels:
buy_percent = level.get('buy_percent', 50.0) / 100.0
sell_percent = level.get('sell_percent', 50.0) / 100.0
total_volume = level.get('total_volume', 0.0) / 1000000
features.extend([buy_percent, sell_percent, total_volume])
# Pad to 30 features
while len(features) < 30:
features.extend([0.5, 0.5, 0.0])
except Exception as e:
logger.warning(f"Error extracting volume profile features: {e}")
features = [0.0] * 30
return features[:30]
def _extract_flow_microstructure_features(self, symbol: str, cob_integration) -> List[float]:
"""Extract flow and microstructure features"""
try:
# For now, return synthetic features since full implementation would be complex
# NO SYNTHETIC DATA - Return None if no real microstructure data
logger.warning(f"No real microstructure data available for {symbol}")
return None
except:
return [0.0] * 50
def _handle_rate_limit(self, url: str):
"""Handle rate limiting with exponential backoff"""
current_time = time.time()
# Check if we need to wait
if url in self.last_request_time:
time_since_last = current_time - self.last_request_time[url]
if time_since_last < self.request_interval:
sleep_time = self.request_interval - time_since_last
logger.info(f"Rate limiting: sleeping {sleep_time:.2f}s")
time.sleep(sleep_time)
self.last_request_time[url] = time.time()
def _make_request_with_retry(self, url: str, params: dict = None):
"""Make HTTP request with retry logic for 451 errors"""
for attempt in range(self.max_retries):
try:
self._handle_rate_limit(url)
response = requests.get(url, params=params, timeout=30)
if response.status_code == 451:
logger.warning(f"Rate limit hit (451), attempt {attempt + 1}/{self.max_retries}")
if attempt < self.max_retries - 1:
sleep_time = self.retry_delay * (2 ** attempt) # Exponential backoff
logger.info(f"Waiting {sleep_time}s before retry...")
time.sleep(sleep_time)
continue
else:
logger.error("Max retries reached, using cached data")
return None
response.raise_for_status()
return response
except Exception as e:
logger.error(f"Request failed (attempt {attempt + 1}): {e}")
if attempt < self.max_retries - 1:
time.sleep(5 * (attempt + 1))
return None

View File

@@ -25,7 +25,6 @@ import json
# Import checkpoint management
import torch
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
from utils.training_integration import get_training_integration
logger = logging.getLogger(__name__)
@@ -73,7 +72,7 @@ class ExtremaTrainer:
# Checkpoint management
self.model_name = model_name
self.enable_checkpoints = enable_checkpoints
self.training_integration = get_training_integration() if enable_checkpoints else None
self.training_integration = None # Removed dependency on utils.training_integration
self.training_session_count = 0
self.best_detection_accuracy = 0.0
self.checkpoint_frequency = 50 # Save checkpoint every 50 training sessions

View File

@@ -22,7 +22,6 @@ import pandas as pd
# Import checkpoint management
import torch
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
from utils.training_integration import get_training_integration
logger = logging.getLogger(__name__)
@@ -84,7 +83,7 @@ class NegativeCaseTrainer:
# Checkpoint management
self.model_name = model_name
self.enable_checkpoints = enable_checkpoints
self.training_integration = get_training_integration() if enable_checkpoints else None
self.training_integration = None # Removed dependency on utils.training_integration
self.training_session_count = 0
self.best_loss_reduction = 0.0
self.checkpoint_frequency = 25 # Save checkpoint every 25 training sessions

View File

@@ -1757,16 +1757,27 @@ class TradingOrchestrator:
logger.warning("EnhancedRealtimeTrainingSystem not available - training disabled")
self.training_enabled = False
return
# Initialize unified training manager
from utils.training_integration import get_unified_training_manager
self.training_manager = get_unified_training_manager(
# Initialize enhanced training system directly (no external training_integration module needed)
try:
from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
orchestrator=self,
data_provider=self.data_provider,
dashboard=None
)
self.training_manager.initialize()
# Keep backward-compatible attribute
self.enhanced_training_system = getattr(self.training_manager, 'training_system', None)
logger.info("✅ Enhanced training system initialized successfully")
# Auto-start training by default
logger.info("🚀 Auto-starting enhanced real-time training...")
self.start_enhanced_training()
except ImportError as e:
logger.error(f"Failed to import EnhancedRealtimeTrainingSystem: {e}")
self.training_enabled = False
return
logger.info("Enhanced real-time training system initialized")
logger.info(" - Real-time model training: ENABLED")
@@ -2365,8 +2376,8 @@ class TradingOrchestrator:
logger.info("Initializing ExtremaTrainer with historical context...")
self.extrema_trainer.initialize_context_data()
# CRITICAL: Initialize ALL models with historical data
self._initialize_models_with_historical_data(loaded_data)
# CRITICAL: Initialize ALL models with historical data (using data provider's normalized methods)
self._initialize_models_with_historical_data(symbols_timeframes)
logger.info(f"🎯 Historical data loading complete: {total_candles} total candles loaded")
logger.info(f"📊 Available datasets: {list(loaded_data.keys())}")
@@ -2374,144 +2385,60 @@ class TradingOrchestrator:
except Exception as e:
logger.error(f"Error in historical data loading: {e}")
def _initialize_models_with_historical_data(self, loaded_data: Dict[str, Any]):
"""Initialize all NN models with historical data and multi-symbol support"""
def _initialize_models_with_historical_data(self, symbols_timeframes: List[Tuple[str, str]]):
"""Initialize all NN models with historical data using data provider's normalized methods"""
try:
logger.info("Initializing models with historical data and multi-symbol support...")
logger.info("Initializing models with normalized historical data from data provider...")
# Prepare multi-symbol feature matrices
symbol_features = self._prepare_multi_symbol_features(loaded_data)
# Use data provider's multi-symbol feature preparation
symbol_features = self.data_provider.get_multi_symbol_features_for_inference(symbols_timeframes, limit=300)
# Initialize CNN with multi-symbol data
if hasattr(self, 'cnn_model') and self.cnn_model:
logger.info("Initializing CNN with multi-symbol historical features...")
self._initialize_cnn_with_data(symbol_features)
self._initialize_cnn_with_provider_data()
# Initialize DQN with multi-symbol states
if hasattr(self, 'rl_agent') and self.rl_agent:
logger.info("Initializing DQN with multi-symbol state vectors...")
self._initialize_dqn_with_data(symbol_features)
self._initialize_dqn_with_provider_data(symbols_timeframes)
# Initialize Transformer with sequence data
if hasattr(self, 'transformer_model') and self.transformer_model:
logger.info("Initializing Transformer with multi-symbol sequences...")
self._initialize_transformer_with_data(symbol_features)
self._initialize_transformer_with_provider_data(symbols_timeframes)
# Initialize Decision Fusion with comprehensive features
if hasattr(self, 'decision_fusion') and self.decision_fusion:
logger.info("Initializing Decision Fusion with multi-symbol features...")
self._initialize_decision_with_data(symbol_features)
self._initialize_decision_with_provider_data(symbol_features)
logger.info("✅ All models initialized with historical multi-symbol data")
logger.info("✅ All models initialized with data provider's normalized historical data")
except Exception as e:
logger.error(f"Error initializing models with historical data: {e}")
def _prepare_multi_symbol_features(self, loaded_data: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare normalized multi-symbol feature matrices"""
try:
symbol_features = {
'ETH/USDT': {'1m': None, '1h': None, '1d': None},
'BTC/USDT': {'1m': None}
}
# Process each symbol's data with symbol-specific normalization
for data_key, df in loaded_data.items():
if df is None or df.empty:
continue
# Extract symbol and timeframe
if '_1m' in data_key:
symbol = data_key.replace('_1m', '')
timeframe = '1m'
elif '_1h' in data_key:
symbol = data_key.replace('_1h', '')
timeframe = '1h'
elif '_1d' in data_key:
symbol = data_key.replace('_1d', '')
timeframe = '1d'
else:
continue
# Apply symbol-grouped normalization
normalized_df = self._apply_symbol_grouped_normalization(df, symbol)
if normalized_df is not None:
symbol_features[symbol][timeframe] = normalized_df
logger.debug(f"Prepared normalized features for {symbol} {timeframe}")
return symbol_features
except Exception as e:
logger.error(f"Error preparing multi-symbol features: {e}")
return {}
def _apply_symbol_grouped_normalization(self, df: pd.DataFrame, symbol: str) -> pd.DataFrame:
"""Apply symbol-grouped normalization with consistent ranges across timeframes"""
try:
df_norm = df.copy()
# Get symbol-specific price ranges for consistent normalization
symbol_price_ranges = {
'ETH/USDT': {'min': 1000, 'max': 5000}, # ETH price range
'BTC/USDT': {'min': 90000, 'max': 120000} # BTC price range
}
if symbol in symbol_price_ranges:
price_range = symbol_price_ranges[symbol]
range_size = price_range['max'] - price_range['min']
# Normalize price columns to [0, 1] range specific to symbol
price_cols = ['open', 'high', 'low', 'close']
for col in price_cols:
if col in df_norm.columns:
df_norm[col] = (df_norm[col] - price_range['min']) / range_size
df_norm[col] = np.clip(df_norm[col], 0, 1) # Ensure [0,1] range
# Normalize volume to [0, 1] using log scale
if 'volume' in df_norm.columns:
df_norm['volume'] = np.log1p(df_norm['volume'])
vol_max = df_norm['volume'].max()
if vol_max > 0:
df_norm['volume'] = df_norm['volume'] / vol_max
logger.debug(f"Applied symbol-grouped normalization for {symbol}")
# Fill any NaN values
df_norm = df_norm.fillna(0)
return df_norm
except Exception as e:
logger.error(f"Error in symbol-grouped normalization for {symbol}: {e}")
return df
def _initialize_cnn_with_data(self, symbol_features: Dict[str, Any]):
"""Initialize CNN with multi-symbol feature matrix"""
def _initialize_cnn_with_provider_data(self):
"""Initialize CNN using data provider's normalized feature extraction"""
try:
# Create combined feature matrix: [ETH_1m, ETH_1h, ETH_1d, BTC_1m]
combined_features = []
# ETH features (1m, 1h, 1d)
for timeframe in ['1m', '1h', '1d']:
eth_data = symbol_features.get('ETH/USDT', {}).get(timeframe)
if eth_data is not None and not eth_data.empty:
# Use last 60 candles for CNN input
recent_data = eth_data.tail(60)
features = recent_data[['open', 'high', 'low', 'close', 'volume']].values
combined_features.append(features.flatten())
features = self.data_provider.get_cnn_features_for_inference('ETH/USDT', timeframe, window_size=60)
if features is not None:
combined_features.append(features)
# BTC features (1m)
btc_data = symbol_features.get('BTC/USDT', {}).get('1m')
if btc_data is not None and not btc_data.empty:
recent_data = btc_data.tail(60)
features = recent_data[['open', 'high', 'low', 'close', 'volume']].values
combined_features.append(features.flatten())
btc_features = self.data_provider.get_cnn_features_for_inference('BTC/USDT', '1m', window_size=60)
if btc_features is not None:
combined_features.append(btc_features)
if combined_features:
# Concatenate all features
full_features = np.concatenate(combined_features)
logger.info(f"CNN initialized with {len(full_features)} multi-symbol features")
logger.info(f"CNN initialized with {len(full_features)} multi-symbol normalized features")
# Store for model access
if not hasattr(self, 'model_historical_features'):
@@ -2519,39 +2446,16 @@ class TradingOrchestrator:
self.model_historical_features['cnn'] = full_features
except Exception as e:
logger.error(f"Error initializing CNN with historical data: {e}")
logger.error(f"Error initializing CNN with provider data: {e}")
def _initialize_dqn_with_data(self, symbol_features: Dict[str, Any]):
"""Initialize DQN with multi-symbol state vectors"""
def _initialize_dqn_with_provider_data(self, symbols_timeframes: List[Tuple[str, str]]):
"""Initialize DQN using data provider's normalized state vector creation"""
try:
# Create comprehensive state vector combining all symbols and timeframes
state_components = []
# Use data provider's DQN state creation
state_vector = self.data_provider.get_dqn_state_for_inference(symbols_timeframes, target_size=100)
for symbol in ['ETH/USDT', 'BTC/USDT']:
timeframes = ['1m', '1h', '1d'] if symbol == 'ETH/USDT' else ['1m']
for timeframe in timeframes:
data = symbol_features.get(symbol, {}).get(timeframe)
if data is not None and not data.empty:
# Extract key features for state
latest = data.iloc[-1]
state_features = [
latest['close'], # Current price
latest['volume'], # Current volume
data['close'].pct_change().iloc[-1] if len(data) > 1 else 0, # Price change
]
state_components.extend(state_features)
if state_components:
# Pad or truncate to expected DQN state size
target_size = 100 # DQN expects 100-dimensional state
if len(state_components) < target_size:
state_components.extend([0] * (target_size - len(state_components)))
else:
state_components = state_components[:target_size]
state_vector = np.array(state_components, dtype=np.float32)
logger.info(f"DQN initialized with {len(state_vector)} dimensional multi-symbol state")
if state_vector is not None:
logger.info(f"DQN initialized with {len(state_vector)} dimensional normalized multi-symbol state")
# Store for model access
if not hasattr(self, 'model_historical_features'):
@@ -2559,30 +2463,16 @@ class TradingOrchestrator:
self.model_historical_features['dqn'] = state_vector
except Exception as e:
logger.error(f"Error initializing DQN with historical data: {e}")
logger.error(f"Error initializing DQN with provider data: {e}")
def _initialize_transformer_with_data(self, symbol_features: Dict[str, Any]):
"""Initialize Transformer with multi-symbol sequence data"""
def _initialize_transformer_with_provider_data(self, symbols_timeframes: List[Tuple[str, str]]):
"""Initialize Transformer using data provider's normalized sequence creation"""
try:
# Prepare sequence data for transformer
sequences = []
# ETH sequences
for timeframe in ['1m', '1h', '1d']:
eth_data = symbol_features.get('ETH/USDT', {}).get(timeframe)
if eth_data is not None and not eth_data.empty:
# Use last 150 points as sequence
sequence = eth_data.tail(150)[['open', 'high', 'low', 'close', 'volume']].values
sequences.append(sequence)
# BTC sequence
btc_data = symbol_features.get('BTC/USDT', {}).get('1m')
if btc_data is not None and not btc_data.empty:
sequence = btc_data.tail(150)[['open', 'high', 'low', 'close', 'volume']].values
sequences.append(sequence)
# Use data provider's transformer sequence creation
sequences = self.data_provider.get_transformer_sequences_for_inference(symbols_timeframes, seq_length=150)
if sequences:
logger.info(f"Transformer initialized with {len(sequences)} multi-symbol sequences")
logger.info(f"Transformer initialized with {len(sequences)} normalized multi-symbol sequences")
# Store for model access
if not hasattr(self, 'model_historical_features'):
@@ -2590,10 +2480,10 @@ class TradingOrchestrator:
self.model_historical_features['transformer'] = sequences
except Exception as e:
logger.error(f"Error initializing Transformer with historical data: {e}")
logger.error(f"Error initializing Transformer with provider data: {e}")
def _initialize_decision_with_data(self, symbol_features: Dict[str, Any]):
"""Initialize Decision Fusion with comprehensive multi-symbol features"""
def _initialize_decision_with_provider_data(self, symbol_features: Dict[str, Dict[str, pd.DataFrame]]):
"""Initialize Decision Fusion using data provider's feature aggregation"""
try:
# Aggregate all available features for decision fusion
all_features = {}
@@ -2611,7 +2501,7 @@ class TradingOrchestrator:
}
if all_features:
logger.info(f"Decision Fusion initialized with {len(all_features)} symbol-timeframe combinations")
logger.info(f"Decision Fusion initialized with {len(all_features)} normalized symbol-timeframe combinations")
# Store for model access
if not hasattr(self, 'model_historical_features'):
@@ -2619,7 +2509,7 @@ class TradingOrchestrator:
self.model_historical_features['decision'] = all_features
except Exception as e:
logger.error(f"Error initializing Decision Fusion with historical data: {e}")
logger.error(f"Error initializing Decision Fusion with provider data: {e}")
def get_ohlcv_data(self, symbol: str, timeframe: str, limit: int = 300) -> List:
"""Get OHLCV data for a symbol with specified timeframe and limit."""

View File

@@ -926,33 +926,44 @@ class CleanTradingDashboard:
return html.P(f"Error: {str(e)}", className="text-danger")
@self.app.callback(
Output('training-status', 'children'),
[Output('training-status', 'children'),
Output('training-status', 'className')],
[Input('start-training-btn', 'n_clicks'),
Input('stop-training-btn', 'n_clicks')],
prevent_initial_call=True
Input('stop-training-btn', 'n_clicks'),
Input('interval-component', 'n_intervals')], # Auto-update on interval
prevent_initial_call=False # Allow initial call to set status
)
def control_training(start_clicks, stop_clicks):
def control_training(start_clicks, stop_clicks, n_intervals):
try:
from utils.training_integration import get_unified_training_manager
manager = get_unified_training_manager(
orchestrator=self.orchestrator,
data_provider=self.data_provider,
dashboard=self
)
# Use orchestrator's enhanced training system directly
if not hasattr(self.orchestrator, 'enhanced_training_system') or not self.orchestrator.enhanced_training_system:
return "Not Available", "badge bg-danger small"
ctx = dash.callback_context
if not ctx.triggered:
raise PreventUpdate
# Check if this is triggered by button clicks
if ctx.triggered:
trigger_id = ctx.triggered[0]['prop_id'].split('.')[0]
if trigger_id == 'start-training-btn':
ok = manager.start()
return 'Running' if ok else 'Error'
self.orchestrator.start_enhanced_training()
return 'Running', 'badge bg-success small'
elif trigger_id == 'stop-training-btn':
ok = manager.stop()
return 'Stopped' if ok else 'Error'
return 'Idle'
self.orchestrator.stop_enhanced_training()
return 'Stopped', 'badge bg-warning small'
# Auto-update: Check actual training status
if hasattr(self.orchestrator.enhanced_training_system, 'is_training'):
if self.orchestrator.enhanced_training_system.is_training:
return 'Running', 'badge bg-success small'
else:
return 'Idle', 'badge bg-secondary small'
else:
# Default to Running since training auto-starts
return 'Running', 'badge bg-success small'
except Exception as e:
logger.error(f"Training control error: {e}")
return 'Error'
logger.error(f"Training status error: {e}")
return 'Error', 'badge bg-danger small'
@self.app.callback(
[Output('eth-cob-content', 'children'),

View File

@@ -173,7 +173,7 @@ class DashboardLayoutManager:
], className="d-flex align-items-center mb-1"),
html.Div([
html.Span("Training:", className="small me-1"),
html.Span(id="training-status", children="Idle", className="badge bg-secondary small")
html.Span(id="training-status", children="Starting...", className="badge bg-primary small")
])
], className="mb-2"),