124 lines
4.7 KiB
Python
124 lines
4.7 KiB
Python
"""
|
|
Enhanced Data Interface with additional NN trading parameters
|
|
"""
|
|
from typing import List, Optional, Tuple
|
|
import numpy as np
|
|
import pandas as pd
|
|
from datetime import datetime
|
|
from .data_interface import DataInterface
|
|
|
|
class MultiDataInterface(DataInterface):
|
|
"""
|
|
Enhanced data interface that supports window_size and output_size parameters
|
|
for neural network trading models.
|
|
"""
|
|
|
|
def __init__(self, symbol: str,
|
|
timeframes: List[str],
|
|
window_size: int = 20,
|
|
output_size: int = 3,
|
|
data_dir: str = "NN/data"):
|
|
"""
|
|
Initialize with window_size and output_size for NN predictions.
|
|
"""
|
|
super().__init__(symbol, timeframes, data_dir)
|
|
self.window_size = window_size
|
|
self.output_size = output_size
|
|
self.scalers = {} # Store scalers for each timeframe
|
|
self.min_window_threshold = 100 # Minimum candles needed for training
|
|
|
|
def get_feature_count(self) -> int:
|
|
"""
|
|
Get number of features (OHLCV) for NN input.
|
|
"""
|
|
return 5 # open, high, low, close, volume
|
|
|
|
def prepare_training_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
"""Prepare training data with windowed sequences"""
|
|
# Get historical data for primary timeframe
|
|
primary_tf = self.timeframes[0]
|
|
df = self.get_historical_data(timeframe=primary_tf,
|
|
n_candles=self.min_window_threshold + 1000)
|
|
|
|
if df is None or len(df) < self.min_window_threshold:
|
|
raise ValueError(f"Insufficient data for training. Need at least {self.min_window_threshold} candles")
|
|
|
|
# Prepare OHLCV sequences
|
|
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values
|
|
|
|
# Create sequences and labels
|
|
X = []
|
|
y = []
|
|
|
|
for i in range(len(ohlcv) - self.window_size - self.output_size):
|
|
# Input sequence
|
|
seq = ohlcv[i:i+self.window_size]
|
|
X.append(seq)
|
|
|
|
# Output target (price movement direction)
|
|
close_prices = ohlcv[i+self.window_size:i+self.window_size+self.output_size, 3] # Close prices
|
|
price_changes = np.diff(close_prices)
|
|
|
|
if self.output_size == 1:
|
|
# Binary classification (up/down)
|
|
label = 1 if price_changes[0] > 0 else 0
|
|
elif self.output_size == 3:
|
|
# 3-class classification (buy/hold/sell)
|
|
if price_changes[0] > 0.002: # Significant rise
|
|
label = 0 # Buy
|
|
elif price_changes[0] < -0.002: # Significant drop
|
|
label = 2 # Sell
|
|
else:
|
|
label = 1 # Hold
|
|
else:
|
|
raise ValueError(f"Unsupported output_size: {self.output_size}")
|
|
|
|
y.append(label)
|
|
|
|
# Convert to numpy arrays
|
|
X = np.array(X)
|
|
y = np.array(y)
|
|
|
|
# Split into train/validation (80/20)
|
|
split_idx = int(0.8 * len(X))
|
|
X_train, y_train = X[:split_idx], y[:split_idx]
|
|
X_val, y_val = X[split_idx:], y[split_idx:]
|
|
|
|
return X_train, y_train, X_val, y_val
|
|
|
|
def prepare_prediction_data(self) -> np.ndarray:
|
|
"""Prepare most recent window for predictions"""
|
|
primary_tf = self.timeframes[0]
|
|
df = self.get_historical_data(timeframe=primary_tf,
|
|
n_candles=self.window_size,
|
|
use_cache=False)
|
|
|
|
if df is None or len(df) < self.window_size:
|
|
raise ValueError(f"Need at least {self.window_size} candles for prediction")
|
|
|
|
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values[-self.window_size:]
|
|
return np.array([ohlcv]) # Add batch dimension
|
|
|
|
def process_predictions(self, predictions: np.ndarray):
|
|
"""Convert prediction probabilities to trading signals"""
|
|
signals = []
|
|
for pred in predictions:
|
|
if self.output_size == 1:
|
|
signal = "BUY" if pred[0] > 0.5 else "SELL"
|
|
confidence = np.abs(pred[0] - 0.5) * 2 # Convert to 0-1 scale
|
|
elif self.output_size == 3:
|
|
action_idx = np.argmax(pred)
|
|
signal = ["BUY", "HOLD", "SELL"][action_idx]
|
|
confidence = pred[action_idx]
|
|
else:
|
|
signal = "HOLD"
|
|
confidence = 0.0
|
|
|
|
signals.append({
|
|
'action': signal,
|
|
'confidence': confidence,
|
|
'timestamp': datetime.now().isoformat()
|
|
})
|
|
|
|
return signals
|