584 lines
27 KiB
Python
584 lines
27 KiB
Python
"""
|
|
Extrema Training Module - Reusable Local Bottom/Top Detection and Training
|
|
|
|
This module provides reusable functionality for:
|
|
1. Detecting local extrema (bottoms and tops) in price data
|
|
2. Creating training opportunities from extrema
|
|
3. Loading and managing 200-candle 1m context data
|
|
4. Generating features for model consumption
|
|
5. Training on not-so-perfect opportunities
|
|
|
|
Can be used across different dashboards and trading systems.
|
|
"""
|
|
|
|
import logging
|
|
import numpy as np
|
|
import pandas as pd
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
from dataclasses import dataclass
|
|
from collections import deque
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class ExtremaPoint:
|
|
"""Represents a detected local extrema (bottom or top)"""
|
|
symbol: str
|
|
timestamp: datetime
|
|
price: float
|
|
extrema_type: str # 'bottom' or 'top'
|
|
confidence: float
|
|
context_before: List[float]
|
|
context_after: List[float]
|
|
optimal_action: str # 'BUY' or 'SELL'
|
|
market_context: Dict[str, Any]
|
|
outcome: Optional[float] = None # Price change after extrema
|
|
|
|
@dataclass
|
|
class ContextData:
|
|
"""200-candle 1m context data for enhanced model performance"""
|
|
symbol: str
|
|
candles: deque
|
|
features: Optional[np.ndarray]
|
|
last_update: datetime
|
|
|
|
class ExtremaTrainer:
|
|
"""Reusable extrema detection and training functionality"""
|
|
|
|
def __init__(self, data_provider, symbols: List[str], window_size: int = 10):
|
|
"""
|
|
Initialize the extrema trainer
|
|
|
|
Args:
|
|
data_provider: Data provider instance
|
|
symbols: List of symbols to track
|
|
window_size: Window size for extrema detection (default 10)
|
|
"""
|
|
self.data_provider = data_provider
|
|
self.symbols = symbols
|
|
self.window_size = window_size
|
|
|
|
# Extrema tracking
|
|
self.detected_extrema = {symbol: deque(maxlen=1000) for symbol in symbols}
|
|
self.extrema_training_queue = deque(maxlen=500)
|
|
self.last_extrema_check = {symbol: datetime.now() for symbol in symbols}
|
|
|
|
# 200-candle context data
|
|
self.context_data = {symbol: ContextData(
|
|
symbol=symbol,
|
|
candles=deque(maxlen=200),
|
|
features=None,
|
|
last_update=datetime.now()
|
|
) for symbol in symbols}
|
|
|
|
self.context_update_frequency = 60 # Update every 60 seconds
|
|
|
|
# Training parameters
|
|
self.min_confidence_threshold = 0.3 # Train on opportunities with at least 30% confidence
|
|
self.max_confidence_threshold = 0.95 # Cap confidence at 95%
|
|
|
|
logger.info(f"ExtremaTrainer initialized for symbols: {symbols}")
|
|
logger.info(f"Window size: {window_size}, Context update frequency: {self.context_update_frequency}s")
|
|
|
|
def initialize_context_data(self) -> Dict[str, bool]:
|
|
"""Initialize 200-candle 1m context data for all symbols"""
|
|
results = {}
|
|
|
|
try:
|
|
logger.info("Initializing 200-candle 1m context data for enhanced model performance")
|
|
|
|
for symbol in self.symbols:
|
|
try:
|
|
# Load 200 candles of 1m data
|
|
context_data = self.data_provider.get_historical_data(symbol, '1m', limit=200)
|
|
|
|
if context_data is not None and len(context_data) > 0:
|
|
# Store raw data
|
|
for _, row in context_data.iterrows():
|
|
candle_data = {
|
|
'timestamp': row['timestamp'],
|
|
'open': row['open'],
|
|
'high': row['high'],
|
|
'low': row['low'],
|
|
'close': row['close'],
|
|
'volume': row['volume']
|
|
}
|
|
self.context_data[symbol].candles.append(candle_data)
|
|
|
|
# Create feature matrix for models
|
|
self.context_data[symbol].features = self._create_context_features(context_data)
|
|
self.context_data[symbol].last_update = datetime.now()
|
|
|
|
results[symbol] = True
|
|
logger.info(f"✅ Loaded {len(context_data)} 1m candles for {symbol} context")
|
|
else:
|
|
results[symbol] = False
|
|
logger.warning(f"❌ No 1m context data available for {symbol}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Error loading context data for {symbol}: {e}")
|
|
results[symbol] = False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing context data: {e}")
|
|
|
|
successful = sum(1 for success in results.values() if success)
|
|
logger.info(f"Context data initialization: {successful}/{len(self.symbols)} symbols loaded")
|
|
|
|
return results
|
|
|
|
def update_context_data(self, symbol: str = None) -> Dict[str, bool]:
|
|
"""Update 200-candle 1m context data for specified symbol or all symbols"""
|
|
results = {}
|
|
|
|
try:
|
|
symbols_to_update = [symbol] if symbol else self.symbols
|
|
|
|
for sym in symbols_to_update:
|
|
try:
|
|
# Check if update is needed
|
|
time_since_update = (datetime.now() - self.context_data[sym].last_update).total_seconds()
|
|
|
|
if time_since_update >= self.context_update_frequency:
|
|
# Get latest 1m data
|
|
latest_data = self.data_provider.get_historical_data(sym, '1m', limit=10, refresh=True)
|
|
|
|
if latest_data is not None and len(latest_data) > 0:
|
|
# Add new candles to context
|
|
for _, row in latest_data.iterrows():
|
|
candle_data = {
|
|
'timestamp': row['timestamp'],
|
|
'open': row['open'],
|
|
'high': row['high'],
|
|
'low': row['low'],
|
|
'close': row['close'],
|
|
'volume': row['volume']
|
|
}
|
|
|
|
# Check if this candle is newer than our latest
|
|
if (not self.context_data[sym].candles or
|
|
candle_data['timestamp'] > self.context_data[sym].candles[-1]['timestamp']):
|
|
self.context_data[sym].candles.append(candle_data)
|
|
|
|
# Update feature matrix
|
|
if len(self.context_data[sym].candles) >= 50:
|
|
context_df = pd.DataFrame(list(self.context_data[sym].candles))
|
|
self.context_data[sym].features = self._create_context_features(context_df)
|
|
|
|
self.context_data[sym].last_update = datetime.now()
|
|
|
|
# Check for local extrema in updated data
|
|
self.detect_local_extrema(sym)
|
|
|
|
results[sym] = True
|
|
else:
|
|
results[sym] = False
|
|
else:
|
|
results[sym] = True # No update needed
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating context data for {sym}: {e}")
|
|
results[sym] = False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating context data: {e}")
|
|
|
|
return results
|
|
|
|
def detect_local_extrema(self, symbol: str) -> List[ExtremaPoint]:
|
|
"""Detect local bottoms and tops for training opportunities"""
|
|
detected = []
|
|
|
|
try:
|
|
if len(self.context_data[symbol].candles) < self.window_size * 3:
|
|
return detected
|
|
|
|
# Get all available price data for better extrema detection
|
|
all_candles = list(self.context_data[symbol].candles)
|
|
prices = [candle['close'] for candle in all_candles]
|
|
timestamps = [candle['timestamp'] for candle in all_candles]
|
|
|
|
# Use a more sophisticated extrema detection algorithm
|
|
window = self.window_size
|
|
|
|
# Look for extrema in the middle portion of the data (not at edges)
|
|
start_idx = window
|
|
end_idx = len(prices) - window
|
|
|
|
for i in range(start_idx, end_idx):
|
|
current_price = prices[i]
|
|
current_time = timestamps[i]
|
|
|
|
# Get surrounding prices for comparison
|
|
left_prices = prices[i - window:i]
|
|
right_prices = prices[i + 1:i + window + 1]
|
|
|
|
# Check for local bottom (current price is lower than surrounding prices)
|
|
is_bottom = (current_price <= min(left_prices) and
|
|
current_price <= min(right_prices) and
|
|
current_price < max(left_prices) * 0.998) # At least 0.2% lower
|
|
|
|
# Check for local top (current price is higher than surrounding prices)
|
|
is_top = (current_price >= max(left_prices) and
|
|
current_price >= max(right_prices) and
|
|
current_price > min(left_prices) * 1.002) # At least 0.2% higher
|
|
|
|
if is_bottom or is_top:
|
|
extrema_type = 'bottom' if is_bottom else 'top'
|
|
|
|
# Calculate confidence based on price deviation and volume
|
|
confidence = self._calculate_extrema_confidence(prices, i, window)
|
|
|
|
# Only process if confidence meets minimum threshold
|
|
if confidence >= self.min_confidence_threshold:
|
|
# Check if this extrema is too close to a previously detected one
|
|
if not self._is_too_close_to_existing_extrema(symbol, current_time, current_price):
|
|
# Create extrema point
|
|
extrema_point = ExtremaPoint(
|
|
symbol=symbol,
|
|
timestamp=current_time,
|
|
price=current_price,
|
|
extrema_type=extrema_type,
|
|
confidence=min(confidence, self.max_confidence_threshold),
|
|
context_before=left_prices,
|
|
context_after=right_prices,
|
|
optimal_action='BUY' if is_bottom else 'SELL',
|
|
market_context=self._get_extrema_market_context(symbol, current_time)
|
|
)
|
|
|
|
# Calculate outcome if we have future data
|
|
if len(right_prices) > 0:
|
|
# Look ahead further for better outcome calculation
|
|
future_idx = min(i + window * 2, len(prices) - 1)
|
|
future_price = prices[future_idx]
|
|
price_change = (future_price - current_price) / current_price
|
|
|
|
# For bottoms, positive change is good; for tops, negative is good
|
|
if extrema_type == 'bottom':
|
|
extrema_point.outcome = price_change
|
|
else: # top
|
|
extrema_point.outcome = -price_change
|
|
|
|
self.detected_extrema[symbol].append(extrema_point)
|
|
self.extrema_training_queue.append(extrema_point)
|
|
detected.append(extrema_point)
|
|
|
|
logger.info(f"Local {extrema_type} detected for {symbol} at ${current_price:.2f} "
|
|
f"(confidence: {confidence:.3f}, outcome: {extrema_point.outcome:.4f})")
|
|
|
|
self.last_extrema_check[symbol] = datetime.now()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error detecting local extrema for {symbol}: {e}")
|
|
|
|
return detected
|
|
|
|
def _is_too_close_to_existing_extrema(self, symbol: str, timestamp: datetime, price: float) -> bool:
|
|
"""Check if this extrema is too close to an existing one"""
|
|
try:
|
|
if symbol not in self.detected_extrema:
|
|
return False
|
|
|
|
recent_extrema = list(self.detected_extrema[symbol])[-10:] # Check last 10 extrema
|
|
|
|
for existing_extrema in recent_extrema:
|
|
# Check time proximity (within 30 minutes)
|
|
time_diff = abs((timestamp - existing_extrema.timestamp).total_seconds())
|
|
if time_diff < 1800: # 30 minutes
|
|
# Check price proximity (within 1%)
|
|
price_diff = abs(price - existing_extrema.price) / existing_extrema.price
|
|
if price_diff < 0.01: # 1%
|
|
return True
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error checking extrema proximity: {e}")
|
|
return False
|
|
|
|
def _calculate_extrema_confidence(self, prices: List[float], extrema_index: int, window: int) -> float:
|
|
"""Calculate confidence level for detected extrema"""
|
|
try:
|
|
extrema_price = prices[extrema_index]
|
|
|
|
# Calculate price deviation from extrema
|
|
surrounding_prices = prices[max(0, extrema_index - window):extrema_index + window + 1]
|
|
price_range = max(surrounding_prices) - min(surrounding_prices)
|
|
|
|
if price_range == 0:
|
|
return 0.5
|
|
|
|
# Calculate how extreme the point is
|
|
if extrema_price == min(surrounding_prices): # Bottom
|
|
deviation = (max(surrounding_prices) - extrema_price) / price_range
|
|
else: # Top
|
|
deviation = (extrema_price - min(surrounding_prices)) / price_range
|
|
|
|
# Additional factors for confidence
|
|
# 1. Volume confirmation
|
|
volume_factor = 1.0
|
|
if len(self.context_data) > 0:
|
|
# Check if volume was higher during extrema
|
|
try:
|
|
recent_candles = list(self.context_data[list(self.context_data.keys())[0]].candles)
|
|
if len(recent_candles) > extrema_index:
|
|
extrema_volume = recent_candles[extrema_index].get('volume', 0)
|
|
avg_volume = np.mean([c.get('volume', 0) for c in recent_candles[-20:]])
|
|
if avg_volume > 0:
|
|
volume_factor = min(1.2, extrema_volume / avg_volume)
|
|
except:
|
|
pass
|
|
|
|
# 2. Price momentum before extrema
|
|
momentum_factor = 1.0
|
|
if extrema_index >= 3:
|
|
price_momentum = abs(prices[extrema_index] - prices[extrema_index - 3]) / prices[extrema_index - 3]
|
|
momentum_factor = min(1.1, 1.0 + price_momentum * 10)
|
|
|
|
# Combine factors
|
|
confidence = deviation * volume_factor * momentum_factor
|
|
|
|
# Ensure confidence is within bounds
|
|
confidence = min(self.max_confidence_threshold, max(self.min_confidence_threshold, confidence))
|
|
|
|
return confidence
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating extrema confidence: {e}")
|
|
return 0.5
|
|
|
|
def _get_extrema_market_context(self, symbol: str, timestamp: datetime) -> Dict[str, Any]:
|
|
"""Get market context at the time of extrema detection"""
|
|
try:
|
|
context = {
|
|
'volatility': 0.0,
|
|
'volume_spike': False,
|
|
'trend_strength': 0.0,
|
|
'rsi_level': 50.0,
|
|
'price_momentum': 0.0
|
|
}
|
|
|
|
if len(self.context_data[symbol].candles) >= 20:
|
|
recent_candles = list(self.context_data[symbol].candles)[-20:]
|
|
|
|
# Calculate volatility
|
|
prices = [c['close'] for c in recent_candles]
|
|
price_changes = [abs(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))]
|
|
context['volatility'] = np.mean(price_changes) if price_changes else 0.0
|
|
|
|
# Check for volume spike
|
|
volumes = [c['volume'] for c in recent_candles]
|
|
avg_volume = np.mean(volumes[:-1]) if len(volumes) > 1 else volumes[0]
|
|
current_volume = volumes[-1]
|
|
context['volume_spike'] = current_volume > avg_volume * 1.5
|
|
|
|
# Simple trend strength
|
|
if len(prices) >= 10:
|
|
trend_slope = (prices[-1] - prices[-10]) / prices[-10]
|
|
context['trend_strength'] = abs(trend_slope)
|
|
|
|
# Price momentum
|
|
if len(prices) >= 5:
|
|
momentum = (prices[-1] - prices[-5]) / prices[-5]
|
|
context['price_momentum'] = momentum
|
|
|
|
return context
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting extrema market context: {e}")
|
|
return {'volatility': 0.0, 'volume_spike': False, 'trend_strength': 0.0, 'rsi_level': 50.0, 'price_momentum': 0.0}
|
|
|
|
def _create_context_features(self, df: pd.DataFrame) -> Optional[np.ndarray]:
|
|
"""Create feature matrix from 1m context data for model consumption"""
|
|
try:
|
|
if df is None or len(df) < 50:
|
|
return None
|
|
|
|
# Select key features for context
|
|
feature_columns = ['open', 'high', 'low', 'close', 'volume']
|
|
|
|
# Add technical indicators if available
|
|
if 'rsi_14' in df.columns:
|
|
feature_columns.extend(['rsi_14', 'sma_20', 'ema_12'])
|
|
if 'macd' in df.columns:
|
|
feature_columns.extend(['macd', 'macd_signal'])
|
|
if 'bb_upper' in df.columns:
|
|
feature_columns.extend(['bb_upper', 'bb_lower', 'bb_percent'])
|
|
|
|
# Extract available features
|
|
available_features = [col for col in feature_columns if col in df.columns]
|
|
feature_data = df[available_features].copy()
|
|
|
|
# Normalize features
|
|
normalized_features = self._normalize_context_features(feature_data)
|
|
|
|
return normalized_features.values if normalized_features is not None else None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating context features: {e}")
|
|
return None
|
|
|
|
def _normalize_context_features(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
|
"""Normalize context features for model consumption"""
|
|
try:
|
|
df_norm = df.copy()
|
|
|
|
# Price normalization (relative to latest close)
|
|
if 'close' in df_norm.columns:
|
|
latest_close = df_norm['close'].iloc[-1]
|
|
for col in ['open', 'high', 'low', 'close', 'sma_20', 'ema_12', 'bb_upper', 'bb_lower']:
|
|
if col in df_norm.columns and latest_close > 0:
|
|
df_norm[col] = df_norm[col] / latest_close
|
|
|
|
# Volume normalization
|
|
if 'volume' in df_norm.columns:
|
|
volume_mean = df_norm['volume'].mean()
|
|
if volume_mean > 0:
|
|
df_norm['volume'] = df_norm['volume'] / volume_mean
|
|
|
|
# RSI normalization (0-100 to 0-1)
|
|
if 'rsi_14' in df_norm.columns:
|
|
df_norm['rsi_14'] = df_norm['rsi_14'] / 100.0
|
|
|
|
# MACD normalization
|
|
if 'macd' in df_norm.columns and 'close' in df.columns:
|
|
latest_close = df['close'].iloc[-1]
|
|
df_norm['macd'] = df_norm['macd'] / latest_close
|
|
if 'macd_signal' in df_norm.columns:
|
|
df_norm['macd_signal'] = df_norm['macd_signal'] / latest_close
|
|
|
|
# BB percent is already normalized
|
|
if 'bb_percent' in df_norm.columns:
|
|
df_norm['bb_percent'] = np.clip(df_norm['bb_percent'], 0, 1)
|
|
|
|
# Fill NaN values
|
|
df_norm = df_norm.fillna(0)
|
|
|
|
return df_norm
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error normalizing context features: {e}")
|
|
return df
|
|
|
|
def get_context_features_for_model(self, symbol: str) -> Optional[np.ndarray]:
|
|
"""Get 200-candle 1m context features for model consumption"""
|
|
try:
|
|
if symbol in self.context_data and self.context_data[symbol].features is not None:
|
|
return self.context_data[symbol].features
|
|
|
|
# If no cached features, create them from current data
|
|
if len(self.context_data[symbol].candles) >= 50:
|
|
context_df = pd.DataFrame(list(self.context_data[symbol].candles))
|
|
features = self._create_context_features(context_df)
|
|
self.context_data[symbol].features = features
|
|
return features
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting context features for {symbol}: {e}")
|
|
return None
|
|
|
|
def get_extrema_training_data(self, count: int = 50, min_confidence: float = None) -> List[ExtremaPoint]:
|
|
"""Get recent extrema training data for model training"""
|
|
try:
|
|
extrema_list = list(self.extrema_training_queue)
|
|
|
|
# Filter by confidence if specified
|
|
if min_confidence is not None:
|
|
extrema_list = [e for e in extrema_list if e.confidence >= min_confidence]
|
|
|
|
return extrema_list[-count:] if extrema_list else []
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting extrema training data: {e}")
|
|
return []
|
|
|
|
def get_perfect_moves_for_cnn(self, count: int = 100) -> List[Dict[str, Any]]:
|
|
"""Get perfect moves formatted for CNN training"""
|
|
try:
|
|
extrema_data = self.get_extrema_training_data(count)
|
|
perfect_moves = []
|
|
|
|
for extrema in extrema_data:
|
|
if extrema.outcome is not None:
|
|
perfect_move = {
|
|
'symbol': extrema.symbol,
|
|
'timeframe': '1m',
|
|
'timestamp': extrema.timestamp,
|
|
'optimal_action': extrema.optimal_action,
|
|
'actual_outcome': abs(extrema.outcome),
|
|
'confidence_should_have_been': extrema.confidence,
|
|
'market_context': extrema.market_context,
|
|
'extrema_type': extrema.extrema_type
|
|
}
|
|
perfect_moves.append(perfect_move)
|
|
|
|
return perfect_moves
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting perfect moves for CNN: {e}")
|
|
return []
|
|
|
|
def get_extrema_stats(self) -> Dict[str, Any]:
|
|
"""Get statistics about extrema detection and training"""
|
|
try:
|
|
stats = {
|
|
'total_extrema_detected': sum(len(extrema) for extrema in self.detected_extrema.values()),
|
|
'extrema_by_symbol': {symbol: len(extrema) for symbol, extrema in self.detected_extrema.items()},
|
|
'training_queue_size': len(self.extrema_training_queue),
|
|
'last_extrema_check': {symbol: check_time.isoformat()
|
|
for symbol, check_time in self.last_extrema_check.items()},
|
|
'context_data_status': {
|
|
symbol: {
|
|
'candles_loaded': len(self.context_data[symbol].candles),
|
|
'features_available': self.context_data[symbol].features is not None,
|
|
'last_update': self.context_data[symbol].last_update.isoformat()
|
|
}
|
|
for symbol in self.symbols
|
|
},
|
|
'window_size': self.window_size,
|
|
'confidence_thresholds': {
|
|
'min': self.min_confidence_threshold,
|
|
'max': self.max_confidence_threshold
|
|
}
|
|
}
|
|
|
|
# Recent extrema breakdown
|
|
recent_extrema = list(self.extrema_training_queue)[-20:]
|
|
if recent_extrema:
|
|
bottoms = len([e for e in recent_extrema if e.extrema_type == 'bottom'])
|
|
tops = len([e for e in recent_extrema if e.extrema_type == 'top'])
|
|
avg_confidence = np.mean([e.confidence for e in recent_extrema])
|
|
avg_outcome = np.mean([e.outcome for e in recent_extrema if e.outcome is not None])
|
|
|
|
stats['recent_extrema'] = {
|
|
'bottoms': bottoms,
|
|
'tops': tops,
|
|
'avg_confidence': avg_confidence,
|
|
'avg_outcome': avg_outcome if not np.isnan(avg_outcome) else 0.0
|
|
}
|
|
|
|
return stats
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting extrema stats: {e}")
|
|
return {}
|
|
|
|
def run_batch_detection(self) -> Dict[str, List[ExtremaPoint]]:
|
|
"""Run extrema detection for all symbols"""
|
|
results = {}
|
|
|
|
try:
|
|
for symbol in self.symbols:
|
|
detected = self.detect_local_extrema(symbol)
|
|
results[symbol] = detected
|
|
|
|
total_detected = sum(len(extrema_list) for extrema_list in results.values())
|
|
logger.info(f"Batch extrema detection completed: {total_detected} extrema detected across {len(self.symbols)} symbols")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in batch extrema detection: {e}")
|
|
|
|
return results |