detecting local extremes and training on them
This commit is contained in:
@ -22,10 +22,13 @@ from collections import deque
|
||||
import torch
|
||||
|
||||
from .config import get_config
|
||||
from .data_provider import DataProvider, RawTick, OHLCVBar
|
||||
from .data_provider import DataProvider, RawTick, OHLCVBar, MarketTick
|
||||
from .universal_data_adapter import UniversalDataAdapter, UniversalDataStream
|
||||
from .realtime_tick_processor import RealTimeTickProcessor, ProcessedTickFeatures, integrate_with_orchestrator
|
||||
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface
|
||||
from .extrema_trainer import ExtremaTrainer
|
||||
from .trading_action import TradingAction
|
||||
from .negative_case_trainer import NegativeCaseTrainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -87,6 +90,28 @@ class PerfectMove:
|
||||
market_state_after: MarketState
|
||||
confidence_should_have_been: float
|
||||
|
||||
@dataclass
|
||||
class TradeInfo:
|
||||
"""Information about an active trade"""
|
||||
symbol: str
|
||||
side: str # 'LONG' or 'SHORT'
|
||||
entry_price: float
|
||||
entry_time: datetime
|
||||
size: float
|
||||
confidence: float
|
||||
market_state: Dict[str, Any]
|
||||
|
||||
@dataclass
|
||||
class LearningCase:
|
||||
"""A learning case for DQN sensitivity training"""
|
||||
state_vector: np.ndarray
|
||||
action: int # sensitivity level chosen
|
||||
reward: float
|
||||
next_state_vector: np.ndarray
|
||||
done: bool
|
||||
trade_info: TradeInfo
|
||||
outcome: float # P&L percentage
|
||||
|
||||
class EnhancedTradingOrchestrator:
|
||||
"""
|
||||
Enhanced orchestrator with sophisticated multi-modal decision making
|
||||
@ -105,6 +130,16 @@ class EnhancedTradingOrchestrator:
|
||||
# Initialize real-time tick processor for ultra-low latency processing
|
||||
self.tick_processor = RealTimeTickProcessor(symbols=self.config.symbols)
|
||||
|
||||
# Initialize extrema trainer for local bottom/top detection and 200-candle context
|
||||
self.extrema_trainer = ExtremaTrainer(
|
||||
data_provider=self.data_provider,
|
||||
symbols=self.config.symbols,
|
||||
window_size=10 # 10-candle window for extrema detection
|
||||
)
|
||||
|
||||
# Initialize negative case trainer for intensive training on losing trades
|
||||
self.negative_case_trainer = NegativeCaseTrainer()
|
||||
|
||||
# Real-time tick features storage
|
||||
self.realtime_tick_features = {symbol: deque(maxlen=100) for symbol in self.config.symbols}
|
||||
|
||||
@ -151,6 +186,18 @@ class EnhancedTradingOrchestrator:
|
||||
self.retrospective_learning_active = False
|
||||
self.last_retrospective_analysis = datetime.now()
|
||||
|
||||
# Local extrema tracking for training on bottoms and tops
|
||||
self.local_extrema = {symbol: deque(maxlen=1000) for symbol in self.symbols}
|
||||
self.extrema_detection_window = 10 # Look for extrema in 10-candle windows
|
||||
self.extrema_training_queue = deque(maxlen=500) # Queue for extrema-based training
|
||||
self.last_extrema_check = {symbol: datetime.now() for symbol in self.symbols}
|
||||
|
||||
# 200-candle context data for models
|
||||
self.context_data_1m = {symbol: deque(maxlen=200) for symbol in self.symbols}
|
||||
self.context_features_1m = {symbol: None for symbol in self.symbols}
|
||||
self.context_update_frequency = 60 # Update context every 60 seconds
|
||||
self.last_context_update = {symbol: datetime.now() for symbol in self.symbols}
|
||||
|
||||
# RL feedback system
|
||||
self.rl_evaluation_queue = deque(maxlen=1000)
|
||||
self.environment_adaptation_rate = 0.01
|
||||
@ -182,6 +229,9 @@ class EnhancedTradingOrchestrator:
|
||||
# Current open positions tracking for closing logic
|
||||
self.open_positions = {} # symbol -> {'side': str, 'entry_price': float, 'timestamp': datetime}
|
||||
|
||||
# Initialize 200-candle context data
|
||||
self._initialize_context_data()
|
||||
|
||||
logger.info("Enhanced TradingOrchestrator initialized with Universal Data Format")
|
||||
logger.info(f"Symbols: {self.symbols}")
|
||||
logger.info(f"Timeframes: {self.timeframes}")
|
||||
@ -192,6 +242,8 @@ class EnhancedTradingOrchestrator:
|
||||
logger.info("Raw tick and OHLCV bar processing enabled for pattern detection")
|
||||
logger.info("Enhanced retrospective learning enabled for perfect opportunity detection")
|
||||
logger.info("DQN RL-based sensitivity learning enabled for adaptive thresholds")
|
||||
logger.info("Local extrema detection enabled for bottom/top training")
|
||||
logger.info("200-candle 1m context data initialized for enhanced model performance")
|
||||
|
||||
def _initialize_timeframe_weights(self) -> Dict[str, float]:
|
||||
"""Initialize weights for different timeframes"""
|
||||
@ -713,7 +765,7 @@ class EnhancedTradingOrchestrator:
|
||||
try:
|
||||
if symbol not in self.active_trades:
|
||||
return
|
||||
|
||||
|
||||
trade_info = self.active_trades[symbol]
|
||||
|
||||
# Calculate trade outcome
|
||||
@ -759,7 +811,7 @@ class EnhancedTradingOrchestrator:
|
||||
del self.active_trades[symbol]
|
||||
|
||||
logger.info(f"Closed trade for sensitivity learning: {symbol} {side} P&L: {pnl_pct*100:+.2f}% Duration: {duration:.0f}s")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error tracking trade closing for sensitivity learning: {e}")
|
||||
|
||||
@ -818,7 +870,7 @@ class EnhancedTradingOrchestrator:
|
||||
'price_change_4': price_changes[-4] if len(price_changes) > 3 else 0.0,
|
||||
'price_change_5': price_changes[-5] if len(price_changes) > 4 else 0.0
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting market state for sensitivity learning: {e}")
|
||||
return self._get_default_market_state()
|
||||
@ -969,7 +1021,7 @@ class EnhancedTradingOrchestrator:
|
||||
final_reward = np.clip(final_reward, -2.0, 2.0)
|
||||
|
||||
return float(final_reward)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating sensitivity reward: {e}")
|
||||
return 0.0
|
||||
@ -1045,7 +1097,7 @@ class EnhancedTradingOrchestrator:
|
||||
|
||||
# Update current sensitivity level based on recent performance
|
||||
self._update_current_sensitivity_level()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training sensitivity DQN: {e}")
|
||||
|
||||
@ -1131,6 +1183,374 @@ class EnhancedTradingOrchestrator:
|
||||
"""Get current closing threshold"""
|
||||
return self.confidence_threshold_close
|
||||
|
||||
def _initialize_context_data(self):
|
||||
"""Initialize 200-candle 1m context data for all symbols"""
|
||||
try:
|
||||
logger.info("Initializing 200-candle 1m context data for enhanced model performance")
|
||||
|
||||
for symbol in self.symbols:
|
||||
try:
|
||||
# Load 200 candles of 1m data
|
||||
context_data = self.data_provider.get_historical_data(symbol, '1m', limit=200)
|
||||
|
||||
if context_data is not None and len(context_data) > 0:
|
||||
# Store raw data
|
||||
for _, row in context_data.iterrows():
|
||||
candle_data = {
|
||||
'timestamp': row['timestamp'],
|
||||
'open': row['open'],
|
||||
'high': row['high'],
|
||||
'low': row['low'],
|
||||
'close': row['close'],
|
||||
'volume': row['volume']
|
||||
}
|
||||
self.context_data_1m[symbol].append(candle_data)
|
||||
|
||||
# Create feature matrix for models
|
||||
self.context_features_1m[symbol] = self._create_context_features(context_data)
|
||||
|
||||
logger.info(f"Loaded {len(context_data)} 1m candles for {symbol} context")
|
||||
else:
|
||||
logger.warning(f"No 1m context data available for {symbol}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading context data for {symbol}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing context data: {e}")
|
||||
|
||||
def _create_context_features(self, df: pd.DataFrame) -> Optional[np.ndarray]:
|
||||
"""Create feature matrix from 1m context data for model consumption"""
|
||||
try:
|
||||
if df is None or len(df) < 50:
|
||||
return None
|
||||
|
||||
# Select key features for context
|
||||
feature_columns = ['open', 'high', 'low', 'close', 'volume']
|
||||
|
||||
# Add technical indicators if available
|
||||
if 'rsi_14' in df.columns:
|
||||
feature_columns.extend(['rsi_14', 'sma_20', 'ema_12'])
|
||||
if 'macd' in df.columns:
|
||||
feature_columns.extend(['macd', 'macd_signal'])
|
||||
if 'bb_upper' in df.columns:
|
||||
feature_columns.extend(['bb_upper', 'bb_lower', 'bb_percent'])
|
||||
|
||||
# Extract available features
|
||||
available_features = [col for col in feature_columns if col in df.columns]
|
||||
feature_data = df[available_features].copy()
|
||||
|
||||
# Normalize features
|
||||
normalized_features = self._normalize_context_features(feature_data)
|
||||
|
||||
return normalized_features.values if normalized_features is not None else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating context features: {e}")
|
||||
return None
|
||||
|
||||
def _normalize_context_features(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
||||
"""Normalize context features for model consumption"""
|
||||
try:
|
||||
df_norm = df.copy()
|
||||
|
||||
# Price normalization (relative to latest close)
|
||||
if 'close' in df_norm.columns:
|
||||
latest_close = df_norm['close'].iloc[-1]
|
||||
for col in ['open', 'high', 'low', 'close', 'sma_20', 'ema_12', 'bb_upper', 'bb_lower']:
|
||||
if col in df_norm.columns and latest_close > 0:
|
||||
df_norm[col] = df_norm[col] / latest_close
|
||||
|
||||
# Volume normalization
|
||||
if 'volume' in df_norm.columns:
|
||||
volume_mean = df_norm['volume'].mean()
|
||||
if volume_mean > 0:
|
||||
df_norm['volume'] = df_norm['volume'] / volume_mean
|
||||
|
||||
# RSI normalization (0-100 to 0-1)
|
||||
if 'rsi_14' in df_norm.columns:
|
||||
df_norm['rsi_14'] = df_norm['rsi_14'] / 100.0
|
||||
|
||||
# MACD normalization
|
||||
if 'macd' in df_norm.columns and 'close' in df.columns:
|
||||
latest_close = df['close'].iloc[-1]
|
||||
df_norm['macd'] = df_norm['macd'] / latest_close
|
||||
if 'macd_signal' in df_norm.columns:
|
||||
df_norm['macd_signal'] = df_norm['macd_signal'] / latest_close
|
||||
|
||||
# BB percent is already normalized
|
||||
if 'bb_percent' in df_norm.columns:
|
||||
df_norm['bb_percent'] = np.clip(df_norm['bb_percent'], 0, 1)
|
||||
|
||||
# Fill NaN values
|
||||
df_norm = df_norm.fillna(0)
|
||||
|
||||
return df_norm
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing context features: {e}")
|
||||
return df
|
||||
|
||||
def update_context_data(self, symbol: str = None):
|
||||
"""Update 200-candle 1m context data for specified symbol or all symbols"""
|
||||
try:
|
||||
symbols_to_update = [symbol] if symbol else self.symbols
|
||||
|
||||
for sym in symbols_to_update:
|
||||
# Check if update is needed
|
||||
time_since_update = (datetime.now() - self.last_context_update[sym]).total_seconds()
|
||||
|
||||
if time_since_update >= self.context_update_frequency:
|
||||
# Get latest 1m data
|
||||
latest_data = self.data_provider.get_historical_data(sym, '1m', limit=10, refresh=True)
|
||||
|
||||
if latest_data is not None and len(latest_data) > 0:
|
||||
# Add new candles to context
|
||||
for _, row in latest_data.iterrows():
|
||||
candle_data = {
|
||||
'timestamp': row['timestamp'],
|
||||
'open': row['open'],
|
||||
'high': row['high'],
|
||||
'low': row['low'],
|
||||
'close': row['close'],
|
||||
'volume': row['volume']
|
||||
}
|
||||
|
||||
# Check if this candle is newer than our latest
|
||||
if (not self.context_data_1m[sym] or
|
||||
candle_data['timestamp'] > self.context_data_1m[sym][-1]['timestamp']):
|
||||
self.context_data_1m[sym].append(candle_data)
|
||||
|
||||
# Update feature matrix
|
||||
if len(self.context_data_1m[sym]) >= 50:
|
||||
context_df = pd.DataFrame(list(self.context_data_1m[sym]))
|
||||
self.context_features_1m[sym] = self._create_context_features(context_df)
|
||||
|
||||
self.last_context_update[sym] = datetime.now()
|
||||
|
||||
# Check for local extrema in updated data
|
||||
self._detect_local_extrema(sym)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating context data: {e}")
|
||||
|
||||
def _detect_local_extrema(self, symbol: str):
|
||||
"""Detect local bottoms and tops for training opportunities"""
|
||||
try:
|
||||
if len(self.context_data_1m[symbol]) < self.extrema_detection_window * 2:
|
||||
return
|
||||
|
||||
# Get recent price data
|
||||
recent_candles = list(self.context_data_1m[symbol])[-self.extrema_detection_window * 2:]
|
||||
prices = [candle['close'] for candle in recent_candles]
|
||||
timestamps = [candle['timestamp'] for candle in recent_candles]
|
||||
|
||||
# Detect local minima (bottoms) and maxima (tops)
|
||||
window = self.extrema_detection_window
|
||||
|
||||
for i in range(window, len(prices) - window):
|
||||
current_price = prices[i]
|
||||
current_time = timestamps[i]
|
||||
|
||||
# Check for local bottom
|
||||
is_bottom = all(current_price <= prices[j] for j in range(i - window, i + window + 1) if j != i)
|
||||
|
||||
# Check for local top
|
||||
is_top = all(current_price >= prices[j] for j in range(i - window, i + window + 1) if j != i)
|
||||
|
||||
if is_bottom or is_top:
|
||||
extrema_type = 'bottom' if is_bottom else 'top'
|
||||
|
||||
# Create training opportunity
|
||||
extrema_data = {
|
||||
'symbol': symbol,
|
||||
'timestamp': current_time,
|
||||
'price': current_price,
|
||||
'type': extrema_type,
|
||||
'context_before': prices[max(0, i - window):i],
|
||||
'context_after': prices[i + 1:min(len(prices), i + window + 1)],
|
||||
'optimal_action': 'BUY' if is_bottom else 'SELL',
|
||||
'confidence_level': self._calculate_extrema_confidence(prices, i, window),
|
||||
'market_context': self._get_extrema_market_context(symbol, current_time)
|
||||
}
|
||||
|
||||
self.local_extrema[symbol].append(extrema_data)
|
||||
self.extrema_training_queue.append(extrema_data)
|
||||
|
||||
logger.info(f"Local {extrema_type} detected for {symbol} at ${current_price:.2f} "
|
||||
f"(confidence: {extrema_data['confidence_level']:.3f})")
|
||||
|
||||
# Create perfect move for CNN training
|
||||
self._create_extrema_perfect_move(extrema_data)
|
||||
|
||||
self.last_extrema_check[symbol] = datetime.now()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting local extrema for {symbol}: {e}")
|
||||
|
||||
def _calculate_extrema_confidence(self, prices: List[float], extrema_index: int, window: int) -> float:
|
||||
"""Calculate confidence level for detected extrema"""
|
||||
try:
|
||||
extrema_price = prices[extrema_index]
|
||||
|
||||
# Calculate price deviation from extrema
|
||||
surrounding_prices = prices[max(0, extrema_index - window):extrema_index + window + 1]
|
||||
price_range = max(surrounding_prices) - min(surrounding_prices)
|
||||
|
||||
if price_range == 0:
|
||||
return 0.5
|
||||
|
||||
# Calculate how extreme the point is
|
||||
if extrema_price == min(surrounding_prices): # Bottom
|
||||
deviation = (max(surrounding_prices) - extrema_price) / price_range
|
||||
else: # Top
|
||||
deviation = (extrema_price - min(surrounding_prices)) / price_range
|
||||
|
||||
# Confidence based on how clear the extrema is
|
||||
confidence = min(0.95, max(0.3, deviation))
|
||||
|
||||
return confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating extrema confidence: {e}")
|
||||
return 0.5
|
||||
|
||||
def _get_extrema_market_context(self, symbol: str, timestamp: datetime) -> Dict[str, Any]:
|
||||
"""Get market context at the time of extrema detection"""
|
||||
try:
|
||||
# Get recent market data around the extrema
|
||||
context = {
|
||||
'volatility': 0.0,
|
||||
'volume_spike': False,
|
||||
'trend_strength': 0.0,
|
||||
'rsi_level': 50.0
|
||||
}
|
||||
|
||||
if len(self.context_data_1m[symbol]) >= 20:
|
||||
recent_candles = list(self.context_data_1m[symbol])[-20:]
|
||||
|
||||
# Calculate volatility
|
||||
prices = [c['close'] for c in recent_candles]
|
||||
price_changes = [abs(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))]
|
||||
context['volatility'] = np.mean(price_changes) if price_changes else 0.0
|
||||
|
||||
# Check for volume spike
|
||||
volumes = [c['volume'] for c in recent_candles]
|
||||
avg_volume = np.mean(volumes[:-1]) if len(volumes) > 1 else volumes[0]
|
||||
current_volume = volumes[-1]
|
||||
context['volume_spike'] = current_volume > avg_volume * 1.5
|
||||
|
||||
# Simple trend strength
|
||||
if len(prices) >= 10:
|
||||
trend_slope = (prices[-1] - prices[-10]) / prices[-10]
|
||||
context['trend_strength'] = abs(trend_slope)
|
||||
|
||||
return context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting extrema market context: {e}")
|
||||
return {'volatility': 0.0, 'volume_spike': False, 'trend_strength': 0.0, 'rsi_level': 50.0}
|
||||
|
||||
def _create_extrema_perfect_move(self, extrema_data: Dict[str, Any]):
|
||||
"""Create a perfect move from detected extrema for CNN training"""
|
||||
try:
|
||||
# Calculate outcome based on price movement after extrema
|
||||
if len(extrema_data['context_after']) > 0:
|
||||
price_after = extrema_data['context_after'][-1]
|
||||
price_change = (price_after - extrema_data['price']) / extrema_data['price']
|
||||
|
||||
# For bottoms, positive price change is good; for tops, negative is good
|
||||
if extrema_data['type'] == 'bottom':
|
||||
outcome = price_change
|
||||
else: # top
|
||||
outcome = -price_change
|
||||
|
||||
perfect_move = PerfectMove(
|
||||
symbol=extrema_data['symbol'],
|
||||
timeframe='1m',
|
||||
timestamp=extrema_data['timestamp'],
|
||||
optimal_action=extrema_data['optimal_action'],
|
||||
actual_outcome=abs(outcome),
|
||||
market_state_before=None,
|
||||
market_state_after=None,
|
||||
confidence_should_have_been=extrema_data['confidence_level']
|
||||
)
|
||||
|
||||
self.perfect_moves.append(perfect_move)
|
||||
self.retrospective_learning_active = True
|
||||
|
||||
logger.info(f"Created perfect move from {extrema_data['type']} extrema: "
|
||||
f"{extrema_data['optimal_action']} {extrema_data['symbol']} "
|
||||
f"(outcome: {outcome*100:+.2f}%)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating extrema perfect move: {e}")
|
||||
|
||||
def get_context_features_for_model(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Get 200-candle 1m context features for model consumption"""
|
||||
try:
|
||||
if symbol in self.context_features_1m and self.context_features_1m[symbol] is not None:
|
||||
return self.context_features_1m[symbol]
|
||||
|
||||
# If no cached features, create them from current data
|
||||
if len(self.context_data_1m[symbol]) >= 50:
|
||||
context_df = pd.DataFrame(list(self.context_data_1m[symbol]))
|
||||
features = self._create_context_features(context_df)
|
||||
self.context_features_1m[symbol] = features
|
||||
return features
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting context features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_extrema_training_data(self, count: int = 50) -> List[Dict[str, Any]]:
|
||||
"""Get recent extrema training data for model training"""
|
||||
try:
|
||||
return list(self.extrema_training_queue)[-count:] if self.extrema_training_queue else []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting extrema training data: {e}")
|
||||
return []
|
||||
|
||||
def get_extrema_stats(self) -> Dict[str, Any]:
|
||||
"""Get statistics about extrema detection and training"""
|
||||
try:
|
||||
stats = {
|
||||
'total_extrema_detected': sum(len(extrema) for extrema in self.local_extrema.values()),
|
||||
'extrema_by_symbol': {symbol: len(extrema) for symbol, extrema in self.local_extrema.items()},
|
||||
'training_queue_size': len(self.extrema_training_queue),
|
||||
'last_extrema_check': {symbol: check_time.isoformat()
|
||||
for symbol, check_time in self.last_extrema_check.items()},
|
||||
'context_data_status': {
|
||||
symbol: {
|
||||
'candles_loaded': len(self.context_data_1m[symbol]),
|
||||
'features_available': self.context_features_1m[symbol] is not None,
|
||||
'last_update': self.last_context_update[symbol].isoformat()
|
||||
}
|
||||
for symbol in self.symbols
|
||||
}
|
||||
}
|
||||
|
||||
# Recent extrema breakdown
|
||||
recent_extrema = list(self.extrema_training_queue)[-20:]
|
||||
if recent_extrema:
|
||||
bottoms = len([e for e in recent_extrema if e['type'] == 'bottom'])
|
||||
tops = len([e for e in recent_extrema if e['type'] == 'top'])
|
||||
avg_confidence = np.mean([e['confidence_level'] for e in recent_extrema])
|
||||
|
||||
stats['recent_extrema'] = {
|
||||
'bottoms': bottoms,
|
||||
'tops': tops,
|
||||
'avg_confidence': avg_confidence
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting extrema stats: {e}")
|
||||
return {}
|
||||
|
||||
def process_realtime_features(self, feature_dict: Dict[str, Any]):
|
||||
"""Process real-time tick features from the tick processor"""
|
||||
try:
|
||||
|
Reference in New Issue
Block a user