working training

This commit is contained in:
Dobromir Popov 2025-03-29 02:18:25 +02:00
parent 0b2000e3e7
commit 2255a8363a
4 changed files with 314 additions and 154 deletions

View File

@ -24,60 +24,67 @@ logger = logging.getLogger(__name__)
class CNNPyTorch(nn.Module):
"""PyTorch CNN model for time series analysis"""
def __init__(self, input_shape, output_size=3):
def __init__(self, input_shape, output_size=5):
"""
Initialize the CNN model.
Initialize the enhanced CNN model.
Args:
input_shape (tuple): Shape of input data (window_size, features)
output_size (int): Size of output (1 for regression, 3 for classification)
output_size (int): Always 5 for our trading signals
"""
super(CNNPyTorch, self).__init__()
window_size, num_features = input_shape
# Architecture parameters
filters = [32, 64, 128]
kernel_sizes = [3, 5, 7]
lstm_units = 100
dense_units = 64
kernel_size = 5
dropout_rate = 0.3
# Create parallel convolutional pathways
self.conv_paths = nn.ModuleList()
for f, k in zip(filters, kernel_sizes):
path = nn.Sequential(
nn.Conv1d(num_features, f, kernel_size=k, padding='same'),
# Enhanced CNN Architecture
self.conv_layers = nn.Sequential(
# Block 1
nn.Conv1d(num_features, 64, kernel_size, padding='same'),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.BatchNorm1d(f),
nn.MaxPool1d(kernel_size=2, stride=1, padding=1),
nn.Dropout(dropout_rate)
)
self.conv_paths.append(path)
# Calculate output size from conv paths
conv_output_size = sum(filters) * window_size
# LSTM layer
self.lstm = nn.LSTM(
input_size=sum(filters),
hidden_size=lstm_units,
batch_first=True,
bidirectional=True
)
# Dense layers
self.flatten = nn.Flatten()
self.dense1 = nn.Sequential(
nn.Linear(lstm_units * 2 * window_size, dense_units),
# Block 2
nn.Conv1d(64, 128, kernel_size, padding='same'),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.BatchNorm1d(dense_units),
nn.Dropout(dropout_rate)
nn.MaxPool1d(2),
# Block 3
nn.Conv1d(128, 256, kernel_size, padding='same'),
nn.BatchNorm1d(256),
nn.ReLU(),
# Block 4
nn.Conv1d(256, 512, kernel_size, padding='same'),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.MaxPool1d(2)
)
# Output layer
self.output = nn.Linear(dense_units, output_size)
# Calculate flattened size after conv and pooling
conv_output_size = 512 * (window_size // 4)
# Enhanced dense layers
self.dense_block = nn.Sequential(
nn.Flatten(),
nn.Linear(conv_output_size, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Linear(128, output_size)
)
# Activation based on output size
if output_size == 1:
@ -89,7 +96,7 @@ class CNNPyTorch(nn.Module):
def forward(self, x):
"""
Forward pass through the network.
Forward pass through enhanced network.
Args:
x: Input tensor of shape [batch_size, window_size, features]
@ -97,35 +104,15 @@ class CNNPyTorch(nn.Module):
Returns:
Output tensor of shape [batch_size, output_size]
"""
batch_size, window_size, num_features = x.shape
# Transpose for conv1d: [batch, features, window]
x_t = x.transpose(1, 2)
# Process through parallel conv paths
conv_outputs = []
for path in self.conv_paths:
conv_outputs.append(path(x_t))
# Process through all CNN layers
conv_out = self.conv_layers(x_t)
# Concatenate conv outputs
conv_concat = torch.cat(conv_outputs, dim=1)
# Process through dense layers
output = self.dense_block(conv_out)
# Transpose back for LSTM: [batch, window, features]
conv_concat = conv_concat.transpose(1, 2)
# LSTM processing
lstm_out, _ = self.lstm(conv_concat)
# Flatten
flattened = self.flatten(lstm_out)
# Dense processing
dense_out = self.dense1(flattened)
# Output
output = self.output(dense_out)
# Apply activation
return self.activation(output)
@ -137,7 +124,7 @@ class CNNModelPyTorch:
predictions with the CNN model.
"""
def __init__(self, window_size, num_features, output_size=3, timeframes=None):
def __init__(self, window_size, num_features, output_size=5, timeframes=None):
"""
Initialize the CNN model.
@ -506,41 +493,27 @@ class CNNModelPyTorch:
def extract_hidden_features(self, X):
"""
Extract hidden features from the model.
Extract hidden features from the model - outputs from last dense layer before output.
Args:
X: Input data
Returns:
Hidden features
Hidden features (output from penultimate dense layer)
"""
# Convert to PyTorch tensor
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
# Forward pass through the model up to the last hidden layer
# Forward pass through the model
self.model.eval()
with torch.no_grad():
# Get features before the output layer
# Get features through CNN layers
x_t = X_tensor.transpose(1, 2)
conv_out = self.model.conv_layers(x_t)
# Process through parallel conv paths
conv_outputs = []
for path in self.model.conv_paths:
conv_outputs.append(path(x_t))
# Process through all dense layers except the output layer
features = conv_out
for layer in self.model.dense_block[:-2]: # Exclude last linear layer and dropout
features = layer(features)
# Concatenate conv outputs
conv_concat = torch.cat(conv_outputs, dim=1)
# Transpose back for LSTM
conv_concat = conv_concat.transpose(1, 2)
# LSTM processing
lstm_out, _ = self.model.lstm(conv_concat)
# Flatten
flattened = self.model.flatten(lstm_out)
# Dense processing
hidden_features = self.model.dense1(flattened)
return hidden_features.cpu().numpy()
return features.cpu().numpy()

View File

@ -118,9 +118,7 @@ def main():
logger.info("Initializing data interface...")
data_interface = DataInterface(
symbol=args.symbol,
timeframes=args.timeframes,
window_size=args.window_size,
output_size=args.output_size
timeframes=args.timeframes
)
except Exception as e:
logger.error(f"Failed to initialize data interface: {str(e)}")

View File

@ -10,6 +10,12 @@ import numpy as np
from threading import Thread
from queue import Queue
from datetime import datetime
import asyncio
import websockets
import json
import os
import pandas as pd
from collections import deque
logger = logging.getLogger(__name__)
@ -19,7 +25,8 @@ class RealtimeAnalyzer:
Features:
- Connects to real-time data sources (websockets)
- Processes incoming data through the neural network
- Processes tick data into multiple timeframes (1s, 1m, 1h, 1d)
- Uses trained models to analyze all timeframes
- Generates trading signals
- Manages risk and position sizing
- Logs all trading decisions
@ -33,17 +40,26 @@ class RealtimeAnalyzer:
data_interface (DataInterface): Preconfigured data interface
model: Trained neural network model
symbol (str): Trading pair symbol
timeframes (list): List of timeframes to monitor
timeframes (list): List of timeframes to monitor (default: ['1s', '1m', '1h', '1d'])
"""
self.data_interface = data_interface
self.model = model
self.symbol = symbol
self.timeframes = timeframes or ['1h']
self.timeframes = timeframes or ['1s', '1m', '1h', '1d']
self.running = False
self.data_queue = Queue()
self.prediction_interval = 60 # Seconds between predictions
self.prediction_interval = 10 # Seconds between predictions
self.ws_url = f"wss://stream.binance.com:9443/ws/{symbol.replace('/', '').lower()}@trade"
self.ws = None
self.tick_storage = deque(maxlen=10000) # Store up to 10,000 ticks
self.candle_cache = {
'1s': deque(maxlen=5000),
'1m': deque(maxlen=5000),
'1h': deque(maxlen=5000),
'1d': deque(maxlen=5000)
}
logger.info(f"RealtimeAnalyzer initialized for {symbol}")
logger.info(f"RealtimeAnalyzer initialized for {symbol} with timeframes: {self.timeframes}")
def start(self):
"""Start the realtime analysis process."""
@ -53,9 +69,13 @@ class RealtimeAnalyzer:
self.running = True
# Start data collection thread
self.data_thread = Thread(target=self._collect_data, daemon=True)
self.data_thread.start()
# Start WebSocket connection thread
self.ws_thread = Thread(target=self._run_websocket, daemon=True)
self.ws_thread.start()
# Start data processing thread
self.processing_thread = Thread(target=self._process_data, daemon=True)
self.processing_thread.start()
# Start analysis thread
self.analysis_thread = Thread(target=self._analyze_data, daemon=True)
@ -66,42 +86,128 @@ class RealtimeAnalyzer:
def stop(self):
"""Stop the realtime analysis process."""
self.running = False
if hasattr(self, 'data_thread'):
self.data_thread.join(timeout=1)
if self.ws:
asyncio.run(self.ws.close())
if hasattr(self, 'ws_thread'):
self.ws_thread.join(timeout=1)
if hasattr(self, 'processing_thread'):
self.processing_thread.join(timeout=1)
if hasattr(self, 'analysis_thread'):
self.analysis_thread.join(timeout=1)
logger.info("Realtime analysis stopped")
def _collect_data(self):
"""Thread function for collecting real-time data."""
logger.info("Starting data collection thread")
def _run_websocket(self):
"""Thread function for running WebSocket connection."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._connect_websocket())
# In a real implementation, this would connect to websockets/API
# For now, we'll simulate data collection from the data interface
async def _connect_websocket(self):
"""Connect to WebSocket and receive data."""
while self.running:
try:
# Get latest data for each timeframe
for timeframe in self.timeframes:
# Get recent data (simulating real-time updates)
X, timestamp = self.data_interface.prepare_realtime_input(
timeframe=timeframe,
n_candles=30,
window_size=self.data_interface.window_size
)
logger.info(f"Connecting to WebSocket: {self.ws_url}")
async with websockets.connect(self.ws_url) as ws:
self.ws = ws
logger.info("WebSocket connected")
if X is not None:
self.data_queue.put({
'timeframe': timeframe,
'data': X,
'timestamp': timestamp
})
while self.running:
try:
message = await ws.recv()
data = json.loads(message)
# Throttle data collection
if 'e' in data and data['e'] == 'trade':
tick = {
'timestamp': data['T'],
'price': float(data['p']),
'volume': float(data['q']),
'symbol': self.symbol
}
self.tick_storage.append(tick)
self.data_queue.put(tick)
except websockets.exceptions.ConnectionClosed:
logger.warning("WebSocket connection closed")
break
except Exception as e:
logger.error(f"Error receiving WebSocket message: {str(e)}")
time.sleep(1)
except Exception as e:
logger.error(f"Error in data collection: {str(e)}")
time.sleep(5) # Wait before retrying
logger.error(f"WebSocket connection error: {str(e)}")
time.sleep(5) # Wait before reconnecting
def _process_data(self):
"""Process incoming tick data into candles for all timeframes."""
logger.info("Starting data processing thread")
while self.running:
try:
# Process any new ticks
while not self.data_queue.empty():
tick = self.data_queue.get()
# Convert timestamp to datetime
timestamp = datetime.fromtimestamp(tick['timestamp'] / 1000)
# Process for each timeframe
for timeframe in self.timeframes:
interval = self._get_interval_seconds(timeframe)
if interval is None:
continue
# Round timestamp to nearest candle interval
candle_ts = int(tick['timestamp'] // (interval * 1000)) * (interval * 1000)
# Get or create candle for this timeframe
if not self.candle_cache[timeframe]:
# First candle for this timeframe
candle = {
'timestamp': candle_ts,
'open': tick['price'],
'high': tick['price'],
'low': tick['price'],
'close': tick['price'],
'volume': tick['volume']
}
self.candle_cache[timeframe].append(candle)
else:
# Update existing candle
last_candle = self.candle_cache[timeframe][-1]
if last_candle['timestamp'] == candle_ts:
# Update current candle
last_candle['high'] = max(last_candle['high'], tick['price'])
last_candle['low'] = min(last_candle['low'], tick['price'])
last_candle['close'] = tick['price']
last_candle['volume'] += tick['volume']
else:
# New candle
candle = {
'timestamp': candle_ts,
'open': tick['price'],
'high': tick['price'],
'low': tick['price'],
'close': tick['price'],
'volume': tick['volume']
}
self.candle_cache[timeframe].append(candle)
time.sleep(0.1)
except Exception as e:
logger.error(f"Error in data processing: {str(e)}")
time.sleep(1)
def _get_interval_seconds(self, timeframe):
"""Convert timeframe string to seconds."""
intervals = {
'1s': 1,
'1m': 60,
'1h': 3600,
'1d': 86400
}
return intervals.get(timeframe)
def _analyze_data(self):
"""Thread function for analyzing data and generating signals."""
@ -118,27 +224,56 @@ class RealtimeAnalyzer:
time.sleep(0.1)
continue
# Get latest data from queue
if not self.data_queue.empty():
data_item = self.data_queue.get()
# Prepare input data from all timeframes
input_data = {}
valid = True
# Make prediction
prediction = self.model.predict(data_item['data'])
for timeframe in self.timeframes:
if not self.candle_cache[timeframe]:
logger.warning(f"No data available for timeframe {timeframe}")
valid = False
break
# Get last N candles for this timeframe
candles = list(self.candle_cache[timeframe])[-self.data_interface.window_size:]
# Convert to numpy array
ohlcv = np.array([
[c['open'], c['high'], c['low'], c['close'], c['volume']]
for c in candles
])
# Normalize data
ohlcv_normalized = (ohlcv - ohlcv.mean(axis=0)) / (ohlcv.std(axis=0) + 1e-8)
input_data[timeframe] = ohlcv_normalized
if not valid:
time.sleep(0.1)
continue
# Make prediction using the model
try:
prediction = self.model.predict(input_data)
# Get latest timestamp from 1s timeframe
latest_ts = self.candle_cache['1s'][-1]['timestamp'] if self.candle_cache['1s'] else int(time.time() * 1000)
# Process prediction
self._process_prediction(
prediction=prediction,
timeframe=data_item['timeframe'],
timestamp=data_item['timestamp']
timeframe='multi',
timestamp=latest_ts
)
last_prediction_time = current_time
except Exception as e:
logger.error(f"Error making prediction: {str(e)}")
time.sleep(0.1)
except Exception as e:
logger.error(f"Error in analysis: {str(e)}")
time.sleep(1) # Wait before retrying
time.sleep(1)
def _process_prediction(self, prediction, timeframe, timestamp):
"""
@ -146,17 +281,23 @@ class RealtimeAnalyzer:
Args:
prediction: Model prediction output
timeframe (str): Timeframe the prediction is for
timestamp: Timestamp of the prediction
timeframe (str): Timeframe the prediction is for ('multi' for combined)
timestamp: Timestamp of the prediction (ms)
"""
# Convert prediction to trading signal
signal = self._prediction_to_signal(prediction)
signal, confidence = self._prediction_to_signal(prediction)
# Log the signal
# Convert timestamp to datetime
try:
dt = datetime.fromtimestamp(timestamp / 1000)
except:
dt = datetime.now()
# Log the signal with all timeframes
logger.info(
f"Signal generated - Timeframe: {timeframe}, "
f"Timestamp: {timestamp}, "
f"Signal: {signal}"
f"Signal generated - Timeframes: {', '.join(self.timeframes)}, "
f"Timestamp: {dt}, "
f"Signal: {signal} (Confidence: {confidence:.2f})"
)
# In a real implementation, we would execute trades here
@ -164,19 +305,60 @@ class RealtimeAnalyzer:
def _prediction_to_signal(self, prediction):
"""
Convert model prediction to trading signal.
Convert model prediction to trading signal and confidence.
Args:
prediction: Model prediction output
prediction: Model prediction output (can be dict for multi-timeframe)
Returns:
str: Trading signal (BUY, SELL, HOLD)
tuple: (signal, confidence) where signal is BUY/SELL/HOLD,
confidence is probability (0-1)
"""
# Simple threshold-based signal generation
if isinstance(prediction, dict):
# Multi-timeframe prediction - combine signals
signals = []
confidences = []
for tf, pred in prediction.items():
if len(pred.shape) == 1:
# Binary classification
signal = "BUY" if pred[0] > 0.5 else "SELL"
confidence = pred[0] if signal == "BUY" else 1 - pred[0]
else:
# Multi-class
class_idx = np.argmax(pred)
signal = ["SELL", "HOLD", "BUY"][class_idx]
confidence = pred[class_idx]
signals.append(signal)
confidences.append(confidence)
# Simple voting system - count BUY/SELL signals
buy_count = signals.count("BUY")
sell_count = signals.count("SELL")
if buy_count > sell_count:
final_signal = "BUY"
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "BUY"])
elif sell_count > buy_count:
final_signal = "SELL"
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "SELL"])
else:
final_signal = "HOLD"
final_confidence = np.mean(confidences)
return final_signal, final_confidence
else:
# Single prediction
if len(prediction.shape) == 1:
# Binary classification
return "BUY" if prediction[0] > 0.5 else "SELL"
signal = "BUY" if prediction[0] > 0.5 else "SELL"
confidence = prediction[0] if signal == "BUY" else 1 - prediction[0]
else:
# Multi-class classification (3 outputs)
# Multi-class
class_idx = np.argmax(prediction)
return ["SELL", "HOLD", "BUY"][class_idx]
signal = ["SELL", "HOLD", "BUY"][class_idx]
confidence = prediction[class_idx]
return signal, confidence

