365 lines
14 KiB
Python
365 lines
14 KiB
Python
"""
|
|
Realtime Analyzer for Neural Network Trading System
|
|
|
|
This module implements real-time analysis of market data using trained neural network models.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
import numpy as np
|
|
from threading import Thread
|
|
from queue import Queue
|
|
from datetime import datetime
|
|
import asyncio
|
|
import websockets
|
|
import json
|
|
import os
|
|
import pandas as pd
|
|
from collections import deque
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class RealtimeAnalyzer:
|
|
"""
|
|
Handles real-time analysis of market data using trained neural network models.
|
|
|
|
Features:
|
|
- Connects to real-time data sources (websockets)
|
|
- Processes tick data into multiple timeframes (1s, 1m, 1h, 1d)
|
|
- Uses trained models to analyze all timeframes
|
|
- Generates trading signals
|
|
- Manages risk and position sizing
|
|
- Logs all trading decisions
|
|
"""
|
|
|
|
def __init__(self, data_interface, model, symbol="BTC/USDT", timeframes=None):
|
|
"""
|
|
Initialize the realtime analyzer.
|
|
|
|
Args:
|
|
data_interface (DataInterface): Preconfigured data interface
|
|
model: Trained neural network model
|
|
symbol (str): Trading pair symbol
|
|
timeframes (list): List of timeframes to monitor (default: ['1s', '1m', '1h', '1d'])
|
|
"""
|
|
self.data_interface = data_interface
|
|
self.model = model
|
|
self.symbol = symbol
|
|
self.timeframes = timeframes or ['1s', '1m', '1h', '1d']
|
|
self.running = False
|
|
self.data_queue = Queue()
|
|
self.prediction_interval = 10 # Seconds between predictions
|
|
self.ws_url = f"wss://stream.binance.com:9443/ws/{symbol.replace('/', '').lower()}@trade"
|
|
self.ws = None
|
|
self.tick_storage = deque(maxlen=10000) # Store up to 10,000 ticks
|
|
self.candle_cache = {
|
|
'1s': deque(maxlen=5000),
|
|
'1m': deque(maxlen=5000),
|
|
'1h': deque(maxlen=5000),
|
|
'1d': deque(maxlen=5000)
|
|
}
|
|
|
|
logger.info(f"RealtimeAnalyzer initialized for {symbol} with timeframes: {self.timeframes}")
|
|
|
|
def start(self):
|
|
"""Start the realtime analysis process."""
|
|
if self.running:
|
|
logger.warning("Realtime analyzer already running")
|
|
return
|
|
|
|
self.running = True
|
|
|
|
# Start WebSocket connection thread
|
|
self.ws_thread = Thread(target=self._run_websocket, daemon=True)
|
|
self.ws_thread.start()
|
|
|
|
# Start data processing thread
|
|
self.processing_thread = Thread(target=self._process_data, daemon=True)
|
|
self.processing_thread.start()
|
|
|
|
# Start analysis thread
|
|
self.analysis_thread = Thread(target=self._analyze_data, daemon=True)
|
|
self.analysis_thread.start()
|
|
|
|
logger.info("Realtime analysis started")
|
|
|
|
def stop(self):
|
|
"""Stop the realtime analysis process."""
|
|
self.running = False
|
|
if self.ws:
|
|
asyncio.run(self.ws.close())
|
|
if hasattr(self, 'ws_thread'):
|
|
self.ws_thread.join(timeout=1)
|
|
if hasattr(self, 'processing_thread'):
|
|
self.processing_thread.join(timeout=1)
|
|
if hasattr(self, 'analysis_thread'):
|
|
self.analysis_thread.join(timeout=1)
|
|
logger.info("Realtime analysis stopped")
|
|
|
|
def _run_websocket(self):
|
|
"""Thread function for running WebSocket connection."""
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
loop.run_until_complete(self._connect_websocket())
|
|
|
|
async def _connect_websocket(self):
|
|
"""Connect to WebSocket and receive data."""
|
|
while self.running:
|
|
try:
|
|
logger.info(f"Connecting to WebSocket: {self.ws_url}")
|
|
async with websockets.connect(self.ws_url) as ws:
|
|
self.ws = ws
|
|
logger.info("WebSocket connected")
|
|
|
|
while self.running:
|
|
try:
|
|
message = await ws.recv()
|
|
data = json.loads(message)
|
|
|
|
if 'e' in data and data['e'] == 'trade':
|
|
tick = {
|
|
'timestamp': data['T'],
|
|
'price': float(data['p']),
|
|
'volume': float(data['q']),
|
|
'symbol': self.symbol
|
|
}
|
|
self.tick_storage.append(tick)
|
|
self.data_queue.put(tick)
|
|
|
|
except websockets.exceptions.ConnectionClosed:
|
|
logger.warning("WebSocket connection closed")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error receiving WebSocket message: {str(e)}")
|
|
time.sleep(1)
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket connection error: {str(e)}")
|
|
time.sleep(5) # Wait before reconnecting
|
|
|
|
def _process_data(self):
|
|
"""Process incoming tick data into candles for all timeframes."""
|
|
logger.info("Starting data processing thread")
|
|
|
|
while self.running:
|
|
try:
|
|
# Process any new ticks
|
|
while not self.data_queue.empty():
|
|
tick = self.data_queue.get()
|
|
|
|
# Convert timestamp to datetime
|
|
timestamp = datetime.fromtimestamp(tick['timestamp'] / 1000)
|
|
|
|
# Process for each timeframe
|
|
for timeframe in self.timeframes:
|
|
interval = self._get_interval_seconds(timeframe)
|
|
if interval is None:
|
|
continue
|
|
|
|
# Round timestamp to nearest candle interval
|
|
candle_ts = int(tick['timestamp'] // (interval * 1000)) * (interval * 1000)
|
|
|
|
# Get or create candle for this timeframe
|
|
if not self.candle_cache[timeframe]:
|
|
# First candle for this timeframe
|
|
candle = {
|
|
'timestamp': candle_ts,
|
|
'open': tick['price'],
|
|
'high': tick['price'],
|
|
'low': tick['price'],
|
|
'close': tick['price'],
|
|
'volume': tick['volume']
|
|
}
|
|
self.candle_cache[timeframe].append(candle)
|
|
else:
|
|
# Update existing candle
|
|
last_candle = self.candle_cache[timeframe][-1]
|
|
|
|
if last_candle['timestamp'] == candle_ts:
|
|
# Update current candle
|
|
last_candle['high'] = max(last_candle['high'], tick['price'])
|
|
last_candle['low'] = min(last_candle['low'], tick['price'])
|
|
last_candle['close'] = tick['price']
|
|
last_candle['volume'] += tick['volume']
|
|
else:
|
|
# New candle
|
|
candle = {
|
|
'timestamp': candle_ts,
|
|
'open': tick['price'],
|
|
'high': tick['price'],
|
|
'low': tick['price'],
|
|
'close': tick['price'],
|
|
'volume': tick['volume']
|
|
}
|
|
self.candle_cache[timeframe].append(candle)
|
|
|
|
time.sleep(0.1)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in data processing: {str(e)}")
|
|
time.sleep(1)
|
|
|
|
def _get_interval_seconds(self, timeframe):
|
|
"""Convert timeframe string to seconds."""
|
|
intervals = {
|
|
'1s': 1,
|
|
'1m': 60,
|
|
'1h': 3600,
|
|
'1d': 86400
|
|
}
|
|
return intervals.get(timeframe)
|
|
|
|
def _analyze_data(self):
|
|
"""Thread function for analyzing data and generating signals."""
|
|
logger.info("Starting analysis thread")
|
|
|
|
last_prediction_time = 0
|
|
|
|
while self.running:
|
|
try:
|
|
current_time = time.time()
|
|
|
|
# Only make predictions at the specified interval
|
|
if current_time - last_prediction_time < self.prediction_interval:
|
|
time.sleep(0.1)
|
|
continue
|
|
|
|
# Prepare input data from all timeframes
|
|
input_data = {}
|
|
valid = True
|
|
|
|
for timeframe in self.timeframes:
|
|
if not self.candle_cache[timeframe]:
|
|
logger.warning(f"No data available for timeframe {timeframe}")
|
|
valid = False
|
|
break
|
|
|
|
# Get last N candles for this timeframe
|
|
candles = list(self.candle_cache[timeframe])[-self.data_interface.window_size:]
|
|
|
|
# Convert to numpy array
|
|
ohlcv = np.array([
|
|
[c['open'], c['high'], c['low'], c['close'], c['volume']]
|
|
for c in candles
|
|
])
|
|
|
|
# Normalize data
|
|
ohlcv_normalized = (ohlcv - ohlcv.mean(axis=0)) / (ohlcv.std(axis=0) + 1e-8)
|
|
input_data[timeframe] = ohlcv_normalized
|
|
|
|
if not valid:
|
|
time.sleep(0.1)
|
|
continue
|
|
|
|
# Make prediction using the model
|
|
try:
|
|
prediction = self.model.predict(input_data)
|
|
|
|
# Get latest timestamp from 1s timeframe
|
|
latest_ts = self.candle_cache['1s'][-1]['timestamp'] if self.candle_cache['1s'] else int(time.time() * 1000)
|
|
|
|
# Process prediction
|
|
self._process_prediction(
|
|
prediction=prediction,
|
|
timeframe='multi',
|
|
timestamp=latest_ts
|
|
)
|
|
|
|
last_prediction_time = current_time
|
|
except Exception as e:
|
|
logger.error(f"Error making prediction: {str(e)}")
|
|
|
|
time.sleep(0.1)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in analysis: {str(e)}")
|
|
time.sleep(1)
|
|
|
|
def _process_prediction(self, prediction, timeframe, timestamp):
|
|
"""
|
|
Process model prediction and generate trading signals.
|
|
|
|
Args:
|
|
prediction: Model prediction output
|
|
timeframe (str): Timeframe the prediction is for ('multi' for combined)
|
|
timestamp: Timestamp of the prediction (ms)
|
|
"""
|
|
# Convert prediction to trading signal
|
|
signal, confidence = self._prediction_to_signal(prediction)
|
|
|
|
# Convert timestamp to datetime
|
|
try:
|
|
dt = datetime.fromtimestamp(timestamp / 1000)
|
|
except:
|
|
dt = datetime.now()
|
|
|
|
# Log the signal with all timeframes
|
|
logger.info(
|
|
f"Signal generated - Timeframes: {', '.join(self.timeframes)}, "
|
|
f"Timestamp: {dt}, "
|
|
f"Signal: {signal} (Confidence: {confidence:.2f})"
|
|
)
|
|
|
|
# In a real implementation, we would execute trades here
|
|
# For now, we'll just log the signals
|
|
|
|
def _prediction_to_signal(self, prediction):
|
|
"""
|
|
Convert model prediction to trading signal and confidence.
|
|
|
|
Args:
|
|
prediction: Model prediction output (can be dict for multi-timeframe)
|
|
|
|
Returns:
|
|
tuple: (signal, confidence) where signal is BUY/SELL/HOLD,
|
|
confidence is probability (0-1)
|
|
"""
|
|
if isinstance(prediction, dict):
|
|
# Multi-timeframe prediction - combine signals
|
|
signals = []
|
|
confidences = []
|
|
|
|
for tf, pred in prediction.items():
|
|
if len(pred.shape) == 1:
|
|
# Binary classification
|
|
signal = "BUY" if pred[0] > 0.5 else "SELL"
|
|
confidence = pred[0] if signal == "BUY" else 1 - pred[0]
|
|
else:
|
|
# Multi-class
|
|
class_idx = np.argmax(pred)
|
|
signal = ["SELL", "HOLD", "BUY"][class_idx]
|
|
confidence = pred[class_idx]
|
|
|
|
signals.append(signal)
|
|
confidences.append(confidence)
|
|
|
|
# Simple voting system - count BUY/SELL signals
|
|
buy_count = signals.count("BUY")
|
|
sell_count = signals.count("SELL")
|
|
|
|
if buy_count > sell_count:
|
|
final_signal = "BUY"
|
|
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "BUY"])
|
|
elif sell_count > buy_count:
|
|
final_signal = "SELL"
|
|
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "SELL"])
|
|
else:
|
|
final_signal = "HOLD"
|
|
final_confidence = np.mean(confidences)
|
|
|
|
return final_signal, final_confidence
|
|
|
|
else:
|
|
# Single prediction
|
|
if len(prediction.shape) == 1:
|
|
# Binary classification
|
|
signal = "BUY" if prediction[0] > 0.5 else "SELL"
|
|
confidence = prediction[0] if signal == "BUY" else 1 - prediction[0]
|
|
else:
|
|
# Multi-class
|
|
class_idx = np.argmax(prediction)
|
|
signal = ["SELL", "HOLD", "BUY"][class_idx]
|
|
confidence = prediction[class_idx]
|
|
|
|
return signal, confidence
|