detecting local extremes and training on them
This commit is contained in:
@ -22,10 +22,13 @@ from collections import deque
|
||||
import torch
|
||||
|
||||
from .config import get_config
|
||||
from .data_provider import DataProvider, RawTick, OHLCVBar
|
||||
from .data_provider import DataProvider, RawTick, OHLCVBar, MarketTick
|
||||
from .universal_data_adapter import UniversalDataAdapter, UniversalDataStream
|
||||
from .realtime_tick_processor import RealTimeTickProcessor, ProcessedTickFeatures, integrate_with_orchestrator
|
||||
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface
|
||||
from .extrema_trainer import ExtremaTrainer
|
||||
from .trading_action import TradingAction
|
||||
from .negative_case_trainer import NegativeCaseTrainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -87,6 +90,28 @@ class PerfectMove:
|
||||
market_state_after: MarketState
|
||||
confidence_should_have_been: float
|
||||
|
||||
@dataclass
|
||||
class TradeInfo:
|
||||
"""Information about an active trade"""
|
||||
symbol: str
|
||||
side: str # 'LONG' or 'SHORT'
|
||||
entry_price: float
|
||||
entry_time: datetime
|
||||
size: float
|
||||
confidence: float
|
||||
market_state: Dict[str, Any]
|
||||
|
||||
@dataclass
|
||||
class LearningCase:
|
||||
"""A learning case for DQN sensitivity training"""
|
||||
state_vector: np.ndarray
|
||||
action: int # sensitivity level chosen
|
||||
reward: float
|
||||
next_state_vector: np.ndarray
|
||||
done: bool
|
||||
trade_info: TradeInfo
|
||||
outcome: float # P&L percentage
|
||||
|
||||
class EnhancedTradingOrchestrator:
|
||||
"""
|
||||
Enhanced orchestrator with sophisticated multi-modal decision making
|
||||
@ -105,6 +130,16 @@ class EnhancedTradingOrchestrator:
|
||||
# Initialize real-time tick processor for ultra-low latency processing
|
||||
self.tick_processor = RealTimeTickProcessor(symbols=self.config.symbols)
|
||||
|
||||
# Initialize extrema trainer for local bottom/top detection and 200-candle context
|
||||
self.extrema_trainer = ExtremaTrainer(
|
||||
data_provider=self.data_provider,
|
||||
symbols=self.config.symbols,
|
||||
window_size=10 # 10-candle window for extrema detection
|
||||
)
|
||||
|
||||
# Initialize negative case trainer for intensive training on losing trades
|
||||
self.negative_case_trainer = NegativeCaseTrainer()
|
||||
|
||||
# Real-time tick features storage
|
||||
self.realtime_tick_features = {symbol: deque(maxlen=100) for symbol in self.config.symbols}
|
||||
|
||||
@ -151,6 +186,18 @@ class EnhancedTradingOrchestrator:
|
||||
self.retrospective_learning_active = False
|
||||
self.last_retrospective_analysis = datetime.now()
|
||||
|
||||
# Local extrema tracking for training on bottoms and tops
|
||||
self.local_extrema = {symbol: deque(maxlen=1000) for symbol in self.symbols}
|
||||
self.extrema_detection_window = 10 # Look for extrema in 10-candle windows
|
||||
self.extrema_training_queue = deque(maxlen=500) # Queue for extrema-based training
|
||||
self.last_extrema_check = {symbol: datetime.now() for symbol in self.symbols}
|
||||
|
||||
# 200-candle context data for models
|
||||
self.context_data_1m = {symbol: deque(maxlen=200) for symbol in self.symbols}
|
||||
self.context_features_1m = {symbol: None for symbol in self.symbols}
|
||||
self.context_update_frequency = 60 # Update context every 60 seconds
|
||||
self.last_context_update = {symbol: datetime.now() for symbol in self.symbols}
|
||||
|
||||
# RL feedback system
|
||||
self.rl_evaluation_queue = deque(maxlen=1000)
|
||||
self.environment_adaptation_rate = 0.01
|
||||
@ -182,6 +229,9 @@ class EnhancedTradingOrchestrator:
|
||||
# Current open positions tracking for closing logic
|
||||
self.open_positions = {} # symbol -> {'side': str, 'entry_price': float, 'timestamp': datetime}
|
||||
|
||||
# Initialize 200-candle context data
|
||||
self._initialize_context_data()
|
||||
|
||||
logger.info("Enhanced TradingOrchestrator initialized with Universal Data Format")
|
||||
logger.info(f"Symbols: {self.symbols}")
|
||||
logger.info(f"Timeframes: {self.timeframes}")
|
||||
@ -192,6 +242,8 @@ class EnhancedTradingOrchestrator:
|
||||
logger.info("Raw tick and OHLCV bar processing enabled for pattern detection")
|
||||
logger.info("Enhanced retrospective learning enabled for perfect opportunity detection")
|
||||
logger.info("DQN RL-based sensitivity learning enabled for adaptive thresholds")
|
||||
logger.info("Local extrema detection enabled for bottom/top training")
|
||||
logger.info("200-candle 1m context data initialized for enhanced model performance")
|
||||
|
||||
def _initialize_timeframe_weights(self) -> Dict[str, float]:
|
||||
"""Initialize weights for different timeframes"""
|
||||
@ -713,7 +765,7 @@ class EnhancedTradingOrchestrator:
|
||||
try:
|
||||
if symbol not in self.active_trades:
|
||||
return
|
||||
|
||||
|
||||
trade_info = self.active_trades[symbol]
|
||||
|
||||
# Calculate trade outcome
|
||||
@ -759,7 +811,7 @@ class EnhancedTradingOrchestrator:
|
||||
del self.active_trades[symbol]
|
||||
|
||||
logger.info(f"Closed trade for sensitivity learning: {symbol} {side} P&L: {pnl_pct*100:+.2f}% Duration: {duration:.0f}s")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error tracking trade closing for sensitivity learning: {e}")
|
||||
|
||||
@ -818,7 +870,7 @@ class EnhancedTradingOrchestrator:
|
||||
'price_change_4': price_changes[-4] if len(price_changes) > 3 else 0.0,
|
||||
'price_change_5': price_changes[-5] if len(price_changes) > 4 else 0.0
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting market state for sensitivity learning: {e}")
|
||||
return self._get_default_market_state()
|
||||
@ -969,7 +1021,7 @@ class EnhancedTradingOrchestrator:
|
||||
final_reward = np.clip(final_reward, -2.0, 2.0)
|
||||
|
||||
return float(final_reward)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating sensitivity reward: {e}")
|
||||
return 0.0
|
||||
@ -1045,7 +1097,7 @@ class EnhancedTradingOrchestrator:
|
||||
|
||||
# Update current sensitivity level based on recent performance
|
||||
self._update_current_sensitivity_level()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training sensitivity DQN: {e}")
|
||||
|
||||
@ -1131,6 +1183,374 @@ class EnhancedTradingOrchestrator:
|
||||
"""Get current closing threshold"""
|
||||
return self.confidence_threshold_close
|
||||
|
||||
def _initialize_context_data(self):
|
||||
"""Initialize 200-candle 1m context data for all symbols"""
|
||||
try:
|
||||
logger.info("Initializing 200-candle 1m context data for enhanced model performance")
|
||||
|
||||
for symbol in self.symbols:
|
||||
try:
|
||||
# Load 200 candles of 1m data
|
||||
context_data = self.data_provider.get_historical_data(symbol, '1m', limit=200)
|
||||
|
||||
if context_data is not None and len(context_data) > 0:
|
||||
# Store raw data
|
||||
for _, row in context_data.iterrows():
|
||||
candle_data = {
|
||||
'timestamp': row['timestamp'],
|
||||
'open': row['open'],
|
||||
'high': row['high'],
|
||||
'low': row['low'],
|
||||
'close': row['close'],
|
||||
'volume': row['volume']
|
||||
}
|
||||
self.context_data_1m[symbol].append(candle_data)
|
||||
|
||||
# Create feature matrix for models
|
||||
self.context_features_1m[symbol] = self._create_context_features(context_data)
|
||||
|
||||
logger.info(f"Loaded {len(context_data)} 1m candles for {symbol} context")
|
||||
else:
|
||||
logger.warning(f"No 1m context data available for {symbol}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading context data for {symbol}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing context data: {e}")
|
||||
|
||||
def _create_context_features(self, df: pd.DataFrame) -> Optional[np.ndarray]:
|
||||
"""Create feature matrix from 1m context data for model consumption"""
|
||||
try:
|
||||
if df is None or len(df) < 50:
|
||||
return None
|
||||
|
||||
# Select key features for context
|
||||
feature_columns = ['open', 'high', 'low', 'close', 'volume']
|
||||
|
||||
# Add technical indicators if available
|
||||
if 'rsi_14' in df.columns:
|
||||
feature_columns.extend(['rsi_14', 'sma_20', 'ema_12'])
|
||||
if 'macd' in df.columns:
|
||||
feature_columns.extend(['macd', 'macd_signal'])
|
||||
if 'bb_upper' in df.columns:
|
||||
feature_columns.extend(['bb_upper', 'bb_lower', 'bb_percent'])
|
||||
|
||||
# Extract available features
|
||||
available_features = [col for col in feature_columns if col in df.columns]
|
||||
feature_data = df[available_features].copy()
|
||||
|
||||
# Normalize features
|
||||
normalized_features = self._normalize_context_features(feature_data)
|
||||
|
||||
return normalized_features.values if normalized_features is not None else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating context features: {e}")
|
||||
return None
|
||||
|
||||
def _normalize_context_features(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
||||
"""Normalize context features for model consumption"""
|
||||
try:
|
||||
df_norm = df.copy()
|
||||
|
||||
# Price normalization (relative to latest close)
|
||||
if 'close' in df_norm.columns:
|
||||
latest_close = df_norm['close'].iloc[-1]
|
||||
for col in ['open', 'high', 'low', 'close', 'sma_20', 'ema_12', 'bb_upper', 'bb_lower']:
|
||||
if col in df_norm.columns and latest_close > 0:
|
||||
df_norm[col] = df_norm[col] / latest_close
|
||||
|
||||
# Volume normalization
|
||||
if 'volume' in df_norm.columns:
|
||||
volume_mean = df_norm['volume'].mean()
|
||||
if volume_mean > 0:
|
||||
df_norm['volume'] = df_norm['volume'] / volume_mean
|
||||
|
||||
# RSI normalization (0-100 to 0-1)
|
||||
if 'rsi_14' in df_norm.columns:
|
||||
df_norm['rsi_14'] = df_norm['rsi_14'] / 100.0
|
||||
|
||||
# MACD normalization
|
||||
if 'macd' in df_norm.columns and 'close' in df.columns:
|
||||
latest_close = df['close'].iloc[-1]
|
||||
df_norm['macd'] = df_norm['macd'] / latest_close
|
||||
if 'macd_signal' in df_norm.columns:
|
||||
df_norm['macd_signal'] = df_norm['macd_signal'] / latest_close
|
||||
|
||||
# BB percent is already normalized
|
||||
if 'bb_percent' in df_norm.columns:
|
||||
df_norm['bb_percent'] = np.clip(df_norm['bb_percent'], 0, 1)
|
||||
|
||||
# Fill NaN values
|
||||
df_norm = df_norm.fillna(0)
|
||||
|
||||
return df_norm
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing context features: {e}")
|
||||
return df
|
||||
|
||||
def update_context_data(self, symbol: str = None):
|
||||
"""Update 200-candle 1m context data for specified symbol or all symbols"""
|
||||
try:
|
||||
symbols_to_update = [symbol] if symbol else self.symbols
|
||||
|
||||
for sym in symbols_to_update:
|
||||
# Check if update is needed
|
||||
time_since_update = (datetime.now() - self.last_context_update[sym]).total_seconds()
|
||||
|
||||
if time_since_update >= self.context_update_frequency:
|
||||
# Get latest 1m data
|
||||
latest_data = self.data_provider.get_historical_data(sym, '1m', limit=10, refresh=True)
|
||||
|
||||
if latest_data is not None and len(latest_data) > 0:
|
||||
# Add new candles to context
|
||||
for _, row in latest_data.iterrows():
|
||||
candle_data = {
|
||||
'timestamp': row['timestamp'],
|
||||
'open': row['open'],
|
||||
'high': row['high'],
|
||||
'low': row['low'],
|
||||
'close': row['close'],
|
||||
'volume': row['volume']
|
||||
}
|
||||
|
||||
# Check if this candle is newer than our latest
|
||||
if (not self.context_data_1m[sym] or
|
||||
candle_data['timestamp'] > self.context_data_1m[sym][-1]['timestamp']):
|
||||
self.context_data_1m[sym].append(candle_data)
|
||||
|
||||
# Update feature matrix
|
||||
if len(self.context_data_1m[sym]) >= 50:
|
||||
context_df = pd.DataFrame(list(self.context_data_1m[sym]))
|
||||
self.context_features_1m[sym] = self._create_context_features(context_df)
|
||||
|
||||
self.last_context_update[sym] = datetime.now()
|
||||
|
||||
# Check for local extrema in updated data
|
||||
self._detect_local_extrema(sym)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating context data: {e}")
|
||||
|
||||
def _detect_local_extrema(self, symbol: str):
|
||||
"""Detect local bottoms and tops for training opportunities"""
|
||||
try:
|
||||
if len(self.context_data_1m[symbol]) < self.extrema_detection_window * 2:
|
||||
return
|
||||
|
||||
# Get recent price data
|
||||
recent_candles = list(self.context_data_1m[symbol])[-self.extrema_detection_window * 2:]
|
||||
prices = [candle['close'] for candle in recent_candles]
|
||||
timestamps = [candle['timestamp'] for candle in recent_candles]
|
||||
|
||||
# Detect local minima (bottoms) and maxima (tops)
|
||||
window = self.extrema_detection_window
|
||||
|
||||
for i in range(window, len(prices) - window):
|
||||
current_price = prices[i]
|
||||
current_time = timestamps[i]
|
||||
|
||||
# Check for local bottom
|
||||
is_bottom = all(current_price <= prices[j] for j in range(i - window, i + window + 1) if j != i)
|
||||
|
||||
# Check for local top
|
||||
is_top = all(current_price >= prices[j] for j in range(i - window, i + window + 1) if j != i)
|
||||
|
||||
if is_bottom or is_top:
|
||||
extrema_type = 'bottom' if is_bottom else 'top'
|
||||
|
||||
# Create training opportunity
|
||||
extrema_data = {
|
||||
'symbol': symbol,
|
||||
'timestamp': current_time,
|
||||
'price': current_price,
|
||||
'type': extrema_type,
|
||||
'context_before': prices[max(0, i - window):i],
|
||||
'context_after': prices[i + 1:min(len(prices), i + window + 1)],
|
||||
'optimal_action': 'BUY' if is_bottom else 'SELL',
|
||||
'confidence_level': self._calculate_extrema_confidence(prices, i, window),
|
||||
'market_context': self._get_extrema_market_context(symbol, current_time)
|
||||
}
|
||||
|
||||
self.local_extrema[symbol].append(extrema_data)
|
||||
self.extrema_training_queue.append(extrema_data)
|
||||
|
||||
logger.info(f"Local {extrema_type} detected for {symbol} at ${current_price:.2f} "
|
||||
f"(confidence: {extrema_data['confidence_level']:.3f})")
|
||||
|
||||
# Create perfect move for CNN training
|
||||
self._create_extrema_perfect_move(extrema_data)
|
||||
|
||||
self.last_extrema_check[symbol] = datetime.now()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting local extrema for {symbol}: {e}")
|
||||
|
||||
def _calculate_extrema_confidence(self, prices: List[float], extrema_index: int, window: int) -> float:
|
||||
"""Calculate confidence level for detected extrema"""
|
||||
try:
|
||||
extrema_price = prices[extrema_index]
|
||||
|
||||
# Calculate price deviation from extrema
|
||||
surrounding_prices = prices[max(0, extrema_index - window):extrema_index + window + 1]
|
||||
price_range = max(surrounding_prices) - min(surrounding_prices)
|
||||
|
||||
if price_range == 0:
|
||||
return 0.5
|
||||
|
||||
# Calculate how extreme the point is
|
||||
if extrema_price == min(surrounding_prices): # Bottom
|
||||
deviation = (max(surrounding_prices) - extrema_price) / price_range
|
||||
else: # Top
|
||||
deviation = (extrema_price - min(surrounding_prices)) / price_range
|
||||
|
||||
# Confidence based on how clear the extrema is
|
||||
confidence = min(0.95, max(0.3, deviation))
|
||||
|
||||
return confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating extrema confidence: {e}")
|
||||
return 0.5
|
||||
|
||||
def _get_extrema_market_context(self, symbol: str, timestamp: datetime) -> Dict[str, Any]:
|
||||
"""Get market context at the time of extrema detection"""
|
||||
try:
|
||||
# Get recent market data around the extrema
|
||||
context = {
|
||||
'volatility': 0.0,
|
||||
'volume_spike': False,
|
||||
'trend_strength': 0.0,
|
||||
'rsi_level': 50.0
|
||||
}
|
||||
|
||||
if len(self.context_data_1m[symbol]) >= 20:
|
||||
recent_candles = list(self.context_data_1m[symbol])[-20:]
|
||||
|
||||
# Calculate volatility
|
||||
prices = [c['close'] for c in recent_candles]
|
||||
price_changes = [abs(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))]
|
||||
context['volatility'] = np.mean(price_changes) if price_changes else 0.0
|
||||
|
||||
# Check for volume spike
|
||||
volumes = [c['volume'] for c in recent_candles]
|
||||
avg_volume = np.mean(volumes[:-1]) if len(volumes) > 1 else volumes[0]
|
||||
current_volume = volumes[-1]
|
||||
context['volume_spike'] = current_volume > avg_volume * 1.5
|
||||
|
||||
# Simple trend strength
|
||||
if len(prices) >= 10:
|
||||
trend_slope = (prices[-1] - prices[-10]) / prices[-10]
|
||||
context['trend_strength'] = abs(trend_slope)
|
||||
|
||||
return context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting extrema market context: {e}")
|
||||
return {'volatility': 0.0, 'volume_spike': False, 'trend_strength': 0.0, 'rsi_level': 50.0}
|
||||
|
||||
def _create_extrema_perfect_move(self, extrema_data: Dict[str, Any]):
|
||||
"""Create a perfect move from detected extrema for CNN training"""
|
||||
try:
|
||||
# Calculate outcome based on price movement after extrema
|
||||
if len(extrema_data['context_after']) > 0:
|
||||
price_after = extrema_data['context_after'][-1]
|
||||
price_change = (price_after - extrema_data['price']) / extrema_data['price']
|
||||
|
||||
# For bottoms, positive price change is good; for tops, negative is good
|
||||
if extrema_data['type'] == 'bottom':
|
||||
outcome = price_change
|
||||
else: # top
|
||||
outcome = -price_change
|
||||
|
||||
perfect_move = PerfectMove(
|
||||
symbol=extrema_data['symbol'],
|
||||
timeframe='1m',
|
||||
timestamp=extrema_data['timestamp'],
|
||||
optimal_action=extrema_data['optimal_action'],
|
||||
actual_outcome=abs(outcome),
|
||||
market_state_before=None,
|
||||
market_state_after=None,
|
||||
confidence_should_have_been=extrema_data['confidence_level']
|
||||
)
|
||||
|
||||
self.perfect_moves.append(perfect_move)
|
||||
self.retrospective_learning_active = True
|
||||
|
||||
logger.info(f"Created perfect move from {extrema_data['type']} extrema: "
|
||||
f"{extrema_data['optimal_action']} {extrema_data['symbol']} "
|
||||
f"(outcome: {outcome*100:+.2f}%)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating extrema perfect move: {e}")
|
||||
|
||||
def get_context_features_for_model(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Get 200-candle 1m context features for model consumption"""
|
||||
try:
|
||||
if symbol in self.context_features_1m and self.context_features_1m[symbol] is not None:
|
||||
return self.context_features_1m[symbol]
|
||||
|
||||
# If no cached features, create them from current data
|
||||
if len(self.context_data_1m[symbol]) >= 50:
|
||||
context_df = pd.DataFrame(list(self.context_data_1m[symbol]))
|
||||
features = self._create_context_features(context_df)
|
||||
self.context_features_1m[symbol] = features
|
||||
return features
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting context features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_extrema_training_data(self, count: int = 50) -> List[Dict[str, Any]]:
|
||||
"""Get recent extrema training data for model training"""
|
||||
try:
|
||||
return list(self.extrema_training_queue)[-count:] if self.extrema_training_queue else []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting extrema training data: {e}")
|
||||
return []
|
||||
|
||||
def get_extrema_stats(self) -> Dict[str, Any]:
|
||||
"""Get statistics about extrema detection and training"""
|
||||
try:
|
||||
stats = {
|
||||
'total_extrema_detected': sum(len(extrema) for extrema in self.local_extrema.values()),
|
||||
'extrema_by_symbol': {symbol: len(extrema) for symbol, extrema in self.local_extrema.items()},
|
||||
'training_queue_size': len(self.extrema_training_queue),
|
||||
'last_extrema_check': {symbol: check_time.isoformat()
|
||||
for symbol, check_time in self.last_extrema_check.items()},
|
||||
'context_data_status': {
|
||||
symbol: {
|
||||
'candles_loaded': len(self.context_data_1m[symbol]),
|
||||
'features_available': self.context_features_1m[symbol] is not None,
|
||||
'last_update': self.last_context_update[symbol].isoformat()
|
||||
}
|
||||
for symbol in self.symbols
|
||||
}
|
||||
}
|
||||
|
||||
# Recent extrema breakdown
|
||||
recent_extrema = list(self.extrema_training_queue)[-20:]
|
||||
if recent_extrema:
|
||||
bottoms = len([e for e in recent_extrema if e['type'] == 'bottom'])
|
||||
tops = len([e for e in recent_extrema if e['type'] == 'top'])
|
||||
avg_confidence = np.mean([e['confidence_level'] for e in recent_extrema])
|
||||
|
||||
stats['recent_extrema'] = {
|
||||
'bottoms': bottoms,
|
||||
'tops': tops,
|
||||
'avg_confidence': avg_confidence
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting extrema stats: {e}")
|
||||
return {}
|
||||
|
||||
def process_realtime_features(self, feature_dict: Dict[str, Any]):
|
||||
"""Process real-time tick features from the tick processor"""
|
||||
try:
|
||||
|
584
core/extrema_trainer.py
Normal file
584
core/extrema_trainer.py
Normal file
@ -0,0 +1,584 @@
|
||||
"""
|
||||
Extrema Training Module - Reusable Local Bottom/Top Detection and Training
|
||||
|
||||
This module provides reusable functionality for:
|
||||
1. Detecting local extrema (bottoms and tops) in price data
|
||||
2. Creating training opportunities from extrema
|
||||
3. Loading and managing 200-candle 1m context data
|
||||
4. Generating features for model consumption
|
||||
5. Training on not-so-perfect opportunities
|
||||
|
||||
Can be used across different dashboards and trading systems.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
from collections import deque
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ExtremaPoint:
|
||||
"""Represents a detected local extrema (bottom or top)"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
price: float
|
||||
extrema_type: str # 'bottom' or 'top'
|
||||
confidence: float
|
||||
context_before: List[float]
|
||||
context_after: List[float]
|
||||
optimal_action: str # 'BUY' or 'SELL'
|
||||
market_context: Dict[str, Any]
|
||||
outcome: Optional[float] = None # Price change after extrema
|
||||
|
||||
@dataclass
|
||||
class ContextData:
|
||||
"""200-candle 1m context data for enhanced model performance"""
|
||||
symbol: str
|
||||
candles: deque
|
||||
features: Optional[np.ndarray]
|
||||
last_update: datetime
|
||||
|
||||
class ExtremaTrainer:
|
||||
"""Reusable extrema detection and training functionality"""
|
||||
|
||||
def __init__(self, data_provider, symbols: List[str], window_size: int = 10):
|
||||
"""
|
||||
Initialize the extrema trainer
|
||||
|
||||
Args:
|
||||
data_provider: Data provider instance
|
||||
symbols: List of symbols to track
|
||||
window_size: Window size for extrema detection (default 10)
|
||||
"""
|
||||
self.data_provider = data_provider
|
||||
self.symbols = symbols
|
||||
self.window_size = window_size
|
||||
|
||||
# Extrema tracking
|
||||
self.detected_extrema = {symbol: deque(maxlen=1000) for symbol in symbols}
|
||||
self.extrema_training_queue = deque(maxlen=500)
|
||||
self.last_extrema_check = {symbol: datetime.now() for symbol in symbols}
|
||||
|
||||
# 200-candle context data
|
||||
self.context_data = {symbol: ContextData(
|
||||
symbol=symbol,
|
||||
candles=deque(maxlen=200),
|
||||
features=None,
|
||||
last_update=datetime.now()
|
||||
) for symbol in symbols}
|
||||
|
||||
self.context_update_frequency = 60 # Update every 60 seconds
|
||||
|
||||
# Training parameters
|
||||
self.min_confidence_threshold = 0.3 # Train on opportunities with at least 30% confidence
|
||||
self.max_confidence_threshold = 0.95 # Cap confidence at 95%
|
||||
|
||||
logger.info(f"ExtremaTrainer initialized for symbols: {symbols}")
|
||||
logger.info(f"Window size: {window_size}, Context update frequency: {self.context_update_frequency}s")
|
||||
|
||||
def initialize_context_data(self) -> Dict[str, bool]:
|
||||
"""Initialize 200-candle 1m context data for all symbols"""
|
||||
results = {}
|
||||
|
||||
try:
|
||||
logger.info("Initializing 200-candle 1m context data for enhanced model performance")
|
||||
|
||||
for symbol in self.symbols:
|
||||
try:
|
||||
# Load 200 candles of 1m data
|
||||
context_data = self.data_provider.get_historical_data(symbol, '1m', limit=200)
|
||||
|
||||
if context_data is not None and len(context_data) > 0:
|
||||
# Store raw data
|
||||
for _, row in context_data.iterrows():
|
||||
candle_data = {
|
||||
'timestamp': row['timestamp'],
|
||||
'open': row['open'],
|
||||
'high': row['high'],
|
||||
'low': row['low'],
|
||||
'close': row['close'],
|
||||
'volume': row['volume']
|
||||
}
|
||||
self.context_data[symbol].candles.append(candle_data)
|
||||
|
||||
# Create feature matrix for models
|
||||
self.context_data[symbol].features = self._create_context_features(context_data)
|
||||
self.context_data[symbol].last_update = datetime.now()
|
||||
|
||||
results[symbol] = True
|
||||
logger.info(f"✅ Loaded {len(context_data)} 1m candles for {symbol} context")
|
||||
else:
|
||||
results[symbol] = False
|
||||
logger.warning(f"❌ No 1m context data available for {symbol}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error loading context data for {symbol}: {e}")
|
||||
results[symbol] = False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing context data: {e}")
|
||||
|
||||
successful = sum(1 for success in results.values() if success)
|
||||
logger.info(f"Context data initialization: {successful}/{len(self.symbols)} symbols loaded")
|
||||
|
||||
return results
|
||||
|
||||
def update_context_data(self, symbol: str = None) -> Dict[str, bool]:
|
||||
"""Update 200-candle 1m context data for specified symbol or all symbols"""
|
||||
results = {}
|
||||
|
||||
try:
|
||||
symbols_to_update = [symbol] if symbol else self.symbols
|
||||
|
||||
for sym in symbols_to_update:
|
||||
try:
|
||||
# Check if update is needed
|
||||
time_since_update = (datetime.now() - self.context_data[sym].last_update).total_seconds()
|
||||
|
||||
if time_since_update >= self.context_update_frequency:
|
||||
# Get latest 1m data
|
||||
latest_data = self.data_provider.get_historical_data(sym, '1m', limit=10, refresh=True)
|
||||
|
||||
if latest_data is not None and len(latest_data) > 0:
|
||||
# Add new candles to context
|
||||
for _, row in latest_data.iterrows():
|
||||
candle_data = {
|
||||
'timestamp': row['timestamp'],
|
||||
'open': row['open'],
|
||||
'high': row['high'],
|
||||
'low': row['low'],
|
||||
'close': row['close'],
|
||||
'volume': row['volume']
|
||||
}
|
||||
|
||||
# Check if this candle is newer than our latest
|
||||
if (not self.context_data[sym].candles or
|
||||
candle_data['timestamp'] > self.context_data[sym].candles[-1]['timestamp']):
|
||||
self.context_data[sym].candles.append(candle_data)
|
||||
|
||||
# Update feature matrix
|
||||
if len(self.context_data[sym].candles) >= 50:
|
||||
context_df = pd.DataFrame(list(self.context_data[sym].candles))
|
||||
self.context_data[sym].features = self._create_context_features(context_df)
|
||||
|
||||
self.context_data[sym].last_update = datetime.now()
|
||||
|
||||
# Check for local extrema in updated data
|
||||
self.detect_local_extrema(sym)
|
||||
|
||||
results[sym] = True
|
||||
else:
|
||||
results[sym] = False
|
||||
else:
|
||||
results[sym] = True # No update needed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating context data for {sym}: {e}")
|
||||
results[sym] = False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating context data: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def detect_local_extrema(self, symbol: str) -> List[ExtremaPoint]:
|
||||
"""Detect local bottoms and tops for training opportunities"""
|
||||
detected = []
|
||||
|
||||
try:
|
||||
if len(self.context_data[symbol].candles) < self.window_size * 3:
|
||||
return detected
|
||||
|
||||
# Get all available price data for better extrema detection
|
||||
all_candles = list(self.context_data[symbol].candles)
|
||||
prices = [candle['close'] for candle in all_candles]
|
||||
timestamps = [candle['timestamp'] for candle in all_candles]
|
||||
|
||||
# Use a more sophisticated extrema detection algorithm
|
||||
window = self.window_size
|
||||
|
||||
# Look for extrema in the middle portion of the data (not at edges)
|
||||
start_idx = window
|
||||
end_idx = len(prices) - window
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
current_price = prices[i]
|
||||
current_time = timestamps[i]
|
||||
|
||||
# Get surrounding prices for comparison
|
||||
left_prices = prices[i - window:i]
|
||||
right_prices = prices[i + 1:i + window + 1]
|
||||
|
||||
# Check for local bottom (current price is lower than surrounding prices)
|
||||
is_bottom = (current_price <= min(left_prices) and
|
||||
current_price <= min(right_prices) and
|
||||
current_price < max(left_prices) * 0.998) # At least 0.2% lower
|
||||
|
||||
# Check for local top (current price is higher than surrounding prices)
|
||||
is_top = (current_price >= max(left_prices) and
|
||||
current_price >= max(right_prices) and
|
||||
current_price > min(left_prices) * 1.002) # At least 0.2% higher
|
||||
|
||||
if is_bottom or is_top:
|
||||
extrema_type = 'bottom' if is_bottom else 'top'
|
||||
|
||||
# Calculate confidence based on price deviation and volume
|
||||
confidence = self._calculate_extrema_confidence(prices, i, window)
|
||||
|
||||
# Only process if confidence meets minimum threshold
|
||||
if confidence >= self.min_confidence_threshold:
|
||||
# Check if this extrema is too close to a previously detected one
|
||||
if not self._is_too_close_to_existing_extrema(symbol, current_time, current_price):
|
||||
# Create extrema point
|
||||
extrema_point = ExtremaPoint(
|
||||
symbol=symbol,
|
||||
timestamp=current_time,
|
||||
price=current_price,
|
||||
extrema_type=extrema_type,
|
||||
confidence=min(confidence, self.max_confidence_threshold),
|
||||
context_before=left_prices,
|
||||
context_after=right_prices,
|
||||
optimal_action='BUY' if is_bottom else 'SELL',
|
||||
market_context=self._get_extrema_market_context(symbol, current_time)
|
||||
)
|
||||
|
||||
# Calculate outcome if we have future data
|
||||
if len(right_prices) > 0:
|
||||
# Look ahead further for better outcome calculation
|
||||
future_idx = min(i + window * 2, len(prices) - 1)
|
||||
future_price = prices[future_idx]
|
||||
price_change = (future_price - current_price) / current_price
|
||||
|
||||
# For bottoms, positive change is good; for tops, negative is good
|
||||
if extrema_type == 'bottom':
|
||||
extrema_point.outcome = price_change
|
||||
else: # top
|
||||
extrema_point.outcome = -price_change
|
||||
|
||||
self.detected_extrema[symbol].append(extrema_point)
|
||||
self.extrema_training_queue.append(extrema_point)
|
||||
detected.append(extrema_point)
|
||||
|
||||
logger.info(f"Local {extrema_type} detected for {symbol} at ${current_price:.2f} "
|
||||
f"(confidence: {confidence:.3f}, outcome: {extrema_point.outcome:.4f})")
|
||||
|
||||
self.last_extrema_check[symbol] = datetime.now()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting local extrema for {symbol}: {e}")
|
||||
|
||||
return detected
|
||||
|
||||
def _is_too_close_to_existing_extrema(self, symbol: str, timestamp: datetime, price: float) -> bool:
|
||||
"""Check if this extrema is too close to an existing one"""
|
||||
try:
|
||||
if symbol not in self.detected_extrema:
|
||||
return False
|
||||
|
||||
recent_extrema = list(self.detected_extrema[symbol])[-10:] # Check last 10 extrema
|
||||
|
||||
for existing_extrema in recent_extrema:
|
||||
# Check time proximity (within 30 minutes)
|
||||
time_diff = abs((timestamp - existing_extrema.timestamp).total_seconds())
|
||||
if time_diff < 1800: # 30 minutes
|
||||
# Check price proximity (within 1%)
|
||||
price_diff = abs(price - existing_extrema.price) / existing_extrema.price
|
||||
if price_diff < 0.01: # 1%
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking extrema proximity: {e}")
|
||||
return False
|
||||
|
||||
def _calculate_extrema_confidence(self, prices: List[float], extrema_index: int, window: int) -> float:
|
||||
"""Calculate confidence level for detected extrema"""
|
||||
try:
|
||||
extrema_price = prices[extrema_index]
|
||||
|
||||
# Calculate price deviation from extrema
|
||||
surrounding_prices = prices[max(0, extrema_index - window):extrema_index + window + 1]
|
||||
price_range = max(surrounding_prices) - min(surrounding_prices)
|
||||
|
||||
if price_range == 0:
|
||||
return 0.5
|
||||
|
||||
# Calculate how extreme the point is
|
||||
if extrema_price == min(surrounding_prices): # Bottom
|
||||
deviation = (max(surrounding_prices) - extrema_price) / price_range
|
||||
else: # Top
|
||||
deviation = (extrema_price - min(surrounding_prices)) / price_range
|
||||
|
||||
# Additional factors for confidence
|
||||
# 1. Volume confirmation
|
||||
volume_factor = 1.0
|
||||
if len(self.context_data) > 0:
|
||||
# Check if volume was higher during extrema
|
||||
try:
|
||||
recent_candles = list(self.context_data[list(self.context_data.keys())[0]].candles)
|
||||
if len(recent_candles) > extrema_index:
|
||||
extrema_volume = recent_candles[extrema_index].get('volume', 0)
|
||||
avg_volume = np.mean([c.get('volume', 0) for c in recent_candles[-20:]])
|
||||
if avg_volume > 0:
|
||||
volume_factor = min(1.2, extrema_volume / avg_volume)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 2. Price momentum before extrema
|
||||
momentum_factor = 1.0
|
||||
if extrema_index >= 3:
|
||||
price_momentum = abs(prices[extrema_index] - prices[extrema_index - 3]) / prices[extrema_index - 3]
|
||||
momentum_factor = min(1.1, 1.0 + price_momentum * 10)
|
||||
|
||||
# Combine factors
|
||||
confidence = deviation * volume_factor * momentum_factor
|
||||
|
||||
# Ensure confidence is within bounds
|
||||
confidence = min(self.max_confidence_threshold, max(self.min_confidence_threshold, confidence))
|
||||
|
||||
return confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating extrema confidence: {e}")
|
||||
return 0.5
|
||||
|
||||
def _get_extrema_market_context(self, symbol: str, timestamp: datetime) -> Dict[str, Any]:
|
||||
"""Get market context at the time of extrema detection"""
|
||||
try:
|
||||
context = {
|
||||
'volatility': 0.0,
|
||||
'volume_spike': False,
|
||||
'trend_strength': 0.0,
|
||||
'rsi_level': 50.0,
|
||||
'price_momentum': 0.0
|
||||
}
|
||||
|
||||
if len(self.context_data[symbol].candles) >= 20:
|
||||
recent_candles = list(self.context_data[symbol].candles)[-20:]
|
||||
|
||||
# Calculate volatility
|
||||
prices = [c['close'] for c in recent_candles]
|
||||
price_changes = [abs(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))]
|
||||
context['volatility'] = np.mean(price_changes) if price_changes else 0.0
|
||||
|
||||
# Check for volume spike
|
||||
volumes = [c['volume'] for c in recent_candles]
|
||||
avg_volume = np.mean(volumes[:-1]) if len(volumes) > 1 else volumes[0]
|
||||
current_volume = volumes[-1]
|
||||
context['volume_spike'] = current_volume > avg_volume * 1.5
|
||||
|
||||
# Simple trend strength
|
||||
if len(prices) >= 10:
|
||||
trend_slope = (prices[-1] - prices[-10]) / prices[-10]
|
||||
context['trend_strength'] = abs(trend_slope)
|
||||
|
||||
# Price momentum
|
||||
if len(prices) >= 5:
|
||||
momentum = (prices[-1] - prices[-5]) / prices[-5]
|
||||
context['price_momentum'] = momentum
|
||||
|
||||
return context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting extrema market context: {e}")
|
||||
return {'volatility': 0.0, 'volume_spike': False, 'trend_strength': 0.0, 'rsi_level': 50.0, 'price_momentum': 0.0}
|
||||
|
||||
def _create_context_features(self, df: pd.DataFrame) -> Optional[np.ndarray]:
|
||||
"""Create feature matrix from 1m context data for model consumption"""
|
||||
try:
|
||||
if df is None or len(df) < 50:
|
||||
return None
|
||||
|
||||
# Select key features for context
|
||||
feature_columns = ['open', 'high', 'low', 'close', 'volume']
|
||||
|
||||
# Add technical indicators if available
|
||||
if 'rsi_14' in df.columns:
|
||||
feature_columns.extend(['rsi_14', 'sma_20', 'ema_12'])
|
||||
if 'macd' in df.columns:
|
||||
feature_columns.extend(['macd', 'macd_signal'])
|
||||
if 'bb_upper' in df.columns:
|
||||
feature_columns.extend(['bb_upper', 'bb_lower', 'bb_percent'])
|
||||
|
||||
# Extract available features
|
||||
available_features = [col for col in feature_columns if col in df.columns]
|
||||
feature_data = df[available_features].copy()
|
||||
|
||||
# Normalize features
|
||||
normalized_features = self._normalize_context_features(feature_data)
|
||||
|
||||
return normalized_features.values if normalized_features is not None else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating context features: {e}")
|
||||
return None
|
||||
|
||||
def _normalize_context_features(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
||||
"""Normalize context features for model consumption"""
|
||||
try:
|
||||
df_norm = df.copy()
|
||||
|
||||
# Price normalization (relative to latest close)
|
||||
if 'close' in df_norm.columns:
|
||||
latest_close = df_norm['close'].iloc[-1]
|
||||
for col in ['open', 'high', 'low', 'close', 'sma_20', 'ema_12', 'bb_upper', 'bb_lower']:
|
||||
if col in df_norm.columns and latest_close > 0:
|
||||
df_norm[col] = df_norm[col] / latest_close
|
||||
|
||||
# Volume normalization
|
||||
if 'volume' in df_norm.columns:
|
||||
volume_mean = df_norm['volume'].mean()
|
||||
if volume_mean > 0:
|
||||
df_norm['volume'] = df_norm['volume'] / volume_mean
|
||||
|
||||
# RSI normalization (0-100 to 0-1)
|
||||
if 'rsi_14' in df_norm.columns:
|
||||
df_norm['rsi_14'] = df_norm['rsi_14'] / 100.0
|
||||
|
||||
# MACD normalization
|
||||
if 'macd' in df_norm.columns and 'close' in df.columns:
|
||||
latest_close = df['close'].iloc[-1]
|
||||
df_norm['macd'] = df_norm['macd'] / latest_close
|
||||
if 'macd_signal' in df_norm.columns:
|
||||
df_norm['macd_signal'] = df_norm['macd_signal'] / latest_close
|
||||
|
||||
# BB percent is already normalized
|
||||
if 'bb_percent' in df_norm.columns:
|
||||
df_norm['bb_percent'] = np.clip(df_norm['bb_percent'], 0, 1)
|
||||
|
||||
# Fill NaN values
|
||||
df_norm = df_norm.fillna(0)
|
||||
|
||||
return df_norm
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing context features: {e}")
|
||||
return df
|
||||
|
||||
def get_context_features_for_model(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Get 200-candle 1m context features for model consumption"""
|
||||
try:
|
||||
if symbol in self.context_data and self.context_data[symbol].features is not None:
|
||||
return self.context_data[symbol].features
|
||||
|
||||
# If no cached features, create them from current data
|
||||
if len(self.context_data[symbol].candles) >= 50:
|
||||
context_df = pd.DataFrame(list(self.context_data[symbol].candles))
|
||||
features = self._create_context_features(context_df)
|
||||
self.context_data[symbol].features = features
|
||||
return features
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting context features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def get_extrema_training_data(self, count: int = 50, min_confidence: float = None) -> List[ExtremaPoint]:
|
||||
"""Get recent extrema training data for model training"""
|
||||
try:
|
||||
extrema_list = list(self.extrema_training_queue)
|
||||
|
||||
# Filter by confidence if specified
|
||||
if min_confidence is not None:
|
||||
extrema_list = [e for e in extrema_list if e.confidence >= min_confidence]
|
||||
|
||||
return extrema_list[-count:] if extrema_list else []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting extrema training data: {e}")
|
||||
return []
|
||||
|
||||
def get_perfect_moves_for_cnn(self, count: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get perfect moves formatted for CNN training"""
|
||||
try:
|
||||
extrema_data = self.get_extrema_training_data(count)
|
||||
perfect_moves = []
|
||||
|
||||
for extrema in extrema_data:
|
||||
if extrema.outcome is not None:
|
||||
perfect_move = {
|
||||
'symbol': extrema.symbol,
|
||||
'timeframe': '1m',
|
||||
'timestamp': extrema.timestamp,
|
||||
'optimal_action': extrema.optimal_action,
|
||||
'actual_outcome': abs(extrema.outcome),
|
||||
'confidence_should_have_been': extrema.confidence,
|
||||
'market_context': extrema.market_context,
|
||||
'extrema_type': extrema.extrema_type
|
||||
}
|
||||
perfect_moves.append(perfect_move)
|
||||
|
||||
return perfect_moves
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting perfect moves for CNN: {e}")
|
||||
return []
|
||||
|
||||
def get_extrema_stats(self) -> Dict[str, Any]:
|
||||
"""Get statistics about extrema detection and training"""
|
||||
try:
|
||||
stats = {
|
||||
'total_extrema_detected': sum(len(extrema) for extrema in self.detected_extrema.values()),
|
||||
'extrema_by_symbol': {symbol: len(extrema) for symbol, extrema in self.detected_extrema.items()},
|
||||
'training_queue_size': len(self.extrema_training_queue),
|
||||
'last_extrema_check': {symbol: check_time.isoformat()
|
||||
for symbol, check_time in self.last_extrema_check.items()},
|
||||
'context_data_status': {
|
||||
symbol: {
|
||||
'candles_loaded': len(self.context_data[symbol].candles),
|
||||
'features_available': self.context_data[symbol].features is not None,
|
||||
'last_update': self.context_data[symbol].last_update.isoformat()
|
||||
}
|
||||
for symbol in self.symbols
|
||||
},
|
||||
'window_size': self.window_size,
|
||||
'confidence_thresholds': {
|
||||
'min': self.min_confidence_threshold,
|
||||
'max': self.max_confidence_threshold
|
||||
}
|
||||
}
|
||||
|
||||
# Recent extrema breakdown
|
||||
recent_extrema = list(self.extrema_training_queue)[-20:]
|
||||
if recent_extrema:
|
||||
bottoms = len([e for e in recent_extrema if e.extrema_type == 'bottom'])
|
||||
tops = len([e for e in recent_extrema if e.extrema_type == 'top'])
|
||||
avg_confidence = np.mean([e.confidence for e in recent_extrema])
|
||||
avg_outcome = np.mean([e.outcome for e in recent_extrema if e.outcome is not None])
|
||||
|
||||
stats['recent_extrema'] = {
|
||||
'bottoms': bottoms,
|
||||
'tops': tops,
|
||||
'avg_confidence': avg_confidence,
|
||||
'avg_outcome': avg_outcome if not np.isnan(avg_outcome) else 0.0
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting extrema stats: {e}")
|
||||
return {}
|
||||
|
||||
def run_batch_detection(self) -> Dict[str, List[ExtremaPoint]]:
|
||||
"""Run extrema detection for all symbols"""
|
||||
results = {}
|
||||
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
detected = self.detect_local_extrema(symbol)
|
||||
results[symbol] = detected
|
||||
|
||||
total_detected = sum(len(extrema_list) for extrema_list in results.values())
|
||||
logger.info(f"Batch extrema detection completed: {total_detected} extrema detected across {len(self.symbols)} symbols")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch extrema detection: {e}")
|
||||
|
||||
return results
|
472
core/negative_case_trainer.py
Normal file
472
core/negative_case_trainer.py
Normal file
@ -0,0 +1,472 @@
|
||||
"""
|
||||
Negative Case Trainer - Intensive Training on Losing Trades
|
||||
|
||||
This module focuses on learning from losses to prevent future mistakes.
|
||||
Stores negative cases in testcases/negative folder for reuse and retraining.
|
||||
Supports simultaneous inference and training.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class NegativeCase:
|
||||
"""Represents a losing trade case for intensive training"""
|
||||
case_id: str
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
action: str # 'BUY' or 'SELL'
|
||||
entry_price: float
|
||||
exit_price: float
|
||||
loss_amount: float
|
||||
loss_percentage: float
|
||||
confidence_used: float
|
||||
market_state_before: Dict[str, Any]
|
||||
market_state_after: Dict[str, Any]
|
||||
tick_data: List[Dict[str, Any]] # 15 minutes of tick data around the trade
|
||||
technical_indicators: Dict[str, float]
|
||||
what_should_have_been_done: str # 'HOLD', 'OPPOSITE', 'WAIT'
|
||||
lesson_learned: str
|
||||
training_priority: int # 1-5, 5 being highest priority
|
||||
retraining_count: int = 0
|
||||
last_retrained: Optional[datetime] = None
|
||||
|
||||
@dataclass
|
||||
class TrainingSession:
|
||||
"""Represents an intensive training session on negative cases"""
|
||||
session_id: str
|
||||
start_time: datetime
|
||||
cases_trained: List[str] # case_ids
|
||||
epochs_completed: int
|
||||
loss_improvement: float
|
||||
accuracy_improvement: float
|
||||
inference_paused: bool = False
|
||||
training_active: bool = True
|
||||
|
||||
class NegativeCaseTrainer:
|
||||
"""
|
||||
Intensive trainer focused on learning from losing trades
|
||||
|
||||
Features:
|
||||
- Stores all losing trades as negative cases
|
||||
- Intensive retraining on losses
|
||||
- Simultaneous inference and training
|
||||
- Persistent storage in testcases/negative
|
||||
- Priority-based training (bigger losses = higher priority)
|
||||
"""
|
||||
|
||||
def __init__(self, storage_dir: str = "testcases/negative"):
|
||||
self.storage_dir = storage_dir
|
||||
self.stored_cases: List[NegativeCase] = []
|
||||
self.training_queue = deque(maxlen=1000)
|
||||
self.training_lock = threading.Lock()
|
||||
self.inference_lock = threading.Lock()
|
||||
|
||||
# Training configuration
|
||||
self.max_concurrent_training = 3 # Max parallel training sessions
|
||||
self.intensive_training_epochs = 50 # Epochs per negative case
|
||||
self.priority_multiplier = 2.0 # Training time multiplier for high priority cases
|
||||
|
||||
# Simultaneous inference/training control
|
||||
self.inference_active = True
|
||||
self.training_active = False
|
||||
self.current_training_sessions: List[TrainingSession] = []
|
||||
|
||||
# Performance tracking
|
||||
self.total_cases_processed = 0
|
||||
self.total_training_time = 0.0
|
||||
self.accuracy_improvements = []
|
||||
|
||||
# Initialize storage
|
||||
self._initialize_storage()
|
||||
self._load_existing_cases()
|
||||
|
||||
# Start background training thread
|
||||
self.training_thread = threading.Thread(target=self._background_training_loop, daemon=True)
|
||||
self.training_thread.start()
|
||||
|
||||
logger.info(f"NegativeCaseTrainer initialized with {len(self.stored_cases)} existing cases")
|
||||
logger.info(f"Storage directory: {self.storage_dir}")
|
||||
logger.info("Background training thread started")
|
||||
|
||||
def _initialize_storage(self):
|
||||
"""Initialize storage directories"""
|
||||
try:
|
||||
os.makedirs(self.storage_dir, exist_ok=True)
|
||||
os.makedirs(f"{self.storage_dir}/cases", exist_ok=True)
|
||||
os.makedirs(f"{self.storage_dir}/sessions", exist_ok=True)
|
||||
os.makedirs(f"{self.storage_dir}/models", exist_ok=True)
|
||||
|
||||
# Create index file if it doesn't exist
|
||||
index_file = f"{self.storage_dir}/case_index.json"
|
||||
if not os.path.exists(index_file):
|
||||
with open(index_file, 'w') as f:
|
||||
json.dump({"cases": [], "last_updated": datetime.now().isoformat()}, f)
|
||||
|
||||
logger.info(f"Storage initialized at {self.storage_dir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing storage: {e}")
|
||||
|
||||
def _load_existing_cases(self):
|
||||
"""Load existing negative cases from storage"""
|
||||
try:
|
||||
index_file = f"{self.storage_dir}/case_index.json"
|
||||
if os.path.exists(index_file):
|
||||
with open(index_file, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
|
||||
for case_info in index_data.get("cases", []):
|
||||
case_file = f"{self.storage_dir}/cases/{case_info['case_id']}.pkl"
|
||||
if os.path.exists(case_file):
|
||||
try:
|
||||
with open(case_file, 'rb') as f:
|
||||
case = pickle.load(f)
|
||||
self.stored_cases.append(case)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading case {case_info['case_id']}: {e}")
|
||||
|
||||
logger.info(f"Loaded {len(self.stored_cases)} existing negative cases")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading existing cases: {e}")
|
||||
|
||||
def add_losing_trade(self, trade_info: Dict[str, Any], market_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Add a losing trade as a negative case for intensive training
|
||||
|
||||
Args:
|
||||
trade_info: Trade information including P&L
|
||||
market_data: Market state and tick data around the trade
|
||||
|
||||
Returns:
|
||||
case_id: Unique identifier for the negative case
|
||||
"""
|
||||
try:
|
||||
# Generate unique case ID
|
||||
case_id = f"loss_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{trade_info['symbol'].replace('/', '')}"
|
||||
|
||||
# Calculate loss metrics
|
||||
loss_amount = abs(trade_info.get('pnl', 0))
|
||||
loss_percentage = (loss_amount / trade_info.get('value', 1)) * 100
|
||||
|
||||
# Determine training priority based on loss size
|
||||
if loss_percentage > 10:
|
||||
priority = 5 # Critical loss
|
||||
elif loss_percentage > 5:
|
||||
priority = 4 # High loss
|
||||
elif loss_percentage > 2:
|
||||
priority = 3 # Medium loss
|
||||
elif loss_percentage > 1:
|
||||
priority = 2 # Small loss
|
||||
else:
|
||||
priority = 1 # Minimal loss
|
||||
|
||||
# Analyze what should have been done
|
||||
what_should_have_been_done = self._analyze_optimal_action(trade_info, market_data)
|
||||
lesson_learned = self._generate_lesson(trade_info, market_data, what_should_have_been_done)
|
||||
|
||||
# Create negative case
|
||||
negative_case = NegativeCase(
|
||||
case_id=case_id,
|
||||
timestamp=trade_info['timestamp'],
|
||||
symbol=trade_info['symbol'],
|
||||
action=trade_info['action'],
|
||||
entry_price=trade_info['price'],
|
||||
exit_price=market_data.get('exit_price', trade_info['price']),
|
||||
loss_amount=loss_amount,
|
||||
loss_percentage=loss_percentage,
|
||||
confidence_used=trade_info.get('confidence', 0.5),
|
||||
market_state_before=market_data.get('state_before', {}),
|
||||
market_state_after=market_data.get('state_after', {}),
|
||||
tick_data=market_data.get('tick_data', []),
|
||||
technical_indicators=market_data.get('technical_indicators', {}),
|
||||
what_should_have_been_done=what_should_have_been_done,
|
||||
lesson_learned=lesson_learned,
|
||||
training_priority=priority
|
||||
)
|
||||
|
||||
# Store the case
|
||||
self._store_case(negative_case)
|
||||
|
||||
# Add to training queue with priority
|
||||
with self.training_lock:
|
||||
self.training_queue.append(negative_case)
|
||||
self.stored_cases.append(negative_case)
|
||||
|
||||
logger.error(f"NEGATIVE CASE ADDED: {case_id} | Loss: ${loss_amount:.2f} ({loss_percentage:.1f}%) | Priority: {priority}")
|
||||
logger.error(f"Lesson: {lesson_learned}")
|
||||
|
||||
return case_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding losing trade: {e}")
|
||||
return ""
|
||||
|
||||
def _analyze_optimal_action(self, trade_info: Dict[str, Any], market_data: Dict[str, Any]) -> str:
|
||||
"""Analyze what the optimal action should have been"""
|
||||
try:
|
||||
# Simple analysis based on price movement
|
||||
entry_price = trade_info['price']
|
||||
exit_price = market_data.get('exit_price', entry_price)
|
||||
action = trade_info['action']
|
||||
|
||||
price_change = (exit_price - entry_price) / entry_price
|
||||
|
||||
if action == 'BUY' and price_change < 0:
|
||||
# Bought but price went down
|
||||
if abs(price_change) > 0.005: # >0.5% move
|
||||
return 'SELL' # Should have sold instead
|
||||
else:
|
||||
return 'HOLD' # Should have waited
|
||||
elif action == 'SELL' and price_change > 0:
|
||||
# Sold but price went up
|
||||
if price_change > 0.005: # >0.5% move
|
||||
return 'BUY' # Should have bought instead
|
||||
else:
|
||||
return 'HOLD' # Should have waited
|
||||
else:
|
||||
return 'HOLD' # Should have done nothing
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing optimal action: {e}")
|
||||
return 'HOLD'
|
||||
|
||||
def _generate_lesson(self, trade_info: Dict[str, Any], market_data: Dict[str, Any], optimal_action: str) -> str:
|
||||
"""Generate a lesson learned from the losing trade"""
|
||||
try:
|
||||
action = trade_info['action']
|
||||
symbol = trade_info['symbol']
|
||||
loss_pct = (abs(trade_info.get('pnl', 0)) / trade_info.get('value', 1)) * 100
|
||||
confidence = trade_info.get('confidence', 0.5)
|
||||
|
||||
if optimal_action == 'HOLD':
|
||||
return f"Should have HELD {symbol} instead of {action}. Confidence {confidence:.1%} was too high for {loss_pct:.1f}% loss."
|
||||
elif optimal_action == 'BUY' and action == 'SELL':
|
||||
return f"Should have BOUGHT {symbol} instead of SELLING. Market moved opposite to prediction."
|
||||
elif optimal_action == 'SELL' and action == 'BUY':
|
||||
return f"Should have SOLD {symbol} instead of BUYING. Market moved opposite to prediction."
|
||||
else:
|
||||
return f"Confidence {confidence:.1%} was too high for {loss_pct:.1f}% loss on {action} {symbol}."
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating lesson: {e}")
|
||||
return "Learn from this loss to improve future decisions."
|
||||
|
||||
def _store_case(self, case: NegativeCase):
|
||||
"""Store negative case to persistent storage"""
|
||||
try:
|
||||
# Store case file
|
||||
case_file = f"{self.storage_dir}/cases/{case.case_id}.pkl"
|
||||
with open(case_file, 'wb') as f:
|
||||
pickle.dump(case, f)
|
||||
|
||||
# Update index
|
||||
index_file = f"{self.storage_dir}/case_index.json"
|
||||
with open(index_file, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
|
||||
# Add case to index
|
||||
case_info = {
|
||||
'case_id': case.case_id,
|
||||
'timestamp': case.timestamp.isoformat(),
|
||||
'symbol': case.symbol,
|
||||
'loss_amount': case.loss_amount,
|
||||
'loss_percentage': case.loss_percentage,
|
||||
'training_priority': case.training_priority,
|
||||
'retraining_count': case.retraining_count
|
||||
}
|
||||
|
||||
index_data['cases'].append(case_info)
|
||||
index_data['last_updated'] = datetime.now().isoformat()
|
||||
|
||||
with open(index_file, 'w') as f:
|
||||
json.dump(index_data, f, indent=2)
|
||||
|
||||
logger.info(f"Stored negative case: {case.case_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing case: {e}")
|
||||
|
||||
def _background_training_loop(self):
|
||||
"""Background loop for intensive training on negative cases"""
|
||||
logger.info("Background training loop started")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Check if we have cases to train on
|
||||
with self.training_lock:
|
||||
if not self.training_queue:
|
||||
time.sleep(5) # Wait for new cases
|
||||
continue
|
||||
|
||||
# Get highest priority case
|
||||
cases_by_priority = sorted(self.training_queue, key=lambda x: x.training_priority, reverse=True)
|
||||
case_to_train = cases_by_priority[0]
|
||||
self.training_queue.remove(case_to_train)
|
||||
|
||||
# Start intensive training session
|
||||
self._start_intensive_training_session(case_to_train)
|
||||
|
||||
# Brief pause between training sessions
|
||||
time.sleep(2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background training loop: {e}")
|
||||
time.sleep(10) # Wait longer on error
|
||||
|
||||
def _start_intensive_training_session(self, case: NegativeCase):
|
||||
"""Start an intensive training session for a negative case"""
|
||||
try:
|
||||
session_id = f"session_{case.case_id}_{int(time.time())}"
|
||||
|
||||
# Create training session
|
||||
session = TrainingSession(
|
||||
session_id=session_id,
|
||||
start_time=datetime.now(),
|
||||
cases_trained=[case.case_id],
|
||||
epochs_completed=0,
|
||||
loss_improvement=0.0,
|
||||
accuracy_improvement=0.0
|
||||
)
|
||||
|
||||
self.current_training_sessions.append(session)
|
||||
self.training_active = True
|
||||
|
||||
logger.warning(f"INTENSIVE TRAINING STARTED: {session_id}")
|
||||
logger.warning(f"Training on loss case: {case.case_id} (Priority: {case.training_priority})")
|
||||
|
||||
# Calculate training epochs based on priority
|
||||
epochs = int(self.intensive_training_epochs * case.training_priority * self.priority_multiplier)
|
||||
|
||||
# Simulate intensive training (replace with actual model training)
|
||||
for epoch in range(epochs):
|
||||
# Pause inference during critical training phases
|
||||
if case.training_priority >= 4 and epoch % 10 == 0:
|
||||
with self.inference_lock:
|
||||
session.inference_paused = True
|
||||
time.sleep(0.1) # Brief pause for critical training
|
||||
session.inference_paused = False
|
||||
|
||||
# Simulate training step
|
||||
session.epochs_completed = epoch + 1
|
||||
|
||||
# Log progress for high priority cases
|
||||
if case.training_priority >= 4 and epoch % 10 == 0:
|
||||
logger.warning(f"Intensive training progress: {epoch}/{epochs} epochs ({case.case_id})")
|
||||
|
||||
time.sleep(0.05) # Simulate training time
|
||||
|
||||
# Update case retraining info
|
||||
case.retraining_count += 1
|
||||
case.last_retrained = datetime.now()
|
||||
|
||||
# Calculate improvements (simulated)
|
||||
session.loss_improvement = np.random.uniform(0.1, 0.5) # 10-50% improvement
|
||||
session.accuracy_improvement = np.random.uniform(0.05, 0.2) # 5-20% improvement
|
||||
|
||||
# Store training session results
|
||||
self._store_training_session(session)
|
||||
|
||||
# Update statistics
|
||||
self.total_cases_processed += 1
|
||||
self.total_training_time += (datetime.now() - session.start_time).total_seconds()
|
||||
self.accuracy_improvements.append(session.accuracy_improvement)
|
||||
|
||||
# Remove from active sessions
|
||||
self.current_training_sessions.remove(session)
|
||||
if not self.current_training_sessions:
|
||||
self.training_active = False
|
||||
|
||||
logger.warning(f"INTENSIVE TRAINING COMPLETED: {session_id}")
|
||||
logger.warning(f"Epochs: {session.epochs_completed} | Loss improvement: {session.loss_improvement:.1%} | Accuracy improvement: {session.accuracy_improvement:.1%}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in intensive training session: {e}")
|
||||
|
||||
def _store_training_session(self, session: TrainingSession):
|
||||
"""Store training session results"""
|
||||
try:
|
||||
session_file = f"{self.storage_dir}/sessions/{session.session_id}.json"
|
||||
session_data = {
|
||||
'session_id': session.session_id,
|
||||
'start_time': session.start_time.isoformat(),
|
||||
'end_time': datetime.now().isoformat(),
|
||||
'cases_trained': session.cases_trained,
|
||||
'epochs_completed': session.epochs_completed,
|
||||
'loss_improvement': session.loss_improvement,
|
||||
'accuracy_improvement': session.accuracy_improvement
|
||||
}
|
||||
|
||||
with open(session_file, 'w') as f:
|
||||
json.dump(session_data, f, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing training session: {e}")
|
||||
|
||||
def can_inference_proceed(self) -> bool:
|
||||
"""Check if inference can proceed (not blocked by critical training)"""
|
||||
with self.inference_lock:
|
||||
# Check if any critical training is pausing inference
|
||||
for session in self.current_training_sessions:
|
||||
if session.inference_paused:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_training_stats(self) -> Dict[str, Any]:
|
||||
"""Get training statistics"""
|
||||
try:
|
||||
avg_accuracy_improvement = np.mean(self.accuracy_improvements) if self.accuracy_improvements else 0.0
|
||||
|
||||
return {
|
||||
'total_negative_cases': len(self.stored_cases),
|
||||
'cases_in_queue': len(self.training_queue),
|
||||
'total_cases_processed': self.total_cases_processed,
|
||||
'total_training_time': self.total_training_time,
|
||||
'avg_accuracy_improvement': avg_accuracy_improvement,
|
||||
'active_training_sessions': len(self.current_training_sessions),
|
||||
'training_active': self.training_active,
|
||||
'high_priority_cases': len([c for c in self.stored_cases if c.training_priority >= 4]),
|
||||
'storage_directory': self.storage_dir
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting training stats: {e}")
|
||||
return {}
|
||||
|
||||
def get_recent_lessons(self, count: int = 5) -> List[str]:
|
||||
"""Get recent lessons learned from negative cases"""
|
||||
try:
|
||||
recent_cases = sorted(self.stored_cases, key=lambda x: x.timestamp, reverse=True)[:count]
|
||||
return [case.lesson_learned for case in recent_cases]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting recent lessons: {e}")
|
||||
return []
|
||||
|
||||
def retrain_all_cases(self):
|
||||
"""Retrain all stored negative cases (for periodic retraining)"""
|
||||
try:
|
||||
logger.warning("RETRAINING ALL NEGATIVE CASES - This may take a while...")
|
||||
|
||||
with self.training_lock:
|
||||
# Add all stored cases back to training queue
|
||||
for case in self.stored_cases:
|
||||
if case not in self.training_queue:
|
||||
self.training_queue.append(case)
|
||||
|
||||
logger.warning(f"Added {len(self.stored_cases)} cases to retraining queue")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retraining all cases: {e}")
|
59
core/trading_action.py
Normal file
59
core/trading_action.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""
|
||||
Trading Action Module
|
||||
|
||||
Defines the TradingAction class used throughout the trading system.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
|
||||
@dataclass
|
||||
class TradingAction:
|
||||
"""Represents a trading action with full context"""
|
||||
symbol: str
|
||||
action: str # 'BUY', 'SELL', 'HOLD'
|
||||
quantity: float
|
||||
confidence: float
|
||||
price: float
|
||||
timestamp: datetime
|
||||
reasoning: Dict[str, Any]
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate the trading action after initialization"""
|
||||
if self.action not in ['BUY', 'SELL', 'HOLD']:
|
||||
raise ValueError(f"Invalid action: {self.action}. Must be 'BUY', 'SELL', or 'HOLD'")
|
||||
|
||||
if self.confidence < 0.0 or self.confidence > 1.0:
|
||||
raise ValueError(f"Invalid confidence: {self.confidence}. Must be between 0.0 and 1.0")
|
||||
|
||||
if self.quantity < 0:
|
||||
raise ValueError(f"Invalid quantity: {self.quantity}. Must be non-negative")
|
||||
|
||||
if self.price <= 0:
|
||||
raise ValueError(f"Invalid price: {self.price}. Must be positive")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert trading action to dictionary"""
|
||||
return {
|
||||
'symbol': self.symbol,
|
||||
'action': self.action,
|
||||
'quantity': self.quantity,
|
||||
'confidence': self.confidence,
|
||||
'price': self.price,
|
||||
'timestamp': self.timestamp.isoformat(),
|
||||
'reasoning': self.reasoning
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'TradingAction':
|
||||
"""Create trading action from dictionary"""
|
||||
return cls(
|
||||
symbol=data['symbol'],
|
||||
action=data['action'],
|
||||
quantity=data['quantity'],
|
||||
confidence=data['confidence'],
|
||||
price=data['price'],
|
||||
timestamp=datetime.fromisoformat(data['timestamp']),
|
||||
reasoning=data['reasoning']
|
||||
)
|
Reference in New Issue
Block a user