2196 lines
105 KiB
Python
2196 lines
105 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Enhanced Real-Time Online Training System
|
|
|
|
This system implements effective online learning with:
|
|
- High-frequency data integration (COB, ticks, OHLCV)
|
|
- Proper reward engineering for profitable trading
|
|
- Experience replay with prioritization
|
|
- Continuous validation and adaptation
|
|
- Multi-timeframe feature engineering
|
|
- Real market microstructure analysis
|
|
"""
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import logging
|
|
import time
|
|
import threading
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any, Tuple
|
|
from collections import deque
|
|
import random
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class EnhancedRealtimeTrainingSystem:
|
|
"""Enhanced real-time training system with proper online learning"""
|
|
|
|
def __init__(self, orchestrator, data_provider, dashboard=None):
|
|
self.orchestrator = orchestrator
|
|
self.data_provider = data_provider
|
|
self.dashboard = dashboard
|
|
|
|
# Training configuration
|
|
self.training_config = {
|
|
'dqn_training_interval': 5, # Train DQN every 5 seconds
|
|
'cnn_training_interval': 10, # Train CNN every 10 seconds
|
|
'batch_size': 640, # Larger batch size for stability (increased 10x)
|
|
'memory_size': 100000, # Larger memory for diversity (increased 10x)
|
|
'validation_interval': 60, # Validate every minute
|
|
'adaptation_threshold': 0.1, # Adapt if performance drops 10%
|
|
'min_training_samples': 100 # Minimum samples before training
|
|
}
|
|
|
|
# Experience buffers
|
|
self.experience_buffer = deque(maxlen=self.training_config['memory_size'])
|
|
self.validation_buffer = deque(maxlen=1000)
|
|
self.priority_buffer = deque(maxlen=2000) # High-priority experiences
|
|
|
|
# Performance tracking
|
|
self.performance_history = {
|
|
'dqn_losses': deque(maxlen=1000),
|
|
'cnn_losses': deque(maxlen=1000),
|
|
'prediction_accuracy': deque(maxlen=500),
|
|
'trading_performance': deque(maxlen=200),
|
|
'validation_scores': deque(maxlen=100)
|
|
}
|
|
|
|
# Feature engineering components
|
|
self.feature_window = 50 # Price history window
|
|
self.technical_indicators = {}
|
|
self.market_microstructure = {}
|
|
|
|
# Training state
|
|
self.is_training = False
|
|
self.training_iteration = 0
|
|
self.last_training_times = {
|
|
'dqn': 0.0,
|
|
'cnn': 0.0,
|
|
'validation': 0.0
|
|
}
|
|
|
|
# Model prediction tracking - NEW for dashboard visualization
|
|
self.recent_dqn_predictions = {
|
|
'ETH/USDT': deque(maxlen=100),
|
|
'BTC/USDT': deque(maxlen=100)
|
|
}
|
|
self.recent_cnn_predictions = {
|
|
'ETH/USDT': deque(maxlen=50),
|
|
'BTC/USDT': deque(maxlen=50)
|
|
}
|
|
self.prediction_accuracy_history = {
|
|
'ETH/USDT': deque(maxlen=200),
|
|
'BTC/USDT': deque(maxlen=200)
|
|
}
|
|
|
|
# FIXED: Forward-looking prediction system
|
|
self.pending_predictions = {
|
|
'ETH/USDT': deque(maxlen=100), # Predictions waiting for validation
|
|
'BTC/USDT': deque(maxlen=100)
|
|
}
|
|
self.last_prediction_time = {
|
|
'ETH/USDT': 0,
|
|
'BTC/USDT': 0
|
|
}
|
|
self.prediction_intervals = {
|
|
'dqn': 30, # Make DQN prediction every 30 seconds
|
|
'cnn': 60 # Make CNN prediction every 60 seconds
|
|
}
|
|
|
|
# Real-time data streams
|
|
self.real_time_data = {
|
|
'ticks': deque(maxlen=1000),
|
|
'ohlcv_1m': deque(maxlen=200),
|
|
'ohlcv_5m': deque(maxlen=100),
|
|
'cob_snapshots': deque(maxlen=500),
|
|
'market_events': deque(maxlen=300)
|
|
}
|
|
|
|
logger.info("Enhanced Real-time Training System initialized")
|
|
|
|
def start_training(self):
|
|
"""Start the enhanced real-time training system"""
|
|
if self.is_training:
|
|
logger.warning("Training system already running")
|
|
return
|
|
|
|
self.is_training = True
|
|
|
|
# Start data collection thread
|
|
data_thread = threading.Thread(target=self._data_collection_worker, daemon=True)
|
|
data_thread.start()
|
|
|
|
# Start training coordinator
|
|
training_thread = threading.Thread(target=self._training_coordinator, daemon=True)
|
|
training_thread.start()
|
|
|
|
# Start validation worker
|
|
validation_thread = threading.Thread(target=self._validation_worker, daemon=True)
|
|
validation_thread.start()
|
|
|
|
logger.info("Enhanced real-time training system started")
|
|
|
|
def stop_training(self):
|
|
"""Stop the training system"""
|
|
self.is_training = False
|
|
logger.info("Enhanced real-time training system stopped")
|
|
|
|
def _data_collection_worker(self):
|
|
"""Collect and preprocess real-time market data"""
|
|
while self.is_training:
|
|
try:
|
|
current_time = time.time()
|
|
|
|
# 1. Collect multi-timeframe data
|
|
self._collect_ohlcv_data()
|
|
|
|
# 2. Collect tick data (if available)
|
|
self._collect_tick_data()
|
|
|
|
# 3. Collect COB data (if available)
|
|
self._collect_cob_data()
|
|
|
|
# 4. Detect market events
|
|
self._detect_market_events()
|
|
|
|
# 5. Update technical indicators
|
|
self._update_technical_indicators()
|
|
|
|
# 6. Create training experiences
|
|
self._create_training_experiences()
|
|
|
|
time.sleep(1) # Collect data every second
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in data collection worker: {e}")
|
|
time.sleep(5)
|
|
|
|
def _training_coordinator(self):
|
|
"""Coordinate all training activities with proper scheduling"""
|
|
while self.is_training:
|
|
try:
|
|
current_time = time.time()
|
|
self.training_iteration += 1
|
|
|
|
# 1. FORWARD-LOOKING PREDICTIONS - Generate real predictions for future validation
|
|
self.generate_forward_looking_predictions()
|
|
|
|
# 2. COB RL Training (every 1 second - HIGHEST PRIORITY since COB imbalance predicts moves)
|
|
cob_interval = self.training_config.get('cob_rl_training_interval', 1)
|
|
if (current_time - self.last_training_times.get('cob_rl', 0) > cob_interval
|
|
and len(self.real_time_data['cob_snapshots']) >= 5):
|
|
if (hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent):
|
|
self._perform_enhanced_cob_rl_training()
|
|
self.last_training_times['cob_rl'] = current_time
|
|
|
|
# 3. DQN Training (every 5 seconds with enough data)
|
|
if (current_time - self.last_training_times['dqn'] > self.training_config['dqn_training_interval']
|
|
and len(self.experience_buffer) >= self.training_config['min_training_samples']):
|
|
self._perform_enhanced_dqn_training()
|
|
self.last_training_times['dqn'] = current_time
|
|
|
|
# 3. CNN Training (every 10 seconds)
|
|
if (current_time - self.last_training_times['cnn'] > self.training_config['cnn_training_interval']
|
|
and len(self.real_time_data['ohlcv_1m']) >= 20):
|
|
self._perform_enhanced_cnn_training()
|
|
self.last_training_times['cnn'] = current_time
|
|
|
|
# 4. Validation (every minute)
|
|
if current_time - self.last_training_times['validation'] > self.training_config['validation_interval']:
|
|
self._perform_validation()
|
|
self.last_training_times['validation'] = current_time
|
|
|
|
# 5. Adaptive learning rate adjustment
|
|
if self.training_iteration % 100 == 0:
|
|
self._adapt_learning_parameters()
|
|
|
|
# Log progress every 30 iterations
|
|
if self.training_iteration % 30 == 0:
|
|
self._log_training_progress()
|
|
|
|
time.sleep(2) # Training coordinator runs every 2 seconds
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in training coordinator: {e}")
|
|
time.sleep(10)
|
|
|
|
def _collect_ohlcv_data(self):
|
|
"""Collect multi-timeframe OHLCV data"""
|
|
try:
|
|
# 1m data
|
|
df_1m = self.data_provider.get_historical_data('ETH/USDT', '1m', limit=5)
|
|
if df_1m is not None and not df_1m.empty:
|
|
latest_bar = {
|
|
'timestamp': df_1m.index[-1],
|
|
'open': float(df_1m['open'].iloc[-1]),
|
|
'high': float(df_1m['high'].iloc[-1]),
|
|
'low': float(df_1m['low'].iloc[-1]),
|
|
'close': float(df_1m['close'].iloc[-1]),
|
|
'volume': float(df_1m['volume'].iloc[-1]),
|
|
'timeframe': '1m'
|
|
}
|
|
|
|
# Only add if new data
|
|
if not self.real_time_data['ohlcv_1m'] or self.real_time_data['ohlcv_1m'][-1]['timestamp'] != latest_bar['timestamp']:
|
|
self.real_time_data['ohlcv_1m'].append(latest_bar)
|
|
|
|
# 5m data (less frequent)
|
|
if self.training_iteration % 5 == 0:
|
|
df_5m = self.data_provider.get_historical_data('ETH/USDT', '5m', limit=3)
|
|
if df_5m is not None and not df_5m.empty:
|
|
latest_bar_5m = {
|
|
'timestamp': df_5m.index[-1],
|
|
'open': float(df_5m['open'].iloc[-1]),
|
|
'high': float(df_5m['high'].iloc[-1]),
|
|
'low': float(df_5m['low'].iloc[-1]),
|
|
'close': float(df_5m['close'].iloc[-1]),
|
|
'volume': float(df_5m['volume'].iloc[-1]),
|
|
'timeframe': '5m'
|
|
}
|
|
|
|
if not self.real_time_data['ohlcv_5m'] or self.real_time_data['ohlcv_5m'][-1]['timestamp'] != latest_bar_5m['timestamp']:
|
|
self.real_time_data['ohlcv_5m'].append(latest_bar_5m)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error collecting OHLCV data: {e}")
|
|
|
|
def _collect_tick_data(self):
|
|
"""Collect real-time tick data from dashboard"""
|
|
try:
|
|
if self.dashboard and hasattr(self.dashboard, 'tick_cache'):
|
|
recent_ticks = self.dashboard.tick_cache[-10:] # Last 10 ticks
|
|
for tick in recent_ticks:
|
|
tick_data = {
|
|
'timestamp': tick.get('datetime', datetime.now()),
|
|
'price': tick.get('price', 0),
|
|
'volume': tick.get('volume', 0),
|
|
'symbol': tick.get('symbol', 'ETHUSDT')
|
|
}
|
|
|
|
# Only add new ticks
|
|
if not self.real_time_data['ticks'] or self.real_time_data['ticks'][-1]['timestamp'] != tick_data['timestamp']:
|
|
self.real_time_data['ticks'].append(tick_data)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error collecting tick data: {e}")
|
|
|
|
def _collect_cob_data(self):
|
|
"""Collect COB (Consolidated Order Book) data and aggregate into time series matrices"""
|
|
try:
|
|
if self.dashboard and hasattr(self.dashboard, 'latest_cob_data'):
|
|
for symbol in ['ETH/USDT', 'BTC/USDT']:
|
|
if symbol in self.dashboard.latest_cob_data:
|
|
cob_data = self.dashboard.latest_cob_data[symbol]
|
|
|
|
# Create raw tick snapshot (1D from API)
|
|
raw_snapshot = {
|
|
'timestamp': datetime.now(),
|
|
'symbol': symbol,
|
|
'stats': cob_data.get('stats', {}),
|
|
'levels': len(cob_data.get('bids', [])) + len(cob_data.get('asks', [])),
|
|
'imbalance': cob_data.get('stats', {}).get('imbalance', 0),
|
|
'spread_bps': cob_data.get('stats', {}).get('spread_bps', 0),
|
|
'current_price': cob_data.get('stats', {}).get('mid_price', 0),
|
|
'bid_liquidity': cob_data.get('stats', {}).get('bid_liquidity', 0),
|
|
'ask_liquidity': cob_data.get('stats', {}).get('ask_liquidity', 0),
|
|
'total_liquidity': cob_data.get('stats', {}).get('total_liquidity', 0),
|
|
}
|
|
|
|
# Add to raw tick collection
|
|
self.real_time_data['cob_snapshots'].append(raw_snapshot)
|
|
|
|
# Aggregate into 1-second averaged matrices
|
|
self._aggregate_cob_to_time_series(symbol, raw_snapshot)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error collecting COB data: {e}")
|
|
|
|
def _aggregate_cob_to_time_series(self, symbol: str, raw_snapshot: Dict):
|
|
"""
|
|
Aggregate COB snapshots from 1D API data to 2D time series matrices
|
|
Creates both raw tick data and 1-second averaged aggregations
|
|
"""
|
|
try:
|
|
current_time = datetime.now()
|
|
|
|
# Initialize aggregation buffers if needed
|
|
if not hasattr(self, 'cob_tick_buffers'):
|
|
self.cob_tick_buffers = {}
|
|
self.cob_1s_aggregated = {}
|
|
self.cob_aggregation_windows = {}
|
|
|
|
if symbol not in self.cob_tick_buffers:
|
|
self.cob_tick_buffers[symbol] = []
|
|
self.cob_1s_aggregated[symbol] = []
|
|
self.cob_aggregation_windows[symbol] = current_time.replace(microsecond=0)
|
|
|
|
# Add raw tick to buffer
|
|
tick_data = {
|
|
'timestamp': current_time,
|
|
'imbalance': raw_snapshot.get('imbalance', 0),
|
|
'spread_bps': raw_snapshot.get('spread_bps', 0),
|
|
'bid_liquidity': raw_snapshot.get('bid_liquidity', 0),
|
|
'ask_liquidity': raw_snapshot.get('ask_liquidity', 0),
|
|
'total_liquidity': raw_snapshot.get('total_liquidity', 0),
|
|
'mid_price': raw_snapshot.get('current_price', 0),
|
|
'levels_count': raw_snapshot.get('levels', 0)
|
|
}
|
|
|
|
self.cob_tick_buffers[symbol].append(tick_data)
|
|
|
|
# Keep only last 1000 ticks (about 3-5 minutes of data at 200ms intervals)
|
|
if len(self.cob_tick_buffers[symbol]) > 1000:
|
|
self.cob_tick_buffers[symbol] = self.cob_tick_buffers[symbol][-1000:]
|
|
|
|
# Check if we need to aggregate to 1-second window
|
|
window_start = self.cob_aggregation_windows[symbol]
|
|
if (current_time - window_start).total_seconds() >= 1.0:
|
|
# Get all ticks in this 1-second window
|
|
window_ticks = [
|
|
tick for tick in self.cob_tick_buffers[symbol]
|
|
if window_start <= tick['timestamp'] < window_start + timedelta(seconds=1)
|
|
]
|
|
|
|
if window_ticks:
|
|
# Create 1-second aggregated data
|
|
aggregated_data = self._create_1s_cob_aggregation(window_ticks, window_start)
|
|
self.cob_1s_aggregated[symbol].append(aggregated_data)
|
|
|
|
# Keep only last 300 seconds (5 minutes of 1s data)
|
|
if len(self.cob_1s_aggregated[symbol]) > 300:
|
|
self.cob_1s_aggregated[symbol] = self.cob_1s_aggregated[symbol][-300:]
|
|
|
|
# Move to next 1-second window
|
|
self.cob_aggregation_windows[symbol] = current_time.replace(microsecond=0)
|
|
|
|
# Create 2D matrices for model training
|
|
self._create_cob_training_matrices(symbol)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error aggregating COB data for {symbol}: {e}")
|
|
|
|
def _create_1s_cob_aggregation(self, window_ticks: List[Dict], window_start: datetime) -> Dict:
|
|
"""Create 1-second aggregated COB data from raw ticks"""
|
|
try:
|
|
if not window_ticks:
|
|
return {}
|
|
|
|
# Statistical aggregations
|
|
imbalances = [tick['imbalance'] for tick in window_ticks]
|
|
spreads = [tick['spread_bps'] for tick in window_ticks]
|
|
bid_liquidities = [tick['bid_liquidity'] for tick in window_ticks]
|
|
ask_liquidities = [tick['ask_liquidity'] for tick in window_ticks]
|
|
total_liquidities = [tick['total_liquidity'] for tick in window_ticks]
|
|
mid_prices = [tick['mid_price'] for tick in window_ticks if tick['mid_price'] > 0]
|
|
|
|
aggregated = {
|
|
'timestamp': window_start,
|
|
'tick_count': len(window_ticks),
|
|
|
|
# Imbalance statistics
|
|
'imbalance_mean': np.mean(imbalances) if imbalances else 0,
|
|
'imbalance_std': np.std(imbalances) if len(imbalances) > 1 else 0,
|
|
'imbalance_min': np.min(imbalances) if imbalances else 0,
|
|
'imbalance_max': np.max(imbalances) if imbalances else 0,
|
|
'imbalance_final': imbalances[-1] if imbalances else 0,
|
|
|
|
# Spread statistics
|
|
'spread_mean': np.mean(spreads) if spreads else 0,
|
|
'spread_std': np.std(spreads) if len(spreads) > 1 else 0,
|
|
'spread_min': np.min(spreads) if spreads else 0,
|
|
'spread_max': np.max(spreads) if spreads else 0,
|
|
'spread_final': spreads[-1] if spreads else 0,
|
|
|
|
# Liquidity statistics
|
|
'bid_liquidity_mean': np.mean(bid_liquidities) if bid_liquidities else 0,
|
|
'ask_liquidity_mean': np.mean(ask_liquidities) if ask_liquidities else 0,
|
|
'total_liquidity_mean': np.mean(total_liquidities) if total_liquidities else 0,
|
|
'liquidity_volatility': np.std(total_liquidities) if len(total_liquidities) > 1 else 0,
|
|
|
|
# Price statistics
|
|
'price_mean': np.mean(mid_prices) if mid_prices else 0,
|
|
'price_std': np.std(mid_prices) if len(mid_prices) > 1 else 0,
|
|
'price_change': (mid_prices[-1] - mid_prices[0]) / mid_prices[0] if len(mid_prices) >= 2 and mid_prices[0] > 0 else 0,
|
|
'price_final': mid_prices[-1] if mid_prices else 0,
|
|
|
|
# Activity metrics
|
|
'avg_levels': np.mean([tick['levels_count'] for tick in window_ticks]),
|
|
'update_frequency': len(window_ticks), # Updates per second
|
|
|
|
# Derived metrics
|
|
'imbalance_momentum': (imbalances[-1] - imbalances[0]) if len(imbalances) >= 2 else 0,
|
|
'spread_momentum': (spreads[-1] - spreads[0]) if len(spreads) >= 2 else 0,
|
|
'liquidity_momentum': (total_liquidities[-1] - total_liquidities[0]) / max(total_liquidities[0], 1) if len(total_liquidities) >= 2 else 0
|
|
}
|
|
|
|
return aggregated
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating 1s COB aggregation: {e}")
|
|
return {}
|
|
|
|
def _create_cob_training_matrices(self, symbol: str):
|
|
"""
|
|
Create 2D training matrices from COB time series data
|
|
Output: [time_steps, features] matrices for both raw ticks and 1s aggregated data
|
|
"""
|
|
try:
|
|
if not hasattr(self, 'cob_training_matrices'):
|
|
self.cob_training_matrices = {}
|
|
|
|
if symbol not in self.cob_training_matrices:
|
|
self.cob_training_matrices[symbol] = {
|
|
'raw_tick_matrix': None,
|
|
'1s_aggregated_matrix': None,
|
|
'combined_features': None
|
|
}
|
|
|
|
# Create raw tick matrix (last 60 ticks = ~12 seconds at 200ms intervals)
|
|
if hasattr(self, 'cob_tick_buffers') and symbol in self.cob_tick_buffers:
|
|
recent_ticks = self.cob_tick_buffers[symbol][-60:]
|
|
if len(recent_ticks) >= 10: # Minimum data required
|
|
tick_matrix = []
|
|
for tick in recent_ticks:
|
|
tick_features = [
|
|
tick.get('imbalance', 0),
|
|
tick.get('spread_bps', 0) / 100.0, # Normalize
|
|
tick.get('bid_liquidity', 0) / 1000000.0, # Normalize to millions
|
|
tick.get('ask_liquidity', 0) / 1000000.0,
|
|
tick.get('total_liquidity', 0) / 1000000.0,
|
|
tick.get('levels_count', 0) / 100.0, # Normalize
|
|
tick.get('mid_price', 0) / 10000.0 if tick.get('mid_price', 0) > 0 else 0 # Normalize price
|
|
]
|
|
tick_matrix.append(tick_features)
|
|
|
|
self.cob_training_matrices[symbol]['raw_tick_matrix'] = np.array(tick_matrix, dtype=np.float32)
|
|
|
|
# Create 1s aggregated matrix (last 60 seconds)
|
|
if hasattr(self, 'cob_1s_aggregated') and symbol in self.cob_1s_aggregated:
|
|
recent_1s = self.cob_1s_aggregated[symbol][-60:]
|
|
if len(recent_1s) >= 5: # Minimum data required
|
|
aggregated_matrix = []
|
|
for agg_data in recent_1s:
|
|
agg_features = [
|
|
agg_data.get('imbalance_mean', 0),
|
|
agg_data.get('imbalance_std', 0),
|
|
agg_data.get('imbalance_momentum', 0),
|
|
agg_data.get('spread_mean', 0) / 100.0,
|
|
agg_data.get('spread_std', 0) / 100.0,
|
|
agg_data.get('spread_momentum', 0) / 100.0,
|
|
agg_data.get('bid_liquidity_mean', 0) / 1000000.0,
|
|
agg_data.get('ask_liquidity_mean', 0) / 1000000.0,
|
|
agg_data.get('total_liquidity_mean', 0) / 1000000.0,
|
|
agg_data.get('liquidity_volatility', 0) / 1000000.0,
|
|
agg_data.get('liquidity_momentum', 0),
|
|
agg_data.get('price_change', 0),
|
|
agg_data.get('price_std', 0) / agg_data.get('price_mean', 1) if agg_data.get('price_mean', 0) > 0 else 0,
|
|
agg_data.get('update_frequency', 0) / 10.0, # Normalize to expected ~5 updates/sec
|
|
agg_data.get('avg_levels', 0) / 100.0
|
|
]
|
|
aggregated_matrix.append(agg_features)
|
|
|
|
self.cob_training_matrices[symbol]['1s_aggregated_matrix'] = np.array(aggregated_matrix, dtype=np.float32)
|
|
|
|
# Create combined feature matrix for comprehensive training
|
|
self._create_combined_cob_features(symbol)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating COB training matrices for {symbol}: {e}")
|
|
|
|
def _create_combined_cob_features(self, symbol: str):
|
|
"""
|
|
Combine raw tick and 1s aggregated data into comprehensive feature matrix
|
|
Creates the 2000-dimensional feature vector used by the COB RL model
|
|
"""
|
|
try:
|
|
if symbol not in self.cob_training_matrices:
|
|
return
|
|
|
|
matrices = self.cob_training_matrices[symbol]
|
|
combined_features = []
|
|
|
|
# 1. Latest raw tick features (7 features from most recent tick)
|
|
if matrices['raw_tick_matrix'] is not None and len(matrices['raw_tick_matrix']) > 0:
|
|
latest_tick = matrices['raw_tick_matrix'][-1]
|
|
combined_features.extend(latest_tick.tolist())
|
|
else:
|
|
combined_features.extend([0.0] * 7)
|
|
|
|
# 2. Raw tick time series statistics (50 features)
|
|
if matrices['raw_tick_matrix'] is not None and len(matrices['raw_tick_matrix']) > 5:
|
|
tick_matrix = matrices['raw_tick_matrix']
|
|
# Statistical features across time for each dimension
|
|
for feature_idx in range(tick_matrix.shape[1]):
|
|
feature_series = tick_matrix[:, feature_idx]
|
|
combined_features.extend([
|
|
np.mean(feature_series),
|
|
np.std(feature_series),
|
|
np.min(feature_series),
|
|
np.max(feature_series),
|
|
feature_series[-1] - feature_series[0] if len(feature_series) > 1 else 0, # Total change
|
|
np.mean(np.diff(feature_series)) if len(feature_series) > 1 else 0, # Average momentum
|
|
np.std(np.diff(feature_series)) if len(feature_series) > 2 else 0 # Momentum volatility
|
|
])
|
|
else:
|
|
combined_features.extend([0.0] * (7 * 7)) # 7 features * 7 statistics
|
|
|
|
# 3. 1-second aggregated features (15 features from most recent 1s)
|
|
if matrices['1s_aggregated_matrix'] is not None and len(matrices['1s_aggregated_matrix']) > 0:
|
|
latest_1s = matrices['1s_aggregated_matrix'][-1]
|
|
combined_features.extend(latest_1s.tolist())
|
|
else:
|
|
combined_features.extend([0.0] * 15)
|
|
|
|
# 4. 1-second time series statistics (150 features)
|
|
if matrices['1s_aggregated_matrix'] is not None and len(matrices['1s_aggregated_matrix']) > 3:
|
|
agg_matrix = matrices['1s_aggregated_matrix']
|
|
# Statistical features across time for each aggregated dimension
|
|
for feature_idx in range(agg_matrix.shape[1]):
|
|
feature_series = agg_matrix[:, feature_idx]
|
|
combined_features.extend([
|
|
np.mean(feature_series),
|
|
np.std(feature_series),
|
|
np.min(feature_series),
|
|
np.max(feature_series),
|
|
feature_series[-1] - feature_series[0] if len(feature_series) > 1 else 0, # Total change
|
|
np.mean(np.diff(feature_series)) if len(feature_series) > 1 else 0, # Average momentum
|
|
np.std(np.diff(feature_series)) if len(feature_series) > 2 else 0, # Momentum volatility
|
|
np.percentile(feature_series, 25), # 25th percentile
|
|
np.percentile(feature_series, 75), # 75th percentile
|
|
len([x for x in np.diff(feature_series) if x > 0]) / max(len(feature_series) - 1, 1) if len(feature_series) > 1 else 0.5 # Positive change ratio
|
|
])
|
|
else:
|
|
combined_features.extend([0.0] * (15 * 10)) # 15 features * 10 statistics
|
|
|
|
# 5. Cross-correlation features between raw ticks and 1s aggregated (50 features)
|
|
if (matrices['raw_tick_matrix'] is not None and matrices['1s_aggregated_matrix'] is not None and
|
|
len(matrices['raw_tick_matrix']) > 10 and len(matrices['1s_aggregated_matrix']) > 5):
|
|
|
|
# Calculate correlations between aligned time periods
|
|
cross_features = []
|
|
try:
|
|
# Downsample raw ticks to match 1s periods for correlation
|
|
tick_downsampled = []
|
|
ticks_per_second = len(matrices['raw_tick_matrix']) // len(matrices['1s_aggregated_matrix'])
|
|
if ticks_per_second > 0:
|
|
for i in range(0, len(matrices['raw_tick_matrix']), ticks_per_second):
|
|
segment = matrices['raw_tick_matrix'][i:i+ticks_per_second]
|
|
if len(segment) > 0:
|
|
tick_downsampled.append(np.mean(segment, axis=0))
|
|
|
|
if len(tick_downsampled) >= len(matrices['1s_aggregated_matrix']):
|
|
tick_downsampled = tick_downsampled[:len(matrices['1s_aggregated_matrix'])]
|
|
|
|
# Calculate correlations between key features
|
|
for tick_idx in [0, 1, 2, 4]: # Imbalance, spread, bid_liq, total_liq
|
|
for agg_idx in [0, 3, 8]: # Imbalance_mean, spread_mean, total_liq_mean
|
|
if len(tick_downsampled) > 2:
|
|
tick_series = [t[tick_idx] for t in tick_downsampled]
|
|
agg_series = matrices['1s_aggregated_matrix'][:, agg_idx]
|
|
if len(agg_series) == len(tick_series):
|
|
correlation = np.corrcoef(tick_series, agg_series)[0, 1]
|
|
cross_features.append(correlation if not np.isnan(correlation) else 0.0)
|
|
except Exception as corr_error:
|
|
logger.debug(f"Error calculating cross-correlations: {corr_error}")
|
|
|
|
# Pad cross features to 50
|
|
while len(cross_features) < 50:
|
|
cross_features.append(0.0)
|
|
combined_features.extend(cross_features[:50])
|
|
else:
|
|
combined_features.extend([0.0] * 50)
|
|
|
|
# 6. Time-based and contextual features (remaining features to reach 2000)
|
|
remaining_features = 2000 - len(combined_features)
|
|
if remaining_features > 0:
|
|
# Add time and market context features
|
|
current_time = datetime.now()
|
|
context_features = [
|
|
np.sin(2 * np.pi * current_time.hour / 24), # Hour cyclical
|
|
np.cos(2 * np.pi * current_time.hour / 24),
|
|
current_time.weekday() / 6.0,
|
|
current_time.minute / 59.0,
|
|
len(self.cob_tick_buffers.get(symbol, [])) / 1000.0, # Tick buffer utilization
|
|
len(self.cob_1s_aggregated.get(symbol, [])) / 300.0, # 1s buffer utilization
|
|
]
|
|
|
|
# Pad to reach exactly 2000 features
|
|
while len(context_features) < remaining_features:
|
|
context_features.append(0.0)
|
|
combined_features.extend(context_features[:remaining_features])
|
|
|
|
# Store combined features (exactly 2000 dimensions)
|
|
matrices['combined_features'] = np.array(combined_features[:2000], dtype=np.float32)
|
|
|
|
logger.debug(f"Created combined COB features for {symbol}: {len(combined_features)} dimensions")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating combined COB features for {symbol}: {e}")
|
|
|
|
def get_cob_training_matrix(self, symbol: str, matrix_type: str = 'combined') -> Optional[np.ndarray]:
|
|
"""
|
|
Get COB training matrix for specified symbol and type
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
matrix_type: 'raw_tick', '1s_aggregated', or 'combined'
|
|
|
|
Returns:
|
|
Training matrix or None if not available
|
|
"""
|
|
try:
|
|
if not hasattr(self, 'cob_training_matrices') or symbol not in self.cob_training_matrices:
|
|
return None
|
|
|
|
return self.cob_training_matrices[symbol].get(f'{matrix_type}_matrix' if matrix_type != 'combined' else 'combined_features')
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting COB training matrix: {e}")
|
|
return None
|
|
|
|
def _detect_market_events(self):
|
|
"""Detect significant market events for priority training"""
|
|
try:
|
|
if len(self.real_time_data['ohlcv_1m']) < 2:
|
|
return
|
|
|
|
current_bar = self.real_time_data['ohlcv_1m'][-1]
|
|
prev_bar = self.real_time_data['ohlcv_1m'][-2]
|
|
|
|
# Price volatility spike
|
|
price_change = abs((current_bar['close'] - prev_bar['close']) / prev_bar['close'])
|
|
if price_change > 0.005: # 0.5% price movement
|
|
event = {
|
|
'timestamp': current_bar['timestamp'],
|
|
'type': 'volatility_spike',
|
|
'magnitude': price_change,
|
|
'price': current_bar['close']
|
|
}
|
|
self.real_time_data['market_events'].append(event)
|
|
|
|
# Volume surge
|
|
if len(self.real_time_data['ohlcv_1m']) >= 10:
|
|
avg_volume = np.mean([bar['volume'] for bar in list(self.real_time_data['ohlcv_1m'])[-10:]])
|
|
if current_bar['volume'] > avg_volume * 2: # 2x average volume
|
|
event = {
|
|
'timestamp': current_bar['timestamp'],
|
|
'type': 'volume_surge',
|
|
'magnitude': current_bar['volume'] / avg_volume,
|
|
'price': current_bar['close']
|
|
}
|
|
self.real_time_data['market_events'].append(event)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error detecting market events: {e}")
|
|
|
|
def _update_technical_indicators(self):
|
|
"""Update technical indicators from real-time data"""
|
|
try:
|
|
if len(self.real_time_data['ohlcv_1m']) < 20:
|
|
return
|
|
|
|
# Get price and volume arrays
|
|
prices = np.array([bar['close'] for bar in self.real_time_data['ohlcv_1m']])
|
|
volumes = np.array([bar['volume'] for bar in self.real_time_data['ohlcv_1m']])
|
|
highs = np.array([bar['high'] for bar in self.real_time_data['ohlcv_1m']])
|
|
lows = np.array([bar['low'] for bar in self.real_time_data['ohlcv_1m']])
|
|
|
|
# Update indicators
|
|
self.technical_indicators = {
|
|
'sma_10': np.mean(prices[-10:]),
|
|
'sma_20': np.mean(prices[-20:]),
|
|
'rsi': self._calculate_rsi(prices, 14),
|
|
'volatility': np.std(prices[-20:]) / np.mean(prices[-20:]),
|
|
'volume_sma': np.mean(volumes[-10:]),
|
|
'price_momentum': (prices[-1] - prices[-5]) / prices[-5] if len(prices) >= 5 else 0,
|
|
'atr': np.mean(highs[-14:] - lows[-14:]) if len(prices) >= 14 else 0
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error updating technical indicators: {e}")
|
|
|
|
def _create_training_experiences(self):
|
|
"""Create comprehensive training experiences"""
|
|
try:
|
|
if len(self.real_time_data['ohlcv_1m']) < 10:
|
|
return
|
|
|
|
current_time = time.time()
|
|
current_bar = self.real_time_data['ohlcv_1m'][-1]
|
|
|
|
# Create comprehensive state features
|
|
state_features = self._build_comprehensive_state()
|
|
|
|
# Create experience with proper reward calculation
|
|
experience = {
|
|
'timestamp': current_time,
|
|
'state': state_features,
|
|
'price': current_bar['close'],
|
|
'technical_indicators': self.technical_indicators.copy(),
|
|
'market_events': len([e for e in self.real_time_data['market_events'] if current_time - time.mktime(e['timestamp'].timetuple()) < 300]),
|
|
'cob_features': self._extract_cob_features(),
|
|
'multi_timeframe': self._get_multi_timeframe_context()
|
|
}
|
|
|
|
# Add to experience buffer
|
|
self.experience_buffer.append(experience)
|
|
|
|
# Add to priority buffer if significant event
|
|
if experience['market_events'] > 0 or any(indicator for indicator in self.technical_indicators.values() if abs(indicator) > 0.02):
|
|
self.priority_buffer.append(experience)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error creating training experiences: {e}")
|
|
|
|
def _build_comprehensive_state(self) -> np.ndarray:
|
|
"""Build comprehensive state vector for RL training"""
|
|
try:
|
|
state_features = []
|
|
|
|
# 1. Price features (normalized)
|
|
if len(self.real_time_data['ohlcv_1m']) >= 10:
|
|
recent_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-10:]]
|
|
base_price = recent_prices[0]
|
|
normalized_prices = [(p - base_price) / base_price for p in recent_prices]
|
|
state_features.extend(normalized_prices)
|
|
else:
|
|
state_features.extend([0.0] * 10)
|
|
|
|
# 2. Technical indicators
|
|
for indicator_name in ['sma_10', 'sma_20', 'rsi', 'volatility', 'volume_sma', 'price_momentum', 'atr']:
|
|
value = self.technical_indicators.get(indicator_name, 0)
|
|
# Normalize indicators
|
|
if indicator_name == 'rsi':
|
|
state_features.append(value / 100.0) # RSI 0-100 -> 0-1
|
|
elif indicator_name in ['volatility', 'price_momentum']:
|
|
state_features.append(np.tanh(value * 100)) # Bounded -1 to 1
|
|
else:
|
|
state_features.append(value / 10000.0) # Price-based normalization
|
|
|
|
# 3. Volume features
|
|
if len(self.real_time_data['ohlcv_1m']) >= 5:
|
|
recent_volumes = [bar['volume'] for bar in list(self.real_time_data['ohlcv_1m'])[-5:]]
|
|
avg_volume = np.mean(recent_volumes)
|
|
volume_ratio = recent_volumes[-1] / avg_volume if avg_volume > 0 else 1.0
|
|
state_features.append(np.tanh(volume_ratio - 1)) # Volume deviation
|
|
else:
|
|
state_features.append(0.0)
|
|
|
|
# 4. Market microstructure (COB features)
|
|
cob_features = self._extract_cob_features()
|
|
state_features.extend(cob_features[:5]) # Top 5 COB features
|
|
|
|
# 5. Time features
|
|
now = datetime.now()
|
|
state_features.append(np.sin(2 * np.pi * now.hour / 24)) # Hour of day (cyclical)
|
|
state_features.append(np.cos(2 * np.pi * now.hour / 24))
|
|
state_features.append(now.weekday() / 6.0) # Day of week
|
|
|
|
# Pad to fixed size (100 features)
|
|
while len(state_features) < 100:
|
|
state_features.append(0.0)
|
|
|
|
return np.array(state_features[:100])
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error building state: {e}")
|
|
return np.zeros(100)
|
|
|
|
def _extract_cob_features(self) -> List[float]:
|
|
"""Extract features from COB data"""
|
|
try:
|
|
if not self.real_time_data['cob_snapshots']:
|
|
return [0.0] * 10
|
|
|
|
latest_cob = self.real_time_data['cob_snapshots'][-1]
|
|
stats = latest_cob.get('stats', {})
|
|
|
|
features = [
|
|
stats.get('imbalance', 0),
|
|
stats.get('spread_bps', 0) / 100.0, # Normalize spread
|
|
latest_cob.get('levels', 0) / 100.0, # Normalize level count
|
|
stats.get('bid_liquidity', 0) / 1000000.0, # Normalize liquidity
|
|
stats.get('ask_liquidity', 0) / 1000000.0,
|
|
]
|
|
|
|
# Pad to 10 features
|
|
while len(features) < 10:
|
|
features.append(0.0)
|
|
|
|
return features[:10]
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error extracting COB features: {e}")
|
|
return [0.0] * 10
|
|
|
|
def _get_multi_timeframe_context(self) -> Dict:
|
|
"""Get multi-timeframe market context"""
|
|
try:
|
|
context = {}
|
|
|
|
# 1m trend
|
|
if len(self.real_time_data['ohlcv_1m']) >= 5:
|
|
recent_1m = list(self.real_time_data['ohlcv_1m'])[-5:]
|
|
trend_1m = (recent_1m[-1]['close'] - recent_1m[0]['close']) / recent_1m[0]['close']
|
|
context['trend_1m'] = trend_1m
|
|
|
|
# 5m trend
|
|
if len(self.real_time_data['ohlcv_5m']) >= 3:
|
|
recent_5m = list(self.real_time_data['ohlcv_5m'])[-3:]
|
|
trend_5m = (recent_5m[-1]['close'] - recent_5m[0]['close']) / recent_5m[0]['close']
|
|
context['trend_5m'] = trend_5m
|
|
|
|
return context
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error getting multi-timeframe context: {e}")
|
|
return {}
|
|
|
|
def _perform_enhanced_dqn_training(self):
|
|
"""Enhanced DQN training with comprehensive market awareness"""
|
|
try:
|
|
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
|
return
|
|
|
|
# PRIORITIZE COB RL TRAINING - Most mission critical
|
|
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
|
self._perform_enhanced_cob_rl_training()
|
|
|
|
# Regular DQN training continues here
|
|
rl_agent = self.orchestrator.rl_agent
|
|
|
|
# Get memory size for training checks
|
|
memory_size = self._get_dqn_memory_size()
|
|
|
|
if memory_size < self.training_config['min_training_samples']:
|
|
logger.debug(f"Insufficient DQN samples: {memory_size}/{self.training_config['min_training_samples']}")
|
|
return
|
|
|
|
# Sample prioritized experiences
|
|
experiences = self._sample_prioritized_experiences()
|
|
|
|
if not experiences:
|
|
return
|
|
|
|
training_start = time.time()
|
|
|
|
# Track training count and log intervals
|
|
if not hasattr(self, 'dqn_training_count'):
|
|
self.dqn_training_count = 0
|
|
|
|
# Batch experiences for training
|
|
batch_size = min(self.training_config['batch_size'], len(experiences))
|
|
total_loss = 0
|
|
training_iterations = 0
|
|
|
|
for i in range(0, len(experiences), batch_size):
|
|
batch = experiences[i:i + batch_size]
|
|
|
|
# Prepare batch data
|
|
states = []
|
|
actions = []
|
|
rewards = []
|
|
next_states = []
|
|
dones = []
|
|
|
|
for exp in batch:
|
|
states.append(exp['state'])
|
|
actions.append(exp['action'])
|
|
rewards.append(exp['reward'])
|
|
next_states.append(exp['next_state'])
|
|
dones.append(exp['done'])
|
|
|
|
# Convert to numpy arrays
|
|
states = np.array(states)
|
|
actions = np.array(actions)
|
|
rewards = np.array(rewards, dtype=np.float32)
|
|
next_states = np.array(next_states)
|
|
dones = np.array(dones, dtype=bool)
|
|
|
|
# Perform training step
|
|
if hasattr(rl_agent, 'train_step'):
|
|
loss = rl_agent.train_step(states, actions, rewards, next_states, dones)
|
|
if loss is not None:
|
|
total_loss += loss
|
|
training_iterations += 1
|
|
elif hasattr(rl_agent, 'replay'):
|
|
# Fallback to replay method
|
|
loss = rl_agent.replay(batch_size=len(batch))
|
|
if loss is not None:
|
|
total_loss += loss
|
|
training_iterations += 1
|
|
|
|
training_time = time.time() - training_start
|
|
avg_loss = total_loss / training_iterations if training_iterations > 0 else 0
|
|
|
|
self.dqn_training_count += 1
|
|
|
|
# Log progress every 10 training sessions
|
|
if self.dqn_training_count % 10 == 0:
|
|
logger.info(f"DQN TRAINING: Session {self.dqn_training_count}, "
|
|
f"Memory={memory_size}, Batches={training_iterations}, "
|
|
f"Avg Loss={avg_loss:.4f}, Time={training_time:.2f}s")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in enhanced DQN training: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
def _perform_enhanced_cob_rl_training(self):
|
|
"""Enhanced COB RL training using comprehensive 2D matrix features - HIGHEST PRIORITY"""
|
|
try:
|
|
if not self.orchestrator or not hasattr(self.orchestrator, 'cob_rl_agent') or not self.orchestrator.cob_rl_agent:
|
|
return
|
|
|
|
cob_rl_agent = self.orchestrator.cob_rl_agent
|
|
|
|
# Check if we have COB training matrices available
|
|
if not hasattr(self, 'cob_training_matrices'):
|
|
return
|
|
|
|
training_updates = 0
|
|
|
|
for symbol in ['ETH/USDT', 'BTC/USDT']:
|
|
if symbol in self.cob_training_matrices:
|
|
# Get the comprehensive 2000-dimensional feature matrix
|
|
combined_features = self.get_cob_training_matrix(symbol, 'combined')
|
|
raw_tick_matrix = self.get_cob_training_matrix(symbol, 'raw_tick')
|
|
aggregated_matrix = self.get_cob_training_matrix(symbol, '1s_aggregated')
|
|
|
|
if combined_features is not None:
|
|
# Create enhanced COB training experience
|
|
current_price = self._get_current_price_from_data(symbol)
|
|
if current_price:
|
|
# Generate COB-based action using imbalance signals
|
|
action = self._generate_cob_action_from_matrices(symbol, combined_features, raw_tick_matrix)
|
|
|
|
# Calculate reward based on COB prediction accuracy
|
|
reward = self._calculate_cob_reward(symbol, action, combined_features)
|
|
|
|
# Create comprehensive state vector for COB RL
|
|
state = combined_features # 2000-dimensional state
|
|
|
|
# Store experience in COB RL agent
|
|
if hasattr(cob_rl_agent, 'store_experience'):
|
|
experience = {
|
|
'state': state,
|
|
'action': action,
|
|
'reward': reward,
|
|
'next_state': state, # Will be updated with next observation
|
|
'done': False,
|
|
'symbol': symbol,
|
|
'timestamp': datetime.now(),
|
|
'price': current_price,
|
|
'cob_features': {
|
|
'raw_tick_available': raw_tick_matrix is not None,
|
|
'aggregated_available': aggregated_matrix is not None,
|
|
'imbalance': combined_features[0] if len(combined_features) > 0 else 0,
|
|
'spread': combined_features[1] if len(combined_features) > 1 else 0,
|
|
'liquidity': combined_features[4] if len(combined_features) > 4 else 0
|
|
}
|
|
}
|
|
cob_rl_agent.store_experience(experience)
|
|
training_updates += 1
|
|
|
|
# Perform COB RL training if enough experiences
|
|
if hasattr(cob_rl_agent, 'get_memory_size'):
|
|
memory_size = cob_rl_agent.get_memory_size()
|
|
if memory_size >= 100: # Minimum experiences for training
|
|
if hasattr(cob_rl_agent, 'train'):
|
|
# Train with batch of COB experiences
|
|
training_loss = cob_rl_agent.train(batch_size=32)
|
|
|
|
if training_loss is not None:
|
|
self.performance_history['cob_rl_losses'].append(training_loss)
|
|
|
|
# Update orchestrator with COB performance
|
|
if hasattr(self.orchestrator, 'update_model_loss'):
|
|
self.orchestrator.update_model_loss('cob_rl', training_loss)
|
|
|
|
logger.info(f"COB RL TRAINING (PRIORITY): {symbol} - loss={training_loss:.6f}, memory={memory_size}, features_dim={len(combined_features)}")
|
|
|
|
# Log COB training summary
|
|
if training_updates > 0:
|
|
logger.info(f"COB RL ENHANCED TRAINING: {training_updates} experiences stored across symbols using 2D matrix features")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in enhanced COB RL training: {e}")
|
|
|
|
def _generate_cob_action_from_matrices(self, symbol: str, combined_features: np.ndarray, raw_tick_matrix: Optional[np.ndarray]) -> int:
|
|
"""
|
|
Generate trading action based on COB matrix analysis
|
|
Uses both combined features and raw tick patterns to predict optimal action
|
|
"""
|
|
try:
|
|
if len(combined_features) < 10:
|
|
return 1 # HOLD as fallback
|
|
|
|
# Extract key COB signals from combined features
|
|
imbalance = combined_features[0] # Order book imbalance (most critical)
|
|
spread = combined_features[1] # Bid-ask spread
|
|
bid_liquidity = combined_features[2] # Bid side liquidity
|
|
ask_liquidity = combined_features[3] # Ask side liquidity
|
|
total_liquidity = combined_features[4] # Total liquidity
|
|
|
|
# Analyze imbalance signal strength (primary predictor)
|
|
action_score = 0.0
|
|
|
|
# 1. Imbalance-based signal (60% weight)
|
|
if imbalance > 0.1: # Strong bid imbalance suggests upward pressure
|
|
action_score += 0.6
|
|
elif imbalance < -0.1: # Strong ask imbalance suggests downward pressure
|
|
action_score -= 0.6
|
|
else: # Balanced book suggests sideways movement
|
|
action_score += 0.0
|
|
|
|
# 2. Spread analysis (20% weight)
|
|
if spread < 0.05: # Tight spread suggests strong liquidity/stability
|
|
action_score += 0.1 if imbalance > 0 else -0.1 if imbalance < 0 else 0
|
|
elif spread > 0.15: # Wide spread suggests uncertainty/volatility
|
|
action_score *= 0.5 # Reduce confidence
|
|
|
|
# 3. Liquidity depth analysis (20% weight)
|
|
liquidity_ratio = bid_liquidity / max(ask_liquidity, 0.001)
|
|
if liquidity_ratio > 1.2: # More bid liquidity
|
|
action_score += 0.2
|
|
elif liquidity_ratio < 0.8: # More ask liquidity
|
|
action_score -= 0.2
|
|
|
|
# 4. Raw tick momentum analysis (if available)
|
|
if raw_tick_matrix is not None and len(raw_tick_matrix) > 5:
|
|
# Analyze recent tick patterns
|
|
recent_imbalances = raw_tick_matrix[-5:, 0] # Last 5 imbalance values
|
|
imbalance_trend = np.mean(np.diff(recent_imbalances)) if len(recent_imbalances) > 1 else 0
|
|
|
|
if imbalance_trend > 0.02: # Increasing imbalance momentum
|
|
action_score += 0.1
|
|
elif imbalance_trend < -0.02: # Decreasing imbalance momentum
|
|
action_score -= 0.1
|
|
|
|
# Convert action score to discrete action
|
|
if action_score > 0.3:
|
|
return 2 # BUY - Strong bullish COB signal
|
|
elif action_score < -0.3:
|
|
return 0 # SELL - Strong bearish COB signal
|
|
else:
|
|
return 1 # HOLD - Neutral or weak COB signal
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error generating COB action: {e}")
|
|
return 1 # HOLD as fallback
|
|
|
|
def _calculate_cob_reward(self, symbol: str, action: int, combined_features: np.ndarray) -> float:
|
|
"""
|
|
Calculate reward for COB RL training based on prediction accuracy and market outcomes
|
|
"""
|
|
try:
|
|
# Get recent price data to validate COB prediction
|
|
recent_prices = self._get_historical_price_sequence(symbol, 3)
|
|
if len(recent_prices) < 2:
|
|
return 0.0
|
|
|
|
# Calculate short-term price movement
|
|
price_change = (recent_prices[-1] - recent_prices[-2]) / recent_prices[-2]
|
|
|
|
# Extract COB features for reward calculation
|
|
imbalance = combined_features[0] if len(combined_features) > 0 else 0
|
|
spread = combined_features[1] if len(combined_features) > 1 else 0
|
|
|
|
# Base reward based on action-outcome alignment
|
|
base_reward = 0.0
|
|
|
|
if action == 2: # BUY action
|
|
if price_change > 0.0005: # Price moved up (0.05%+)
|
|
base_reward = 1.0 # Correct prediction
|
|
elif price_change < -0.0005: # Price moved down
|
|
base_reward = -1.0 # Incorrect prediction
|
|
else:
|
|
base_reward = -0.1 # Neutral movement (slight penalty for aggressive action)
|
|
|
|
elif action == 0: # SELL action
|
|
if price_change < -0.0005: # Price moved down
|
|
base_reward = 1.0 # Correct prediction
|
|
elif price_change > 0.0005: # Price moved up
|
|
base_reward = -1.0 # Incorrect prediction
|
|
else:
|
|
base_reward = -0.1 # Neutral movement (slight penalty for aggressive action)
|
|
|
|
else: # HOLD action (action == 1)
|
|
if abs(price_change) < 0.0005: # Neutral movement
|
|
base_reward = 0.5 # Correct prediction of low volatility
|
|
else:
|
|
base_reward = -0.2 # Missed opportunity
|
|
|
|
# Bonus/penalty based on COB signal strength
|
|
signal_strength = abs(imbalance) + (1.0 / max(spread, 0.01)) # Strong imbalance + tight spread
|
|
if signal_strength > 1.0:
|
|
base_reward *= 1.2 # Bonus for acting on strong signals
|
|
elif signal_strength < 0.3:
|
|
base_reward *= 0.8 # Penalty for acting on weak signals
|
|
|
|
# Clamp reward to reasonable range
|
|
return max(-2.0, min(2.0, base_reward))
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error calculating COB reward: {e}")
|
|
return 0.0
|
|
|
|
def _collect_cob_training_experiences(self) -> List[Dict]:
|
|
"""Collect COB-specific training experiences from real market data"""
|
|
try:
|
|
experiences = []
|
|
|
|
# Get recent COB snapshots with price outcomes
|
|
if not self.real_time_data['cob_snapshots']:
|
|
return experiences
|
|
|
|
# Take last 20 COB snapshots for training
|
|
recent_cobs = list(self.real_time_data['cob_snapshots'])[-20:]
|
|
|
|
for i, cob_snapshot in enumerate(recent_cobs):
|
|
if i == len(recent_cobs) - 1: # Skip last one (no future price)
|
|
break
|
|
|
|
current_price = cob_snapshot.get('current_price', 0)
|
|
if current_price <= 0:
|
|
continue
|
|
|
|
# Get price change over next 30 seconds
|
|
next_snapshot = recent_cobs[i + 1]
|
|
next_price = next_snapshot.get('current_price', current_price)
|
|
price_change = (next_price - current_price) / current_price
|
|
|
|
# Extract comprehensive COB features
|
|
cob_features = self._extract_comprehensive_cob_features(cob_snapshot)
|
|
|
|
if len(cob_features) == 0:
|
|
continue
|
|
|
|
experience = {
|
|
'timestamp': cob_snapshot.get('timestamp', datetime.now()),
|
|
'symbol': cob_snapshot.get('symbol', 'ETH/USDT'),
|
|
'cob_features': cob_features,
|
|
'current_price': current_price,
|
|
'next_price': next_price,
|
|
'price_change': price_change,
|
|
'imbalance': cob_snapshot.get('stats', {}).get('imbalance', 0),
|
|
'spread': cob_snapshot.get('stats', {}).get('spread_bps', 0),
|
|
}
|
|
|
|
experiences.append(experience)
|
|
|
|
return experiences
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error collecting COB training experiences: {e}")
|
|
return []
|
|
|
|
def _extract_comprehensive_cob_features(self, cob_snapshot: Dict) -> np.ndarray:
|
|
"""Extract comprehensive 2000-dimensional COB features for the massive RL model"""
|
|
try:
|
|
features = []
|
|
|
|
# 1. Basic COB statistics (50 features)
|
|
stats = cob_snapshot.get('stats', {})
|
|
basic_features = [
|
|
stats.get('imbalance', 0),
|
|
stats.get('spread_bps', 0) / 100.0, # Normalize
|
|
stats.get('bid_liquidity', 0) / 1000000.0, # Normalize to millions
|
|
stats.get('ask_liquidity', 0) / 1000000.0,
|
|
stats.get('total_liquidity', 0) / 1000000.0,
|
|
cob_snapshot.get('levels', 0) / 100.0, # Normalize level count
|
|
cob_snapshot.get('current_price', 0) / 10000.0, # Normalize price
|
|
]
|
|
|
|
# Pad basic features to 50
|
|
while len(basic_features) < 50:
|
|
basic_features.append(0.0)
|
|
features.extend(basic_features[:50])
|
|
|
|
# 2. Price bucket features (500 features)
|
|
price_buckets = cob_snapshot.get('price_buckets', {})
|
|
bucket_features = []
|
|
|
|
# Get sorted bucket keys and extract features
|
|
bucket_keys = sorted(price_buckets.keys())[:500] # Top 500 buckets
|
|
for bucket_key in bucket_keys:
|
|
bucket_data = price_buckets[bucket_key]
|
|
bucket_features.extend([
|
|
bucket_data.get('bid_volume', 0) / 1000000.0,
|
|
bucket_data.get('ask_volume', 0) / 1000000.0,
|
|
bucket_data.get('imbalance', 0),
|
|
bucket_data.get('momentum', 0)
|
|
])
|
|
|
|
# Pad bucket features to 500
|
|
while len(bucket_features) < 500:
|
|
bucket_features.append(0.0)
|
|
features.extend(bucket_features[:500])
|
|
|
|
# 3. Order book level features (1000 features)
|
|
ob_levels = cob_snapshot.get('order_book_levels', [])
|
|
level_features = []
|
|
|
|
for level in ob_levels[:250]: # Top 250 levels (250 * 4 = 1000 features)
|
|
level_features.extend([
|
|
level.get('bid_price', 0) / 10000.0,
|
|
level.get('bid_volume', 0) / 1000000.0,
|
|
level.get('ask_price', 0) / 10000.0,
|
|
level.get('ask_volume', 0) / 1000000.0
|
|
])
|
|
|
|
# Pad level features to 1000
|
|
while len(level_features) < 1000:
|
|
level_features.append(0.0)
|
|
features.extend(level_features[:1000])
|
|
|
|
# 4. Technical indicators (200 features)
|
|
tech_features = []
|
|
|
|
# Price momentum indicators
|
|
price_history = self._get_historical_price_sequence(cob_snapshot.get('symbol', 'ETH/USDT'), 20)
|
|
if len(price_history) >= 10:
|
|
current_price = price_history[-1]
|
|
prev_prices = price_history[-10:]
|
|
|
|
# Price changes over different periods
|
|
for i in [1, 2, 3, 5, 10]:
|
|
if len(prev_prices) > i:
|
|
change = (current_price - prev_prices[-i]) / prev_prices[-i]
|
|
tech_features.append(change)
|
|
|
|
# Moving averages
|
|
if len(prev_prices) >= 5:
|
|
ma5 = sum(prev_prices[-5:]) / 5
|
|
tech_features.append((current_price - ma5) / ma5)
|
|
|
|
if len(prev_prices) >= 10:
|
|
ma10 = sum(prev_prices[-10:]) / 10
|
|
tech_features.append((current_price - ma10) / ma10)
|
|
|
|
# Volatility measure
|
|
if len(prev_prices) >= 5:
|
|
volatility = np.std(prev_prices[-5:]) / np.mean(prev_prices[-5:])
|
|
tech_features.append(volatility)
|
|
|
|
# Pad technical features to 200
|
|
while len(tech_features) < 200:
|
|
tech_features.append(0.0)
|
|
features.extend(tech_features[:200])
|
|
|
|
# 5. Time-based features (50 features)
|
|
timestamp = cob_snapshot.get('timestamp', datetime.now())
|
|
if isinstance(timestamp, str):
|
|
timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
|
|
|
|
time_features = [
|
|
np.sin(2 * np.pi * timestamp.hour / 24), # Hour of day (cyclical)
|
|
np.cos(2 * np.pi * timestamp.hour / 24),
|
|
timestamp.weekday() / 6.0, # Day of week
|
|
timestamp.minute / 59.0, # Minute of hour
|
|
(timestamp - datetime(2024, 1, 1)).days / 365.0, # Days since reference
|
|
]
|
|
|
|
# Pad time features to 50
|
|
while len(time_features) < 50:
|
|
time_features.append(0.0)
|
|
features.extend(time_features[:50])
|
|
|
|
# 6. Market context features (200 features)
|
|
context_features = []
|
|
|
|
# Recent market events and patterns
|
|
recent_snapshots = list(self.real_time_data['cob_snapshots'])[-10:]
|
|
if len(recent_snapshots) > 1:
|
|
# Imbalance trend
|
|
imbalances = [snap.get('stats', {}).get('imbalance', 0) for snap in recent_snapshots]
|
|
imbalance_trend = np.mean(np.diff(imbalances)) if len(imbalances) > 1 else 0
|
|
context_features.append(imbalance_trend)
|
|
|
|
# Spread trend
|
|
spreads = [snap.get('stats', {}).get('spread_bps', 0) for snap in recent_snapshots]
|
|
spread_trend = np.mean(np.diff(spreads)) if len(spreads) > 1 else 0
|
|
context_features.append(spread_trend)
|
|
|
|
# Liquidity trend
|
|
liquidities = [snap.get('stats', {}).get('total_liquidity', 0) for snap in recent_snapshots]
|
|
liquidity_trend = np.mean(np.diff(liquidities)) if len(liquidities) > 1 else 0
|
|
context_features.append(liquidity_trend / 1000000.0)
|
|
|
|
# Pad context features to 200
|
|
while len(context_features) < 200:
|
|
context_features.append(0.0)
|
|
features.extend(context_features[:200])
|
|
|
|
# Ensure exactly 2000 features
|
|
while len(features) < 2000:
|
|
features.append(0.0)
|
|
|
|
return np.array(features[:2000], dtype=np.float32)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error extracting comprehensive COB features: {e}")
|
|
return np.zeros(2000, dtype=np.float32)
|
|
|
|
def _perform_enhanced_cnn_training(self):
|
|
"""Perform enhanced CNN training with real market features"""
|
|
try:
|
|
if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
|
return
|
|
|
|
model = self.orchestrator.cnn_model
|
|
|
|
# Create training sequences
|
|
sequences = self._create_cnn_training_sequences()
|
|
|
|
if len(sequences) < 10:
|
|
return
|
|
|
|
training_losses = []
|
|
|
|
# Train on sequences
|
|
for sequence_batch in self._batch_sequences(sequences, 16):
|
|
try:
|
|
# Extract features and targets
|
|
features = np.array([seq['features'] for seq in sequence_batch])
|
|
targets = np.array([seq['target'] for seq in sequence_batch])
|
|
|
|
# Perform actual PyTorch training
|
|
loss = self._perform_real_cnn_training(features, targets)
|
|
if loss is not None:
|
|
training_losses.append(loss)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"CNN batch training failed: {e}")
|
|
|
|
# Update performance tracking
|
|
if training_losses:
|
|
avg_loss = np.mean(training_losses)
|
|
self.performance_history['cnn_losses'].append(avg_loss)
|
|
|
|
if hasattr(self.orchestrator, 'update_model_loss'):
|
|
self.orchestrator.update_model_loss('cnn', avg_loss)
|
|
|
|
logger.info(f"CNN ENHANCED TRAINING: {len(sequences)} sequences, avg_loss={avg_loss:.6f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in enhanced CNN training: {e}")
|
|
|
|
def _create_cnn_training_sequences(self) -> List[Dict]:
|
|
"""Create training sequences for CNN price prediction"""
|
|
try:
|
|
sequences = []
|
|
|
|
if len(self.real_time_data['ohlcv_1m']) < 20:
|
|
return sequences
|
|
|
|
bars = list(self.real_time_data['ohlcv_1m'])
|
|
|
|
# Create sequences of length 15 to predict next price
|
|
for i in range(len(bars) - 15):
|
|
sequence_bars = bars[i:i+15]
|
|
target_bar = bars[i+15]
|
|
|
|
# Create feature matrix (15 x features)
|
|
features = []
|
|
for bar in sequence_bars:
|
|
bar_features = [
|
|
bar['open'] / 10000,
|
|
bar['high'] / 10000,
|
|
bar['low'] / 10000,
|
|
bar['close'] / 10000,
|
|
bar['volume'] / 1000000,
|
|
]
|
|
features.append(bar_features)
|
|
|
|
# Pad features to standard size (15 x 20)
|
|
feature_matrix = np.zeros((15, 20))
|
|
for j, feat in enumerate(features):
|
|
feature_matrix[j, :len(feat)] = feat
|
|
|
|
# Target: price direction (0=down, 1=same, 2=up)
|
|
price_change = (target_bar['close'] - sequence_bars[-1]['close']) / sequence_bars[-1]['close']
|
|
if price_change > 0.001:
|
|
target = 2 # UP
|
|
elif price_change < -0.001:
|
|
target = 0 # DOWN
|
|
else:
|
|
target = 1 # SAME
|
|
|
|
sequences.append({
|
|
'features': feature_matrix.flatten(), # Flatten for neural network
|
|
'target': target,
|
|
'price_change': price_change
|
|
})
|
|
|
|
return sequences
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating CNN sequences: {e}")
|
|
return []
|
|
|
|
def _batch_sequences(self, sequences: List[Dict], batch_size: int):
|
|
"""Batch sequences for training"""
|
|
for i in range(0, len(sequences), batch_size):
|
|
yield sequences[i:i + batch_size]
|
|
|
|
def _perform_real_cnn_training(self, features: np.ndarray, targets: np.ndarray) -> float:
|
|
"""Train the CNN model with real data and backpropagation"""
|
|
try:
|
|
if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
|
logger.debug("CNN model not available for training.")
|
|
return 1.0
|
|
|
|
model = self.orchestrator.cnn_model
|
|
optimizer = self.orchestrator.cnn_optimizer # Assuming orchestrator holds the optimizer
|
|
criterion = nn.CrossEntropyLoss() # For price direction (classification)
|
|
|
|
model.train()
|
|
optimizer.zero_grad()
|
|
|
|
# Convert numpy arrays to PyTorch tensors and move to device
|
|
device = next(model.parameters()).device
|
|
features_tensor = torch.from_numpy(features).float().to(device)
|
|
targets_tensor = torch.from_numpy(targets).long().to(device)
|
|
|
|
# Ensure features_tensor has the correct shape for CNN (batch_size, channels, height, width)
|
|
# Assuming features are flattened (batch_size, 15*20) and need to be reshaped to (batch_size, 1, 15, 20)
|
|
# This depends on the actual CNN model architecture. Assuming a simple CNN that expects (batch, channels, height, width)
|
|
# For now, assuming the CNN expects (batch_size, features_dimension) if it's a 1D CNN or fully connected layers after flatten.
|
|
# Based on _create_cnn_training_sequences, features are flattened.
|
|
# Let's reshape to (batch_size, 1, 15, 20) if it's an image-like input or (batch_size, 1, features_len) for 1D CNN.
|
|
# Given the previous flattening, it's likely a 1D CNN or a fully connected layer expecting 1D input.
|
|
# If the CNN model expects (batch_size, in_channels, sequence_length), we need to reshape correctly.
|
|
# For now, let's assume it expects (batch_size, features_dimension).
|
|
|
|
# If the CNN expects (batch_size, channels, sequence_length)
|
|
# features_tensor = features_tensor.view(features_tensor.shape[0], 1, 15 * 20) # Example for 1D CNN
|
|
|
|
# Ensure proper shape for CNN input
|
|
if len(features_tensor.shape) == 2:
|
|
# If it's (batch_size, features), keep as is for 1D CNN
|
|
pass
|
|
elif len(features_tensor.shape) == 1:
|
|
# If it's (features), add batch dimension
|
|
features_tensor = features_tensor.unsqueeze(0)
|
|
else:
|
|
# Reshape to (batch_size, features) if needed
|
|
features_tensor = features_tensor.view(features_tensor.shape[0], -1)
|
|
|
|
# Limit input size to prevent shape mismatches
|
|
if features_tensor.shape[1] > 1000: # Limit to 1000 features
|
|
features_tensor = features_tensor[:, :1000]
|
|
|
|
outputs = model(features_tensor)
|
|
|
|
loss = criterion(outputs, targets_tensor)
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
return loss.item()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in CNN training: {e}")
|
|
return 1.0 # Return default loss value in case of error
|
|
|
|
def _perform_validation(self):
|
|
"""Perform validation to track model performance"""
|
|
try:
|
|
# Validate DQN performance
|
|
dqn_score = self._validate_dqn_performance()
|
|
|
|
# Validate CNN performance
|
|
cnn_score = self._validate_cnn_performance()
|
|
|
|
# Update validation history
|
|
validation_result = {
|
|
'timestamp': time.time(),
|
|
'dqn_score': dqn_score,
|
|
'cnn_score': cnn_score,
|
|
'combined_score': (dqn_score + cnn_score) / 2
|
|
}
|
|
|
|
self.performance_history['validation_scores'].append(validation_result)
|
|
|
|
logger.info(f"VALIDATION: DQN={dqn_score:.3f}, CNN={cnn_score:.3f}, Combined={validation_result['combined_score']:.3f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in validation: {e}")
|
|
|
|
def _validate_dqn_performance(self) -> float:
|
|
"""Validate DQN performance based on recent decisions"""
|
|
try:
|
|
if len(self.performance_history['dqn_losses']) < 10:
|
|
return 0.5 # Neutral score
|
|
|
|
# Score based on loss improvement
|
|
recent_losses = list(self.performance_history['dqn_losses'])[-10:]
|
|
loss_trend = np.polyfit(range(len(recent_losses)), recent_losses, 1)[0]
|
|
|
|
# Negative trend (improving) = higher score
|
|
score = 0.5 + np.tanh(-loss_trend * 1000) # Scale and bound to 0-1
|
|
|
|
return max(0.0, min(1.0, score))
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error validating DQN: {e}")
|
|
return 0.5
|
|
|
|
def _validate_cnn_performance(self) -> float:
|
|
"""Validate CNN performance based on prediction accuracy"""
|
|
try:
|
|
if len(self.performance_history['cnn_losses']) < 10:
|
|
return 0.5 # Neutral score
|
|
|
|
# Score based on loss improvement
|
|
recent_losses = list(self.performance_history['cnn_losses'])[-10:]
|
|
loss_trend = np.polyfit(range(len(recent_losses)), recent_losses, 1)[0]
|
|
|
|
score = 0.5 + np.tanh(-loss_trend * 100)
|
|
|
|
return max(0.0, min(1.0, score))
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error validating CNN: {e}")
|
|
return 0.5
|
|
|
|
def _adapt_learning_parameters(self):
|
|
"""Adapt learning parameters based on performance"""
|
|
try:
|
|
if len(self.performance_history['validation_scores']) < 5:
|
|
return
|
|
|
|
recent_scores = [v['combined_score'] for v in list(self.performance_history['validation_scores'])[-5:]]
|
|
avg_score = np.mean(recent_scores)
|
|
|
|
# Adapt training frequency based on performance
|
|
if avg_score < 0.4: # Poor performance
|
|
self.training_config['dqn_training_interval'] = max(3, self.training_config['dqn_training_interval'] - 1)
|
|
self.training_config['cnn_training_interval'] = max(5, self.training_config['cnn_training_interval'] - 2)
|
|
logger.info("ADAPTATION: Increased training frequency due to poor performance")
|
|
elif avg_score > 0.7: # Good performance
|
|
self.training_config['dqn_training_interval'] = min(10, self.training_config['dqn_training_interval'] + 1)
|
|
self.training_config['cnn_training_interval'] = min(15, self.training_config['cnn_training_interval'] + 2)
|
|
logger.info("ADAPTATION: Decreased training frequency due to good performance")
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error in parameter adaptation: {e}")
|
|
|
|
def _log_training_progress(self):
|
|
"""Log comprehensive training progress"""
|
|
try:
|
|
stats = {
|
|
'iteration': self.training_iteration,
|
|
'experience_buffer': len(self.experience_buffer),
|
|
'priority_buffer': len(self.priority_buffer),
|
|
'dqn_memory': self._get_dqn_memory_size(),
|
|
'data_streams': {
|
|
'ohlcv_1m': len(self.real_time_data['ohlcv_1m']),
|
|
'ticks': len(self.real_time_data['ticks']),
|
|
'cob_snapshots': len(self.real_time_data['cob_snapshots']),
|
|
'market_events': len(self.real_time_data['market_events'])
|
|
}
|
|
}
|
|
|
|
if self.performance_history['dqn_losses']:
|
|
stats['dqn_avg_loss'] = np.mean(list(self.performance_history['dqn_losses'])[-10:])
|
|
|
|
if self.performance_history['cnn_losses']:
|
|
stats['cnn_avg_loss'] = np.mean(list(self.performance_history['cnn_losses'])[-10:])
|
|
|
|
if self.performance_history['validation_scores']:
|
|
stats['validation_score'] = self.performance_history['validation_scores'][-1]['combined_score']
|
|
|
|
logger.info(f"ENHANCED TRAINING PROGRESS: {stats}")
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error logging progress: {e}")
|
|
|
|
def _validation_worker(self):
|
|
"""Background worker for continuous validation"""
|
|
while self.is_training:
|
|
try:
|
|
time.sleep(30) # Validate every 30 seconds
|
|
|
|
# Quick performance check
|
|
if len(self.performance_history['validation_scores']) >= 2:
|
|
recent_scores = [v['combined_score'] for v in list(self.performance_history['validation_scores'])[-2:]]
|
|
if recent_scores[-1] < recent_scores[-2] - 0.1: # Performance dropped
|
|
logger.warning("VALIDATION: Performance drop detected - consider model adjustment")
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error in validation worker: {e}")
|
|
time.sleep(60)
|
|
|
|
def _calculate_rsi(self, prices, period=14):
|
|
"""Calculate RSI indicator"""
|
|
try:
|
|
if len(prices) < period + 1:
|
|
return 50.0
|
|
|
|
deltas = np.diff(prices)
|
|
gains = np.where(deltas > 0, deltas, 0)
|
|
losses = np.where(deltas < 0, -deltas, 0)
|
|
|
|
avg_gain = np.mean(gains[-period:])
|
|
avg_loss = np.mean(losses[-period:])
|
|
|
|
if avg_loss == 0:
|
|
return 100.0
|
|
|
|
rs = avg_gain / avg_loss
|
|
rsi = 100 - (100 / (1 + rs))
|
|
return float(rsi)
|
|
except:
|
|
return 50.0
|
|
|
|
def _get_dqn_memory_size(self) -> int:
|
|
"""Get DQN agent memory size"""
|
|
try:
|
|
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
|
|
and self.orchestrator.rl_agent and hasattr(self.orchestrator.rl_agent, 'memory')):
|
|
return len(self.orchestrator.rl_agent.memory)
|
|
return 0
|
|
except:
|
|
return 0
|
|
|
|
def get_training_statistics(self) -> Dict[str, Any]:
|
|
"""Get comprehensive training statistics"""
|
|
try:
|
|
stats = {
|
|
'is_training': self.is_training,
|
|
'training_iteration': self.training_iteration,
|
|
'experience_buffer_size': len(self.experience_buffer),
|
|
'priority_buffer_size': len(self.priority_buffer),
|
|
'data_collection_stats': {
|
|
'ohlcv_1m_bars': len(self.real_time_data['ohlcv_1m']),
|
|
'tick_data_points': len(self.real_time_data['ticks']),
|
|
'cob_snapshots': len(self.real_time_data['cob_snapshots']),
|
|
'market_events': len(self.real_time_data['market_events'])
|
|
},
|
|
'performance_history': {
|
|
'dqn_loss_count': len(self.performance_history['dqn_losses']),
|
|
'cnn_loss_count': len(self.performance_history['cnn_losses']),
|
|
'validation_count': len(self.performance_history['validation_scores'])
|
|
},
|
|
'prediction_stats': {
|
|
'dqn_predictions': {symbol: len(predictions) for symbol, predictions in self.recent_dqn_predictions.items()},
|
|
'cnn_predictions': {symbol: len(predictions) for symbol, predictions in self.recent_cnn_predictions.items()},
|
|
'accuracy_history': {symbol: len(history) for symbol, history in self.prediction_accuracy_history.items()}
|
|
}
|
|
}
|
|
|
|
if self.performance_history['dqn_losses']:
|
|
stats['dqn_recent_loss'] = list(self.performance_history['dqn_losses'])[-1]
|
|
|
|
if self.performance_history['cnn_losses']:
|
|
stats['cnn_recent_loss'] = list(self.performance_history['cnn_losses'])[-1]
|
|
|
|
if self.performance_history['validation_scores']:
|
|
stats['recent_validation_score'] = self.performance_history['validation_scores'][-1]['combined_score']
|
|
|
|
return stats
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting training statistics: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def capture_dqn_prediction(self, symbol: str, state: np.ndarray, q_values: List[float], action: int, confidence: float, price: float):
|
|
"""Capture DQN prediction for dashboard visualization"""
|
|
try:
|
|
prediction = {
|
|
'timestamp': datetime.now(),
|
|
'symbol': symbol,
|
|
'state': state.tolist() if hasattr(state, 'tolist') else state,
|
|
'q_values': q_values,
|
|
'action': action, # 0=BUY, 1=SELL, 2=HOLD
|
|
'confidence': confidence,
|
|
'price': price
|
|
}
|
|
|
|
if symbol in self.recent_dqn_predictions:
|
|
self.recent_dqn_predictions[symbol].append(prediction)
|
|
|
|
logger.debug(f"DQN prediction captured: {symbol} action={action} confidence={confidence:.2f}")
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error capturing DQN prediction: {e}")
|
|
|
|
def capture_cnn_prediction(self, symbol: str, current_price: float, predicted_price: float, direction: int, confidence: float, features: Optional[np.ndarray] = None):
|
|
"""Capture CNN prediction for dashboard visualization"""
|
|
try:
|
|
prediction = {
|
|
'timestamp': datetime.now(),
|
|
'symbol': symbol,
|
|
'current_price': current_price,
|
|
'predicted_price': predicted_price,
|
|
'direction': direction, # 0=DOWN, 1=SAME, 2=UP
|
|
'confidence': confidence,
|
|
'features': features.tolist() if features is not None and hasattr(features, 'tolist') else None
|
|
}
|
|
|
|
if symbol in self.recent_cnn_predictions:
|
|
self.recent_cnn_predictions[symbol].append(prediction)
|
|
|
|
logger.debug(f"CNN prediction captured: {symbol} direction={direction} confidence={confidence:.2f}")
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error capturing CNN prediction: {e}")
|
|
|
|
def validate_prediction_accuracy(self, symbol: str, prediction_type: str, predicted_action: int, actual_price_change: float, confidence: float):
|
|
"""Validate prediction accuracy and store results"""
|
|
try:
|
|
# Determine if prediction was correct
|
|
was_correct = False
|
|
|
|
if prediction_type == 'DQN':
|
|
# For DQN: BUY (0) should be followed by price increase, SELL (1) by decrease
|
|
if predicted_action == 0 and actual_price_change > 0.001: # BUY + price up
|
|
was_correct = True
|
|
elif predicted_action == 1 and actual_price_change < -0.001: # SELL + price down
|
|
was_correct = True
|
|
elif predicted_action == 2 and abs(actual_price_change) <= 0.001: # HOLD + no change
|
|
was_correct = True
|
|
|
|
elif prediction_type == 'CNN':
|
|
# For CNN: direction prediction accuracy
|
|
if predicted_action == 2 and actual_price_change > 0.001: # UP + price up
|
|
was_correct = True
|
|
elif predicted_action == 0 and actual_price_change < -0.001: # DOWN + price down
|
|
was_correct = True
|
|
elif predicted_action == 1 and abs(actual_price_change) <= 0.001: # SAME + no change
|
|
was_correct = True
|
|
|
|
# Calculate accuracy score based on confidence and correctness
|
|
accuracy_score = confidence if was_correct else (1.0 - confidence)
|
|
|
|
accuracy_data = {
|
|
'timestamp': datetime.now(),
|
|
'symbol': symbol,
|
|
'prediction_type': prediction_type,
|
|
'correct': was_correct,
|
|
'accuracy_score': accuracy_score,
|
|
'confidence': confidence,
|
|
'actual_price_change': actual_price_change,
|
|
'predicted_action': predicted_action
|
|
}
|
|
|
|
if symbol in self.prediction_accuracy_history:
|
|
self.prediction_accuracy_history[symbol].append(accuracy_data)
|
|
|
|
logger.debug(f"Prediction accuracy validated: {symbol} {prediction_type} correct={was_correct} score={accuracy_score:.2f}")
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error validating prediction accuracy: {e}")
|
|
|
|
def get_prediction_summary(self, symbol: str) -> Dict[str, Any]:
|
|
"""Get prediction summary for a symbol"""
|
|
try:
|
|
summary = {
|
|
'symbol': symbol,
|
|
'dqn_predictions': len(self.recent_dqn_predictions.get(symbol, [])),
|
|
'cnn_predictions': len(self.recent_cnn_predictions.get(symbol, [])),
|
|
'accuracy_history': len(self.prediction_accuracy_history.get(symbol, [])),
|
|
'pending_predictions': len(self.pending_predictions.get(symbol, []))
|
|
}
|
|
|
|
# Calculate accuracy statistics
|
|
if symbol in self.prediction_accuracy_history and self.prediction_accuracy_history[symbol]:
|
|
accuracy_data = list(self.prediction_accuracy_history[symbol])
|
|
|
|
total_predictions = len(accuracy_data)
|
|
correct_predictions = sum(1 for acc in accuracy_data if acc['correct'])
|
|
|
|
summary['total_predictions'] = total_predictions
|
|
summary['correct_predictions'] = correct_predictions
|
|
summary['accuracy_rate'] = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
|
|
|
# Calculate accuracy by prediction type
|
|
dqn_accuracy_data = [acc for acc in accuracy_data if acc['prediction_type'] == 'DQN']
|
|
cnn_accuracy_data = [acc for acc in accuracy_data if acc['prediction_type'] == 'CNN']
|
|
|
|
if dqn_accuracy_data:
|
|
dqn_correct = sum(1 for acc in dqn_accuracy_data if acc['correct'])
|
|
summary['dqn_accuracy_rate'] = dqn_correct / len(dqn_accuracy_data)
|
|
else:
|
|
summary['dqn_accuracy_rate'] = 0.0
|
|
|
|
if cnn_accuracy_data:
|
|
cnn_correct = sum(1 for acc in cnn_accuracy_data if acc['correct'])
|
|
summary['cnn_accuracy_rate'] = cnn_correct / len(cnn_accuracy_data)
|
|
else:
|
|
summary['cnn_accuracy_rate'] = 0.0
|
|
|
|
return summary
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting prediction summary: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def generate_forward_looking_predictions(self):
|
|
"""Generate forward-looking predictions based on current market data"""
|
|
try:
|
|
current_time = time.time()
|
|
|
|
for symbol in ['ETH/USDT', 'BTC/USDT']:
|
|
# Check if it's time to make new predictions
|
|
time_since_last = current_time - self.last_prediction_time.get(symbol, 0)
|
|
|
|
# Generate DQN prediction every 30 seconds
|
|
if time_since_last >= self.prediction_intervals['dqn']:
|
|
self._generate_forward_dqn_prediction(symbol, current_time)
|
|
|
|
# Generate CNN prediction every 60 seconds
|
|
if time_since_last >= self.prediction_intervals['cnn']:
|
|
self._generate_forward_cnn_prediction(symbol, current_time)
|
|
|
|
# Validate pending predictions
|
|
self._validate_pending_predictions(symbol, current_time)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating forward-looking predictions: {e}")
|
|
|
|
def _generate_forward_dqn_prediction(self, symbol: str, current_time: float):
|
|
"""Generate a DQN prediction for future price movement"""
|
|
try:
|
|
# Get current market state (only historical data)
|
|
current_state = self._build_comprehensive_state()
|
|
current_price = self._get_current_price_from_data(symbol)
|
|
|
|
if current_price is None:
|
|
return
|
|
|
|
# Use DQN model to predict action (if available)
|
|
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
|
|
and self.orchestrator.rl_agent):
|
|
|
|
# Get Q-values from model
|
|
action = self.orchestrator.rl_agent.act(current_state, explore=False)
|
|
# Get Q-values separately if available
|
|
if hasattr(self.orchestrator.rl_agent, 'policy_net'):
|
|
with torch.no_grad():
|
|
state_tensor = torch.FloatTensor(current_state).unsqueeze(0).to(self.orchestrator.rl_agent.device)
|
|
q_values_tensor = self.orchestrator.rl_agent.policy_net(state_tensor)
|
|
if isinstance(q_values_tensor, tuple):
|
|
q_values = q_values_tensor[0].cpu().numpy()[0].tolist()
|
|
else:
|
|
q_values = q_values_tensor.cpu().numpy()[0].tolist()
|
|
else:
|
|
q_values = [0.33, 0.33, 0.34] # Default uniform distribution
|
|
|
|
confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33
|
|
|
|
else:
|
|
# Fallback to technical analysis-based prediction
|
|
action, q_values, confidence = self._technical_analysis_prediction(symbol)
|
|
|
|
# Create forward-looking prediction
|
|
prediction_time = datetime.now()
|
|
target_time = prediction_time + timedelta(minutes=5) # Predict 5 minutes ahead
|
|
|
|
prediction = {
|
|
'id': f"dqn_{symbol}_{int(current_time)}",
|
|
'type': 'DQN',
|
|
'symbol': symbol,
|
|
'prediction_time': prediction_time,
|
|
'target_time': target_time,
|
|
'current_price': current_price,
|
|
'predicted_action': action,
|
|
'q_values': q_values,
|
|
'confidence': confidence,
|
|
'state': current_state.tolist() if hasattr(current_state, 'tolist') else current_state,
|
|
'validated': False
|
|
}
|
|
|
|
# Add to pending predictions for future validation
|
|
if symbol in self.pending_predictions:
|
|
self.pending_predictions[symbol].append(prediction)
|
|
|
|
# Add to recent predictions for display (only if confident enough)
|
|
if confidence > 0.4:
|
|
display_prediction = {
|
|
'timestamp': prediction_time,
|
|
'price': current_price,
|
|
'action': action,
|
|
'confidence': confidence,
|
|
'q_values': q_values
|
|
}
|
|
if symbol in self.recent_dqn_predictions:
|
|
self.recent_dqn_predictions[symbol].append(display_prediction)
|
|
|
|
self.last_prediction_time[symbol] = int(current_time)
|
|
|
|
logger.info(f"Forward DQN prediction: {symbol} action={['BUY','SELL','HOLD'][action]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating forward DQN prediction: {e}")
|
|
|
|
def _generate_forward_cnn_prediction(self, symbol: str, current_time: float):
|
|
"""Generate a CNN prediction for future price direction"""
|
|
try:
|
|
# Get current price and historical sequence (only past data)
|
|
current_price = self._get_current_price_from_data(symbol)
|
|
price_sequence = self._get_historical_price_sequence(symbol, periods=15)
|
|
|
|
if current_price is None or len(price_sequence) < 15:
|
|
return
|
|
|
|
# Use CNN model to predict direction (if available)
|
|
if (self.orchestrator and hasattr(self.orchestrator, 'cnn_model')
|
|
and self.orchestrator.cnn_model):
|
|
|
|
# Prepare features for CNN
|
|
features = self._prepare_cnn_features(price_sequence)
|
|
|
|
try:
|
|
# Get prediction from CNN model
|
|
prediction_output = self.orchestrator.cnn_model.predict(features)
|
|
if hasattr(prediction_output, 'tolist'):
|
|
pred_probs = prediction_output.tolist()
|
|
else:
|
|
pred_probs = [0.33, 0.33, 0.34] # Default
|
|
|
|
direction = int(np.argmax(pred_probs)) # 0=DOWN, 1=SAME, 2=UP
|
|
confidence = max(pred_probs)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"CNN model prediction failed: {e}")
|
|
direction, confidence = self._technical_direction_prediction(symbol)
|
|
|
|
else:
|
|
# Fallback to technical analysis
|
|
direction, confidence = self._technical_direction_prediction(symbol)
|
|
|
|
# Calculate predicted price based on direction
|
|
price_change_percent = self._estimate_price_change(direction, confidence)
|
|
predicted_price = current_price * (1 + price_change_percent)
|
|
|
|
# Create forward-looking prediction
|
|
prediction_time = datetime.now()
|
|
target_time = prediction_time + timedelta(minutes=10) # Predict 10 minutes ahead
|
|
|
|
prediction = {
|
|
'id': f"cnn_{symbol}_{int(current_time)}",
|
|
'type': 'CNN',
|
|
'symbol': symbol,
|
|
'prediction_time': prediction_time,
|
|
'target_time': target_time,
|
|
'current_price': current_price,
|
|
'predicted_price': predicted_price,
|
|
'direction': direction,
|
|
'confidence': confidence,
|
|
'features': features.tolist() if hasattr(features, 'tolist') else None,
|
|
'validated': False
|
|
}
|
|
|
|
# Add to pending predictions for future validation
|
|
if symbol in self.pending_predictions:
|
|
self.pending_predictions[symbol].append(prediction)
|
|
|
|
# Add to recent predictions for display (only if confident enough)
|
|
if confidence > 0.5:
|
|
display_prediction = {
|
|
'timestamp': prediction_time,
|
|
'current_price': current_price,
|
|
'predicted_price': predicted_price,
|
|
'direction': direction,
|
|
'confidence': confidence
|
|
}
|
|
if symbol in self.recent_cnn_predictions:
|
|
self.recent_cnn_predictions[symbol].append(display_prediction)
|
|
|
|
logger.info(f"Forward CNN prediction: {symbol} direction={['DOWN','SAME','UP'][direction]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating forward CNN prediction: {e}")
|
|
|
|
def _validate_pending_predictions(self, symbol: str, current_time: float):
|
|
"""Validate pending predictions when their target time arrives"""
|
|
try:
|
|
if symbol not in self.pending_predictions:
|
|
return
|
|
|
|
current_datetime = datetime.now()
|
|
validated_predictions = []
|
|
|
|
# Check each pending prediction
|
|
for prediction in list(self.pending_predictions[symbol]):
|
|
target_time = prediction['target_time']
|
|
|
|
# If target time has passed, validate the prediction
|
|
if current_datetime >= target_time:
|
|
actual_price = self._get_current_price_from_data(symbol)
|
|
|
|
if actual_price is not None:
|
|
# Calculate actual price change
|
|
predicted_price = prediction.get('predicted_price', prediction['current_price'])
|
|
actual_change = (actual_price - prediction['current_price']) / prediction['current_price']
|
|
predicted_change = (predicted_price - prediction['current_price']) / prediction['current_price']
|
|
|
|
# Validate based on prediction type
|
|
if prediction['type'] == 'DQN':
|
|
was_correct = self._validate_dqn_prediction(prediction, actual_change)
|
|
else: # CNN
|
|
was_correct = self._validate_cnn_prediction(prediction, actual_change)
|
|
|
|
# Store accuracy result
|
|
accuracy_data = {
|
|
'timestamp': current_datetime,
|
|
'symbol': symbol,
|
|
'prediction_type': prediction['type'],
|
|
'correct': was_correct,
|
|
'accuracy_score': prediction['confidence'] if was_correct else (1.0 - prediction['confidence']),
|
|
'confidence': prediction['confidence'],
|
|
'actual_price_change': actual_change,
|
|
'predicted_action': prediction.get('predicted_action', prediction.get('direction', 0)),
|
|
'actual_price': actual_price
|
|
}
|
|
|
|
if symbol in self.prediction_accuracy_history:
|
|
self.prediction_accuracy_history[symbol].append(accuracy_data)
|
|
|
|
validated_predictions.append(prediction['id'])
|
|
|
|
logger.info(f"Validated {prediction['type']} prediction: {symbol} correct={was_correct} confidence={prediction['confidence']:.2f}")
|
|
|
|
# Remove validated predictions from pending list
|
|
if validated_predictions:
|
|
self.pending_predictions[symbol] = deque([
|
|
p for p in self.pending_predictions[symbol]
|
|
if p['id'] not in validated_predictions
|
|
], maxlen=100)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating pending predictions: {e}")
|
|
|
|
def _validate_dqn_prediction(self, prediction: Dict, actual_change: float) -> bool:
|
|
"""Validate DQN action prediction"""
|
|
predicted_action = prediction['predicted_action']
|
|
threshold = 0.005 # 0.5% threshold for significant movement
|
|
|
|
if predicted_action == 0: # BUY prediction
|
|
return actual_change > threshold
|
|
elif predicted_action == 1: # SELL prediction
|
|
return actual_change < -threshold
|
|
else: # HOLD prediction
|
|
return abs(actual_change) <= threshold
|
|
|
|
def _validate_cnn_prediction(self, prediction: Dict, actual_change: float) -> bool:
|
|
"""Validate CNN direction prediction"""
|
|
predicted_direction = prediction['direction']
|
|
threshold = 0.002 # 0.2% threshold for direction
|
|
|
|
if predicted_direction == 2: # UP prediction
|
|
return actual_change > threshold
|
|
elif predicted_direction == 0: # DOWN prediction
|
|
return actual_change < -threshold
|
|
else: # SAME prediction
|
|
return abs(actual_change) <= threshold
|
|
|
|
def _get_current_price_from_data(self, symbol: str) -> Optional[float]:
|
|
"""Get current price from real-time data streams"""
|
|
try:
|
|
if len(self.real_time_data['ohlcv_1m']) > 0:
|
|
return self.real_time_data['ohlcv_1m'][-1]['close']
|
|
return None
|
|
except Exception as e:
|
|
logger.debug(f"Error getting current price: {e}")
|
|
return None
|
|
|
|
def _get_historical_price_sequence(self, symbol: str, periods: int = 15) -> List[float]:
|
|
"""Get historical price sequence for CNN features"""
|
|
try:
|
|
if len(self.real_time_data['ohlcv_1m']) >= periods:
|
|
return [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-periods:]]
|
|
return []
|
|
except Exception as e:
|
|
logger.debug(f"Error getting price sequence: {e}")
|
|
return []
|
|
|
|
def _technical_analysis_prediction(self, symbol: str) -> Tuple[int, List[float], float]:
|
|
"""Fallback technical analysis prediction for DQN"""
|
|
try:
|
|
# Simple momentum-based prediction
|
|
if len(self.real_time_data['ohlcv_1m']) >= 5:
|
|
recent_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-5:]]
|
|
momentum = (recent_prices[-1] - recent_prices[0]) / recent_prices[0]
|
|
|
|
if momentum > 0.01: # 1% upward momentum
|
|
return 0, [0.6, 0.2, 0.2], 0.6 # BUY
|
|
elif momentum < -0.01: # 1% downward momentum
|
|
return 1, [0.2, 0.6, 0.2], 0.6 # SELL
|
|
else:
|
|
return 2, [0.2, 0.2, 0.6], 0.6 # HOLD
|
|
|
|
return 2, [0.33, 0.33, 0.34], 0.33 # Default HOLD
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error in technical analysis prediction: {e}")
|
|
return 2, [0.33, 0.33, 0.34], 0.33
|
|
|
|
def _technical_direction_prediction(self, symbol: str) -> Tuple[int, float]:
|
|
"""Fallback technical analysis for CNN direction"""
|
|
try:
|
|
if len(self.real_time_data['ohlcv_1m']) >= 3:
|
|
recent_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-3:]]
|
|
short_momentum = (recent_prices[-1] - recent_prices[-2]) / recent_prices[-2]
|
|
|
|
if short_momentum > 0.005: # 0.5% short-term up
|
|
return 2, 0.65 # UP
|
|
elif short_momentum < -0.005: # 0.5% short-term down
|
|
return 0, 0.65 # DOWN
|
|
else:
|
|
return 1, 0.55 # SAME
|
|
|
|
return 1, 0.5 # Default SAME
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error in technical direction prediction: {e}")
|
|
return 1, 0.5
|
|
|
|
def _prepare_cnn_features(self, price_sequence: List[float]) -> np.ndarray:
|
|
"""Prepare features for CNN model"""
|
|
try:
|
|
# Normalize prices relative to first price
|
|
if len(price_sequence) >= 15:
|
|
base_price = price_sequence[0]
|
|
normalized = [(p - base_price) / base_price for p in price_sequence]
|
|
|
|
# Create feature matrix (15 x 20, flattened)
|
|
features = np.zeros((15, 20))
|
|
for i, norm_price in enumerate(normalized):
|
|
features[i, 0] = norm_price # Normalized price
|
|
if i > 0:
|
|
features[i, 1] = normalized[i] - normalized[i-1] # Price change
|
|
|
|
return features.flatten()
|
|
|
|
return np.zeros(300) # Default feature vector
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error preparing CNN features: {e}")
|
|
return np.zeros(300)
|
|
|
|
def _estimate_price_change(self, direction: int, confidence: float) -> float:
|
|
"""Estimate price change percentage based on direction and confidence"""
|
|
try:
|
|
# Base change scaled by confidence
|
|
base_change = 0.01 * confidence # Up to 1% change
|
|
|
|
if direction == 2: # UP
|
|
return base_change
|
|
elif direction == 0: # DOWN
|
|
return -base_change
|
|
else: # SAME
|
|
return 0.0
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error estimating price change: {e}")
|
|
return 0.0 |