detecting local extremes and training on them

This commit is contained in:
Dobromir Popov
2025-05-27 02:36:20 +03:00
parent 2ba0406b9f
commit cc20b6194a
14 changed files with 3415 additions and 91 deletions

View File

@ -22,10 +22,13 @@ from collections import deque
import torch
from .config import get_config
from .data_provider import DataProvider, RawTick, OHLCVBar
from .data_provider import DataProvider, RawTick, OHLCVBar, MarketTick
from .universal_data_adapter import UniversalDataAdapter, UniversalDataStream
from .realtime_tick_processor import RealTimeTickProcessor, ProcessedTickFeatures, integrate_with_orchestrator
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface
from .extrema_trainer import ExtremaTrainer
from .trading_action import TradingAction
from .negative_case_trainer import NegativeCaseTrainer
logger = logging.getLogger(__name__)
@ -87,6 +90,28 @@ class PerfectMove:
market_state_after: MarketState
confidence_should_have_been: float
@dataclass
class TradeInfo:
"""Information about an active trade"""
symbol: str
side: str # 'LONG' or 'SHORT'
entry_price: float
entry_time: datetime
size: float
confidence: float
market_state: Dict[str, Any]
@dataclass
class LearningCase:
"""A learning case for DQN sensitivity training"""
state_vector: np.ndarray
action: int # sensitivity level chosen
reward: float
next_state_vector: np.ndarray
done: bool
trade_info: TradeInfo
outcome: float # P&L percentage
class EnhancedTradingOrchestrator:
"""
Enhanced orchestrator with sophisticated multi-modal decision making
@ -105,6 +130,16 @@ class EnhancedTradingOrchestrator:
# Initialize real-time tick processor for ultra-low latency processing
self.tick_processor = RealTimeTickProcessor(symbols=self.config.symbols)
# Initialize extrema trainer for local bottom/top detection and 200-candle context
self.extrema_trainer = ExtremaTrainer(
data_provider=self.data_provider,
symbols=self.config.symbols,
window_size=10 # 10-candle window for extrema detection
)
# Initialize negative case trainer for intensive training on losing trades
self.negative_case_trainer = NegativeCaseTrainer()
# Real-time tick features storage
self.realtime_tick_features = {symbol: deque(maxlen=100) for symbol in self.config.symbols}
@ -151,6 +186,18 @@ class EnhancedTradingOrchestrator:
self.retrospective_learning_active = False
self.last_retrospective_analysis = datetime.now()
# Local extrema tracking for training on bottoms and tops
self.local_extrema = {symbol: deque(maxlen=1000) for symbol in self.symbols}
self.extrema_detection_window = 10 # Look for extrema in 10-candle windows
self.extrema_training_queue = deque(maxlen=500) # Queue for extrema-based training
self.last_extrema_check = {symbol: datetime.now() for symbol in self.symbols}
# 200-candle context data for models
self.context_data_1m = {symbol: deque(maxlen=200) for symbol in self.symbols}
self.context_features_1m = {symbol: None for symbol in self.symbols}
self.context_update_frequency = 60 # Update context every 60 seconds
self.last_context_update = {symbol: datetime.now() for symbol in self.symbols}
# RL feedback system
self.rl_evaluation_queue = deque(maxlen=1000)
self.environment_adaptation_rate = 0.01
@ -182,6 +229,9 @@ class EnhancedTradingOrchestrator:
# Current open positions tracking for closing logic
self.open_positions = {} # symbol -> {'side': str, 'entry_price': float, 'timestamp': datetime}
# Initialize 200-candle context data
self._initialize_context_data()
logger.info("Enhanced TradingOrchestrator initialized with Universal Data Format")
logger.info(f"Symbols: {self.symbols}")
logger.info(f"Timeframes: {self.timeframes}")
@ -192,6 +242,8 @@ class EnhancedTradingOrchestrator:
logger.info("Raw tick and OHLCV bar processing enabled for pattern detection")
logger.info("Enhanced retrospective learning enabled for perfect opportunity detection")
logger.info("DQN RL-based sensitivity learning enabled for adaptive thresholds")
logger.info("Local extrema detection enabled for bottom/top training")
logger.info("200-candle 1m context data initialized for enhanced model performance")
def _initialize_timeframe_weights(self) -> Dict[str, float]:
"""Initialize weights for different timeframes"""
@ -713,7 +765,7 @@ class EnhancedTradingOrchestrator:
try:
if symbol not in self.active_trades:
return
trade_info = self.active_trades[symbol]
# Calculate trade outcome
@ -759,7 +811,7 @@ class EnhancedTradingOrchestrator:
del self.active_trades[symbol]
logger.info(f"Closed trade for sensitivity learning: {symbol} {side} P&L: {pnl_pct*100:+.2f}% Duration: {duration:.0f}s")
except Exception as e:
logger.error(f"Error tracking trade closing for sensitivity learning: {e}")
@ -818,7 +870,7 @@ class EnhancedTradingOrchestrator:
'price_change_4': price_changes[-4] if len(price_changes) > 3 else 0.0,
'price_change_5': price_changes[-5] if len(price_changes) > 4 else 0.0
}
except Exception as e:
logger.error(f"Error getting market state for sensitivity learning: {e}")
return self._get_default_market_state()
@ -969,7 +1021,7 @@ class EnhancedTradingOrchestrator:
final_reward = np.clip(final_reward, -2.0, 2.0)
return float(final_reward)
except Exception as e:
logger.error(f"Error calculating sensitivity reward: {e}")
return 0.0
@ -1045,7 +1097,7 @@ class EnhancedTradingOrchestrator:
# Update current sensitivity level based on recent performance
self._update_current_sensitivity_level()
except Exception as e:
logger.error(f"Error training sensitivity DQN: {e}")
@ -1131,6 +1183,374 @@ class EnhancedTradingOrchestrator:
"""Get current closing threshold"""
return self.confidence_threshold_close
def _initialize_context_data(self):
"""Initialize 200-candle 1m context data for all symbols"""
try:
logger.info("Initializing 200-candle 1m context data for enhanced model performance")
for symbol in self.symbols:
try:
# Load 200 candles of 1m data
context_data = self.data_provider.get_historical_data(symbol, '1m', limit=200)
if context_data is not None and len(context_data) > 0:
# Store raw data
for _, row in context_data.iterrows():
candle_data = {
'timestamp': row['timestamp'],
'open': row['open'],
'high': row['high'],
'low': row['low'],
'close': row['close'],
'volume': row['volume']
}
self.context_data_1m[symbol].append(candle_data)
# Create feature matrix for models
self.context_features_1m[symbol] = self._create_context_features(context_data)
logger.info(f"Loaded {len(context_data)} 1m candles for {symbol} context")
else:
logger.warning(f"No 1m context data available for {symbol}")
except Exception as e:
logger.error(f"Error loading context data for {symbol}: {e}")
except Exception as e:
logger.error(f"Error initializing context data: {e}")
def _create_context_features(self, df: pd.DataFrame) -> Optional[np.ndarray]:
"""Create feature matrix from 1m context data for model consumption"""
try:
if df is None or len(df) < 50:
return None
# Select key features for context
feature_columns = ['open', 'high', 'low', 'close', 'volume']
# Add technical indicators if available
if 'rsi_14' in df.columns:
feature_columns.extend(['rsi_14', 'sma_20', 'ema_12'])
if 'macd' in df.columns:
feature_columns.extend(['macd', 'macd_signal'])
if 'bb_upper' in df.columns:
feature_columns.extend(['bb_upper', 'bb_lower', 'bb_percent'])
# Extract available features
available_features = [col for col in feature_columns if col in df.columns]
feature_data = df[available_features].copy()
# Normalize features
normalized_features = self._normalize_context_features(feature_data)
return normalized_features.values if normalized_features is not None else None
except Exception as e:
logger.error(f"Error creating context features: {e}")
return None
def _normalize_context_features(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
"""Normalize context features for model consumption"""
try:
df_norm = df.copy()
# Price normalization (relative to latest close)
if 'close' in df_norm.columns:
latest_close = df_norm['close'].iloc[-1]
for col in ['open', 'high', 'low', 'close', 'sma_20', 'ema_12', 'bb_upper', 'bb_lower']:
if col in df_norm.columns and latest_close > 0:
df_norm[col] = df_norm[col] / latest_close
# Volume normalization
if 'volume' in df_norm.columns:
volume_mean = df_norm['volume'].mean()
if volume_mean > 0:
df_norm['volume'] = df_norm['volume'] / volume_mean
# RSI normalization (0-100 to 0-1)
if 'rsi_14' in df_norm.columns:
df_norm['rsi_14'] = df_norm['rsi_14'] / 100.0
# MACD normalization
if 'macd' in df_norm.columns and 'close' in df.columns:
latest_close = df['close'].iloc[-1]
df_norm['macd'] = df_norm['macd'] / latest_close
if 'macd_signal' in df_norm.columns:
df_norm['macd_signal'] = df_norm['macd_signal'] / latest_close
# BB percent is already normalized
if 'bb_percent' in df_norm.columns:
df_norm['bb_percent'] = np.clip(df_norm['bb_percent'], 0, 1)
# Fill NaN values
df_norm = df_norm.fillna(0)
return df_norm
except Exception as e:
logger.error(f"Error normalizing context features: {e}")
return df
def update_context_data(self, symbol: str = None):
"""Update 200-candle 1m context data for specified symbol or all symbols"""
try:
symbols_to_update = [symbol] if symbol else self.symbols
for sym in symbols_to_update:
# Check if update is needed
time_since_update = (datetime.now() - self.last_context_update[sym]).total_seconds()
if time_since_update >= self.context_update_frequency:
# Get latest 1m data
latest_data = self.data_provider.get_historical_data(sym, '1m', limit=10, refresh=True)
if latest_data is not None and len(latest_data) > 0:
# Add new candles to context
for _, row in latest_data.iterrows():
candle_data = {
'timestamp': row['timestamp'],
'open': row['open'],
'high': row['high'],
'low': row['low'],
'close': row['close'],
'volume': row['volume']
}
# Check if this candle is newer than our latest
if (not self.context_data_1m[sym] or
candle_data['timestamp'] > self.context_data_1m[sym][-1]['timestamp']):
self.context_data_1m[sym].append(candle_data)
# Update feature matrix
if len(self.context_data_1m[sym]) >= 50:
context_df = pd.DataFrame(list(self.context_data_1m[sym]))
self.context_features_1m[sym] = self._create_context_features(context_df)
self.last_context_update[sym] = datetime.now()
# Check for local extrema in updated data
self._detect_local_extrema(sym)
except Exception as e:
logger.error(f"Error updating context data: {e}")
def _detect_local_extrema(self, symbol: str):
"""Detect local bottoms and tops for training opportunities"""
try:
if len(self.context_data_1m[symbol]) < self.extrema_detection_window * 2:
return
# Get recent price data
recent_candles = list(self.context_data_1m[symbol])[-self.extrema_detection_window * 2:]
prices = [candle['close'] for candle in recent_candles]
timestamps = [candle['timestamp'] for candle in recent_candles]
# Detect local minima (bottoms) and maxima (tops)
window = self.extrema_detection_window
for i in range(window, len(prices) - window):
current_price = prices[i]
current_time = timestamps[i]
# Check for local bottom
is_bottom = all(current_price <= prices[j] for j in range(i - window, i + window + 1) if j != i)
# Check for local top
is_top = all(current_price >= prices[j] for j in range(i - window, i + window + 1) if j != i)
if is_bottom or is_top:
extrema_type = 'bottom' if is_bottom else 'top'
# Create training opportunity
extrema_data = {
'symbol': symbol,
'timestamp': current_time,
'price': current_price,
'type': extrema_type,
'context_before': prices[max(0, i - window):i],
'context_after': prices[i + 1:min(len(prices), i + window + 1)],
'optimal_action': 'BUY' if is_bottom else 'SELL',
'confidence_level': self._calculate_extrema_confidence(prices, i, window),
'market_context': self._get_extrema_market_context(symbol, current_time)
}
self.local_extrema[symbol].append(extrema_data)
self.extrema_training_queue.append(extrema_data)
logger.info(f"Local {extrema_type} detected for {symbol} at ${current_price:.2f} "
f"(confidence: {extrema_data['confidence_level']:.3f})")
# Create perfect move for CNN training
self._create_extrema_perfect_move(extrema_data)
self.last_extrema_check[symbol] = datetime.now()
except Exception as e:
logger.error(f"Error detecting local extrema for {symbol}: {e}")
def _calculate_extrema_confidence(self, prices: List[float], extrema_index: int, window: int) -> float:
"""Calculate confidence level for detected extrema"""
try:
extrema_price = prices[extrema_index]
# Calculate price deviation from extrema
surrounding_prices = prices[max(0, extrema_index - window):extrema_index + window + 1]
price_range = max(surrounding_prices) - min(surrounding_prices)
if price_range == 0:
return 0.5
# Calculate how extreme the point is
if extrema_price == min(surrounding_prices): # Bottom
deviation = (max(surrounding_prices) - extrema_price) / price_range
else: # Top
deviation = (extrema_price - min(surrounding_prices)) / price_range
# Confidence based on how clear the extrema is
confidence = min(0.95, max(0.3, deviation))
return confidence
except Exception as e:
logger.error(f"Error calculating extrema confidence: {e}")
return 0.5
def _get_extrema_market_context(self, symbol: str, timestamp: datetime) -> Dict[str, Any]:
"""Get market context at the time of extrema detection"""
try:
# Get recent market data around the extrema
context = {
'volatility': 0.0,
'volume_spike': False,
'trend_strength': 0.0,
'rsi_level': 50.0
}
if len(self.context_data_1m[symbol]) >= 20:
recent_candles = list(self.context_data_1m[symbol])[-20:]
# Calculate volatility
prices = [c['close'] for c in recent_candles]
price_changes = [abs(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))]
context['volatility'] = np.mean(price_changes) if price_changes else 0.0
# Check for volume spike
volumes = [c['volume'] for c in recent_candles]
avg_volume = np.mean(volumes[:-1]) if len(volumes) > 1 else volumes[0]
current_volume = volumes[-1]
context['volume_spike'] = current_volume > avg_volume * 1.5
# Simple trend strength
if len(prices) >= 10:
trend_slope = (prices[-1] - prices[-10]) / prices[-10]
context['trend_strength'] = abs(trend_slope)
return context
except Exception as e:
logger.error(f"Error getting extrema market context: {e}")
return {'volatility': 0.0, 'volume_spike': False, 'trend_strength': 0.0, 'rsi_level': 50.0}
def _create_extrema_perfect_move(self, extrema_data: Dict[str, Any]):
"""Create a perfect move from detected extrema for CNN training"""
try:
# Calculate outcome based on price movement after extrema
if len(extrema_data['context_after']) > 0:
price_after = extrema_data['context_after'][-1]
price_change = (price_after - extrema_data['price']) / extrema_data['price']
# For bottoms, positive price change is good; for tops, negative is good
if extrema_data['type'] == 'bottom':
outcome = price_change
else: # top
outcome = -price_change
perfect_move = PerfectMove(
symbol=extrema_data['symbol'],
timeframe='1m',
timestamp=extrema_data['timestamp'],
optimal_action=extrema_data['optimal_action'],
actual_outcome=abs(outcome),
market_state_before=None,
market_state_after=None,
confidence_should_have_been=extrema_data['confidence_level']
)
self.perfect_moves.append(perfect_move)
self.retrospective_learning_active = True
logger.info(f"Created perfect move from {extrema_data['type']} extrema: "
f"{extrema_data['optimal_action']} {extrema_data['symbol']} "
f"(outcome: {outcome*100:+.2f}%)")
except Exception as e:
logger.error(f"Error creating extrema perfect move: {e}")
def get_context_features_for_model(self, symbol: str) -> Optional[np.ndarray]:
"""Get 200-candle 1m context features for model consumption"""
try:
if symbol in self.context_features_1m and self.context_features_1m[symbol] is not None:
return self.context_features_1m[symbol]
# If no cached features, create them from current data
if len(self.context_data_1m[symbol]) >= 50:
context_df = pd.DataFrame(list(self.context_data_1m[symbol]))
features = self._create_context_features(context_df)
self.context_features_1m[symbol] = features
return features
return None
except Exception as e:
logger.error(f"Error getting context features for {symbol}: {e}")
return None
def get_extrema_training_data(self, count: int = 50) -> List[Dict[str, Any]]:
"""Get recent extrema training data for model training"""
try:
return list(self.extrema_training_queue)[-count:] if self.extrema_training_queue else []
except Exception as e:
logger.error(f"Error getting extrema training data: {e}")
return []
def get_extrema_stats(self) -> Dict[str, Any]:
"""Get statistics about extrema detection and training"""
try:
stats = {
'total_extrema_detected': sum(len(extrema) for extrema in self.local_extrema.values()),
'extrema_by_symbol': {symbol: len(extrema) for symbol, extrema in self.local_extrema.items()},
'training_queue_size': len(self.extrema_training_queue),
'last_extrema_check': {symbol: check_time.isoformat()
for symbol, check_time in self.last_extrema_check.items()},
'context_data_status': {
symbol: {
'candles_loaded': len(self.context_data_1m[symbol]),
'features_available': self.context_features_1m[symbol] is not None,
'last_update': self.last_context_update[symbol].isoformat()
}
for symbol in self.symbols
}
}
# Recent extrema breakdown
recent_extrema = list(self.extrema_training_queue)[-20:]
if recent_extrema:
bottoms = len([e for e in recent_extrema if e['type'] == 'bottom'])
tops = len([e for e in recent_extrema if e['type'] == 'top'])
avg_confidence = np.mean([e['confidence_level'] for e in recent_extrema])
stats['recent_extrema'] = {
'bottoms': bottoms,
'tops': tops,
'avg_confidence': avg_confidence
}
return stats
except Exception as e:
logger.error(f"Error getting extrema stats: {e}")
return {}
def process_realtime_features(self, feature_dict: Dict[str, Any]):
"""Process real-time tick features from the tick processor"""
try:

584
core/extrema_trainer.py Normal file
View File

@ -0,0 +1,584 @@
"""
Extrema Training Module - Reusable Local Bottom/Top Detection and Training
This module provides reusable functionality for:
1. Detecting local extrema (bottoms and tops) in price data
2. Creating training opportunities from extrema
3. Loading and managing 200-candle 1m context data
4. Generating features for model consumption
5. Training on not-so-perfect opportunities
Can be used across different dashboards and trading systems.
"""
import logging
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from collections import deque
logger = logging.getLogger(__name__)
@dataclass
class ExtremaPoint:
"""Represents a detected local extrema (bottom or top)"""
symbol: str
timestamp: datetime
price: float
extrema_type: str # 'bottom' or 'top'
confidence: float
context_before: List[float]
context_after: List[float]
optimal_action: str # 'BUY' or 'SELL'
market_context: Dict[str, Any]
outcome: Optional[float] = None # Price change after extrema
@dataclass
class ContextData:
"""200-candle 1m context data for enhanced model performance"""
symbol: str
candles: deque
features: Optional[np.ndarray]
last_update: datetime
class ExtremaTrainer:
"""Reusable extrema detection and training functionality"""
def __init__(self, data_provider, symbols: List[str], window_size: int = 10):
"""
Initialize the extrema trainer
Args:
data_provider: Data provider instance
symbols: List of symbols to track
window_size: Window size for extrema detection (default 10)
"""
self.data_provider = data_provider
self.symbols = symbols
self.window_size = window_size
# Extrema tracking
self.detected_extrema = {symbol: deque(maxlen=1000) for symbol in symbols}
self.extrema_training_queue = deque(maxlen=500)
self.last_extrema_check = {symbol: datetime.now() for symbol in symbols}
# 200-candle context data
self.context_data = {symbol: ContextData(
symbol=symbol,
candles=deque(maxlen=200),
features=None,
last_update=datetime.now()
) for symbol in symbols}
self.context_update_frequency = 60 # Update every 60 seconds
# Training parameters
self.min_confidence_threshold = 0.3 # Train on opportunities with at least 30% confidence
self.max_confidence_threshold = 0.95 # Cap confidence at 95%
logger.info(f"ExtremaTrainer initialized for symbols: {symbols}")
logger.info(f"Window size: {window_size}, Context update frequency: {self.context_update_frequency}s")
def initialize_context_data(self) -> Dict[str, bool]:
"""Initialize 200-candle 1m context data for all symbols"""
results = {}
try:
logger.info("Initializing 200-candle 1m context data for enhanced model performance")
for symbol in self.symbols:
try:
# Load 200 candles of 1m data
context_data = self.data_provider.get_historical_data(symbol, '1m', limit=200)
if context_data is not None and len(context_data) > 0:
# Store raw data
for _, row in context_data.iterrows():
candle_data = {
'timestamp': row['timestamp'],
'open': row['open'],
'high': row['high'],
'low': row['low'],
'close': row['close'],
'volume': row['volume']
}
self.context_data[symbol].candles.append(candle_data)
# Create feature matrix for models
self.context_data[symbol].features = self._create_context_features(context_data)
self.context_data[symbol].last_update = datetime.now()
results[symbol] = True
logger.info(f"✅ Loaded {len(context_data)} 1m candles for {symbol} context")
else:
results[symbol] = False
logger.warning(f"❌ No 1m context data available for {symbol}")
except Exception as e:
logger.error(f"❌ Error loading context data for {symbol}: {e}")
results[symbol] = False
except Exception as e:
logger.error(f"Error initializing context data: {e}")
successful = sum(1 for success in results.values() if success)
logger.info(f"Context data initialization: {successful}/{len(self.symbols)} symbols loaded")
return results
def update_context_data(self, symbol: str = None) -> Dict[str, bool]:
"""Update 200-candle 1m context data for specified symbol or all symbols"""
results = {}
try:
symbols_to_update = [symbol] if symbol else self.symbols
for sym in symbols_to_update:
try:
# Check if update is needed
time_since_update = (datetime.now() - self.context_data[sym].last_update).total_seconds()
if time_since_update >= self.context_update_frequency:
# Get latest 1m data
latest_data = self.data_provider.get_historical_data(sym, '1m', limit=10, refresh=True)
if latest_data is not None and len(latest_data) > 0:
# Add new candles to context
for _, row in latest_data.iterrows():
candle_data = {
'timestamp': row['timestamp'],
'open': row['open'],
'high': row['high'],
'low': row['low'],
'close': row['close'],
'volume': row['volume']
}
# Check if this candle is newer than our latest
if (not self.context_data[sym].candles or
candle_data['timestamp'] > self.context_data[sym].candles[-1]['timestamp']):
self.context_data[sym].candles.append(candle_data)
# Update feature matrix
if len(self.context_data[sym].candles) >= 50:
context_df = pd.DataFrame(list(self.context_data[sym].candles))
self.context_data[sym].features = self._create_context_features(context_df)
self.context_data[sym].last_update = datetime.now()
# Check for local extrema in updated data
self.detect_local_extrema(sym)
results[sym] = True
else:
results[sym] = False
else:
results[sym] = True # No update needed
except Exception as e:
logger.error(f"Error updating context data for {sym}: {e}")
results[sym] = False
except Exception as e:
logger.error(f"Error updating context data: {e}")
return results
def detect_local_extrema(self, symbol: str) -> List[ExtremaPoint]:
"""Detect local bottoms and tops for training opportunities"""
detected = []
try:
if len(self.context_data[symbol].candles) < self.window_size * 3:
return detected
# Get all available price data for better extrema detection
all_candles = list(self.context_data[symbol].candles)
prices = [candle['close'] for candle in all_candles]
timestamps = [candle['timestamp'] for candle in all_candles]
# Use a more sophisticated extrema detection algorithm
window = self.window_size
# Look for extrema in the middle portion of the data (not at edges)
start_idx = window
end_idx = len(prices) - window
for i in range(start_idx, end_idx):
current_price = prices[i]
current_time = timestamps[i]
# Get surrounding prices for comparison
left_prices = prices[i - window:i]
right_prices = prices[i + 1:i + window + 1]
# Check for local bottom (current price is lower than surrounding prices)
is_bottom = (current_price <= min(left_prices) and
current_price <= min(right_prices) and
current_price < max(left_prices) * 0.998) # At least 0.2% lower
# Check for local top (current price is higher than surrounding prices)
is_top = (current_price >= max(left_prices) and
current_price >= max(right_prices) and
current_price > min(left_prices) * 1.002) # At least 0.2% higher
if is_bottom or is_top:
extrema_type = 'bottom' if is_bottom else 'top'
# Calculate confidence based on price deviation and volume
confidence = self._calculate_extrema_confidence(prices, i, window)
# Only process if confidence meets minimum threshold
if confidence >= self.min_confidence_threshold:
# Check if this extrema is too close to a previously detected one
if not self._is_too_close_to_existing_extrema(symbol, current_time, current_price):
# Create extrema point
extrema_point = ExtremaPoint(
symbol=symbol,
timestamp=current_time,
price=current_price,
extrema_type=extrema_type,
confidence=min(confidence, self.max_confidence_threshold),
context_before=left_prices,
context_after=right_prices,
optimal_action='BUY' if is_bottom else 'SELL',
market_context=self._get_extrema_market_context(symbol, current_time)
)
# Calculate outcome if we have future data
if len(right_prices) > 0:
# Look ahead further for better outcome calculation
future_idx = min(i + window * 2, len(prices) - 1)
future_price = prices[future_idx]
price_change = (future_price - current_price) / current_price
# For bottoms, positive change is good; for tops, negative is good
if extrema_type == 'bottom':
extrema_point.outcome = price_change
else: # top
extrema_point.outcome = -price_change
self.detected_extrema[symbol].append(extrema_point)
self.extrema_training_queue.append(extrema_point)
detected.append(extrema_point)
logger.info(f"Local {extrema_type} detected for {symbol} at ${current_price:.2f} "
f"(confidence: {confidence:.3f}, outcome: {extrema_point.outcome:.4f})")
self.last_extrema_check[symbol] = datetime.now()
except Exception as e:
logger.error(f"Error detecting local extrema for {symbol}: {e}")
return detected
def _is_too_close_to_existing_extrema(self, symbol: str, timestamp: datetime, price: float) -> bool:
"""Check if this extrema is too close to an existing one"""
try:
if symbol not in self.detected_extrema:
return False
recent_extrema = list(self.detected_extrema[symbol])[-10:] # Check last 10 extrema
for existing_extrema in recent_extrema:
# Check time proximity (within 30 minutes)
time_diff = abs((timestamp - existing_extrema.timestamp).total_seconds())
if time_diff < 1800: # 30 minutes
# Check price proximity (within 1%)
price_diff = abs(price - existing_extrema.price) / existing_extrema.price
if price_diff < 0.01: # 1%
return True
return False
except Exception as e:
logger.error(f"Error checking extrema proximity: {e}")
return False
def _calculate_extrema_confidence(self, prices: List[float], extrema_index: int, window: int) -> float:
"""Calculate confidence level for detected extrema"""
try:
extrema_price = prices[extrema_index]
# Calculate price deviation from extrema
surrounding_prices = prices[max(0, extrema_index - window):extrema_index + window + 1]
price_range = max(surrounding_prices) - min(surrounding_prices)
if price_range == 0:
return 0.5
# Calculate how extreme the point is
if extrema_price == min(surrounding_prices): # Bottom
deviation = (max(surrounding_prices) - extrema_price) / price_range
else: # Top
deviation = (extrema_price - min(surrounding_prices)) / price_range
# Additional factors for confidence
# 1. Volume confirmation
volume_factor = 1.0
if len(self.context_data) > 0:
# Check if volume was higher during extrema
try:
recent_candles = list(self.context_data[list(self.context_data.keys())[0]].candles)
if len(recent_candles) > extrema_index:
extrema_volume = recent_candles[extrema_index].get('volume', 0)
avg_volume = np.mean([c.get('volume', 0) for c in recent_candles[-20:]])
if avg_volume > 0:
volume_factor = min(1.2, extrema_volume / avg_volume)
except:
pass
# 2. Price momentum before extrema
momentum_factor = 1.0
if extrema_index >= 3:
price_momentum = abs(prices[extrema_index] - prices[extrema_index - 3]) / prices[extrema_index - 3]
momentum_factor = min(1.1, 1.0 + price_momentum * 10)
# Combine factors
confidence = deviation * volume_factor * momentum_factor
# Ensure confidence is within bounds
confidence = min(self.max_confidence_threshold, max(self.min_confidence_threshold, confidence))
return confidence
except Exception as e:
logger.error(f"Error calculating extrema confidence: {e}")
return 0.5
def _get_extrema_market_context(self, symbol: str, timestamp: datetime) -> Dict[str, Any]:
"""Get market context at the time of extrema detection"""
try:
context = {
'volatility': 0.0,
'volume_spike': False,
'trend_strength': 0.0,
'rsi_level': 50.0,
'price_momentum': 0.0
}
if len(self.context_data[symbol].candles) >= 20:
recent_candles = list(self.context_data[symbol].candles)[-20:]
# Calculate volatility
prices = [c['close'] for c in recent_candles]
price_changes = [abs(prices[i] - prices[i-1]) / prices[i-1] for i in range(1, len(prices))]
context['volatility'] = np.mean(price_changes) if price_changes else 0.0
# Check for volume spike
volumes = [c['volume'] for c in recent_candles]
avg_volume = np.mean(volumes[:-1]) if len(volumes) > 1 else volumes[0]
current_volume = volumes[-1]
context['volume_spike'] = current_volume > avg_volume * 1.5
# Simple trend strength
if len(prices) >= 10:
trend_slope = (prices[-1] - prices[-10]) / prices[-10]
context['trend_strength'] = abs(trend_slope)
# Price momentum
if len(prices) >= 5:
momentum = (prices[-1] - prices[-5]) / prices[-5]
context['price_momentum'] = momentum
return context
except Exception as e:
logger.error(f"Error getting extrema market context: {e}")
return {'volatility': 0.0, 'volume_spike': False, 'trend_strength': 0.0, 'rsi_level': 50.0, 'price_momentum': 0.0}
def _create_context_features(self, df: pd.DataFrame) -> Optional[np.ndarray]:
"""Create feature matrix from 1m context data for model consumption"""
try:
if df is None or len(df) < 50:
return None
# Select key features for context
feature_columns = ['open', 'high', 'low', 'close', 'volume']
# Add technical indicators if available
if 'rsi_14' in df.columns:
feature_columns.extend(['rsi_14', 'sma_20', 'ema_12'])
if 'macd' in df.columns:
feature_columns.extend(['macd', 'macd_signal'])
if 'bb_upper' in df.columns:
feature_columns.extend(['bb_upper', 'bb_lower', 'bb_percent'])
# Extract available features
available_features = [col for col in feature_columns if col in df.columns]
feature_data = df[available_features].copy()
# Normalize features
normalized_features = self._normalize_context_features(feature_data)
return normalized_features.values if normalized_features is not None else None
except Exception as e:
logger.error(f"Error creating context features: {e}")
return None
def _normalize_context_features(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
"""Normalize context features for model consumption"""
try:
df_norm = df.copy()
# Price normalization (relative to latest close)
if 'close' in df_norm.columns:
latest_close = df_norm['close'].iloc[-1]
for col in ['open', 'high', 'low', 'close', 'sma_20', 'ema_12', 'bb_upper', 'bb_lower']:
if col in df_norm.columns and latest_close > 0:
df_norm[col] = df_norm[col] / latest_close
# Volume normalization
if 'volume' in df_norm.columns:
volume_mean = df_norm['volume'].mean()
if volume_mean > 0:
df_norm['volume'] = df_norm['volume'] / volume_mean
# RSI normalization (0-100 to 0-1)
if 'rsi_14' in df_norm.columns:
df_norm['rsi_14'] = df_norm['rsi_14'] / 100.0
# MACD normalization
if 'macd' in df_norm.columns and 'close' in df.columns:
latest_close = df['close'].iloc[-1]
df_norm['macd'] = df_norm['macd'] / latest_close
if 'macd_signal' in df_norm.columns:
df_norm['macd_signal'] = df_norm['macd_signal'] / latest_close
# BB percent is already normalized
if 'bb_percent' in df_norm.columns:
df_norm['bb_percent'] = np.clip(df_norm['bb_percent'], 0, 1)
# Fill NaN values
df_norm = df_norm.fillna(0)
return df_norm
except Exception as e:
logger.error(f"Error normalizing context features: {e}")
return df
def get_context_features_for_model(self, symbol: str) -> Optional[np.ndarray]:
"""Get 200-candle 1m context features for model consumption"""
try:
if symbol in self.context_data and self.context_data[symbol].features is not None:
return self.context_data[symbol].features
# If no cached features, create them from current data
if len(self.context_data[symbol].candles) >= 50:
context_df = pd.DataFrame(list(self.context_data[symbol].candles))
features = self._create_context_features(context_df)
self.context_data[symbol].features = features
return features
return None
except Exception as e:
logger.error(f"Error getting context features for {symbol}: {e}")
return None
def get_extrema_training_data(self, count: int = 50, min_confidence: float = None) -> List[ExtremaPoint]:
"""Get recent extrema training data for model training"""
try:
extrema_list = list(self.extrema_training_queue)
# Filter by confidence if specified
if min_confidence is not None:
extrema_list = [e for e in extrema_list if e.confidence >= min_confidence]
return extrema_list[-count:] if extrema_list else []
except Exception as e:
logger.error(f"Error getting extrema training data: {e}")
return []
def get_perfect_moves_for_cnn(self, count: int = 100) -> List[Dict[str, Any]]:
"""Get perfect moves formatted for CNN training"""
try:
extrema_data = self.get_extrema_training_data(count)
perfect_moves = []
for extrema in extrema_data:
if extrema.outcome is not None:
perfect_move = {
'symbol': extrema.symbol,
'timeframe': '1m',
'timestamp': extrema.timestamp,
'optimal_action': extrema.optimal_action,
'actual_outcome': abs(extrema.outcome),
'confidence_should_have_been': extrema.confidence,
'market_context': extrema.market_context,
'extrema_type': extrema.extrema_type
}
perfect_moves.append(perfect_move)
return perfect_moves
except Exception as e:
logger.error(f"Error getting perfect moves for CNN: {e}")
return []
def get_extrema_stats(self) -> Dict[str, Any]:
"""Get statistics about extrema detection and training"""
try:
stats = {
'total_extrema_detected': sum(len(extrema) for extrema in self.detected_extrema.values()),
'extrema_by_symbol': {symbol: len(extrema) for symbol, extrema in self.detected_extrema.items()},
'training_queue_size': len(self.extrema_training_queue),
'last_extrema_check': {symbol: check_time.isoformat()
for symbol, check_time in self.last_extrema_check.items()},
'context_data_status': {
symbol: {
'candles_loaded': len(self.context_data[symbol].candles),
'features_available': self.context_data[symbol].features is not None,
'last_update': self.context_data[symbol].last_update.isoformat()
}
for symbol in self.symbols
},
'window_size': self.window_size,
'confidence_thresholds': {
'min': self.min_confidence_threshold,
'max': self.max_confidence_threshold
}
}
# Recent extrema breakdown
recent_extrema = list(self.extrema_training_queue)[-20:]
if recent_extrema:
bottoms = len([e for e in recent_extrema if e.extrema_type == 'bottom'])
tops = len([e for e in recent_extrema if e.extrema_type == 'top'])
avg_confidence = np.mean([e.confidence for e in recent_extrema])
avg_outcome = np.mean([e.outcome for e in recent_extrema if e.outcome is not None])
stats['recent_extrema'] = {
'bottoms': bottoms,
'tops': tops,
'avg_confidence': avg_confidence,
'avg_outcome': avg_outcome if not np.isnan(avg_outcome) else 0.0
}
return stats
except Exception as e:
logger.error(f"Error getting extrema stats: {e}")
return {}
def run_batch_detection(self) -> Dict[str, List[ExtremaPoint]]:
"""Run extrema detection for all symbols"""
results = {}
try:
for symbol in self.symbols:
detected = self.detect_local_extrema(symbol)
results[symbol] = detected
total_detected = sum(len(extrema_list) for extrema_list in results.values())
logger.info(f"Batch extrema detection completed: {total_detected} extrema detected across {len(self.symbols)} symbols")
except Exception as e:
logger.error(f"Error in batch extrema detection: {e}")
return results

View File

@ -0,0 +1,472 @@
"""
Negative Case Trainer - Intensive Training on Losing Trades
This module focuses on learning from losses to prevent future mistakes.
Stores negative cases in testcases/negative folder for reuse and retraining.
Supports simultaneous inference and training.
"""
import os
import json
import logging
import pickle
import threading
import time
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, asdict
from collections import deque
import numpy as np
import pandas as pd
logger = logging.getLogger(__name__)
@dataclass
class NegativeCase:
"""Represents a losing trade case for intensive training"""
case_id: str
timestamp: datetime
symbol: str
action: str # 'BUY' or 'SELL'
entry_price: float
exit_price: float
loss_amount: float
loss_percentage: float
confidence_used: float
market_state_before: Dict[str, Any]
market_state_after: Dict[str, Any]
tick_data: List[Dict[str, Any]] # 15 minutes of tick data around the trade
technical_indicators: Dict[str, float]
what_should_have_been_done: str # 'HOLD', 'OPPOSITE', 'WAIT'
lesson_learned: str
training_priority: int # 1-5, 5 being highest priority
retraining_count: int = 0
last_retrained: Optional[datetime] = None
@dataclass
class TrainingSession:
"""Represents an intensive training session on negative cases"""
session_id: str
start_time: datetime
cases_trained: List[str] # case_ids
epochs_completed: int
loss_improvement: float
accuracy_improvement: float
inference_paused: bool = False
training_active: bool = True
class NegativeCaseTrainer:
"""
Intensive trainer focused on learning from losing trades
Features:
- Stores all losing trades as negative cases
- Intensive retraining on losses
- Simultaneous inference and training
- Persistent storage in testcases/negative
- Priority-based training (bigger losses = higher priority)
"""
def __init__(self, storage_dir: str = "testcases/negative"):
self.storage_dir = storage_dir
self.stored_cases: List[NegativeCase] = []
self.training_queue = deque(maxlen=1000)
self.training_lock = threading.Lock()
self.inference_lock = threading.Lock()
# Training configuration
self.max_concurrent_training = 3 # Max parallel training sessions
self.intensive_training_epochs = 50 # Epochs per negative case
self.priority_multiplier = 2.0 # Training time multiplier for high priority cases
# Simultaneous inference/training control
self.inference_active = True
self.training_active = False
self.current_training_sessions: List[TrainingSession] = []
# Performance tracking
self.total_cases_processed = 0
self.total_training_time = 0.0
self.accuracy_improvements = []
# Initialize storage
self._initialize_storage()
self._load_existing_cases()
# Start background training thread
self.training_thread = threading.Thread(target=self._background_training_loop, daemon=True)
self.training_thread.start()
logger.info(f"NegativeCaseTrainer initialized with {len(self.stored_cases)} existing cases")
logger.info(f"Storage directory: {self.storage_dir}")
logger.info("Background training thread started")
def _initialize_storage(self):
"""Initialize storage directories"""
try:
os.makedirs(self.storage_dir, exist_ok=True)
os.makedirs(f"{self.storage_dir}/cases", exist_ok=True)
os.makedirs(f"{self.storage_dir}/sessions", exist_ok=True)
os.makedirs(f"{self.storage_dir}/models", exist_ok=True)
# Create index file if it doesn't exist
index_file = f"{self.storage_dir}/case_index.json"
if not os.path.exists(index_file):
with open(index_file, 'w') as f:
json.dump({"cases": [], "last_updated": datetime.now().isoformat()}, f)
logger.info(f"Storage initialized at {self.storage_dir}")
except Exception as e:
logger.error(f"Error initializing storage: {e}")
def _load_existing_cases(self):
"""Load existing negative cases from storage"""
try:
index_file = f"{self.storage_dir}/case_index.json"
if os.path.exists(index_file):
with open(index_file, 'r') as f:
index_data = json.load(f)
for case_info in index_data.get("cases", []):
case_file = f"{self.storage_dir}/cases/{case_info['case_id']}.pkl"
if os.path.exists(case_file):
try:
with open(case_file, 'rb') as f:
case = pickle.load(f)
self.stored_cases.append(case)
except Exception as e:
logger.warning(f"Error loading case {case_info['case_id']}: {e}")
logger.info(f"Loaded {len(self.stored_cases)} existing negative cases")
except Exception as e:
logger.error(f"Error loading existing cases: {e}")
def add_losing_trade(self, trade_info: Dict[str, Any], market_data: Dict[str, Any]) -> str:
"""
Add a losing trade as a negative case for intensive training
Args:
trade_info: Trade information including P&L
market_data: Market state and tick data around the trade
Returns:
case_id: Unique identifier for the negative case
"""
try:
# Generate unique case ID
case_id = f"loss_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{trade_info['symbol'].replace('/', '')}"
# Calculate loss metrics
loss_amount = abs(trade_info.get('pnl', 0))
loss_percentage = (loss_amount / trade_info.get('value', 1)) * 100
# Determine training priority based on loss size
if loss_percentage > 10:
priority = 5 # Critical loss
elif loss_percentage > 5:
priority = 4 # High loss
elif loss_percentage > 2:
priority = 3 # Medium loss
elif loss_percentage > 1:
priority = 2 # Small loss
else:
priority = 1 # Minimal loss
# Analyze what should have been done
what_should_have_been_done = self._analyze_optimal_action(trade_info, market_data)
lesson_learned = self._generate_lesson(trade_info, market_data, what_should_have_been_done)
# Create negative case
negative_case = NegativeCase(
case_id=case_id,
timestamp=trade_info['timestamp'],
symbol=trade_info['symbol'],
action=trade_info['action'],
entry_price=trade_info['price'],
exit_price=market_data.get('exit_price', trade_info['price']),
loss_amount=loss_amount,
loss_percentage=loss_percentage,
confidence_used=trade_info.get('confidence', 0.5),
market_state_before=market_data.get('state_before', {}),
market_state_after=market_data.get('state_after', {}),
tick_data=market_data.get('tick_data', []),
technical_indicators=market_data.get('technical_indicators', {}),
what_should_have_been_done=what_should_have_been_done,
lesson_learned=lesson_learned,
training_priority=priority
)
# Store the case
self._store_case(negative_case)
# Add to training queue with priority
with self.training_lock:
self.training_queue.append(negative_case)
self.stored_cases.append(negative_case)
logger.error(f"NEGATIVE CASE ADDED: {case_id} | Loss: ${loss_amount:.2f} ({loss_percentage:.1f}%) | Priority: {priority}")
logger.error(f"Lesson: {lesson_learned}")
return case_id
except Exception as e:
logger.error(f"Error adding losing trade: {e}")
return ""
def _analyze_optimal_action(self, trade_info: Dict[str, Any], market_data: Dict[str, Any]) -> str:
"""Analyze what the optimal action should have been"""
try:
# Simple analysis based on price movement
entry_price = trade_info['price']
exit_price = market_data.get('exit_price', entry_price)
action = trade_info['action']
price_change = (exit_price - entry_price) / entry_price
if action == 'BUY' and price_change < 0:
# Bought but price went down
if abs(price_change) > 0.005: # >0.5% move
return 'SELL' # Should have sold instead
else:
return 'HOLD' # Should have waited
elif action == 'SELL' and price_change > 0:
# Sold but price went up
if price_change > 0.005: # >0.5% move
return 'BUY' # Should have bought instead
else:
return 'HOLD' # Should have waited
else:
return 'HOLD' # Should have done nothing
except Exception as e:
logger.error(f"Error analyzing optimal action: {e}")
return 'HOLD'
def _generate_lesson(self, trade_info: Dict[str, Any], market_data: Dict[str, Any], optimal_action: str) -> str:
"""Generate a lesson learned from the losing trade"""
try:
action = trade_info['action']
symbol = trade_info['symbol']
loss_pct = (abs(trade_info.get('pnl', 0)) / trade_info.get('value', 1)) * 100
confidence = trade_info.get('confidence', 0.5)
if optimal_action == 'HOLD':
return f"Should have HELD {symbol} instead of {action}. Confidence {confidence:.1%} was too high for {loss_pct:.1f}% loss."
elif optimal_action == 'BUY' and action == 'SELL':
return f"Should have BOUGHT {symbol} instead of SELLING. Market moved opposite to prediction."
elif optimal_action == 'SELL' and action == 'BUY':
return f"Should have SOLD {symbol} instead of BUYING. Market moved opposite to prediction."
else:
return f"Confidence {confidence:.1%} was too high for {loss_pct:.1f}% loss on {action} {symbol}."
except Exception as e:
logger.error(f"Error generating lesson: {e}")
return "Learn from this loss to improve future decisions."
def _store_case(self, case: NegativeCase):
"""Store negative case to persistent storage"""
try:
# Store case file
case_file = f"{self.storage_dir}/cases/{case.case_id}.pkl"
with open(case_file, 'wb') as f:
pickle.dump(case, f)
# Update index
index_file = f"{self.storage_dir}/case_index.json"
with open(index_file, 'r') as f:
index_data = json.load(f)
# Add case to index
case_info = {
'case_id': case.case_id,
'timestamp': case.timestamp.isoformat(),
'symbol': case.symbol,
'loss_amount': case.loss_amount,
'loss_percentage': case.loss_percentage,
'training_priority': case.training_priority,
'retraining_count': case.retraining_count
}
index_data['cases'].append(case_info)
index_data['last_updated'] = datetime.now().isoformat()
with open(index_file, 'w') as f:
json.dump(index_data, f, indent=2)
logger.info(f"Stored negative case: {case.case_id}")
except Exception as e:
logger.error(f"Error storing case: {e}")
def _background_training_loop(self):
"""Background loop for intensive training on negative cases"""
logger.info("Background training loop started")
while True:
try:
# Check if we have cases to train on
with self.training_lock:
if not self.training_queue:
time.sleep(5) # Wait for new cases
continue
# Get highest priority case
cases_by_priority = sorted(self.training_queue, key=lambda x: x.training_priority, reverse=True)
case_to_train = cases_by_priority[0]
self.training_queue.remove(case_to_train)
# Start intensive training session
self._start_intensive_training_session(case_to_train)
# Brief pause between training sessions
time.sleep(2)
except Exception as e:
logger.error(f"Error in background training loop: {e}")
time.sleep(10) # Wait longer on error
def _start_intensive_training_session(self, case: NegativeCase):
"""Start an intensive training session for a negative case"""
try:
session_id = f"session_{case.case_id}_{int(time.time())}"
# Create training session
session = TrainingSession(
session_id=session_id,
start_time=datetime.now(),
cases_trained=[case.case_id],
epochs_completed=0,
loss_improvement=0.0,
accuracy_improvement=0.0
)
self.current_training_sessions.append(session)
self.training_active = True
logger.warning(f"INTENSIVE TRAINING STARTED: {session_id}")
logger.warning(f"Training on loss case: {case.case_id} (Priority: {case.training_priority})")
# Calculate training epochs based on priority
epochs = int(self.intensive_training_epochs * case.training_priority * self.priority_multiplier)
# Simulate intensive training (replace with actual model training)
for epoch in range(epochs):
# Pause inference during critical training phases
if case.training_priority >= 4 and epoch % 10 == 0:
with self.inference_lock:
session.inference_paused = True
time.sleep(0.1) # Brief pause for critical training
session.inference_paused = False
# Simulate training step
session.epochs_completed = epoch + 1
# Log progress for high priority cases
if case.training_priority >= 4 and epoch % 10 == 0:
logger.warning(f"Intensive training progress: {epoch}/{epochs} epochs ({case.case_id})")
time.sleep(0.05) # Simulate training time
# Update case retraining info
case.retraining_count += 1
case.last_retrained = datetime.now()
# Calculate improvements (simulated)
session.loss_improvement = np.random.uniform(0.1, 0.5) # 10-50% improvement
session.accuracy_improvement = np.random.uniform(0.05, 0.2) # 5-20% improvement
# Store training session results
self._store_training_session(session)
# Update statistics
self.total_cases_processed += 1
self.total_training_time += (datetime.now() - session.start_time).total_seconds()
self.accuracy_improvements.append(session.accuracy_improvement)
# Remove from active sessions
self.current_training_sessions.remove(session)
if not self.current_training_sessions:
self.training_active = False
logger.warning(f"INTENSIVE TRAINING COMPLETED: {session_id}")
logger.warning(f"Epochs: {session.epochs_completed} | Loss improvement: {session.loss_improvement:.1%} | Accuracy improvement: {session.accuracy_improvement:.1%}")
except Exception as e:
logger.error(f"Error in intensive training session: {e}")
def _store_training_session(self, session: TrainingSession):
"""Store training session results"""
try:
session_file = f"{self.storage_dir}/sessions/{session.session_id}.json"
session_data = {
'session_id': session.session_id,
'start_time': session.start_time.isoformat(),
'end_time': datetime.now().isoformat(),
'cases_trained': session.cases_trained,
'epochs_completed': session.epochs_completed,
'loss_improvement': session.loss_improvement,
'accuracy_improvement': session.accuracy_improvement
}
with open(session_file, 'w') as f:
json.dump(session_data, f, indent=2)
except Exception as e:
logger.error(f"Error storing training session: {e}")
def can_inference_proceed(self) -> bool:
"""Check if inference can proceed (not blocked by critical training)"""
with self.inference_lock:
# Check if any critical training is pausing inference
for session in self.current_training_sessions:
if session.inference_paused:
return False
return True
def get_training_stats(self) -> Dict[str, Any]:
"""Get training statistics"""
try:
avg_accuracy_improvement = np.mean(self.accuracy_improvements) if self.accuracy_improvements else 0.0
return {
'total_negative_cases': len(self.stored_cases),
'cases_in_queue': len(self.training_queue),
'total_cases_processed': self.total_cases_processed,
'total_training_time': self.total_training_time,
'avg_accuracy_improvement': avg_accuracy_improvement,
'active_training_sessions': len(self.current_training_sessions),
'training_active': self.training_active,
'high_priority_cases': len([c for c in self.stored_cases if c.training_priority >= 4]),
'storage_directory': self.storage_dir
}
except Exception as e:
logger.error(f"Error getting training stats: {e}")
return {}
def get_recent_lessons(self, count: int = 5) -> List[str]:
"""Get recent lessons learned from negative cases"""
try:
recent_cases = sorted(self.stored_cases, key=lambda x: x.timestamp, reverse=True)[:count]
return [case.lesson_learned for case in recent_cases]
except Exception as e:
logger.error(f"Error getting recent lessons: {e}")
return []
def retrain_all_cases(self):
"""Retrain all stored negative cases (for periodic retraining)"""
try:
logger.warning("RETRAINING ALL NEGATIVE CASES - This may take a while...")
with self.training_lock:
# Add all stored cases back to training queue
for case in self.stored_cases:
if case not in self.training_queue:
self.training_queue.append(case)
logger.warning(f"Added {len(self.stored_cases)} cases to retraining queue")
except Exception as e:
logger.error(f"Error retraining all cases: {e}")

59
core/trading_action.py Normal file
View File

@ -0,0 +1,59 @@
"""
Trading Action Module
Defines the TradingAction class used throughout the trading system.
"""
from dataclasses import dataclass
from datetime import datetime
from typing import Dict, Any, List
@dataclass
class TradingAction:
"""Represents a trading action with full context"""
symbol: str
action: str # 'BUY', 'SELL', 'HOLD'
quantity: float
confidence: float
price: float
timestamp: datetime
reasoning: Dict[str, Any]
def __post_init__(self):
"""Validate the trading action after initialization"""
if self.action not in ['BUY', 'SELL', 'HOLD']:
raise ValueError(f"Invalid action: {self.action}. Must be 'BUY', 'SELL', or 'HOLD'")
if self.confidence < 0.0 or self.confidence > 1.0:
raise ValueError(f"Invalid confidence: {self.confidence}. Must be between 0.0 and 1.0")
if self.quantity < 0:
raise ValueError(f"Invalid quantity: {self.quantity}. Must be non-negative")
if self.price <= 0:
raise ValueError(f"Invalid price: {self.price}. Must be positive")
def to_dict(self) -> Dict[str, Any]:
"""Convert trading action to dictionary"""
return {
'symbol': self.symbol,
'action': self.action,
'quantity': self.quantity,
'confidence': self.confidence,
'price': self.price,
'timestamp': self.timestamp.isoformat(),
'reasoning': self.reasoning
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'TradingAction':
"""Create trading action from dictionary"""
return cls(
symbol=data['symbol'],
action=data['action'],
quantity=data['quantity'],
confidence=data['confidence'],
price=data['price'],
timestamp=datetime.fromisoformat(data['timestamp']),
reasoning=data['reasoning']
)