multi pair inputs, wip, working training??
This commit is contained in:
@ -3,12 +3,17 @@ Neural Network Models
|
||||
====================
|
||||
|
||||
This package contains the neural network models used in the trading system:
|
||||
- CNN Model: Deep convolutional neural network for feature extraction
|
||||
- Transformer Model: Processes high-level features for improved pattern recognition
|
||||
- CNN Model: Deep convolutional neural network for feature extraction
|
||||
- Transformer Model: Processes high-level features for improved pattern recognition
|
||||
- MoE: Mixture of Experts model that combines multiple neural networks
|
||||
|
||||
PyTorch implementation only.
|
||||
"""
|
||||
|
||||
from NN.models.cnn_model import CNNModel
|
||||
from NN.models.transformer_model import TransformerModel, TransformerBlock, MixtureOfExpertsModel
|
||||
from NN.models.cnn_model_pytorch import CNNModelPyTorch as CNNModel
|
||||
from NN.models.transformer_model_pytorch import (
|
||||
TransformerModelPyTorch as TransformerModel,
|
||||
MixtureOfExpertsModelPyTorch as MixtureOfExpertsModel
|
||||
)
|
||||
|
||||
__all__ = ['CNNModel', 'TransformerModel', 'TransformerBlock', 'MixtureOfExpertsModel']
|
||||
__all__ = ['CNNModel', 'TransformerModel', 'MixtureOfExpertsModel']
|
||||
|
@ -227,8 +227,8 @@ def realtime(data_interface, model, args):
|
||||
logger.info("Starting real-time mode...")
|
||||
|
||||
try:
|
||||
# Skip realtime import for training mode
|
||||
RealtimeAnalyzer = None
|
||||
# Import realtime analyzer
|
||||
from NN.utils.realtime_analyzer import RealtimeAnalyzer
|
||||
|
||||
# Load the latest model
|
||||
model_dir = os.path.join('models')
|
||||
|
285
NN/realtime_main.py
Normal file
285
NN/realtime_main.py
Normal file
@ -0,0 +1,285 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Neural Network Trading System Main Module
|
||||
|
||||
This module serves as the main entry point for the NN trading system,
|
||||
coordinating data flow between different components and implementing
|
||||
training and inference pipelines.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler(os.path.join('logs', f'nn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'))
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger('NN')
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
|
||||
def parse_arguments():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(description='Neural Network Trading System')
|
||||
|
||||
parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train',
|
||||
help='Mode to run (train, predict, realtime)')
|
||||
parser.add_argument('--symbol', type=str, default='BTC/USD',
|
||||
help='Main trading pair symbol (default: BTC/USD)')
|
||||
parser.add_argument('--context-pairs', type=str, nargs='*', default=[],
|
||||
help='Additional context trading pairs')
|
||||
parser.add_argument('--timeframes', type=str, nargs='+', default=['5m', '15m', '1h'],
|
||||
help='Timeframes to use (default: 5m,15m,1h)')
|
||||
parser.add_argument('--window-size', type=int, default=30,
|
||||
help='Window size for input data (default: 30)')
|
||||
parser.add_argument('--output-size', type=int, default=5,
|
||||
help='Output size (1=up/down, 3=BUY/HOLD/SELL, 5=with extrema)')
|
||||
parser.add_argument('--batch-size', type=int, default=32,
|
||||
help='Batch size for training')
|
||||
parser.add_argument('--epochs', type=int, default=100,
|
||||
help='Number of epochs for training')
|
||||
parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn',
|
||||
help='Model type to use')
|
||||
parser.add_argument('--framework', type=str, choices=['tensorflow', 'pytorch'], default='pytorch',
|
||||
help='Deep learning framework to use')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def main():
|
||||
"""Main entry point for the NN trading system"""
|
||||
# Parse arguments
|
||||
args = parse_arguments()
|
||||
|
||||
logger.info(f"Starting NN Trading System in {args.mode} mode")
|
||||
logger.info(f"Main Symbol: {args.symbol}")
|
||||
if args.context_pairs:
|
||||
logger.info(f"Context Pairs: {args.context_pairs}")
|
||||
logger.info(f"Timeframes: {args.timeframes}")
|
||||
logger.info(f"Window Size: {args.window_size}")
|
||||
logger.info(f"Output Size: {args.output_size} (1=up/down, 3=BUY/HOLD/SELL, 5=with extrema)")
|
||||
logger.info(f"Model Type: {args.model_type}")
|
||||
logger.info(f"Framework: {args.framework}")
|
||||
|
||||
# Import the appropriate modules based on the framework
|
||||
if args.framework == 'pytorch':
|
||||
try:
|
||||
import torch
|
||||
logger.info(f"Using PyTorch {torch.__version__}")
|
||||
|
||||
# Import PyTorch-based modules
|
||||
from NN.utils.multi_data_interface import MultiDataInterface
|
||||
|
||||
if args.model_type == 'cnn':
|
||||
from NN.models.cnn_model_pytorch import CNNModelPyTorch as Model
|
||||
elif args.model_type == 'transformer':
|
||||
from NN.models.transformer_model_pytorch import TransformerModelPyTorchWrapper as Model
|
||||
elif args.model_type == 'moe':
|
||||
from NN.models.transformer_model_pytorch import MixtureOfExpertsModelPyTorch as Model
|
||||
else:
|
||||
logger.error(f"Unknown model type: {args.model_type}")
|
||||
return
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import PyTorch modules: {str(e)}")
|
||||
logger.error("Please make sure PyTorch is installed or use the TensorFlow framework.")
|
||||
return
|
||||
|
||||
elif args.framework == 'tensorflow':
|
||||
try:
|
||||
import tensorflow as tf
|
||||
logger.info(f"Using TensorFlow {tf.__version__}")
|
||||
|
||||
# Import TensorFlow-based modules
|
||||
from NN.utils.multi_data_interface import MultiDataInterface
|
||||
|
||||
if args.model_type == 'cnn':
|
||||
from NN.models.cnn_model import CNNModel as Model
|
||||
elif args.model_type == 'transformer':
|
||||
from NN.models.transformer_model import TransformerModel as Model
|
||||
elif args.model_type == 'moe':
|
||||
from NN.models.transformer_model import MixtureOfExpertsModel as Model
|
||||
else:
|
||||
logger.error(f"Unknown model type: {args.model_type}")
|
||||
return
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import TensorFlow modules: {str(e)}")
|
||||
logger.error("Please make sure TensorFlow is installed or use the PyTorch framework.")
|
||||
return
|
||||
else:
|
||||
logger.error(f"Unknown framework: {args.framework}")
|
||||
return
|
||||
|
||||
# Initialize data interface
|
||||
try:
|
||||
logger.info("Initializing data interface...")
|
||||
data_interface = MultiDataInterface(
|
||||
symbol=args.symbol,
|
||||
timeframes=args.timeframes,
|
||||
window_size=args.window_size,
|
||||
output_size=args.output_size
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize data interface: {str(e)}")
|
||||
return
|
||||
|
||||
# Initialize model
|
||||
try:
|
||||
logger.info(f"Initializing {args.model_type.upper()} model...")
|
||||
# Calculate actual feature count (OHLCV per timeframe)
|
||||
num_features = 5 * len(args.timeframes)
|
||||
model = Model(
|
||||
window_size=args.window_size,
|
||||
num_features=num_features,
|
||||
output_size=args.output_size,
|
||||
timeframes=args.timeframes
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model: {str(e)}")
|
||||
return
|
||||
|
||||
# Execute the requested mode
|
||||
if args.mode == 'train':
|
||||
train(data_interface, model, args)
|
||||
elif args.mode == 'predict':
|
||||
predict(data_interface, model, args)
|
||||
elif args.mode == 'realtime':
|
||||
realtime(data_interface, model, args)
|
||||
else:
|
||||
logger.error(f"Unknown mode: {args.mode}")
|
||||
return
|
||||
|
||||
logger.info("Neural Network Trading System finished successfully")
|
||||
|
||||
def train(data_interface, model, args):
|
||||
"""Train the model using the data interface"""
|
||||
logger.info("Starting training mode...")
|
||||
|
||||
try:
|
||||
# Prepare training data
|
||||
logger.info("Preparing training data...")
|
||||
X, y, _ = data_interface.prepare_nn_input(
|
||||
timeframes=args.timeframes,
|
||||
n_candles=1000,
|
||||
window_size=args.window_size
|
||||
)
|
||||
logger.info(f"Training data shape: {X.shape}")
|
||||
logger.info(f"Target data shape: {y.shape}")
|
||||
|
||||
# Split into train/validation sets (80/20)
|
||||
split_idx = int(len(X) * 0.8)
|
||||
X_train, y_train = X[:split_idx], y[:split_idx]
|
||||
X_val, y_val = X[split_idx:], y[split_idx:]
|
||||
|
||||
# Train the model
|
||||
logger.info("Training model...")
|
||||
model.train(
|
||||
X_train, y_train,
|
||||
X_val, y_val,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs
|
||||
)
|
||||
|
||||
# Save the model
|
||||
model_path = os.path.join(
|
||||
'models',
|
||||
f"{args.model_type}_{args.symbol.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
)
|
||||
logger.info(f"Saving model to {model_path}...")
|
||||
model.save(model_path)
|
||||
|
||||
# Evaluate the model
|
||||
logger.info("Evaluating model...")
|
||||
metrics = model.evaluate(X_val, y_val)
|
||||
logger.info(f"Evaluation metrics: {metrics}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training mode: {str(e)}")
|
||||
return
|
||||
|
||||
def predict(data_interface, model, args):
|
||||
"""Make predictions using the trained model"""
|
||||
logger.info("Starting prediction mode...")
|
||||
|
||||
try:
|
||||
# Load the latest model
|
||||
model_dir = os.path.join('models')
|
||||
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
|
||||
|
||||
if not model_files:
|
||||
logger.error(f"No saved model found for type {args.model_type}")
|
||||
return
|
||||
|
||||
latest_model = sorted(model_files)[-1]
|
||||
model_path = os.path.join(model_dir, latest_model)
|
||||
|
||||
logger.info(f"Loading model from {model_path}...")
|
||||
model.load(model_path)
|
||||
|
||||
# Prepare prediction data
|
||||
logger.info("Preparing prediction data...")
|
||||
X_pred = data_interface.prepare_prediction_data()
|
||||
|
||||
# Make predictions
|
||||
logger.info("Making predictions...")
|
||||
predictions = model.predict(X_pred)
|
||||
|
||||
# Process and display predictions
|
||||
logger.info("Processing predictions...")
|
||||
data_interface.process_predictions(predictions)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prediction mode: {str(e)}")
|
||||
return
|
||||
|
||||
def realtime(data_interface, model, args):
|
||||
"""Run the model in real-time mode"""
|
||||
logger.info("Starting real-time mode...")
|
||||
|
||||
try:
|
||||
# Import realtime analyzer
|
||||
from NN.utils.realtime_analyzer import RealtimeAnalyzer
|
||||
|
||||
# Load the latest model
|
||||
model_dir = os.path.join('models')
|
||||
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
|
||||
|
||||
if not model_files:
|
||||
logger.error(f"No saved model found for type {args.model_type}")
|
||||
return
|
||||
|
||||
latest_model = sorted(model_files)[-1]
|
||||
model_path = os.path.join(model_dir, latest_model)
|
||||
|
||||
logger.info(f"Loading model from {model_path}...")
|
||||
model.load(model_path)
|
||||
|
||||
# Initialize realtime analyzer
|
||||
logger.info("Initializing real-time analyzer...")
|
||||
realtime_analyzer = RealtimeAnalyzer(
|
||||
data_interface=data_interface,
|
||||
model=model,
|
||||
symbol=args.symbol,
|
||||
timeframes=args.timeframes
|
||||
)
|
||||
|
||||
# Start real-time analysis
|
||||
logger.info("Starting real-time analysis...")
|
||||
realtime_analyzer.start()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in real-time mode: {str(e)}")
|
||||
return
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -17,16 +17,15 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class DataInterface:
|
||||
"""
|
||||
Handles data collection, processing, and preparation for neural network models.
|
||||
|
||||
This class is responsible for:
|
||||
1. Fetching historical data
|
||||
2. Preprocessing data for neural network input
|
||||
3. Generating training datasets
|
||||
4. Handling real-time data integration
|
||||
Enhanced Data Interface supporting:
|
||||
- Multiple trading pairs (up to 3)
|
||||
- Multiple timeframes per pair (1s, 1m, 1h, 1d + custom)
|
||||
- Technical indicators (up to 20)
|
||||
- Cross-timeframe normalization
|
||||
- Real-time tick streaming
|
||||
"""
|
||||
|
||||
def __init__(self, symbol="BTC/USDT", timeframes=None, data_dir="NN/data"):
|
||||
def __init__(self, symbol=None, timeframes=None, data_dir="NN/data"):
|
||||
"""
|
||||
Initialize the data interface.
|
||||
|
||||
@ -157,9 +156,9 @@ class DataInterface:
|
||||
else:
|
||||
cycle = np.sin(i / 24 * np.pi) * 0.01 # Daily cycle
|
||||
|
||||
# Calculate price change with random walk + cycles
|
||||
price_change = price * (drift + volatility * np.random.randn() + cycle)
|
||||
price += price_change
|
||||
# Calculate price change with random walk + cycles (clamped to prevent overflow)
|
||||
price_change = price * np.clip(drift + volatility * np.random.randn() + cycle, -0.1, 0.1)
|
||||
price = np.clip(price + price_change, 1000, 100000) # Keep price in reasonable range
|
||||
|
||||
# Generate OHLC from the price
|
||||
open_price = price
|
||||
@ -171,8 +170,8 @@ class DataInterface:
|
||||
high_price = max(high_price, open_price, close_price)
|
||||
low_price = min(low_price, open_price, close_price)
|
||||
|
||||
# Generate volume (higher for larger price movements)
|
||||
volume = abs(price_change) * (10000 + 5000 * np.random.rand())
|
||||
# Generate volume (higher for larger price movements) with safe calculation
|
||||
volume = 10000 + 5000 * np.random.rand() + abs(price_change)/price * 10000
|
||||
|
||||
prices.append((open_price, high_price, low_price, close_price))
|
||||
volumes.append(volume)
|
||||
@ -217,19 +216,41 @@ class DataInterface:
|
||||
logger.error("No data available for feature creation")
|
||||
return None, None, None
|
||||
|
||||
# For simplicity, we'll use just one timeframe for now
|
||||
# In a more complex implementation, we would merge multiple timeframes
|
||||
primary_tf = timeframes[0]
|
||||
if primary_tf not in dfs:
|
||||
logger.error(f"Primary timeframe {primary_tf} not available")
|
||||
# Create features for each timeframe
|
||||
features = []
|
||||
targets = []
|
||||
timestamps = []
|
||||
|
||||
for tf in timeframes:
|
||||
if tf in dfs:
|
||||
X, y, ts = self._create_features(dfs[tf], window_size)
|
||||
features.append(X)
|
||||
if len(targets) == 0: # Only need targets from one timeframe
|
||||
targets = y
|
||||
timestamps = ts
|
||||
|
||||
if not features:
|
||||
return None, None, None
|
||||
|
||||
# Stack features from all timeframes along the time dimension
|
||||
# Reshape each timeframe's features to [samples, window, 1, features]
|
||||
reshaped_features = [f.reshape(f.shape[0], f.shape[1], 1, f.shape[2])
|
||||
for f in features]
|
||||
# Concatenate along the channel dimension
|
||||
X = np.concatenate(reshaped_features, axis=2)
|
||||
# Reshape to [samples, window, features*timeframes]
|
||||
X = X.reshape(X.shape[0], X.shape[1], -1)
|
||||
|
||||
df = dfs[primary_tf]
|
||||
# Validate data
|
||||
if np.any(np.isnan(X)) or np.any(np.isinf(X)):
|
||||
logger.error("Generated features contain NaN or infinite values")
|
||||
return None, None, None
|
||||
|
||||
# Ensure all values are finite and normalized
|
||||
X = np.nan_to_num(X, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
X = np.clip(X, -1e6, 1e6) # Clip extreme values
|
||||
|
||||
# Create features
|
||||
X, y, timestamps = self._create_features(df, window_size)
|
||||
|
||||
return X, y, timestamps
|
||||
return X, targets, timestamps
|
||||
|
||||
def _create_features(self, df, window_size):
|
||||
"""
|
||||
@ -248,9 +269,28 @@ class DataInterface:
|
||||
# Extract OHLCV columns
|
||||
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values
|
||||
|
||||
# Scale the data
|
||||
scaler = MinMaxScaler()
|
||||
ohlcv_scaled = scaler.fit_transform(ohlcv)
|
||||
# Validate data before scaling
|
||||
if np.any(np.isnan(ohlcv)) or np.any(np.isinf(ohlcv)):
|
||||
logger.error("Input data contains NaN or infinite values")
|
||||
return None, None, None
|
||||
|
||||
# Handle potential constant columns (avoid division by zero in scaler)
|
||||
ohlcv = np.nan_to_num(ohlcv, nan=0.0)
|
||||
ranges = np.ptp(ohlcv, axis=0)
|
||||
for i in range(len(ranges)):
|
||||
if ranges[i] == 0: # Constant column
|
||||
ohlcv[:, i] = 1 if i == 3 else 0 # Set close to 1, others to 0
|
||||
|
||||
# Scale the data with safety checks
|
||||
try:
|
||||
scaler = MinMaxScaler()
|
||||
ohlcv_scaled = scaler.fit_transform(ohlcv)
|
||||
if np.any(np.isnan(ohlcv_scaled)) or np.any(np.isinf(ohlcv_scaled)):
|
||||
logger.error("Scaling produced invalid values")
|
||||
return None, None, None
|
||||
except Exception as e:
|
||||
logger.error(f"Scaling failed: {str(e)}")
|
||||
return None, None, None
|
||||
|
||||
# Store the scaler for later use
|
||||
timeframe = next((tf for tf in self.timeframes if self.dataframes.get(tf) is not None and
|
||||
@ -343,6 +383,11 @@ class DataInterface:
|
||||
logger.info(f"Dataset generated and saved: {dataset_name}")
|
||||
return dataset_info
|
||||
|
||||
def get_feature_count(self):
|
||||
"""Get the number of features per input sample"""
|
||||
# OHLCV (5 features) per timeframe
|
||||
return 5 * len(self.timeframes)
|
||||
|
||||
def prepare_realtime_input(self, timeframe='1h', n_candles=30, window_size=20):
|
||||
"""
|
||||
Prepare a single input sample from the most recent data for real-time inference.
|
||||
@ -387,4 +432,4 @@ class DataInterface:
|
||||
# Get timestamp of the most recent candle
|
||||
timestamp = df['timestamp'].iloc[-1]
|
||||
|
||||
return X, timestamp
|
||||
return X, timestamp
|
||||
|
123
NN/utils/multi_data_interface.py
Normal file
123
NN/utils/multi_data_interface.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""
|
||||
Enhanced Data Interface with additional NN trading parameters
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from .data_interface import DataInterface
|
||||
|
||||
class MultiDataInterface(DataInterface):
|
||||
"""
|
||||
Enhanced data interface that supports window_size and output_size parameters
|
||||
for neural network trading models.
|
||||
"""
|
||||
|
||||
def __init__(self, symbol: str,
|
||||
timeframes: List[str],
|
||||
window_size: int = 20,
|
||||
output_size: int = 3,
|
||||
data_dir: str = "NN/data"):
|
||||
"""
|
||||
Initialize with window_size and output_size for NN predictions.
|
||||
"""
|
||||
super().__init__(symbol, timeframes, data_dir)
|
||||
self.window_size = window_size
|
||||
self.output_size = output_size
|
||||
self.scalers = {} # Store scalers for each timeframe
|
||||
self.min_window_threshold = 100 # Minimum candles needed for training
|
||||
|
||||
def get_feature_count(self) -> int:
|
||||
"""
|
||||
Get number of features (OHLCV) for NN input.
|
||||
"""
|
||||
return 5 # open, high, low, close, volume
|
||||
|
||||
def prepare_training_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Prepare training data with windowed sequences"""
|
||||
# Get historical data for primary timeframe
|
||||
primary_tf = self.timeframes[0]
|
||||
df = self.get_historical_data(timeframe=primary_tf,
|
||||
n_candles=self.min_window_threshold + 1000)
|
||||
|
||||
if df is None or len(df) < self.min_window_threshold:
|
||||
raise ValueError(f"Insufficient data for training. Need at least {self.min_window_threshold} candles")
|
||||
|
||||
# Prepare OHLCV sequences
|
||||
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values
|
||||
|
||||
# Create sequences and labels
|
||||
X = []
|
||||
y = []
|
||||
|
||||
for i in range(len(ohlcv) - self.window_size - self.output_size):
|
||||
# Input sequence
|
||||
seq = ohlcv[i:i+self.window_size]
|
||||
X.append(seq)
|
||||
|
||||
# Output target (price movement direction)
|
||||
close_prices = ohlcv[i+self.window_size:i+self.window_size+self.output_size, 3] # Close prices
|
||||
price_changes = np.diff(close_prices)
|
||||
|
||||
if self.output_size == 1:
|
||||
# Binary classification (up/down)
|
||||
label = 1 if price_changes[0] > 0 else 0
|
||||
elif self.output_size == 3:
|
||||
# 3-class classification (buy/hold/sell)
|
||||
if price_changes[0] > 0.002: # Significant rise
|
||||
label = 0 # Buy
|
||||
elif price_changes[0] < -0.002: # Significant drop
|
||||
label = 2 # Sell
|
||||
else:
|
||||
label = 1 # Hold
|
||||
else:
|
||||
raise ValueError(f"Unsupported output_size: {self.output_size}")
|
||||
|
||||
y.append(label)
|
||||
|
||||
# Convert to numpy arrays
|
||||
X = np.array(X)
|
||||
y = np.array(y)
|
||||
|
||||
# Split into train/validation (80/20)
|
||||
split_idx = int(0.8 * len(X))
|
||||
X_train, y_train = X[:split_idx], y[:split_idx]
|
||||
X_val, y_val = X[split_idx:], y[split_idx:]
|
||||
|
||||
return X_train, y_train, X_val, y_val
|
||||
|
||||
def prepare_prediction_data(self) -> np.ndarray:
|
||||
"""Prepare most recent window for predictions"""
|
||||
primary_tf = self.timeframes[0]
|
||||
df = self.get_historical_data(timeframe=primary_tf,
|
||||
n_candles=self.window_size,
|
||||
use_cache=False)
|
||||
|
||||
if df is None or len(df) < self.window_size:
|
||||
raise ValueError(f"Need at least {self.window_size} candles for prediction")
|
||||
|
||||
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values[-self.window_size:]
|
||||
return np.array([ohlcv]) # Add batch dimension
|
||||
|
||||
def process_predictions(self, predictions: np.ndarray):
|
||||
"""Convert prediction probabilities to trading signals"""
|
||||
signals = []
|
||||
for pred in predictions:
|
||||
if self.output_size == 1:
|
||||
signal = "BUY" if pred[0] > 0.5 else "SELL"
|
||||
confidence = np.abs(pred[0] - 0.5) * 2 # Convert to 0-1 scale
|
||||
elif self.output_size == 3:
|
||||
action_idx = np.argmax(pred)
|
||||
signal = ["BUY", "HOLD", "SELL"][action_idx]
|
||||
confidence = pred[action_idx]
|
||||
else:
|
||||
signal = "HOLD"
|
||||
confidence = 0.0
|
||||
|
||||
signals.append({
|
||||
'action': signal,
|
||||
'confidence': confidence,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
})
|
||||
|
||||
return signals
|
182
NN/utils/realtime_analyzer.py
Normal file
182
NN/utils/realtime_analyzer.py
Normal file
@ -0,0 +1,182 @@
|
||||
"""
|
||||
Realtime Analyzer for Neural Network Trading System
|
||||
|
||||
This module implements real-time analysis of market data using trained neural network models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RealtimeAnalyzer:
|
||||
"""
|
||||
Handles real-time analysis of market data using trained neural network models.
|
||||
|
||||
Features:
|
||||
- Connects to real-time data sources (websockets)
|
||||
- Processes incoming data through the neural network
|
||||
- Generates trading signals
|
||||
- Manages risk and position sizing
|
||||
- Logs all trading decisions
|
||||
"""
|
||||
|
||||
def __init__(self, data_interface, model, symbol="BTC/USDT", timeframes=None):
|
||||
"""
|
||||
Initialize the realtime analyzer.
|
||||
|
||||
Args:
|
||||
data_interface (DataInterface): Preconfigured data interface
|
||||
model: Trained neural network model
|
||||
symbol (str): Trading pair symbol
|
||||
timeframes (list): List of timeframes to monitor
|
||||
"""
|
||||
self.data_interface = data_interface
|
||||
self.model = model
|
||||
self.symbol = symbol
|
||||
self.timeframes = timeframes or ['1h']
|
||||
self.running = False
|
||||
self.data_queue = Queue()
|
||||
self.prediction_interval = 60 # Seconds between predictions
|
||||
|
||||
logger.info(f"RealtimeAnalyzer initialized for {symbol}")
|
||||
|
||||
def start(self):
|
||||
"""Start the realtime analysis process."""
|
||||
if self.running:
|
||||
logger.warning("Realtime analyzer already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
|
||||
# Start data collection thread
|
||||
self.data_thread = Thread(target=self._collect_data, daemon=True)
|
||||
self.data_thread.start()
|
||||
|
||||
# Start analysis thread
|
||||
self.analysis_thread = Thread(target=self._analyze_data, daemon=True)
|
||||
self.analysis_thread.start()
|
||||
|
||||
logger.info("Realtime analysis started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the realtime analysis process."""
|
||||
self.running = False
|
||||
if hasattr(self, 'data_thread'):
|
||||
self.data_thread.join(timeout=1)
|
||||
if hasattr(self, 'analysis_thread'):
|
||||
self.analysis_thread.join(timeout=1)
|
||||
logger.info("Realtime analysis stopped")
|
||||
|
||||
def _collect_data(self):
|
||||
"""Thread function for collecting real-time data."""
|
||||
logger.info("Starting data collection thread")
|
||||
|
||||
# In a real implementation, this would connect to websockets/API
|
||||
# For now, we'll simulate data collection from the data interface
|
||||
while self.running:
|
||||
try:
|
||||
# Get latest data for each timeframe
|
||||
for timeframe in self.timeframes:
|
||||
# Get recent data (simulating real-time updates)
|
||||
X, timestamp = self.data_interface.prepare_realtime_input(
|
||||
timeframe=timeframe,
|
||||
n_candles=30,
|
||||
window_size=self.data_interface.window_size
|
||||
)
|
||||
|
||||
if X is not None:
|
||||
self.data_queue.put({
|
||||
'timeframe': timeframe,
|
||||
'data': X,
|
||||
'timestamp': timestamp
|
||||
})
|
||||
|
||||
# Throttle data collection
|
||||
time.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in data collection: {str(e)}")
|
||||
time.sleep(5) # Wait before retrying
|
||||
|
||||
def _analyze_data(self):
|
||||
"""Thread function for analyzing data and generating signals."""
|
||||
logger.info("Starting analysis thread")
|
||||
|
||||
last_prediction_time = 0
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
# Only make predictions at the specified interval
|
||||
if current_time - last_prediction_time < self.prediction_interval:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
# Get latest data from queue
|
||||
if not self.data_queue.empty():
|
||||
data_item = self.data_queue.get()
|
||||
|
||||
# Make prediction
|
||||
prediction = self.model.predict(data_item['data'])
|
||||
|
||||
# Process prediction
|
||||
self._process_prediction(
|
||||
prediction=prediction,
|
||||
timeframe=data_item['timeframe'],
|
||||
timestamp=data_item['timestamp']
|
||||
)
|
||||
|
||||
last_prediction_time = current_time
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in analysis: {str(e)}")
|
||||
time.sleep(1) # Wait before retrying
|
||||
|
||||
def _process_prediction(self, prediction, timeframe, timestamp):
|
||||
"""
|
||||
Process model prediction and generate trading signals.
|
||||
|
||||
Args:
|
||||
prediction: Model prediction output
|
||||
timeframe (str): Timeframe the prediction is for
|
||||
timestamp: Timestamp of the prediction
|
||||
"""
|
||||
# Convert prediction to trading signal
|
||||
signal = self._prediction_to_signal(prediction)
|
||||
|
||||
# Log the signal
|
||||
logger.info(
|
||||
f"Signal generated - Timeframe: {timeframe}, "
|
||||
f"Timestamp: {timestamp}, "
|
||||
f"Signal: {signal}"
|
||||
)
|
||||
|
||||
# In a real implementation, we would execute trades here
|
||||
# For now, we'll just log the signals
|
||||
|
||||
def _prediction_to_signal(self, prediction):
|
||||
"""
|
||||
Convert model prediction to trading signal.
|
||||
|
||||
Args:
|
||||
prediction: Model prediction output
|
||||
|
||||
Returns:
|
||||
str: Trading signal (BUY, SELL, HOLD)
|
||||
"""
|
||||
# Simple threshold-based signal generation
|
||||
if len(prediction.shape) == 1:
|
||||
# Binary classification
|
||||
return "BUY" if prediction[0] > 0.5 else "SELL"
|
||||
else:
|
||||
# Multi-class classification (3 outputs)
|
||||
class_idx = np.argmax(prediction)
|
||||
return ["SELL", "HOLD", "BUY"][class_idx]
|
Reference in New Issue
Block a user