Files
gogo2/core/cnn_pivot_predictor.py
2025-06-13 11:48:03 +03:00

285 lines
12 KiB
Python

#!/usr/bin/env python3
"""
CNN Pivot Predictor Core Module
This module handles all CNN-based pivot prediction logic, separated from the web UI.
"""
import logging
import time
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional, Tuple
import json
import os
from dataclasses import dataclass
# Setup logging with ASCII-only output
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
@dataclass
class PivotPrediction:
"""Dataclass for CNN pivot predictions"""
level: int
type: str # 'HIGH' or 'LOW'
predicted_price: float
confidence: float
timestamp: datetime
current_price: float
model_inputs: Optional[Dict] = None
@dataclass
class ActualPivot:
"""Dataclass for actual detected pivots"""
type: str # 'HIGH' or 'LOW'
price: float
timestamp: datetime
strength: int
confirmed: bool = False
@dataclass
class TrainingDataPoint:
"""Dataclass for capturing training comparison data"""
prediction: PivotPrediction
actual_pivot: Optional[ActualPivot]
prediction_accuracy: Optional[float]
time_accuracy: Optional[float]
captured_at: datetime
class CNNPivotPredictor:
"""Core CNN pivot prediction engine"""
def __init__(self, config: Optional[Dict] = None):
self.config = config or self._default_config()
self.current_predictions: List[PivotPrediction] = []
self.training_data: List[TrainingDataPoint] = []
self.model_available = False
# Initialize data storage paths
self.training_data_dir = "data/cnn_training"
os.makedirs(self.training_data_dir, exist_ok=True)
logger.info("CNN Pivot Predictor initialized")
def _default_config(self) -> Dict:
"""Default configuration for CNN predictor"""
return {
'prediction_levels': 5, # Williams Market Structure levels
'confidence_threshold': 0.3,
'model_timesteps': 900,
'model_features': 50,
'prediction_horizon_minutes': 30
}
def generate_predictions(self, market_data: pd.DataFrame, current_price: float) -> List[PivotPrediction]:
"""
Generate CNN pivot predictions based on current market data
Args:
market_data: DataFrame with OHLCV data
current_price: Current market price
Returns:
List of pivot predictions
"""
try:
current_time = datetime.now()
predictions = []
# For demo purposes, generate sample predictions
# In production, this would use the actual CNN model
for level in range(1, self.config['prediction_levels'] + 1):
# HIGH pivot prediction
high_confidence = np.random.uniform(0.4, 0.9)
if high_confidence > self.config['confidence_threshold']:
high_price = current_price + np.random.uniform(10, 50)
high_prediction = PivotPrediction(
level=level,
type='HIGH',
predicted_price=high_price,
confidence=high_confidence,
timestamp=current_time + timedelta(minutes=level*5),
current_price=current_price,
model_inputs=self._prepare_model_inputs(market_data)
)
predictions.append(high_prediction)
# LOW pivot prediction
low_confidence = np.random.uniform(0.3, 0.8)
if low_confidence > self.config['confidence_threshold']:
low_price = current_price - np.random.uniform(15, 40)
low_prediction = PivotPrediction(
level=level,
type='LOW',
predicted_price=low_price,
confidence=low_confidence,
timestamp=current_time + timedelta(minutes=level*7),
current_price=current_price,
model_inputs=self._prepare_model_inputs(market_data)
)
predictions.append(low_prediction)
self.current_predictions = predictions
logger.info(f"Generated {len(predictions)} CNN pivot predictions")
return predictions
except Exception as e:
logger.error(f"Error generating CNN predictions: {e}")
return []
def _prepare_model_inputs(self, market_data: pd.DataFrame) -> Dict:
"""Prepare model inputs for CNN prediction"""
if len(market_data) < self.config['model_timesteps']:
return {'insufficient_data': True}
# Extract last 900 timesteps with 50 features
recent_data = market_data.tail(self.config['model_timesteps'])
return {
'timesteps': len(recent_data),
'features': self.config['model_features'],
'price_range': (recent_data['low'].min(), recent_data['high'].max()),
'volume_avg': recent_data['volume'].mean(),
'timestamp': datetime.now().isoformat()
}
def update_predictions(self, market_data: pd.DataFrame, current_price: float) -> List[PivotPrediction]:
"""Update existing predictions or generate new ones"""
# Remove expired predictions
current_time = datetime.now()
self.current_predictions = [
pred for pred in self.current_predictions
if pred.timestamp > current_time - timedelta(minutes=60)
]
# Generate new predictions if needed
if len(self.current_predictions) < 5:
new_predictions = self.generate_predictions(market_data, current_price)
return new_predictions
return self.current_predictions
def capture_training_data(self, actual_pivot: ActualPivot) -> None:
"""
Capture training data by comparing predictions with actual pivots
Args:
actual_pivot: Detected actual pivot point
"""
try:
current_time = datetime.now()
# Find matching predictions within time window
matching_predictions = [
pred for pred in self.current_predictions
if (pred.type == actual_pivot.type and
abs((pred.timestamp - actual_pivot.timestamp).total_seconds()) < 1800) # 30 min window
]
for prediction in matching_predictions:
# Calculate accuracy metrics
price_accuracy = self._calculate_price_accuracy(prediction, actual_pivot)
time_accuracy = self._calculate_time_accuracy(prediction, actual_pivot)
training_point = TrainingDataPoint(
prediction=prediction,
actual_pivot=actual_pivot,
prediction_accuracy=price_accuracy,
time_accuracy=time_accuracy,
captured_at=current_time
)
self.training_data.append(training_point)
logger.info(f"Captured training data point: {prediction.type} pivot with {price_accuracy:.2%} accuracy")
# Save training data periodically
if len(self.training_data) % 5 == 0:
self._save_training_data()
except Exception as e:
logger.error(f"Error capturing training data: {e}")
def _calculate_price_accuracy(self, prediction: PivotPrediction, actual: ActualPivot) -> float:
"""Calculate price prediction accuracy"""
if actual.price == 0:
return 0.0
price_diff = abs(prediction.predicted_price - actual.price)
accuracy = max(0.0, 1.0 - (price_diff / actual.price))
return accuracy
def _calculate_time_accuracy(self, prediction: PivotPrediction, actual: ActualPivot) -> float:
"""Calculate timing prediction accuracy"""
time_diff_seconds = abs((prediction.timestamp - actual.timestamp).total_seconds())
max_acceptable_diff = 1800 # 30 minutes
accuracy = max(0.0, 1.0 - (time_diff_seconds / max_acceptable_diff))
return accuracy
def _save_training_data(self) -> None:
"""Save training data to JSON file"""
try:
filename = f"cnn_training_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
filepath = os.path.join(self.training_data_dir, filename)
# Convert to serializable format
data_to_save = []
for point in self.training_data:
data_to_save.append({
'prediction': {
'level': point.prediction.level,
'type': point.prediction.type,
'predicted_price': point.prediction.predicted_price,
'confidence': point.prediction.confidence,
'timestamp': point.prediction.timestamp.isoformat(),
'current_price': point.prediction.current_price,
'model_inputs': point.prediction.model_inputs
},
'actual_pivot': {
'type': point.actual_pivot.type,
'price': point.actual_pivot.price,
'timestamp': point.actual_pivot.timestamp.isoformat(),
'strength': point.actual_pivot.strength
} if point.actual_pivot else None,
'prediction_accuracy': point.prediction_accuracy,
'time_accuracy': point.time_accuracy,
'captured_at': point.captured_at.isoformat()
})
with open(filepath, 'w') as f:
json.dump(data_to_save, f, indent=2)
logger.info(f"Saved {len(data_to_save)} training data points to {filepath}")
# Clear processed data
self.training_data = []
except Exception as e:
logger.error(f"Error saving training data: {e}")
def get_prediction_stats(self) -> Dict:
"""Get current prediction statistics"""
if not self.current_predictions:
return {'active_predictions': 0, 'high_confidence': 0, 'low_confidence': 0}
high_conf = len([p for p in self.current_predictions if p.confidence > 0.7])
low_conf = len([p for p in self.current_predictions if p.confidence <= 0.5])
return {
'active_predictions': len(self.current_predictions),
'high_confidence': high_conf,
'medium_confidence': len(self.current_predictions) - high_conf - low_conf,
'low_confidence': low_conf,
'avg_confidence': np.mean([p.confidence for p in self.current_predictions])
}
def get_training_stats(self) -> Dict:
"""Get training data capture statistics"""
return {
'captured_points': len(self.training_data),
'avg_price_accuracy': np.mean([p.prediction_accuracy for p in self.training_data if p.prediction_accuracy]) if self.training_data else 0,
'avg_time_accuracy': np.mean([p.time_accuracy for p in self.training_data if p.time_accuracy]) if self.training_data else 0
}