gogo2/core/extrema_trainer.py
2025-05-27 02:36:20 +03:00

584 lines
27 KiB
Python

"""
Extrema Training Module - Reusable Local Bottom/Top Detection and Training
This module provides reusable functionality for:
1. Detecting local extrema (bottoms and tops) in price data
2. Creating training opportunities from extrema
3. Loading and managing 200-candle 1m context data
4. Generating features for model consumption
5. Training on not-so-perfect opportunities
Can be used across different dashboards and trading systems.
"""
import logging
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from collections import deque
logger = logging.getLogger(__name__)
@dataclass
class ExtremaPoint:
"""Represents a detected local extrema (bottom or top)"""
symbol: str
timestamp: datetime
price: float
extrema_type: str # 'bottom' or 'top'
confidence: float
context_before: List[float]
context_after: List[float]
optimal_action: str # 'BUY' or 'SELL'
market_context: Dict[str, Any]
outcome: Optional[float] = None # Price change after extrema
@dataclass
class ContextData:
"""200-candle 1m context data for enhanced model performance"""
symbol: str
candles: deque
features: Optional[np.ndarray]
last_update: datetime
class ExtremaTrainer:
"""Reusable extrema detection and training functionality"""
def __init__(self, data_provider, symbols: List[str], window_size: int = 10):
"""
Initialize the extrema trainer
Args:
data_provider: Data provider instance
symbols: List of symbols to track
window_size: Window size for extrema detection (default 10)
"""
self.data_provider = data_provider
self.symbols = symbols
self.window_size = window_size
# Extrema tracking
self.detected_extrema = {symbol: deque(maxlen=1000) for symbol in symbols}
self.extrema_training_queue = deque(maxlen=500)
self.last_extrema_check = {symbol: datetime.now() for symbol in symbols}
# 200-candle context data
self.context_data = {symbol: ContextData(
symbol=symbol,
candles=deque(maxlen=200),
features=None,
last_update=datetime.now()
) for symbol in symbols}
self.context_update_frequency = 60 # Update every 60 seconds
# Training parameters
self.min_confidence_threshold = 0.3 # Train on opportunities with at least 30% confidence
self.max_confidence_threshold = 0.95 # Cap confidence at 95%
logger.info(f"ExtremaTrainer initialized for symbols: {symbols}")
logger.info(f"Window size: {window_size}, Context update frequency: {self.context_update_frequency}s")
def initialize_context_data(self) -> Dict[str, bool]:
"""Initialize 200-candle 1m context data for all symbols"""
results = {}
try:
logger.info("Initializing 200-candle 1m context data for enhanced model performance")
for symbol in self.symbols:
try:
# Load 200 candles of 1m data
context_data = self.data_provider.get_historical_data(symbol, '1m', limit=200)
if context_data is not None and len(context_data) > 0:
# Store raw data
for _, row in context_data.iterrows():
candle_data = {
'timestamp': row['timestamp'],
'open': row['open'],
'high': row['high'],
'low': row['low'],
'close': row['close'],
'volume': row['volume']
}
self.context_data[symbol].candles.append(candle_data)
# Create feature matrix for models
self.context_data[symbol].features = self._create_context_features(context_data)
self.context_data[symbol].last_update = datetime.now()
results[symbol] = True
logger.info(f"✅ Loaded {len(context_data)} 1m candles for {symbol} context")
else:
results[symbol] = False
logger.warning(f"❌ No 1m context data available for {symbol}")
except Exception as e:
logger.error(f"❌ Error loading context data for {symbol}: {e}")
results[symbol] = False
except Exception as e:
logger.error(f"Error initializing context data: {e}")
successful = sum(1 for success in results.values() if success)
logger.info(f"Context data initialization: {successful}/{len(self.symbols)} symbols loaded")
return results
def update_context_data(self, symbol: str = None) -> Dict[str, bool]:
"""Update 200-candle 1m context data for specified symbol or all symbols"""
results = {}
try:
symbols_to_update = [symbol] if symbol else self.symbols
for sym in symbols_to_update:
try:
# Check if update is needed
time_since_update = (datetime.now() - self.context_data[sym].last_update).total_seconds()
if time_since_update >= self.context_update_frequency:
# Get latest 1m data
latest_data = self.data_provider.get_historical_data(sym, '1m', limit=10, refresh=True)
if latest_data is not None and len(latest_data) > 0:
# Add new candles to context
for _, row in latest_data.iterrows():
candle_data = {
'timestamp': row['timestamp'],
'open': row['open'],
'high': row['high'],
'low': row['low'],
'close': row['close'],
'volume': row['volume']
}
# Check if this candle is newer than our latest
if (not self.context_data[sym].candles or
candle_data['timestamp'] > self.context_data[sym].candles[-1]['timestamp']):
self.context_data[sym].candles.append(candle_data)
# Update feature matrix
if len(self.context_data[sym].candles) >= 50:
context_df = pd.DataFrame(list(self.context_data[sym].candles))
self.context_data[sym].features = self._create_context_features(context_df)
self.context_data[sym].last_update = datetime.now()
# Check for local extrema in updated data
self.detect_local_extrema(sym)
results[sym] = True
else:
results[sym] = False
else:
results[sym] = True # No update needed
except Exception as e:
logger.error(f"Error updating context data for {sym}: {e}")
results[sym] = False
except Exception as e:
logger.error(f"Error updating context data: {e}")
return results
def detect_local_extrema(self, symbol: str) -> List[ExtremaPoint]:
"""Detect local bottoms and tops for training opportunities"""
detected = []
try:
if len(self.context_data[symbol].candles) < self.window_size * 3:
return detected
# Get all available price data for better extrema detection
all_candles = list(self.context_data[symbol].candles)
prices = [candle['close'] for candle in all_candles]
timestamps = [candle['timestamp'] for candle in all_candles]
# Use a more sophisticated extrema detection algorithm
window = self.window_size
# Look for extrema in the middle portion of the data (not at edges)
start_idx = window
end_idx = len(prices) - window
for i in range(start_idx, end_idx):
current_price = prices[i]
current_time = timestamps[i]
# Get surrounding prices for comparison
left_prices = prices[i - window:i]
right_prices = prices[i + 1:i + window + 1]
# Check for local bottom (current price is lower than surrounding prices)
is_bottom = (current_price <= min(left_prices) and
current_price <= min(right_prices) and
current_price < max(left_prices) * 0.998) # At least 0.2% lower
# Check for local top (current price is higher than surrounding prices)
is_top = (current_price >= max(left_prices) and
current_price >= max(right_prices) and
current_price > min(left_prices) * 1.002) # At least 0.2% higher
if is_bottom or is_top:
extrema_type = 'bottom' if is_bottom else 'top'
# Calculate confidence based on price deviation and volume
confidence = self._calculate_extrema_confidence(prices, i, window)
# Only process if confidence meets minimum threshold
if confidence >= self.min_confidence_threshold:
# Check if this extrema is too close to a previously detected one
if not self._is_too_close_to_existing_extrema(symbol, current_time, current_price):
# Create extrema point
extrema_point = ExtremaPoint(
symbol=symbol,
timestamp=current_time,
price=current_price,
extrema_type=extrema_type,
confidence=min(confidence, self.max_confidence_threshold),
context_before=left_prices,
context_after=right_prices,
optimal_action='BUY' if is_bottom else 'SELL',
market_context=self._get_extrema_market_context(symbol, current_time)
)
# Calculate outcome if we have future data
if len(right_prices) > 0:
# Look ahead further for better outcome calculation
future_idx = min(i + window * 2, len(prices) - 1)
future_price = prices[future_idx]
price_change = (future_price - current_price) / current_price
# For bottoms, positive change is good; for tops, negative is good
if extrema_type == 'bottom':
extrema_point.outcome = price_change
else: # top
extrema_point.outcome = -price_change
self.detected_extrema[symbol].append(extrema_point)
self.extrema_training_queue.append(extrema_point)
detected.append(extrema_point)
logger.info(f"Local {extrema_type} detected for {symbol} at ${current_price:.2f} "
f"(confidence: {confidence:.3f}, outcome: {extrema_point.outcome:.4f})")
self.last_extrema_check[symbol] = datetime.now()
except Exception as e:
logger.error(f"Error detecting local extrema for {symbol}: {e}")
return detected
def _is_too_close_to_existing_extrema(self, symbol: str, timestamp: datetime, price: float) -> bool:
"""Check if this extrema is too close to an existing one"""
try:
if symbol not in self.detected_extrema:
return False
recent_extrema = list(self.detected_extrema[symbol])[-10:] # Check last 10 extrema
for existing_extrema in recent_extrema:
# Check time proximity (within 30 minutes)
time_diff = abs((timestamp - existing_extrema.timestamp).total_seconds())
if time_diff < 1800: # 30 minutes
# Check price proximity (within 1%)
price_diff = abs(price - existing_extrema.price) / existing_extrema.price
if price_diff < 0.01: # 1%
return True
return False
except Exception as e:
logger.error(f"Error checking extrema proximity: {e}")
return False
def _calculate_extrema_confidence(self, prices: List[float], extrema_index: int, window: int) -> float:
"""Calculate confidence level for detected extrema"""
try:
extrema_price = prices[extrema_index]
# Calculate price deviation from extrema
surrounding_prices = prices[max(0, extrema_index - window):extrema_index + window + 1]
price_range = max(surrounding_prices) - min(surrounding_prices)
if price_range == 0:
return 0.5
# Calculate how extreme the point is
if extrema_price == min(surrounding_prices): # Bottom
deviation = (max(surrounding_prices) - extrema_price) / price_range
else: # Top
deviation = (extrema_price - min(surrounding_prices)) / price_range
# Additional factors for confidence
# 1. Volume confirmation
volume_factor = 1.0
if len(self.context_data) > 0:
# Check if volume was higher during extrema
try:
recent_candles = list(self.context_data[list(self.context_data.keys())[0]].candles)
if len(recent_candles) > extrema_index:
extrema_volume = recent_candles[extrema_index].get('volume', 0)
avg_volume = np.mean([c.get('volume', 0) for c in recent_candles[-20:]])
if avg_volume > 0:
volume_factor = min(1.2, extrema_volume / avg_volume)
except:
pass
# 2. Price momentum before extrema
momentum_factor = 1.0
if extrema_index >= 3:
price_momentum = abs(prices[extrema_index] - prices[extrema_index - 3]) / prices[extrema_index - 3]
momentum_factor = min(1.1, 1.0 + price_momentum * 10)
# Combine factors
confidence = deviation * volume_factor * momentum_factor
# Ensure confidence is within bounds
confidence = min(self.max_confidence_threshold, max(self.min_confidence_threshold, confidence))
return confidence
except Exception as e:
logger.error(f"Error calculating extrema confidence: {e}")
return 0.5
def _get_extrema_market_context(self, symbol: str, timestamp: datetime) -> Dict[str, Any]:
"""Get market context at the time of extrema detection"""
try:
context = {
'volatility': 0.0,
'volume_spike': False,
'trend_strength': 0.0,
'rsi_level': 50.0,
'price_momentum': 0.0
}
if len(self.context_data[symbol].candles) >= 20:
recent_candles = list(self.context_data[symbol].candles)[-20:]
# Calculate volatility
prices = [c['close'] for c in recent_candles]
price_changes = [abs(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))]
context['volatility'] = np.mean(price_changes) if price_changes else 0.0
# Check for volume spike
volumes = [c['volume'] for c in recent_candles]
avg_volume = np.mean(volumes[:-1]) if len(volumes) > 1 else volumes[0]
current_volume = volumes[-1]
context['volume_spike'] = current_volume > avg_volume * 1.5
# Simple trend strength
if len(prices) >= 10:
trend_slope = (prices[-1] - prices[-10]) / prices[-10]
context['trend_strength'] = abs(trend_slope)
# Price momentum
if len(prices) >= 5:
momentum = (prices[-1] - prices[-5]) / prices[-5]
context['price_momentum'] = momentum
return context
except Exception as e:
logger.error(f"Error getting extrema market context: {e}")
return {'volatility': 0.0, 'volume_spike': False, 'trend_strength': 0.0, 'rsi_level': 50.0, 'price_momentum': 0.0}
def _create_context_features(self, df: pd.DataFrame) -> Optional[np.ndarray]:
"""Create feature matrix from 1m context data for model consumption"""
try:
if df is None or len(df) < 50:
return None
# Select key features for context
feature_columns = ['open', 'high', 'low', 'close', 'volume']
# Add technical indicators if available
if 'rsi_14' in df.columns:
feature_columns.extend(['rsi_14', 'sma_20', 'ema_12'])
if 'macd' in df.columns:
feature_columns.extend(['macd', 'macd_signal'])
if 'bb_upper' in df.columns:
feature_columns.extend(['bb_upper', 'bb_lower', 'bb_percent'])
# Extract available features
available_features = [col for col in feature_columns if col in df.columns]
feature_data = df[available_features].copy()
# Normalize features
normalized_features = self._normalize_context_features(feature_data)
return normalized_features.values if normalized_features is not None else None
except Exception as e:
logger.error(f"Error creating context features: {e}")
return None
def _normalize_context_features(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
"""Normalize context features for model consumption"""
try:
df_norm = df.copy()
# Price normalization (relative to latest close)
if 'close' in df_norm.columns:
latest_close = df_norm['close'].iloc[-1]
for col in ['open', 'high', 'low', 'close', 'sma_20', 'ema_12', 'bb_upper', 'bb_lower']:
if col in df_norm.columns and latest_close > 0:
df_norm[col] = df_norm[col] / latest_close
# Volume normalization
if 'volume' in df_norm.columns:
volume_mean = df_norm['volume'].mean()
if volume_mean > 0:
df_norm['volume'] = df_norm['volume'] / volume_mean
# RSI normalization (0-100 to 0-1)
if 'rsi_14' in df_norm.columns:
df_norm['rsi_14'] = df_norm['rsi_14'] / 100.0
# MACD normalization
if 'macd' in df_norm.columns and 'close' in df.columns:
latest_close = df['close'].iloc[-1]
df_norm['macd'] = df_norm['macd'] / latest_close
if 'macd_signal' in df_norm.columns:
df_norm['macd_signal'] = df_norm['macd_signal'] / latest_close
# BB percent is already normalized
if 'bb_percent' in df_norm.columns:
df_norm['bb_percent'] = np.clip(df_norm['bb_percent'], 0, 1)
# Fill NaN values
df_norm = df_norm.fillna(0)
return df_norm
except Exception as e:
logger.error(f"Error normalizing context features: {e}")
return df
def get_context_features_for_model(self, symbol: str) -> Optional[np.ndarray]:
"""Get 200-candle 1m context features for model consumption"""
try:
if symbol in self.context_data and self.context_data[symbol].features is not None:
return self.context_data[symbol].features
# If no cached features, create them from current data
if len(self.context_data[symbol].candles) >= 50:
context_df = pd.DataFrame(list(self.context_data[symbol].candles))
features = self._create_context_features(context_df)
self.context_data[symbol].features = features
return features
return None
except Exception as e:
logger.error(f"Error getting context features for {symbol}: {e}")
return None
def get_extrema_training_data(self, count: int = 50, min_confidence: float = None) -> List[ExtremaPoint]:
"""Get recent extrema training data for model training"""
try:
extrema_list = list(self.extrema_training_queue)
# Filter by confidence if specified
if min_confidence is not None:
extrema_list = [e for e in extrema_list if e.confidence >= min_confidence]
return extrema_list[-count:] if extrema_list else []
except Exception as e:
logger.error(f"Error getting extrema training data: {e}")
return []
def get_perfect_moves_for_cnn(self, count: int = 100) -> List[Dict[str, Any]]:
"""Get perfect moves formatted for CNN training"""
try:
extrema_data = self.get_extrema_training_data(count)
perfect_moves = []
for extrema in extrema_data:
if extrema.outcome is not None:
perfect_move = {
'symbol': extrema.symbol,
'timeframe': '1m',
'timestamp': extrema.timestamp,
'optimal_action': extrema.optimal_action,
'actual_outcome': abs(extrema.outcome),
'confidence_should_have_been': extrema.confidence,
'market_context': extrema.market_context,
'extrema_type': extrema.extrema_type
}
perfect_moves.append(perfect_move)
return perfect_moves
except Exception as e:
logger.error(f"Error getting perfect moves for CNN: {e}")
return []
def get_extrema_stats(self) -> Dict[str, Any]:
"""Get statistics about extrema detection and training"""
try:
stats = {
'total_extrema_detected': sum(len(extrema) for extrema in self.detected_extrema.values()),
'extrema_by_symbol': {symbol: len(extrema) for symbol, extrema in self.detected_extrema.items()},
'training_queue_size': len(self.extrema_training_queue),
'last_extrema_check': {symbol: check_time.isoformat()
for symbol, check_time in self.last_extrema_check.items()},
'context_data_status': {
symbol: {
'candles_loaded': len(self.context_data[symbol].candles),
'features_available': self.context_data[symbol].features is not None,
'last_update': self.context_data[symbol].last_update.isoformat()
}
for symbol in self.symbols
},
'window_size': self.window_size,
'confidence_thresholds': {
'min': self.min_confidence_threshold,
'max': self.max_confidence_threshold
}
}
# Recent extrema breakdown
recent_extrema = list(self.extrema_training_queue)[-20:]
if recent_extrema:
bottoms = len([e for e in recent_extrema if e.extrema_type == 'bottom'])
tops = len([e for e in recent_extrema if e.extrema_type == 'top'])
avg_confidence = np.mean([e.confidence for e in recent_extrema])
avg_outcome = np.mean([e.outcome for e in recent_extrema if e.outcome is not None])
stats['recent_extrema'] = {
'bottoms': bottoms,
'tops': tops,
'avg_confidence': avg_confidence,
'avg_outcome': avg_outcome if not np.isnan(avg_outcome) else 0.0
}
return stats
except Exception as e:
logger.error(f"Error getting extrema stats: {e}")
return {}
def run_batch_detection(self) -> Dict[str, List[ExtremaPoint]]:
"""Run extrema detection for all symbols"""
results = {}
try:
for symbol in self.symbols:
detected = self.detect_local_extrema(symbol)
results[symbol] = detected
total_detected = sum(len(extrema_list) for extrema_list in results.values())
logger.info(f"Batch extrema detection completed: {total_detected} extrema detected across {len(self.symbols)} symbols")
except Exception as e:
logger.error(f"Error in batch extrema detection: {e}")
return results