View File

@ -37,3 +37,10 @@ Backend tkagg is interactive backend. Turning interactive mode on.
remodel our NN architecture. we should support up to 3 pairs simultaniously. so input can be 3 pairs: each pair will have up to 5 timeframes 1s(ticks, unspecified length), 1m, 1h, 1d + one additionall. we should normalize them in a way that preserves the relations between them (one price should be normalized to the same value across all tieframes). additionally to the 5 features OHLCV we will add up to 20 additional features for various technical indcators. 1s timeframe will be streamed in realtime. the MOE model should handle all that. we still need to access latest of the CNN hidden layers in the MOe model so we can extract learned features recognition
.
now let's run our "NN Training Pipeline" debug config. for now we start with single pair - BTC/USD. later we'll add up to 3 pairs for context. the NN will always have only 1 "main" pair - where the buy/sell actions are applied and which price prediction is calculater for each frame. we'll also try to predict the next local extrema that will help us be profitable
python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --framework pytorch
python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --framework pytorch --epochs 1000
python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --framework pytorch --epochs 1000 --symbol BTC/USDT --timeframes 1m 5m 1h 4h --epochs 10 --batch-size 32 --window-size 20 --output-size 3
python NN/realtime-main.py --mode train --model-type cnn --framework pytorch --symbol BTC/USDT --timeframes 1m 5m 1h 4h --epochs 10 --batch-size 32 --window-size 20 --output-size 3