285 lines
12 KiB
Python
285 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
CNN Pivot Predictor Core Module
|
|
|
|
This module handles all CNN-based pivot prediction logic, separated from the web UI.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
import numpy as np
|
|
import pandas as pd
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Any, Optional, Tuple
|
|
import json
|
|
import os
|
|
from dataclasses import dataclass
|
|
|
|
# Setup logging with ASCII-only output
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class PivotPrediction:
|
|
"""Dataclass for CNN pivot predictions"""
|
|
level: int
|
|
type: str # 'HIGH' or 'LOW'
|
|
predicted_price: float
|
|
confidence: float
|
|
timestamp: datetime
|
|
current_price: float
|
|
model_inputs: Optional[Dict] = None
|
|
|
|
@dataclass
|
|
class ActualPivot:
|
|
"""Dataclass for actual detected pivots"""
|
|
type: str # 'HIGH' or 'LOW'
|
|
price: float
|
|
timestamp: datetime
|
|
strength: int
|
|
confirmed: bool = False
|
|
|
|
@dataclass
|
|
class TrainingDataPoint:
|
|
"""Dataclass for capturing training comparison data"""
|
|
prediction: PivotPrediction
|
|
actual_pivot: Optional[ActualPivot]
|
|
prediction_accuracy: Optional[float]
|
|
time_accuracy: Optional[float]
|
|
captured_at: datetime
|
|
|
|
class CNNPivotPredictor:
|
|
"""Core CNN pivot prediction engine"""
|
|
|
|
def __init__(self, config: Optional[Dict] = None):
|
|
self.config = config or self._default_config()
|
|
self.current_predictions: List[PivotPrediction] = []
|
|
self.training_data: List[TrainingDataPoint] = []
|
|
self.model_available = False
|
|
|
|
# Initialize data storage paths
|
|
self.training_data_dir = "data/cnn_training"
|
|
os.makedirs(self.training_data_dir, exist_ok=True)
|
|
|
|
logger.info("CNN Pivot Predictor initialized")
|
|
|
|
def _default_config(self) -> Dict:
|
|
"""Default configuration for CNN predictor"""
|
|
return {
|
|
'prediction_levels': 5, # Williams Market Structure levels
|
|
'confidence_threshold': 0.3,
|
|
'model_timesteps': 900,
|
|
'model_features': 50,
|
|
'prediction_horizon_minutes': 30
|
|
}
|
|
|
|
def generate_predictions(self, market_data: pd.DataFrame, current_price: float) -> List[PivotPrediction]:
|
|
"""
|
|
Generate CNN pivot predictions based on current market data
|
|
|
|
Args:
|
|
market_data: DataFrame with OHLCV data
|
|
current_price: Current market price
|
|
|
|
Returns:
|
|
List of pivot predictions
|
|
"""
|
|
try:
|
|
current_time = datetime.now()
|
|
predictions = []
|
|
|
|
# For demo purposes, generate sample predictions
|
|
# In production, this would use the actual CNN model
|
|
for level in range(1, self.config['prediction_levels'] + 1):
|
|
# HIGH pivot prediction
|
|
high_confidence = np.random.uniform(0.4, 0.9)
|
|
if high_confidence > self.config['confidence_threshold']:
|
|
high_price = current_price + np.random.uniform(10, 50)
|
|
|
|
high_prediction = PivotPrediction(
|
|
level=level,
|
|
type='HIGH',
|
|
predicted_price=high_price,
|
|
confidence=high_confidence,
|
|
timestamp=current_time + timedelta(minutes=level*5),
|
|
current_price=current_price,
|
|
model_inputs=self._prepare_model_inputs(market_data)
|
|
)
|
|
predictions.append(high_prediction)
|
|
|
|
# LOW pivot prediction
|
|
low_confidence = np.random.uniform(0.3, 0.8)
|
|
if low_confidence > self.config['confidence_threshold']:
|
|
low_price = current_price - np.random.uniform(15, 40)
|
|
|
|
low_prediction = PivotPrediction(
|
|
level=level,
|
|
type='LOW',
|
|
predicted_price=low_price,
|
|
confidence=low_confidence,
|
|
timestamp=current_time + timedelta(minutes=level*7),
|
|
current_price=current_price,
|
|
model_inputs=self._prepare_model_inputs(market_data)
|
|
)
|
|
predictions.append(low_prediction)
|
|
|
|
self.current_predictions = predictions
|
|
logger.info(f"Generated {len(predictions)} CNN pivot predictions")
|
|
return predictions
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating CNN predictions: {e}")
|
|
return []
|
|
|
|
def _prepare_model_inputs(self, market_data: pd.DataFrame) -> Dict:
|
|
"""Prepare model inputs for CNN prediction"""
|
|
if len(market_data) < self.config['model_timesteps']:
|
|
return {'insufficient_data': True}
|
|
|
|
# Extract last 900 timesteps with 50 features
|
|
recent_data = market_data.tail(self.config['model_timesteps'])
|
|
|
|
return {
|
|
'timesteps': len(recent_data),
|
|
'features': self.config['model_features'],
|
|
'price_range': (recent_data['low'].min(), recent_data['high'].max()),
|
|
'volume_avg': recent_data['volume'].mean(),
|
|
'timestamp': datetime.now().isoformat()
|
|
}
|
|
|
|
def update_predictions(self, market_data: pd.DataFrame, current_price: float) -> List[PivotPrediction]:
|
|
"""Update existing predictions or generate new ones"""
|
|
# Remove expired predictions
|
|
current_time = datetime.now()
|
|
self.current_predictions = [
|
|
pred for pred in self.current_predictions
|
|
if pred.timestamp > current_time - timedelta(minutes=60)
|
|
]
|
|
|
|
# Generate new predictions if needed
|
|
if len(self.current_predictions) < 5:
|
|
new_predictions = self.generate_predictions(market_data, current_price)
|
|
return new_predictions
|
|
|
|
return self.current_predictions
|
|
|
|
def capture_training_data(self, actual_pivot: ActualPivot) -> None:
|
|
"""
|
|
Capture training data by comparing predictions with actual pivots
|
|
|
|
Args:
|
|
actual_pivot: Detected actual pivot point
|
|
"""
|
|
try:
|
|
current_time = datetime.now()
|
|
|
|
# Find matching predictions within time window
|
|
matching_predictions = [
|
|
pred for pred in self.current_predictions
|
|
if (pred.type == actual_pivot.type and
|
|
abs((pred.timestamp - actual_pivot.timestamp).total_seconds()) < 1800) # 30 min window
|
|
]
|
|
|
|
for prediction in matching_predictions:
|
|
# Calculate accuracy metrics
|
|
price_accuracy = self._calculate_price_accuracy(prediction, actual_pivot)
|
|
time_accuracy = self._calculate_time_accuracy(prediction, actual_pivot)
|
|
|
|
training_point = TrainingDataPoint(
|
|
prediction=prediction,
|
|
actual_pivot=actual_pivot,
|
|
prediction_accuracy=price_accuracy,
|
|
time_accuracy=time_accuracy,
|
|
captured_at=current_time
|
|
)
|
|
|
|
self.training_data.append(training_point)
|
|
logger.info(f"Captured training data point: {prediction.type} pivot with {price_accuracy:.2%} accuracy")
|
|
|
|
# Save training data periodically
|
|
if len(self.training_data) % 5 == 0:
|
|
self._save_training_data()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error capturing training data: {e}")
|
|
|
|
def _calculate_price_accuracy(self, prediction: PivotPrediction, actual: ActualPivot) -> float:
|
|
"""Calculate price prediction accuracy"""
|
|
if actual.price == 0:
|
|
return 0.0
|
|
|
|
price_diff = abs(prediction.predicted_price - actual.price)
|
|
accuracy = max(0.0, 1.0 - (price_diff / actual.price))
|
|
return accuracy
|
|
|
|
def _calculate_time_accuracy(self, prediction: PivotPrediction, actual: ActualPivot) -> float:
|
|
"""Calculate timing prediction accuracy"""
|
|
time_diff_seconds = abs((prediction.timestamp - actual.timestamp).total_seconds())
|
|
max_acceptable_diff = 1800 # 30 minutes
|
|
accuracy = max(0.0, 1.0 - (time_diff_seconds / max_acceptable_diff))
|
|
return accuracy
|
|
|
|
def _save_training_data(self) -> None:
|
|
"""Save training data to JSON file"""
|
|
try:
|
|
filename = f"cnn_training_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
|
filepath = os.path.join(self.training_data_dir, filename)
|
|
|
|
# Convert to serializable format
|
|
data_to_save = []
|
|
for point in self.training_data:
|
|
data_to_save.append({
|
|
'prediction': {
|
|
'level': point.prediction.level,
|
|
'type': point.prediction.type,
|
|
'predicted_price': point.prediction.predicted_price,
|
|
'confidence': point.prediction.confidence,
|
|
'timestamp': point.prediction.timestamp.isoformat(),
|
|
'current_price': point.prediction.current_price,
|
|
'model_inputs': point.prediction.model_inputs
|
|
},
|
|
'actual_pivot': {
|
|
'type': point.actual_pivot.type,
|
|
'price': point.actual_pivot.price,
|
|
'timestamp': point.actual_pivot.timestamp.isoformat(),
|
|
'strength': point.actual_pivot.strength
|
|
} if point.actual_pivot else None,
|
|
'prediction_accuracy': point.prediction_accuracy,
|
|
'time_accuracy': point.time_accuracy,
|
|
'captured_at': point.captured_at.isoformat()
|
|
})
|
|
|
|
with open(filepath, 'w') as f:
|
|
json.dump(data_to_save, f, indent=2)
|
|
|
|
logger.info(f"Saved {len(data_to_save)} training data points to {filepath}")
|
|
|
|
# Clear processed data
|
|
self.training_data = []
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving training data: {e}")
|
|
|
|
def get_prediction_stats(self) -> Dict:
|
|
"""Get current prediction statistics"""
|
|
if not self.current_predictions:
|
|
return {'active_predictions': 0, 'high_confidence': 0, 'low_confidence': 0}
|
|
|
|
high_conf = len([p for p in self.current_predictions if p.confidence > 0.7])
|
|
low_conf = len([p for p in self.current_predictions if p.confidence <= 0.5])
|
|
|
|
return {
|
|
'active_predictions': len(self.current_predictions),
|
|
'high_confidence': high_conf,
|
|
'medium_confidence': len(self.current_predictions) - high_conf - low_conf,
|
|
'low_confidence': low_conf,
|
|
'avg_confidence': np.mean([p.confidence for p in self.current_predictions])
|
|
}
|
|
|
|
def get_training_stats(self) -> Dict:
|
|
"""Get training data capture statistics"""
|
|
return {
|
|
'captured_points': len(self.training_data),
|
|
'avg_price_accuracy': np.mean([p.prediction_accuracy for p in self.training_data if p.prediction_accuracy]) if self.training_data else 0,
|
|
'avg_time_accuracy': np.mean([p.time_accuracy for p in self.training_data if p.time_accuracy]) if self.training_data else 0
|
|
} |