working training
This commit is contained in:
parent
0b2000e3e7
commit
2255a8363a
@ -24,60 +24,67 @@ logger = logging.getLogger(__name__)
|
|||||||
class CNNPyTorch(nn.Module):
|
class CNNPyTorch(nn.Module):
|
||||||
"""PyTorch CNN model for time series analysis"""
|
"""PyTorch CNN model for time series analysis"""
|
||||||
|
|
||||||
def __init__(self, input_shape, output_size=3):
|
def __init__(self, input_shape, output_size=5):
|
||||||
"""
|
"""
|
||||||
Initialize the CNN model.
|
Initialize the enhanced CNN model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_shape (tuple): Shape of input data (window_size, features)
|
input_shape (tuple): Shape of input data (window_size, features)
|
||||||
output_size (int): Size of output (1 for regression, 3 for classification)
|
output_size (int): Always 5 for our trading signals
|
||||||
"""
|
"""
|
||||||
super(CNNPyTorch, self).__init__()
|
super(CNNPyTorch, self).__init__()
|
||||||
|
|
||||||
window_size, num_features = input_shape
|
window_size, num_features = input_shape
|
||||||
|
kernel_size = 5
|
||||||
# Architecture parameters
|
|
||||||
filters = [32, 64, 128]
|
|
||||||
kernel_sizes = [3, 5, 7]
|
|
||||||
lstm_units = 100
|
|
||||||
dense_units = 64
|
|
||||||
dropout_rate = 0.3
|
dropout_rate = 0.3
|
||||||
|
|
||||||
# Create parallel convolutional pathways
|
# Enhanced CNN Architecture
|
||||||
self.conv_paths = nn.ModuleList()
|
self.conv_layers = nn.Sequential(
|
||||||
|
# Block 1
|
||||||
for f, k in zip(filters, kernel_sizes):
|
nn.Conv1d(num_features, 64, kernel_size, padding='same'),
|
||||||
path = nn.Sequential(
|
nn.BatchNorm1d(64),
|
||||||
nn.Conv1d(num_features, f, kernel_size=k, padding='same'),
|
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.BatchNorm1d(f),
|
|
||||||
nn.MaxPool1d(kernel_size=2, stride=1, padding=1),
|
|
||||||
nn.Dropout(dropout_rate)
|
|
||||||
)
|
|
||||||
self.conv_paths.append(path)
|
|
||||||
|
|
||||||
# Calculate output size from conv paths
|
# Block 2
|
||||||
conv_output_size = sum(filters) * window_size
|
nn.Conv1d(64, 128, kernel_size, padding='same'),
|
||||||
|
nn.BatchNorm1d(128),
|
||||||
# LSTM layer
|
|
||||||
self.lstm = nn.LSTM(
|
|
||||||
input_size=sum(filters),
|
|
||||||
hidden_size=lstm_units,
|
|
||||||
batch_first=True,
|
|
||||||
bidirectional=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Dense layers
|
|
||||||
self.flatten = nn.Flatten()
|
|
||||||
self.dense1 = nn.Sequential(
|
|
||||||
nn.Linear(lstm_units * 2 * window_size, dense_units),
|
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.BatchNorm1d(dense_units),
|
nn.MaxPool1d(2),
|
||||||
nn.Dropout(dropout_rate)
|
|
||||||
|
# Block 3
|
||||||
|
nn.Conv1d(128, 256, kernel_size, padding='same'),
|
||||||
|
nn.BatchNorm1d(256),
|
||||||
|
nn.ReLU(),
|
||||||
|
|
||||||
|
# Block 4
|
||||||
|
nn.Conv1d(256, 512, kernel_size, padding='same'),
|
||||||
|
nn.BatchNorm1d(512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.MaxPool1d(2)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Output layer
|
# Calculate flattened size after conv and pooling
|
||||||
self.output = nn.Linear(dense_units, output_size)
|
conv_output_size = 512 * (window_size // 4)
|
||||||
|
|
||||||
|
# Enhanced dense layers
|
||||||
|
self.dense_block = nn.Sequential(
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(conv_output_size, 512),
|
||||||
|
nn.BatchNorm1d(512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(dropout_rate),
|
||||||
|
|
||||||
|
nn.Linear(512, 256),
|
||||||
|
nn.BatchNorm1d(256),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(dropout_rate),
|
||||||
|
|
||||||
|
nn.Linear(256, 128),
|
||||||
|
nn.BatchNorm1d(128),
|
||||||
|
nn.ReLU(),
|
||||||
|
|
||||||
|
nn.Linear(128, output_size)
|
||||||
|
)
|
||||||
|
|
||||||
# Activation based on output size
|
# Activation based on output size
|
||||||
if output_size == 1:
|
if output_size == 1:
|
||||||
@ -89,7 +96,7 @@ class CNNPyTorch(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
Forward pass through the network.
|
Forward pass through enhanced network.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Input tensor of shape [batch_size, window_size, features]
|
x: Input tensor of shape [batch_size, window_size, features]
|
||||||
@ -97,35 +104,15 @@ class CNNPyTorch(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Output tensor of shape [batch_size, output_size]
|
Output tensor of shape [batch_size, output_size]
|
||||||
"""
|
"""
|
||||||
batch_size, window_size, num_features = x.shape
|
|
||||||
|
|
||||||
# Transpose for conv1d: [batch, features, window]
|
# Transpose for conv1d: [batch, features, window]
|
||||||
x_t = x.transpose(1, 2)
|
x_t = x.transpose(1, 2)
|
||||||
|
|
||||||
# Process through parallel conv paths
|
# Process through all CNN layers
|
||||||
conv_outputs = []
|
conv_out = self.conv_layers(x_t)
|
||||||
for path in self.conv_paths:
|
|
||||||
conv_outputs.append(path(x_t))
|
|
||||||
|
|
||||||
# Concatenate conv outputs
|
# Process through dense layers
|
||||||
conv_concat = torch.cat(conv_outputs, dim=1)
|
output = self.dense_block(conv_out)
|
||||||
|
|
||||||
# Transpose back for LSTM: [batch, window, features]
|
|
||||||
conv_concat = conv_concat.transpose(1, 2)
|
|
||||||
|
|
||||||
# LSTM processing
|
|
||||||
lstm_out, _ = self.lstm(conv_concat)
|
|
||||||
|
|
||||||
# Flatten
|
|
||||||
flattened = self.flatten(lstm_out)
|
|
||||||
|
|
||||||
# Dense processing
|
|
||||||
dense_out = self.dense1(flattened)
|
|
||||||
|
|
||||||
# Output
|
|
||||||
output = self.output(dense_out)
|
|
||||||
|
|
||||||
# Apply activation
|
|
||||||
return self.activation(output)
|
return self.activation(output)
|
||||||
|
|
||||||
|
|
||||||
@ -137,7 +124,7 @@ class CNNModelPyTorch:
|
|||||||
predictions with the CNN model.
|
predictions with the CNN model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, window_size, num_features, output_size=3, timeframes=None):
|
def __init__(self, window_size, num_features, output_size=5, timeframes=None):
|
||||||
"""
|
"""
|
||||||
Initialize the CNN model.
|
Initialize the CNN model.
|
||||||
|
|
||||||
@ -506,41 +493,27 @@ class CNNModelPyTorch:
|
|||||||
|
|
||||||
def extract_hidden_features(self, X):
|
def extract_hidden_features(self, X):
|
||||||
"""
|
"""
|
||||||
Extract hidden features from the model.
|
Extract hidden features from the model - outputs from last dense layer before output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
X: Input data
|
X: Input data
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Hidden features
|
Hidden features (output from penultimate dense layer)
|
||||||
"""
|
"""
|
||||||
# Convert to PyTorch tensor
|
# Convert to PyTorch tensor
|
||||||
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
|
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
|
||||||
|
|
||||||
# Forward pass through the model up to the last hidden layer
|
# Forward pass through the model
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Get features before the output layer
|
# Get features through CNN layers
|
||||||
x_t = X_tensor.transpose(1, 2)
|
x_t = X_tensor.transpose(1, 2)
|
||||||
|
conv_out = self.model.conv_layers(x_t)
|
||||||
|
|
||||||
# Process through parallel conv paths
|
# Process through all dense layers except the output layer
|
||||||
conv_outputs = []
|
features = conv_out
|
||||||
for path in self.model.conv_paths:
|
for layer in self.model.dense_block[:-2]: # Exclude last linear layer and dropout
|
||||||
conv_outputs.append(path(x_t))
|
features = layer(features)
|
||||||
|
|
||||||
# Concatenate conv outputs
|
return features.cpu().numpy()
|
||||||
conv_concat = torch.cat(conv_outputs, dim=1)
|
|
||||||
|
|
||||||
# Transpose back for LSTM
|
|
||||||
conv_concat = conv_concat.transpose(1, 2)
|
|
||||||
|
|
||||||
# LSTM processing
|
|
||||||
lstm_out, _ = self.model.lstm(conv_concat)
|
|
||||||
|
|
||||||
# Flatten
|
|
||||||
flattened = self.model.flatten(lstm_out)
|
|
||||||
|
|
||||||
# Dense processing
|
|
||||||
hidden_features = self.model.dense1(flattened)
|
|
||||||
|
|
||||||
return hidden_features.cpu().numpy()
|
|
||||||
|
@ -118,9 +118,7 @@ def main():
|
|||||||
logger.info("Initializing data interface...")
|
logger.info("Initializing data interface...")
|
||||||
data_interface = DataInterface(
|
data_interface = DataInterface(
|
||||||
symbol=args.symbol,
|
symbol=args.symbol,
|
||||||
timeframes=args.timeframes,
|
timeframes=args.timeframes
|
||||||
window_size=args.window_size,
|
|
||||||
output_size=args.output_size
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize data interface: {str(e)}")
|
logger.error(f"Failed to initialize data interface: {str(e)}")
|
||||||
|
@ -10,6 +10,12 @@ import numpy as np
|
|||||||
from threading import Thread
|
from threading import Thread
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import asyncio
|
||||||
|
import websockets
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -19,7 +25,8 @@ class RealtimeAnalyzer:
|
|||||||
|
|
||||||
Features:
|
Features:
|
||||||
- Connects to real-time data sources (websockets)
|
- Connects to real-time data sources (websockets)
|
||||||
- Processes incoming data through the neural network
|
- Processes tick data into multiple timeframes (1s, 1m, 1h, 1d)
|
||||||
|
- Uses trained models to analyze all timeframes
|
||||||
- Generates trading signals
|
- Generates trading signals
|
||||||
- Manages risk and position sizing
|
- Manages risk and position sizing
|
||||||
- Logs all trading decisions
|
- Logs all trading decisions
|
||||||
@ -33,17 +40,26 @@ class RealtimeAnalyzer:
|
|||||||
data_interface (DataInterface): Preconfigured data interface
|
data_interface (DataInterface): Preconfigured data interface
|
||||||
model: Trained neural network model
|
model: Trained neural network model
|
||||||
symbol (str): Trading pair symbol
|
symbol (str): Trading pair symbol
|
||||||
timeframes (list): List of timeframes to monitor
|
timeframes (list): List of timeframes to monitor (default: ['1s', '1m', '1h', '1d'])
|
||||||
"""
|
"""
|
||||||
self.data_interface = data_interface
|
self.data_interface = data_interface
|
||||||
self.model = model
|
self.model = model
|
||||||
self.symbol = symbol
|
self.symbol = symbol
|
||||||
self.timeframes = timeframes or ['1h']
|
self.timeframes = timeframes or ['1s', '1m', '1h', '1d']
|
||||||
self.running = False
|
self.running = False
|
||||||
self.data_queue = Queue()
|
self.data_queue = Queue()
|
||||||
self.prediction_interval = 60 # Seconds between predictions
|
self.prediction_interval = 10 # Seconds between predictions
|
||||||
|
self.ws_url = f"wss://stream.binance.com:9443/ws/{symbol.replace('/', '').lower()}@trade"
|
||||||
|
self.ws = None
|
||||||
|
self.tick_storage = deque(maxlen=10000) # Store up to 10,000 ticks
|
||||||
|
self.candle_cache = {
|
||||||
|
'1s': deque(maxlen=5000),
|
||||||
|
'1m': deque(maxlen=5000),
|
||||||
|
'1h': deque(maxlen=5000),
|
||||||
|
'1d': deque(maxlen=5000)
|
||||||
|
}
|
||||||
|
|
||||||
logger.info(f"RealtimeAnalyzer initialized for {symbol}")
|
logger.info(f"RealtimeAnalyzer initialized for {symbol} with timeframes: {self.timeframes}")
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Start the realtime analysis process."""
|
"""Start the realtime analysis process."""
|
||||||
@ -53,9 +69,13 @@ class RealtimeAnalyzer:
|
|||||||
|
|
||||||
self.running = True
|
self.running = True
|
||||||
|
|
||||||
# Start data collection thread
|
# Start WebSocket connection thread
|
||||||
self.data_thread = Thread(target=self._collect_data, daemon=True)
|
self.ws_thread = Thread(target=self._run_websocket, daemon=True)
|
||||||
self.data_thread.start()
|
self.ws_thread.start()
|
||||||
|
|
||||||
|
# Start data processing thread
|
||||||
|
self.processing_thread = Thread(target=self._process_data, daemon=True)
|
||||||
|
self.processing_thread.start()
|
||||||
|
|
||||||
# Start analysis thread
|
# Start analysis thread
|
||||||
self.analysis_thread = Thread(target=self._analyze_data, daemon=True)
|
self.analysis_thread = Thread(target=self._analyze_data, daemon=True)
|
||||||
@ -66,42 +86,128 @@ class RealtimeAnalyzer:
|
|||||||
def stop(self):
|
def stop(self):
|
||||||
"""Stop the realtime analysis process."""
|
"""Stop the realtime analysis process."""
|
||||||
self.running = False
|
self.running = False
|
||||||
if hasattr(self, 'data_thread'):
|
if self.ws:
|
||||||
self.data_thread.join(timeout=1)
|
asyncio.run(self.ws.close())
|
||||||
|
if hasattr(self, 'ws_thread'):
|
||||||
|
self.ws_thread.join(timeout=1)
|
||||||
|
if hasattr(self, 'processing_thread'):
|
||||||
|
self.processing_thread.join(timeout=1)
|
||||||
if hasattr(self, 'analysis_thread'):
|
if hasattr(self, 'analysis_thread'):
|
||||||
self.analysis_thread.join(timeout=1)
|
self.analysis_thread.join(timeout=1)
|
||||||
logger.info("Realtime analysis stopped")
|
logger.info("Realtime analysis stopped")
|
||||||
|
|
||||||
def _collect_data(self):
|
def _run_websocket(self):
|
||||||
"""Thread function for collecting real-time data."""
|
"""Thread function for running WebSocket connection."""
|
||||||
logger.info("Starting data collection thread")
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.run_until_complete(self._connect_websocket())
|
||||||
|
|
||||||
# In a real implementation, this would connect to websockets/API
|
async def _connect_websocket(self):
|
||||||
# For now, we'll simulate data collection from the data interface
|
"""Connect to WebSocket and receive data."""
|
||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
# Get latest data for each timeframe
|
logger.info(f"Connecting to WebSocket: {self.ws_url}")
|
||||||
for timeframe in self.timeframes:
|
async with websockets.connect(self.ws_url) as ws:
|
||||||
# Get recent data (simulating real-time updates)
|
self.ws = ws
|
||||||
X, timestamp = self.data_interface.prepare_realtime_input(
|
logger.info("WebSocket connected")
|
||||||
timeframe=timeframe,
|
|
||||||
n_candles=30,
|
|
||||||
window_size=self.data_interface.window_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if X is not None:
|
while self.running:
|
||||||
self.data_queue.put({
|
try:
|
||||||
'timeframe': timeframe,
|
message = await ws.recv()
|
||||||
'data': X,
|
data = json.loads(message)
|
||||||
'timestamp': timestamp
|
|
||||||
})
|
|
||||||
|
|
||||||
# Throttle data collection
|
if 'e' in data and data['e'] == 'trade':
|
||||||
|
tick = {
|
||||||
|
'timestamp': data['T'],
|
||||||
|
'price': float(data['p']),
|
||||||
|
'volume': float(data['q']),
|
||||||
|
'symbol': self.symbol
|
||||||
|
}
|
||||||
|
self.tick_storage.append(tick)
|
||||||
|
self.data_queue.put(tick)
|
||||||
|
|
||||||
|
except websockets.exceptions.ConnectionClosed:
|
||||||
|
logger.warning("WebSocket connection closed")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error receiving WebSocket message: {str(e)}")
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in data collection: {str(e)}")
|
logger.error(f"WebSocket connection error: {str(e)}")
|
||||||
time.sleep(5) # Wait before retrying
|
time.sleep(5) # Wait before reconnecting
|
||||||
|
|
||||||
|
def _process_data(self):
|
||||||
|
"""Process incoming tick data into candles for all timeframes."""
|
||||||
|
logger.info("Starting data processing thread")
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
# Process any new ticks
|
||||||
|
while not self.data_queue.empty():
|
||||||
|
tick = self.data_queue.get()
|
||||||
|
|
||||||
|
# Convert timestamp to datetime
|
||||||
|
timestamp = datetime.fromtimestamp(tick['timestamp'] / 1000)
|
||||||
|
|
||||||
|
# Process for each timeframe
|
||||||
|
for timeframe in self.timeframes:
|
||||||
|
interval = self._get_interval_seconds(timeframe)
|
||||||
|
if interval is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Round timestamp to nearest candle interval
|
||||||
|
candle_ts = int(tick['timestamp'] // (interval * 1000)) * (interval * 1000)
|
||||||
|
|
||||||
|
# Get or create candle for this timeframe
|
||||||
|
if not self.candle_cache[timeframe]:
|
||||||
|
# First candle for this timeframe
|
||||||
|
candle = {
|
||||||
|
'timestamp': candle_ts,
|
||||||
|
'open': tick['price'],
|
||||||
|
'high': tick['price'],
|
||||||
|
'low': tick['price'],
|
||||||
|
'close': tick['price'],
|
||||||
|
'volume': tick['volume']
|
||||||
|
}
|
||||||
|
self.candle_cache[timeframe].append(candle)
|
||||||
|
else:
|
||||||
|
# Update existing candle
|
||||||
|
last_candle = self.candle_cache[timeframe][-1]
|
||||||
|
|
||||||
|
if last_candle['timestamp'] == candle_ts:
|
||||||
|
# Update current candle
|
||||||
|
last_candle['high'] = max(last_candle['high'], tick['price'])
|
||||||
|
last_candle['low'] = min(last_candle['low'], tick['price'])
|
||||||
|
last_candle['close'] = tick['price']
|
||||||
|
last_candle['volume'] += tick['volume']
|
||||||
|
else:
|
||||||
|
# New candle
|
||||||
|
candle = {
|
||||||
|
'timestamp': candle_ts,
|
||||||
|
'open': tick['price'],
|
||||||
|
'high': tick['price'],
|
||||||
|
'low': tick['price'],
|
||||||
|
'close': tick['price'],
|
||||||
|
'volume': tick['volume']
|
||||||
|
}
|
||||||
|
self.candle_cache[timeframe].append(candle)
|
||||||
|
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in data processing: {str(e)}")
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
def _get_interval_seconds(self, timeframe):
|
||||||
|
"""Convert timeframe string to seconds."""
|
||||||
|
intervals = {
|
||||||
|
'1s': 1,
|
||||||
|
'1m': 60,
|
||||||
|
'1h': 3600,
|
||||||
|
'1d': 86400
|
||||||
|
}
|
||||||
|
return intervals.get(timeframe)
|
||||||
|
|
||||||
def _analyze_data(self):
|
def _analyze_data(self):
|
||||||
"""Thread function for analyzing data and generating signals."""
|
"""Thread function for analyzing data and generating signals."""
|
||||||
@ -118,27 +224,56 @@ class RealtimeAnalyzer:
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get latest data from queue
|
# Prepare input data from all timeframes
|
||||||
if not self.data_queue.empty():
|
input_data = {}
|
||||||
data_item = self.data_queue.get()
|
valid = True
|
||||||
|
|
||||||
# Make prediction
|
for timeframe in self.timeframes:
|
||||||
prediction = self.model.predict(data_item['data'])
|
if not self.candle_cache[timeframe]:
|
||||||
|
logger.warning(f"No data available for timeframe {timeframe}")
|
||||||
|
valid = False
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get last N candles for this timeframe
|
||||||
|
candles = list(self.candle_cache[timeframe])[-self.data_interface.window_size:]
|
||||||
|
|
||||||
|
# Convert to numpy array
|
||||||
|
ohlcv = np.array([
|
||||||
|
[c['open'], c['high'], c['low'], c['close'], c['volume']]
|
||||||
|
for c in candles
|
||||||
|
])
|
||||||
|
|
||||||
|
# Normalize data
|
||||||
|
ohlcv_normalized = (ohlcv - ohlcv.mean(axis=0)) / (ohlcv.std(axis=0) + 1e-8)
|
||||||
|
input_data[timeframe] = ohlcv_normalized
|
||||||
|
|
||||||
|
if not valid:
|
||||||
|
time.sleep(0.1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Make prediction using the model
|
||||||
|
try:
|
||||||
|
prediction = self.model.predict(input_data)
|
||||||
|
|
||||||
|
# Get latest timestamp from 1s timeframe
|
||||||
|
latest_ts = self.candle_cache['1s'][-1]['timestamp'] if self.candle_cache['1s'] else int(time.time() * 1000)
|
||||||
|
|
||||||
# Process prediction
|
# Process prediction
|
||||||
self._process_prediction(
|
self._process_prediction(
|
||||||
prediction=prediction,
|
prediction=prediction,
|
||||||
timeframe=data_item['timeframe'],
|
timeframe='multi',
|
||||||
timestamp=data_item['timestamp']
|
timestamp=latest_ts
|
||||||
)
|
)
|
||||||
|
|
||||||
last_prediction_time = current_time
|
last_prediction_time = current_time
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error making prediction: {str(e)}")
|
||||||
|
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in analysis: {str(e)}")
|
logger.error(f"Error in analysis: {str(e)}")
|
||||||
time.sleep(1) # Wait before retrying
|
time.sleep(1)
|
||||||
|
|
||||||
def _process_prediction(self, prediction, timeframe, timestamp):
|
def _process_prediction(self, prediction, timeframe, timestamp):
|
||||||
"""
|
"""
|
||||||
@ -146,17 +281,23 @@ class RealtimeAnalyzer:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
prediction: Model prediction output
|
prediction: Model prediction output
|
||||||
timeframe (str): Timeframe the prediction is for
|
timeframe (str): Timeframe the prediction is for ('multi' for combined)
|
||||||
timestamp: Timestamp of the prediction
|
timestamp: Timestamp of the prediction (ms)
|
||||||
"""
|
"""
|
||||||
# Convert prediction to trading signal
|
# Convert prediction to trading signal
|
||||||
signal = self._prediction_to_signal(prediction)
|
signal, confidence = self._prediction_to_signal(prediction)
|
||||||
|
|
||||||
# Log the signal
|
# Convert timestamp to datetime
|
||||||
|
try:
|
||||||
|
dt = datetime.fromtimestamp(timestamp / 1000)
|
||||||
|
except:
|
||||||
|
dt = datetime.now()
|
||||||
|
|
||||||
|
# Log the signal with all timeframes
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Signal generated - Timeframe: {timeframe}, "
|
f"Signal generated - Timeframes: {', '.join(self.timeframes)}, "
|
||||||
f"Timestamp: {timestamp}, "
|
f"Timestamp: {dt}, "
|
||||||
f"Signal: {signal}"
|
f"Signal: {signal} (Confidence: {confidence:.2f})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# In a real implementation, we would execute trades here
|
# In a real implementation, we would execute trades here
|
||||||
@ -164,19 +305,60 @@ class RealtimeAnalyzer:
|
|||||||
|
|
||||||
def _prediction_to_signal(self, prediction):
|
def _prediction_to_signal(self, prediction):
|
||||||
"""
|
"""
|
||||||
Convert model prediction to trading signal.
|
Convert model prediction to trading signal and confidence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prediction: Model prediction output
|
prediction: Model prediction output (can be dict for multi-timeframe)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Trading signal (BUY, SELL, HOLD)
|
tuple: (signal, confidence) where signal is BUY/SELL/HOLD,
|
||||||
|
confidence is probability (0-1)
|
||||||
"""
|
"""
|
||||||
# Simple threshold-based signal generation
|
if isinstance(prediction, dict):
|
||||||
|
# Multi-timeframe prediction - combine signals
|
||||||
|
signals = []
|
||||||
|
confidences = []
|
||||||
|
|
||||||
|
for tf, pred in prediction.items():
|
||||||
|
if len(pred.shape) == 1:
|
||||||
|
# Binary classification
|
||||||
|
signal = "BUY" if pred[0] > 0.5 else "SELL"
|
||||||
|
confidence = pred[0] if signal == "BUY" else 1 - pred[0]
|
||||||
|
else:
|
||||||
|
# Multi-class
|
||||||
|
class_idx = np.argmax(pred)
|
||||||
|
signal = ["SELL", "HOLD", "BUY"][class_idx]
|
||||||
|
confidence = pred[class_idx]
|
||||||
|
|
||||||
|
signals.append(signal)
|
||||||
|
confidences.append(confidence)
|
||||||
|
|
||||||
|
# Simple voting system - count BUY/SELL signals
|
||||||
|
buy_count = signals.count("BUY")
|
||||||
|
sell_count = signals.count("SELL")
|
||||||
|
|
||||||
|
if buy_count > sell_count:
|
||||||
|
final_signal = "BUY"
|
||||||
|
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "BUY"])
|
||||||
|
elif sell_count > buy_count:
|
||||||
|
final_signal = "SELL"
|
||||||
|
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "SELL"])
|
||||||
|
else:
|
||||||
|
final_signal = "HOLD"
|
||||||
|
final_confidence = np.mean(confidences)
|
||||||
|
|
||||||
|
return final_signal, final_confidence
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Single prediction
|
||||||
if len(prediction.shape) == 1:
|
if len(prediction.shape) == 1:
|
||||||
# Binary classification
|
# Binary classification
|
||||||
return "BUY" if prediction[0] > 0.5 else "SELL"
|
signal = "BUY" if prediction[0] > 0.5 else "SELL"
|
||||||
|
confidence = prediction[0] if signal == "BUY" else 1 - prediction[0]
|
||||||
else:
|
else:
|
||||||
# Multi-class classification (3 outputs)
|
# Multi-class
|
||||||
class_idx = np.argmax(prediction)
|
class_idx = np.argmax(prediction)
|
||||||
return ["SELL", "HOLD", "BUY"][class_idx]
|
signal = ["SELL", "HOLD", "BUY"][class_idx]
|
||||||
|
confidence = prediction[class_idx]
|
||||||
|
|
||||||
|
return signal, confidence
|
||||||
|
@ -37,3 +37,10 @@ Backend tkagg is interactive backend. Turning interactive mode on.
|
|||||||
remodel our NN architecture. we should support up to 3 pairs simultaniously. so input can be 3 pairs: each pair will have up to 5 timeframes 1s(ticks, unspecified length), 1m, 1h, 1d + one additionall. we should normalize them in a way that preserves the relations between them (one price should be normalized to the same value across all tieframes). additionally to the 5 features OHLCV we will add up to 20 additional features for various technical indcators. 1s timeframe will be streamed in realtime. the MOE model should handle all that. we still need to access latest of the CNN hidden layers in the MOe model so we can extract learned features recognition
|
remodel our NN architecture. we should support up to 3 pairs simultaniously. so input can be 3 pairs: each pair will have up to 5 timeframes 1s(ticks, unspecified length), 1m, 1h, 1d + one additionall. we should normalize them in a way that preserves the relations between them (one price should be normalized to the same value across all tieframes). additionally to the 5 features OHLCV we will add up to 20 additional features for various technical indcators. 1s timeframe will be streamed in realtime. the MOE model should handle all that. we still need to access latest of the CNN hidden layers in the MOe model so we can extract learned features recognition
|
||||||
.
|
.
|
||||||
now let's run our "NN Training Pipeline" debug config. for now we start with single pair - BTC/USD. later we'll add up to 3 pairs for context. the NN will always have only 1 "main" pair - where the buy/sell actions are applied and which price prediction is calculater for each frame. we'll also try to predict the next local extrema that will help us be profitable
|
now let's run our "NN Training Pipeline" debug config. for now we start with single pair - BTC/USD. later we'll add up to 3 pairs for context. the NN will always have only 1 "main" pair - where the buy/sell actions are applied and which price prediction is calculater for each frame. we'll also try to predict the next local extrema that will help us be profitable
|
||||||
|
|
||||||
|
python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --framework pytorch
|
||||||
|
python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --framework pytorch --epochs 1000
|
||||||
|
python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --framework pytorch --epochs 1000 --symbol BTC/USDT --timeframes 1m 5m 1h 4h --epochs 10 --batch-size 32 --window-size 20 --output-size 3
|
||||||
|
|
||||||
|
python NN/realtime-main.py --mode train --model-type cnn --framework pytorch --symbol BTC/USDT --timeframes 1m 5m 1h 4h --epochs 10 --batch-size 32 --window-size 20 --output-size 3
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user