extrema trainer WIP

This commit is contained in:
Dobromir Popov
2025-08-10 03:20:13 +03:00
parent 1cc8509e87
commit b2faa9b6ca
5 changed files with 315 additions and 35 deletions

View File

@ -43,6 +43,22 @@ class ExtremaPoint:
market_context: Dict[str, Any]
outcome: Optional[float] = None # Price change after extrema
@dataclass
class PredictedPivot:
"""Represents a prediction of the next pivot point within a capped horizon"""
symbol: str
created_at: datetime
current_price: float
predicted_time: datetime
predicted_price: float
horizon_seconds: int
target_type: str # 'top' or 'bottom'
confidence: float
evaluated: bool = False
success: Optional[bool] = None
error_abs: Optional[float] = None # absolute price error at eval
time_error_s: Optional[int] = None # time offset at eval
@dataclass
class ContextData:
"""200-candle 1m context data for enhanced model performance"""
@ -103,6 +119,10 @@ class ExtremaTrainer:
'successful_predictions': 0,
'failed_predictions': 0,
'detection_accuracy': 0.0,
'prediction_evaluations': 0,
'prediction_successes': 0,
'prediction_mae': 0.0, # mean absolute error for price
'prediction_mte': 0.0, # mean time error seconds
'last_training_time': None
}
@ -114,6 +134,140 @@ class ExtremaTrainer:
logger.info(f"Window size: {window_size}, Context update frequency: {self.context_update_frequency}s")
logger.info(f"Checkpoint management: {enable_checkpoints}, Model name: {model_name}")
# Next pivot prediction management
self.prediction_window_seconds = 300 # cap at 5 minutes
self.pending_predictions: Dict[str, deque] = {symbol: deque(maxlen=200) for symbol in symbols}
self.last_prediction: Dict[str, Optional[PredictedPivot]] = {symbol: None for symbol in symbols}
# === Prediction API ===
def predict_next_pivot(self, symbol: str, now: Optional[datetime] = None, current_price: Optional[float] = None) -> Optional[PredictedPivot]:
"""Predict next pivot point (time, price) within 5 minutes using real context.
Strategy (baseline, fully real-data-driven):
- Determine last detected extrema type from recent detections; target the opposite type next.
- Estimate horizon by median time gap between recent extrema (capped to 300s, floored at 30s).
- Estimate amplitude by median absolute price change between recent extrema; project from current_price in the direction implied by target type.
- Confidence derived from recent detection confidence averages (bounded).
"""
try:
if symbol not in self.detected_extrema:
return None
now = now or datetime.now()
# Use current price from provider if not passed
if current_price is None:
try:
if hasattr(self.data_provider, 'get_current_price'):
current_price = self.data_provider.get_current_price(symbol) or 0.0
except Exception:
current_price = 0.0
if not current_price or current_price <= 0:
return None
recent = list(self.detected_extrema[symbol])[-10:]
if not recent:
return None
# Determine last extrema
last_ext = recent[-1]
target_type = 'top' if last_ext.extrema_type == 'bottom' else 'bottom'
# Estimate horizon as median delta between last extrema timestamps
gaps = []
for i in range(1, len(recent)):
gaps.append((recent[i].timestamp - recent[i-1].timestamp).total_seconds())
median_gap = int(np.median(gaps)) if gaps else 60
horizon_s = max(30, min(self.prediction_window_seconds, median_gap))
# Estimate amplitude as median absolute change between extrema
price_changes = []
for i in range(1, len(recent)):
price_changes.append(abs(recent[i].price - recent[i-1].price))
median_amp = float(np.median(price_changes)) if price_changes else current_price * 0.002 # ~0.2%
predicted_price = current_price + (median_amp if target_type == 'top' else -median_amp)
predicted_time = now + timedelta(seconds=horizon_s)
# Confidence from average of recent detection confidences
conf_vals = [e.confidence for e in recent]
confidence = float(np.mean(conf_vals)) if conf_vals else 0.5
confidence = max(0.1, min(0.95, confidence))
pred = PredictedPivot(
symbol=symbol,
created_at=now,
current_price=current_price,
predicted_time=predicted_time,
predicted_price=predicted_price,
horizon_seconds=horizon_s,
target_type=target_type,
confidence=confidence
)
self.pending_predictions[symbol].append(pred)
self.last_prediction[symbol] = pred
return pred
except Exception as e:
logger.error(f"Error predicting next pivot for {symbol}: {e}")
return None
def get_latest_prediction(self, symbol: str) -> Optional[PredictedPivot]:
return self.last_prediction.get(symbol)
def evaluate_pending_predictions(self, symbol: str) -> int:
"""Evaluate pending predictions within the 5-minute window using detected extrema.
Returns number of evaluations performed.
"""
try:
if symbol not in self.pending_predictions:
return 0
now = datetime.now()
evaluated = 0
# Build a quick index of detected extrema within last 10 minutes
recent_extrema = [e for e in self.detected_extrema[symbol] if (now - e.timestamp).total_seconds() <= 600]
for pred in list(self.pending_predictions[symbol]):
if pred.evaluated:
continue
# If evaluation horizon passed, evaluate against nearest extrema in time
if (now - pred.created_at).total_seconds() >= min(self.prediction_window_seconds, pred.horizon_seconds):
# Find extrema closest in time after creation
candidate = None
min_dt = None
for e in recent_extrema:
if e.timestamp >= pred.created_at and e.extrema_type == pred.target_type:
dt = abs((e.timestamp - pred.predicted_time).total_seconds())
if min_dt is None or dt < min_dt:
min_dt = dt
candidate = e
if candidate is not None:
price_err = abs(candidate.price - pred.predicted_price)
time_err = int(abs((candidate.timestamp - pred.predicted_time).total_seconds()))
# Decide success with simple thresholds
price_tol = max(0.001 * pred.current_price, 0.5) # 0.1% or $0.5
time_tol = 90 # 1.5 minutes
success = (price_err <= price_tol) and (time_err <= time_tol)
pred.evaluated = True
pred.success = success
pred.error_abs = price_err
pred.time_error_s = time_err
self.training_stats['prediction_evaluations'] += 1
if success:
self.training_stats['prediction_successes'] += 1
# Update running means
n = self.training_stats['prediction_evaluations']
prev_mae = self.training_stats['prediction_mae']
prev_mte = self.training_stats['prediction_mte']
self.training_stats['prediction_mae'] = ((prev_mae * (n - 1)) + price_err) / n
self.training_stats['prediction_mte'] = ((prev_mte * (n - 1)) + time_err) / n
evaluated += 1
# Optionally checkpoint on batch
if evaluated > 0:
self.save_checkpoint(force_save=False)
return evaluated
except Exception as e:
logger.error(f"Error evaluating predictions for {symbol}: {e}")
return 0
def load_best_checkpoint(self):
"""Load the best checkpoint for this extrema trainer"""
try:
@ -182,6 +336,10 @@ class ExtremaTrainer:
symbol: list(extrema_deque)
for symbol, extrema_deque in self.detected_extrema.items()
},
'last_prediction': {
symbol: (self._serialize_prediction(pred) if pred else None)
for symbol, pred in self.last_prediction.items()
},
'window_size': self.window_size,
'symbols': self.symbols
}
@ -216,6 +374,25 @@ class ExtremaTrainer:
except Exception as e:
logger.error(f"Error saving ExtremaTrainer checkpoint: {e}")
return False
def _serialize_prediction(self, pred: PredictedPivot) -> Dict[str, Any]:
try:
return {
'symbol': pred.symbol,
'created_at': pred.created_at.isoformat(),
'current_price': pred.current_price,
'predicted_time': pred.predicted_time.isoformat(),
'predicted_price': pred.predicted_price,
'horizon_seconds': pred.horizon_seconds,
'target_type': pred.target_type,
'confidence': pred.confidence,
'evaluated': pred.evaluated,
'success': pred.success,
'error_abs': pred.error_abs,
'time_error_s': pred.time_error_s,
}
except Exception:
return {}
def initialize_context_data(self) -> Dict[str, bool]:
"""Initialize 200-candle 1m context data for all symbols"""

View File

@ -976,15 +976,21 @@ class TradingOrchestrator:
# The presence of features indicates a signal. We'll return a generic HOLD
# with a neutral confidence. This can be refined if ExtremaTrainer provides
# more specific BUY/SELL signals directly.
return {
"action": "HOLD",
"confidence": 0.5,
"probabilities": {
"BUY": 0.33,
"SELL": 0.33,
"HOLD": 0.34,
},
}
# Provide next-pivot prediction vector capped at 5 min
pred = self.model.predict_next_pivot(symbol=symbol)
if pred:
return {
"action": "HOLD",
"confidence": pred.confidence,
"prediction": {
"target_type": pred.target_type,
"predicted_time": pred.predicted_time,
"predicted_price": pred.predicted_price,
"horizon_seconds": pred.horizon_seconds,
},
}
# Fallback neutral
return {"action": "HOLD", "confidence": 0.5}
return None
except Exception as e:
logger.error(