working training
This commit is contained in:
@ -24,60 +24,67 @@ logger = logging.getLogger(__name__)
|
||||
class CNNPyTorch(nn.Module):
|
||||
"""PyTorch CNN model for time series analysis"""
|
||||
|
||||
def __init__(self, input_shape, output_size=3):
|
||||
def __init__(self, input_shape, output_size=5):
|
||||
"""
|
||||
Initialize the CNN model.
|
||||
Initialize the enhanced CNN model.
|
||||
|
||||
Args:
|
||||
input_shape (tuple): Shape of input data (window_size, features)
|
||||
output_size (int): Size of output (1 for regression, 3 for classification)
|
||||
output_size (int): Always 5 for our trading signals
|
||||
"""
|
||||
super(CNNPyTorch, self).__init__()
|
||||
|
||||
window_size, num_features = input_shape
|
||||
|
||||
# Architecture parameters
|
||||
filters = [32, 64, 128]
|
||||
kernel_sizes = [3, 5, 7]
|
||||
lstm_units = 100
|
||||
dense_units = 64
|
||||
kernel_size = 5
|
||||
dropout_rate = 0.3
|
||||
|
||||
# Create parallel convolutional pathways
|
||||
self.conv_paths = nn.ModuleList()
|
||||
|
||||
for f, k in zip(filters, kernel_sizes):
|
||||
path = nn.Sequential(
|
||||
nn.Conv1d(num_features, f, kernel_size=k, padding='same'),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(f),
|
||||
nn.MaxPool1d(kernel_size=2, stride=1, padding=1),
|
||||
nn.Dropout(dropout_rate)
|
||||
)
|
||||
self.conv_paths.append(path)
|
||||
|
||||
# Calculate output size from conv paths
|
||||
conv_output_size = sum(filters) * window_size
|
||||
|
||||
# LSTM layer
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=sum(filters),
|
||||
hidden_size=lstm_units,
|
||||
batch_first=True,
|
||||
bidirectional=True
|
||||
)
|
||||
|
||||
# Dense layers
|
||||
self.flatten = nn.Flatten()
|
||||
self.dense1 = nn.Sequential(
|
||||
nn.Linear(lstm_units * 2 * window_size, dense_units),
|
||||
# Enhanced CNN Architecture
|
||||
self.conv_layers = nn.Sequential(
|
||||
# Block 1
|
||||
nn.Conv1d(num_features, 64, kernel_size, padding='same'),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(dense_units),
|
||||
nn.Dropout(dropout_rate)
|
||||
|
||||
# Block 2
|
||||
nn.Conv1d(64, 128, kernel_size, padding='same'),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool1d(2),
|
||||
|
||||
# Block 3
|
||||
nn.Conv1d(128, 256, kernel_size, padding='same'),
|
||||
nn.BatchNorm1d(256),
|
||||
nn.ReLU(),
|
||||
|
||||
# Block 4
|
||||
nn.Conv1d(256, 512, kernel_size, padding='same'),
|
||||
nn.BatchNorm1d(512),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool1d(2)
|
||||
)
|
||||
|
||||
# Output layer
|
||||
self.output = nn.Linear(dense_units, output_size)
|
||||
# Calculate flattened size after conv and pooling
|
||||
conv_output_size = 512 * (window_size // 4)
|
||||
|
||||
# Enhanced dense layers
|
||||
self.dense_block = nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(conv_output_size, 512),
|
||||
nn.BatchNorm1d(512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
|
||||
nn.Linear(512, 256),
|
||||
nn.BatchNorm1d(256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
|
||||
nn.Linear(256, 128),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Linear(128, output_size)
|
||||
)
|
||||
|
||||
# Activation based on output size
|
||||
if output_size == 1:
|
||||
@ -89,7 +96,7 @@ class CNNPyTorch(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the network.
|
||||
Forward pass through enhanced network.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape [batch_size, window_size, features]
|
||||
@ -97,35 +104,15 @@ class CNNPyTorch(nn.Module):
|
||||
Returns:
|
||||
Output tensor of shape [batch_size, output_size]
|
||||
"""
|
||||
batch_size, window_size, num_features = x.shape
|
||||
|
||||
# Transpose for conv1d: [batch, features, window]
|
||||
x_t = x.transpose(1, 2)
|
||||
|
||||
# Process through parallel conv paths
|
||||
conv_outputs = []
|
||||
for path in self.conv_paths:
|
||||
conv_outputs.append(path(x_t))
|
||||
# Process through all CNN layers
|
||||
conv_out = self.conv_layers(x_t)
|
||||
|
||||
# Concatenate conv outputs
|
||||
conv_concat = torch.cat(conv_outputs, dim=1)
|
||||
# Process through dense layers
|
||||
output = self.dense_block(conv_out)
|
||||
|
||||
# Transpose back for LSTM: [batch, window, features]
|
||||
conv_concat = conv_concat.transpose(1, 2)
|
||||
|
||||
# LSTM processing
|
||||
lstm_out, _ = self.lstm(conv_concat)
|
||||
|
||||
# Flatten
|
||||
flattened = self.flatten(lstm_out)
|
||||
|
||||
# Dense processing
|
||||
dense_out = self.dense1(flattened)
|
||||
|
||||
# Output
|
||||
output = self.output(dense_out)
|
||||
|
||||
# Apply activation
|
||||
return self.activation(output)
|
||||
|
||||
|
||||
@ -137,7 +124,7 @@ class CNNModelPyTorch:
|
||||
predictions with the CNN model.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size, num_features, output_size=3, timeframes=None):
|
||||
def __init__(self, window_size, num_features, output_size=5, timeframes=None):
|
||||
"""
|
||||
Initialize the CNN model.
|
||||
|
||||
@ -506,41 +493,27 @@ class CNNModelPyTorch:
|
||||
|
||||
def extract_hidden_features(self, X):
|
||||
"""
|
||||
Extract hidden features from the model.
|
||||
Extract hidden features from the model - outputs from last dense layer before output.
|
||||
|
||||
Args:
|
||||
X: Input data
|
||||
|
||||
Returns:
|
||||
Hidden features
|
||||
Hidden features (output from penultimate dense layer)
|
||||
"""
|
||||
# Convert to PyTorch tensor
|
||||
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
|
||||
|
||||
# Forward pass through the model up to the last hidden layer
|
||||
# Forward pass through the model
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
# Get features before the output layer
|
||||
# Get features through CNN layers
|
||||
x_t = X_tensor.transpose(1, 2)
|
||||
conv_out = self.model.conv_layers(x_t)
|
||||
|
||||
# Process through parallel conv paths
|
||||
conv_outputs = []
|
||||
for path in self.model.conv_paths:
|
||||
conv_outputs.append(path(x_t))
|
||||
|
||||
# Concatenate conv outputs
|
||||
conv_concat = torch.cat(conv_outputs, dim=1)
|
||||
|
||||
# Transpose back for LSTM
|
||||
conv_concat = conv_concat.transpose(1, 2)
|
||||
|
||||
# LSTM processing
|
||||
lstm_out, _ = self.model.lstm(conv_concat)
|
||||
|
||||
# Flatten
|
||||
flattened = self.model.flatten(lstm_out)
|
||||
|
||||
# Dense processing
|
||||
hidden_features = self.model.dense1(flattened)
|
||||
# Process through all dense layers except the output layer
|
||||
features = conv_out
|
||||
for layer in self.model.dense_block[:-2]: # Exclude last linear layer and dropout
|
||||
features = layer(features)
|
||||
|
||||
return hidden_features.cpu().numpy()
|
||||
return features.cpu().numpy()
|
||||
|
@ -118,9 +118,7 @@ def main():
|
||||
logger.info("Initializing data interface...")
|
||||
data_interface = DataInterface(
|
||||
symbol=args.symbol,
|
||||
timeframes=args.timeframes,
|
||||
window_size=args.window_size,
|
||||
output_size=args.output_size
|
||||
timeframes=args.timeframes
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize data interface: {str(e)}")
|
||||
|
@ -10,6 +10,12 @@ import numpy as np
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
import os
|
||||
import pandas as pd
|
||||
from collections import deque
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -19,7 +25,8 @@ class RealtimeAnalyzer:
|
||||
|
||||
Features:
|
||||
- Connects to real-time data sources (websockets)
|
||||
- Processes incoming data through the neural network
|
||||
- Processes tick data into multiple timeframes (1s, 1m, 1h, 1d)
|
||||
- Uses trained models to analyze all timeframes
|
||||
- Generates trading signals
|
||||
- Manages risk and position sizing
|
||||
- Logs all trading decisions
|
||||
@ -33,17 +40,26 @@ class RealtimeAnalyzer:
|
||||
data_interface (DataInterface): Preconfigured data interface
|
||||
model: Trained neural network model
|
||||
symbol (str): Trading pair symbol
|
||||
timeframes (list): List of timeframes to monitor
|
||||
timeframes (list): List of timeframes to monitor (default: ['1s', '1m', '1h', '1d'])
|
||||
"""
|
||||
self.data_interface = data_interface
|
||||
self.model = model
|
||||
self.symbol = symbol
|
||||
self.timeframes = timeframes or ['1h']
|
||||
self.timeframes = timeframes or ['1s', '1m', '1h', '1d']
|
||||
self.running = False
|
||||
self.data_queue = Queue()
|
||||
self.prediction_interval = 60 # Seconds between predictions
|
||||
self.prediction_interval = 10 # Seconds between predictions
|
||||
self.ws_url = f"wss://stream.binance.com:9443/ws/{symbol.replace('/', '').lower()}@trade"
|
||||
self.ws = None
|
||||
self.tick_storage = deque(maxlen=10000) # Store up to 10,000 ticks
|
||||
self.candle_cache = {
|
||||
'1s': deque(maxlen=5000),
|
||||
'1m': deque(maxlen=5000),
|
||||
'1h': deque(maxlen=5000),
|
||||
'1d': deque(maxlen=5000)
|
||||
}
|
||||
|
||||
logger.info(f"RealtimeAnalyzer initialized for {symbol}")
|
||||
logger.info(f"RealtimeAnalyzer initialized for {symbol} with timeframes: {self.timeframes}")
|
||||
|
||||
def start(self):
|
||||
"""Start the realtime analysis process."""
|
||||
@ -53,9 +69,13 @@ class RealtimeAnalyzer:
|
||||
|
||||
self.running = True
|
||||
|
||||
# Start data collection thread
|
||||
self.data_thread = Thread(target=self._collect_data, daemon=True)
|
||||
self.data_thread.start()
|
||||
# Start WebSocket connection thread
|
||||
self.ws_thread = Thread(target=self._run_websocket, daemon=True)
|
||||
self.ws_thread.start()
|
||||
|
||||
# Start data processing thread
|
||||
self.processing_thread = Thread(target=self._process_data, daemon=True)
|
||||
self.processing_thread.start()
|
||||
|
||||
# Start analysis thread
|
||||
self.analysis_thread = Thread(target=self._analyze_data, daemon=True)
|
||||
@ -66,42 +86,128 @@ class RealtimeAnalyzer:
|
||||
def stop(self):
|
||||
"""Stop the realtime analysis process."""
|
||||
self.running = False
|
||||
if hasattr(self, 'data_thread'):
|
||||
self.data_thread.join(timeout=1)
|
||||
if self.ws:
|
||||
asyncio.run(self.ws.close())
|
||||
if hasattr(self, 'ws_thread'):
|
||||
self.ws_thread.join(timeout=1)
|
||||
if hasattr(self, 'processing_thread'):
|
||||
self.processing_thread.join(timeout=1)
|
||||
if hasattr(self, 'analysis_thread'):
|
||||
self.analysis_thread.join(timeout=1)
|
||||
logger.info("Realtime analysis stopped")
|
||||
|
||||
def _collect_data(self):
|
||||
"""Thread function for collecting real-time data."""
|
||||
logger.info("Starting data collection thread")
|
||||
|
||||
# In a real implementation, this would connect to websockets/API
|
||||
# For now, we'll simulate data collection from the data interface
|
||||
def _run_websocket(self):
|
||||
"""Thread function for running WebSocket connection."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self._connect_websocket())
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Connect to WebSocket and receive data."""
|
||||
while self.running:
|
||||
try:
|
||||
# Get latest data for each timeframe
|
||||
for timeframe in self.timeframes:
|
||||
# Get recent data (simulating real-time updates)
|
||||
X, timestamp = self.data_interface.prepare_realtime_input(
|
||||
timeframe=timeframe,
|
||||
n_candles=30,
|
||||
window_size=self.data_interface.window_size
|
||||
)
|
||||
logger.info(f"Connecting to WebSocket: {self.ws_url}")
|
||||
async with websockets.connect(self.ws_url) as ws:
|
||||
self.ws = ws
|
||||
logger.info("WebSocket connected")
|
||||
|
||||
if X is not None:
|
||||
self.data_queue.put({
|
||||
'timeframe': timeframe,
|
||||
'data': X,
|
||||
'timestamp': timestamp
|
||||
})
|
||||
while self.running:
|
||||
try:
|
||||
message = await ws.recv()
|
||||
data = json.loads(message)
|
||||
|
||||
if 'e' in data and data['e'] == 'trade':
|
||||
tick = {
|
||||
'timestamp': data['T'],
|
||||
'price': float(data['p']),
|
||||
'volume': float(data['q']),
|
||||
'symbol': self.symbol
|
||||
}
|
||||
self.tick_storage.append(tick)
|
||||
self.data_queue.put(tick)
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.warning("WebSocket connection closed")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving WebSocket message: {str(e)}")
|
||||
time.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket connection error: {str(e)}")
|
||||
time.sleep(5) # Wait before reconnecting
|
||||
|
||||
def _process_data(self):
|
||||
"""Process incoming tick data into candles for all timeframes."""
|
||||
logger.info("Starting data processing thread")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Process any new ticks
|
||||
while not self.data_queue.empty():
|
||||
tick = self.data_queue.get()
|
||||
|
||||
# Convert timestamp to datetime
|
||||
timestamp = datetime.fromtimestamp(tick['timestamp'] / 1000)
|
||||
|
||||
# Process for each timeframe
|
||||
for timeframe in self.timeframes:
|
||||
interval = self._get_interval_seconds(timeframe)
|
||||
if interval is None:
|
||||
continue
|
||||
|
||||
# Round timestamp to nearest candle interval
|
||||
candle_ts = int(tick['timestamp'] // (interval * 1000)) * (interval * 1000)
|
||||
|
||||
# Get or create candle for this timeframe
|
||||
if not self.candle_cache[timeframe]:
|
||||
# First candle for this timeframe
|
||||
candle = {
|
||||
'timestamp': candle_ts,
|
||||
'open': tick['price'],
|
||||
'high': tick['price'],
|
||||
'low': tick['price'],
|
||||
'close': tick['price'],
|
||||
'volume': tick['volume']
|
||||
}
|
||||
self.candle_cache[timeframe].append(candle)
|
||||
else:
|
||||
# Update existing candle
|
||||
last_candle = self.candle_cache[timeframe][-1]
|
||||
|
||||
if last_candle['timestamp'] == candle_ts:
|
||||
# Update current candle
|
||||
last_candle['high'] = max(last_candle['high'], tick['price'])
|
||||
last_candle['low'] = min(last_candle['low'], tick['price'])
|
||||
last_candle['close'] = tick['price']
|
||||
last_candle['volume'] += tick['volume']
|
||||
else:
|
||||
# New candle
|
||||
candle = {
|
||||
'timestamp': candle_ts,
|
||||
'open': tick['price'],
|
||||
'high': tick['price'],
|
||||
'low': tick['price'],
|
||||
'close': tick['price'],
|
||||
'volume': tick['volume']
|
||||
}
|
||||
self.candle_cache[timeframe].append(candle)
|
||||
|
||||
# Throttle data collection
|
||||
time.sleep(1)
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in data collection: {str(e)}")
|
||||
time.sleep(5) # Wait before retrying
|
||||
logger.error(f"Error in data processing: {str(e)}")
|
||||
time.sleep(1)
|
||||
|
||||
def _get_interval_seconds(self, timeframe):
|
||||
"""Convert timeframe string to seconds."""
|
||||
intervals = {
|
||||
'1s': 1,
|
||||
'1m': 60,
|
||||
'1h': 3600,
|
||||
'1d': 86400
|
||||
}
|
||||
return intervals.get(timeframe)
|
||||
|
||||
def _analyze_data(self):
|
||||
"""Thread function for analyzing data and generating signals."""
|
||||
@ -118,27 +224,56 @@ class RealtimeAnalyzer:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
# Get latest data from queue
|
||||
if not self.data_queue.empty():
|
||||
data_item = self.data_queue.get()
|
||||
# Prepare input data from all timeframes
|
||||
input_data = {}
|
||||
valid = True
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
if not self.candle_cache[timeframe]:
|
||||
logger.warning(f"No data available for timeframe {timeframe}")
|
||||
valid = False
|
||||
break
|
||||
|
||||
# Get last N candles for this timeframe
|
||||
candles = list(self.candle_cache[timeframe])[-self.data_interface.window_size:]
|
||||
|
||||
# Make prediction
|
||||
prediction = self.model.predict(data_item['data'])
|
||||
# Convert to numpy array
|
||||
ohlcv = np.array([
|
||||
[c['open'], c['high'], c['low'], c['close'], c['volume']]
|
||||
for c in candles
|
||||
])
|
||||
|
||||
# Normalize data
|
||||
ohlcv_normalized = (ohlcv - ohlcv.mean(axis=0)) / (ohlcv.std(axis=0) + 1e-8)
|
||||
input_data[timeframe] = ohlcv_normalized
|
||||
|
||||
if not valid:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
# Make prediction using the model
|
||||
try:
|
||||
prediction = self.model.predict(input_data)
|
||||
|
||||
# Get latest timestamp from 1s timeframe
|
||||
latest_ts = self.candle_cache['1s'][-1]['timestamp'] if self.candle_cache['1s'] else int(time.time() * 1000)
|
||||
|
||||
# Process prediction
|
||||
self._process_prediction(
|
||||
prediction=prediction,
|
||||
timeframe=data_item['timeframe'],
|
||||
timestamp=data_item['timestamp']
|
||||
timeframe='multi',
|
||||
timestamp=latest_ts
|
||||
)
|
||||
|
||||
last_prediction_time = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"Error making prediction: {str(e)}")
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in analysis: {str(e)}")
|
||||
time.sleep(1) # Wait before retrying
|
||||
time.sleep(1)
|
||||
|
||||
def _process_prediction(self, prediction, timeframe, timestamp):
|
||||
"""
|
||||
@ -146,17 +281,23 @@ class RealtimeAnalyzer:
|
||||
|
||||
Args:
|
||||
prediction: Model prediction output
|
||||
timeframe (str): Timeframe the prediction is for
|
||||
timestamp: Timestamp of the prediction
|
||||
timeframe (str): Timeframe the prediction is for ('multi' for combined)
|
||||
timestamp: Timestamp of the prediction (ms)
|
||||
"""
|
||||
# Convert prediction to trading signal
|
||||
signal = self._prediction_to_signal(prediction)
|
||||
signal, confidence = self._prediction_to_signal(prediction)
|
||||
|
||||
# Log the signal
|
||||
# Convert timestamp to datetime
|
||||
try:
|
||||
dt = datetime.fromtimestamp(timestamp / 1000)
|
||||
except:
|
||||
dt = datetime.now()
|
||||
|
||||
# Log the signal with all timeframes
|
||||
logger.info(
|
||||
f"Signal generated - Timeframe: {timeframe}, "
|
||||
f"Timestamp: {timestamp}, "
|
||||
f"Signal: {signal}"
|
||||
f"Signal generated - Timeframes: {', '.join(self.timeframes)}, "
|
||||
f"Timestamp: {dt}, "
|
||||
f"Signal: {signal} (Confidence: {confidence:.2f})"
|
||||
)
|
||||
|
||||
# In a real implementation, we would execute trades here
|
||||
@ -164,19 +305,60 @@ class RealtimeAnalyzer:
|
||||
|
||||
def _prediction_to_signal(self, prediction):
|
||||
"""
|
||||
Convert model prediction to trading signal.
|
||||
Convert model prediction to trading signal and confidence.
|
||||
|
||||
Args:
|
||||
prediction: Model prediction output
|
||||
prediction: Model prediction output (can be dict for multi-timeframe)
|
||||
|
||||
Returns:
|
||||
str: Trading signal (BUY, SELL, HOLD)
|
||||
tuple: (signal, confidence) where signal is BUY/SELL/HOLD,
|
||||
confidence is probability (0-1)
|
||||
"""
|
||||
# Simple threshold-based signal generation
|
||||
if len(prediction.shape) == 1:
|
||||
# Binary classification
|
||||
return "BUY" if prediction[0] > 0.5 else "SELL"
|
||||
if isinstance(prediction, dict):
|
||||
# Multi-timeframe prediction - combine signals
|
||||
signals = []
|
||||
confidences = []
|
||||
|
||||
for tf, pred in prediction.items():
|
||||
if len(pred.shape) == 1:
|
||||
# Binary classification
|
||||
signal = "BUY" if pred[0] > 0.5 else "SELL"
|
||||
confidence = pred[0] if signal == "BUY" else 1 - pred[0]
|
||||
else:
|
||||
# Multi-class
|
||||
class_idx = np.argmax(pred)
|
||||
signal = ["SELL", "HOLD", "BUY"][class_idx]
|
||||
confidence = pred[class_idx]
|
||||
|
||||
signals.append(signal)
|
||||
confidences.append(confidence)
|
||||
|
||||
# Simple voting system - count BUY/SELL signals
|
||||
buy_count = signals.count("BUY")
|
||||
sell_count = signals.count("SELL")
|
||||
|
||||
if buy_count > sell_count:
|
||||
final_signal = "BUY"
|
||||
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "BUY"])
|
||||
elif sell_count > buy_count:
|
||||
final_signal = "SELL"
|
||||
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "SELL"])
|
||||
else:
|
||||
final_signal = "HOLD"
|
||||
final_confidence = np.mean(confidences)
|
||||
|
||||
return final_signal, final_confidence
|
||||
|
||||
else:
|
||||
# Multi-class classification (3 outputs)
|
||||
class_idx = np.argmax(prediction)
|
||||
return ["SELL", "HOLD", "BUY"][class_idx]
|
||||
# Single prediction
|
||||
if len(prediction.shape) == 1:
|
||||
# Binary classification
|
||||
signal = "BUY" if prediction[0] > 0.5 else "SELL"
|
||||
confidence = prediction[0] if signal == "BUY" else 1 - prediction[0]
|
||||
else:
|
||||
# Multi-class
|
||||
class_idx = np.argmax(prediction)
|
||||
signal = ["SELL", "HOLD", "BUY"][class_idx]
|
||||
confidence = prediction[class_idx]
|
||||
|
||||
return signal, confidence
|
||||
|
Reference in New Issue
Block a user