Merge commit 'd49a473ed6f4aef55bfdd47d6370e53582be6b7b' into cleanup
This commit is contained in:
@@ -1,16 +0,0 @@
|
||||
"""
|
||||
Neural Network Trading System
|
||||
============================
|
||||
|
||||
A comprehensive neural network trading system that uses deep learning models
|
||||
to analyze cryptocurrency price data and generate trading signals.
|
||||
|
||||
The system consists of:
|
||||
1. Data Interface: Connects to realtime trading data
|
||||
2. CNN Model: Deep convolutional neural network for feature extraction
|
||||
3. Transformer Model: Processes high-level features for improved pattern recognition
|
||||
4. MoE: Mixture of Experts model that combines multiple neural networks
|
||||
"""
|
||||
|
||||
__version__ = '0.1.0'
|
||||
__author__ = 'Gogo2 Project'
|
||||
@@ -1,11 +0,0 @@
|
||||
"""
|
||||
Neural Network Data
|
||||
=================
|
||||
|
||||
This package is used to store datasets and model outputs.
|
||||
It does not contain any code, but serves as a storage location for:
|
||||
- Training datasets
|
||||
- Evaluation results
|
||||
- Inference outputs
|
||||
- Model checkpoints
|
||||
"""
|
||||
@@ -1,6 +0,0 @@
|
||||
# Trading environments for reinforcement learning
|
||||
# This module contains environments for training trading agents
|
||||
|
||||
from NN.environments.trading_env import TradingEnvironment
|
||||
|
||||
__all__ = ['TradingEnvironment']
|
||||
@@ -1,532 +0,0 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Dict, Tuple, List, Any, Optional
|
||||
import logging
|
||||
import gym
|
||||
from gym import spaces
|
||||
import random
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TradingEnvironment(gym.Env):
|
||||
"""
|
||||
Trading environment implementing gym interface for reinforcement learning
|
||||
|
||||
2-Action System:
|
||||
- 0: SELL (or close long position)
|
||||
- 1: BUY (or close short position)
|
||||
|
||||
Intelligent Position Management:
|
||||
- When neutral: Actions enter positions
|
||||
- When positioned: Actions can close or flip positions
|
||||
- Different thresholds for entry vs exit decisions
|
||||
|
||||
State:
|
||||
- OHLCV data from multiple timeframes
|
||||
- Technical indicators
|
||||
- Position data and unrealized PnL
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_interface,
|
||||
initial_balance: float = 10000.0,
|
||||
transaction_fee: float = 0.0002,
|
||||
window_size: int = 20,
|
||||
max_position: float = 1.0,
|
||||
reward_scaling: float = 1.0,
|
||||
entry_threshold: float = 0.6, # Higher threshold for entering positions
|
||||
exit_threshold: float = 0.3, # Lower threshold for exiting positions
|
||||
):
|
||||
"""
|
||||
Initialize the trading environment with 2-action system.
|
||||
|
||||
Args:
|
||||
data_interface: DataInterface instance to get market data
|
||||
initial_balance: Initial balance in the base currency
|
||||
transaction_fee: Fee for each transaction as a fraction of trade value
|
||||
window_size: Number of candles in the observation window
|
||||
max_position: Maximum position size as a fraction of balance
|
||||
reward_scaling: Scale factor for rewards
|
||||
entry_threshold: Confidence threshold for entering new positions
|
||||
exit_threshold: Confidence threshold for exiting positions
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.data_interface = data_interface
|
||||
self.initial_balance = initial_balance
|
||||
self.transaction_fee = transaction_fee
|
||||
self.window_size = window_size
|
||||
self.max_position = max_position
|
||||
self.reward_scaling = reward_scaling
|
||||
self.entry_threshold = entry_threshold
|
||||
self.exit_threshold = exit_threshold
|
||||
|
||||
# Load data for primary timeframe (assuming the first one is primary)
|
||||
self.timeframe = self.data_interface.timeframes[0]
|
||||
self.reset_data()
|
||||
|
||||
# Define action and observation spaces for 2-action system
|
||||
self.action_space = spaces.Discrete(2) # 0=SELL, 1=BUY
|
||||
|
||||
# For observation space, we consider multiple timeframes with OHLCV data
|
||||
# and additional features like technical indicators, position info, etc.
|
||||
n_timeframes = len(self.data_interface.timeframes)
|
||||
n_features = 5 # OHLCV data by default
|
||||
|
||||
# Add additional features for position, balance, unrealized_pnl, etc.
|
||||
additional_features = 5 # position, balance, unrealized_pnl, entry_price, position_duration
|
||||
|
||||
# Calculate total feature dimension
|
||||
total_features = (n_timeframes * n_features * self.window_size) + additional_features
|
||||
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(total_features,), dtype=np.float32
|
||||
)
|
||||
|
||||
# Use tuple for state_shape that EnhancedCNN expects
|
||||
self.state_shape = (total_features,)
|
||||
|
||||
# Position tracking for 2-action system
|
||||
self.position = 0.0 # -1 (short), 0 (neutral), 1 (long)
|
||||
self.entry_price = 0.0 # Price at which position was entered
|
||||
self.entry_step = 0 # Step at which position was entered
|
||||
|
||||
# Initialize state
|
||||
self.reset()
|
||||
|
||||
def reset_data(self):
|
||||
"""Reset data and generate a new set of price data for training"""
|
||||
# Get data for each timeframe
|
||||
self.data = {}
|
||||
for tf in self.data_interface.timeframes:
|
||||
df = self.data_interface.dataframes[tf]
|
||||
if df is not None and not df.empty:
|
||||
self.data[tf] = df
|
||||
|
||||
if not self.data:
|
||||
raise ValueError("No data available for training")
|
||||
|
||||
# Use the primary timeframe for step count
|
||||
self.prices = self.data[self.timeframe]['close'].values
|
||||
self.timestamps = self.data[self.timeframe].index.values
|
||||
self.max_steps = len(self.prices) - self.window_size - 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset the environment to initial state"""
|
||||
# Reset trading variables
|
||||
self.balance = self.initial_balance
|
||||
self.trades = []
|
||||
self.rewards = []
|
||||
|
||||
# Reset step counter
|
||||
self.current_step = self.window_size
|
||||
|
||||
# Get initial observation
|
||||
observation = self._get_observation()
|
||||
|
||||
return observation
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Take a step in the environment using 2-action system with intelligent position management.
|
||||
|
||||
Args:
|
||||
action: Action to take (0: SELL, 1: BUY)
|
||||
|
||||
Returns:
|
||||
tuple: (observation, reward, done, info)
|
||||
"""
|
||||
# Get current state before taking action
|
||||
prev_balance = self.balance
|
||||
prev_position = self.position
|
||||
prev_price = self.prices[self.current_step]
|
||||
|
||||
# Take action with intelligent position management
|
||||
info = {}
|
||||
reward = 0
|
||||
last_position_info = None
|
||||
|
||||
# Get current price
|
||||
current_price = self.prices[self.current_step]
|
||||
next_price = self.prices[self.current_step + 1] if self.current_step + 1 < len(self.prices) else current_price
|
||||
|
||||
# Implement 2-action system with position management
|
||||
if action == 0: # SELL action
|
||||
if self.position == 0: # No position - enter short
|
||||
self._open_position(-1.0 * self.max_position, current_price)
|
||||
logger.info(f"ENTER SHORT at step {self.current_step}, price: {current_price:.4f}")
|
||||
reward = -self.transaction_fee # Entry cost
|
||||
|
||||
elif self.position > 0: # Long position - close it
|
||||
close_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += close_pnl * self.reward_scaling
|
||||
logger.info(f"CLOSE LONG at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
|
||||
|
||||
elif self.position < 0: # Already short - potentially flip to long if very strong signal
|
||||
# For now, just hold the short position (no action)
|
||||
pass
|
||||
|
||||
elif action == 1: # BUY action
|
||||
if self.position == 0: # No position - enter long
|
||||
self._open_position(1.0 * self.max_position, current_price)
|
||||
logger.info(f"ENTER LONG at step {self.current_step}, price: {current_price:.4f}")
|
||||
reward = -self.transaction_fee # Entry cost
|
||||
|
||||
elif self.position < 0: # Short position - close it
|
||||
close_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += close_pnl * self.reward_scaling
|
||||
logger.info(f"CLOSE SHORT at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
|
||||
|
||||
elif self.position > 0: # Already long - potentially flip to short if very strong signal
|
||||
# For now, just hold the long position (no action)
|
||||
pass
|
||||
|
||||
# Calculate unrealized PnL and add to reward if holding position
|
||||
if self.position != 0:
|
||||
unrealized_pnl = self._calculate_unrealized_pnl(next_price)
|
||||
reward += unrealized_pnl * self.reward_scaling * 0.1 # Scale down unrealized PnL
|
||||
|
||||
# Apply time-based holding penalty to encourage decisive actions
|
||||
position_duration = self.current_step - self.entry_step
|
||||
holding_penalty = min(position_duration * 0.0001, 0.01) # Max 1% penalty
|
||||
reward -= holding_penalty
|
||||
|
||||
# Reward staying neutral when uncertain (no clear setup)
|
||||
else:
|
||||
reward += 0.0001 # Small reward for not trading without clear signals
|
||||
|
||||
# Move to next step
|
||||
self.current_step += 1
|
||||
|
||||
# Get new observation
|
||||
observation = self._get_observation()
|
||||
|
||||
# Check if episode is done
|
||||
done = self.current_step >= len(self.prices) - 1
|
||||
|
||||
# If done, close any remaining positions
|
||||
if done and self.position != 0:
|
||||
final_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += final_pnl * self.reward_scaling
|
||||
info['final_pnl'] = final_pnl
|
||||
info['final_balance'] = self.balance
|
||||
logger.info(f"Episode ended. Final balance: {self.balance:.4f}, Return: {(self.balance/self.initial_balance-1)*100:.2f}%")
|
||||
|
||||
# Track trade result if position changed or position was closed
|
||||
if prev_position != self.position or last_position_info is not None:
|
||||
# Calculate realized PnL if position was closed
|
||||
realized_pnl = 0
|
||||
position_info = {}
|
||||
|
||||
if last_position_info is not None:
|
||||
# Use the position information from closing
|
||||
realized_pnl = last_position_info['pnl']
|
||||
position_info = last_position_info
|
||||
else:
|
||||
# Calculate manually based on balance change
|
||||
realized_pnl = self.balance - prev_balance if prev_position != 0 else 0
|
||||
|
||||
# Record detailed trade information
|
||||
trade_result = {
|
||||
'step': self.current_step,
|
||||
'timestamp': self.timestamps[self.current_step],
|
||||
'action': action,
|
||||
'action_name': ['SELL', 'BUY'][action],
|
||||
'price': current_price,
|
||||
'position_changed': prev_position != self.position,
|
||||
'prev_position': prev_position,
|
||||
'new_position': self.position,
|
||||
'position_size': abs(self.position) if self.position != 0 else abs(prev_position),
|
||||
'entry_price': position_info.get('entry_price', self.entry_price),
|
||||
'exit_price': position_info.get('exit_price', current_price),
|
||||
'realized_pnl': realized_pnl,
|
||||
'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0,
|
||||
'pnl': realized_pnl, # Total PnL (realized for this step)
|
||||
'balance_before': prev_balance,
|
||||
'balance_after': self.balance,
|
||||
'trade_fee': position_info.get('fee', abs(self.position - prev_position) * current_price * self.transaction_fee)
|
||||
}
|
||||
info['trade_result'] = trade_result
|
||||
self.trades.append(trade_result)
|
||||
|
||||
# Log trade details
|
||||
logger.info(f"Trade executed - Action: {['SELL', 'BUY'][action]}, "
|
||||
f"Price: {current_price:.4f}, PnL: {realized_pnl:.4f}, "
|
||||
f"Balance: {self.balance:.4f}")
|
||||
|
||||
# Store reward
|
||||
self.rewards.append(reward)
|
||||
|
||||
# Update info dict with current state
|
||||
info.update({
|
||||
'step': self.current_step,
|
||||
'price': current_price,
|
||||
'prev_price': prev_price,
|
||||
'price_change': (current_price - prev_price) / prev_price if prev_price != 0 else 0,
|
||||
'balance': self.balance,
|
||||
'position': self.position,
|
||||
'entry_price': self.entry_price,
|
||||
'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0.0,
|
||||
'total_trades': len(self.trades),
|
||||
'total_pnl': self.total_pnl,
|
||||
'return_pct': (self.balance/self.initial_balance-1)*100
|
||||
})
|
||||
|
||||
return observation, reward, done, info
|
||||
|
||||
def _calculate_unrealized_pnl(self, current_price):
|
||||
"""Calculate unrealized PnL for current position"""
|
||||
if self.position == 0 or self.entry_price == 0:
|
||||
return 0.0
|
||||
|
||||
if self.position > 0: # Long position
|
||||
return self.position * (current_price / self.entry_price - 1.0)
|
||||
else: # Short position
|
||||
return -self.position * (1.0 - current_price / self.entry_price)
|
||||
|
||||
def _open_position(self, position_size: float, entry_price: float):
|
||||
"""Open a new position"""
|
||||
self.position = position_size
|
||||
self.entry_price = entry_price
|
||||
self.entry_step = self.current_step
|
||||
|
||||
# Calculate position value
|
||||
position_value = abs(position_size) * entry_price
|
||||
|
||||
# Apply transaction fee
|
||||
fee = position_value * self.transaction_fee
|
||||
self.balance -= fee
|
||||
|
||||
logger.info(f"Opened position: {position_size:.4f} at {entry_price:.4f}, fee: {fee:.4f}")
|
||||
|
||||
def _close_position(self, exit_price: float) -> Tuple[float, Dict]:
|
||||
"""Close current position and return PnL"""
|
||||
if self.position == 0:
|
||||
return 0.0, {}
|
||||
|
||||
# Calculate PnL
|
||||
if self.position > 0: # Long position
|
||||
pnl = (exit_price - self.entry_price) / self.entry_price
|
||||
else: # Short position
|
||||
pnl = (self.entry_price - exit_price) / self.entry_price
|
||||
|
||||
# Apply transaction fees (entry + exit)
|
||||
position_value = abs(self.position) * exit_price
|
||||
exit_fee = position_value * self.transaction_fee
|
||||
total_fees = exit_fee # Entry fee already applied when opening
|
||||
|
||||
# Net PnL after fees
|
||||
net_pnl = pnl - (total_fees / (abs(self.position) * self.entry_price))
|
||||
|
||||
# Update balance
|
||||
self.balance *= (1 + net_pnl)
|
||||
self.total_pnl += net_pnl
|
||||
|
||||
# Track trade
|
||||
position_info = {
|
||||
'position_size': self.position,
|
||||
'entry_price': self.entry_price,
|
||||
'exit_price': exit_price,
|
||||
'pnl': net_pnl,
|
||||
'duration': self.current_step - self.entry_step,
|
||||
'entry_step': self.entry_step,
|
||||
'exit_step': self.current_step
|
||||
}
|
||||
|
||||
self.trades.append(position_info)
|
||||
|
||||
# Update trade statistics
|
||||
if net_pnl > 0:
|
||||
self.winning_trades += 1
|
||||
else:
|
||||
self.losing_trades += 1
|
||||
|
||||
logger.info(f"Closed position: {self.position:.4f}, PnL: {net_pnl:.4f}, Duration: {position_info['duration']} steps")
|
||||
|
||||
# Reset position
|
||||
self.position = 0.0
|
||||
self.entry_price = 0.0
|
||||
self.entry_step = 0
|
||||
|
||||
return net_pnl, position_info
|
||||
|
||||
def _get_observation(self):
|
||||
"""
|
||||
Get the current observation.
|
||||
|
||||
Returns:
|
||||
np.array: The observation vector
|
||||
"""
|
||||
observations = []
|
||||
|
||||
# Get data from each timeframe
|
||||
for tf in self.data_interface.timeframes:
|
||||
if tf in self.data:
|
||||
# Get the window of data for this timeframe
|
||||
df = self.data[tf]
|
||||
start_idx = self._align_timeframe_index(tf)
|
||||
|
||||
if start_idx is not None and start_idx >= 0 and start_idx + self.window_size <= len(df):
|
||||
window = df.iloc[start_idx:start_idx + self.window_size]
|
||||
|
||||
# Extract OHLCV data
|
||||
ohlcv = window[['open', 'high', 'low', 'close', 'volume']].values
|
||||
|
||||
# Normalize OHLCV data
|
||||
last_close = ohlcv[-1, 3] # Last close price
|
||||
ohlcv_normalized = np.zeros_like(ohlcv)
|
||||
ohlcv_normalized[:, 0] = ohlcv[:, 0] / last_close - 1.0 # open
|
||||
ohlcv_normalized[:, 1] = ohlcv[:, 1] / last_close - 1.0 # high
|
||||
ohlcv_normalized[:, 2] = ohlcv[:, 2] / last_close - 1.0 # low
|
||||
ohlcv_normalized[:, 3] = ohlcv[:, 3] / last_close - 1.0 # close
|
||||
|
||||
# Normalize volume (relative to moving average of volume)
|
||||
if 'volume' in window.columns:
|
||||
volume_ma = ohlcv[:, 4].mean()
|
||||
if volume_ma > 0:
|
||||
ohlcv_normalized[:, 4] = ohlcv[:, 4] / volume_ma - 1.0
|
||||
else:
|
||||
ohlcv_normalized[:, 4] = 0.0
|
||||
else:
|
||||
ohlcv_normalized[:, 4] = 0.0
|
||||
|
||||
# Flatten and add to observations
|
||||
observations.append(ohlcv_normalized.flatten())
|
||||
else:
|
||||
# Fill with zeros if not enough data
|
||||
observations.append(np.zeros(self.window_size * 5))
|
||||
|
||||
# Add position and balance information
|
||||
current_price = self.prices[self.current_step]
|
||||
position_info = np.array([
|
||||
self.position / self.max_position, # Normalized position (-1 to 1)
|
||||
self.balance / self.initial_balance - 1.0, # Normalized balance change
|
||||
self._calculate_unrealized_pnl(current_price) # Unrealized PnL
|
||||
])
|
||||
|
||||
observations.append(position_info)
|
||||
|
||||
# Concatenate all observations
|
||||
observation = np.concatenate(observations)
|
||||
return observation
|
||||
|
||||
def _align_timeframe_index(self, timeframe):
|
||||
"""
|
||||
Align the index of a higher timeframe with the current step in the primary timeframe.
|
||||
|
||||
Args:
|
||||
timeframe: The timeframe to align
|
||||
|
||||
Returns:
|
||||
int: The starting index in the higher timeframe
|
||||
"""
|
||||
if timeframe == self.timeframe:
|
||||
return self.current_step - self.window_size
|
||||
|
||||
# Get timestamps for current primary timeframe step
|
||||
primary_ts = self.timestamps[self.current_step]
|
||||
|
||||
# Find closest index in the higher timeframe
|
||||
higher_ts = self.data[timeframe].index.values
|
||||
idx = np.searchsorted(higher_ts, primary_ts)
|
||||
|
||||
# Adjust to get the starting index
|
||||
start_idx = max(0, idx - self.window_size)
|
||||
return start_idx
|
||||
|
||||
def get_last_positions(self, n=5):
|
||||
"""
|
||||
Get detailed information about the last n positions.
|
||||
|
||||
Args:
|
||||
n: Number of last positions to return
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing position details
|
||||
"""
|
||||
if not self.trades:
|
||||
return []
|
||||
|
||||
# Filter trades to only include those that closed positions
|
||||
position_trades = [t for t in self.trades if t.get('realized_pnl', 0) != 0 or (t.get('prev_position', 0) != 0 and t.get('new_position', 0) == 0)]
|
||||
|
||||
positions = []
|
||||
last_n_trades = position_trades[-n:] if len(position_trades) >= n else position_trades
|
||||
|
||||
for trade in last_n_trades:
|
||||
position_info = {
|
||||
'timestamp': trade.get('timestamp', self.timestamps[trade['step']]),
|
||||
'action': trade.get('action_name', ['SELL', 'BUY'][trade['action']]),
|
||||
'entry_price': trade.get('entry_price', 0.0),
|
||||
'exit_price': trade.get('exit_price', trade['price']),
|
||||
'position_size': trade.get('position_size', self.max_position),
|
||||
'realized_pnl': trade.get('realized_pnl', 0.0),
|
||||
'fee': trade.get('trade_fee', 0.0),
|
||||
'pnl': trade.get('pnl', 0.0),
|
||||
'pnl_percentage': (trade.get('pnl', 0.0) / self.initial_balance) * 100,
|
||||
'balance_before': trade.get('balance_before', 0.0),
|
||||
'balance_after': trade.get('balance_after', 0.0),
|
||||
'duration': trade.get('duration', 'N/A')
|
||||
}
|
||||
positions.append(position_info)
|
||||
|
||||
return positions
|
||||
|
||||
def render(self, mode='human'):
|
||||
"""Render the environment"""
|
||||
current_step = self.current_step
|
||||
current_price = self.prices[current_step]
|
||||
|
||||
# Display basic information
|
||||
print(f"\nTrading Environment Status:")
|
||||
print(f"============================")
|
||||
print(f"Step: {current_step}/{len(self.prices)-1}")
|
||||
print(f"Current Price: {current_price:.4f}")
|
||||
print(f"Current Balance: {self.balance:.4f}")
|
||||
print(f"Current Position: {self.position:.4f}")
|
||||
|
||||
if self.position != 0:
|
||||
unrealized_pnl = self._calculate_unrealized_pnl(current_price)
|
||||
print(f"Entry Price: {self.entry_price:.4f}")
|
||||
print(f"Unrealized PnL: {unrealized_pnl:.4f} ({unrealized_pnl/self.balance*100:.2f}%)")
|
||||
|
||||
print(f"Total PnL: {self.total_pnl:.4f} ({self.total_pnl/self.initial_balance*100:.2f}%)")
|
||||
print(f"Total Trades: {len(self.trades)}")
|
||||
|
||||
if len(self.trades) > 0:
|
||||
win_trades = [t for t in self.trades if t.get('realized_pnl', 0) > 0]
|
||||
win_count = len(win_trades)
|
||||
# Count trades that closed positions (not just changed them)
|
||||
closed_positions = [t for t in self.trades if t.get('realized_pnl', 0) != 0]
|
||||
closed_count = len(closed_positions)
|
||||
win_rate = win_count / closed_count if closed_count > 0 else 0
|
||||
print(f"Positions Closed: {closed_count}")
|
||||
print(f"Winning Positions: {win_count}")
|
||||
print(f"Win Rate: {win_rate:.2f}")
|
||||
|
||||
# Display last 5 positions
|
||||
print("\nLast 5 Positions:")
|
||||
print("================")
|
||||
last_positions = self.get_last_positions(5)
|
||||
|
||||
if not last_positions:
|
||||
print("No closed positions yet.")
|
||||
|
||||
for pos in last_positions:
|
||||
print(f"Time: {pos['timestamp']}")
|
||||
print(f"Action: {pos['action']}")
|
||||
print(f"Entry: {pos['entry_price']:.4f}, Exit: {pos['exit_price']:.4f}")
|
||||
print(f"Size: {pos['position_size']:.4f}")
|
||||
print(f"PnL: {pos['realized_pnl']:.4f} ({pos['pnl_percentage']:.2f}%)")
|
||||
print(f"Fee: {pos['fee']:.4f}")
|
||||
print(f"Balance: {pos['balance_before']:.4f} -> {pos['balance_after']:.4f}")
|
||||
print("----------------")
|
||||
|
||||
return
|
||||
|
||||
def close(self):
|
||||
"""Close the environment"""
|
||||
pass
|
||||
@@ -1,162 +0,0 @@
|
||||
# Trading Agent System
|
||||
|
||||
This directory contains the implementation of a modular trading agent system that integrates with the neural network models and can execute trades on various cryptocurrency exchanges.
|
||||
|
||||
## Overview
|
||||
|
||||
The trading agent system is designed to:
|
||||
|
||||
1. Connect to different cryptocurrency exchanges using a common interface
|
||||
2. Execute trades based on signals from neural network models
|
||||
3. Manage risk through position sizing, trade limits, and cooldown periods
|
||||
4. Monitor and report on trading activity
|
||||
|
||||
## Components
|
||||
|
||||
### Exchange Interfaces
|
||||
|
||||
- `ExchangeInterface`: Abstract base class defining the common interface for all exchange implementations
|
||||
- `BinanceInterface`: Implementation for the Binance exchange, with support for both mainnet and testnet
|
||||
- `MEXCInterface`: Implementation for the MEXC exchange
|
||||
|
||||
### Trading Agent
|
||||
|
||||
The `TradingAgent` class (`trading_agent.py`) manages trading activities:
|
||||
|
||||
- Connects to the configured exchange
|
||||
- Processes trading signals from neural network models
|
||||
- Applies trading rules and risk management
|
||||
- Tracks and reports trading performance
|
||||
|
||||
### Neural Network Orchestrator
|
||||
|
||||
The `NeuralNetworkOrchestrator` class (`neural_network_orchestrator.py`) coordinates between models and trading:
|
||||
|
||||
- Manages the neural network inference process
|
||||
- Routes model signals to the trading agent
|
||||
- Provides integration with the RealTimeChart for visualization
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from NN.exchanges import BinanceInterface, MEXCInterface
|
||||
from NN.trading_agent import TradingAgent
|
||||
|
||||
# Initialize an exchange interface
|
||||
exchange = BinanceInterface(
|
||||
api_key="your_api_key",
|
||||
api_secret="your_api_secret",
|
||||
test_mode=True # Use testnet
|
||||
)
|
||||
|
||||
# Connect to the exchange
|
||||
exchange.connect()
|
||||
|
||||
# Create a trading agent
|
||||
agent = TradingAgent(
|
||||
exchange_name="binance",
|
||||
api_key="your_api_key",
|
||||
api_secret="your_api_secret",
|
||||
test_mode=True,
|
||||
trade_symbols=["BTC/USDT", "ETH/USDT"],
|
||||
position_size=0.1,
|
||||
max_trades_per_day=5,
|
||||
trade_cooldown_minutes=60
|
||||
)
|
||||
|
||||
# Start the trading agent
|
||||
agent.start()
|
||||
|
||||
# Process a trading signal
|
||||
agent.process_signal(
|
||||
symbol="BTC/USDT",
|
||||
action="BUY",
|
||||
confidence=0.85,
|
||||
timestamp=int(time.time())
|
||||
)
|
||||
|
||||
# Stop the trading agent when done
|
||||
agent.stop()
|
||||
```
|
||||
|
||||
### Integration with Neural Network Models
|
||||
|
||||
The system is designed to be integrated with neural network models through the `NeuralNetworkOrchestrator`:
|
||||
|
||||
```python
|
||||
from NN.neural_network_orchestrator import NeuralNetworkOrchestrator
|
||||
|
||||
# Configure exchange
|
||||
exchange_config = {
|
||||
"exchange": "binance",
|
||||
"api_key": "your_api_key",
|
||||
"api_secret": "your_api_secret",
|
||||
"test_mode": True,
|
||||
"trade_symbols": ["BTC/USDT", "ETH/USDT"],
|
||||
"position_size": 0.1,
|
||||
"max_trades_per_day": 5,
|
||||
"trade_cooldown_minutes": 60
|
||||
}
|
||||
|
||||
# Initialize orchestrator
|
||||
orchestrator = NeuralNetworkOrchestrator(
|
||||
model=model,
|
||||
data_interface=data_interface,
|
||||
chart=chart,
|
||||
symbols=["BTC/USDT", "ETH/USDT"],
|
||||
timeframes=["1m", "5m", "1h", "4h", "1d"],
|
||||
exchange_config=exchange_config
|
||||
)
|
||||
|
||||
# Start inference and trading
|
||||
orchestrator.start_inference()
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Exchange-Specific Configuration
|
||||
|
||||
- **Binance**: Supports both mainnet and testnet environments
|
||||
- **MEXC**: Supports mainnet only (no test environment available)
|
||||
|
||||
### Trading Agent Configuration
|
||||
|
||||
- `exchange_name`: Name of exchange ('binance', 'mexc')
|
||||
- `api_key`: API key for the exchange
|
||||
- `api_secret`: API secret for the exchange
|
||||
- `test_mode`: Whether to use test/sandbox environment
|
||||
- `trade_symbols`: List of trading symbols to monitor
|
||||
- `position_size`: Size of each position as a fraction of balance (0.0-1.0)
|
||||
- `max_trades_per_day`: Maximum number of trades to execute per day
|
||||
- `trade_cooldown_minutes`: Minimum time between trades in minutes
|
||||
|
||||
## Adding New Exchanges
|
||||
|
||||
To add support for a new exchange:
|
||||
|
||||
1. Create a new class that inherits from `ExchangeInterface`
|
||||
2. Implement all required methods (see `exchange_interface.py`)
|
||||
3. Add the new exchange to the imports in `__init__.py`
|
||||
4. Update the `_create_exchange` method in `TradingAgent` to support the new exchange
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
class KrakenInterface(ExchangeInterface):
|
||||
"""Kraken Exchange API Interface"""
|
||||
|
||||
def __init__(self, api_key=None, api_secret=None, test_mode=True):
|
||||
super().__init__(api_key, api_secret, test_mode)
|
||||
# Initialize Kraken-specific attributes
|
||||
|
||||
# Implement all required methods...
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- API keys should have trade permissions but not withdrawal permissions
|
||||
- Use environment variables or secure storage for API credentials
|
||||
- Always test with small position sizes before deploying with larger amounts
|
||||
- Consider using test mode/testnet for initial testing
|
||||
@@ -1,5 +0,0 @@
|
||||
from .exchange_interface import ExchangeInterface
|
||||
from .mexc_interface import MEXCInterface
|
||||
from .binance_interface import BinanceInterface
|
||||
|
||||
__all__ = ['ExchangeInterface', 'MEXCInterface', 'BinanceInterface']
|
||||
@@ -1,276 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
import requests
|
||||
import hmac
|
||||
import hashlib
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from .exchange_interface import ExchangeInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BinanceInterface(ExchangeInterface):
|
||||
"""Binance Exchange API Interface"""
|
||||
|
||||
def __init__(self, api_key: str = None, api_secret: str = None, test_mode: bool = True):
|
||||
"""Initialize Binance exchange interface.
|
||||
|
||||
Args:
|
||||
api_key: Binance API key
|
||||
api_secret: Binance API secret
|
||||
test_mode: If True, use testnet environment
|
||||
"""
|
||||
super().__init__(api_key, api_secret, test_mode)
|
||||
|
||||
# Use testnet URLs if in test mode
|
||||
if test_mode:
|
||||
self.base_url = "https://testnet.binance.vision"
|
||||
else:
|
||||
self.base_url = "https://api.binance.com"
|
||||
|
||||
self.api_version = "v3"
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""Connect to Binance API. This is a no-op for REST API."""
|
||||
if not self.api_key or not self.api_secret:
|
||||
logger.warning("Binance API credentials not provided. Running in read-only mode.")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Test connection by pinging server and checking account info
|
||||
ping_result = self._send_public_request('GET', 'ping')
|
||||
|
||||
if self.api_key and self.api_secret:
|
||||
# Check account connectivity
|
||||
self.get_account_info()
|
||||
|
||||
logger.info(f"Successfully connected to Binance API ({'testnet' if self.test_mode else 'live'})")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Binance API: {str(e)}")
|
||||
return False
|
||||
|
||||
def _generate_signature(self, params: Dict[str, Any]) -> str:
|
||||
"""Generate signature for authenticated requests."""
|
||||
query_string = urlencode(params)
|
||||
signature = hmac.new(
|
||||
self.api_secret.encode('utf-8'),
|
||||
query_string.encode('utf-8'),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
return signature
|
||||
|
||||
def _send_public_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Send public request to Binance API."""
|
||||
url = f"{self.base_url}/api/{self.api_version}/{endpoint}"
|
||||
|
||||
try:
|
||||
if method.upper() == 'GET':
|
||||
response = requests.get(url, params=params)
|
||||
else:
|
||||
response = requests.post(url, json=params)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in public request to {endpoint}: {str(e)}")
|
||||
raise
|
||||
|
||||
def _send_private_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Send private/authenticated request to Binance API."""
|
||||
if not self.api_key or not self.api_secret:
|
||||
raise ValueError("API key and secret are required for private requests")
|
||||
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
# Add timestamp
|
||||
params['timestamp'] = int(time.time() * 1000)
|
||||
|
||||
# Generate signature
|
||||
signature = self._generate_signature(params)
|
||||
params['signature'] = signature
|
||||
|
||||
# Set headers
|
||||
headers = {
|
||||
'X-MBX-APIKEY': self.api_key
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/api/{self.api_version}/{endpoint}"
|
||||
|
||||
try:
|
||||
if method.upper() == 'GET':
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
elif method.upper() == 'POST':
|
||||
response = requests.post(url, data=params, headers=headers)
|
||||
elif method.upper() == 'DELETE':
|
||||
response = requests.delete(url, params=params, headers=headers)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
# Log detailed error if available
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Binance API error: {response.text}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in private request to {endpoint}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_account_info(self) -> Dict[str, Any]:
|
||||
"""Get account information."""
|
||||
return self._send_private_request('GET', 'account')
|
||||
|
||||
def get_balance(self, asset: str) -> float:
|
||||
"""Get balance of a specific asset.
|
||||
|
||||
Args:
|
||||
asset: Asset symbol (e.g., 'BTC', 'USDT')
|
||||
|
||||
Returns:
|
||||
float: Available balance of the asset
|
||||
"""
|
||||
try:
|
||||
account_info = self._send_private_request('GET', 'account')
|
||||
balances = account_info.get('balances', [])
|
||||
|
||||
for balance in balances:
|
||||
if balance['asset'] == asset:
|
||||
return float(balance['free'])
|
||||
|
||||
# Asset not found
|
||||
return 0.0
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting balance for {asset}: {str(e)}")
|
||||
return 0.0
|
||||
|
||||
def get_ticker(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get current ticker data for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
|
||||
Returns:
|
||||
dict: Ticker data including price information
|
||||
"""
|
||||
binance_symbol = symbol.replace('/', '')
|
||||
try:
|
||||
ticker = self._send_public_request('GET', 'ticker/24hr', {'symbol': binance_symbol})
|
||||
|
||||
# Convert to a standardized format
|
||||
result = {
|
||||
'symbol': symbol,
|
||||
'bid': float(ticker['bidPrice']),
|
||||
'ask': float(ticker['askPrice']),
|
||||
'last': float(ticker['lastPrice']),
|
||||
'volume': float(ticker['volume']),
|
||||
'timestamp': int(ticker['closeTime'])
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
|
||||
raise
|
||||
|
||||
def place_order(self, symbol: str, side: str, order_type: str,
|
||||
quantity: float, price: float = None) -> Dict[str, Any]:
|
||||
"""Place an order on the exchange.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
side: Order side ('buy' or 'sell')
|
||||
order_type: Order type ('market', 'limit', etc.)
|
||||
quantity: Order quantity
|
||||
price: Order price (for limit orders)
|
||||
|
||||
Returns:
|
||||
dict: Order information including order ID
|
||||
"""
|
||||
binance_symbol = symbol.replace('/', '')
|
||||
params = {
|
||||
'symbol': binance_symbol,
|
||||
'side': side.upper(),
|
||||
'type': order_type.upper(),
|
||||
'quantity': quantity,
|
||||
}
|
||||
|
||||
if order_type.lower() == 'limit' and price is not None:
|
||||
params['price'] = price
|
||||
params['timeInForce'] = 'GTC' # Good Till Cancelled
|
||||
|
||||
# Use test order endpoint in test mode
|
||||
endpoint = 'order/test' if self.test_mode else 'order'
|
||||
|
||||
try:
|
||||
order_result = self._send_private_request('POST', endpoint, params)
|
||||
return order_result
|
||||
except Exception as e:
|
||||
logger.error(f"Error placing {side} {order_type} order for {symbol}: {str(e)}")
|
||||
raise
|
||||
|
||||
def cancel_order(self, symbol: str, order_id: str) -> bool:
|
||||
"""Cancel an existing order.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
order_id: ID of the order to cancel
|
||||
|
||||
Returns:
|
||||
bool: True if cancellation successful, False otherwise
|
||||
"""
|
||||
binance_symbol = symbol.replace('/', '')
|
||||
params = {
|
||||
'symbol': binance_symbol,
|
||||
'orderId': order_id
|
||||
}
|
||||
|
||||
try:
|
||||
cancel_result = self._send_private_request('DELETE', 'order', params)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling order {order_id} for {symbol}: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]:
|
||||
"""Get status of an existing order.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
order_id: ID of the order
|
||||
|
||||
Returns:
|
||||
dict: Order status information
|
||||
"""
|
||||
binance_symbol = symbol.replace('/', '')
|
||||
params = {
|
||||
'symbol': binance_symbol,
|
||||
'orderId': order_id
|
||||
}
|
||||
|
||||
try:
|
||||
order_info = self._send_private_request('GET', 'order', params)
|
||||
return order_info
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting order status for {order_id} on {symbol}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]:
|
||||
"""Get all open orders, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT'), or None for all symbols
|
||||
|
||||
Returns:
|
||||
list: List of open orders
|
||||
"""
|
||||
params = {}
|
||||
if symbol:
|
||||
params['symbol'] = symbol.replace('/', '')
|
||||
|
||||
try:
|
||||
open_orders = self._send_private_request('GET', 'openOrders', params)
|
||||
return open_orders
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting open orders: {str(e)}")
|
||||
return []
|
||||
@@ -1,191 +0,0 @@
|
||||
import abc
|
||||
import logging
|
||||
from typing import Dict, Any, List, Tuple, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ExchangeInterface(abc.ABC):
|
||||
"""Base class for all exchange interfaces.
|
||||
|
||||
This abstract class defines the required methods that all exchange
|
||||
implementations must provide to ensure compatibility with the trading system.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str = None, api_secret: str = None, test_mode: bool = True):
|
||||
"""Initialize the exchange interface.
|
||||
|
||||
Args:
|
||||
api_key: API key for the exchange
|
||||
api_secret: API secret for the exchange
|
||||
test_mode: If True, use test/sandbox environment
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.test_mode = test_mode
|
||||
self.client = None
|
||||
self.last_price_cache = {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self) -> bool:
|
||||
"""Connect to the exchange API.
|
||||
|
||||
Returns:
|
||||
bool: True if connection successful, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_balance(self, asset: str) -> float:
|
||||
"""Get balance of a specific asset.
|
||||
|
||||
Args:
|
||||
asset: Asset symbol (e.g., 'BTC', 'USDT')
|
||||
|
||||
Returns:
|
||||
float: Available balance of the asset
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_ticker(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get current ticker data for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
|
||||
Returns:
|
||||
dict: Ticker data including price information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def place_order(self, symbol: str, side: str, order_type: str,
|
||||
quantity: float, price: float = None) -> Dict[str, Any]:
|
||||
"""Place an order on the exchange.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
side: Order side ('buy' or 'sell')
|
||||
order_type: Order type ('market', 'limit', etc.)
|
||||
quantity: Order quantity
|
||||
price: Order price (for limit orders)
|
||||
|
||||
Returns:
|
||||
dict: Order information including order ID
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def cancel_order(self, symbol: str, order_id: str) -> bool:
|
||||
"""Cancel an existing order.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
order_id: ID of the order to cancel
|
||||
|
||||
Returns:
|
||||
bool: True if cancellation successful, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]:
|
||||
"""Get status of an existing order.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
order_id: ID of the order
|
||||
|
||||
Returns:
|
||||
dict: Order status information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]:
|
||||
"""Get all open orders, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT'), or None for all symbols
|
||||
|
||||
Returns:
|
||||
list: List of open orders
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_last_price(self, symbol: str) -> float:
|
||||
"""Get last known price for a symbol, may use cached value.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
|
||||
Returns:
|
||||
float: Last price
|
||||
"""
|
||||
try:
|
||||
ticker = self.get_ticker(symbol)
|
||||
price = float(ticker['last'])
|
||||
self.last_price_cache[symbol] = price
|
||||
return price
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting price for {symbol}: {str(e)}")
|
||||
# Return cached price if available
|
||||
return self.last_price_cache.get(symbol, 0.0)
|
||||
|
||||
def execute_trade(self, symbol: str, action: str, quantity: float = None,
|
||||
percent_of_balance: float = None) -> Optional[Dict[str, Any]]:
|
||||
"""Execute a trade based on a signal.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
action: Trade action ('BUY', 'SELL')
|
||||
quantity: Specific quantity to trade
|
||||
percent_of_balance: Alternative to quantity - percentage of available balance to use
|
||||
|
||||
Returns:
|
||||
dict: Order information or None if order failed
|
||||
"""
|
||||
if action not in ['BUY', 'SELL']:
|
||||
logger.error(f"Invalid action: {action}. Must be 'BUY' or 'SELL'")
|
||||
return None
|
||||
|
||||
side = action.lower()
|
||||
|
||||
try:
|
||||
# Determine base and quote assets from symbol (e.g., BTC/USDT -> BTC, USDT)
|
||||
base_asset, quote_asset = symbol.split('/')
|
||||
|
||||
# Calculate quantity if percent_of_balance is provided
|
||||
if quantity is None and percent_of_balance is not None:
|
||||
if percent_of_balance <= 0 or percent_of_balance > 1:
|
||||
logger.error(f"Invalid percent_of_balance: {percent_of_balance}. Must be between 0 and 1")
|
||||
return None
|
||||
|
||||
if side == 'buy':
|
||||
# For buy, use quote asset (e.g., USDT)
|
||||
balance = self.get_balance(quote_asset)
|
||||
price = self.get_last_price(symbol)
|
||||
quantity = (balance * percent_of_balance) / price
|
||||
else:
|
||||
# For sell, use base asset (e.g., BTC)
|
||||
balance = self.get_balance(base_asset)
|
||||
quantity = balance * percent_of_balance
|
||||
|
||||
if not quantity or quantity <= 0:
|
||||
logger.error(f"Invalid quantity: {quantity}")
|
||||
return None
|
||||
|
||||
# Place market order
|
||||
order = self.place_order(
|
||||
symbol=symbol,
|
||||
side=side,
|
||||
order_type='market',
|
||||
quantity=quantity
|
||||
)
|
||||
|
||||
logger.info(f"Executed {side.upper()} order for {quantity} {base_asset} at market price")
|
||||
return order
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing {action} trade for {symbol}: {str(e)}")
|
||||
return None
|
||||
@@ -1,520 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
import requests
|
||||
import hmac
|
||||
import hashlib
|
||||
from urllib.parse import urlencode, quote_plus
|
||||
import json # Added for json.dumps
|
||||
|
||||
from .exchange_interface import ExchangeInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# https://github.com/mexcdevelop/mexc-api-postman/blob/main/MEXC%20V3.postman_collection.json
|
||||
# MEXC V3.postman_collection.json
|
||||
|
||||
class MEXCInterface(ExchangeInterface):
|
||||
"""MEXC Exchange API Interface"""
|
||||
|
||||
def __init__(self, api_key: str = "", api_secret: str = "", test_mode: bool = True, trading_mode: str = 'simulation'):
|
||||
"""Initialize MEXC exchange interface.
|
||||
|
||||
Args:
|
||||
api_key: MEXC API key
|
||||
api_secret: MEXC API secret
|
||||
test_mode: If True, use test/sandbox environment (Note: MEXC doesn't have a true sandbox)
|
||||
trading_mode: 'simulation', 'testnet', or 'live'. Determines API endpoints used.
|
||||
"""
|
||||
super().__init__(api_key, api_secret, test_mode)
|
||||
|
||||
self.trading_mode = trading_mode # Store the trading mode
|
||||
|
||||
# MEXC API Base URLs
|
||||
self.base_url = "https://api.mexc.com" # Live API URL
|
||||
if self.trading_mode == 'testnet':
|
||||
# Note: MEXC does not have a separate testnet for spot trading.
|
||||
# We use the live API for 'testnet' mode and rely on 'simulation' for true dry-runs.
|
||||
logger.warning("MEXC does not have a separate testnet for spot trading. Using live API for 'testnet' mode.")
|
||||
|
||||
self.api_version = "api/v3"
|
||||
self.recv_window = 5000 # 5 seconds window for request validity
|
||||
|
||||
# Session for HTTP requests
|
||||
self.session = requests.Session()
|
||||
|
||||
logger.info(f"MEXCInterface initialized in {self.trading_mode} mode. Ensure correct API endpoints are being used.")
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""Test connection to MEXC API by fetching account info."""
|
||||
if not self.api_key or not self.api_secret:
|
||||
logger.error("MEXC API key or secret not set. Cannot connect.")
|
||||
return False
|
||||
|
||||
# Test connection by making a small, authenticated request
|
||||
try:
|
||||
account_info = self.get_account_info()
|
||||
if account_info:
|
||||
logger.info("Successfully connected to MEXC API and retrieved account info.")
|
||||
return True
|
||||
else:
|
||||
logger.error("Failed to connect to MEXC API: Could not retrieve account info.")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during MEXC API connection test: {e}")
|
||||
return False
|
||||
|
||||
def _format_spot_symbol(self, symbol: str) -> str:
|
||||
"""Formats a symbol to MEXC spot API standard (e.g., 'ETH/USDT' -> 'ETHUSDC')."""
|
||||
if '/' in symbol:
|
||||
base, quote = symbol.split('/')
|
||||
# Convert USDT to USDC for MEXC spot trading
|
||||
if quote.upper() == 'USDT':
|
||||
quote = 'USDC'
|
||||
return f"{base.upper()}{quote.upper()}"
|
||||
else:
|
||||
# Convert USDT to USDC for symbols like ETHUSDT
|
||||
symbol = symbol.upper()
|
||||
if symbol.endswith('USDT'):
|
||||
symbol = symbol.replace('USDT', 'USDC')
|
||||
return symbol
|
||||
|
||||
def _format_futures_symbol(self, symbol: str) -> str:
|
||||
"""Formats a symbol to MEXC futures API standard (e.g., 'ETH/USDT' -> 'ETH_USDT')."""
|
||||
# This method is included for completeness but should not be used for spot trading
|
||||
return symbol.replace('/', '_').upper()
|
||||
|
||||
def _generate_signature(self, timestamp: str, method: str, endpoint: str, params: Dict[str, Any]) -> str:
|
||||
"""Generate signature for private API calls using MEXC's official method"""
|
||||
# MEXC signature format varies by method:
|
||||
# For GET/DELETE: URL-encoded query string of alphabetically sorted parameters.
|
||||
# For POST: JSON string of parameters (no sorting needed).
|
||||
# The API-Secret is used as the HMAC SHA256 key.
|
||||
|
||||
# Remove signature from params to avoid circular inclusion
|
||||
clean_params = {k: v for k, v in params.items() if k != 'signature'}
|
||||
|
||||
parameter_string: str
|
||||
|
||||
if method.upper() == "POST":
|
||||
# For POST requests, the signature parameter is a JSON string
|
||||
# Ensure sorting keys for consistent JSON string generation across runs
|
||||
# even though MEXC says sorting is not required for POST params, it's good practice.
|
||||
parameter_string = json.dumps(clean_params, sort_keys=True, separators=(',', ':'))
|
||||
else:
|
||||
# For GET/DELETE requests, parameters are spliced in dictionary order with & interval
|
||||
sorted_params = sorted(clean_params.items())
|
||||
parameter_string = '&'.join(f"{key}={str(value)}" for key, value in sorted_params)
|
||||
|
||||
# The string to be signed is: accessKey + timestamp + obtained parameter string.
|
||||
string_to_sign = f"{self.api_key}{timestamp}{parameter_string}"
|
||||
|
||||
logger.debug(f"MEXC string to sign (method {method}): {string_to_sign}")
|
||||
|
||||
# Generate HMAC SHA256 signature
|
||||
signature = hmac.new(
|
||||
self.api_secret.encode('utf-8'),
|
||||
string_to_sign.encode('utf-8'),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
logger.debug(f"MEXC generated signature: {signature}")
|
||||
return signature
|
||||
|
||||
def _send_public_request(self, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Send a public API request to MEXC."""
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
url = f"{self.base_url}/{self.api_version}/{endpoint}"
|
||||
|
||||
headers = {'Accept': 'application/json'}
|
||||
|
||||
try:
|
||||
response = requests.request(method, url, params=params, headers=headers, timeout=10)
|
||||
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
logger.error(f"HTTP error in public request to {endpoint}: {response.status_code} {response.reason}")
|
||||
logger.error(f"Response content: {response.text}")
|
||||
return {}
|
||||
except requests.exceptions.ConnectionError as conn_err:
|
||||
logger.error(f"Connection error in public request to {endpoint}: {conn_err}")
|
||||
return {}
|
||||
except requests.exceptions.Timeout as timeout_err:
|
||||
logger.error(f"Timeout error in public request to {endpoint}: {timeout_err}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in public request to {endpoint}: {e}")
|
||||
return {}
|
||||
|
||||
def _send_private_request(self, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
|
||||
"""Send a private request to the exchange with proper signature"""
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
timestamp = str(int(time.time() * 1000))
|
||||
# Add timestamp and recvWindow to params for signature and request
|
||||
params['timestamp'] = timestamp
|
||||
params['recvWindow'] = self.recv_window
|
||||
signature = self._generate_signature(timestamp, method, endpoint, params)
|
||||
params['signature'] = signature
|
||||
|
||||
headers = {
|
||||
"X-MEXC-APIKEY": self.api_key,
|
||||
"Request-Time": timestamp
|
||||
}
|
||||
|
||||
# For spot API, use the correct endpoint format
|
||||
if not endpoint.startswith('api/v3/'):
|
||||
endpoint = f"api/v3/{endpoint}"
|
||||
url = f"{self.base_url}/{endpoint}"
|
||||
try:
|
||||
if method.upper() == "GET":
|
||||
response = self.session.get(url, headers=headers, params=params, timeout=10)
|
||||
elif method.upper() == "POST":
|
||||
# MEXC expects POST parameters as JSON in the request body, not as query string
|
||||
# The signature is generated from the JSON string of parameters.
|
||||
# We need to exclude 'signature' from the JSON body sent, as it's for the header.
|
||||
params_for_body = {k: v for k, v in params.items() if k != 'signature'}
|
||||
response = self.session.post(url, headers=headers, json=params_for_body, timeout=10)
|
||||
else:
|
||||
logger.error(f"Unsupported method: {method}")
|
||||
return None
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
# For successful responses, return the data directly
|
||||
# MEXC doesn't always use 'success' field for successful operations
|
||||
if response.status_code == 200:
|
||||
return data
|
||||
else:
|
||||
logger.error(f"API error: Status Code: {response.status_code}, Response: {response.text}")
|
||||
return None
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
logger.error(f"HTTP error for {endpoint}: Status Code: {response.status_code}, Response: {response.text}")
|
||||
logger.error(f"HTTP error details: {http_err}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Request error for {endpoint}: {e}")
|
||||
return None
|
||||
|
||||
def get_account_info(self) -> Dict[str, Any]:
|
||||
"""Get account information"""
|
||||
endpoint = "account"
|
||||
result = self._send_private_request("GET", endpoint, {})
|
||||
return result if result is not None else {}
|
||||
|
||||
def get_balance(self, asset: str) -> float:
|
||||
"""Get available balance for a specific asset."""
|
||||
account_info = self.get_account_info()
|
||||
if account_info and 'balances' in account_info:
|
||||
for balance in account_info['balances']:
|
||||
if balance.get('asset') == asset.upper():
|
||||
return float(balance.get('free', 0.0))
|
||||
logger.warning(f"Could not retrieve free balance for {asset}")
|
||||
return 0.0
|
||||
|
||||
def get_ticker(self, symbol: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get ticker information for a symbol."""
|
||||
formatted_symbol = self._format_spot_symbol(symbol)
|
||||
endpoint = "ticker/24hr"
|
||||
params = {'symbol': formatted_symbol}
|
||||
|
||||
response = self._send_public_request('GET', endpoint, params)
|
||||
|
||||
if isinstance(response, dict):
|
||||
ticker_data: Dict[str, Any] = response
|
||||
elif isinstance(response, list) and len(response) > 0:
|
||||
found_ticker = next((item for item in response if item.get('symbol') == formatted_symbol), None)
|
||||
if found_ticker:
|
||||
ticker_data = found_ticker
|
||||
else:
|
||||
logger.error(f"Ticker data for {formatted_symbol} not found in response list.")
|
||||
return None
|
||||
else:
|
||||
logger.error(f"Unexpected ticker response format: {response}")
|
||||
return None
|
||||
|
||||
# At this point, ticker_data is guaranteed to be a Dict[str, Any] due to the above logic
|
||||
# If it was None, we would have returned early.
|
||||
|
||||
# Extract relevant info and format for universal use
|
||||
last_price = float(ticker_data.get('lastPrice', 0))
|
||||
bid_price = float(ticker_data.get('bidPrice', 0))
|
||||
ask_price = float(ticker_data.get('askPrice', 0))
|
||||
volume = float(ticker_data.get('volume', 0)) # Base asset volume
|
||||
|
||||
# Determine price change and percent change
|
||||
price_change = float(ticker_data.get('priceChange', 0))
|
||||
price_change_percent = float(ticker_data.get('priceChangePercent', 0))
|
||||
|
||||
logger.info(f"MEXC: Got ticker from {endpoint} for {symbol}: ${last_price:.2f}")
|
||||
|
||||
return {
|
||||
'symbol': formatted_symbol,
|
||||
'last': last_price,
|
||||
'bid': bid_price,
|
||||
'ask': ask_price,
|
||||
'volume': volume,
|
||||
'high': float(ticker_data.get('highPrice', 0)),
|
||||
'low': float(ticker_data.get('lowPrice', 0)),
|
||||
'change': price_change_percent, # This is usually priceChangePercent
|
||||
'exchange': 'MEXC',
|
||||
'raw_data': ticker_data
|
||||
}
|
||||
|
||||
def get_api_symbols(self) -> List[str]:
|
||||
"""Get list of symbols supported for API trading"""
|
||||
try:
|
||||
endpoint = "selfSymbols"
|
||||
result = self._send_private_request("GET", endpoint, {})
|
||||
if result and 'data' in result:
|
||||
return result['data']
|
||||
elif isinstance(result, list):
|
||||
return result
|
||||
else:
|
||||
logger.warning(f"Unexpected response format for API symbols: {result}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API symbols: {e}")
|
||||
return []
|
||||
|
||||
def is_symbol_supported(self, symbol: str) -> bool:
|
||||
"""Check if a symbol is supported for API trading"""
|
||||
formatted_symbol = self._format_spot_symbol(symbol)
|
||||
supported_symbols = self.get_api_symbols()
|
||||
return formatted_symbol in supported_symbols
|
||||
|
||||
def place_order(self, symbol: str, side: str, order_type: str, quantity: float, price: Optional[float] = None) -> Dict[str, Any]:
|
||||
"""Place a new order on MEXC."""
|
||||
formatted_symbol = self._format_spot_symbol(symbol)
|
||||
|
||||
# Check if symbol is supported for API trading
|
||||
if not self.is_symbol_supported(symbol):
|
||||
supported_symbols = self.get_api_symbols()
|
||||
logger.error(f"Symbol {formatted_symbol} is not supported for API trading")
|
||||
logger.info(f"Supported symbols include: {supported_symbols[:10]}...") # Show first 10
|
||||
return {}
|
||||
|
||||
# Format quantity according to symbol precision requirements
|
||||
formatted_quantity = self._format_quantity_for_symbol(formatted_symbol, quantity)
|
||||
if formatted_quantity is None:
|
||||
logger.error(f"MEXC: Failed to format quantity {quantity} for {formatted_symbol}")
|
||||
return {}
|
||||
|
||||
# Handle order type restrictions for specific symbols
|
||||
final_order_type = self._adjust_order_type_for_symbol(formatted_symbol, order_type.upper())
|
||||
|
||||
# Get price for limit orders
|
||||
final_price = price
|
||||
if final_order_type == 'LIMIT' and price is None:
|
||||
# Get current market price
|
||||
ticker = self.get_ticker(symbol)
|
||||
if ticker and 'last' in ticker:
|
||||
final_price = ticker['last']
|
||||
logger.info(f"MEXC: Using market price ${final_price:.2f} for LIMIT order")
|
||||
else:
|
||||
logger.error(f"MEXC: Could not get market price for LIMIT order on {formatted_symbol}")
|
||||
return {}
|
||||
|
||||
endpoint = "order"
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
'symbol': formatted_symbol,
|
||||
'side': side.upper(),
|
||||
'type': final_order_type,
|
||||
'quantity': str(formatted_quantity) # Quantity must be a string
|
||||
}
|
||||
if final_price is not None:
|
||||
params['price'] = str(final_price) # Price must be a string for limit orders
|
||||
|
||||
logger.info(f"MEXC: Placing {side.upper()} {final_order_type} order for {formatted_quantity} {formatted_symbol} at price {final_price}")
|
||||
|
||||
try:
|
||||
# MEXC API endpoint for placing orders is /api/v3/order (POST)
|
||||
order_result = self._send_private_request('POST', endpoint, params)
|
||||
if order_result is not None:
|
||||
logger.info(f"MEXC: Order placed successfully: {order_result}")
|
||||
return order_result
|
||||
else:
|
||||
logger.error(f"MEXC: Error placing order: request returned None")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"MEXC: Exception placing order: {e}")
|
||||
return {}
|
||||
|
||||
def _format_quantity_for_symbol(self, formatted_symbol: str, quantity: float) -> Optional[float]:
|
||||
"""Format quantity according to symbol precision requirements"""
|
||||
try:
|
||||
# Symbol-specific precision rules
|
||||
if formatted_symbol == 'ETHUSDC':
|
||||
# ETHUSDC requires max 5 decimal places, step size 0.000001
|
||||
formatted_qty = round(quantity, 5)
|
||||
# Ensure it meets minimum step size
|
||||
step_size = 0.000001
|
||||
formatted_qty = round(formatted_qty / step_size) * step_size
|
||||
# Round again to remove floating point errors
|
||||
formatted_qty = round(formatted_qty, 6)
|
||||
logger.info(f"MEXC: Formatted ETHUSDC quantity {quantity} -> {formatted_qty}")
|
||||
return formatted_qty
|
||||
elif formatted_symbol == 'BTCUSDC':
|
||||
# Assume similar precision for BTC
|
||||
formatted_qty = round(quantity, 6)
|
||||
step_size = 0.000001
|
||||
formatted_qty = round(formatted_qty / step_size) * step_size
|
||||
formatted_qty = round(formatted_qty, 6)
|
||||
return formatted_qty
|
||||
else:
|
||||
# Default formatting - 6 decimal places
|
||||
return round(quantity, 6)
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting quantity for {formatted_symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _adjust_order_type_for_symbol(self, formatted_symbol: str, order_type: str) -> str:
|
||||
"""Adjust order type based on symbol restrictions"""
|
||||
if formatted_symbol == 'ETHUSDC':
|
||||
# ETHUSDC only supports LIMIT and LIMIT_MAKER orders
|
||||
if order_type == 'MARKET':
|
||||
logger.info(f"MEXC: Converting MARKET order to LIMIT for {formatted_symbol} (MARKET not supported)")
|
||||
return 'LIMIT'
|
||||
return order_type
|
||||
|
||||
def cancel_order(self, symbol: str, order_id: str) -> Dict[str, Any]:
|
||||
"""Cancel an existing order on MEXC."""
|
||||
formatted_symbol = self._format_spot_symbol(symbol)
|
||||
endpoint = "order"
|
||||
params = {
|
||||
'symbol': formatted_symbol,
|
||||
'orderId': order_id
|
||||
}
|
||||
logger.info(f"MEXC: Cancelling order {order_id} for {formatted_symbol}")
|
||||
try:
|
||||
# MEXC API endpoint for cancelling orders is /api/v3/order (DELETE)
|
||||
cancel_result = self._send_private_request('DELETE', endpoint, params)
|
||||
if cancel_result:
|
||||
logger.info(f"MEXC: Order cancelled successfully: {cancel_result}")
|
||||
return cancel_result
|
||||
else:
|
||||
logger.error(f"MEXC: Error cancelling order: {cancel_result}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"MEXC: Exception cancelling order: {e}")
|
||||
return {}
|
||||
|
||||
def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]:
|
||||
"""Get the status of an order on MEXC."""
|
||||
formatted_symbol = self._format_spot_symbol(symbol)
|
||||
endpoint = "order"
|
||||
params = {
|
||||
'symbol': formatted_symbol,
|
||||
'orderId': order_id
|
||||
}
|
||||
logger.info(f"MEXC: Getting status for order {order_id} for {formatted_symbol}")
|
||||
try:
|
||||
# MEXC API endpoint for order status is /api/v3/order (GET)
|
||||
status_result = self._send_private_request('GET', endpoint, params)
|
||||
if status_result:
|
||||
logger.info(f"MEXC: Order status retrieved: {status_result}")
|
||||
return status_result
|
||||
else:
|
||||
logger.error(f"MEXC: Error getting order status: {status_result}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"MEXC: Exception getting order status: {e}")
|
||||
return {}
|
||||
|
||||
def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Get all open orders on MEXC for a symbol or all symbols."""
|
||||
endpoint = "openOrders"
|
||||
params = {}
|
||||
if symbol:
|
||||
params['symbol'] = self._format_spot_symbol(symbol)
|
||||
|
||||
logger.info(f"MEXC: Getting open orders for {symbol if symbol else 'all symbols'}")
|
||||
try:
|
||||
# MEXC API endpoint for open orders is /api/v3/openOrders (GET)
|
||||
open_orders = self._send_private_request('GET', endpoint, params)
|
||||
if open_orders and isinstance(open_orders, list):
|
||||
logger.info(f"MEXC: Retrieved {len(open_orders)} open orders.")
|
||||
return open_orders
|
||||
else:
|
||||
logger.error(f"MEXC: Error getting open orders: {open_orders}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"MEXC: Exception getting open orders: {e}")
|
||||
return []
|
||||
|
||||
def get_my_trades(self, symbol: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get trade history for a specific symbol."""
|
||||
formatted_symbol = self._format_spot_symbol(symbol)
|
||||
endpoint = "myTrades"
|
||||
params = {'symbol': formatted_symbol, 'limit': limit}
|
||||
|
||||
logger.info(f"MEXC: Getting trade history for {formatted_symbol} (limit: {limit})")
|
||||
try:
|
||||
# MEXC API endpoint for trade history is /api/v3/myTrades (GET)
|
||||
trade_history = self._send_private_request('GET', endpoint, params)
|
||||
if trade_history and isinstance(trade_history, list):
|
||||
logger.info(f"MEXC: Retrieved {len(trade_history)} trade records.")
|
||||
return trade_history
|
||||
else:
|
||||
logger.error(f"MEXC: Error getting trade history: {trade_history}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"MEXC: Exception getting trade history: {e}")
|
||||
return []
|
||||
|
||||
def get_server_time(self) -> int:
|
||||
"""Get current MEXC server time in milliseconds."""
|
||||
endpoint = "time"
|
||||
response = self._send_public_request('GET', endpoint)
|
||||
if response and 'serverTime' in response:
|
||||
return int(response['serverTime'])
|
||||
logger.error("Failed to get MEXC server time.")
|
||||
return int(time.time() * 1000) # Fallback to local time
|
||||
|
||||
def get_all_balances(self) -> Dict[str, Dict[str, float]]:
|
||||
"""Get all asset balances from MEXC account."""
|
||||
account_info = self.get_account_info()
|
||||
balances = {}
|
||||
if account_info and 'balances' in account_info:
|
||||
for balance in account_info['balances']:
|
||||
asset = balance.get('asset')
|
||||
free = float(balance.get('free'))
|
||||
locked = float(balance.get('locked'))
|
||||
if asset:
|
||||
balances[asset.upper()] = {'free': free, 'locked': locked, 'total': free + locked}
|
||||
return balances
|
||||
|
||||
def get_trading_fees(self) -> Dict[str, Any]:
|
||||
"""Get current trading fee rates from MEXC API"""
|
||||
endpoint = "account/commission"
|
||||
response = self._send_private_request('GET', endpoint)
|
||||
if response and 'data' in response:
|
||||
fees_data = response['data']
|
||||
return {
|
||||
'maker': float(fees_data.get('makerCommission', 0.0)),
|
||||
'taker': float(fees_data.get('takerCommission', 0.0)),
|
||||
'default': float(fees_data.get('defaultCommission', 0.0))
|
||||
}
|
||||
logger.error("Failed to get trading fees from MEXC API.")
|
||||
return {}
|
||||
|
||||
def get_symbol_trading_fees(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get trading fee rates for a specific symbol from MEXC API"""
|
||||
formatted_symbol = self._format_spot_symbol(symbol)
|
||||
endpoint = "account/commission"
|
||||
params = {'symbol': formatted_symbol}
|
||||
response = self._send_private_request('GET', endpoint, params)
|
||||
if response and 'data' in response:
|
||||
fees_data = response['data']
|
||||
return {
|
||||
'maker': float(fees_data.get('makerCommission', 0.0)),
|
||||
'taker': float(fees_data.get('takerCommission', 0.0)),
|
||||
'default': float(fees_data.get('defaultCommission', 0.0))
|
||||
}
|
||||
logger.error(f"Failed to get trading fees for {symbol} from MEXC API.")
|
||||
return {}
|
||||
@@ -1,254 +0,0 @@
|
||||
"""
|
||||
Trading Agent Test Script
|
||||
|
||||
This script demonstrates how to use the swappable exchange modules
|
||||
to connect to and interact with different cryptocurrency exchanges.
|
||||
|
||||
Usage:
|
||||
python -m NN.exchanges.trading_agent_test --exchange binance --test-mode
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler("exchange_test.log"),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger("exchange_test")
|
||||
|
||||
# Import exchange interfaces
|
||||
try:
|
||||
from .exchange_interface import ExchangeInterface
|
||||
from .binance_interface import BinanceInterface
|
||||
from .mexc_interface import MEXCInterface
|
||||
except ImportError:
|
||||
# When running as standalone script
|
||||
from exchange_interface import ExchangeInterface
|
||||
from binance_interface import BinanceInterface
|
||||
from mexc_interface import MEXCInterface
|
||||
|
||||
def create_exchange(exchange_name: str, api_key: str = None, api_secret: str = None, test_mode: bool = True) -> ExchangeInterface:
|
||||
"""Create an exchange interface instance.
|
||||
|
||||
Args:
|
||||
exchange_name: Name of the exchange ('binance' or 'mexc')
|
||||
api_key: API key for the exchange
|
||||
api_secret: API secret for the exchange
|
||||
test_mode: If True, use test/sandbox environment
|
||||
|
||||
Returns:
|
||||
ExchangeInterface: The exchange interface instance
|
||||
"""
|
||||
exchange_name = exchange_name.lower()
|
||||
|
||||
if exchange_name == 'binance':
|
||||
return BinanceInterface(api_key, api_secret, test_mode)
|
||||
elif exchange_name == 'mexc':
|
||||
return MEXCInterface(api_key, api_secret, test_mode)
|
||||
else:
|
||||
raise ValueError(f"Unsupported exchange: {exchange_name}. Supported exchanges: binance, mexc")
|
||||
|
||||
def test_exchange(exchange: ExchangeInterface, symbols: list = None):
|
||||
"""Test the exchange interface.
|
||||
|
||||
Args:
|
||||
exchange: Exchange interface instance
|
||||
symbols: List of symbols to test with (e.g., ['BTC/USDT', 'ETH/USDT'])
|
||||
"""
|
||||
if symbols is None:
|
||||
symbols = ['BTC/USDT', 'ETH/USDT']
|
||||
|
||||
# Test connection
|
||||
logger.info(f"Testing connection to exchange...")
|
||||
connected = exchange.connect()
|
||||
if not connected and hasattr(exchange, 'api_key') and exchange.api_key:
|
||||
logger.error("Failed to connect to exchange. Make sure your API credentials are correct.")
|
||||
return False
|
||||
elif not connected:
|
||||
logger.warning("Running in read-only mode without API credentials.")
|
||||
else:
|
||||
logger.info("Connection successful with API credentials!")
|
||||
|
||||
# Test getting ticker data
|
||||
ticker_success = True
|
||||
for symbol in symbols:
|
||||
try:
|
||||
logger.info(f"Getting ticker data for {symbol}...")
|
||||
ticker = exchange.get_ticker(symbol)
|
||||
logger.info(f"Ticker for {symbol}: Last price: {ticker['last']}, Volume: {ticker['volume']}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
|
||||
ticker_success = False
|
||||
|
||||
if not ticker_success:
|
||||
logger.error("Failed to get ticker data. Exchange interface test failed.")
|
||||
return False
|
||||
|
||||
# Test getting account balances if API keys are provided
|
||||
if hasattr(exchange, 'api_key') and exchange.api_key:
|
||||
logger.info("Testing account balance retrieval...")
|
||||
try:
|
||||
for base_asset in ['BTC', 'ETH', 'USDT']:
|
||||
balance = exchange.get_balance(base_asset)
|
||||
logger.info(f"Balance for {base_asset}: {balance}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting account balances: {str(e)}")
|
||||
logger.warning("Balance retrieval failed, but this is not critical if ticker data works.")
|
||||
else:
|
||||
logger.warning("API keys not provided. Skipping balance checks.")
|
||||
|
||||
logger.info("Exchange interface test completed successfully in read-only mode.")
|
||||
return True
|
||||
|
||||
def execute_test_trades(exchange: ExchangeInterface, symbol: str, test_trade_amount: float = 0.001):
|
||||
"""Execute test trades.
|
||||
|
||||
Args:
|
||||
exchange: Exchange interface instance
|
||||
symbol: Symbol to trade (e.g., 'BTC/USDT')
|
||||
test_trade_amount: Amount to use for test trades
|
||||
"""
|
||||
if not hasattr(exchange, 'api_key') or not exchange.api_key:
|
||||
logger.warning("API keys not provided. Skipping test trades.")
|
||||
return
|
||||
|
||||
logger.info(f"Executing test trades for {symbol} with amount {test_trade_amount}...")
|
||||
|
||||
# Get current ticker for the symbol
|
||||
try:
|
||||
ticker = exchange.get_ticker(symbol)
|
||||
logger.info(f"Current price for {symbol}: {ticker['last']}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
|
||||
return
|
||||
|
||||
# Execute a buy order
|
||||
try:
|
||||
logger.info(f"Placing a test BUY order for {test_trade_amount} {symbol}...")
|
||||
buy_order = exchange.execute_trade(symbol, 'BUY', quantity=test_trade_amount)
|
||||
if buy_order:
|
||||
logger.info(f"BUY order executed: {buy_order}")
|
||||
order_id = buy_order.get('orderId')
|
||||
|
||||
# Get order status
|
||||
if order_id:
|
||||
time.sleep(2) # Wait for order to process
|
||||
status = exchange.get_order_status(symbol, order_id)
|
||||
logger.info(f"Order status: {status}")
|
||||
else:
|
||||
logger.error("BUY order failed.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing BUY order: {str(e)}")
|
||||
|
||||
# Wait before selling
|
||||
time.sleep(5)
|
||||
|
||||
# Execute a sell order
|
||||
try:
|
||||
logger.info(f"Placing a test SELL order for {test_trade_amount} {symbol}...")
|
||||
sell_order = exchange.execute_trade(symbol, 'SELL', quantity=test_trade_amount)
|
||||
if sell_order:
|
||||
logger.info(f"SELL order executed: {sell_order}")
|
||||
order_id = sell_order.get('orderId')
|
||||
|
||||
# Get order status
|
||||
if order_id:
|
||||
time.sleep(2) # Wait for order to process
|
||||
status = exchange.get_order_status(symbol, order_id)
|
||||
logger.info(f"Order status: {status}")
|
||||
else:
|
||||
logger.error("SELL order failed.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing SELL order: {str(e)}")
|
||||
|
||||
# Get open orders
|
||||
try:
|
||||
logger.info("Getting open orders...")
|
||||
open_orders = exchange.get_open_orders(symbol)
|
||||
if open_orders:
|
||||
logger.info(f"Open orders: {open_orders}")
|
||||
|
||||
# Cancel any open orders
|
||||
for order in open_orders:
|
||||
order_id = order.get('orderId')
|
||||
if order_id:
|
||||
logger.info(f"Cancelling order {order_id}...")
|
||||
cancelled = exchange.cancel_order(symbol, order_id)
|
||||
logger.info(f"Order cancelled: {cancelled}")
|
||||
else:
|
||||
logger.info("No open orders.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting/cancelling open orders: {str(e)}")
|
||||
|
||||
def main():
|
||||
"""Main function for testing exchange interfaces."""
|
||||
# Parse command-line arguments
|
||||
parser = argparse.ArgumentParser(description="Test exchange interfaces")
|
||||
parser.add_argument('--exchange', type=str, default='binance', choices=['binance', 'mexc'],
|
||||
help='Exchange to test')
|
||||
parser.add_argument('--api-key', type=str, default=None,
|
||||
help='API key for the exchange')
|
||||
parser.add_argument('--api-secret', type=str, default=None,
|
||||
help='API secret for the exchange')
|
||||
parser.add_argument('--test-mode', action='store_true',
|
||||
help='Use test/sandbox environment')
|
||||
parser.add_argument('--symbols', nargs='+', default=['BTC/USDT', 'ETH/USDT'],
|
||||
help='Symbols to test with')
|
||||
parser.add_argument('--execute-trades', action='store_true',
|
||||
help='Execute test trades (use with caution!)')
|
||||
parser.add_argument('--test-trade-amount', type=float, default=0.001,
|
||||
help='Amount to use for test trades')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Use environment variables for API keys if not provided
|
||||
api_key = args.api_key or os.environ.get(f"{args.exchange.upper()}_API_KEY")
|
||||
api_secret = args.api_secret or os.environ.get(f"{args.exchange.upper()}_API_SECRET")
|
||||
|
||||
# Create exchange interface
|
||||
try:
|
||||
exchange = create_exchange(
|
||||
exchange_name=args.exchange,
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
test_mode=args.test_mode
|
||||
)
|
||||
|
||||
logger.info(f"Created {args.exchange} exchange interface")
|
||||
logger.info(f"Test mode: {args.test_mode}")
|
||||
|
||||
# Test exchange
|
||||
if test_exchange(exchange, args.symbols):
|
||||
logger.info("Exchange interface test passed!")
|
||||
|
||||
# Execute test trades if requested
|
||||
if args.execute_trades:
|
||||
logger.warning("Executing test trades. This will use real funds!")
|
||||
execute_test_trades(
|
||||
exchange=exchange,
|
||||
symbol=args.symbols[0],
|
||||
test_trade_amount=args.test_trade_amount
|
||||
)
|
||||
else:
|
||||
logger.error("Exchange interface test failed!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing exchange interface: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -1,21 +0,0 @@
|
||||
"""
|
||||
Neural Network Models
|
||||
====================
|
||||
|
||||
This package contains the neural network models used in the trading system:
|
||||
- CNN Model: Deep convolutional neural network for feature extraction
|
||||
- DQN Agent: Deep Q-Network for reinforcement learning
|
||||
- COB RL Model: Specialized RL model for order book data
|
||||
- Advanced Transformer: High-performance transformer for trading
|
||||
|
||||
PyTorch implementation only.
|
||||
"""
|
||||
|
||||
from NN.models.cnn_model import EnhancedCNNModel as CNNModel
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
from NN.models.cob_rl_model import MassiveRLNetwork, COBRLModelInterface
|
||||
from NN.models.advanced_transformer_trading import AdvancedTradingTransformer, TradingTransformerConfig
|
||||
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface
|
||||
|
||||
__all__ = ['CNNModel', 'DQNAgent', 'MassiveRLNetwork', 'COBRLModelInterface', 'AdvancedTradingTransformer', 'TradingTransformerConfig',
|
||||
'ModelInterface', 'CNNModelInterface', 'RLAgentInterface', 'ExtremaTrainerInterface']
|
||||
@@ -267,7 +267,17 @@ class COBRLModelInterface(ModelInterface):
|
||||
|
||||
logger.info(f"COB RL Model Interface initialized on {self.device}")
|
||||
|
||||
<<<<<<< HEAD
|
||||
def predict(self, cob_features) -> Dict[str, Any]:
|
||||
=======
|
||||
def to(self, device):
|
||||
"""PyTorch-style device movement method"""
|
||||
self.device = device
|
||||
self.model = self.model.to(device)
|
||||
return self
|
||||
|
||||
def predict(self, cob_features: np.ndarray) -> Dict[str, Any]:
|
||||
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
|
||||
"""Make prediction using the model"""
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,6 +3,7 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Tuple, Dict, Any, Optional, Union
|
||||
@@ -80,6 +81,9 @@ class EnhancedCNN(nn.Module):
|
||||
self.n_actions = n_actions
|
||||
self.confidence_threshold = confidence_threshold
|
||||
|
||||
# Training data storage
|
||||
self.training_data = []
|
||||
|
||||
# Calculate input dimensions
|
||||
if isinstance(input_shape, (list, tuple)):
|
||||
if len(input_shape) == 3: # [channels, height, width]
|
||||
@@ -265,8 +269,9 @@ class EnhancedCNN(nn.Module):
|
||||
nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither
|
||||
)
|
||||
|
||||
# ULTRA MASSIVE multi-timeframe price prediction heads
|
||||
self.price_pred_immediate = nn.Sequential(
|
||||
# ULTRA MASSIVE price direction prediction head
|
||||
# Outputs single direction and confidence values
|
||||
self.price_direction_head = nn.Sequential(
|
||||
nn.Linear(1024, 1024), # Increased from 512
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
@@ -275,32 +280,13 @@ class EnhancedCNN(nn.Module):
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256), # Increased from 128
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 3) # Up, Down, Sideways
|
||||
nn.Linear(256, 2) # [direction, confidence]
|
||||
)
|
||||
|
||||
self.price_pred_midterm = nn.Sequential(
|
||||
nn.Linear(1024, 1024), # Increased from 512
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(1024, 512), # Increased from 256
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256), # Increased from 128
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 3) # Up, Down, Sideways
|
||||
)
|
||||
|
||||
self.price_pred_longterm = nn.Sequential(
|
||||
nn.Linear(1024, 1024), # Increased from 512
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(1024, 512), # Increased from 256
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256), # Increased from 128
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 3) # Up, Down, Sideways
|
||||
)
|
||||
# Direction activation (tanh for -1 to 1)
|
||||
self.direction_activation = nn.Tanh()
|
||||
# Confidence activation (sigmoid for 0 to 1)
|
||||
self.confidence_activation = nn.Sigmoid()
|
||||
|
||||
# ULTRA MASSIVE value prediction with ensemble approaches
|
||||
self.price_pred_value = nn.Sequential(
|
||||
@@ -371,21 +357,45 @@ class EnhancedCNN(nn.Module):
|
||||
nn.Linear(128, 4) # Low risk, medium risk, high risk, extreme risk
|
||||
)
|
||||
|
||||
def _memory_barrier(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Create a memory barrier to prevent in-place operation issues"""
|
||||
return tensor.detach().clone().requires_grad_(tensor.requires_grad)
|
||||
|
||||
def _check_rebuild_network(self, features):
|
||||
"""Check if network needs to be rebuilt for different feature dimensions"""
|
||||
"""DEPRECATED: Network should have fixed architecture - no runtime rebuilding"""
|
||||
if features != self.feature_dim:
|
||||
logger.info(f"Rebuilding network for new feature dimension: {features} (was {self.feature_dim})")
|
||||
self.feature_dim = features
|
||||
self._build_network()
|
||||
# Move to device after rebuilding
|
||||
self.to(self.device)
|
||||
return True
|
||||
logger.error(f"CRITICAL: Input feature dimension mismatch! Expected {self.feature_dim}, got {features}")
|
||||
logger.error("This indicates a bug in data preprocessing - input should be fixed size!")
|
||||
logger.error("Network architecture should NOT change at runtime!")
|
||||
raise ValueError(f"Input dimension mismatch: expected {self.feature_dim}, got {features}")
|
||||
return False
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass through the ULTRA MASSIVE network"""
|
||||
batch_size = x.size(0)
|
||||
|
||||
# Validate input dimensions to prevent zero-element tensor issues
|
||||
if x.numel() == 0:
|
||||
logger.error(f"Forward pass received empty tensor with shape {x.shape}")
|
||||
# Return default outputs for all 5 expected values to prevent crash
|
||||
default_q_values = torch.zeros(batch_size, self.n_actions, device=x.device)
|
||||
default_extrema = torch.zeros(batch_size, 3, device=x.device) # bottom/top/neither
|
||||
default_price_pred = torch.zeros(batch_size, 1, device=x.device)
|
||||
default_features = torch.zeros(batch_size, 1024, device=x.device)
|
||||
default_advanced = torch.zeros(batch_size, 1, device=x.device)
|
||||
return default_q_values, default_extrema, default_price_pred, default_features, default_advanced
|
||||
|
||||
# Check for zero feature dimensions
|
||||
if len(x.shape) > 1 and any(dim == 0 for dim in x.shape[1:]):
|
||||
logger.error(f"Forward pass received tensor with zero feature dimensions: {x.shape}")
|
||||
# Return default outputs for all 5 expected values to prevent crash
|
||||
default_q_values = torch.zeros(batch_size, self.n_actions, device=x.device)
|
||||
default_extrema = torch.zeros(batch_size, 3, device=x.device) # bottom/top/neither
|
||||
default_price_pred = torch.zeros(batch_size, 1, device=x.device)
|
||||
default_features = torch.zeros(batch_size, 1024, device=x.device)
|
||||
default_advanced = torch.zeros(batch_size, 1, device=x.device)
|
||||
return default_q_values, default_extrema, default_price_pred, default_features, default_advanced
|
||||
|
||||
# Process different input shapes
|
||||
if len(x.shape) > 2:
|
||||
# Handle 4D input [batch, timeframes, window, features] or 3D input [batch, timeframes, features]
|
||||
@@ -397,10 +407,11 @@ class EnhancedCNN(nn.Module):
|
||||
# Now x is 3D: [batch, timeframes, features]
|
||||
x_reshaped = x
|
||||
|
||||
# Check if the feature dimension has changed and rebuild if necessary
|
||||
if x_reshaped.size(1) * x_reshaped.size(2) != self.feature_dim:
|
||||
total_features = x_reshaped.size(1) * x_reshaped.size(2)
|
||||
self._check_rebuild_network(total_features)
|
||||
# Validate input dimensions (should be fixed)
|
||||
total_features = x_reshaped.size(1) * x_reshaped.size(2)
|
||||
if total_features != self.feature_dim:
|
||||
logger.error(f"Input dimension mismatch: expected {self.feature_dim}, got {total_features}")
|
||||
raise ValueError(f"Input dimension mismatch: expected {self.feature_dim}, got {total_features}")
|
||||
|
||||
# Apply ultra massive convolutions
|
||||
x_conv = self.conv_layers(x_reshaped)
|
||||
@@ -413,9 +424,10 @@ class EnhancedCNN(nn.Module):
|
||||
# For 2D input [batch, features]
|
||||
x_flat = x
|
||||
|
||||
# Check if dimensions have changed
|
||||
# Validate input dimensions (should be fixed)
|
||||
if x_flat.size(1) != self.feature_dim:
|
||||
self._check_rebuild_network(x_flat.size(1))
|
||||
logger.error(f"Input dimension mismatch: expected {self.feature_dim}, got {x_flat.size(1)}")
|
||||
raise ValueError(f"Input dimension mismatch: expected {self.feature_dim}, got {x_flat.size(1)}")
|
||||
|
||||
# Apply ULTRA MASSIVE FC layers to get base features
|
||||
features = self.fc_layers(x_flat) # [batch, 1024]
|
||||
@@ -464,10 +476,14 @@ class EnhancedCNN(nn.Module):
|
||||
# Extrema predictions (bottom/top/neither detection)
|
||||
extrema_pred = self.extrema_head(features_refined)
|
||||
|
||||
# Multi-timeframe price movement predictions
|
||||
price_immediate = self.price_pred_immediate(features_refined)
|
||||
price_midterm = self.price_pred_midterm(features_refined)
|
||||
price_longterm = self.price_pred_longterm(features_refined)
|
||||
# Price direction predictions
|
||||
price_direction_raw = self.price_direction_head(features_refined)
|
||||
|
||||
# Apply separate activations to direction and confidence
|
||||
direction = self.direction_activation(price_direction_raw[:, 0:1]) # -1 to 1
|
||||
confidence = self.confidence_activation(price_direction_raw[:, 1:2]) # 0 to 1
|
||||
price_direction_pred = torch.cat([direction, confidence], dim=1) # [batch, 2]
|
||||
|
||||
price_values = self.price_pred_value(features_refined)
|
||||
|
||||
# Additional specialized predictions for enhanced accuracy
|
||||
@@ -476,38 +492,42 @@ class EnhancedCNN(nn.Module):
|
||||
market_regime_pred = self.market_regime_head(features_refined)
|
||||
risk_pred = self.risk_head(features_refined)
|
||||
|
||||
# Package all price predictions
|
||||
price_predictions = {
|
||||
'immediate': price_immediate,
|
||||
'midterm': price_midterm,
|
||||
'longterm': price_longterm,
|
||||
'values': price_values
|
||||
}
|
||||
# Use the price direction prediction directly (already [batch, 2])
|
||||
price_direction_tensor = price_direction_pred
|
||||
|
||||
# Package additional predictions for enhanced decision making
|
||||
advanced_predictions = {
|
||||
'volatility': volatility_pred,
|
||||
'support_resistance': support_resistance_pred,
|
||||
'market_regime': market_regime_pred,
|
||||
'risk_assessment': risk_pred
|
||||
}
|
||||
# Package additional predictions into a single tensor (use volatility as primary)
|
||||
# For compatibility with DQN agent, we return volatility_pred as the advanced prediction tensor
|
||||
advanced_pred_tensor = volatility_pred
|
||||
|
||||
return q_values, extrema_pred, price_predictions, features_refined, advanced_predictions
|
||||
return q_values, extrema_pred, price_direction_tensor, features_refined, advanced_pred_tensor
|
||||
|
||||
def act(self, state, explore=True):
|
||||
def act(self, state, explore=True) -> Tuple[int, float, List[float]]:
|
||||
"""Enhanced action selection with ultra massive model predictions"""
|
||||
if explore and np.random.random() < 0.1: # 10% random exploration
|
||||
return np.random.choice(self.n_actions)
|
||||
|
||||
self.eval()
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
|
||||
# Accept both NumPy arrays and already-built torch tensors
|
||||
if isinstance(state, torch.Tensor):
|
||||
state_tensor = state.detach().to(self.device)
|
||||
if state_tensor.dim() == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0)
|
||||
else:
|
||||
# Convert to tensor **directly on the target device** to avoid intermediate CPU copies
|
||||
state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device)
|
||||
if state_tensor.dim() == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
q_values, extrema_pred, price_predictions, features, advanced_predictions = self(state_tensor)
|
||||
q_values, extrema_pred, price_direction_predictions, features, advanced_predictions = self(state_tensor)
|
||||
|
||||
# Process price direction predictions
|
||||
if price_direction_predictions is not None:
|
||||
self.process_price_direction_predictions(price_direction_predictions)
|
||||
|
||||
# Apply softmax to get action probabilities
|
||||
action_probs = torch.softmax(q_values, dim=1)
|
||||
action = torch.argmax(action_probs, dim=1).item()
|
||||
action_probs_tensor = torch.softmax(q_values, dim=1)
|
||||
action_idx = int(torch.argmax(action_probs_tensor, dim=1).item())
|
||||
confidence = float(action_probs_tensor[0, action_idx].item()) # Confidence of the chosen action
|
||||
action_probs = action_probs_tensor.squeeze(0).tolist() # Convert to list of floats for return
|
||||
|
||||
# Log advanced predictions for better decision making
|
||||
if hasattr(self, '_log_predictions') and self._log_predictions:
|
||||
@@ -537,7 +557,180 @@ class EnhancedCNN(nn.Module):
|
||||
logger.info(f" Market Regime: {regime_labels[regime_class]} ({regime[regime_class]:.3f})")
|
||||
logger.info(f" Risk Level: {risk_labels[risk_class]} ({risk[risk_class]:.3f})")
|
||||
|
||||
return action
|
||||
return action_idx, confidence, action_probs
|
||||
|
||||
def process_price_direction_predictions(self, price_direction_pred: torch.Tensor) -> Dict[str, float]:
|
||||
"""
|
||||
Process price direction predictions and convert to standardized format
|
||||
|
||||
Args:
|
||||
price_direction_pred: Tensor of shape (batch_size, 2) containing [direction, confidence]
|
||||
|
||||
Returns:
|
||||
Dict with direction (-1 to 1) and confidence (0 to 1)
|
||||
"""
|
||||
try:
|
||||
if price_direction_pred is None or price_direction_pred.numel() == 0:
|
||||
return {}
|
||||
|
||||
# Extract direction and confidence values
|
||||
direction_value = float(price_direction_pred[0, 0].item()) # -1 to 1
|
||||
confidence_value = float(price_direction_pred[0, 1].item()) # 0 to 1
|
||||
|
||||
processed_directions = {
|
||||
'direction': direction_value,
|
||||
'confidence': confidence_value
|
||||
}
|
||||
|
||||
# Store for later access
|
||||
self.last_price_direction = processed_directions
|
||||
|
||||
return processed_directions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing price direction predictions: {e}")
|
||||
return {}
|
||||
|
||||
def get_price_direction_vector(self) -> Dict[str, float]:
|
||||
"""
|
||||
Get the current price direction and confidence
|
||||
|
||||
Returns:
|
||||
Dict with direction (-1 to 1) and confidence (0 to 1)
|
||||
"""
|
||||
return getattr(self, 'last_price_direction', {})
|
||||
|
||||
def get_price_direction_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get a summary of price direction prediction
|
||||
|
||||
Returns:
|
||||
Dict containing direction and confidence information
|
||||
"""
|
||||
try:
|
||||
last_direction = getattr(self, 'last_price_direction', {})
|
||||
if not last_direction:
|
||||
return {
|
||||
'direction_value': 0.0,
|
||||
'confidence_value': 0.0,
|
||||
'direction_label': "SIDEWAYS",
|
||||
'discrete_direction': 0,
|
||||
'strength': 0.0,
|
||||
'weighted_strength': 0.0
|
||||
}
|
||||
|
||||
direction_value = last_direction['direction']
|
||||
confidence_value = last_direction['confidence']
|
||||
|
||||
# Convert to discrete direction
|
||||
if direction_value > 0.1:
|
||||
direction_label = "UP"
|
||||
discrete_direction = 1
|
||||
elif direction_value < -0.1:
|
||||
direction_label = "DOWN"
|
||||
discrete_direction = -1
|
||||
else:
|
||||
direction_label = "SIDEWAYS"
|
||||
discrete_direction = 0
|
||||
|
||||
return {
|
||||
'direction_value': float(direction_value),
|
||||
'confidence_value': float(confidence_value),
|
||||
'direction_label': direction_label,
|
||||
'discrete_direction': discrete_direction,
|
||||
'strength': abs(float(direction_value)),
|
||||
'weighted_strength': abs(float(direction_value)) * float(confidence_value)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating price direction summary: {e}")
|
||||
return {
|
||||
'direction_value': 0.0,
|
||||
'confidence_value': 0.0,
|
||||
'direction_label': "SIDEWAYS",
|
||||
'discrete_direction': 0,
|
||||
'strength': 0.0,
|
||||
'weighted_strength': 0.0
|
||||
}
|
||||
|
||||
def add_training_data(self, state, action, reward, position_pnl=0.0, has_position=False):
|
||||
"""
|
||||
Add training data to the model's training buffer with position-based reward enhancement
|
||||
|
||||
Args:
|
||||
state: Input state
|
||||
action: Action taken
|
||||
reward: Base reward received
|
||||
position_pnl: Current position P&L (0.0 if no position)
|
||||
has_position: Whether we currently have an open position
|
||||
"""
|
||||
try:
|
||||
# Enhance reward based on position status
|
||||
enhanced_reward = self._calculate_position_enhanced_reward(
|
||||
reward, action, position_pnl, has_position
|
||||
)
|
||||
|
||||
self.training_data.append({
|
||||
'state': state,
|
||||
'action': action,
|
||||
'reward': enhanced_reward,
|
||||
'base_reward': reward, # Keep original reward for analysis
|
||||
'position_pnl': position_pnl,
|
||||
'has_position': has_position,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
|
||||
# Keep only the last 1000 training samples
|
||||
if len(self.training_data) > 1000:
|
||||
self.training_data = self.training_data[-1000:]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding training data: {e}")
|
||||
|
||||
def _calculate_position_enhanced_reward(self, base_reward, action, position_pnl, has_position):
|
||||
"""
|
||||
Calculate position-enhanced reward to incentivize profitable trades and closing losing ones
|
||||
|
||||
Args:
|
||||
base_reward: Original reward from price prediction accuracy
|
||||
action: Action taken ('BUY', 'SELL', 'HOLD')
|
||||
position_pnl: Current position P&L
|
||||
has_position: Whether we have an open position
|
||||
|
||||
Returns:
|
||||
Enhanced reward that incentivizes profitable behavior
|
||||
"""
|
||||
try:
|
||||
enhanced_reward = base_reward
|
||||
|
||||
if has_position and position_pnl != 0.0:
|
||||
# Position-based reward adjustments
|
||||
pnl_factor = position_pnl / 100.0 # Normalize P&L to reasonable scale
|
||||
|
||||
if position_pnl > 0: # Profitable position
|
||||
if action == "HOLD":
|
||||
# Reward holding profitable positions (let winners run)
|
||||
enhanced_reward += abs(pnl_factor) * 0.5
|
||||
elif action in ["BUY", "SELL"]:
|
||||
# Moderate reward for taking action on profitable positions
|
||||
enhanced_reward += abs(pnl_factor) * 0.3
|
||||
|
||||
elif position_pnl < 0: # Losing position
|
||||
if action == "HOLD":
|
||||
# Penalty for holding losing positions (cut losses)
|
||||
enhanced_reward -= abs(pnl_factor) * 0.8
|
||||
elif action in ["BUY", "SELL"]:
|
||||
# Reward for taking action to close losing positions
|
||||
enhanced_reward += abs(pnl_factor) * 0.6
|
||||
|
||||
# Ensure reward doesn't become extreme
|
||||
enhanced_reward = max(-5.0, min(5.0, enhanced_reward))
|
||||
|
||||
return enhanced_reward
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating position-enhanced reward: {e}")
|
||||
return base_reward
|
||||
|
||||
def save(self, path):
|
||||
"""Save model weights and architecture"""
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
{"best_reward": 4791516.572471984, "best_episode": 3250, "best_pnl": 826842167451289.1, "best_win_rate": 0.47368421052631576, "date": "2025-04-01 10:19:16"}
|
||||
@@ -1,20 +0,0 @@
|
||||
{
|
||||
"supervised": {
|
||||
"epochs_completed": 22650,
|
||||
"best_val_pnl": 0.0,
|
||||
"best_epoch": 50,
|
||||
"best_win_rate": 0
|
||||
},
|
||||
"reinforcement": {
|
||||
"episodes_completed": 0,
|
||||
"best_reward": -Infinity,
|
||||
"best_episode": 0,
|
||||
"best_win_rate": 0
|
||||
},
|
||||
"hybrid": {
|
||||
"iterations_completed": 453,
|
||||
"best_combined_score": 0.0,
|
||||
"training_started": "2025-04-09T10:30:42.510856",
|
||||
"last_update": "2025-04-09T10:40:02.217840"
|
||||
}
|
||||
}
|
||||
@@ -1,326 +0,0 @@
|
||||
{
|
||||
"epochs_completed": 8,
|
||||
"best_val_pnl": 0.0,
|
||||
"best_epoch": 1,
|
||||
"best_win_rate": 0.0,
|
||||
"training_started": "2025-04-02T10:43:58.946682",
|
||||
"last_update": "2025-04-02T10:44:10.940892",
|
||||
"epochs": [
|
||||
{
|
||||
"epoch": 1,
|
||||
"train_loss": 1.0950355529785156,
|
||||
"val_loss": 1.1657923062642415,
|
||||
"train_acc": 0.3255208333333333,
|
||||
"val_acc": 0.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:01.840889",
|
||||
"data_age": 2,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 2,
|
||||
"train_loss": 1.0831659038861592,
|
||||
"val_loss": 1.1212460199991863,
|
||||
"train_acc": 0.390625,
|
||||
"val_acc": 0.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:03.134833",
|
||||
"data_age": 4,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 3,
|
||||
"train_loss": 1.0740693012873332,
|
||||
"val_loss": 1.0992945830027263,
|
||||
"train_acc": 0.4739583333333333,
|
||||
"val_acc": 0.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:04.425272",
|
||||
"data_age": 5,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 4,
|
||||
"train_loss": 1.0747728943824768,
|
||||
"val_loss": 1.0821794271469116,
|
||||
"train_acc": 0.4609375,
|
||||
"val_acc": 0.3229166666666667,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:05.716421",
|
||||
"data_age": 6,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 5,
|
||||
"train_loss": 1.0489931503931682,
|
||||
"val_loss": 1.0669521888097127,
|
||||
"train_acc": 0.5833333333333334,
|
||||
"val_acc": 1.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:07.007935",
|
||||
"data_age": 8,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 6,
|
||||
"train_loss": 1.0533669590950012,
|
||||
"val_loss": 1.0505590836207073,
|
||||
"train_acc": 0.5104166666666666,
|
||||
"val_acc": 1.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:08.296061",
|
||||
"data_age": 9,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 7,
|
||||
"train_loss": 1.0456886688868205,
|
||||
"val_loss": 1.0351698795954387,
|
||||
"train_acc": 0.5651041666666666,
|
||||
"val_acc": 1.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:09.607584",
|
||||
"data_age": 10,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 8,
|
||||
"train_loss": 1.040040671825409,
|
||||
"val_loss": 1.0227736632029216,
|
||||
"train_acc": 0.6119791666666666,
|
||||
"val_acc": 1.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:10.940892",
|
||||
"data_age": 11,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
}
|
||||
],
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"total_wins": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
}
|
||||
}
|
||||
@@ -1,192 +0,0 @@
|
||||
{
|
||||
"epochs_completed": 7,
|
||||
"best_val_pnl": 0.002028853100759435,
|
||||
"best_epoch": 6,
|
||||
"best_win_rate": 0.5157894736842106,
|
||||
"training_started": "2025-03-31T02:50:10.418670",
|
||||
"last_update": "2025-03-31T02:50:15.227593",
|
||||
"epochs": [
|
||||
{
|
||||
"epoch": 1,
|
||||
"train_loss": 1.1206786036491394,
|
||||
"val_loss": 1.0542699098587036,
|
||||
"train_acc": 0.11197916666666667,
|
||||
"val_acc": 0.25,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:12.881423",
|
||||
"data_age": 2
|
||||
},
|
||||
{
|
||||
"epoch": 2,
|
||||
"train_loss": 1.1266120672225952,
|
||||
"val_loss": 1.072133183479309,
|
||||
"train_acc": 0.1171875,
|
||||
"val_acc": 0.25,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:13.186840",
|
||||
"data_age": 2
|
||||
},
|
||||
{
|
||||
"epoch": 3,
|
||||
"train_loss": 1.1415620843569438,
|
||||
"val_loss": 1.1701548099517822,
|
||||
"train_acc": 0.1015625,
|
||||
"val_acc": 0.5208333333333334,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:13.442018",
|
||||
"data_age": 3
|
||||
},
|
||||
{
|
||||
"epoch": 4,
|
||||
"train_loss": 1.1331567962964375,
|
||||
"val_loss": 1.070081114768982,
|
||||
"train_acc": 0.09375,
|
||||
"val_acc": 0.22916666666666666,
|
||||
"train_pnl": 0.010650217327384765,
|
||||
"val_pnl": -0.0007049481907895126,
|
||||
"train_win_rate": 0.49279538904899134,
|
||||
"val_win_rate": 0.40625,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.9036458333333334,
|
||||
"HOLD": 0.09635416666666667
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.3333333333333333,
|
||||
"HOLD": 0.6666666666666666
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:13.739899",
|
||||
"data_age": 3
|
||||
},
|
||||
{
|
||||
"epoch": 5,
|
||||
"train_loss": 1.10965762535731,
|
||||
"val_loss": 1.0485950708389282,
|
||||
"train_acc": 0.12239583333333333,
|
||||
"val_acc": 0.17708333333333334,
|
||||
"train_pnl": 0.011924086862580204,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.5070422535211268,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.7395833333333334,
|
||||
"HOLD": 0.2604166666666667
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:14.073439",
|
||||
"data_age": 3
|
||||
},
|
||||
{
|
||||
"epoch": 6,
|
||||
"train_loss": 1.1272419293721516,
|
||||
"val_loss": 1.084235429763794,
|
||||
"train_acc": 0.1015625,
|
||||
"val_acc": 0.22916666666666666,
|
||||
"train_pnl": 0.014825159601390072,
|
||||
"val_pnl": 0.00405770620151887,
|
||||
"train_win_rate": 0.4908616187989556,
|
||||
"val_win_rate": 0.5157894736842106,
|
||||
"best_position_size": 2.0,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 1.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 1.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:14.658295",
|
||||
"data_age": 4
|
||||
},
|
||||
{
|
||||
"epoch": 7,
|
||||
"train_loss": 1.1171108484268188,
|
||||
"val_loss": 1.0741244554519653,
|
||||
"train_acc": 0.1171875,
|
||||
"val_acc": 0.22916666666666666,
|
||||
"train_pnl": 0.0059474696523706605,
|
||||
"val_pnl": 0.00405770620151887,
|
||||
"train_win_rate": 0.4838709677419355,
|
||||
"val_win_rate": 0.5157894736842106,
|
||||
"best_position_size": 2.0,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.7291666666666666,
|
||||
"HOLD": 0.2708333333333333
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 1.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:15.227593",
|
||||
"data_age": 4
|
||||
}
|
||||
]
|
||||
}
|
||||
512
NN/models/standardized_cnn.py
Normal file
512
NN/models/standardized_cnn.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""
|
||||
Standardized CNN Model for Multi-Modal Trading System
|
||||
|
||||
This module extends the existing EnhancedCNN to work with standardized BaseDataInput format
|
||||
and provides ModelOutput for cross-model feeding.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the project root to the path to import core modules
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
from core.data_models import BaseDataInput, ModelOutput, create_model_output
|
||||
from .enhanced_cnn import EnhancedCNN, SelfAttention, ResidualBlock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class StandardizedCNN(nn.Module):
|
||||
"""
|
||||
Standardized CNN Model that accepts BaseDataInput and outputs ModelOutput
|
||||
|
||||
Features:
|
||||
- Accepts standardized BaseDataInput format
|
||||
- Processes COB+OHLCV data: 300 frames (1s,1m,1h,1d) ETH + 300s 1s BTC
|
||||
- Includes COB ±20 buckets and MA (1s,5s,15s,60s) of COB imbalance ±5 buckets
|
||||
- Outputs BUY/SELL trading action with confidence scores
|
||||
- Provides hidden states for cross-model feeding
|
||||
- Integrates with checkpoint management system
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "standardized_cnn_v1", confidence_threshold: float = 0.6):
|
||||
"""
|
||||
Initialize the standardized CNN model
|
||||
|
||||
Args:
|
||||
model_name: Name identifier for this model instance
|
||||
confidence_threshold: Minimum confidence threshold for predictions
|
||||
"""
|
||||
super(StandardizedCNN, self).__init__()
|
||||
|
||||
self.model_name = model_name
|
||||
self.model_type = "cnn"
|
||||
self.confidence_threshold = confidence_threshold
|
||||
|
||||
# Calculate expected input dimensions from BaseDataInput
|
||||
self.expected_feature_dim = self._calculate_expected_features()
|
||||
|
||||
# Initialize the underlying enhanced CNN with calculated dimensions
|
||||
self.enhanced_cnn = EnhancedCNN(
|
||||
input_shape=self.expected_feature_dim,
|
||||
n_actions=3, # BUY, SELL, HOLD
|
||||
confidence_threshold=confidence_threshold
|
||||
)
|
||||
|
||||
# Additional layers for processing BaseDataInput structure
|
||||
self.input_processor = self._build_input_processor()
|
||||
|
||||
# Output processing layers
|
||||
self.output_processor = self._build_output_processor()
|
||||
|
||||
# Optional numeric return head (predicts percent change for 1s,1m,1h,1d)
|
||||
# Uses cnn_features (1024) to regress predicted returns per timeframe
|
||||
self.return_head = nn.Sequential(
|
||||
nn.Linear(1024, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(256, 4) # [return_1s, return_1m, return_1h, return_1d]
|
||||
)
|
||||
|
||||
# Device management
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.to(self.device)
|
||||
try:
|
||||
import torch.backends.cudnn as cudnn
|
||||
cudnn.benchmark = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"StandardizedCNN '{model_name}' initialized")
|
||||
logger.info(f"Expected feature dimension: {self.expected_feature_dim}")
|
||||
logger.info(f"Device: {self.device}")
|
||||
|
||||
def _calculate_expected_features(self) -> int:
|
||||
"""
|
||||
Calculate expected feature dimension from BaseDataInput structure
|
||||
|
||||
Based on actual BaseDataInput.get_feature_vector():
|
||||
- OHLCV ETH: 300 frames x 4 timeframes x 5 features = 6000
|
||||
- OHLCV BTC: 300 frames x 5 features = 1500
|
||||
- COB features: ~184 features (actual from implementation)
|
||||
- Technical indicators: 100 features (padded)
|
||||
- Last predictions: 50 features (padded)
|
||||
Total: ~7834 features (actual measured)
|
||||
"""
|
||||
return 7834 # Based on actual BaseDataInput.get_feature_vector() measurement
|
||||
|
||||
def _build_input_processor(self) -> nn.Module:
|
||||
"""
|
||||
Build input processing layers for BaseDataInput
|
||||
|
||||
Returns:
|
||||
nn.Module: Input processing layers
|
||||
"""
|
||||
return nn.Sequential(
|
||||
# Initial processing of raw BaseDataInput features
|
||||
nn.Linear(self.expected_feature_dim, 4096),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.BatchNorm1d(4096),
|
||||
|
||||
# Feature refinement
|
||||
nn.Linear(4096, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.BatchNorm1d(2048),
|
||||
|
||||
# Final feature extraction
|
||||
nn.Linear(2048, 1024),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1)
|
||||
)
|
||||
|
||||
def _build_output_processor(self) -> nn.Module:
|
||||
"""
|
||||
Build output processing layers for standardized ModelOutput
|
||||
|
||||
Returns:
|
||||
nn.Module: Output processing layers
|
||||
"""
|
||||
return nn.Sequential(
|
||||
# Process CNN outputs for standardized format
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Final action prediction
|
||||
nn.Linear(512, 3), # BUY, SELL, HOLD
|
||||
nn.Softmax(dim=1)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Forward pass through the standardized CNN
|
||||
|
||||
Args:
|
||||
x: Input tensor from BaseDataInput.get_feature_vector()
|
||||
|
||||
Returns:
|
||||
Tuple of (action_probabilities, hidden_states_dict)
|
||||
"""
|
||||
batch_size = x.size(0)
|
||||
|
||||
# Validate input dimensions
|
||||
if x.size(1) != self.expected_feature_dim:
|
||||
logger.warning(f"Input dimension mismatch: expected {self.expected_feature_dim}, got {x.size(1)}")
|
||||
# Pad or truncate as needed
|
||||
if x.size(1) < self.expected_feature_dim:
|
||||
padding = torch.zeros(batch_size, self.expected_feature_dim - x.size(1), device=x.device)
|
||||
x = torch.cat([x, padding], dim=1)
|
||||
else:
|
||||
x = x[:, :self.expected_feature_dim]
|
||||
|
||||
# Process input through input processor
|
||||
processed_features = self.input_processor(x) # [batch, 1024]
|
||||
|
||||
# Get enhanced CNN predictions (using processed features as input)
|
||||
# We need to reshape for the enhanced CNN which expects different input format
|
||||
cnn_input = processed_features.unsqueeze(1) # Add sequence dimension
|
||||
|
||||
try:
|
||||
q_values, extrema_pred, price_pred, cnn_features, advanced_pred = self.enhanced_cnn(cnn_input)
|
||||
except Exception as e:
|
||||
logger.warning(f"Enhanced CNN forward pass failed: {e}, using fallback")
|
||||
# Fallback to direct processing
|
||||
cnn_features = processed_features
|
||||
q_values = torch.zeros(batch_size, 3, device=x.device)
|
||||
extrema_pred = torch.zeros(batch_size, 3, device=x.device)
|
||||
price_pred = torch.zeros(batch_size, 3, device=x.device)
|
||||
advanced_pred = torch.zeros(batch_size, 5, device=x.device)
|
||||
|
||||
# Process outputs for standardized format
|
||||
action_probs = self.output_processor(cnn_features) # [batch, 3]
|
||||
|
||||
# Predict numeric returns per timeframe from cnn_features
|
||||
predicted_returns = self.return_head(cnn_features) # [batch, 4]
|
||||
|
||||
# Prepare hidden states for cross-model feeding
|
||||
hidden_states = {
|
||||
'processed_features': processed_features.detach(),
|
||||
'cnn_features': cnn_features.detach(),
|
||||
'q_values': q_values.detach(),
|
||||
'extrema_predictions': extrema_pred.detach(),
|
||||
'price_predictions': price_pred.detach(),
|
||||
'advanced_predictions': advanced_pred.detach(),
|
||||
'attention_weights': torch.ones(batch_size, 1, device=x.device) # Placeholder
|
||||
}
|
||||
|
||||
return action_probs, hidden_states, predicted_returns.detach()
|
||||
|
||||
def predict_from_base_input(self, base_input: BaseDataInput) -> ModelOutput:
|
||||
"""
|
||||
Make prediction from BaseDataInput and return standardized ModelOutput
|
||||
|
||||
Args:
|
||||
base_input: Standardized input data
|
||||
|
||||
Returns:
|
||||
ModelOutput: Standardized model output
|
||||
"""
|
||||
try:
|
||||
# Convert BaseDataInput to feature vector
|
||||
feature_vector = base_input.get_feature_vector()
|
||||
|
||||
# Convert to tensor and add batch dimension
|
||||
input_tensor = torch.tensor(feature_vector, dtype=torch.float32, device=self.device).unsqueeze(0)
|
||||
|
||||
# Set model to evaluation mode
|
||||
self.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
# Forward pass
|
||||
action_probs, hidden_states, predicted_returns = self.forward(input_tensor)
|
||||
|
||||
# Get action and confidence
|
||||
action_probs_np = action_probs.squeeze(0).cpu().numpy()
|
||||
action_idx = np.argmax(action_probs_np)
|
||||
confidence = float(action_probs_np[action_idx])
|
||||
|
||||
# Map action index to action name
|
||||
action_names = ['BUY', 'SELL', 'HOLD']
|
||||
action = action_names[action_idx]
|
||||
|
||||
# Prepare predictions dictionary
|
||||
predictions = {
|
||||
'action': action,
|
||||
'buy_probability': float(action_probs_np[0]),
|
||||
'sell_probability': float(action_probs_np[1]),
|
||||
'hold_probability': float(action_probs_np[2]),
|
||||
'action_probabilities': action_probs_np.tolist(),
|
||||
'extrema_detected': self._interpret_extrema(hidden_states.get('extrema_predictions')),
|
||||
'price_direction': self._interpret_price_direction(hidden_states.get('price_predictions')),
|
||||
'market_conditions': self._interpret_advanced_predictions(hidden_states.get('advanced_predictions'))
|
||||
}
|
||||
|
||||
# Add numeric predicted returns per timeframe if available
|
||||
try:
|
||||
pr = predicted_returns.squeeze(0).cpu().numpy().tolist()
|
||||
# Ensure length 4; if not, safely handle
|
||||
if isinstance(pr, list) and len(pr) >= 4:
|
||||
predictions['predicted_returns'] = pr[:4]
|
||||
predictions['predicted_return_1s'] = float(pr[0])
|
||||
predictions['predicted_return_1m'] = float(pr[1])
|
||||
predictions['predicted_return_1h'] = float(pr[2])
|
||||
predictions['predicted_return_1d'] = float(pr[3])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Prepare hidden states for cross-model feeding (convert tensors to numpy)
|
||||
cross_model_states = {}
|
||||
for key, tensor in hidden_states.items():
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
cross_model_states[key] = tensor.squeeze(0).cpu().numpy().tolist()
|
||||
else:
|
||||
cross_model_states[key] = tensor
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
'model_version': '1.0',
|
||||
'confidence_threshold': self.confidence_threshold,
|
||||
'feature_dimension': self.expected_feature_dim,
|
||||
'processing_time_ms': 0, # Could add timing if needed
|
||||
'input_validation': base_input.validate()
|
||||
}
|
||||
|
||||
# Create standardized ModelOutput
|
||||
model_output = ModelOutput(
|
||||
model_type=self.model_type,
|
||||
model_name=self.model_name,
|
||||
symbol=base_input.symbol,
|
||||
timestamp=datetime.now(),
|
||||
confidence=confidence,
|
||||
predictions=predictions,
|
||||
hidden_states=cross_model_states,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
return model_output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN prediction: {e}")
|
||||
# Return default output
|
||||
return self._create_default_output(base_input.symbol)
|
||||
|
||||
def _interpret_extrema(self, extrema_tensor: Optional[torch.Tensor]) -> str:
|
||||
"""Interpret extrema predictions"""
|
||||
if extrema_tensor is None:
|
||||
return "unknown"
|
||||
|
||||
try:
|
||||
extrema_probs = torch.softmax(extrema_tensor.squeeze(0), dim=0)
|
||||
extrema_idx = torch.argmax(extrema_probs).item()
|
||||
extrema_labels = ['bottom', 'top', 'neither']
|
||||
return extrema_labels[extrema_idx]
|
||||
except:
|
||||
return "unknown"
|
||||
|
||||
def _interpret_price_direction(self, price_tensor: Optional[torch.Tensor]) -> str:
|
||||
"""Interpret price direction predictions"""
|
||||
if price_tensor is None:
|
||||
return "unknown"
|
||||
|
||||
try:
|
||||
price_probs = torch.softmax(price_tensor.squeeze(0), dim=0)
|
||||
price_idx = torch.argmax(price_probs).item()
|
||||
price_labels = ['up', 'down', 'sideways']
|
||||
return price_labels[price_idx]
|
||||
except:
|
||||
return "unknown"
|
||||
|
||||
def _interpret_advanced_predictions(self, advanced_tensor: Optional[torch.Tensor]) -> Dict[str, str]:
|
||||
"""Interpret advanced market predictions"""
|
||||
if advanced_tensor is None:
|
||||
return {"volatility": "unknown", "risk": "unknown"}
|
||||
|
||||
try:
|
||||
# Assuming advanced predictions include volatility (5 classes)
|
||||
if advanced_tensor.size(-1) >= 5:
|
||||
volatility_probs = torch.softmax(advanced_tensor.squeeze(0)[:5], dim=0)
|
||||
volatility_idx = torch.argmax(volatility_probs).item()
|
||||
volatility_labels = ['very_low', 'low', 'medium', 'high', 'very_high']
|
||||
volatility = volatility_labels[volatility_idx]
|
||||
else:
|
||||
volatility = "unknown"
|
||||
|
||||
return {
|
||||
"volatility": volatility,
|
||||
"risk": "medium" # Placeholder
|
||||
}
|
||||
except:
|
||||
return {"volatility": "unknown", "risk": "unknown"}
|
||||
|
||||
def _create_default_output(self, symbol: str) -> ModelOutput:
|
||||
"""Create default ModelOutput for error cases"""
|
||||
return create_model_output(
|
||||
model_type=self.model_type,
|
||||
model_name=self.model_name,
|
||||
symbol=symbol,
|
||||
action='HOLD',
|
||||
confidence=0.5,
|
||||
metadata={'error': True, 'default_output': True}
|
||||
)
|
||||
|
||||
def train_step(self, base_inputs: List[BaseDataInput], targets: List[str],
|
||||
optimizer: torch.optim.Optimizer) -> float:
|
||||
"""
|
||||
Perform a single training step
|
||||
|
||||
Args:
|
||||
base_inputs: List of BaseDataInput for training
|
||||
targets: List of target actions ('BUY', 'SELL', 'HOLD')
|
||||
optimizer: PyTorch optimizer
|
||||
|
||||
Returns:
|
||||
float: Training loss
|
||||
"""
|
||||
self.train()
|
||||
|
||||
try:
|
||||
# Convert inputs to tensors
|
||||
feature_vectors = []
|
||||
for base_input in base_inputs:
|
||||
feature_vector = base_input.get_feature_vector()
|
||||
feature_vectors.append(feature_vector)
|
||||
|
||||
input_tensor = torch.tensor(np.array(feature_vectors), dtype=torch.float32, device=self.device)
|
||||
|
||||
# Convert targets to tensor
|
||||
action_to_idx = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
|
||||
target_indices = [action_to_idx.get(target, 2) for target in targets]
|
||||
target_tensor = torch.tensor(target_indices, dtype=torch.long, device=self.device)
|
||||
|
||||
# Forward pass
|
||||
action_probs, _ = self.forward(input_tensor)
|
||||
|
||||
# Calculate loss
|
||||
loss = F.cross_entropy(action_probs, target_tensor)
|
||||
|
||||
# Backward pass
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
return float(loss.item())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training step: {e}")
|
||||
return float('inf')
|
||||
|
||||
def evaluate(self, base_inputs: List[BaseDataInput], targets: List[str]) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate model performance
|
||||
|
||||
Args:
|
||||
base_inputs: List of BaseDataInput for evaluation
|
||||
targets: List of target actions
|
||||
|
||||
Returns:
|
||||
Dict containing evaluation metrics
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
try:
|
||||
correct = 0
|
||||
total = len(base_inputs)
|
||||
total_confidence = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
for base_input, target in zip(base_inputs, targets):
|
||||
model_output = self.predict_from_base_input(base_input)
|
||||
predicted_action = model_output.predictions['action']
|
||||
|
||||
if predicted_action == target:
|
||||
correct += 1
|
||||
|
||||
total_confidence += model_output.confidence
|
||||
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
avg_confidence = total_confidence / total if total > 0 else 0.0
|
||||
|
||||
return {
|
||||
'accuracy': accuracy,
|
||||
'avg_confidence': avg_confidence,
|
||||
'correct_predictions': correct,
|
||||
'total_predictions': total
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in evaluation: {e}")
|
||||
return {'accuracy': 0.0, 'avg_confidence': 0.0, 'correct_predictions': 0, 'total_predictions': 0}
|
||||
|
||||
def save_checkpoint(self, filepath: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Save model checkpoint
|
||||
|
||||
Args:
|
||||
filepath: Path to save checkpoint
|
||||
metadata: Optional metadata to save with checkpoint
|
||||
"""
|
||||
try:
|
||||
checkpoint = {
|
||||
'model_state_dict': self.state_dict(),
|
||||
'model_name': self.model_name,
|
||||
'model_type': self.model_type,
|
||||
'confidence_threshold': self.confidence_threshold,
|
||||
'expected_feature_dim': self.expected_feature_dim,
|
||||
'metadata': metadata or {},
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
torch.save(checkpoint, filepath)
|
||||
logger.info(f"Checkpoint saved to {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint: {e}")
|
||||
|
||||
def load_checkpoint(self, filepath: str) -> bool:
|
||||
"""
|
||||
Load model checkpoint
|
||||
|
||||
Args:
|
||||
filepath: Path to checkpoint file
|
||||
|
||||
Returns:
|
||||
bool: True if loaded successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
checkpoint = torch.load(filepath, map_location=self.device)
|
||||
|
||||
# Load model state
|
||||
self.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
# Load configuration
|
||||
self.model_name = checkpoint.get('model_name', self.model_name)
|
||||
self.confidence_threshold = checkpoint.get('confidence_threshold', self.confidence_threshold)
|
||||
self.expected_feature_dim = checkpoint.get('expected_feature_dim', self.expected_feature_dim)
|
||||
|
||||
logger.info(f"Checkpoint loaded from {filepath}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""Get model information"""
|
||||
return {
|
||||
'model_name': self.model_name,
|
||||
'model_type': self.model_type,
|
||||
'confidence_threshold': self.confidence_threshold,
|
||||
'expected_feature_dim': self.expected_feature_dim,
|
||||
'device': str(self.device),
|
||||
'parameter_count': sum(p.numel() for p in self.parameters()),
|
||||
'trainable_parameters': sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
}
|
||||
@@ -1,821 +0,0 @@
|
||||
"""
|
||||
Transformer Neural Network for timeseries analysis
|
||||
|
||||
This module implements a Transformer model with attention mechanisms for cryptocurrency price analysis.
|
||||
It also includes a Mixture of Experts model that combines predictions from multiple models.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import Model, load_model
|
||||
from tensorflow.keras.layers import (
|
||||
Input, Dense, Dropout, BatchNormalization,
|
||||
Concatenate, Layer, LayerNormalization, MultiHeadAttention,
|
||||
Add, GlobalAveragePooling1D, Conv1D, Reshape
|
||||
)
|
||||
from tensorflow.keras.optimizers import Adam
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
|
||||
import datetime
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TransformerBlock(Layer):
|
||||
"""
|
||||
Transformer block implementation with multi-head attention and feed-forward networks.
|
||||
"""
|
||||
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
|
||||
super(TransformerBlock, self).__init__()
|
||||
self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
|
||||
self.ffn = tf.keras.Sequential([
|
||||
Dense(ff_dim, activation="relu"),
|
||||
Dense(embed_dim),
|
||||
])
|
||||
self.layernorm1 = LayerNormalization(epsilon=1e-6)
|
||||
self.layernorm2 = LayerNormalization(epsilon=1e-6)
|
||||
self.dropout1 = Dropout(rate)
|
||||
self.dropout2 = Dropout(rate)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
attn_output = self.att(inputs, inputs)
|
||||
attn_output = self.dropout1(attn_output, training=training)
|
||||
out1 = self.layernorm1(inputs + attn_output)
|
||||
ffn_output = self.ffn(out1)
|
||||
ffn_output = self.dropout2(ffn_output, training=training)
|
||||
return self.layernorm2(out1 + ffn_output)
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config()
|
||||
config.update({
|
||||
'att': self.att,
|
||||
'ffn': self.ffn,
|
||||
'layernorm1': self.layernorm1,
|
||||
'layernorm2': self.layernorm2,
|
||||
'dropout1': self.dropout1,
|
||||
'dropout2': self.dropout2
|
||||
})
|
||||
return config
|
||||
|
||||
class PositionalEncoding(Layer):
|
||||
"""
|
||||
Positional encoding layer to add position information to input embeddings.
|
||||
"""
|
||||
def __init__(self, position, d_model):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.position = position
|
||||
self.d_model = d_model
|
||||
self.pos_encoding = self.positional_encoding(position, d_model)
|
||||
|
||||
def get_angles(self, position, i, d_model):
|
||||
angles = 1 / tf.pow(10000, (2 * (i // 2)) / tf.cast(d_model, tf.float32))
|
||||
return position * angles
|
||||
|
||||
def positional_encoding(self, position, d_model):
|
||||
angle_rads = self.get_angles(
|
||||
position=tf.range(position, dtype=tf.float32)[:, tf.newaxis],
|
||||
i=tf.range(d_model, dtype=tf.float32)[tf.newaxis, :],
|
||||
d_model=d_model
|
||||
)
|
||||
|
||||
# Apply sin to even indices in the array
|
||||
sines = tf.math.sin(angle_rads[:, 0::2])
|
||||
|
||||
# Apply cos to odd indices in the array
|
||||
cosines = tf.math.cos(angle_rads[:, 1::2])
|
||||
|
||||
pos_encoding = tf.concat([sines, cosines], axis=-1)
|
||||
pos_encoding = pos_encoding[tf.newaxis, ...]
|
||||
|
||||
return tf.cast(pos_encoding, tf.float32)
|
||||
|
||||
def call(self, inputs):
|
||||
return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :]
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config()
|
||||
config.update({
|
||||
'position': self.position,
|
||||
'd_model': self.d_model,
|
||||
'pos_encoding': self.pos_encoding
|
||||
})
|
||||
return config
|
||||
|
||||
class TransformerModel:
|
||||
"""
|
||||
Transformer Neural Network for time series analysis.
|
||||
|
||||
This model uses self-attention mechanisms to capture relationships between
|
||||
different time points in the input data.
|
||||
"""
|
||||
|
||||
def __init__(self, ts_input_shape=(20, 5), feature_input_shape=64, output_size=1, model_dir="NN/models/saved"):
|
||||
"""
|
||||
Initialize the Transformer model.
|
||||
|
||||
Args:
|
||||
ts_input_shape (tuple): Shape of time series input data (sequence_length, features)
|
||||
feature_input_shape (int): Shape of additional feature input (e.g., from CNN)
|
||||
output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell)
|
||||
model_dir (str): Directory to save trained models
|
||||
"""
|
||||
self.ts_input_shape = ts_input_shape
|
||||
self.feature_input_shape = feature_input_shape
|
||||
self.output_size = output_size
|
||||
self.model_dir = model_dir
|
||||
self.model = None
|
||||
self.history = None
|
||||
|
||||
# Create model directory if it doesn't exist
|
||||
os.makedirs(self.model_dir, exist_ok=True)
|
||||
|
||||
logger.info(f"Initialized Transformer model with TS input shape {ts_input_shape}, "
|
||||
f"feature input shape {feature_input_shape}, and output size {output_size}")
|
||||
|
||||
def build_model(self, embed_dim=32, num_heads=4, ff_dim=64, num_transformer_blocks=2, dropout_rate=0.1, learning_rate=0.001):
|
||||
"""
|
||||
Build the Transformer model architecture.
|
||||
|
||||
Args:
|
||||
embed_dim (int): Embedding dimension for transformer
|
||||
num_heads (int): Number of attention heads
|
||||
ff_dim (int): Hidden dimension of the feed forward network
|
||||
num_transformer_blocks (int): Number of transformer blocks
|
||||
dropout_rate (float): Dropout rate for regularization
|
||||
learning_rate (float): Learning rate for Adam optimizer
|
||||
|
||||
Returns:
|
||||
The compiled model
|
||||
"""
|
||||
# Time series input
|
||||
ts_inputs = Input(shape=self.ts_input_shape, name="ts_input")
|
||||
|
||||
# Additional feature input (e.g., from CNN)
|
||||
feature_inputs = Input(shape=(self.feature_input_shape,), name="feature_input")
|
||||
|
||||
# Process time series with transformer
|
||||
# First, project the input to the embedding dimension
|
||||
x = Conv1D(embed_dim, 1, activation="relu")(ts_inputs)
|
||||
|
||||
# Add positional encoding
|
||||
x = PositionalEncoding(self.ts_input_shape[0], embed_dim)(x)
|
||||
|
||||
# Add transformer blocks
|
||||
for _ in range(num_transformer_blocks):
|
||||
x = TransformerBlock(embed_dim, num_heads, ff_dim, dropout_rate)(x)
|
||||
|
||||
# Global pooling to get a single vector representation
|
||||
x = GlobalAveragePooling1D()(x)
|
||||
x = Dropout(dropout_rate)(x)
|
||||
|
||||
# Combine with additional features
|
||||
combined = Concatenate()([x, feature_inputs])
|
||||
|
||||
# Dense layers for final classification/regression
|
||||
x = Dense(64, activation="relu")(combined)
|
||||
x = BatchNormalization()(x)
|
||||
x = Dropout(dropout_rate)(x)
|
||||
|
||||
# Output layer
|
||||
if self.output_size == 1:
|
||||
# Binary classification (up/down)
|
||||
outputs = Dense(1, activation='sigmoid', name='output')(x)
|
||||
loss = 'binary_crossentropy'
|
||||
metrics = ['accuracy']
|
||||
elif self.output_size == 3:
|
||||
# Multi-class classification (buy/hold/sell)
|
||||
outputs = Dense(3, activation='softmax', name='output')(x)
|
||||
loss = 'categorical_crossentropy'
|
||||
metrics = ['accuracy']
|
||||
else:
|
||||
# Regression
|
||||
outputs = Dense(self.output_size, activation='linear', name='output')(x)
|
||||
loss = 'mse'
|
||||
metrics = ['mae']
|
||||
|
||||
# Create and compile model
|
||||
self.model = Model(inputs=[ts_inputs, feature_inputs], outputs=outputs)
|
||||
|
||||
# Compile with Adam optimizer
|
||||
self.model.compile(
|
||||
optimizer=Adam(learning_rate=learning_rate),
|
||||
loss=loss,
|
||||
metrics=metrics
|
||||
)
|
||||
|
||||
# Log model summary
|
||||
self.model.summary(print_fn=lambda x: logger.info(x))
|
||||
|
||||
return self.model
|
||||
|
||||
def train(self, X_ts, X_features, y, batch_size=32, epochs=100, validation_split=0.2,
|
||||
callbacks=None, class_weights=None):
|
||||
"""
|
||||
Train the Transformer model on the provided data.
|
||||
|
||||
Args:
|
||||
X_ts (numpy.ndarray): Time series input features
|
||||
X_features (numpy.ndarray): Additional input features
|
||||
y (numpy.ndarray): Target labels
|
||||
batch_size (int): Batch size
|
||||
epochs (int): Number of epochs
|
||||
validation_split (float): Fraction of data to use for validation
|
||||
callbacks (list): List of Keras callbacks
|
||||
class_weights (dict): Class weights for imbalanced datasets
|
||||
|
||||
Returns:
|
||||
History object containing training metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
self.build_model()
|
||||
|
||||
# Default callbacks if none provided
|
||||
if callbacks is None:
|
||||
# Create a timestamp for model checkpoints
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
callbacks = [
|
||||
EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
restore_best_weights=True
|
||||
),
|
||||
ReduceLROnPlateau(
|
||||
monitor='val_loss',
|
||||
factor=0.5,
|
||||
patience=5,
|
||||
min_lr=1e-6
|
||||
),
|
||||
ModelCheckpoint(
|
||||
filepath=os.path.join(self.model_dir, f"transformer_model_{timestamp}.h5"),
|
||||
monitor='val_loss',
|
||||
save_best_only=True
|
||||
)
|
||||
]
|
||||
|
||||
# Check if y needs to be one-hot encoded for multi-class
|
||||
if self.output_size == 3 and len(y.shape) == 1:
|
||||
y = tf.keras.utils.to_categorical(y, num_classes=3)
|
||||
|
||||
# Train the model
|
||||
logger.info(f"Training Transformer model with {len(X_ts)} samples, batch size {batch_size}, epochs {epochs}")
|
||||
self.history = self.model.fit(
|
||||
[X_ts, X_features], y,
|
||||
batch_size=batch_size,
|
||||
epochs=epochs,
|
||||
validation_split=validation_split,
|
||||
callbacks=callbacks,
|
||||
class_weight=class_weights,
|
||||
verbose=2
|
||||
)
|
||||
|
||||
# Save the trained model
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_path = os.path.join(self.model_dir, f"transformer_model_final_{timestamp}.h5")
|
||||
self.model.save(model_path)
|
||||
logger.info(f"Model saved to {model_path}")
|
||||
|
||||
# Save training history
|
||||
history_path = os.path.join(self.model_dir, f"transformer_model_history_{timestamp}.json")
|
||||
with open(history_path, 'w') as f:
|
||||
# Convert numpy values to Python native types for JSON serialization
|
||||
history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()}
|
||||
json.dump(history_dict, f, indent=2)
|
||||
|
||||
return self.history
|
||||
|
||||
def evaluate(self, X_ts, X_features, y):
|
||||
"""
|
||||
Evaluate the model on test data.
|
||||
|
||||
Args:
|
||||
X_ts (numpy.ndarray): Time series input features
|
||||
X_features (numpy.ndarray): Additional input features
|
||||
y (numpy.ndarray): Target labels
|
||||
|
||||
Returns:
|
||||
dict: Evaluation metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built or trained yet")
|
||||
|
||||
# Convert y to one-hot encoding for multi-class
|
||||
if self.output_size == 3 and len(y.shape) == 1:
|
||||
y = tf.keras.utils.to_categorical(y, num_classes=3)
|
||||
|
||||
# Evaluate model
|
||||
logger.info(f"Evaluating Transformer model on {len(X_ts)} samples")
|
||||
eval_results = self.model.evaluate([X_ts, X_features], y, verbose=0)
|
||||
|
||||
metrics = {}
|
||||
for metric, value in zip(self.model.metrics_names, eval_results):
|
||||
metrics[metric] = value
|
||||
logger.info(f"{metric}: {value:.4f}")
|
||||
|
||||
return metrics
|
||||
|
||||
def predict(self, X_ts, X_features=None):
|
||||
"""
|
||||
Make predictions on new data.
|
||||
|
||||
Args:
|
||||
X_ts (numpy.ndarray): Time series input features
|
||||
X_features (numpy.ndarray): Additional input features
|
||||
|
||||
Returns:
|
||||
tuple: (y_pred, y_proba) where:
|
||||
y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class)
|
||||
y_proba is the class probability
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built or trained yet")
|
||||
|
||||
# Ensure X_ts has the right shape
|
||||
if len(X_ts.shape) == 2:
|
||||
# Single sample, add batch dimension
|
||||
X_ts = np.expand_dims(X_ts, axis=0)
|
||||
|
||||
# Ensure X_features has the right shape
|
||||
if X_features is None:
|
||||
# Extract features from time series data if no external features provided
|
||||
X_features = self._extract_features_from_timeseries(X_ts)
|
||||
elif len(X_features.shape) == 1:
|
||||
# Single sample, add batch dimension
|
||||
X_features = np.expand_dims(X_features, axis=0)
|
||||
|
||||
def _extract_features_from_timeseries(self, X_ts: np.ndarray) -> np.ndarray:
|
||||
"""Extract meaningful features from time series data instead of using dummy zeros"""
|
||||
try:
|
||||
batch_size = X_ts.shape[0]
|
||||
features = []
|
||||
|
||||
for i in range(batch_size):
|
||||
sample = X_ts[i] # Shape: (timesteps, features)
|
||||
|
||||
# Extract statistical features from each feature dimension
|
||||
sample_features = []
|
||||
|
||||
for feature_idx in range(sample.shape[1]):
|
||||
feature_data = sample[:, feature_idx]
|
||||
|
||||
# Basic statistical features
|
||||
sample_features.extend([
|
||||
np.mean(feature_data), # Mean
|
||||
np.std(feature_data), # Standard deviation
|
||||
np.min(feature_data), # Minimum
|
||||
np.max(feature_data), # Maximum
|
||||
np.percentile(feature_data, 25), # 25th percentile
|
||||
np.percentile(feature_data, 75), # 75th percentile
|
||||
])
|
||||
|
||||
# Trend features
|
||||
if len(feature_data) > 1:
|
||||
# Linear trend (slope)
|
||||
x = np.arange(len(feature_data))
|
||||
slope = np.polyfit(x, feature_data, 1)[0]
|
||||
sample_features.append(slope)
|
||||
|
||||
# Rate of change
|
||||
rate_of_change = (feature_data[-1] - feature_data[0]) / feature_data[0] if feature_data[0] != 0 else 0
|
||||
sample_features.append(rate_of_change)
|
||||
else:
|
||||
sample_features.extend([0.0, 0.0])
|
||||
|
||||
# Pad or truncate to expected feature size
|
||||
while len(sample_features) < self.feature_input_shape:
|
||||
sample_features.append(0.0)
|
||||
sample_features = sample_features[:self.feature_input_shape]
|
||||
|
||||
features.append(sample_features)
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting features from time series: {e}")
|
||||
# Fallback to zeros if extraction fails
|
||||
return np.zeros((X_ts.shape[0], self.feature_input_shape), dtype=np.float32)
|
||||
|
||||
# Get predictions
|
||||
y_proba = self.model.predict([X_ts, X_features])
|
||||
|
||||
# Process based on output type
|
||||
if self.output_size == 1:
|
||||
# Binary classification
|
||||
y_pred = (y_proba > 0.5).astype(int).flatten()
|
||||
return y_pred, y_proba.flatten()
|
||||
elif self.output_size == 3:
|
||||
# Multi-class classification
|
||||
y_pred = np.argmax(y_proba, axis=1)
|
||||
return y_pred, y_proba
|
||||
else:
|
||||
# Regression
|
||||
return y_proba, y_proba
|
||||
|
||||
def save(self, filepath=None):
|
||||
"""
|
||||
Save the model to disk.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to save the model
|
||||
|
||||
Returns:
|
||||
str: Path where the model was saved
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built yet")
|
||||
|
||||
if filepath is None:
|
||||
# Create a default filepath with timestamp
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filepath = os.path.join(self.model_dir, f"transformer_model_{timestamp}.h5")
|
||||
|
||||
self.model.save(filepath)
|
||||
logger.info(f"Model saved to {filepath}")
|
||||
return filepath
|
||||
|
||||
def load(self, filepath):
|
||||
"""
|
||||
Load a saved model from disk.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to the saved model
|
||||
|
||||
Returns:
|
||||
The loaded model
|
||||
"""
|
||||
# Register custom layers
|
||||
custom_objects = {
|
||||
'TransformerBlock': TransformerBlock,
|
||||
'PositionalEncoding': PositionalEncoding
|
||||
}
|
||||
|
||||
self.model = load_model(filepath, custom_objects=custom_objects)
|
||||
logger.info(f"Model loaded from {filepath}")
|
||||
return self.model
|
||||
|
||||
def plot_training_history(self):
|
||||
"""
|
||||
Plot training history (loss and metrics).
|
||||
|
||||
Returns:
|
||||
str: Path to the saved plot
|
||||
"""
|
||||
if self.history is None:
|
||||
raise ValueError("Model has not been trained yet")
|
||||
|
||||
plt.figure(figsize=(12, 5))
|
||||
|
||||
# Plot loss
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.plot(self.history.history['loss'], label='Training Loss')
|
||||
if 'val_loss' in self.history.history:
|
||||
plt.plot(self.history.history['val_loss'], label='Validation Loss')
|
||||
plt.title('Model Loss')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
|
||||
# Plot accuracy
|
||||
plt.subplot(1, 2, 2)
|
||||
|
||||
if 'accuracy' in self.history.history:
|
||||
plt.plot(self.history.history['accuracy'], label='Training Accuracy')
|
||||
if 'val_accuracy' in self.history.history:
|
||||
plt.plot(self.history.history['val_accuracy'], label='Validation Accuracy')
|
||||
plt.title('Model Accuracy')
|
||||
plt.ylabel('Accuracy')
|
||||
elif 'mae' in self.history.history:
|
||||
plt.plot(self.history.history['mae'], label='Training MAE')
|
||||
if 'val_mae' in self.history.history:
|
||||
plt.plot(self.history.history['val_mae'], label='Validation MAE')
|
||||
plt.title('Model MAE')
|
||||
plt.ylabel('MAE')
|
||||
|
||||
plt.xlabel('Epoch')
|
||||
plt.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Save figure
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
fig_path = os.path.join(self.model_dir, f"transformer_training_history_{timestamp}.png")
|
||||
plt.savefig(fig_path)
|
||||
plt.close()
|
||||
|
||||
logger.info(f"Training history plot saved to {fig_path}")
|
||||
return fig_path
|
||||
|
||||
|
||||
class MixtureOfExpertsModel:
|
||||
"""
|
||||
Mixture of Experts (MoE) model.
|
||||
|
||||
This model combines predictions from multiple expert models (such as CNN and Transformer)
|
||||
using a weighted ensemble approach.
|
||||
"""
|
||||
|
||||
def __init__(self, output_size=1, model_dir="NN/models/saved"):
|
||||
"""
|
||||
Initialize the MoE model.
|
||||
|
||||
Args:
|
||||
output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell)
|
||||
model_dir (str): Directory to save trained models
|
||||
"""
|
||||
self.output_size = output_size
|
||||
self.model_dir = model_dir
|
||||
self.model = None
|
||||
self.history = None
|
||||
self.experts = {}
|
||||
|
||||
# Create model directory if it doesn't exist
|
||||
os.makedirs(self.model_dir, exist_ok=True)
|
||||
|
||||
logger.info(f"Initialized Mixture of Experts model with output size {output_size}")
|
||||
|
||||
def add_expert(self, name, model):
|
||||
"""
|
||||
Add an expert model to the MoE.
|
||||
|
||||
Args:
|
||||
name (str): Name of the expert model
|
||||
model: The expert model instance
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.experts[name] = model
|
||||
logger.info(f"Added expert model '{name}' to MoE")
|
||||
|
||||
def build_model(self, ts_input_shape=(20, 5), expert_weights=None, learning_rate=0.001):
|
||||
"""
|
||||
Build the MoE model by combining expert models.
|
||||
|
||||
Args:
|
||||
ts_input_shape (tuple): Shape of time series input data
|
||||
expert_weights (dict): Weights for each expert model
|
||||
learning_rate (float): Learning rate for Adam optimizer
|
||||
|
||||
Returns:
|
||||
The compiled model
|
||||
"""
|
||||
# Time series input
|
||||
ts_inputs = Input(shape=ts_input_shape, name="ts_input")
|
||||
|
||||
# Additional feature input (from CNN)
|
||||
feature_inputs = Input(shape=(64,), name="feature_input") # Default size for features
|
||||
|
||||
# Process with each expert model
|
||||
expert_outputs = []
|
||||
expert_names = []
|
||||
|
||||
for name, expert in self.experts.items():
|
||||
# Skip if expert model is not valid or doesn't have a call/predict method
|
||||
if expert is None:
|
||||
logger.warning(f"Expert model '{name}' is None, skipping")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Different handling based on model type
|
||||
if name == 'cnn':
|
||||
# CNN model takes only time series input
|
||||
expert_output = expert(ts_inputs)
|
||||
expert_outputs.append(expert_output)
|
||||
expert_names.append(name)
|
||||
elif name == 'transformer':
|
||||
# Transformer model takes both time series and feature inputs
|
||||
expert_output = expert([ts_inputs, feature_inputs])
|
||||
expert_outputs.append(expert_output)
|
||||
expert_names.append(name)
|
||||
else:
|
||||
logger.warning(f"Unknown expert model type: {name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding expert '{name}': {str(e)}")
|
||||
|
||||
if not expert_outputs:
|
||||
logger.error("No valid expert models found")
|
||||
return None
|
||||
|
||||
# Use expert weighting
|
||||
if expert_weights is None:
|
||||
# Equal weighting
|
||||
weights = [1.0 / len(expert_outputs)] * len(expert_outputs)
|
||||
else:
|
||||
# User-provided weights
|
||||
weights = [expert_weights.get(name, 1.0 / len(expert_outputs)) for name in expert_names]
|
||||
# Normalize weights
|
||||
weights = [w / sum(weights) for w in weights]
|
||||
|
||||
# Combine expert outputs using weighted average
|
||||
if len(expert_outputs) == 1:
|
||||
# Only one expert, use its output directly
|
||||
combined_output = expert_outputs[0]
|
||||
else:
|
||||
# Multiple experts, compute weighted average
|
||||
weighted_outputs = [output * weight for output, weight in zip(expert_outputs, weights)]
|
||||
combined_output = Add()(weighted_outputs)
|
||||
|
||||
# Create the MoE model
|
||||
moe_model = Model(inputs=[ts_inputs, feature_inputs], outputs=combined_output)
|
||||
|
||||
# Compile the model
|
||||
if self.output_size == 1:
|
||||
# Binary classification
|
||||
moe_model.compile(
|
||||
optimizer=Adam(learning_rate=learning_rate),
|
||||
loss='binary_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
elif self.output_size == 3:
|
||||
# Multi-class classification for BUY/HOLD/SELL
|
||||
moe_model.compile(
|
||||
optimizer=Adam(learning_rate=learning_rate),
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
else:
|
||||
# Regression
|
||||
moe_model.compile(
|
||||
optimizer=Adam(learning_rate=learning_rate),
|
||||
loss='mse',
|
||||
metrics=['mae']
|
||||
)
|
||||
|
||||
self.model = moe_model
|
||||
|
||||
# Log model summary
|
||||
self.model.summary(print_fn=lambda x: logger.info(x))
|
||||
|
||||
logger.info(f"Built MoE model with weights: {weights}")
|
||||
return self.model
|
||||
|
||||
def train(self, X_ts, X_features, y, batch_size=32, epochs=100, validation_split=0.2,
|
||||
callbacks=None, class_weights=None):
|
||||
"""
|
||||
Train the MoE model on the provided data.
|
||||
|
||||
Args:
|
||||
X_ts (numpy.ndarray): Time series input features
|
||||
X_features (numpy.ndarray): Additional input features
|
||||
y (numpy.ndarray): Target labels
|
||||
batch_size (int): Batch size
|
||||
epochs (int): Number of epochs
|
||||
validation_split (float): Fraction of data to use for validation
|
||||
callbacks (list): List of Keras callbacks
|
||||
class_weights (dict): Class weights for imbalanced datasets
|
||||
|
||||
Returns:
|
||||
History object containing training metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
logger.error("MoE model has not been built yet")
|
||||
return None
|
||||
|
||||
# Default callbacks if none provided
|
||||
if callbacks is None:
|
||||
# Create a timestamp for model checkpoints
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
callbacks = [
|
||||
EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
restore_best_weights=True
|
||||
),
|
||||
ReduceLROnPlateau(
|
||||
monitor='val_loss',
|
||||
factor=0.5,
|
||||
patience=5,
|
||||
min_lr=1e-6
|
||||
),
|
||||
ModelCheckpoint(
|
||||
filepath=os.path.join(self.model_dir, f"moe_model_{timestamp}.h5"),
|
||||
monitor='val_loss',
|
||||
save_best_only=True
|
||||
)
|
||||
]
|
||||
|
||||
# Check if y needs to be one-hot encoded for multi-class
|
||||
if self.output_size == 3 and len(y.shape) == 1:
|
||||
y = tf.keras.utils.to_categorical(y, num_classes=3)
|
||||
|
||||
# Train the model
|
||||
logger.info(f"Training MoE model with {len(X_ts)} samples, batch size {batch_size}, epochs {epochs}")
|
||||
self.history = self.model.fit(
|
||||
[X_ts, X_features], y,
|
||||
batch_size=batch_size,
|
||||
epochs=epochs,
|
||||
validation_split=validation_split,
|
||||
callbacks=callbacks,
|
||||
class_weight=class_weights,
|
||||
verbose=2
|
||||
)
|
||||
|
||||
# Save the trained model
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_path = os.path.join(self.model_dir, f"moe_model_final_{timestamp}.h5")
|
||||
self.model.save(model_path)
|
||||
logger.info(f"Model saved to {model_path}")
|
||||
|
||||
# Save training history
|
||||
history_path = os.path.join(self.model_dir, f"moe_model_history_{timestamp}.json")
|
||||
with open(history_path, 'w') as f:
|
||||
# Convert numpy values to Python native types for JSON serialization
|
||||
history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()}
|
||||
json.dump(history_dict, f, indent=2)
|
||||
|
||||
return self.history
|
||||
|
||||
def predict(self, X_ts, X_features=None):
|
||||
"""
|
||||
Make predictions on new data.
|
||||
|
||||
Args:
|
||||
X_ts (numpy.ndarray): Time series input features
|
||||
X_features (numpy.ndarray): Additional input features
|
||||
|
||||
Returns:
|
||||
tuple: (y_pred, y_proba) where:
|
||||
y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class)
|
||||
y_proba is the class probability
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built or trained yet")
|
||||
|
||||
# Ensure X_ts has the right shape
|
||||
if len(X_ts.shape) == 2:
|
||||
# Single sample, add batch dimension
|
||||
X_ts = np.expand_dims(X_ts, axis=0)
|
||||
|
||||
# Ensure X_features has the right shape
|
||||
if X_features is None:
|
||||
# Create dummy features with zeros
|
||||
X_features = np.zeros((X_ts.shape[0], 64)) # Default size
|
||||
elif len(X_features.shape) == 1:
|
||||
# Single sample, add batch dimension
|
||||
X_features = np.expand_dims(X_features, axis=0)
|
||||
|
||||
# Get predictions
|
||||
y_proba = self.model.predict([X_ts, X_features])
|
||||
|
||||
# Process based on output type
|
||||
if self.output_size == 1:
|
||||
# Binary classification
|
||||
y_pred = (y_proba > 0.5).astype(int).flatten()
|
||||
return y_pred, y_proba.flatten()
|
||||
elif self.output_size == 3:
|
||||
# Multi-class classification
|
||||
y_pred = np.argmax(y_proba, axis=1)
|
||||
return y_pred, y_proba
|
||||
else:
|
||||
# Regression
|
||||
return y_proba, y_proba
|
||||
|
||||
def save(self, filepath=None):
|
||||
"""
|
||||
Save the model to disk.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to save the model
|
||||
|
||||
Returns:
|
||||
str: Path where the model was saved
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built yet")
|
||||
|
||||
if filepath is None:
|
||||
# Create a default filepath with timestamp
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filepath = os.path.join(self.model_dir, f"moe_model_{timestamp}.h5")
|
||||
|
||||
self.model.save(filepath)
|
||||
logger.info(f"Model saved to {filepath}")
|
||||
return filepath
|
||||
|
||||
def load(self, filepath):
|
||||
"""
|
||||
Load a saved model from disk.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to the saved model
|
||||
|
||||
Returns:
|
||||
The loaded model
|
||||
"""
|
||||
# Register custom layers
|
||||
custom_objects = {
|
||||
'TransformerBlock': TransformerBlock,
|
||||
'PositionalEncoding': PositionalEncoding
|
||||
}
|
||||
|
||||
self.model = load_model(filepath, custom_objects=custom_objects)
|
||||
logger.info(f"Model loaded from {filepath}")
|
||||
return self.model
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
# This would be a complete implementation in a real system
|
||||
print("Transformer and MoE models defined, but not implemented here.")
|
||||
@@ -1,88 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Start TensorBoard for monitoring neural network training
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import webbrowser
|
||||
from time import sleep
|
||||
|
||||
def start_tensorboard(logdir="NN/models/saved/logs", port=6006, open_browser=True):
|
||||
"""
|
||||
Start TensorBoard in a subprocess
|
||||
|
||||
Args:
|
||||
logdir: Directory containing TensorBoard logs
|
||||
port: Port to run TensorBoard on
|
||||
open_browser: Whether to open a browser automatically
|
||||
"""
|
||||
# Make sure the log directory exists
|
||||
os.makedirs(logdir, exist_ok=True)
|
||||
|
||||
# Create command
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"tensorboard.main",
|
||||
f"--logdir={logdir}",
|
||||
f"--port={port}",
|
||||
"--bind_all"
|
||||
]
|
||||
|
||||
print(f"Starting TensorBoard with logs from {logdir} on port {port}")
|
||||
print(f"Command: {' '.join(cmd)}")
|
||||
|
||||
# Start TensorBoard in a subprocess
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=True
|
||||
)
|
||||
|
||||
# Wait for TensorBoard to start up
|
||||
for line in process.stdout:
|
||||
print(line.strip())
|
||||
if "TensorBoard" in line and "http://" in line:
|
||||
# TensorBoard is running, extract the URL
|
||||
url = None
|
||||
for part in line.split():
|
||||
if part.startswith(("http://", "https://")):
|
||||
url = part
|
||||
break
|
||||
|
||||
# Open browser if requested and URL found
|
||||
if open_browser and url:
|
||||
print(f"Opening TensorBoard in browser: {url}")
|
||||
webbrowser.open(url)
|
||||
|
||||
break
|
||||
|
||||
# Return the process for the caller to manage
|
||||
return process
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description="Start TensorBoard for NN training visualization")
|
||||
parser.add_argument("--logdir", default="NN/models/saved/logs", help="Directory containing TensorBoard logs")
|
||||
parser.add_argument("--port", type=int, default=6006, help="Port to run TensorBoard on")
|
||||
parser.add_argument("--no-browser", action="store_true", help="Don't open browser automatically")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Start TensorBoard
|
||||
process = start_tensorboard(args.logdir, args.port, not args.no_browser)
|
||||
|
||||
try:
|
||||
# Keep the script running until Ctrl+C
|
||||
print("TensorBoard is running. Press Ctrl+C to stop.")
|
||||
while True:
|
||||
sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("Stopping TensorBoard...")
|
||||
process.terminate()
|
||||
process.wait()
|
||||
@@ -27,8 +27,18 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
<<<<<<< HEAD
|
||||
# Import prediction tracking
|
||||
from core.prediction_database import get_prediction_db
|
||||
=======
|
||||
# Import checkpoint management
|
||||
try:
|
||||
from utils.checkpoint_manager import get_checkpoint_manager, save_checkpoint
|
||||
CHECKPOINT_MANAGER_AVAILABLE = True
|
||||
except ImportError:
|
||||
CHECKPOINT_MANAGER_AVAILABLE = False
|
||||
logger.warning("Checkpoint manager not available. Model persistence will be disabled.")
|
||||
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -61,12 +71,19 @@ class EnhancedRealtimeTrainingSystem:
|
||||
# Experience buffers
|
||||
self.experience_buffer = deque(maxlen=self.training_config['memory_size'])
|
||||
self.validation_buffer = deque(maxlen=1000)
|
||||
|
||||
# Training counters - CRITICAL for checkpoint management
|
||||
self.training_iteration = 0
|
||||
self.dqn_training_count = 0
|
||||
self.cnn_training_count = 0
|
||||
self.cob_training_count = 0
|
||||
self.priority_buffer = deque(maxlen=2000) # High-priority experiences
|
||||
|
||||
# Performance tracking
|
||||
self.performance_history = {
|
||||
'dqn_losses': deque(maxlen=1000),
|
||||
'cnn_losses': deque(maxlen=1000),
|
||||
'cob_rl_losses': deque(maxlen=1000), # Added COB RL loss tracking
|
||||
'prediction_accuracy': deque(maxlen=500),
|
||||
'trading_performance': deque(maxlen=200),
|
||||
'validation_scores': deque(maxlen=100)
|
||||
@@ -764,18 +781,33 @@ class EnhancedRealtimeTrainingSystem:
|
||||
# Statistical features across time for each aggregated dimension
|
||||
for feature_idx in range(agg_matrix.shape[1]):
|
||||
feature_series = agg_matrix[:, feature_idx]
|
||||
combined_features.extend([
|
||||
np.mean(feature_series),
|
||||
np.std(feature_series),
|
||||
np.min(feature_series),
|
||||
np.max(feature_series),
|
||||
feature_series[-1] - feature_series[0] if len(feature_series) > 1 else 0, # Total change
|
||||
np.mean(np.diff(feature_series)) if len(feature_series) > 1 else 0, # Average momentum
|
||||
np.std(np.diff(feature_series)) if len(feature_series) > 2 else 0, # Momentum volatility
|
||||
np.percentile(feature_series, 25), # 25th percentile
|
||||
np.percentile(feature_series, 75), # 75th percentile
|
||||
len([x for x in np.diff(feature_series) if x > 0]) / max(len(feature_series) - 1, 1) if len(feature_series) > 1 else 0.5 # Positive change ratio
|
||||
])
|
||||
# Clean feature series to prevent division warnings
|
||||
feature_series_clean = feature_series[np.isfinite(feature_series)]
|
||||
|
||||
if len(feature_series_clean) > 0:
|
||||
# Safe percentile calculation
|
||||
try:
|
||||
percentile_25 = np.percentile(feature_series_clean, 25)
|
||||
percentile_75 = np.percentile(feature_series_clean, 75)
|
||||
except (ValueError, RuntimeWarning):
|
||||
percentile_25 = np.median(feature_series_clean) if len(feature_series_clean) > 0 else 0
|
||||
percentile_75 = np.median(feature_series_clean) if len(feature_series_clean) > 0 else 0
|
||||
|
||||
combined_features.extend([
|
||||
np.mean(feature_series_clean),
|
||||
np.std(feature_series_clean),
|
||||
np.min(feature_series_clean),
|
||||
np.max(feature_series_clean),
|
||||
feature_series_clean[-1] - feature_series_clean[0] if len(feature_series_clean) > 1 else 0, # Total change
|
||||
np.mean(np.diff(feature_series_clean)) if len(feature_series_clean) > 1 else 0, # Average momentum
|
||||
np.std(np.diff(feature_series_clean)) if len(feature_series_clean) > 2 else 0, # Momentum volatility
|
||||
percentile_25, # 25th percentile
|
||||
percentile_75, # 75th percentile
|
||||
len([x for x in np.diff(feature_series_clean) if x > 0]) / max(len(feature_series_clean) - 1, 1) if len(feature_series_clean) > 1 else 0.5 # Positive change ratio
|
||||
])
|
||||
else:
|
||||
# All values are NaN or inf, use zeros
|
||||
combined_features.extend([0.0] * 10)
|
||||
else:
|
||||
combined_features.extend([0.0] * (15 * 10)) # 15 features * 10 statistics
|
||||
|
||||
@@ -913,13 +945,14 @@ class EnhancedRealtimeTrainingSystem:
|
||||
lows = np.array([bar['low'] for bar in self.real_time_data['ohlcv_1m']])
|
||||
|
||||
# Update indicators
|
||||
price_mean = np.mean(prices[-20:])
|
||||
self.technical_indicators = {
|
||||
'sma_10': np.mean(prices[-10:]),
|
||||
'sma_20': np.mean(prices[-20:]),
|
||||
'rsi': self._calculate_rsi(prices, 14),
|
||||
'volatility': np.std(prices[-20:]) / np.mean(prices[-20:]),
|
||||
'volatility': np.std(prices[-20:]) / price_mean if price_mean > 0 else 0,
|
||||
'volume_sma': np.mean(volumes[-10:]),
|
||||
'price_momentum': (prices[-1] - prices[-5]) / prices[-5] if len(prices) >= 5 else 0,
|
||||
'price_momentum': (prices[-1] - prices[-5]) / prices[-5] if len(prices) >= 5 and prices[-5] > 0 else 0,
|
||||
'atr': np.mean(highs[-14:] - lows[-14:]) if len(prices) >= 14 else 0
|
||||
}
|
||||
|
||||
@@ -935,8 +968,8 @@ class EnhancedRealtimeTrainingSystem:
|
||||
current_time = time.time()
|
||||
current_bar = self.real_time_data['ohlcv_1m'][-1]
|
||||
|
||||
# Create comprehensive state features
|
||||
state_features = self._build_comprehensive_state()
|
||||
# Create comprehensive state features with default dimensions
|
||||
state_features = self._build_comprehensive_state(100) # Use default 100 for general experiences
|
||||
|
||||
# Create experience with proper reward calculation
|
||||
experience = {
|
||||
@@ -959,8 +992,8 @@ class EnhancedRealtimeTrainingSystem:
|
||||
except Exception as e:
|
||||
logger.debug(f"Error creating training experiences: {e}")
|
||||
|
||||
def _build_comprehensive_state(self) -> np.ndarray:
|
||||
"""Build comprehensive state vector for RL training"""
|
||||
def _build_comprehensive_state(self, target_dimensions: int = 100) -> np.ndarray:
|
||||
"""Build comprehensive state vector for RL training with adaptive dimensions"""
|
||||
try:
|
||||
state_features = []
|
||||
|
||||
@@ -1003,15 +1036,138 @@ class EnhancedRealtimeTrainingSystem:
|
||||
state_features.append(np.cos(2 * np.pi * now.hour / 24))
|
||||
state_features.append(now.weekday() / 6.0) # Day of week
|
||||
|
||||
# Pad to fixed size (100 features)
|
||||
while len(state_features) < 100:
|
||||
# Current count: 10 (prices) + 7 (indicators) + 1 (volume) + 5 (COB) + 3 (time) = 26 base features
|
||||
|
||||
# 6. Enhanced features for larger dimensions
|
||||
if target_dimensions > 50:
|
||||
# Add more price history
|
||||
if len(self.real_time_data['ohlcv_1m']) >= 20:
|
||||
extended_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-20:]]
|
||||
base_price = extended_prices[0]
|
||||
extended_normalized = [(p - base_price) / base_price for p in extended_prices[10:]] # Additional 10
|
||||
state_features.extend(extended_normalized)
|
||||
else:
|
||||
state_features.extend([0.0] * 10)
|
||||
|
||||
# Add volume history
|
||||
if len(self.real_time_data['ohlcv_1m']) >= 10:
|
||||
volume_history = [bar['volume'] for bar in list(self.real_time_data['ohlcv_1m'])[-10:]]
|
||||
avg_vol = np.mean(volume_history) if volume_history else 1.0
|
||||
# Prevent division by zero
|
||||
if avg_vol == 0:
|
||||
avg_vol = 1.0
|
||||
normalized_volumes = [v / avg_vol for v in volume_history]
|
||||
state_features.extend(normalized_volumes)
|
||||
else:
|
||||
state_features.extend([0.0] * 10)
|
||||
|
||||
# Add extended COB features
|
||||
extended_cob = self._extract_cob_features()
|
||||
state_features.extend(extended_cob[5:]) # Remaining COB features
|
||||
|
||||
# Add 5m timeframe data if available
|
||||
if len(self.real_time_data['ohlcv_5m']) >= 5:
|
||||
tf_5m_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_5m'])[-5:]]
|
||||
if tf_5m_prices:
|
||||
base_5m = tf_5m_prices[0]
|
||||
# Prevent division by zero
|
||||
if base_5m == 0:
|
||||
base_5m = 1.0
|
||||
normalized_5m = [(p - base_5m) / base_5m for p in tf_5m_prices]
|
||||
state_features.extend(normalized_5m)
|
||||
else:
|
||||
state_features.extend([0.0] * 5)
|
||||
else:
|
||||
state_features.extend([0.0] * 5)
|
||||
|
||||
# 7. Adaptive padding/truncation based on target dimensions
|
||||
current_length = len(state_features)
|
||||
|
||||
if target_dimensions > current_length:
|
||||
# Pad with additional engineered features
|
||||
remaining = target_dimensions - current_length
|
||||
|
||||
# Add statistical features if we have data
|
||||
if len(self.real_time_data['ohlcv_1m']) >= 20:
|
||||
all_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-20:]]
|
||||
all_volumes = [bar['volume'] for bar in list(self.real_time_data['ohlcv_1m'])[-20:]]
|
||||
|
||||
# Statistical features
|
||||
additional_features = [
|
||||
np.std(all_prices) / np.mean(all_prices) if np.mean(all_prices) > 0 else 0, # Price CV
|
||||
np.std(all_volumes) / np.mean(all_volumes) if np.mean(all_volumes) > 0 else 0, # Volume CV
|
||||
(max(all_prices) - min(all_prices)) / np.mean(all_prices) if np.mean(all_prices) > 0 else 0, # Price range
|
||||
# Safe correlation calculation
|
||||
np.corrcoef(all_prices, all_volumes)[0, 1] if (len(all_prices) == len(all_volumes) and len(all_prices) > 1 and
|
||||
np.std(all_prices) > 0 and np.std(all_volumes) > 0) else 0, # Price-volume correlation
|
||||
]
|
||||
|
||||
# Add momentum features
|
||||
for window in [3, 5, 10]:
|
||||
if len(all_prices) >= window:
|
||||
momentum = (all_prices[-1] - all_prices[-window]) / all_prices[-window] if all_prices[-window] > 0 else 0
|
||||
additional_features.append(momentum)
|
||||
else:
|
||||
additional_features.append(0.0)
|
||||
|
||||
# Extend to fill remaining space
|
||||
while len(additional_features) < remaining and len(additional_features) < 50:
|
||||
additional_features.extend([
|
||||
np.sin(len(additional_features) * 0.1), # Sine waves for variety
|
||||
np.cos(len(additional_features) * 0.1),
|
||||
np.tanh(len(additional_features) * 0.01)
|
||||
])
|
||||
|
||||
state_features.extend(additional_features[:remaining])
|
||||
else:
|
||||
# Fill with structured zeros/patterns if no data
|
||||
pattern_features = []
|
||||
for i in range(remaining):
|
||||
pattern_features.append(np.sin(i * 0.01)) # Small oscillating pattern
|
||||
state_features.extend(pattern_features)
|
||||
|
||||
# Ensure exact target dimension
|
||||
state_features = state_features[:target_dimensions]
|
||||
while len(state_features) < target_dimensions:
|
||||
state_features.append(0.0)
|
||||
|
||||
return np.array(state_features[:100])
|
||||
return np.array(state_features)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building state: {e}")
|
||||
return np.zeros(100)
|
||||
return np.zeros(target_dimensions)
|
||||
|
||||
def _get_model_expected_dimensions(self, model_type: str) -> int:
|
||||
"""Get expected input dimensions for different model types"""
|
||||
try:
|
||||
if model_type == 'dqn':
|
||||
# Try to get DQN expected dimensions from model
|
||||
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
|
||||
and self.orchestrator.rl_agent and hasattr(self.orchestrator.rl_agent, 'policy_net')):
|
||||
# Get first layer input size
|
||||
first_layer = list(self.orchestrator.rl_agent.policy_net.children())[0]
|
||||
if hasattr(first_layer, 'in_features'):
|
||||
return first_layer.in_features
|
||||
return 403 # Default for DQN based on error logs
|
||||
|
||||
elif model_type == 'cnn':
|
||||
# CNN might have different input expectations
|
||||
if (self.orchestrator and hasattr(self.orchestrator, 'cnn_model')
|
||||
and self.orchestrator.cnn_model):
|
||||
# Try to get CNN input size
|
||||
if hasattr(self.orchestrator.cnn_model, 'input_shape'):
|
||||
return self.orchestrator.cnn_model.input_shape
|
||||
return 300 # Default for CNN based on error logs
|
||||
|
||||
elif model_type == 'cob_rl':
|
||||
return 2000 # COB RL expects 2000 features
|
||||
|
||||
else:
|
||||
return 100 # Default
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting model dimensions for {model_type}: {e}")
|
||||
return 100 # Fallback
|
||||
|
||||
def _extract_cob_features(self) -> List[float]:
|
||||
"""Extract features from COB data"""
|
||||
@@ -1131,8 +1287,8 @@ class EnhancedRealtimeTrainingSystem:
|
||||
total_loss += loss
|
||||
training_iterations += 1
|
||||
elif hasattr(rl_agent, 'replay'):
|
||||
# Fallback to replay method
|
||||
loss = rl_agent.replay(batch_size=len(batch))
|
||||
# Fallback to replay method - DQNAgent.replay() doesn't accept batch_size parameter
|
||||
loss = rl_agent.replay()
|
||||
if loss is not None:
|
||||
total_loss += loss
|
||||
training_iterations += 1
|
||||
@@ -1142,6 +1298,10 @@ class EnhancedRealtimeTrainingSystem:
|
||||
|
||||
self.dqn_training_count += 1
|
||||
|
||||
# Save checkpoint after training
|
||||
if training_iterations > 0 and avg_loss > 0:
|
||||
self._save_model_checkpoint('dqn_agent', rl_agent, avg_loss)
|
||||
|
||||
# Log progress every 10 training sessions
|
||||
if self.dqn_training_count % 10 == 0:
|
||||
logger.info(f"DQN TRAINING: Session {self.dqn_training_count}, "
|
||||
@@ -1175,6 +1335,18 @@ class EnhancedRealtimeTrainingSystem:
|
||||
aggregated_matrix = self.get_cob_training_matrix(symbol, '1s_aggregated')
|
||||
|
||||
if combined_features is not None:
|
||||
# Ensure features are exactly 2000 dimensions
|
||||
if len(combined_features) != 2000:
|
||||
logger.warning(f"COB features wrong size: {len(combined_features)}, padding/truncating to 2000")
|
||||
if len(combined_features) < 2000:
|
||||
# Pad with zeros
|
||||
padded_features = np.zeros(2000, dtype=np.float32)
|
||||
padded_features[:len(combined_features)] = combined_features
|
||||
combined_features = padded_features
|
||||
else:
|
||||
# Truncate to 2000
|
||||
combined_features = combined_features[:2000]
|
||||
|
||||
# Create enhanced COB training experience
|
||||
current_price = self._get_current_price_from_data(symbol)
|
||||
if current_price:
|
||||
@@ -1184,29 +1356,14 @@ class EnhancedRealtimeTrainingSystem:
|
||||
# Calculate reward based on COB prediction accuracy
|
||||
reward = self._calculate_cob_reward(symbol, action, combined_features)
|
||||
|
||||
# Create comprehensive state vector for COB RL
|
||||
# Create comprehensive state vector for COB RL (exactly 2000 dimensions)
|
||||
state = combined_features # 2000-dimensional state
|
||||
|
||||
# Store experience in COB RL agent
|
||||
if hasattr(cob_rl_agent, 'store_experience'):
|
||||
experience = {
|
||||
'state': state,
|
||||
'action': action,
|
||||
'reward': reward,
|
||||
'next_state': state, # Will be updated with next observation
|
||||
'done': False,
|
||||
'symbol': symbol,
|
||||
'timestamp': datetime.now(),
|
||||
'price': current_price,
|
||||
'cob_features': {
|
||||
'raw_tick_available': raw_tick_matrix is not None,
|
||||
'aggregated_available': aggregated_matrix is not None,
|
||||
'imbalance': combined_features[0] if len(combined_features) > 0 else 0,
|
||||
'spread': combined_features[1] if len(combined_features) > 1 else 0,
|
||||
'liquidity': combined_features[4] if len(combined_features) > 4 else 0
|
||||
}
|
||||
}
|
||||
cob_rl_agent.store_experience(experience)
|
||||
if hasattr(cob_rl_agent, 'remember'):
|
||||
# Use tuple format for DQN agent compatibility
|
||||
experience_tuple = (state, action, reward, state, False) # next_state = current state for now
|
||||
cob_rl_agent.remember(state, action, reward, state, False)
|
||||
training_updates += 1
|
||||
|
||||
# Perform COB RL training if enough experiences
|
||||
@@ -1479,16 +1636,29 @@ class EnhancedRealtimeTrainingSystem:
|
||||
# Moving averages
|
||||
if len(prev_prices) >= 5:
|
||||
ma5 = sum(prev_prices[-5:]) / 5
|
||||
tech_features.append((current_price - ma5) / ma5)
|
||||
# Prevent division by zero
|
||||
if ma5 != 0:
|
||||
tech_features.append((current_price - ma5) / ma5)
|
||||
else:
|
||||
tech_features.append(0.0)
|
||||
|
||||
if len(prev_prices) >= 10:
|
||||
ma10 = sum(prev_prices[-10:]) / 10
|
||||
tech_features.append((current_price - ma10) / ma10)
|
||||
# Prevent division by zero
|
||||
if ma10 != 0:
|
||||
tech_features.append((current_price - ma10) / ma10)
|
||||
else:
|
||||
tech_features.append(0.0)
|
||||
|
||||
# Volatility measure
|
||||
if len(prev_prices) >= 5:
|
||||
volatility = np.std(prev_prices[-5:]) / np.mean(prev_prices[-5:])
|
||||
tech_features.append(volatility)
|
||||
price_mean = np.mean(prev_prices[-5:])
|
||||
# Prevent division by zero
|
||||
if price_mean != 0:
|
||||
volatility = np.std(prev_prices[-5:]) / price_mean
|
||||
tech_features.append(volatility)
|
||||
else:
|
||||
tech_features.append(0.0)
|
||||
|
||||
# Pad technical features to 200
|
||||
while len(tech_features) < 200:
|
||||
@@ -1670,6 +1840,14 @@ class EnhancedRealtimeTrainingSystem:
|
||||
features_tensor = torch.from_numpy(features).float().to(device)
|
||||
targets_tensor = torch.from_numpy(targets).long().to(device)
|
||||
|
||||
# FIXED: Move tensors to same device as model
|
||||
device = next(model.parameters()).device
|
||||
features_tensor = features_tensor.to(device)
|
||||
targets_tensor = targets_tensor.to(device)
|
||||
|
||||
# Move criterion to same device as well
|
||||
criterion = criterion.to(device)
|
||||
|
||||
# Ensure features_tensor has the correct shape for CNN (batch_size, channels, height, width)
|
||||
# Assuming features are flattened (batch_size, 15*20) and need to be reshaped to (batch_size, 1, 15, 20)
|
||||
# This depends on the actual CNN model architecture. Assuming a simple CNN that expects (batch, channels, height, width)
|
||||
@@ -1700,6 +1878,7 @@ class EnhancedRealtimeTrainingSystem:
|
||||
|
||||
outputs = model(features_tensor)
|
||||
|
||||
<<<<<<< HEAD
|
||||
# Extract logits from model output (model returns a dictionary)
|
||||
if isinstance(outputs, dict):
|
||||
logits = outputs['logits']
|
||||
@@ -1713,6 +1892,19 @@ class EnhancedRealtimeTrainingSystem:
|
||||
logger.error(f"CNN output is not a tensor: {type(logits)}")
|
||||
return 0.0
|
||||
|
||||
=======
|
||||
# FIXED: Handle case where model returns tuple (extract the logits)
|
||||
if isinstance(outputs, tuple):
|
||||
# Assume the first element is the main output (logits)
|
||||
logits = outputs[0]
|
||||
elif isinstance(outputs, dict):
|
||||
# Handle dictionary output (get main prediction)
|
||||
logits = outputs.get('logits', outputs.get('predictions', outputs.get('output', list(outputs.values())[0])))
|
||||
else:
|
||||
# Single tensor output
|
||||
logits = outputs
|
||||
|
||||
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
|
||||
loss = criterion(logits, targets_tensor)
|
||||
|
||||
loss.backward()
|
||||
@@ -1721,8 +1913,122 @@ class EnhancedRealtimeTrainingSystem:
|
||||
return loss.item()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN training: {e}")
|
||||
logger.error(f"RT TRAINING: Error in CNN training: {e}")
|
||||
return 1.0 # Return default loss value in case of error
|
||||
|
||||
def _sample_prioritized_experiences(self) -> List[Dict]:
|
||||
"""Sample prioritized experiences for training"""
|
||||
try:
|
||||
experiences = []
|
||||
|
||||
# Sample from priority buffer first (high-priority experiences)
|
||||
if self.priority_buffer:
|
||||
priority_samples = min(len(self.priority_buffer), self.training_config['batch_size'] // 2)
|
||||
priority_experiences = random.sample(list(self.priority_buffer), priority_samples)
|
||||
experiences.extend(priority_experiences)
|
||||
|
||||
# Sample from regular experience buffer
|
||||
if self.experience_buffer:
|
||||
remaining_samples = self.training_config['batch_size'] - len(experiences)
|
||||
if remaining_samples > 0:
|
||||
regular_samples = min(len(self.experience_buffer), remaining_samples)
|
||||
regular_experiences = random.sample(list(self.experience_buffer), regular_samples)
|
||||
experiences.extend(regular_experiences)
|
||||
|
||||
# Convert experiences to DQN format
|
||||
dqn_experiences = []
|
||||
for exp in experiences:
|
||||
# Create next state by shifting current state (simple approximation)
|
||||
next_state = exp['state'].copy() if hasattr(exp['state'], 'copy') else exp['state']
|
||||
|
||||
# Simple reward based on recent market movement
|
||||
reward = self._calculate_experience_reward(exp)
|
||||
|
||||
# Action mapping: 0=BUY, 1=SELL, 2=HOLD
|
||||
action = self._determine_action_from_experience(exp)
|
||||
|
||||
dqn_exp = {
|
||||
'state': exp['state'],
|
||||
'action': action,
|
||||
'reward': reward,
|
||||
'next_state': next_state,
|
||||
'done': False # Episodes don't really "end" in continuous trading
|
||||
}
|
||||
|
||||
dqn_experiences.append(dqn_exp)
|
||||
|
||||
return dqn_experiences
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sampling prioritized experiences: {e}")
|
||||
return []
|
||||
|
||||
def _calculate_experience_reward(self, experience: Dict) -> float:
|
||||
"""Calculate reward for an experience"""
|
||||
try:
|
||||
# Simple reward based on technical indicators and market events
|
||||
reward = 0.0
|
||||
|
||||
# Reward based on market events
|
||||
if experience.get('market_events', 0) > 0:
|
||||
reward += 0.1 # Bonus for learning from market events
|
||||
|
||||
# Reward based on technical indicators
|
||||
tech_indicators = experience.get('technical_indicators', {})
|
||||
if tech_indicators:
|
||||
# Reward for strong momentum
|
||||
momentum = tech_indicators.get('price_momentum', 0)
|
||||
reward += np.tanh(momentum * 10) # Bounded reward
|
||||
|
||||
# Penalize high volatility
|
||||
volatility = tech_indicators.get('volatility', 0)
|
||||
reward -= min(volatility * 5, 0.2) # Penalty for high volatility
|
||||
|
||||
# Reward based on COB features
|
||||
cob_features = experience.get('cob_features', [])
|
||||
if cob_features and len(cob_features) > 0:
|
||||
# Reward for strong order book imbalance
|
||||
imbalance = cob_features[0] if len(cob_features) > 0 else 0
|
||||
reward += abs(imbalance) * 0.1 # Reward for any imbalance signal
|
||||
|
||||
return max(-1.0, min(1.0, reward)) # Clamp to [-1, 1]
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error calculating experience reward: {e}")
|
||||
return 0.0
|
||||
|
||||
def _determine_action_from_experience(self, experience: Dict) -> int:
|
||||
"""Determine action from experience data"""
|
||||
try:
|
||||
# Use technical indicators to determine action
|
||||
tech_indicators = experience.get('technical_indicators', {})
|
||||
|
||||
if tech_indicators:
|
||||
momentum = tech_indicators.get('price_momentum', 0)
|
||||
rsi = tech_indicators.get('rsi', 50)
|
||||
|
||||
# Simple logic based on momentum and RSI
|
||||
if momentum > 0.005 and rsi < 70: # Upward momentum, not overbought
|
||||
return 0 # BUY
|
||||
elif momentum < -0.005 and rsi > 30: # Downward momentum, not oversold
|
||||
return 1 # SELL
|
||||
else:
|
||||
return 2 # HOLD
|
||||
|
||||
# Fallback to COB-based action
|
||||
cob_features = experience.get('cob_features', [])
|
||||
if cob_features and len(cob_features) > 0:
|
||||
imbalance = cob_features[0]
|
||||
if imbalance > 0.1:
|
||||
return 0 # BUY (bid imbalance)
|
||||
elif imbalance < -0.1:
|
||||
return 1 # SELL (ask imbalance)
|
||||
|
||||
return 2 # Default to HOLD
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error determining action from experience: {e}")
|
||||
return 2 # Default to HOLD
|
||||
|
||||
def _perform_validation(self):
|
||||
"""Perform validation to track model performance"""
|
||||
@@ -2084,17 +2390,21 @@ class EnhancedRealtimeTrainingSystem:
|
||||
def _generate_forward_dqn_prediction(self, symbol: str, current_time: float):
|
||||
"""Generate a DQN prediction for future price movement"""
|
||||
try:
|
||||
# Get current market state (only historical data)
|
||||
current_state = self._build_comprehensive_state()
|
||||
# Get current market state with DQN-specific dimensions
|
||||
target_dims = self._get_model_expected_dimensions('dqn')
|
||||
current_state = self._build_comprehensive_state(target_dims)
|
||||
current_price = self._get_current_price_from_data(symbol)
|
||||
|
||||
if current_price is None:
|
||||
# SKIP prediction if price is invalid
|
||||
if current_price is None or current_price <= 0:
|
||||
logger.debug(f"Skipping DQN prediction for {symbol}: invalid price {current_price}")
|
||||
return
|
||||
|
||||
# Use DQN model to predict action (if available)
|
||||
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
|
||||
and self.orchestrator.rl_agent):
|
||||
|
||||
<<<<<<< HEAD
|
||||
# Use RL agent to make prediction
|
||||
current_state = self._get_dqn_state_features(symbol)
|
||||
if current_state is None:
|
||||
@@ -2112,6 +2422,28 @@ class EnhancedRealtimeTrainingSystem:
|
||||
|
||||
confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33
|
||||
|
||||
=======
|
||||
# Get action from DQN agent
|
||||
action = self.orchestrator.rl_agent.act(current_state, explore=False)
|
||||
|
||||
# Get Q-values by manually calling the model
|
||||
q_values = self._get_dqn_q_values(current_state)
|
||||
|
||||
# Calculate confidence from Q-values
|
||||
if q_values is not None and len(q_values) > 0:
|
||||
# Convert to probabilities and get confidence
|
||||
probs = torch.softmax(torch.tensor(q_values), dim=0).numpy()
|
||||
confidence = float(max(probs))
|
||||
q_values = q_values.tolist() if hasattr(q_values, 'tolist') else list(q_values)
|
||||
else:
|
||||
confidence = 0.33
|
||||
q_values = [0.33, 0.33, 0.34] # Default uniform distribution
|
||||
|
||||
# Handle case where action is None (HOLD)
|
||||
if action is None:
|
||||
action = 2 # Map None to HOLD action
|
||||
|
||||
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
|
||||
else:
|
||||
# Fallback to technical analysis-based prediction
|
||||
action, q_values, confidence = self._technical_analysis_prediction(symbol)
|
||||
@@ -2138,8 +2470,8 @@ class EnhancedRealtimeTrainingSystem:
|
||||
if symbol in self.pending_predictions:
|
||||
self.pending_predictions[symbol].append(prediction)
|
||||
|
||||
# Add to recent predictions for display (only if confident enough)
|
||||
if confidence > 0.4:
|
||||
# Add to recent predictions for display (only if confident enough AND valid price)
|
||||
if confidence > 0.4 and current_price > 0:
|
||||
display_prediction = {
|
||||
'timestamp': prediction_time,
|
||||
'price': current_price,
|
||||
@@ -2152,6 +2484,7 @@ class EnhancedRealtimeTrainingSystem:
|
||||
|
||||
self.last_prediction_time[symbol] = int(current_time)
|
||||
|
||||
<<<<<<< HEAD
|
||||
# Robust action labeling
|
||||
if action is None:
|
||||
action_label = 'HOLD'
|
||||
@@ -2163,10 +2496,46 @@ class EnhancedRealtimeTrainingSystem:
|
||||
action_label = 'UNKNOWN'
|
||||
|
||||
logger.info(f"Forward DQN prediction: {symbol} action={action_label} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
|
||||
=======
|
||||
logger.info(f"Forward DQN prediction: {symbol} action={['BUY','SELL','HOLD'][action]} confidence={confidence:.2f} price=${current_price:.2f} target={target_time.strftime('%H:%M:%S')} dims={len(current_state)}")
|
||||
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating forward DQN prediction: {e}")
|
||||
|
||||
def _get_dqn_q_values(self, state: np.ndarray) -> Optional[np.ndarray]:
|
||||
"""Get Q-values from DQN agent without performing action selection"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
||||
return None
|
||||
|
||||
rl_agent = self.orchestrator.rl_agent
|
||||
|
||||
# Convert state to tensor
|
||||
if isinstance(state, np.ndarray):
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(rl_agent.device)
|
||||
else:
|
||||
state_tensor = state.unsqueeze(0).to(rl_agent.device)
|
||||
|
||||
# Get Q-values directly from policy network
|
||||
with torch.no_grad():
|
||||
policy_output = rl_agent.policy_net(state_tensor)
|
||||
|
||||
# Handle different output formats
|
||||
if isinstance(policy_output, dict):
|
||||
q_values = policy_output.get('q_values', policy_output.get('Q_values', list(policy_output.values())[0]))
|
||||
elif isinstance(policy_output, tuple):
|
||||
q_values = policy_output[0] # Assume first element is Q-values
|
||||
else:
|
||||
q_values = policy_output
|
||||
|
||||
# Convert to numpy
|
||||
return q_values.cpu().data.numpy()[0]
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting DQN Q-values: {e}")
|
||||
return None
|
||||
|
||||
def _generate_forward_cnn_prediction(self, symbol: str, current_time: float):
|
||||
"""Generate a CNN prediction for future price direction"""
|
||||
try:
|
||||
@@ -2174,9 +2543,15 @@ class EnhancedRealtimeTrainingSystem:
|
||||
current_price = self._get_current_price_from_data(symbol)
|
||||
price_sequence = self._get_historical_price_sequence(symbol, periods=15)
|
||||
|
||||
if current_price is None or len(price_sequence) < 15:
|
||||
# SKIP prediction if price is invalid
|
||||
if current_price is None or current_price <= 0:
|
||||
logger.debug(f"Skipping CNN prediction for {symbol}: invalid price {current_price}")
|
||||
return
|
||||
|
||||
|
||||
if len(price_sequence) < 15:
|
||||
logger.debug(f"Skipping CNN prediction for {symbol}: insufficient data")
|
||||
return
|
||||
|
||||
# Use CNN model to predict direction (if available)
|
||||
if (self.orchestrator and hasattr(self.orchestrator, 'cnn_model')
|
||||
and self.orchestrator.cnn_model):
|
||||
@@ -2229,8 +2604,8 @@ class EnhancedRealtimeTrainingSystem:
|
||||
if symbol in self.pending_predictions:
|
||||
self.pending_predictions[symbol].append(prediction)
|
||||
|
||||
# Add to recent predictions for display (only if confident enough)
|
||||
if confidence > 0.5:
|
||||
# Add to recent predictions for display (only if confident enough AND valid prices)
|
||||
if confidence > 0.5 and current_price > 0 and predicted_price > 0:
|
||||
display_prediction = {
|
||||
'timestamp': prediction_time,
|
||||
'current_price': current_price,
|
||||
@@ -2241,7 +2616,7 @@ class EnhancedRealtimeTrainingSystem:
|
||||
if symbol in self.recent_cnn_predictions:
|
||||
self.recent_cnn_predictions[symbol].append(display_prediction)
|
||||
|
||||
logger.info(f"Forward CNN prediction: {symbol} direction={['DOWN','SAME','UP'][direction]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
|
||||
logger.info(f"Forward CNN prediction: {symbol} direction={['DOWN','SAME','UP'][direction]} confidence={confidence:.2f} price=${current_price:.2f} -> ${predicted_price:.2f} target={target_time.strftime('%H:%M:%S')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating forward CNN prediction: {e}")
|
||||
@@ -2332,8 +2707,24 @@ class EnhancedRealtimeTrainingSystem:
|
||||
def _get_current_price_from_data(self, symbol: str) -> Optional[float]:
|
||||
"""Get current price from real-time data streams"""
|
||||
try:
|
||||
# First, try to get from data provider (most reliable)
|
||||
if self.data_provider:
|
||||
price = self.data_provider.get_current_price(symbol)
|
||||
if price and price > 0:
|
||||
return price
|
||||
|
||||
# Fallback to internal buffer
|
||||
if len(self.real_time_data['ohlcv_1m']) > 0:
|
||||
return self.real_time_data['ohlcv_1m'][-1]['close']
|
||||
price = self.real_time_data['ohlcv_1m'][-1]['close']
|
||||
if price and price > 0:
|
||||
return price
|
||||
|
||||
# Fallback to orchestrator price
|
||||
if self.orchestrator:
|
||||
price = self.orchestrator._get_current_price(symbol)
|
||||
if price and price > 0:
|
||||
return price
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price: {e}")
|
||||
@@ -2428,4 +2819,56 @@ class EnhancedRealtimeTrainingSystem:
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error estimating price change: {e}")
|
||||
return 0.0
|
||||
return 0.0 d
|
||||
ef _save_model_checkpoint(self, model_name: str, model_obj, loss: float):
|
||||
"""
|
||||
Save model checkpoint after training if performance improved
|
||||
|
||||
This is CRITICAL for preserving training progress across restarts.
|
||||
"""
|
||||
try:
|
||||
if not CHECKPOINT_MANAGER_AVAILABLE:
|
||||
return
|
||||
|
||||
# Get checkpoint manager
|
||||
checkpoint_manager = get_checkpoint_manager()
|
||||
if not checkpoint_manager:
|
||||
return
|
||||
|
||||
# Prepare performance metrics
|
||||
performance_metrics = {
|
||||
'loss': loss,
|
||||
'training_samples': len(self.experience_buffer),
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Prepare training metadata
|
||||
training_metadata = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'training_iteration': self.training_iteration,
|
||||
'model_type': model_name
|
||||
}
|
||||
|
||||
# Determine model type based on model name
|
||||
model_type = model_name
|
||||
if 'dqn' in model_name.lower():
|
||||
model_type = 'dqn'
|
||||
elif 'cnn' in model_name.lower():
|
||||
model_type = 'cnn'
|
||||
elif 'cob' in model_name.lower():
|
||||
model_type = 'cob_rl'
|
||||
|
||||
# Save checkpoint
|
||||
checkpoint_path = save_checkpoint(
|
||||
model=model_obj,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
if checkpoint_path:
|
||||
logger.info(f"💾 Saved checkpoint for {model_name}: {checkpoint_path} (loss: {loss:.4f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||||
24
NN/training/samples/curl.txt
Normal file
24
NN/training/samples/curl.txt
Normal file
@@ -0,0 +1,24 @@
|
||||
so, curl example:
|
||||
curl http://localhost:1234/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "deepseek-r1-0528-qwen3-8b",
|
||||
"messages": [
|
||||
{ "role": "system", "content": "what will be the next row in this sequence:
|
||||
symbol MAIN SYMBOL (ETH) REF1 (BTC) REF2 (SPX) REF3 (SOL)
|
||||
timeframe 1s 1m 1h 1d 1s 1s 1s
|
||||
datapoint O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp
|
||||
2025-01-15T10:00:00Z 3421.5 3421.75 3421.25 3421.6 125.4 2025-01-15T10:00:00Z 3422.1 3424.8 3420.5 3423.25 1245.7 2025-01-15T10:00:00Z 3420 3428.5 3418.75 3425.1 12847.2 2025-01-15T10:00:00Z 3415.25 3435.6 3410.8 3430.4 145238.6 2025-01-15T10:00:00Z 97850.2 97852.4 97848.1 97851.3 8.7 2025-01-15T10:00:00Z 5925.4 5926.1 5924.8 5925.7 0 2025-01-15T10:00:00Z 191.22 191.45 191.08 191.35 1247.3
|
||||
2025-01-15T10:00:01Z 3421.6 3421.85 3421.45 3421.75 98.2 2025-01-15T10:01:00Z 3423.25 3425.9 3421.8 3424.6 1189.3 2025-01-15T11:00:00Z 3425.1 3432.2 3422.4 3429.8 11960.5 2025-01-16T10:00:00Z 3430.4 3445.2 3425.15 3440.85 138947.1 2025-01-15T10:00:01Z 97851.3 97853.8 97849.5 97852.9 9.1 2025-01-15T10:00:01Z 5925.7 5926.3 5925.2 5925.9 0 2025-01-15T10:00:01Z 191.35 191.58 191.15 191.48 1156.7
|
||||
2025-01-15T10:00:02Z 3421.75 3421.95 3421.55 3421.8 110.6 2025-01-15T10:02:00Z 3424.6 3427.15 3423.4 3425.9 1356.8 2025-01-15T12:00:00Z 3429.8 3436.7 3427.2 3434.5 13205.9 2025-01-17T10:00:00Z 3440.85 3455.3 3438.9 3450.75 142568.3 2025-01-15T10:00:02Z 97852.9 97855.2 97850.7 97854.6 7.9 2025-01-15T10:00:02Z 5925.9 5926.5 5925.4 5926.1 0 2025-01-15T10:00:02Z 191.48 191.72 191.28 191.61 1298.4
|
||||
2025-01-15T10:00:03Z 3421.8 3422.05 3421.65 3421.9 87.3 2025-01-15T10:03:00Z 3425.9 3428.4 3424.2 3427.1 1423.5 2025-01-15T13:00:00Z 3434.5 3441.8 3432.1 3438.2 14087.6 2025-01-18T10:00:00Z 3450.75 3465.4 3448.6 3460.2 149825.7 2025-01-15T10:00:03Z 97854.6 97857.1 97852.3 97856.8 8.4 2025-01-15T10:00:03Z 5926.1 5926.7 5925.6 5926.3 0 2025-01-15T10:00:03Z 191.61 191.85 191.42 191.74 1187.9
|
||||
2025-01-15T10:00:04Z 3421.9 3422.15 3421.75 3422.0 134.7 2025-01-15T10:04:00Z 3427.1 3429.6 3425.8 3428.3 1298.2 2025-01-15T14:00:00Z 3438.2 3445.6 3436.4 3442.1 12734.8 2025-01-19T10:00:00Z 3460.2 3475.8 3457.4 3470.6 156742.4 2025-01-15T10:00:04Z 97856.8 97859.4 97854.9 97858.2 9.2 2025-01-15T10:00:04Z 5926.3 5926.9 5925.8 5926.5 0 2025-01-15T10:00:04Z 191.74 191.98 191.55 191.87 1342.6
|
||||
2025-01-15T10:00:05Z 3422.0 3422.25 3421.85 3422.1 156.8 2025-01-15T10:05:00Z 3428.3 3430.8 3426.9 3429.5 1467.9 2025-01-15T15:00:00Z 3442.1 3449.3 3440.7 3446.8 11823.4 2025-01-20T10:00:00Z 3470.6 3485.2 3467.9 3480.1 163456.8 2025-01-15T10:00:05Z 97858.2 97860.7 97856.4 97859.8 8.8 2025-01-15T10:00:05Z 5926.5 5927.1 5926.0 5926.7 0 2025-01-15T10:00:05Z 191.87 192.11 191.68 192.0 1278.5
|
||||
|
||||
" },
|
||||
{ "role": "user", "content": "2025-01-15T10:00:06Z" }
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": -1,
|
||||
"stream": false
|
||||
}'
|
||||
16
NN/training/samples/readme.md
Normal file
16
NN/training/samples/readme.md
Normal file
@@ -0,0 +1,16 @@
|
||||
<!-- example text dump: -->
|
||||
|
||||
symbol MAIN SYMBOL (ETH) REF1 (BTC) REF2 (SPX) REF3 (SOL)
|
||||
timeframe 1s 1m 1h 1d 1s 1s 1s
|
||||
datapoint O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp
|
||||
2025-01-15T10:00:00Z 3421.5 3421.75 3421.25 3421.6 125.4 2025-01-15T10:00:00Z 3422.1 3424.8 3420.5 3423.25 1245.7 2025-01-15T10:00:00Z 3420 3428.5 3418.75 3425.1 12847.2 2025-01-15T10:00:00Z 3415.25 3435.6 3410.8 3430.4 145238.6 2025-01-15T10:00:00Z 97850.2 97852.4 97848.1 97851.3 8.7 2025-01-15T10:00:00Z 5925.4 5926.1 5924.8 5925.7 0 2025-01-15T10:00:00Z 191.22 191.45 191.08 191.35 1247.3
|
||||
2025-01-15T10:00:01Z 3421.6 3421.85 3421.45 3421.75 98.2 2025-01-15T10:01:00Z 3423.25 3425.9 3421.8 3424.6 1189.3 2025-01-15T11:00:00Z 3425.1 3432.2 3422.4 3429.8 11960.5 2025-01-16T10:00:00Z 3430.4 3445.2 3425.15 3440.85 138947.1 2025-01-15T10:00:01Z 97851.3 97853.8 97849.5 97852.9 9.1 2025-01-15T10:00:01Z 5925.7 5926.3 5925.2 5925.9 0 2025-01-15T10:00:01Z 191.35 191.58 191.15 191.48 1156.7
|
||||
2025-01-15T10:00:02Z 3421.75 3421.95 3421.55 3421.8 110.6 2025-01-15T10:02:00Z 3424.6 3427.15 3423.4 3425.9 1356.8 2025-01-15T12:00:00Z 3429.8 3436.7 3427.2 3434.5 13205.9 2025-01-17T10:00:00Z 3440.85 3455.3 3438.9 3450.75 142568.3 2025-01-15T10:00:02Z 97852.9 97855.2 97850.7 97854.6 7.9 2025-01-15T10:00:02Z 5925.9 5926.5 5925.4 5926.1 0 2025-01-15T10:00:02Z 191.48 191.72 191.28 191.61 1298.4
|
||||
2025-01-15T10:00:03Z 3421.8 3422.05 3421.65 3421.9 87.3 2025-01-15T10:03:00Z 3425.9 3428.4 3424.2 3427.1 1423.5 2025-01-15T13:00:00Z 3434.5 3441.8 3432.1 3438.2 14087.6 2025-01-18T10:00:00Z 3450.75 3465.4 3448.6 3460.2 149825.7 2025-01-15T10:00:03Z 97854.6 97857.1 97852.3 97856.8 8.4 2025-01-15T10:00:03Z 5926.1 5926.7 5925.6 5926.3 0 2025-01-15T10:00:03Z 191.61 191.85 191.42 191.74 1187.9
|
||||
2025-01-15T10:00:04Z 3421.9 3422.15 3421.75 3422.0 134.7 2025-01-15T10:04:00Z 3427.1 3429.6 3425.8 3428.3 1298.2 2025-01-15T14:00:00Z 3438.2 3445.6 3436.4 3442.1 12734.8 2025-01-19T10:00:00Z 3460.2 3475.8 3457.4 3470.6 156742.4 2025-01-15T10:00:04Z 97856.8 97859.4 97854.9 97858.2 9.2 2025-01-15T10:00:04Z 5926.3 5926.9 5925.8 5926.5 0 2025-01-15T10:00:04Z 191.74 191.98 191.55 191.87 1342.6
|
||||
2025-01-15T10:00:05Z 3422.0 3422.25 3421.85 3422.1 156.8 2025-01-15T10:05:00Z 3428.3 3430.8 3426.9 3429.5 1467.9 2025-01-15T15:00:00Z 3442.1 3449.3 3440.7 3446.8 11823.4 2025-01-20T10:00:00Z 3470.6 3485.2 3467.9 3480.1 163456.8 2025-01-15T10:00:05Z 97858.2 97860.7 97856.4 97859.8 8.8 2025-01-15T10:00:05Z 5926.5 5927.1 5926.0 5926.7 0 2025-01-15T10:00:05Z 191.87 192.11 191.68 192.0 1278.5
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
4
NN/training/samples/txt/market_data_20250826_2129.txt
Normal file
4
NN/training/samples/txt/market_data_20250826_2129.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
symbol MAIN SYMBOL (ETH) REF1 (BTC) REF2 (SPX) REF3 (SOL)
|
||||
timeframe 1s 1m 1h 1d 1s 1s 1s
|
||||
datapoint O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp
|
||||
2025-08-26T21:29:44Z 0 0 0 0 0 2025-08-26T21:29:44Z 0 0 0 0 0 2025-08-26T21:29:44Z 0 0 0 0 0 2025-08-26T21:29:44Z 0 0 0 0 0 2025-08-26T21:29:44Z 0 0 0 0 0 2025-08-26T21:29:44Z 5500.00 5520.00 5495.00 5510.00 1000000.0 2025-08-26T21:29:44Z 0 0 0 0 0 2025-08-26T21:29:44Z
|
||||
@@ -1,13 +0,0 @@
|
||||
"""
|
||||
Neural Network Utilities
|
||||
======================
|
||||
|
||||
This package contains utility functions and classes used in the neural network trading system:
|
||||
- Data Interface: Connects to realtime trading data and processes it for the neural network models
|
||||
"""
|
||||
|
||||
from .data_interface import DataInterface
|
||||
from .trading_env import TradingEnvironment
|
||||
from .signal_interpreter import SignalInterpreter
|
||||
|
||||
__all__ = ['DataInterface', 'TradingEnvironment', 'SignalInterpreter']
|
||||
@@ -209,8 +209,8 @@ class DataInterface:
|
||||
curr_close = data[window_size-1:-1, 3]
|
||||
price_changes = (next_close - curr_close) / curr_close
|
||||
|
||||
# Define thresholds for price movement classification
|
||||
threshold = 0.0005 # 0.05% threshold - smaller to encourage more signals
|
||||
# Define thresholds for price movement classification
|
||||
threshold = 0.001 # 0.10% threshold - prefer bigger moves to beat fees
|
||||
y = np.zeros(len(price_changes), dtype=int)
|
||||
y[price_changes > threshold] = 2 # Up
|
||||
y[price_changes < -threshold] = 0 # Down
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
"""
|
||||
Enhanced Data Interface with additional NN trading parameters
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from .data_interface import DataInterface
|
||||
|
||||
class MultiDataInterface(DataInterface):
|
||||
"""
|
||||
Enhanced data interface that supports window_size and output_size parameters
|
||||
for neural network trading models.
|
||||
"""
|
||||
|
||||
def __init__(self, symbol: str,
|
||||
timeframes: List[str],
|
||||
window_size: int = 20,
|
||||
output_size: int = 3,
|
||||
data_dir: str = "NN/data"):
|
||||
"""
|
||||
Initialize with window_size and output_size for NN predictions.
|
||||
"""
|
||||
super().__init__(symbol, timeframes, data_dir)
|
||||
self.window_size = window_size
|
||||
self.output_size = output_size
|
||||
self.scalers = {} # Store scalers for each timeframe
|
||||
self.min_window_threshold = 100 # Minimum candles needed for training
|
||||
|
||||
def get_feature_count(self) -> int:
|
||||
"""
|
||||
Get number of features (OHLCV) for NN input.
|
||||
"""
|
||||
return 5 # open, high, low, close, volume
|
||||
|
||||
def prepare_training_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Prepare training data with windowed sequences"""
|
||||
# Get historical data for primary timeframe
|
||||
primary_tf = self.timeframes[0]
|
||||
df = self.get_historical_data(timeframe=primary_tf,
|
||||
n_candles=self.min_window_threshold + 1000)
|
||||
|
||||
if df is None or len(df) < self.min_window_threshold:
|
||||
raise ValueError(f"Insufficient data for training. Need at least {self.min_window_threshold} candles")
|
||||
|
||||
# Prepare OHLCV sequences
|
||||
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values
|
||||
|
||||
# Create sequences and labels
|
||||
X = []
|
||||
y = []
|
||||
|
||||
for i in range(len(ohlcv) - self.window_size - self.output_size):
|
||||
# Input sequence
|
||||
seq = ohlcv[i:i+self.window_size]
|
||||
X.append(seq)
|
||||
|
||||
# Output target (price movement direction)
|
||||
close_prices = ohlcv[i+self.window_size:i+self.window_size+self.output_size, 3] # Close prices
|
||||
price_changes = np.diff(close_prices)
|
||||
|
||||
if self.output_size == 1:
|
||||
# Binary classification (up/down)
|
||||
label = 1 if price_changes[0] > 0 else 0
|
||||
elif self.output_size == 3:
|
||||
# 3-class classification (buy/hold/sell)
|
||||
if price_changes[0] > 0.002: # Significant rise
|
||||
label = 0 # Buy
|
||||
elif price_changes[0] < -0.002: # Significant drop
|
||||
label = 2 # Sell
|
||||
else:
|
||||
label = 1 # Hold
|
||||
else:
|
||||
raise ValueError(f"Unsupported output_size: {self.output_size}")
|
||||
|
||||
y.append(label)
|
||||
|
||||
# Convert to numpy arrays
|
||||
X = np.array(X)
|
||||
y = np.array(y)
|
||||
|
||||
# Split into train/validation (80/20)
|
||||
split_idx = int(0.8 * len(X))
|
||||
X_train, y_train = X[:split_idx], y[:split_idx]
|
||||
X_val, y_val = X[split_idx:], y[split_idx:]
|
||||
|
||||
return X_train, y_train, X_val, y_val
|
||||
|
||||
def prepare_prediction_data(self) -> np.ndarray:
|
||||
"""Prepare most recent window for predictions"""
|
||||
primary_tf = self.timeframes[0]
|
||||
df = self.get_historical_data(timeframe=primary_tf,
|
||||
n_candles=self.window_size,
|
||||
use_cache=False)
|
||||
|
||||
if df is None or len(df) < self.window_size:
|
||||
raise ValueError(f"Need at least {self.window_size} candles for prediction")
|
||||
|
||||
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values[-self.window_size:]
|
||||
return np.array([ohlcv]) # Add batch dimension
|
||||
|
||||
def process_predictions(self, predictions: np.ndarray):
|
||||
"""Convert prediction probabilities to trading signals"""
|
||||
signals = []
|
||||
for pred in predictions:
|
||||
if self.output_size == 1:
|
||||
signal = "BUY" if pred[0] > 0.5 else "SELL"
|
||||
confidence = np.abs(pred[0] - 0.5) * 2 # Convert to 0-1 scale
|
||||
elif self.output_size == 3:
|
||||
action_idx = np.argmax(pred)
|
||||
signal = ["BUY", "HOLD", "SELL"][action_idx]
|
||||
confidence = pred[action_idx]
|
||||
else:
|
||||
signal = "HOLD"
|
||||
confidence = 0.0
|
||||
|
||||
signals.append({
|
||||
'action': signal,
|
||||
'confidence': confidence,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
})
|
||||
|
||||
return signals
|
||||
@@ -1,364 +0,0 @@
|
||||
"""
|
||||
Realtime Analyzer for Neural Network Trading System
|
||||
|
||||
This module implements real-time analysis of market data using trained neural network models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
import os
|
||||
import pandas as pd
|
||||
from collections import deque
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RealtimeAnalyzer:
|
||||
"""
|
||||
Handles real-time analysis of market data using trained neural network models.
|
||||
|
||||
Features:
|
||||
- Connects to real-time data sources (websockets)
|
||||
- Processes tick data into multiple timeframes (1s, 1m, 1h, 1d)
|
||||
- Uses trained models to analyze all timeframes
|
||||
- Generates trading signals
|
||||
- Manages risk and position sizing
|
||||
- Logs all trading decisions
|
||||
"""
|
||||
|
||||
def __init__(self, data_interface, model, symbol="BTC/USDT", timeframes=None):
|
||||
"""
|
||||
Initialize the realtime analyzer.
|
||||
|
||||
Args:
|
||||
data_interface (DataInterface): Preconfigured data interface
|
||||
model: Trained neural network model
|
||||
symbol (str): Trading pair symbol
|
||||
timeframes (list): List of timeframes to monitor (default: ['1s', '1m', '1h', '1d'])
|
||||
"""
|
||||
self.data_interface = data_interface
|
||||
self.model = model
|
||||
self.symbol = symbol
|
||||
self.timeframes = timeframes or ['1s', '1m', '1h', '1d']
|
||||
self.running = False
|
||||
self.data_queue = Queue()
|
||||
self.prediction_interval = 10 # Seconds between predictions
|
||||
self.ws_url = f"wss://stream.binance.com:9443/ws/{symbol.replace('/', '').lower()}@trade"
|
||||
self.ws = None
|
||||
self.tick_storage = deque(maxlen=10000) # Store up to 10,000 ticks
|
||||
self.candle_cache = {
|
||||
'1s': deque(maxlen=5000),
|
||||
'1m': deque(maxlen=5000),
|
||||
'1h': deque(maxlen=5000),
|
||||
'1d': deque(maxlen=5000)
|
||||
}
|
||||
|
||||
logger.info(f"RealtimeAnalyzer initialized for {symbol} with timeframes: {self.timeframes}")
|
||||
|
||||
def start(self):
|
||||
"""Start the realtime analysis process."""
|
||||
if self.running:
|
||||
logger.warning("Realtime analyzer already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
|
||||
# Start WebSocket connection thread
|
||||
self.ws_thread = Thread(target=self._run_websocket, daemon=True)
|
||||
self.ws_thread.start()
|
||||
|
||||
# Start data processing thread
|
||||
self.processing_thread = Thread(target=self._process_data, daemon=True)
|
||||
self.processing_thread.start()
|
||||
|
||||
# Start analysis thread
|
||||
self.analysis_thread = Thread(target=self._analyze_data, daemon=True)
|
||||
self.analysis_thread.start()
|
||||
|
||||
logger.info("Realtime analysis started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the realtime analysis process."""
|
||||
self.running = False
|
||||
if self.ws:
|
||||
asyncio.run(self.ws.close())
|
||||
if hasattr(self, 'ws_thread'):
|
||||
self.ws_thread.join(timeout=1)
|
||||
if hasattr(self, 'processing_thread'):
|
||||
self.processing_thread.join(timeout=1)
|
||||
if hasattr(self, 'analysis_thread'):
|
||||
self.analysis_thread.join(timeout=1)
|
||||
logger.info("Realtime analysis stopped")
|
||||
|
||||
def _run_websocket(self):
|
||||
"""Thread function for running WebSocket connection."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self._connect_websocket())
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Connect to WebSocket and receive data."""
|
||||
while self.running:
|
||||
try:
|
||||
logger.info(f"Connecting to WebSocket: {self.ws_url}")
|
||||
async with websockets.connect(self.ws_url) as ws:
|
||||
self.ws = ws
|
||||
logger.info("WebSocket connected")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
message = await ws.recv()
|
||||
data = json.loads(message)
|
||||
|
||||
if 'e' in data and data['e'] == 'trade':
|
||||
tick = {
|
||||
'timestamp': data['T'],
|
||||
'price': float(data['p']),
|
||||
'volume': float(data['q']),
|
||||
'symbol': self.symbol
|
||||
}
|
||||
self.tick_storage.append(tick)
|
||||
self.data_queue.put(tick)
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.warning("WebSocket connection closed")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving WebSocket message: {str(e)}")
|
||||
time.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket connection error: {str(e)}")
|
||||
time.sleep(5) # Wait before reconnecting
|
||||
|
||||
def _process_data(self):
|
||||
"""Process incoming tick data into candles for all timeframes."""
|
||||
logger.info("Starting data processing thread")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Process any new ticks
|
||||
while not self.data_queue.empty():
|
||||
tick = self.data_queue.get()
|
||||
|
||||
# Convert timestamp to datetime
|
||||
timestamp = datetime.fromtimestamp(tick['timestamp'] / 1000)
|
||||
|
||||
# Process for each timeframe
|
||||
for timeframe in self.timeframes:
|
||||
interval = self._get_interval_seconds(timeframe)
|
||||
if interval is None:
|
||||
continue
|
||||
|
||||
# Round timestamp to nearest candle interval
|
||||
candle_ts = int(tick['timestamp'] // (interval * 1000)) * (interval * 1000)
|
||||
|
||||
# Get or create candle for this timeframe
|
||||
if not self.candle_cache[timeframe]:
|
||||
# First candle for this timeframe
|
||||
candle = {
|
||||
'timestamp': candle_ts,
|
||||
'open': tick['price'],
|
||||
'high': tick['price'],
|
||||
'low': tick['price'],
|
||||
'close': tick['price'],
|
||||
'volume': tick['volume']
|
||||
}
|
||||
self.candle_cache[timeframe].append(candle)
|
||||
else:
|
||||
# Update existing candle
|
||||
last_candle = self.candle_cache[timeframe][-1]
|
||||
|
||||
if last_candle['timestamp'] == candle_ts:
|
||||
# Update current candle
|
||||
last_candle['high'] = max(last_candle['high'], tick['price'])
|
||||
last_candle['low'] = min(last_candle['low'], tick['price'])
|
||||
last_candle['close'] = tick['price']
|
||||
last_candle['volume'] += tick['volume']
|
||||
else:
|
||||
# New candle
|
||||
candle = {
|
||||
'timestamp': candle_ts,
|
||||
'open': tick['price'],
|
||||
'high': tick['price'],
|
||||
'low': tick['price'],
|
||||
'close': tick['price'],
|
||||
'volume': tick['volume']
|
||||
}
|
||||
self.candle_cache[timeframe].append(candle)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in data processing: {str(e)}")
|
||||
time.sleep(1)
|
||||
|
||||
def _get_interval_seconds(self, timeframe):
|
||||
"""Convert timeframe string to seconds."""
|
||||
intervals = {
|
||||
'1s': 1,
|
||||
'1m': 60,
|
||||
'1h': 3600,
|
||||
'1d': 86400
|
||||
}
|
||||
return intervals.get(timeframe)
|
||||
|
||||
def _analyze_data(self):
|
||||
"""Thread function for analyzing data and generating signals."""
|
||||
logger.info("Starting analysis thread")
|
||||
|
||||
last_prediction_time = 0
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
# Only make predictions at the specified interval
|
||||
if current_time - last_prediction_time < self.prediction_interval:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
# Prepare input data from all timeframes
|
||||
input_data = {}
|
||||
valid = True
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
if not self.candle_cache[timeframe]:
|
||||
logger.warning(f"No data available for timeframe {timeframe}")
|
||||
valid = False
|
||||
break
|
||||
|
||||
# Get last N candles for this timeframe
|
||||
candles = list(self.candle_cache[timeframe])[-self.data_interface.window_size:]
|
||||
|
||||
# Convert to numpy array
|
||||
ohlcv = np.array([
|
||||
[c['open'], c['high'], c['low'], c['close'], c['volume']]
|
||||
for c in candles
|
||||
])
|
||||
|
||||
# Normalize data
|
||||
ohlcv_normalized = (ohlcv - ohlcv.mean(axis=0)) / (ohlcv.std(axis=0) + 1e-8)
|
||||
input_data[timeframe] = ohlcv_normalized
|
||||
|
||||
if not valid:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
# Make prediction using the model
|
||||
try:
|
||||
prediction = self.model.predict(input_data)
|
||||
|
||||
# Get latest timestamp from 1s timeframe
|
||||
latest_ts = self.candle_cache['1s'][-1]['timestamp'] if self.candle_cache['1s'] else int(time.time() * 1000)
|
||||
|
||||
# Process prediction
|
||||
self._process_prediction(
|
||||
prediction=prediction,
|
||||
timeframe='multi',
|
||||
timestamp=latest_ts
|
||||
)
|
||||
|
||||
last_prediction_time = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"Error making prediction: {str(e)}")
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in analysis: {str(e)}")
|
||||
time.sleep(1)
|
||||
|
||||
def _process_prediction(self, prediction, timeframe, timestamp):
|
||||
"""
|
||||
Process model prediction and generate trading signals.
|
||||
|
||||
Args:
|
||||
prediction: Model prediction output
|
||||
timeframe (str): Timeframe the prediction is for ('multi' for combined)
|
||||
timestamp: Timestamp of the prediction (ms)
|
||||
"""
|
||||
# Convert prediction to trading signal
|
||||
signal, confidence = self._prediction_to_signal(prediction)
|
||||
|
||||
# Convert timestamp to datetime
|
||||
try:
|
||||
dt = datetime.fromtimestamp(timestamp / 1000)
|
||||
except:
|
||||
dt = datetime.now()
|
||||
|
||||
# Log the signal with all timeframes
|
||||
logger.info(
|
||||
f"Signal generated - Timeframes: {', '.join(self.timeframes)}, "
|
||||
f"Timestamp: {dt}, "
|
||||
f"Signal: {signal} (Confidence: {confidence:.2f})"
|
||||
)
|
||||
|
||||
# In a real implementation, we would execute trades here
|
||||
# For now, we'll just log the signals
|
||||
|
||||
def _prediction_to_signal(self, prediction):
|
||||
"""
|
||||
Convert model prediction to trading signal and confidence.
|
||||
|
||||
Args:
|
||||
prediction: Model prediction output (can be dict for multi-timeframe)
|
||||
|
||||
Returns:
|
||||
tuple: (signal, confidence) where signal is BUY/SELL/HOLD,
|
||||
confidence is probability (0-1)
|
||||
"""
|
||||
if isinstance(prediction, dict):
|
||||
# Multi-timeframe prediction - combine signals
|
||||
signals = []
|
||||
confidences = []
|
||||
|
||||
for tf, pred in prediction.items():
|
||||
if len(pred.shape) == 1:
|
||||
# Binary classification
|
||||
signal = "BUY" if pred[0] > 0.5 else "SELL"
|
||||
confidence = pred[0] if signal == "BUY" else 1 - pred[0]
|
||||
else:
|
||||
# Multi-class
|
||||
class_idx = np.argmax(pred)
|
||||
signal = ["SELL", "HOLD", "BUY"][class_idx]
|
||||
confidence = pred[class_idx]
|
||||
|
||||
signals.append(signal)
|
||||
confidences.append(confidence)
|
||||
|
||||
# Simple voting system - count BUY/SELL signals
|
||||
buy_count = signals.count("BUY")
|
||||
sell_count = signals.count("SELL")
|
||||
|
||||
if buy_count > sell_count:
|
||||
final_signal = "BUY"
|
||||
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "BUY"])
|
||||
elif sell_count > buy_count:
|
||||
final_signal = "SELL"
|
||||
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "SELL"])
|
||||
else:
|
||||
final_signal = "HOLD"
|
||||
final_confidence = np.mean(confidences)
|
||||
|
||||
return final_signal, final_confidence
|
||||
|
||||
else:
|
||||
# Single prediction
|
||||
if len(prediction.shape) == 1:
|
||||
# Binary classification
|
||||
signal = "BUY" if prediction[0] > 0.5 else "SELL"
|
||||
confidence = prediction[0] if signal == "BUY" else 1 - prediction[0]
|
||||
else:
|
||||
# Multi-class
|
||||
class_idx = np.argmax(prediction)
|
||||
signal = ["SELL", "HOLD", "BUY"][class_idx]
|
||||
confidence = prediction[class_idx]
|
||||
|
||||
return signal, confidence
|
||||
@@ -50,7 +50,7 @@ class SignalInterpreter:
|
||||
self.oscillation_filter_enabled = self.config.get('oscillation_filter_enabled', False) # Disable oscillation filter by default
|
||||
|
||||
# Sensitivity parameters
|
||||
self.min_price_movement = self.config.get('min_price_movement', 0.0001) # Lower price movement threshold
|
||||
self.min_price_movement = self.config.get('min_price_movement', 0.001) # Align with deadzone; prefer bigger moves
|
||||
self.hold_cooldown = self.config.get('hold_cooldown', 1) # Shorter hold cooldown
|
||||
self.consecutive_signals_required = self.config.get('consecutive_signals_required', 1) # Require only one signal
|
||||
|
||||
|
||||
@@ -20,8 +20,10 @@ class TradingEnvironment(gym.Env):
|
||||
window_size: int = 20,
|
||||
risk_aversion: float = 0.2, # Controls how much to penalize volatility
|
||||
price_scaling: str = 'zscore', # 'zscore', 'minmax', or 'raw'
|
||||
reward_scaling: float = 10.0, # Scale factor for rewards
|
||||
episode_penalty: float = 0.1): # Penalty for active positions at end of episode
|
||||
reward_scaling: float = 10.0, # Scale factor for rewards
|
||||
episode_penalty: float = 0.1, # Penalty for active positions at end of episode
|
||||
min_profit_after_fees: float = 0.0005 # Deadzone: require >= 5 bps beyond fees
|
||||
):
|
||||
super(TradingEnvironment, self).__init__()
|
||||
|
||||
self.data = data
|
||||
@@ -33,6 +35,7 @@ class TradingEnvironment(gym.Env):
|
||||
self.price_scaling = price_scaling
|
||||
self.reward_scaling = reward_scaling
|
||||
self.episode_penalty = episode_penalty
|
||||
self.min_profit_after_fees = max(0.0, float(min_profit_after_fees))
|
||||
|
||||
# Preprocess data if needed
|
||||
self._preprocess_data()
|
||||
@@ -177,8 +180,14 @@ class TradingEnvironment(gym.Env):
|
||||
price_diff = current_price - self.entry_price
|
||||
pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees
|
||||
|
||||
# Adjust reward based on PnL and risk
|
||||
reward = pnl * self.reward_scaling
|
||||
# Deadzone to discourage micro profits
|
||||
if pnl > 0 and pnl < self.min_profit_after_fees:
|
||||
reward = -self.fee_rate
|
||||
elif pnl < 0 and abs(pnl) < self.min_profit_after_fees:
|
||||
reward = pnl * self.reward_scaling * 0.5
|
||||
else:
|
||||
effective_pnl = pnl - (self.min_profit_after_fees if pnl > 0 else 0.0)
|
||||
reward = effective_pnl * self.reward_scaling
|
||||
|
||||
# Track trade performance
|
||||
self.total_trades += 1
|
||||
@@ -212,8 +221,12 @@ class TradingEnvironment(gym.Env):
|
||||
price_diff = current_price - self.entry_price
|
||||
unrealized_pnl = price_diff / self.entry_price
|
||||
|
||||
# Small reward/penalty based on unrealized P&L
|
||||
reward = unrealized_pnl * 0.05 # Scale down to encourage holding good positions
|
||||
# Encourage holding only if unrealized edge exceeds deadzone
|
||||
unrealized_edge = unrealized_pnl
|
||||
if abs(unrealized_edge) >= self.min_profit_after_fees:
|
||||
reward = unrealized_edge * (self.reward_scaling * 0.2)
|
||||
else:
|
||||
reward = -0.0002
|
||||
|
||||
elif self.position < 0: # Short position
|
||||
if action == 0: # BUY (close short)
|
||||
@@ -221,8 +234,13 @@ class TradingEnvironment(gym.Env):
|
||||
price_diff = self.entry_price - current_price
|
||||
pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees
|
||||
|
||||
# Adjust reward based on PnL and risk
|
||||
reward = pnl * self.reward_scaling
|
||||
if pnl > 0 and pnl < self.min_profit_after_fees:
|
||||
reward = -self.fee_rate
|
||||
elif pnl < 0 and abs(pnl) < self.min_profit_after_fees:
|
||||
reward = pnl * self.reward_scaling * 0.5
|
||||
else:
|
||||
effective_pnl = pnl - (self.min_profit_after_fees if pnl > 0 else 0.0)
|
||||
reward = effective_pnl * self.reward_scaling
|
||||
|
||||
# Track trade performance
|
||||
self.total_trades += 1
|
||||
@@ -256,8 +274,12 @@ class TradingEnvironment(gym.Env):
|
||||
price_diff = self.entry_price - current_price
|
||||
unrealized_pnl = price_diff / self.entry_price
|
||||
|
||||
# Small reward/penalty based on unrealized P&L
|
||||
reward = unrealized_pnl * 0.05 # Scale down to encourage holding good positions
|
||||
# Encourage holding only if unrealized edge exceeds deadzone
|
||||
unrealized_edge = unrealized_pnl
|
||||
if abs(unrealized_edge) >= self.min_profit_after_fees:
|
||||
reward = unrealized_edge * (self.reward_scaling * 0.2)
|
||||
else:
|
||||
reward = -0.0002
|
||||
|
||||
# Record the action
|
||||
self.actions_taken.append(action)
|
||||
|
||||
Reference in New Issue
Block a user