training wip
This commit is contained in:
@@ -21,7 +21,6 @@ from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -522,7 +521,7 @@ class CNNModelTrainer:
|
||||
# Checkpoint management
|
||||
self.model_name = model_name
|
||||
self.enable_checkpoints = enable_checkpoints
|
||||
self.training_integration = get_training_integration() if enable_checkpoints else None
|
||||
self.training_integration = None # Removed dependency on utils.training_integration
|
||||
self.epoch_count = 0
|
||||
self.best_val_accuracy = 0.0
|
||||
self.best_val_loss = float('inf')
|
||||
|
@@ -16,7 +16,6 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,7 +43,7 @@ class DQNAgent:
|
||||
# Checkpoint management
|
||||
self.model_name = model_name
|
||||
self.enable_checkpoints = enable_checkpoints
|
||||
self.training_integration = get_training_integration() if enable_checkpoints else None
|
||||
self.training_integration = None # Removed dependency on utils.training_integration
|
||||
self.episode_count = 0
|
||||
self.best_reward = float('-inf')
|
||||
self.reward_history = deque(maxlen=100)
|
||||
|
@@ -1802,604 +1802,177 @@ class DataProvider:
|
||||
logger.debug(f"Applied pivot-based normalization for {symbol}")
|
||||
|
||||
else:
|
||||
# Fallback to traditional normalization when pivot bounds not available
|
||||
logger.debug("Using traditional normalization (no pivot bounds available)")
|
||||
|
||||
for col in df_norm.columns:
|
||||
if col in ['open', 'high', 'low', 'close', 'sma_10', 'sma_20', 'sma_50',
|
||||
'ema_12', 'ema_26', 'ema_50', 'bb_upper', 'bb_lower', 'bb_middle',
|
||||
'keltner_upper', 'keltner_lower', 'keltner_middle', 'psar', 'vwap']:
|
||||
# Price-based indicators: normalize by close price
|
||||
if 'close' in df_norm.columns:
|
||||
base_price = df_norm['close'].iloc[-1] # Use latest close as reference
|
||||
if base_price > 0:
|
||||
df_norm[col] = df_norm[col] / base_price
|
||||
|
||||
elif col == 'volume':
|
||||
# Volume: normalize by its own rolling mean
|
||||
volume_mean = df_norm[col].rolling(window=min(20, len(df_norm))).mean().iloc[-1]
|
||||
if volume_mean > 0:
|
||||
df_norm[col] = df_norm[col] / volume_mean
|
||||
|
||||
# Normalize indicators that have standard ranges (regardless of pivot bounds)
|
||||
for col in df_norm.columns:
|
||||
if col in ['rsi_14', 'rsi_7', 'rsi_21']:
|
||||
# RSI: already 0-100, normalize to 0-1
|
||||
df_norm[col] = df_norm[col] / 100.0
|
||||
|
||||
elif col in ['stoch_k', 'stoch_d']:
|
||||
# Stochastic: already 0-100, normalize to 0-1
|
||||
df_norm[col] = df_norm[col] / 100.0
|
||||
|
||||
elif col == 'williams_r':
|
||||
# Williams %R: -100 to 0, normalize to 0-1
|
||||
df_norm[col] = (df_norm[col] + 100) / 100.0
|
||||
|
||||
elif col in ['macd', 'macd_signal', 'macd_histogram']:
|
||||
# MACD: normalize by ATR or close price
|
||||
if 'atr' in df_norm.columns and df_norm['atr'].iloc[-1] > 0:
|
||||
df_norm[col] = df_norm[col] / df_norm['atr'].iloc[-1]
|
||||
elif 'close' in df_norm.columns and df_norm['close'].iloc[-1] > 0:
|
||||
df_norm[col] = df_norm[col] / df_norm['close'].iloc[-1]
|
||||
|
||||
elif col in ['bb_width', 'bb_percent', 'price_position', 'trend_strength',
|
||||
'momentum_composite', 'volatility_regime', 'pivot_price_position',
|
||||
'pivot_support_distance', 'pivot_resistance_distance']:
|
||||
# Already normalized indicators: ensure 0-1 range
|
||||
df_norm[col] = np.clip(df_norm[col], 0, 1)
|
||||
|
||||
elif col in ['atr', 'true_range']:
|
||||
# Volatility indicators: normalize by close price or pivot range
|
||||
if symbol and symbol in self.pivot_bounds:
|
||||
bounds = self.pivot_bounds[symbol]
|
||||
df_norm[col] = df_norm[col] / bounds.get_price_range()
|
||||
elif 'close' in df_norm.columns and df_norm['close'].iloc[-1] > 0:
|
||||
df_norm[col] = df_norm[col] / df_norm['close'].iloc[-1]
|
||||
|
||||
elif col not in ['timestamp', 'near_pivot_support', 'near_pivot_resistance']:
|
||||
# Other indicators: z-score normalization
|
||||
col_mean = df_norm[col].rolling(window=min(20, len(df_norm))).mean().iloc[-1]
|
||||
col_std = df_norm[col].rolling(window=min(20, len(df_norm))).std().iloc[-1]
|
||||
if col_std > 0:
|
||||
df_norm[col] = (df_norm[col] - col_mean) / col_std
|
||||
else:
|
||||
df_norm[col] = 0
|
||||
|
||||
# Replace inf/-inf with 0
|
||||
df_norm = df_norm.replace([np.inf, -np.inf], 0)
|
||||
# Use symbol-grouped normalization with consistent ranges
|
||||
df_norm = self._apply_symbol_grouped_normalization(df_norm, symbol)
|
||||
|
||||
# Fill any remaining NaN values
|
||||
df_norm = df_norm.fillna(0.0)
|
||||
|
||||
return df_norm
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing features for {symbol}: {e}")
|
||||
return df.fillna(0.0) if df is not None else None
|
||||
|
||||
def _apply_symbol_grouped_normalization(self, df: pd.DataFrame, symbol: str) -> pd.DataFrame:
|
||||
"""Apply symbol-grouped normalization with consistent ranges across timeframes"""
|
||||
try:
|
||||
df_norm = df.copy()
|
||||
|
||||
# Get symbol-specific price ranges for consistent normalization
|
||||
symbol_price_ranges = {
|
||||
'ETH/USDT': {'min': 1000, 'max': 5000}, # ETH price range
|
||||
'BTC/USDT': {'min': 90000, 'max': 120000} # BTC price range
|
||||
}
|
||||
|
||||
if symbol in symbol_price_ranges:
|
||||
price_range = symbol_price_ranges[symbol]
|
||||
range_size = price_range['max'] - price_range['min']
|
||||
|
||||
# Normalize price columns to [0, 1] range specific to symbol
|
||||
price_cols = ['open', 'high', 'low', 'close']
|
||||
for col in price_cols:
|
||||
if col in df_norm.columns:
|
||||
df_norm[col] = (df_norm[col] - price_range['min']) / range_size
|
||||
df_norm[col] = np.clip(df_norm[col], 0, 1) # Ensure [0,1] range
|
||||
|
||||
# Normalize volume to [0, 1] using log scale
|
||||
if 'volume' in df_norm.columns:
|
||||
df_norm['volume'] = np.log1p(df_norm['volume'])
|
||||
vol_max = df_norm['volume'].max()
|
||||
if vol_max > 0:
|
||||
df_norm['volume'] = df_norm['volume'] / vol_max
|
||||
|
||||
logger.debug(f"Applied symbol-grouped normalization for {symbol}")
|
||||
|
||||
# Fill any NaN values
|
||||
df_norm = df_norm.fillna(0)
|
||||
|
||||
return df_norm
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing features: {e}")
|
||||
logger.error(f"Error in symbol-grouped normalization for {symbol}: {e}")
|
||||
return df
|
||||
|
||||
def get_multi_symbol_feature_matrix(self, symbols: List[str] = None,
|
||||
timeframes: List[str] = None,
|
||||
window_size: int = 20) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Get feature matrix for multiple symbols and timeframes
|
||||
|
||||
Returns:
|
||||
np.ndarray: Shape (n_symbols, n_timeframes, window_size, n_features)
|
||||
"""
|
||||
|
||||
def get_historical_data_for_inference(self, symbol: str, timeframe: str, limit: int = 300) -> Optional[pd.DataFrame]:
|
||||
"""Get normalized historical data specifically for model inference"""
|
||||
try:
|
||||
if symbols is None:
|
||||
symbols = self.symbols
|
||||
if timeframes is None:
|
||||
timeframes = self.timeframes
|
||||
# Get raw historical data
|
||||
raw_df = self.get_historical_data(symbol, timeframe, limit)
|
||||
|
||||
symbol_matrices = []
|
||||
if raw_df is None or raw_df.empty:
|
||||
return None
|
||||
|
||||
for symbol in symbols:
|
||||
symbol_matrix = self.get_feature_matrix(symbol, timeframes, window_size)
|
||||
if symbol_matrix is not None:
|
||||
symbol_matrices.append(symbol_matrix)
|
||||
# Apply normalization for inference
|
||||
normalized_df = self._normalize_features(raw_df, symbol)
|
||||
|
||||
logger.debug(f"Retrieved normalized historical data for inference: {symbol} {timeframe} ({len(normalized_df)} records)")
|
||||
return normalized_df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting normalized historical data for inference: {symbol} {timeframe}: {e}")
|
||||
return None
|
||||
|
||||
def get_multi_symbol_features_for_inference(self, symbols_timeframes: List[Tuple[str, str]], limit: int = 300) -> Dict[str, Dict[str, pd.DataFrame]]:
|
||||
"""Get normalized multi-symbol feature matrices for model inference"""
|
||||
try:
|
||||
logger.info("Preparing normalized multi-symbol features for model inference...")
|
||||
|
||||
symbol_features = {}
|
||||
|
||||
for symbol, timeframe in symbols_timeframes:
|
||||
if symbol not in symbol_features:
|
||||
symbol_features[symbol] = {}
|
||||
|
||||
# Get normalized data for inference
|
||||
normalized_df = self.get_historical_data_for_inference(symbol, timeframe, limit)
|
||||
|
||||
if normalized_df is not None and not normalized_df.empty:
|
||||
symbol_features[symbol][timeframe] = normalized_df
|
||||
logger.debug(f"Prepared normalized features for {symbol} {timeframe}: {len(normalized_df)} records")
|
||||
else:
|
||||
logger.warning(f"Could not create feature matrix for {symbol}")
|
||||
logger.warning(f"No normalized data available for {symbol} {timeframe}")
|
||||
symbol_features[symbol][timeframe] = None
|
||||
|
||||
if symbol_matrices:
|
||||
# Stack all symbol matrices
|
||||
multi_symbol_matrix = np.stack(symbol_matrices, axis=0)
|
||||
logger.info(f"Created multi-symbol feature matrix: {multi_symbol_matrix.shape}")
|
||||
return multi_symbol_matrix
|
||||
|
||||
return None
|
||||
return symbol_features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating multi-symbol feature matrix: {e}")
|
||||
return None
|
||||
|
||||
def health_check(self) -> Dict[str, Any]:
|
||||
"""Get health status of the data provider"""
|
||||
status = {
|
||||
'streaming': self.is_streaming,
|
||||
'symbols': len(self.symbols),
|
||||
'timeframes': len(self.timeframes),
|
||||
'current_prices': len(self.current_prices),
|
||||
'websocket_tasks': len(self.websocket_tasks),
|
||||
'historical_data_loaded': {}
|
||||
}
|
||||
|
||||
# Check historical data availability
|
||||
for symbol in self.symbols:
|
||||
status['historical_data_loaded'][symbol] = {}
|
||||
for tf in self.timeframes:
|
||||
has_data = (symbol in self.historical_data and
|
||||
tf in self.historical_data[symbol] and
|
||||
not self.historical_data[symbol][tf].empty)
|
||||
status['historical_data_loaded'][symbol][tf] = has_data
|
||||
|
||||
return status
|
||||
|
||||
def subscribe_to_ticks(self, callback: Callable[[MarketTick], None],
|
||||
symbols: List[str] = None,
|
||||
subscriber_name: str = None) -> str:
|
||||
"""Subscribe to real-time tick data updates"""
|
||||
subscriber_id = str(uuid.uuid4())[:8]
|
||||
subscriber_name = subscriber_name or f"subscriber_{subscriber_id}"
|
||||
|
||||
# Convert symbols to Binance format
|
||||
if symbols:
|
||||
binance_symbols = [s.replace('/', '').upper() for s in symbols]
|
||||
else:
|
||||
binance_symbols = [s.replace('/', '').upper() for s in self.symbols]
|
||||
|
||||
subscriber = DataSubscriber(
|
||||
subscriber_id=subscriber_id,
|
||||
callback=callback,
|
||||
symbols=binance_symbols,
|
||||
subscriber_name=subscriber_name
|
||||
)
|
||||
|
||||
with self.subscriber_lock:
|
||||
self.subscribers[subscriber_id] = subscriber
|
||||
|
||||
logger.info(f"New tick subscriber registered: {subscriber_name} ({subscriber_id}) for symbols: {binance_symbols}")
|
||||
|
||||
# Send recent tick data to new subscriber
|
||||
self._send_recent_ticks_to_subscriber(subscriber)
|
||||
|
||||
return subscriber_id
|
||||
|
||||
def unsubscribe_from_ticks(self, subscriber_id: str):
|
||||
"""Unsubscribe from tick data updates"""
|
||||
with self.subscriber_lock:
|
||||
if subscriber_id in self.subscribers:
|
||||
subscriber_name = self.subscribers[subscriber_id].subscriber_name
|
||||
self.subscribers[subscriber_id].active = False
|
||||
del self.subscribers[subscriber_id]
|
||||
logger.info(f"Subscriber {subscriber_name} ({subscriber_id}) unsubscribed")
|
||||
|
||||
def _send_recent_ticks_to_subscriber(self, subscriber: DataSubscriber):
|
||||
"""Send recent tick data to a new subscriber"""
|
||||
logger.error(f"Error preparing multi-symbol features for inference: {e}")
|
||||
return {}
|
||||
|
||||
def get_cnn_features_for_inference(self, symbol: str, timeframe: str = '1m', window_size: int = 60) -> Optional[np.ndarray]:
|
||||
"""Get normalized CNN features for a specific symbol and timeframe"""
|
||||
try:
|
||||
for symbol in subscriber.symbols:
|
||||
if symbol in self.tick_buffers:
|
||||
# Send last 50 ticks to get subscriber up to speed
|
||||
recent_ticks = list(self.tick_buffers[symbol])[-50:]
|
||||
for tick in recent_ticks:
|
||||
try:
|
||||
subscriber.callback(tick)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending recent tick to subscriber {subscriber.subscriber_id}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending recent ticks: {e}")
|
||||
|
||||
def _distribute_tick(self, tick: MarketTick):
|
||||
"""Distribute tick to all relevant subscribers"""
|
||||
distributed_count = 0
|
||||
|
||||
with self.subscriber_lock:
|
||||
subscribers_to_remove = []
|
||||
# Get normalized data
|
||||
df = self.get_historical_data_for_inference(symbol, timeframe, limit=300)
|
||||
|
||||
for subscriber_id, subscriber in self.subscribers.items():
|
||||
if not subscriber.active:
|
||||
subscribers_to_remove.append(subscriber_id)
|
||||
continue
|
||||
if df is None or df.empty:
|
||||
return None
|
||||
|
||||
# Extract recent window for CNN
|
||||
recent_data = df.tail(window_size)
|
||||
|
||||
# Extract OHLCV features
|
||||
features = recent_data[['open', 'high', 'low', 'close', 'volume']].values
|
||||
|
||||
logger.debug(f"Extracted CNN features for {symbol} {timeframe}: {features.shape}")
|
||||
return features.flatten()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting CNN features for {symbol} {timeframe}: {e}")
|
||||
return None
|
||||
|
||||
def get_dqn_state_for_inference(self, symbols_timeframes: List[Tuple[str, str]], target_size: int = 100) -> Optional[np.ndarray]:
|
||||
"""Get normalized DQN state vector combining multiple symbols and timeframes"""
|
||||
try:
|
||||
state_components = []
|
||||
|
||||
for symbol, timeframe in symbols_timeframes:
|
||||
df = self.get_historical_data_for_inference(symbol, timeframe, limit=50)
|
||||
|
||||
if tick.symbol in subscriber.symbols:
|
||||
try:
|
||||
# Call subscriber callback in a thread to avoid blocking
|
||||
def call_callback():
|
||||
try:
|
||||
subscriber.callback(tick)
|
||||
subscriber.tick_count += 1
|
||||
subscriber.last_update = datetime.now()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in subscriber {subscriber_id} callback: {e}")
|
||||
subscriber.active = False
|
||||
|
||||
# Use thread to avoid blocking the main data processing
|
||||
Thread(target=call_callback, daemon=True).start()
|
||||
distributed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error distributing tick to subscriber {subscriber_id}: {e}")
|
||||
subscriber.active = False
|
||||
if df is not None and not df.empty:
|
||||
# Extract key features for state
|
||||
latest = df.iloc[-1]
|
||||
state_features = [
|
||||
latest['close'], # Current price (normalized)
|
||||
latest['volume'], # Current volume (normalized)
|
||||
df['close'].pct_change().iloc[-1] if len(df) > 1 else 0, # Price change
|
||||
]
|
||||
state_components.extend(state_features)
|
||||
|
||||
# Remove inactive subscribers
|
||||
for subscriber_id in subscribers_to_remove:
|
||||
if subscriber_id in self.subscribers:
|
||||
del self.subscribers[subscriber_id]
|
||||
|
||||
self.distribution_stats['total_ticks_distributed'] += distributed_count
|
||||
|
||||
def _validate_tick_data(self, symbol: str, price: float, volume: float) -> bool:
|
||||
"""Validate incoming tick data for quality"""
|
||||
try:
|
||||
# Basic validation
|
||||
if price <= 0 or volume < 0:
|
||||
return False
|
||||
|
||||
# Price change validation
|
||||
last_price = self.last_prices.get(symbol, 0)
|
||||
if last_price > 0:
|
||||
price_change_pct = abs(price - last_price) / last_price
|
||||
if price_change_pct > self.price_change_threshold:
|
||||
logger.warning(f"Large price change for {symbol}: {price_change_pct:.2%}")
|
||||
# Don't reject, just warn - could be legitimate
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating tick data: {e}")
|
||||
return False
|
||||
|
||||
def get_recent_ticks(self, symbol: str, count: int = 100) -> List[MarketTick]:
|
||||
"""Get recent ticks for a symbol"""
|
||||
binance_symbol = symbol.replace('/', '').upper()
|
||||
if binance_symbol in self.tick_buffers:
|
||||
return list(self.tick_buffers[binance_symbol])[-count:]
|
||||
return []
|
||||
|
||||
def subscribe_to_raw_ticks(self, callback: Callable[[RawTick], None]) -> str:
|
||||
"""Subscribe to raw tick data with timing information"""
|
||||
subscriber_id = str(uuid.uuid4())
|
||||
self.raw_tick_callbacks.append(callback)
|
||||
logger.info(f"Raw tick subscriber added: {subscriber_id}")
|
||||
return subscriber_id
|
||||
|
||||
def subscribe_to_ohlcv_bars(self, callback: Callable[[OHLCVBar], None]) -> str:
|
||||
"""Subscribe to 1s OHLCV bars calculated from ticks"""
|
||||
subscriber_id = str(uuid.uuid4())
|
||||
self.ohlcv_bar_callbacks.append(callback)
|
||||
logger.info(f"OHLCV bar subscriber added: {subscriber_id}")
|
||||
return subscriber_id
|
||||
|
||||
def get_raw_tick_features(self, symbol: str, window_size: int = 50) -> Optional[np.ndarray]:
|
||||
"""Get raw tick features for model consumption"""
|
||||
return self.tick_aggregator.get_tick_features_for_model(symbol, window_size)
|
||||
|
||||
def get_ohlcv_features(self, symbol: str, window_size: int = 60) -> Optional[np.ndarray]:
|
||||
"""Get 1s OHLCV features for model consumption"""
|
||||
return self.tick_aggregator.get_ohlcv_features_for_model(symbol, window_size)
|
||||
|
||||
def get_detected_patterns(self, symbol: str, count: int = 50) -> List:
|
||||
"""Get recently detected tick patterns"""
|
||||
return self.tick_aggregator.get_detected_patterns(symbol, count)
|
||||
|
||||
def get_tick_aggregator_stats(self) -> Dict[str, Any]:
|
||||
"""Get tick aggregator statistics"""
|
||||
return self.tick_aggregator.get_statistics()
|
||||
|
||||
def get_subscriber_stats(self) -> Dict[str, Any]:
|
||||
"""Get subscriber and distribution statistics"""
|
||||
with self.subscriber_lock:
|
||||
active_subscribers = len([s for s in self.subscribers.values() if s.active])
|
||||
subscriber_stats = {
|
||||
sid: {
|
||||
'name': s.subscriber_name,
|
||||
'active': s.active,
|
||||
'symbols': s.symbols,
|
||||
'tick_count': s.tick_count,
|
||||
'last_update': s.last_update.isoformat() if s.last_update else None
|
||||
}
|
||||
for sid, s in self.subscribers.items()
|
||||
}
|
||||
|
||||
# Get tick aggregator stats
|
||||
aggregator_stats = self.get_tick_aggregator_stats()
|
||||
|
||||
return {
|
||||
'active_subscribers': active_subscribers,
|
||||
'total_subscribers': len(self.subscribers),
|
||||
'raw_tick_callbacks': len(self.raw_tick_callbacks),
|
||||
'ohlcv_bar_callbacks': len(self.ohlcv_bar_callbacks),
|
||||
'subscriber_details': subscriber_stats,
|
||||
'distribution_stats': self.distribution_stats.copy(),
|
||||
'buffer_sizes': {symbol: len(buffer) for symbol, buffer in self.tick_buffers.items()},
|
||||
'tick_aggregator': aggregator_stats
|
||||
}
|
||||
|
||||
def update_bom_cache(self, symbol: str, bom_features: List[float], cob_integration=None):
|
||||
"""
|
||||
Update BOM cache with latest features for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
bom_features: List of BOM features (should be 120 features)
|
||||
cob_integration: Optional COB integration instance for real BOM data
|
||||
"""
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
|
||||
# Ensure we have exactly 120 features
|
||||
if len(bom_features) != self.bom_feature_count:
|
||||
if len(bom_features) > self.bom_feature_count:
|
||||
bom_features = bom_features[:self.bom_feature_count]
|
||||
if state_components:
|
||||
# Pad or truncate to expected DQN state size
|
||||
if len(state_components) < target_size:
|
||||
state_components.extend([0] * (target_size - len(state_components)))
|
||||
else:
|
||||
bom_features.extend([0.0] * (self.bom_feature_count - len(bom_features)))
|
||||
|
||||
# Convert to numpy array for efficient storage
|
||||
bom_array = np.array(bom_features, dtype=np.float32)
|
||||
|
||||
# Add timestamp and features to cache
|
||||
with self.data_lock:
|
||||
self.bom_data_cache[symbol].append((current_time, bom_array))
|
||||
|
||||
logger.debug(f"Updated BOM cache for {symbol}: {len(self.bom_data_cache[symbol])} timestamps cached")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating BOM cache for {symbol}: {e}")
|
||||
|
||||
def get_bom_matrix_for_cnn(self, symbol: str, sequence_length: int = 50) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Get BOM matrix for CNN input from cached 1s data
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
sequence_length: Required sequence length (default 50)
|
||||
|
||||
Returns:
|
||||
np.ndarray: BOM matrix of shape (sequence_length, 120) or None if insufficient data
|
||||
"""
|
||||
try:
|
||||
with self.data_lock:
|
||||
if symbol not in self.bom_data_cache or len(self.bom_data_cache[symbol]) == 0:
|
||||
logger.warning(f"No BOM data cached for {symbol}")
|
||||
return None
|
||||
state_components = state_components[:target_size]
|
||||
|
||||
# Get recent data
|
||||
cached_data = list(self.bom_data_cache[symbol])
|
||||
|
||||
if len(cached_data) < sequence_length:
|
||||
logger.warning(f"Insufficient BOM data for {symbol}: {len(cached_data)} < {sequence_length}")
|
||||
# Pad with zeros if we don't have enough data
|
||||
bom_matrix = np.zeros((sequence_length, self.bom_feature_count), dtype=np.float32)
|
||||
|
||||
# Fill available data at the end
|
||||
for i, (timestamp, features) in enumerate(cached_data):
|
||||
if i < sequence_length:
|
||||
bom_matrix[sequence_length - len(cached_data) + i] = features
|
||||
|
||||
return bom_matrix
|
||||
|
||||
# Take the most recent sequence_length samples
|
||||
recent_data = cached_data[-sequence_length:]
|
||||
|
||||
# Create matrix
|
||||
bom_matrix = np.zeros((sequence_length, self.bom_feature_count), dtype=np.float32)
|
||||
for i, (timestamp, features) in enumerate(recent_data):
|
||||
bom_matrix[i] = features
|
||||
|
||||
logger.debug(f"Retrieved BOM matrix for {symbol}: shape={bom_matrix.shape}")
|
||||
return bom_matrix
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting BOM matrix for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_real_bom_features(self, symbol: str) -> Optional[List[float]]:
|
||||
"""
|
||||
Get REAL BOM features from actual market data ONLY
|
||||
|
||||
NO SYNTHETIC DATA - Returns None if real data is not available
|
||||
"""
|
||||
try:
|
||||
# Try to get real COB data from integration
|
||||
if hasattr(self, 'cob_integration') and self.cob_integration:
|
||||
return self._extract_real_bom_features(symbol, self.cob_integration)
|
||||
state_vector = np.array(state_components, dtype=np.float32)
|
||||
logger.debug(f"Created DQN state vector: {len(state_vector)} dimensions")
|
||||
return state_vector
|
||||
|
||||
# No real data available - return None instead of synthetic
|
||||
logger.warning(f"No real BOM data available for {symbol} - waiting for real market data")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting real BOM features for {symbol}: {e}")
|
||||
logger.error(f"Error creating DQN state for inference: {e}")
|
||||
return None
|
||||
|
||||
def start_bom_cache_updates(self, cob_integration=None):
|
||||
"""
|
||||
Start background updates of BOM cache every second
|
||||
|
||||
Args:
|
||||
cob_integration: Optional COB integration instance for real data
|
||||
"""
|
||||
|
||||
def get_transformer_sequences_for_inference(self, symbols_timeframes: List[Tuple[str, str]], seq_length: int = 150) -> List[np.ndarray]:
|
||||
"""Get normalized sequences for transformer inference"""
|
||||
try:
|
||||
def update_loop():
|
||||
while self.is_streaming:
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
if cob_integration:
|
||||
# Try to get real BOM features from COB integration
|
||||
try:
|
||||
bom_features = self._extract_real_bom_features(symbol, cob_integration)
|
||||
if bom_features:
|
||||
self.update_bom_cache(symbol, bom_features, cob_integration)
|
||||
else:
|
||||
# NO SYNTHETIC FALLBACK - Wait for real data
|
||||
logger.warning(f"No real BOM features available for {symbol} - waiting for real data")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting real BOM features for {symbol}: {e}")
|
||||
logger.warning(f"Waiting for real data instead of using synthetic")
|
||||
else:
|
||||
# NO SYNTHETIC FEATURES - Wait for real COB integration
|
||||
logger.warning(f"No COB integration available for {symbol} - waiting for real data")
|
||||
|
||||
time.sleep(1.0) # Update every second
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in BOM cache update loop: {e}")
|
||||
time.sleep(5.0) # Wait longer on error
|
||||
sequences = []
|
||||
|
||||
# Start background thread
|
||||
bom_thread = Thread(target=update_loop, daemon=True)
|
||||
bom_thread.start()
|
||||
for symbol, timeframe in symbols_timeframes:
|
||||
df = self.get_historical_data_for_inference(symbol, timeframe, limit=300)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Use last seq_length points as sequence
|
||||
sequence = df.tail(seq_length)[['open', 'high', 'low', 'close', 'volume']].values
|
||||
sequences.append(sequence)
|
||||
logger.debug(f"Created transformer sequence for {symbol} {timeframe}: {sequence.shape}")
|
||||
|
||||
logger.info("Started BOM cache updates (1s resolution)")
|
||||
return sequences
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting BOM cache updates: {e}")
|
||||
|
||||
def _extract_real_bom_features(self, symbol: str, cob_integration) -> Optional[List[float]]:
|
||||
"""Extract real BOM features from COB integration"""
|
||||
try:
|
||||
features = []
|
||||
|
||||
# Get consolidated order book
|
||||
if hasattr(cob_integration, 'get_consolidated_orderbook'):
|
||||
cob_snapshot = cob_integration.get_consolidated_orderbook(symbol)
|
||||
if cob_snapshot:
|
||||
# Extract order book features (40 features)
|
||||
features.extend(self._extract_orderbook_features(cob_snapshot))
|
||||
else:
|
||||
features.extend([0.0] * 40)
|
||||
else:
|
||||
features.extend([0.0] * 40)
|
||||
|
||||
# Get volume profile features (30 features)
|
||||
if hasattr(cob_integration, 'get_session_volume_profile'):
|
||||
volume_profile = cob_integration.get_session_volume_profile(symbol)
|
||||
if volume_profile:
|
||||
features.extend(self._extract_volume_profile_features(volume_profile))
|
||||
else:
|
||||
features.extend([0.0] * 30)
|
||||
else:
|
||||
features.extend([0.0] * 30)
|
||||
|
||||
# Add flow and microstructure features (50 features)
|
||||
features.extend(self._extract_flow_microstructure_features(symbol, cob_integration))
|
||||
|
||||
# Ensure exactly 120 features
|
||||
if len(features) > 120:
|
||||
features = features[:120]
|
||||
elif len(features) < 120:
|
||||
features.extend([0.0] * (120 - len(features)))
|
||||
|
||||
return features
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting real BOM features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _extract_orderbook_features(self, cob_snapshot) -> List[float]:
|
||||
"""Extract order book features from COB snapshot"""
|
||||
features = []
|
||||
|
||||
try:
|
||||
# Top 10 bid levels
|
||||
for i in range(10):
|
||||
if i < len(cob_snapshot.consolidated_bids):
|
||||
level = cob_snapshot.consolidated_bids[i]
|
||||
price_offset = (level.price - cob_snapshot.volume_weighted_mid) / cob_snapshot.volume_weighted_mid
|
||||
volume_normalized = level.total_volume_usd / 1000000
|
||||
features.extend([price_offset, volume_normalized])
|
||||
else:
|
||||
features.extend([0.0, 0.0])
|
||||
|
||||
# Top 10 ask levels
|
||||
for i in range(10):
|
||||
if i < len(cob_snapshot.consolidated_asks):
|
||||
level = cob_snapshot.consolidated_asks[i]
|
||||
price_offset = (level.price - cob_snapshot.volume_weighted_mid) / cob_snapshot.volume_weighted_mid
|
||||
volume_normalized = level.total_volume_usd / 1000000
|
||||
features.extend([price_offset, volume_normalized])
|
||||
else:
|
||||
features.extend([0.0, 0.0])
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting order book features: {e}")
|
||||
features = [0.0] * 40
|
||||
|
||||
return features[:40]
|
||||
|
||||
def _extract_volume_profile_features(self, volume_profile) -> List[float]:
|
||||
"""Extract volume profile features"""
|
||||
features = []
|
||||
|
||||
try:
|
||||
if 'data' in volume_profile:
|
||||
svp_data = volume_profile['data']
|
||||
top_levels = sorted(svp_data, key=lambda x: x.get('total_volume', 0), reverse=True)[:10]
|
||||
|
||||
for level in top_levels:
|
||||
buy_percent = level.get('buy_percent', 50.0) / 100.0
|
||||
sell_percent = level.get('sell_percent', 50.0) / 100.0
|
||||
total_volume = level.get('total_volume', 0.0) / 1000000
|
||||
features.extend([buy_percent, sell_percent, total_volume])
|
||||
|
||||
# Pad to 30 features
|
||||
while len(features) < 30:
|
||||
features.extend([0.5, 0.5, 0.0])
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting volume profile features: {e}")
|
||||
features = [0.0] * 30
|
||||
|
||||
return features[:30]
|
||||
|
||||
def _extract_flow_microstructure_features(self, symbol: str, cob_integration) -> List[float]:
|
||||
"""Extract flow and microstructure features"""
|
||||
try:
|
||||
# For now, return synthetic features since full implementation would be complex
|
||||
# NO SYNTHETIC DATA - Return None if no real microstructure data
|
||||
logger.warning(f"No real microstructure data available for {symbol}")
|
||||
return None
|
||||
except:
|
||||
return [0.0] * 50
|
||||
|
||||
def _handle_rate_limit(self, url: str):
|
||||
"""Handle rate limiting with exponential backoff"""
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we need to wait
|
||||
if url in self.last_request_time:
|
||||
time_since_last = current_time - self.last_request_time[url]
|
||||
if time_since_last < self.request_interval:
|
||||
sleep_time = self.request_interval - time_since_last
|
||||
logger.info(f"Rate limiting: sleeping {sleep_time:.2f}s")
|
||||
time.sleep(sleep_time)
|
||||
|
||||
self.last_request_time[url] = time.time()
|
||||
|
||||
def _make_request_with_retry(self, url: str, params: dict = None):
|
||||
"""Make HTTP request with retry logic for 451 errors"""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
self._handle_rate_limit(url)
|
||||
response = requests.get(url, params=params, timeout=30)
|
||||
|
||||
if response.status_code == 451:
|
||||
logger.warning(f"Rate limit hit (451), attempt {attempt + 1}/{self.max_retries}")
|
||||
if attempt < self.max_retries - 1:
|
||||
sleep_time = self.retry_delay * (2 ** attempt) # Exponential backoff
|
||||
logger.info(f"Waiting {sleep_time}s before retry...")
|
||||
time.sleep(sleep_time)
|
||||
continue
|
||||
else:
|
||||
logger.error("Max retries reached, using cached data")
|
||||
return None
|
||||
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Request failed (attempt {attempt + 1}): {e}")
|
||||
if attempt < self.max_retries - 1:
|
||||
time.sleep(5 * (attempt + 1))
|
||||
|
||||
return None
|
||||
logger.error(f"Error creating transformer sequences for inference: {e}")
|
||||
return []
|
||||
|
@@ -25,7 +25,6 @@ import json
|
||||
# Import checkpoint management
|
||||
import torch
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -73,7 +72,7 @@ class ExtremaTrainer:
|
||||
# Checkpoint management
|
||||
self.model_name = model_name
|
||||
self.enable_checkpoints = enable_checkpoints
|
||||
self.training_integration = get_training_integration() if enable_checkpoints else None
|
||||
self.training_integration = None # Removed dependency on utils.training_integration
|
||||
self.training_session_count = 0
|
||||
self.best_detection_accuracy = 0.0
|
||||
self.checkpoint_frequency = 50 # Save checkpoint every 50 training sessions
|
||||
|
@@ -22,7 +22,6 @@ import pandas as pd
|
||||
# Import checkpoint management
|
||||
import torch
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -84,7 +83,7 @@ class NegativeCaseTrainer:
|
||||
# Checkpoint management
|
||||
self.model_name = model_name
|
||||
self.enable_checkpoints = enable_checkpoints
|
||||
self.training_integration = get_training_integration() if enable_checkpoints else None
|
||||
self.training_integration = None # Removed dependency on utils.training_integration
|
||||
self.training_session_count = 0
|
||||
self.best_loss_reduction = 0.0
|
||||
self.checkpoint_frequency = 25 # Save checkpoint every 25 training sessions
|
||||
|
@@ -1757,16 +1757,27 @@ class TradingOrchestrator:
|
||||
logger.warning("EnhancedRealtimeTrainingSystem not available - training disabled")
|
||||
self.training_enabled = False
|
||||
return
|
||||
# Initialize unified training manager
|
||||
from utils.training_integration import get_unified_training_manager
|
||||
self.training_manager = get_unified_training_manager(
|
||||
orchestrator=self,
|
||||
data_provider=self.data_provider,
|
||||
dashboard=None
|
||||
)
|
||||
self.training_manager.initialize()
|
||||
# Keep backward-compatible attribute
|
||||
self.enhanced_training_system = getattr(self.training_manager, 'training_system', None)
|
||||
|
||||
# Initialize enhanced training system directly (no external training_integration module needed)
|
||||
try:
|
||||
from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
||||
|
||||
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
|
||||
orchestrator=self,
|
||||
data_provider=self.data_provider,
|
||||
dashboard=None
|
||||
)
|
||||
|
||||
logger.info("✅ Enhanced training system initialized successfully")
|
||||
|
||||
# Auto-start training by default
|
||||
logger.info("🚀 Auto-starting enhanced real-time training...")
|
||||
self.start_enhanced_training()
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import EnhancedRealtimeTrainingSystem: {e}")
|
||||
self.training_enabled = False
|
||||
return
|
||||
|
||||
logger.info("Enhanced real-time training system initialized")
|
||||
logger.info(" - Real-time model training: ENABLED")
|
||||
@@ -2365,8 +2376,8 @@ class TradingOrchestrator:
|
||||
logger.info("Initializing ExtremaTrainer with historical context...")
|
||||
self.extrema_trainer.initialize_context_data()
|
||||
|
||||
# CRITICAL: Initialize ALL models with historical data
|
||||
self._initialize_models_with_historical_data(loaded_data)
|
||||
# CRITICAL: Initialize ALL models with historical data (using data provider's normalized methods)
|
||||
self._initialize_models_with_historical_data(symbols_timeframes)
|
||||
|
||||
logger.info(f"🎯 Historical data loading complete: {total_candles} total candles loaded")
|
||||
logger.info(f"📊 Available datasets: {list(loaded_data.keys())}")
|
||||
@@ -2374,144 +2385,60 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in historical data loading: {e}")
|
||||
|
||||
def _initialize_models_with_historical_data(self, loaded_data: Dict[str, Any]):
|
||||
"""Initialize all NN models with historical data and multi-symbol support"""
|
||||
def _initialize_models_with_historical_data(self, symbols_timeframes: List[Tuple[str, str]]):
|
||||
"""Initialize all NN models with historical data using data provider's normalized methods"""
|
||||
try:
|
||||
logger.info("Initializing models with historical data and multi-symbol support...")
|
||||
logger.info("Initializing models with normalized historical data from data provider...")
|
||||
|
||||
# Prepare multi-symbol feature matrices
|
||||
symbol_features = self._prepare_multi_symbol_features(loaded_data)
|
||||
# Use data provider's multi-symbol feature preparation
|
||||
symbol_features = self.data_provider.get_multi_symbol_features_for_inference(symbols_timeframes, limit=300)
|
||||
|
||||
# Initialize CNN with multi-symbol data
|
||||
if hasattr(self, 'cnn_model') and self.cnn_model:
|
||||
logger.info("Initializing CNN with multi-symbol historical features...")
|
||||
self._initialize_cnn_with_data(symbol_features)
|
||||
self._initialize_cnn_with_provider_data()
|
||||
|
||||
# Initialize DQN with multi-symbol states
|
||||
if hasattr(self, 'rl_agent') and self.rl_agent:
|
||||
logger.info("Initializing DQN with multi-symbol state vectors...")
|
||||
self._initialize_dqn_with_data(symbol_features)
|
||||
self._initialize_dqn_with_provider_data(symbols_timeframes)
|
||||
|
||||
# Initialize Transformer with sequence data
|
||||
if hasattr(self, 'transformer_model') and self.transformer_model:
|
||||
logger.info("Initializing Transformer with multi-symbol sequences...")
|
||||
self._initialize_transformer_with_data(symbol_features)
|
||||
self._initialize_transformer_with_provider_data(symbols_timeframes)
|
||||
|
||||
# Initialize Decision Fusion with comprehensive features
|
||||
if hasattr(self, 'decision_fusion') and self.decision_fusion:
|
||||
logger.info("Initializing Decision Fusion with multi-symbol features...")
|
||||
self._initialize_decision_with_data(symbol_features)
|
||||
self._initialize_decision_with_provider_data(symbol_features)
|
||||
|
||||
logger.info("✅ All models initialized with historical multi-symbol data")
|
||||
logger.info("✅ All models initialized with data provider's normalized historical data")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing models with historical data: {e}")
|
||||
|
||||
def _prepare_multi_symbol_features(self, loaded_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Prepare normalized multi-symbol feature matrices"""
|
||||
try:
|
||||
symbol_features = {
|
||||
'ETH/USDT': {'1m': None, '1h': None, '1d': None},
|
||||
'BTC/USDT': {'1m': None}
|
||||
}
|
||||
|
||||
# Process each symbol's data with symbol-specific normalization
|
||||
for data_key, df in loaded_data.items():
|
||||
if df is None or df.empty:
|
||||
continue
|
||||
|
||||
# Extract symbol and timeframe
|
||||
if '_1m' in data_key:
|
||||
symbol = data_key.replace('_1m', '')
|
||||
timeframe = '1m'
|
||||
elif '_1h' in data_key:
|
||||
symbol = data_key.replace('_1h', '')
|
||||
timeframe = '1h'
|
||||
elif '_1d' in data_key:
|
||||
symbol = data_key.replace('_1d', '')
|
||||
timeframe = '1d'
|
||||
else:
|
||||
continue
|
||||
|
||||
# Apply symbol-grouped normalization
|
||||
normalized_df = self._apply_symbol_grouped_normalization(df, symbol)
|
||||
|
||||
if normalized_df is not None:
|
||||
symbol_features[symbol][timeframe] = normalized_df
|
||||
logger.debug(f"Prepared normalized features for {symbol} {timeframe}")
|
||||
|
||||
return symbol_features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing multi-symbol features: {e}")
|
||||
return {}
|
||||
|
||||
def _apply_symbol_grouped_normalization(self, df: pd.DataFrame, symbol: str) -> pd.DataFrame:
|
||||
"""Apply symbol-grouped normalization with consistent ranges across timeframes"""
|
||||
try:
|
||||
df_norm = df.copy()
|
||||
|
||||
# Get symbol-specific price ranges for consistent normalization
|
||||
symbol_price_ranges = {
|
||||
'ETH/USDT': {'min': 1000, 'max': 5000}, # ETH price range
|
||||
'BTC/USDT': {'min': 90000, 'max': 120000} # BTC price range
|
||||
}
|
||||
|
||||
if symbol in symbol_price_ranges:
|
||||
price_range = symbol_price_ranges[symbol]
|
||||
range_size = price_range['max'] - price_range['min']
|
||||
|
||||
# Normalize price columns to [0, 1] range specific to symbol
|
||||
price_cols = ['open', 'high', 'low', 'close']
|
||||
for col in price_cols:
|
||||
if col in df_norm.columns:
|
||||
df_norm[col] = (df_norm[col] - price_range['min']) / range_size
|
||||
df_norm[col] = np.clip(df_norm[col], 0, 1) # Ensure [0,1] range
|
||||
|
||||
# Normalize volume to [0, 1] using log scale
|
||||
if 'volume' in df_norm.columns:
|
||||
df_norm['volume'] = np.log1p(df_norm['volume'])
|
||||
vol_max = df_norm['volume'].max()
|
||||
if vol_max > 0:
|
||||
df_norm['volume'] = df_norm['volume'] / vol_max
|
||||
|
||||
logger.debug(f"Applied symbol-grouped normalization for {symbol}")
|
||||
|
||||
# Fill any NaN values
|
||||
df_norm = df_norm.fillna(0)
|
||||
|
||||
return df_norm
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in symbol-grouped normalization for {symbol}: {e}")
|
||||
return df
|
||||
|
||||
def _initialize_cnn_with_data(self, symbol_features: Dict[str, Any]):
|
||||
"""Initialize CNN with multi-symbol feature matrix"""
|
||||
def _initialize_cnn_with_provider_data(self):
|
||||
"""Initialize CNN using data provider's normalized feature extraction"""
|
||||
try:
|
||||
# Create combined feature matrix: [ETH_1m, ETH_1h, ETH_1d, BTC_1m]
|
||||
combined_features = []
|
||||
|
||||
# ETH features (1m, 1h, 1d)
|
||||
for timeframe in ['1m', '1h', '1d']:
|
||||
eth_data = symbol_features.get('ETH/USDT', {}).get(timeframe)
|
||||
if eth_data is not None and not eth_data.empty:
|
||||
# Use last 60 candles for CNN input
|
||||
recent_data = eth_data.tail(60)
|
||||
features = recent_data[['open', 'high', 'low', 'close', 'volume']].values
|
||||
combined_features.append(features.flatten())
|
||||
features = self.data_provider.get_cnn_features_for_inference('ETH/USDT', timeframe, window_size=60)
|
||||
if features is not None:
|
||||
combined_features.append(features)
|
||||
|
||||
# BTC features (1m)
|
||||
btc_data = symbol_features.get('BTC/USDT', {}).get('1m')
|
||||
if btc_data is not None and not btc_data.empty:
|
||||
recent_data = btc_data.tail(60)
|
||||
features = recent_data[['open', 'high', 'low', 'close', 'volume']].values
|
||||
combined_features.append(features.flatten())
|
||||
btc_features = self.data_provider.get_cnn_features_for_inference('BTC/USDT', '1m', window_size=60)
|
||||
if btc_features is not None:
|
||||
combined_features.append(btc_features)
|
||||
|
||||
if combined_features:
|
||||
# Concatenate all features
|
||||
full_features = np.concatenate(combined_features)
|
||||
logger.info(f"CNN initialized with {len(full_features)} multi-symbol features")
|
||||
logger.info(f"CNN initialized with {len(full_features)} multi-symbol normalized features")
|
||||
|
||||
# Store for model access
|
||||
if not hasattr(self, 'model_historical_features'):
|
||||
@@ -2519,39 +2446,16 @@ class TradingOrchestrator:
|
||||
self.model_historical_features['cnn'] = full_features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing CNN with historical data: {e}")
|
||||
logger.error(f"Error initializing CNN with provider data: {e}")
|
||||
|
||||
def _initialize_dqn_with_data(self, symbol_features: Dict[str, Any]):
|
||||
"""Initialize DQN with multi-symbol state vectors"""
|
||||
def _initialize_dqn_with_provider_data(self, symbols_timeframes: List[Tuple[str, str]]):
|
||||
"""Initialize DQN using data provider's normalized state vector creation"""
|
||||
try:
|
||||
# Create comprehensive state vector combining all symbols and timeframes
|
||||
state_components = []
|
||||
# Use data provider's DQN state creation
|
||||
state_vector = self.data_provider.get_dqn_state_for_inference(symbols_timeframes, target_size=100)
|
||||
|
||||
for symbol in ['ETH/USDT', 'BTC/USDT']:
|
||||
timeframes = ['1m', '1h', '1d'] if symbol == 'ETH/USDT' else ['1m']
|
||||
|
||||
for timeframe in timeframes:
|
||||
data = symbol_features.get(symbol, {}).get(timeframe)
|
||||
if data is not None and not data.empty:
|
||||
# Extract key features for state
|
||||
latest = data.iloc[-1]
|
||||
state_features = [
|
||||
latest['close'], # Current price
|
||||
latest['volume'], # Current volume
|
||||
data['close'].pct_change().iloc[-1] if len(data) > 1 else 0, # Price change
|
||||
]
|
||||
state_components.extend(state_features)
|
||||
|
||||
if state_components:
|
||||
# Pad or truncate to expected DQN state size
|
||||
target_size = 100 # DQN expects 100-dimensional state
|
||||
if len(state_components) < target_size:
|
||||
state_components.extend([0] * (target_size - len(state_components)))
|
||||
else:
|
||||
state_components = state_components[:target_size]
|
||||
|
||||
state_vector = np.array(state_components, dtype=np.float32)
|
||||
logger.info(f"DQN initialized with {len(state_vector)} dimensional multi-symbol state")
|
||||
if state_vector is not None:
|
||||
logger.info(f"DQN initialized with {len(state_vector)} dimensional normalized multi-symbol state")
|
||||
|
||||
# Store for model access
|
||||
if not hasattr(self, 'model_historical_features'):
|
||||
@@ -2559,30 +2463,16 @@ class TradingOrchestrator:
|
||||
self.model_historical_features['dqn'] = state_vector
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing DQN with historical data: {e}")
|
||||
logger.error(f"Error initializing DQN with provider data: {e}")
|
||||
|
||||
def _initialize_transformer_with_data(self, symbol_features: Dict[str, Any]):
|
||||
"""Initialize Transformer with multi-symbol sequence data"""
|
||||
def _initialize_transformer_with_provider_data(self, symbols_timeframes: List[Tuple[str, str]]):
|
||||
"""Initialize Transformer using data provider's normalized sequence creation"""
|
||||
try:
|
||||
# Prepare sequence data for transformer
|
||||
sequences = []
|
||||
|
||||
# ETH sequences
|
||||
for timeframe in ['1m', '1h', '1d']:
|
||||
eth_data = symbol_features.get('ETH/USDT', {}).get(timeframe)
|
||||
if eth_data is not None and not eth_data.empty:
|
||||
# Use last 150 points as sequence
|
||||
sequence = eth_data.tail(150)[['open', 'high', 'low', 'close', 'volume']].values
|
||||
sequences.append(sequence)
|
||||
|
||||
# BTC sequence
|
||||
btc_data = symbol_features.get('BTC/USDT', {}).get('1m')
|
||||
if btc_data is not None and not btc_data.empty:
|
||||
sequence = btc_data.tail(150)[['open', 'high', 'low', 'close', 'volume']].values
|
||||
sequences.append(sequence)
|
||||
# Use data provider's transformer sequence creation
|
||||
sequences = self.data_provider.get_transformer_sequences_for_inference(symbols_timeframes, seq_length=150)
|
||||
|
||||
if sequences:
|
||||
logger.info(f"Transformer initialized with {len(sequences)} multi-symbol sequences")
|
||||
logger.info(f"Transformer initialized with {len(sequences)} normalized multi-symbol sequences")
|
||||
|
||||
# Store for model access
|
||||
if not hasattr(self, 'model_historical_features'):
|
||||
@@ -2590,10 +2480,10 @@ class TradingOrchestrator:
|
||||
self.model_historical_features['transformer'] = sequences
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing Transformer with historical data: {e}")
|
||||
logger.error(f"Error initializing Transformer with provider data: {e}")
|
||||
|
||||
def _initialize_decision_with_data(self, symbol_features: Dict[str, Any]):
|
||||
"""Initialize Decision Fusion with comprehensive multi-symbol features"""
|
||||
def _initialize_decision_with_provider_data(self, symbol_features: Dict[str, Dict[str, pd.DataFrame]]):
|
||||
"""Initialize Decision Fusion using data provider's feature aggregation"""
|
||||
try:
|
||||
# Aggregate all available features for decision fusion
|
||||
all_features = {}
|
||||
@@ -2611,7 +2501,7 @@ class TradingOrchestrator:
|
||||
}
|
||||
|
||||
if all_features:
|
||||
logger.info(f"Decision Fusion initialized with {len(all_features)} symbol-timeframe combinations")
|
||||
logger.info(f"Decision Fusion initialized with {len(all_features)} normalized symbol-timeframe combinations")
|
||||
|
||||
# Store for model access
|
||||
if not hasattr(self, 'model_historical_features'):
|
||||
@@ -2619,7 +2509,7 @@ class TradingOrchestrator:
|
||||
self.model_historical_features['decision'] = all_features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing Decision Fusion with historical data: {e}")
|
||||
logger.error(f"Error initializing Decision Fusion with provider data: {e}")
|
||||
|
||||
def get_ohlcv_data(self, symbol: str, timeframe: str, limit: int = 300) -> List:
|
||||
"""Get OHLCV data for a symbol with specified timeframe and limit."""
|
||||
|
@@ -926,33 +926,44 @@ class CleanTradingDashboard:
|
||||
return html.P(f"Error: {str(e)}", className="text-danger")
|
||||
|
||||
@self.app.callback(
|
||||
Output('training-status', 'children'),
|
||||
[Output('training-status', 'children'),
|
||||
Output('training-status', 'className')],
|
||||
[Input('start-training-btn', 'n_clicks'),
|
||||
Input('stop-training-btn', 'n_clicks')],
|
||||
prevent_initial_call=True
|
||||
Input('stop-training-btn', 'n_clicks'),
|
||||
Input('interval-component', 'n_intervals')], # Auto-update on interval
|
||||
prevent_initial_call=False # Allow initial call to set status
|
||||
)
|
||||
def control_training(start_clicks, stop_clicks):
|
||||
def control_training(start_clicks, stop_clicks, n_intervals):
|
||||
try:
|
||||
from utils.training_integration import get_unified_training_manager
|
||||
manager = get_unified_training_manager(
|
||||
orchestrator=self.orchestrator,
|
||||
data_provider=self.data_provider,
|
||||
dashboard=self
|
||||
)
|
||||
# Use orchestrator's enhanced training system directly
|
||||
if not hasattr(self.orchestrator, 'enhanced_training_system') or not self.orchestrator.enhanced_training_system:
|
||||
return "Not Available", "badge bg-danger small"
|
||||
|
||||
ctx = dash.callback_context
|
||||
if not ctx.triggered:
|
||||
raise PreventUpdate
|
||||
trigger_id = ctx.triggered[0]['prop_id'].split('.')[0]
|
||||
if trigger_id == 'start-training-btn':
|
||||
ok = manager.start()
|
||||
return 'Running' if ok else 'Error'
|
||||
elif trigger_id == 'stop-training-btn':
|
||||
ok = manager.stop()
|
||||
return 'Stopped' if ok else 'Error'
|
||||
return 'Idle'
|
||||
|
||||
# Check if this is triggered by button clicks
|
||||
if ctx.triggered:
|
||||
trigger_id = ctx.triggered[0]['prop_id'].split('.')[0]
|
||||
if trigger_id == 'start-training-btn':
|
||||
self.orchestrator.start_enhanced_training()
|
||||
return 'Running', 'badge bg-success small'
|
||||
elif trigger_id == 'stop-training-btn':
|
||||
self.orchestrator.stop_enhanced_training()
|
||||
return 'Stopped', 'badge bg-warning small'
|
||||
|
||||
# Auto-update: Check actual training status
|
||||
if hasattr(self.orchestrator.enhanced_training_system, 'is_training'):
|
||||
if self.orchestrator.enhanced_training_system.is_training:
|
||||
return 'Running', 'badge bg-success small'
|
||||
else:
|
||||
return 'Idle', 'badge bg-secondary small'
|
||||
else:
|
||||
# Default to Running since training auto-starts
|
||||
return 'Running', 'badge bg-success small'
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training control error: {e}")
|
||||
return 'Error'
|
||||
logger.error(f"Training status error: {e}")
|
||||
return 'Error', 'badge bg-danger small'
|
||||
|
||||
@self.app.callback(
|
||||
[Output('eth-cob-content', 'children'),
|
||||
|
@@ -173,7 +173,7 @@ class DashboardLayoutManager:
|
||||
], className="d-flex align-items-center mb-1"),
|
||||
html.Div([
|
||||
html.Span("Training:", className="small me-1"),
|
||||
html.Span(id="training-status", children="Idle", className="badge bg-secondary small")
|
||||
html.Span(id="training-status", children="Starting...", className="badge bg-primary small")
|
||||
])
|
||||
], className="mb-2"),
|
||||
|
||||
|
Reference in New Issue
Block a user