Merge commit 'd49a473ed6f4aef55bfdd47d6370e53582be6b7b' into cleanup

This commit is contained in:
Dobromir Popov
2025-10-01 00:32:19 +03:00
353 changed files with 81004 additions and 35899 deletions

View File

@@ -1,16 +0,0 @@
"""
Neural Network Trading System
============================
A comprehensive neural network trading system that uses deep learning models
to analyze cryptocurrency price data and generate trading signals.
The system consists of:
1. Data Interface: Connects to realtime trading data
2. CNN Model: Deep convolutional neural network for feature extraction
3. Transformer Model: Processes high-level features for improved pattern recognition
4. MoE: Mixture of Experts model that combines multiple neural networks
"""
__version__ = '0.1.0'
__author__ = 'Gogo2 Project'

View File

@@ -1,11 +0,0 @@
"""
Neural Network Data
=================
This package is used to store datasets and model outputs.
It does not contain any code, but serves as a storage location for:
- Training datasets
- Evaluation results
- Inference outputs
- Model checkpoints
"""

View File

@@ -1,6 +0,0 @@
# Trading environments for reinforcement learning
# This module contains environments for training trading agents
from NN.environments.trading_env import TradingEnvironment
__all__ = ['TradingEnvironment']

View File

@@ -1,532 +0,0 @@
import numpy as np
import pandas as pd
from typing import Dict, Tuple, List, Any, Optional
import logging
import gym
from gym import spaces
import random
# Configure logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TradingEnvironment(gym.Env):
"""
Trading environment implementing gym interface for reinforcement learning
2-Action System:
- 0: SELL (or close long position)
- 1: BUY (or close short position)
Intelligent Position Management:
- When neutral: Actions enter positions
- When positioned: Actions can close or flip positions
- Different thresholds for entry vs exit decisions
State:
- OHLCV data from multiple timeframes
- Technical indicators
- Position data and unrealized PnL
"""
def __init__(
self,
data_interface,
initial_balance: float = 10000.0,
transaction_fee: float = 0.0002,
window_size: int = 20,
max_position: float = 1.0,
reward_scaling: float = 1.0,
entry_threshold: float = 0.6, # Higher threshold for entering positions
exit_threshold: float = 0.3, # Lower threshold for exiting positions
):
"""
Initialize the trading environment with 2-action system.
Args:
data_interface: DataInterface instance to get market data
initial_balance: Initial balance in the base currency
transaction_fee: Fee for each transaction as a fraction of trade value
window_size: Number of candles in the observation window
max_position: Maximum position size as a fraction of balance
reward_scaling: Scale factor for rewards
entry_threshold: Confidence threshold for entering new positions
exit_threshold: Confidence threshold for exiting positions
"""
super().__init__()
self.data_interface = data_interface
self.initial_balance = initial_balance
self.transaction_fee = transaction_fee
self.window_size = window_size
self.max_position = max_position
self.reward_scaling = reward_scaling
self.entry_threshold = entry_threshold
self.exit_threshold = exit_threshold
# Load data for primary timeframe (assuming the first one is primary)
self.timeframe = self.data_interface.timeframes[0]
self.reset_data()
# Define action and observation spaces for 2-action system
self.action_space = spaces.Discrete(2) # 0=SELL, 1=BUY
# For observation space, we consider multiple timeframes with OHLCV data
# and additional features like technical indicators, position info, etc.
n_timeframes = len(self.data_interface.timeframes)
n_features = 5 # OHLCV data by default
# Add additional features for position, balance, unrealized_pnl, etc.
additional_features = 5 # position, balance, unrealized_pnl, entry_price, position_duration
# Calculate total feature dimension
total_features = (n_timeframes * n_features * self.window_size) + additional_features
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(total_features,), dtype=np.float32
)
# Use tuple for state_shape that EnhancedCNN expects
self.state_shape = (total_features,)
# Position tracking for 2-action system
self.position = 0.0 # -1 (short), 0 (neutral), 1 (long)
self.entry_price = 0.0 # Price at which position was entered
self.entry_step = 0 # Step at which position was entered
# Initialize state
self.reset()
def reset_data(self):
"""Reset data and generate a new set of price data for training"""
# Get data for each timeframe
self.data = {}
for tf in self.data_interface.timeframes:
df = self.data_interface.dataframes[tf]
if df is not None and not df.empty:
self.data[tf] = df
if not self.data:
raise ValueError("No data available for training")
# Use the primary timeframe for step count
self.prices = self.data[self.timeframe]['close'].values
self.timestamps = self.data[self.timeframe].index.values
self.max_steps = len(self.prices) - self.window_size - 1
def reset(self):
"""Reset the environment to initial state"""
# Reset trading variables
self.balance = self.initial_balance
self.trades = []
self.rewards = []
# Reset step counter
self.current_step = self.window_size
# Get initial observation
observation = self._get_observation()
return observation
def step(self, action):
"""
Take a step in the environment using 2-action system with intelligent position management.
Args:
action: Action to take (0: SELL, 1: BUY)
Returns:
tuple: (observation, reward, done, info)
"""
# Get current state before taking action
prev_balance = self.balance
prev_position = self.position
prev_price = self.prices[self.current_step]
# Take action with intelligent position management
info = {}
reward = 0
last_position_info = None
# Get current price
current_price = self.prices[self.current_step]
next_price = self.prices[self.current_step + 1] if self.current_step + 1 < len(self.prices) else current_price
# Implement 2-action system with position management
if action == 0: # SELL action
if self.position == 0: # No position - enter short
self._open_position(-1.0 * self.max_position, current_price)
logger.info(f"ENTER SHORT at step {self.current_step}, price: {current_price:.4f}")
reward = -self.transaction_fee # Entry cost
elif self.position > 0: # Long position - close it
close_pnl, last_position_info = self._close_position(current_price)
reward += close_pnl * self.reward_scaling
logger.info(f"CLOSE LONG at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
elif self.position < 0: # Already short - potentially flip to long if very strong signal
# For now, just hold the short position (no action)
pass
elif action == 1: # BUY action
if self.position == 0: # No position - enter long
self._open_position(1.0 * self.max_position, current_price)
logger.info(f"ENTER LONG at step {self.current_step}, price: {current_price:.4f}")
reward = -self.transaction_fee # Entry cost
elif self.position < 0: # Short position - close it
close_pnl, last_position_info = self._close_position(current_price)
reward += close_pnl * self.reward_scaling
logger.info(f"CLOSE SHORT at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
elif self.position > 0: # Already long - potentially flip to short if very strong signal
# For now, just hold the long position (no action)
pass
# Calculate unrealized PnL and add to reward if holding position
if self.position != 0:
unrealized_pnl = self._calculate_unrealized_pnl(next_price)
reward += unrealized_pnl * self.reward_scaling * 0.1 # Scale down unrealized PnL
# Apply time-based holding penalty to encourage decisive actions
position_duration = self.current_step - self.entry_step
holding_penalty = min(position_duration * 0.0001, 0.01) # Max 1% penalty
reward -= holding_penalty
# Reward staying neutral when uncertain (no clear setup)
else:
reward += 0.0001 # Small reward for not trading without clear signals
# Move to next step
self.current_step += 1
# Get new observation
observation = self._get_observation()
# Check if episode is done
done = self.current_step >= len(self.prices) - 1
# If done, close any remaining positions
if done and self.position != 0:
final_pnl, last_position_info = self._close_position(current_price)
reward += final_pnl * self.reward_scaling
info['final_pnl'] = final_pnl
info['final_balance'] = self.balance
logger.info(f"Episode ended. Final balance: {self.balance:.4f}, Return: {(self.balance/self.initial_balance-1)*100:.2f}%")
# Track trade result if position changed or position was closed
if prev_position != self.position or last_position_info is not None:
# Calculate realized PnL if position was closed
realized_pnl = 0
position_info = {}
if last_position_info is not None:
# Use the position information from closing
realized_pnl = last_position_info['pnl']
position_info = last_position_info
else:
# Calculate manually based on balance change
realized_pnl = self.balance - prev_balance if prev_position != 0 else 0
# Record detailed trade information
trade_result = {
'step': self.current_step,
'timestamp': self.timestamps[self.current_step],
'action': action,
'action_name': ['SELL', 'BUY'][action],
'price': current_price,
'position_changed': prev_position != self.position,
'prev_position': prev_position,
'new_position': self.position,
'position_size': abs(self.position) if self.position != 0 else abs(prev_position),
'entry_price': position_info.get('entry_price', self.entry_price),
'exit_price': position_info.get('exit_price', current_price),
'realized_pnl': realized_pnl,
'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0,
'pnl': realized_pnl, # Total PnL (realized for this step)
'balance_before': prev_balance,
'balance_after': self.balance,
'trade_fee': position_info.get('fee', abs(self.position - prev_position) * current_price * self.transaction_fee)
}
info['trade_result'] = trade_result
self.trades.append(trade_result)
# Log trade details
logger.info(f"Trade executed - Action: {['SELL', 'BUY'][action]}, "
f"Price: {current_price:.4f}, PnL: {realized_pnl:.4f}, "
f"Balance: {self.balance:.4f}")
# Store reward
self.rewards.append(reward)
# Update info dict with current state
info.update({
'step': self.current_step,
'price': current_price,
'prev_price': prev_price,
'price_change': (current_price - prev_price) / prev_price if prev_price != 0 else 0,
'balance': self.balance,
'position': self.position,
'entry_price': self.entry_price,
'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0.0,
'total_trades': len(self.trades),
'total_pnl': self.total_pnl,
'return_pct': (self.balance/self.initial_balance-1)*100
})
return observation, reward, done, info
def _calculate_unrealized_pnl(self, current_price):
"""Calculate unrealized PnL for current position"""
if self.position == 0 or self.entry_price == 0:
return 0.0
if self.position > 0: # Long position
return self.position * (current_price / self.entry_price - 1.0)
else: # Short position
return -self.position * (1.0 - current_price / self.entry_price)
def _open_position(self, position_size: float, entry_price: float):
"""Open a new position"""
self.position = position_size
self.entry_price = entry_price
self.entry_step = self.current_step
# Calculate position value
position_value = abs(position_size) * entry_price
# Apply transaction fee
fee = position_value * self.transaction_fee
self.balance -= fee
logger.info(f"Opened position: {position_size:.4f} at {entry_price:.4f}, fee: {fee:.4f}")
def _close_position(self, exit_price: float) -> Tuple[float, Dict]:
"""Close current position and return PnL"""
if self.position == 0:
return 0.0, {}
# Calculate PnL
if self.position > 0: # Long position
pnl = (exit_price - self.entry_price) / self.entry_price
else: # Short position
pnl = (self.entry_price - exit_price) / self.entry_price
# Apply transaction fees (entry + exit)
position_value = abs(self.position) * exit_price
exit_fee = position_value * self.transaction_fee
total_fees = exit_fee # Entry fee already applied when opening
# Net PnL after fees
net_pnl = pnl - (total_fees / (abs(self.position) * self.entry_price))
# Update balance
self.balance *= (1 + net_pnl)
self.total_pnl += net_pnl
# Track trade
position_info = {
'position_size': self.position,
'entry_price': self.entry_price,
'exit_price': exit_price,
'pnl': net_pnl,
'duration': self.current_step - self.entry_step,
'entry_step': self.entry_step,
'exit_step': self.current_step
}
self.trades.append(position_info)
# Update trade statistics
if net_pnl > 0:
self.winning_trades += 1
else:
self.losing_trades += 1
logger.info(f"Closed position: {self.position:.4f}, PnL: {net_pnl:.4f}, Duration: {position_info['duration']} steps")
# Reset position
self.position = 0.0
self.entry_price = 0.0
self.entry_step = 0
return net_pnl, position_info
def _get_observation(self):
"""
Get the current observation.
Returns:
np.array: The observation vector
"""
observations = []
# Get data from each timeframe
for tf in self.data_interface.timeframes:
if tf in self.data:
# Get the window of data for this timeframe
df = self.data[tf]
start_idx = self._align_timeframe_index(tf)
if start_idx is not None and start_idx >= 0 and start_idx + self.window_size <= len(df):
window = df.iloc[start_idx:start_idx + self.window_size]
# Extract OHLCV data
ohlcv = window[['open', 'high', 'low', 'close', 'volume']].values
# Normalize OHLCV data
last_close = ohlcv[-1, 3] # Last close price
ohlcv_normalized = np.zeros_like(ohlcv)
ohlcv_normalized[:, 0] = ohlcv[:, 0] / last_close - 1.0 # open
ohlcv_normalized[:, 1] = ohlcv[:, 1] / last_close - 1.0 # high
ohlcv_normalized[:, 2] = ohlcv[:, 2] / last_close - 1.0 # low
ohlcv_normalized[:, 3] = ohlcv[:, 3] / last_close - 1.0 # close
# Normalize volume (relative to moving average of volume)
if 'volume' in window.columns:
volume_ma = ohlcv[:, 4].mean()
if volume_ma > 0:
ohlcv_normalized[:, 4] = ohlcv[:, 4] / volume_ma - 1.0
else:
ohlcv_normalized[:, 4] = 0.0
else:
ohlcv_normalized[:, 4] = 0.0
# Flatten and add to observations
observations.append(ohlcv_normalized.flatten())
else:
# Fill with zeros if not enough data
observations.append(np.zeros(self.window_size * 5))
# Add position and balance information
current_price = self.prices[self.current_step]
position_info = np.array([
self.position / self.max_position, # Normalized position (-1 to 1)
self.balance / self.initial_balance - 1.0, # Normalized balance change
self._calculate_unrealized_pnl(current_price) # Unrealized PnL
])
observations.append(position_info)
# Concatenate all observations
observation = np.concatenate(observations)
return observation
def _align_timeframe_index(self, timeframe):
"""
Align the index of a higher timeframe with the current step in the primary timeframe.
Args:
timeframe: The timeframe to align
Returns:
int: The starting index in the higher timeframe
"""
if timeframe == self.timeframe:
return self.current_step - self.window_size
# Get timestamps for current primary timeframe step
primary_ts = self.timestamps[self.current_step]
# Find closest index in the higher timeframe
higher_ts = self.data[timeframe].index.values
idx = np.searchsorted(higher_ts, primary_ts)
# Adjust to get the starting index
start_idx = max(0, idx - self.window_size)
return start_idx
def get_last_positions(self, n=5):
"""
Get detailed information about the last n positions.
Args:
n: Number of last positions to return
Returns:
list: List of dictionaries containing position details
"""
if not self.trades:
return []
# Filter trades to only include those that closed positions
position_trades = [t for t in self.trades if t.get('realized_pnl', 0) != 0 or (t.get('prev_position', 0) != 0 and t.get('new_position', 0) == 0)]
positions = []
last_n_trades = position_trades[-n:] if len(position_trades) >= n else position_trades
for trade in last_n_trades:
position_info = {
'timestamp': trade.get('timestamp', self.timestamps[trade['step']]),
'action': trade.get('action_name', ['SELL', 'BUY'][trade['action']]),
'entry_price': trade.get('entry_price', 0.0),
'exit_price': trade.get('exit_price', trade['price']),
'position_size': trade.get('position_size', self.max_position),
'realized_pnl': trade.get('realized_pnl', 0.0),
'fee': trade.get('trade_fee', 0.0),
'pnl': trade.get('pnl', 0.0),
'pnl_percentage': (trade.get('pnl', 0.0) / self.initial_balance) * 100,
'balance_before': trade.get('balance_before', 0.0),
'balance_after': trade.get('balance_after', 0.0),
'duration': trade.get('duration', 'N/A')
}
positions.append(position_info)
return positions
def render(self, mode='human'):
"""Render the environment"""
current_step = self.current_step
current_price = self.prices[current_step]
# Display basic information
print(f"\nTrading Environment Status:")
print(f"============================")
print(f"Step: {current_step}/{len(self.prices)-1}")
print(f"Current Price: {current_price:.4f}")
print(f"Current Balance: {self.balance:.4f}")
print(f"Current Position: {self.position:.4f}")
if self.position != 0:
unrealized_pnl = self._calculate_unrealized_pnl(current_price)
print(f"Entry Price: {self.entry_price:.4f}")
print(f"Unrealized PnL: {unrealized_pnl:.4f} ({unrealized_pnl/self.balance*100:.2f}%)")
print(f"Total PnL: {self.total_pnl:.4f} ({self.total_pnl/self.initial_balance*100:.2f}%)")
print(f"Total Trades: {len(self.trades)}")
if len(self.trades) > 0:
win_trades = [t for t in self.trades if t.get('realized_pnl', 0) > 0]
win_count = len(win_trades)
# Count trades that closed positions (not just changed them)
closed_positions = [t for t in self.trades if t.get('realized_pnl', 0) != 0]
closed_count = len(closed_positions)
win_rate = win_count / closed_count if closed_count > 0 else 0
print(f"Positions Closed: {closed_count}")
print(f"Winning Positions: {win_count}")
print(f"Win Rate: {win_rate:.2f}")
# Display last 5 positions
print("\nLast 5 Positions:")
print("================")
last_positions = self.get_last_positions(5)
if not last_positions:
print("No closed positions yet.")
for pos in last_positions:
print(f"Time: {pos['timestamp']}")
print(f"Action: {pos['action']}")
print(f"Entry: {pos['entry_price']:.4f}, Exit: {pos['exit_price']:.4f}")
print(f"Size: {pos['position_size']:.4f}")
print(f"PnL: {pos['realized_pnl']:.4f} ({pos['pnl_percentage']:.2f}%)")
print(f"Fee: {pos['fee']:.4f}")
print(f"Balance: {pos['balance_before']:.4f} -> {pos['balance_after']:.4f}")
print("----------------")
return
def close(self):
"""Close the environment"""
pass

View File

@@ -1,162 +0,0 @@
# Trading Agent System
This directory contains the implementation of a modular trading agent system that integrates with the neural network models and can execute trades on various cryptocurrency exchanges.
## Overview
The trading agent system is designed to:
1. Connect to different cryptocurrency exchanges using a common interface
2. Execute trades based on signals from neural network models
3. Manage risk through position sizing, trade limits, and cooldown periods
4. Monitor and report on trading activity
## Components
### Exchange Interfaces
- `ExchangeInterface`: Abstract base class defining the common interface for all exchange implementations
- `BinanceInterface`: Implementation for the Binance exchange, with support for both mainnet and testnet
- `MEXCInterface`: Implementation for the MEXC exchange
### Trading Agent
The `TradingAgent` class (`trading_agent.py`) manages trading activities:
- Connects to the configured exchange
- Processes trading signals from neural network models
- Applies trading rules and risk management
- Tracks and reports trading performance
### Neural Network Orchestrator
The `NeuralNetworkOrchestrator` class (`neural_network_orchestrator.py`) coordinates between models and trading:
- Manages the neural network inference process
- Routes model signals to the trading agent
- Provides integration with the RealTimeChart for visualization
## Usage
### Basic Usage
```python
from NN.exchanges import BinanceInterface, MEXCInterface
from NN.trading_agent import TradingAgent
# Initialize an exchange interface
exchange = BinanceInterface(
api_key="your_api_key",
api_secret="your_api_secret",
test_mode=True # Use testnet
)
# Connect to the exchange
exchange.connect()
# Create a trading agent
agent = TradingAgent(
exchange_name="binance",
api_key="your_api_key",
api_secret="your_api_secret",
test_mode=True,
trade_symbols=["BTC/USDT", "ETH/USDT"],
position_size=0.1,
max_trades_per_day=5,
trade_cooldown_minutes=60
)
# Start the trading agent
agent.start()
# Process a trading signal
agent.process_signal(
symbol="BTC/USDT",
action="BUY",
confidence=0.85,
timestamp=int(time.time())
)
# Stop the trading agent when done
agent.stop()
```
### Integration with Neural Network Models
The system is designed to be integrated with neural network models through the `NeuralNetworkOrchestrator`:
```python
from NN.neural_network_orchestrator import NeuralNetworkOrchestrator
# Configure exchange
exchange_config = {
"exchange": "binance",
"api_key": "your_api_key",
"api_secret": "your_api_secret",
"test_mode": True,
"trade_symbols": ["BTC/USDT", "ETH/USDT"],
"position_size": 0.1,
"max_trades_per_day": 5,
"trade_cooldown_minutes": 60
}
# Initialize orchestrator
orchestrator = NeuralNetworkOrchestrator(
model=model,
data_interface=data_interface,
chart=chart,
symbols=["BTC/USDT", "ETH/USDT"],
timeframes=["1m", "5m", "1h", "4h", "1d"],
exchange_config=exchange_config
)
# Start inference and trading
orchestrator.start_inference()
```
## Configuration
### Exchange-Specific Configuration
- **Binance**: Supports both mainnet and testnet environments
- **MEXC**: Supports mainnet only (no test environment available)
### Trading Agent Configuration
- `exchange_name`: Name of exchange ('binance', 'mexc')
- `api_key`: API key for the exchange
- `api_secret`: API secret for the exchange
- `test_mode`: Whether to use test/sandbox environment
- `trade_symbols`: List of trading symbols to monitor
- `position_size`: Size of each position as a fraction of balance (0.0-1.0)
- `max_trades_per_day`: Maximum number of trades to execute per day
- `trade_cooldown_minutes`: Minimum time between trades in minutes
## Adding New Exchanges
To add support for a new exchange:
1. Create a new class that inherits from `ExchangeInterface`
2. Implement all required methods (see `exchange_interface.py`)
3. Add the new exchange to the imports in `__init__.py`
4. Update the `_create_exchange` method in `TradingAgent` to support the new exchange
Example:
```python
class KrakenInterface(ExchangeInterface):
"""Kraken Exchange API Interface"""
def __init__(self, api_key=None, api_secret=None, test_mode=True):
super().__init__(api_key, api_secret, test_mode)
# Initialize Kraken-specific attributes
# Implement all required methods...
```
## Security Considerations
- API keys should have trade permissions but not withdrawal permissions
- Use environment variables or secure storage for API credentials
- Always test with small position sizes before deploying with larger amounts
- Consider using test mode/testnet for initial testing

View File

@@ -1,5 +0,0 @@
from .exchange_interface import ExchangeInterface
from .mexc_interface import MEXCInterface
from .binance_interface import BinanceInterface
__all__ = ['ExchangeInterface', 'MEXCInterface', 'BinanceInterface']

View File

@@ -1,276 +0,0 @@
import logging
import time
from typing import Dict, Any, List, Optional
import requests
import hmac
import hashlib
from urllib.parse import urlencode
from .exchange_interface import ExchangeInterface
logger = logging.getLogger(__name__)
class BinanceInterface(ExchangeInterface):
"""Binance Exchange API Interface"""
def __init__(self, api_key: str = None, api_secret: str = None, test_mode: bool = True):
"""Initialize Binance exchange interface.
Args:
api_key: Binance API key
api_secret: Binance API secret
test_mode: If True, use testnet environment
"""
super().__init__(api_key, api_secret, test_mode)
# Use testnet URLs if in test mode
if test_mode:
self.base_url = "https://testnet.binance.vision"
else:
self.base_url = "https://api.binance.com"
self.api_version = "v3"
def connect(self) -> bool:
"""Connect to Binance API. This is a no-op for REST API."""
if not self.api_key or not self.api_secret:
logger.warning("Binance API credentials not provided. Running in read-only mode.")
return False
try:
# Test connection by pinging server and checking account info
ping_result = self._send_public_request('GET', 'ping')
if self.api_key and self.api_secret:
# Check account connectivity
self.get_account_info()
logger.info(f"Successfully connected to Binance API ({'testnet' if self.test_mode else 'live'})")
return True
except Exception as e:
logger.error(f"Failed to connect to Binance API: {str(e)}")
return False
def _generate_signature(self, params: Dict[str, Any]) -> str:
"""Generate signature for authenticated requests."""
query_string = urlencode(params)
signature = hmac.new(
self.api_secret.encode('utf-8'),
query_string.encode('utf-8'),
hashlib.sha256
).hexdigest()
return signature
def _send_public_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
"""Send public request to Binance API."""
url = f"{self.base_url}/api/{self.api_version}/{endpoint}"
try:
if method.upper() == 'GET':
response = requests.get(url, params=params)
else:
response = requests.post(url, json=params)
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Error in public request to {endpoint}: {str(e)}")
raise
def _send_private_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
"""Send private/authenticated request to Binance API."""
if not self.api_key or not self.api_secret:
raise ValueError("API key and secret are required for private requests")
if params is None:
params = {}
# Add timestamp
params['timestamp'] = int(time.time() * 1000)
# Generate signature
signature = self._generate_signature(params)
params['signature'] = signature
# Set headers
headers = {
'X-MBX-APIKEY': self.api_key
}
url = f"{self.base_url}/api/{self.api_version}/{endpoint}"
try:
if method.upper() == 'GET':
response = requests.get(url, params=params, headers=headers)
elif method.upper() == 'POST':
response = requests.post(url, data=params, headers=headers)
elif method.upper() == 'DELETE':
response = requests.delete(url, params=params, headers=headers)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
# Log detailed error if available
if response.status_code != 200:
logger.error(f"Binance API error: {response.text}")
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Error in private request to {endpoint}: {str(e)}")
raise
def get_account_info(self) -> Dict[str, Any]:
"""Get account information."""
return self._send_private_request('GET', 'account')
def get_balance(self, asset: str) -> float:
"""Get balance of a specific asset.
Args:
asset: Asset symbol (e.g., 'BTC', 'USDT')
Returns:
float: Available balance of the asset
"""
try:
account_info = self._send_private_request('GET', 'account')
balances = account_info.get('balances', [])
for balance in balances:
if balance['asset'] == asset:
return float(balance['free'])
# Asset not found
return 0.0
except Exception as e:
logger.error(f"Error getting balance for {asset}: {str(e)}")
return 0.0
def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker data for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
Returns:
dict: Ticker data including price information
"""
binance_symbol = symbol.replace('/', '')
try:
ticker = self._send_public_request('GET', 'ticker/24hr', {'symbol': binance_symbol})
# Convert to a standardized format
result = {
'symbol': symbol,
'bid': float(ticker['bidPrice']),
'ask': float(ticker['askPrice']),
'last': float(ticker['lastPrice']),
'volume': float(ticker['volume']),
'timestamp': int(ticker['closeTime'])
}
return result
except Exception as e:
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
raise
def place_order(self, symbol: str, side: str, order_type: str,
quantity: float, price: float = None) -> Dict[str, Any]:
"""Place an order on the exchange.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
side: Order side ('buy' or 'sell')
order_type: Order type ('market', 'limit', etc.)
quantity: Order quantity
price: Order price (for limit orders)
Returns:
dict: Order information including order ID
"""
binance_symbol = symbol.replace('/', '')
params = {
'symbol': binance_symbol,
'side': side.upper(),
'type': order_type.upper(),
'quantity': quantity,
}
if order_type.lower() == 'limit' and price is not None:
params['price'] = price
params['timeInForce'] = 'GTC' # Good Till Cancelled
# Use test order endpoint in test mode
endpoint = 'order/test' if self.test_mode else 'order'
try:
order_result = self._send_private_request('POST', endpoint, params)
return order_result
except Exception as e:
logger.error(f"Error placing {side} {order_type} order for {symbol}: {str(e)}")
raise
def cancel_order(self, symbol: str, order_id: str) -> bool:
"""Cancel an existing order.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
order_id: ID of the order to cancel
Returns:
bool: True if cancellation successful, False otherwise
"""
binance_symbol = symbol.replace('/', '')
params = {
'symbol': binance_symbol,
'orderId': order_id
}
try:
cancel_result = self._send_private_request('DELETE', 'order', params)
return True
except Exception as e:
logger.error(f"Error cancelling order {order_id} for {symbol}: {str(e)}")
return False
def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]:
"""Get status of an existing order.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
order_id: ID of the order
Returns:
dict: Order status information
"""
binance_symbol = symbol.replace('/', '')
params = {
'symbol': binance_symbol,
'orderId': order_id
}
try:
order_info = self._send_private_request('GET', 'order', params)
return order_info
except Exception as e:
logger.error(f"Error getting order status for {order_id} on {symbol}: {str(e)}")
raise
def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]:
"""Get all open orders, optionally filtered by symbol.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT'), or None for all symbols
Returns:
list: List of open orders
"""
params = {}
if symbol:
params['symbol'] = symbol.replace('/', '')
try:
open_orders = self._send_private_request('GET', 'openOrders', params)
return open_orders
except Exception as e:
logger.error(f"Error getting open orders: {str(e)}")
return []

View File

@@ -1,191 +0,0 @@
import abc
import logging
from typing import Dict, Any, List, Tuple, Optional
logger = logging.getLogger(__name__)
class ExchangeInterface(abc.ABC):
"""Base class for all exchange interfaces.
This abstract class defines the required methods that all exchange
implementations must provide to ensure compatibility with the trading system.
"""
def __init__(self, api_key: str = None, api_secret: str = None, test_mode: bool = True):
"""Initialize the exchange interface.
Args:
api_key: API key for the exchange
api_secret: API secret for the exchange
test_mode: If True, use test/sandbox environment
"""
self.api_key = api_key
self.api_secret = api_secret
self.test_mode = test_mode
self.client = None
self.last_price_cache = {}
@abc.abstractmethod
def connect(self) -> bool:
"""Connect to the exchange API.
Returns:
bool: True if connection successful, False otherwise
"""
pass
@abc.abstractmethod
def get_balance(self, asset: str) -> float:
"""Get balance of a specific asset.
Args:
asset: Asset symbol (e.g., 'BTC', 'USDT')
Returns:
float: Available balance of the asset
"""
pass
@abc.abstractmethod
def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker data for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
Returns:
dict: Ticker data including price information
"""
pass
@abc.abstractmethod
def place_order(self, symbol: str, side: str, order_type: str,
quantity: float, price: float = None) -> Dict[str, Any]:
"""Place an order on the exchange.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
side: Order side ('buy' or 'sell')
order_type: Order type ('market', 'limit', etc.)
quantity: Order quantity
price: Order price (for limit orders)
Returns:
dict: Order information including order ID
"""
pass
@abc.abstractmethod
def cancel_order(self, symbol: str, order_id: str) -> bool:
"""Cancel an existing order.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
order_id: ID of the order to cancel
Returns:
bool: True if cancellation successful, False otherwise
"""
pass
@abc.abstractmethod
def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]:
"""Get status of an existing order.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
order_id: ID of the order
Returns:
dict: Order status information
"""
pass
@abc.abstractmethod
def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]:
"""Get all open orders, optionally filtered by symbol.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT'), or None for all symbols
Returns:
list: List of open orders
"""
pass
def get_last_price(self, symbol: str) -> float:
"""Get last known price for a symbol, may use cached value.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
Returns:
float: Last price
"""
try:
ticker = self.get_ticker(symbol)
price = float(ticker['last'])
self.last_price_cache[symbol] = price
return price
except Exception as e:
logger.error(f"Error getting price for {symbol}: {str(e)}")
# Return cached price if available
return self.last_price_cache.get(symbol, 0.0)
def execute_trade(self, symbol: str, action: str, quantity: float = None,
percent_of_balance: float = None) -> Optional[Dict[str, Any]]:
"""Execute a trade based on a signal.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
action: Trade action ('BUY', 'SELL')
quantity: Specific quantity to trade
percent_of_balance: Alternative to quantity - percentage of available balance to use
Returns:
dict: Order information or None if order failed
"""
if action not in ['BUY', 'SELL']:
logger.error(f"Invalid action: {action}. Must be 'BUY' or 'SELL'")
return None
side = action.lower()
try:
# Determine base and quote assets from symbol (e.g., BTC/USDT -> BTC, USDT)
base_asset, quote_asset = symbol.split('/')
# Calculate quantity if percent_of_balance is provided
if quantity is None and percent_of_balance is not None:
if percent_of_balance <= 0 or percent_of_balance > 1:
logger.error(f"Invalid percent_of_balance: {percent_of_balance}. Must be between 0 and 1")
return None
if side == 'buy':
# For buy, use quote asset (e.g., USDT)
balance = self.get_balance(quote_asset)
price = self.get_last_price(symbol)
quantity = (balance * percent_of_balance) / price
else:
# For sell, use base asset (e.g., BTC)
balance = self.get_balance(base_asset)
quantity = balance * percent_of_balance
if not quantity or quantity <= 0:
logger.error(f"Invalid quantity: {quantity}")
return None
# Place market order
order = self.place_order(
symbol=symbol,
side=side,
order_type='market',
quantity=quantity
)
logger.info(f"Executed {side.upper()} order for {quantity} {base_asset} at market price")
return order
except Exception as e:
logger.error(f"Error executing {action} trade for {symbol}: {str(e)}")
return None

View File

@@ -1,520 +0,0 @@
import logging
import time
from typing import Dict, Any, List, Optional
import requests
import hmac
import hashlib
from urllib.parse import urlencode, quote_plus
import json # Added for json.dumps
from .exchange_interface import ExchangeInterface
logger = logging.getLogger(__name__)
# https://github.com/mexcdevelop/mexc-api-postman/blob/main/MEXC%20V3.postman_collection.json
# MEXC V3.postman_collection.json
class MEXCInterface(ExchangeInterface):
"""MEXC Exchange API Interface"""
def __init__(self, api_key: str = "", api_secret: str = "", test_mode: bool = True, trading_mode: str = 'simulation'):
"""Initialize MEXC exchange interface.
Args:
api_key: MEXC API key
api_secret: MEXC API secret
test_mode: If True, use test/sandbox environment (Note: MEXC doesn't have a true sandbox)
trading_mode: 'simulation', 'testnet', or 'live'. Determines API endpoints used.
"""
super().__init__(api_key, api_secret, test_mode)
self.trading_mode = trading_mode # Store the trading mode
# MEXC API Base URLs
self.base_url = "https://api.mexc.com" # Live API URL
if self.trading_mode == 'testnet':
# Note: MEXC does not have a separate testnet for spot trading.
# We use the live API for 'testnet' mode and rely on 'simulation' for true dry-runs.
logger.warning("MEXC does not have a separate testnet for spot trading. Using live API for 'testnet' mode.")
self.api_version = "api/v3"
self.recv_window = 5000 # 5 seconds window for request validity
# Session for HTTP requests
self.session = requests.Session()
logger.info(f"MEXCInterface initialized in {self.trading_mode} mode. Ensure correct API endpoints are being used.")
def connect(self) -> bool:
"""Test connection to MEXC API by fetching account info."""
if not self.api_key or not self.api_secret:
logger.error("MEXC API key or secret not set. Cannot connect.")
return False
# Test connection by making a small, authenticated request
try:
account_info = self.get_account_info()
if account_info:
logger.info("Successfully connected to MEXC API and retrieved account info.")
return True
else:
logger.error("Failed to connect to MEXC API: Could not retrieve account info.")
return False
except Exception as e:
logger.error(f"Exception during MEXC API connection test: {e}")
return False
def _format_spot_symbol(self, symbol: str) -> str:
"""Formats a symbol to MEXC spot API standard (e.g., 'ETH/USDT' -> 'ETHUSDC')."""
if '/' in symbol:
base, quote = symbol.split('/')
# Convert USDT to USDC for MEXC spot trading
if quote.upper() == 'USDT':
quote = 'USDC'
return f"{base.upper()}{quote.upper()}"
else:
# Convert USDT to USDC for symbols like ETHUSDT
symbol = symbol.upper()
if symbol.endswith('USDT'):
symbol = symbol.replace('USDT', 'USDC')
return symbol
def _format_futures_symbol(self, symbol: str) -> str:
"""Formats a symbol to MEXC futures API standard (e.g., 'ETH/USDT' -> 'ETH_USDT')."""
# This method is included for completeness but should not be used for spot trading
return symbol.replace('/', '_').upper()
def _generate_signature(self, timestamp: str, method: str, endpoint: str, params: Dict[str, Any]) -> str:
"""Generate signature for private API calls using MEXC's official method"""
# MEXC signature format varies by method:
# For GET/DELETE: URL-encoded query string of alphabetically sorted parameters.
# For POST: JSON string of parameters (no sorting needed).
# The API-Secret is used as the HMAC SHA256 key.
# Remove signature from params to avoid circular inclusion
clean_params = {k: v for k, v in params.items() if k != 'signature'}
parameter_string: str
if method.upper() == "POST":
# For POST requests, the signature parameter is a JSON string
# Ensure sorting keys for consistent JSON string generation across runs
# even though MEXC says sorting is not required for POST params, it's good practice.
parameter_string = json.dumps(clean_params, sort_keys=True, separators=(',', ':'))
else:
# For GET/DELETE requests, parameters are spliced in dictionary order with & interval
sorted_params = sorted(clean_params.items())
parameter_string = '&'.join(f"{key}={str(value)}" for key, value in sorted_params)
# The string to be signed is: accessKey + timestamp + obtained parameter string.
string_to_sign = f"{self.api_key}{timestamp}{parameter_string}"
logger.debug(f"MEXC string to sign (method {method}): {string_to_sign}")
# Generate HMAC SHA256 signature
signature = hmac.new(
self.api_secret.encode('utf-8'),
string_to_sign.encode('utf-8'),
hashlib.sha256
).hexdigest()
logger.debug(f"MEXC generated signature: {signature}")
return signature
def _send_public_request(self, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Send a public API request to MEXC."""
if params is None:
params = {}
url = f"{self.base_url}/{self.api_version}/{endpoint}"
headers = {'Accept': 'application/json'}
try:
response = requests.request(method, url, params=params, headers=headers, timeout=10)
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
return response.json()
except requests.exceptions.HTTPError as http_err:
logger.error(f"HTTP error in public request to {endpoint}: {response.status_code} {response.reason}")
logger.error(f"Response content: {response.text}")
return {}
except requests.exceptions.ConnectionError as conn_err:
logger.error(f"Connection error in public request to {endpoint}: {conn_err}")
return {}
except requests.exceptions.Timeout as timeout_err:
logger.error(f"Timeout error in public request to {endpoint}: {timeout_err}")
return {}
except Exception as e:
logger.error(f"Error in public request to {endpoint}: {e}")
return {}
def _send_private_request(self, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
"""Send a private request to the exchange with proper signature"""
if params is None:
params = {}
timestamp = str(int(time.time() * 1000))
# Add timestamp and recvWindow to params for signature and request
params['timestamp'] = timestamp
params['recvWindow'] = self.recv_window
signature = self._generate_signature(timestamp, method, endpoint, params)
params['signature'] = signature
headers = {
"X-MEXC-APIKEY": self.api_key,
"Request-Time": timestamp
}
# For spot API, use the correct endpoint format
if not endpoint.startswith('api/v3/'):
endpoint = f"api/v3/{endpoint}"
url = f"{self.base_url}/{endpoint}"
try:
if method.upper() == "GET":
response = self.session.get(url, headers=headers, params=params, timeout=10)
elif method.upper() == "POST":
# MEXC expects POST parameters as JSON in the request body, not as query string
# The signature is generated from the JSON string of parameters.
# We need to exclude 'signature' from the JSON body sent, as it's for the header.
params_for_body = {k: v for k, v in params.items() if k != 'signature'}
response = self.session.post(url, headers=headers, json=params_for_body, timeout=10)
else:
logger.error(f"Unsupported method: {method}")
return None
response.raise_for_status()
data = response.json()
# For successful responses, return the data directly
# MEXC doesn't always use 'success' field for successful operations
if response.status_code == 200:
return data
else:
logger.error(f"API error: Status Code: {response.status_code}, Response: {response.text}")
return None
except requests.exceptions.HTTPError as http_err:
logger.error(f"HTTP error for {endpoint}: Status Code: {response.status_code}, Response: {response.text}")
logger.error(f"HTTP error details: {http_err}")
return None
except Exception as e:
logger.error(f"Request error for {endpoint}: {e}")
return None
def get_account_info(self) -> Dict[str, Any]:
"""Get account information"""
endpoint = "account"
result = self._send_private_request("GET", endpoint, {})
return result if result is not None else {}
def get_balance(self, asset: str) -> float:
"""Get available balance for a specific asset."""
account_info = self.get_account_info()
if account_info and 'balances' in account_info:
for balance in account_info['balances']:
if balance.get('asset') == asset.upper():
return float(balance.get('free', 0.0))
logger.warning(f"Could not retrieve free balance for {asset}")
return 0.0
def get_ticker(self, symbol: str) -> Optional[Dict[str, Any]]:
"""Get ticker information for a symbol."""
formatted_symbol = self._format_spot_symbol(symbol)
endpoint = "ticker/24hr"
params = {'symbol': formatted_symbol}
response = self._send_public_request('GET', endpoint, params)
if isinstance(response, dict):
ticker_data: Dict[str, Any] = response
elif isinstance(response, list) and len(response) > 0:
found_ticker = next((item for item in response if item.get('symbol') == formatted_symbol), None)
if found_ticker:
ticker_data = found_ticker
else:
logger.error(f"Ticker data for {formatted_symbol} not found in response list.")
return None
else:
logger.error(f"Unexpected ticker response format: {response}")
return None
# At this point, ticker_data is guaranteed to be a Dict[str, Any] due to the above logic
# If it was None, we would have returned early.
# Extract relevant info and format for universal use
last_price = float(ticker_data.get('lastPrice', 0))
bid_price = float(ticker_data.get('bidPrice', 0))
ask_price = float(ticker_data.get('askPrice', 0))
volume = float(ticker_data.get('volume', 0)) # Base asset volume
# Determine price change and percent change
price_change = float(ticker_data.get('priceChange', 0))
price_change_percent = float(ticker_data.get('priceChangePercent', 0))
logger.info(f"MEXC: Got ticker from {endpoint} for {symbol}: ${last_price:.2f}")
return {
'symbol': formatted_symbol,
'last': last_price,
'bid': bid_price,
'ask': ask_price,
'volume': volume,
'high': float(ticker_data.get('highPrice', 0)),
'low': float(ticker_data.get('lowPrice', 0)),
'change': price_change_percent, # This is usually priceChangePercent
'exchange': 'MEXC',
'raw_data': ticker_data
}
def get_api_symbols(self) -> List[str]:
"""Get list of symbols supported for API trading"""
try:
endpoint = "selfSymbols"
result = self._send_private_request("GET", endpoint, {})
if result and 'data' in result:
return result['data']
elif isinstance(result, list):
return result
else:
logger.warning(f"Unexpected response format for API symbols: {result}")
return []
except Exception as e:
logger.error(f"Error getting API symbols: {e}")
return []
def is_symbol_supported(self, symbol: str) -> bool:
"""Check if a symbol is supported for API trading"""
formatted_symbol = self._format_spot_symbol(symbol)
supported_symbols = self.get_api_symbols()
return formatted_symbol in supported_symbols
def place_order(self, symbol: str, side: str, order_type: str, quantity: float, price: Optional[float] = None) -> Dict[str, Any]:
"""Place a new order on MEXC."""
formatted_symbol = self._format_spot_symbol(symbol)
# Check if symbol is supported for API trading
if not self.is_symbol_supported(symbol):
supported_symbols = self.get_api_symbols()
logger.error(f"Symbol {formatted_symbol} is not supported for API trading")
logger.info(f"Supported symbols include: {supported_symbols[:10]}...") # Show first 10
return {}
# Format quantity according to symbol precision requirements
formatted_quantity = self._format_quantity_for_symbol(formatted_symbol, quantity)
if formatted_quantity is None:
logger.error(f"MEXC: Failed to format quantity {quantity} for {formatted_symbol}")
return {}
# Handle order type restrictions for specific symbols
final_order_type = self._adjust_order_type_for_symbol(formatted_symbol, order_type.upper())
# Get price for limit orders
final_price = price
if final_order_type == 'LIMIT' and price is None:
# Get current market price
ticker = self.get_ticker(symbol)
if ticker and 'last' in ticker:
final_price = ticker['last']
logger.info(f"MEXC: Using market price ${final_price:.2f} for LIMIT order")
else:
logger.error(f"MEXC: Could not get market price for LIMIT order on {formatted_symbol}")
return {}
endpoint = "order"
params: Dict[str, Any] = {
'symbol': formatted_symbol,
'side': side.upper(),
'type': final_order_type,
'quantity': str(formatted_quantity) # Quantity must be a string
}
if final_price is not None:
params['price'] = str(final_price) # Price must be a string for limit orders
logger.info(f"MEXC: Placing {side.upper()} {final_order_type} order for {formatted_quantity} {formatted_symbol} at price {final_price}")
try:
# MEXC API endpoint for placing orders is /api/v3/order (POST)
order_result = self._send_private_request('POST', endpoint, params)
if order_result is not None:
logger.info(f"MEXC: Order placed successfully: {order_result}")
return order_result
else:
logger.error(f"MEXC: Error placing order: request returned None")
return {}
except Exception as e:
logger.error(f"MEXC: Exception placing order: {e}")
return {}
def _format_quantity_for_symbol(self, formatted_symbol: str, quantity: float) -> Optional[float]:
"""Format quantity according to symbol precision requirements"""
try:
# Symbol-specific precision rules
if formatted_symbol == 'ETHUSDC':
# ETHUSDC requires max 5 decimal places, step size 0.000001
formatted_qty = round(quantity, 5)
# Ensure it meets minimum step size
step_size = 0.000001
formatted_qty = round(formatted_qty / step_size) * step_size
# Round again to remove floating point errors
formatted_qty = round(formatted_qty, 6)
logger.info(f"MEXC: Formatted ETHUSDC quantity {quantity} -> {formatted_qty}")
return formatted_qty
elif formatted_symbol == 'BTCUSDC':
# Assume similar precision for BTC
formatted_qty = round(quantity, 6)
step_size = 0.000001
formatted_qty = round(formatted_qty / step_size) * step_size
formatted_qty = round(formatted_qty, 6)
return formatted_qty
else:
# Default formatting - 6 decimal places
return round(quantity, 6)
except Exception as e:
logger.error(f"Error formatting quantity for {formatted_symbol}: {e}")
return None
def _adjust_order_type_for_symbol(self, formatted_symbol: str, order_type: str) -> str:
"""Adjust order type based on symbol restrictions"""
if formatted_symbol == 'ETHUSDC':
# ETHUSDC only supports LIMIT and LIMIT_MAKER orders
if order_type == 'MARKET':
logger.info(f"MEXC: Converting MARKET order to LIMIT for {formatted_symbol} (MARKET not supported)")
return 'LIMIT'
return order_type
def cancel_order(self, symbol: str, order_id: str) -> Dict[str, Any]:
"""Cancel an existing order on MEXC."""
formatted_symbol = self._format_spot_symbol(symbol)
endpoint = "order"
params = {
'symbol': formatted_symbol,
'orderId': order_id
}
logger.info(f"MEXC: Cancelling order {order_id} for {formatted_symbol}")
try:
# MEXC API endpoint for cancelling orders is /api/v3/order (DELETE)
cancel_result = self._send_private_request('DELETE', endpoint, params)
if cancel_result:
logger.info(f"MEXC: Order cancelled successfully: {cancel_result}")
return cancel_result
else:
logger.error(f"MEXC: Error cancelling order: {cancel_result}")
return {}
except Exception as e:
logger.error(f"MEXC: Exception cancelling order: {e}")
return {}
def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]:
"""Get the status of an order on MEXC."""
formatted_symbol = self._format_spot_symbol(symbol)
endpoint = "order"
params = {
'symbol': formatted_symbol,
'orderId': order_id
}
logger.info(f"MEXC: Getting status for order {order_id} for {formatted_symbol}")
try:
# MEXC API endpoint for order status is /api/v3/order (GET)
status_result = self._send_private_request('GET', endpoint, params)
if status_result:
logger.info(f"MEXC: Order status retrieved: {status_result}")
return status_result
else:
logger.error(f"MEXC: Error getting order status: {status_result}")
return {}
except Exception as e:
logger.error(f"MEXC: Exception getting order status: {e}")
return {}
def get_open_orders(self, symbol: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get all open orders on MEXC for a symbol or all symbols."""
endpoint = "openOrders"
params = {}
if symbol:
params['symbol'] = self._format_spot_symbol(symbol)
logger.info(f"MEXC: Getting open orders for {symbol if symbol else 'all symbols'}")
try:
# MEXC API endpoint for open orders is /api/v3/openOrders (GET)
open_orders = self._send_private_request('GET', endpoint, params)
if open_orders and isinstance(open_orders, list):
logger.info(f"MEXC: Retrieved {len(open_orders)} open orders.")
return open_orders
else:
logger.error(f"MEXC: Error getting open orders: {open_orders}")
return []
except Exception as e:
logger.error(f"MEXC: Exception getting open orders: {e}")
return []
def get_my_trades(self, symbol: str, limit: int = 100) -> List[Dict[str, Any]]:
"""Get trade history for a specific symbol."""
formatted_symbol = self._format_spot_symbol(symbol)
endpoint = "myTrades"
params = {'symbol': formatted_symbol, 'limit': limit}
logger.info(f"MEXC: Getting trade history for {formatted_symbol} (limit: {limit})")
try:
# MEXC API endpoint for trade history is /api/v3/myTrades (GET)
trade_history = self._send_private_request('GET', endpoint, params)
if trade_history and isinstance(trade_history, list):
logger.info(f"MEXC: Retrieved {len(trade_history)} trade records.")
return trade_history
else:
logger.error(f"MEXC: Error getting trade history: {trade_history}")
return []
except Exception as e:
logger.error(f"MEXC: Exception getting trade history: {e}")
return []
def get_server_time(self) -> int:
"""Get current MEXC server time in milliseconds."""
endpoint = "time"
response = self._send_public_request('GET', endpoint)
if response and 'serverTime' in response:
return int(response['serverTime'])
logger.error("Failed to get MEXC server time.")
return int(time.time() * 1000) # Fallback to local time
def get_all_balances(self) -> Dict[str, Dict[str, float]]:
"""Get all asset balances from MEXC account."""
account_info = self.get_account_info()
balances = {}
if account_info and 'balances' in account_info:
for balance in account_info['balances']:
asset = balance.get('asset')
free = float(balance.get('free'))
locked = float(balance.get('locked'))
if asset:
balances[asset.upper()] = {'free': free, 'locked': locked, 'total': free + locked}
return balances
def get_trading_fees(self) -> Dict[str, Any]:
"""Get current trading fee rates from MEXC API"""
endpoint = "account/commission"
response = self._send_private_request('GET', endpoint)
if response and 'data' in response:
fees_data = response['data']
return {
'maker': float(fees_data.get('makerCommission', 0.0)),
'taker': float(fees_data.get('takerCommission', 0.0)),
'default': float(fees_data.get('defaultCommission', 0.0))
}
logger.error("Failed to get trading fees from MEXC API.")
return {}
def get_symbol_trading_fees(self, symbol: str) -> Dict[str, Any]:
"""Get trading fee rates for a specific symbol from MEXC API"""
formatted_symbol = self._format_spot_symbol(symbol)
endpoint = "account/commission"
params = {'symbol': formatted_symbol}
response = self._send_private_request('GET', endpoint, params)
if response and 'data' in response:
fees_data = response['data']
return {
'maker': float(fees_data.get('makerCommission', 0.0)),
'taker': float(fees_data.get('takerCommission', 0.0)),
'default': float(fees_data.get('defaultCommission', 0.0))
}
logger.error(f"Failed to get trading fees for {symbol} from MEXC API.")
return {}

View File

@@ -1,254 +0,0 @@
"""
Trading Agent Test Script
This script demonstrates how to use the swappable exchange modules
to connect to and interact with different cryptocurrency exchanges.
Usage:
python -m NN.exchanges.trading_agent_test --exchange binance --test-mode
"""
import argparse
import logging
import os
import sys
import time
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("exchange_test.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger("exchange_test")
# Import exchange interfaces
try:
from .exchange_interface import ExchangeInterface
from .binance_interface import BinanceInterface
from .mexc_interface import MEXCInterface
except ImportError:
# When running as standalone script
from exchange_interface import ExchangeInterface
from binance_interface import BinanceInterface
from mexc_interface import MEXCInterface
def create_exchange(exchange_name: str, api_key: str = None, api_secret: str = None, test_mode: bool = True) -> ExchangeInterface:
"""Create an exchange interface instance.
Args:
exchange_name: Name of the exchange ('binance' or 'mexc')
api_key: API key for the exchange
api_secret: API secret for the exchange
test_mode: If True, use test/sandbox environment
Returns:
ExchangeInterface: The exchange interface instance
"""
exchange_name = exchange_name.lower()
if exchange_name == 'binance':
return BinanceInterface(api_key, api_secret, test_mode)
elif exchange_name == 'mexc':
return MEXCInterface(api_key, api_secret, test_mode)
else:
raise ValueError(f"Unsupported exchange: {exchange_name}. Supported exchanges: binance, mexc")
def test_exchange(exchange: ExchangeInterface, symbols: list = None):
"""Test the exchange interface.
Args:
exchange: Exchange interface instance
symbols: List of symbols to test with (e.g., ['BTC/USDT', 'ETH/USDT'])
"""
if symbols is None:
symbols = ['BTC/USDT', 'ETH/USDT']
# Test connection
logger.info(f"Testing connection to exchange...")
connected = exchange.connect()
if not connected and hasattr(exchange, 'api_key') and exchange.api_key:
logger.error("Failed to connect to exchange. Make sure your API credentials are correct.")
return False
elif not connected:
logger.warning("Running in read-only mode without API credentials.")
else:
logger.info("Connection successful with API credentials!")
# Test getting ticker data
ticker_success = True
for symbol in symbols:
try:
logger.info(f"Getting ticker data for {symbol}...")
ticker = exchange.get_ticker(symbol)
logger.info(f"Ticker for {symbol}: Last price: {ticker['last']}, Volume: {ticker['volume']}")
except Exception as e:
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
ticker_success = False
if not ticker_success:
logger.error("Failed to get ticker data. Exchange interface test failed.")
return False
# Test getting account balances if API keys are provided
if hasattr(exchange, 'api_key') and exchange.api_key:
logger.info("Testing account balance retrieval...")
try:
for base_asset in ['BTC', 'ETH', 'USDT']:
balance = exchange.get_balance(base_asset)
logger.info(f"Balance for {base_asset}: {balance}")
except Exception as e:
logger.error(f"Error getting account balances: {str(e)}")
logger.warning("Balance retrieval failed, but this is not critical if ticker data works.")
else:
logger.warning("API keys not provided. Skipping balance checks.")
logger.info("Exchange interface test completed successfully in read-only mode.")
return True
def execute_test_trades(exchange: ExchangeInterface, symbol: str, test_trade_amount: float = 0.001):
"""Execute test trades.
Args:
exchange: Exchange interface instance
symbol: Symbol to trade (e.g., 'BTC/USDT')
test_trade_amount: Amount to use for test trades
"""
if not hasattr(exchange, 'api_key') or not exchange.api_key:
logger.warning("API keys not provided. Skipping test trades.")
return
logger.info(f"Executing test trades for {symbol} with amount {test_trade_amount}...")
# Get current ticker for the symbol
try:
ticker = exchange.get_ticker(symbol)
logger.info(f"Current price for {symbol}: {ticker['last']}")
except Exception as e:
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
return
# Execute a buy order
try:
logger.info(f"Placing a test BUY order for {test_trade_amount} {symbol}...")
buy_order = exchange.execute_trade(symbol, 'BUY', quantity=test_trade_amount)
if buy_order:
logger.info(f"BUY order executed: {buy_order}")
order_id = buy_order.get('orderId')
# Get order status
if order_id:
time.sleep(2) # Wait for order to process
status = exchange.get_order_status(symbol, order_id)
logger.info(f"Order status: {status}")
else:
logger.error("BUY order failed.")
except Exception as e:
logger.error(f"Error executing BUY order: {str(e)}")
# Wait before selling
time.sleep(5)
# Execute a sell order
try:
logger.info(f"Placing a test SELL order for {test_trade_amount} {symbol}...")
sell_order = exchange.execute_trade(symbol, 'SELL', quantity=test_trade_amount)
if sell_order:
logger.info(f"SELL order executed: {sell_order}")
order_id = sell_order.get('orderId')
# Get order status
if order_id:
time.sleep(2) # Wait for order to process
status = exchange.get_order_status(symbol, order_id)
logger.info(f"Order status: {status}")
else:
logger.error("SELL order failed.")
except Exception as e:
logger.error(f"Error executing SELL order: {str(e)}")
# Get open orders
try:
logger.info("Getting open orders...")
open_orders = exchange.get_open_orders(symbol)
if open_orders:
logger.info(f"Open orders: {open_orders}")
# Cancel any open orders
for order in open_orders:
order_id = order.get('orderId')
if order_id:
logger.info(f"Cancelling order {order_id}...")
cancelled = exchange.cancel_order(symbol, order_id)
logger.info(f"Order cancelled: {cancelled}")
else:
logger.info("No open orders.")
except Exception as e:
logger.error(f"Error getting/cancelling open orders: {str(e)}")
def main():
"""Main function for testing exchange interfaces."""
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Test exchange interfaces")
parser.add_argument('--exchange', type=str, default='binance', choices=['binance', 'mexc'],
help='Exchange to test')
parser.add_argument('--api-key', type=str, default=None,
help='API key for the exchange')
parser.add_argument('--api-secret', type=str, default=None,
help='API secret for the exchange')
parser.add_argument('--test-mode', action='store_true',
help='Use test/sandbox environment')
parser.add_argument('--symbols', nargs='+', default=['BTC/USDT', 'ETH/USDT'],
help='Symbols to test with')
parser.add_argument('--execute-trades', action='store_true',
help='Execute test trades (use with caution!)')
parser.add_argument('--test-trade-amount', type=float, default=0.001,
help='Amount to use for test trades')
args = parser.parse_args()
# Use environment variables for API keys if not provided
api_key = args.api_key or os.environ.get(f"{args.exchange.upper()}_API_KEY")
api_secret = args.api_secret or os.environ.get(f"{args.exchange.upper()}_API_SECRET")
# Create exchange interface
try:
exchange = create_exchange(
exchange_name=args.exchange,
api_key=api_key,
api_secret=api_secret,
test_mode=args.test_mode
)
logger.info(f"Created {args.exchange} exchange interface")
logger.info(f"Test mode: {args.test_mode}")
# Test exchange
if test_exchange(exchange, args.symbols):
logger.info("Exchange interface test passed!")
# Execute test trades if requested
if args.execute_trades:
logger.warning("Executing test trades. This will use real funds!")
execute_test_trades(
exchange=exchange,
symbol=args.symbols[0],
test_trade_amount=args.test_trade_amount
)
else:
logger.error("Exchange interface test failed!")
except Exception as e:
logger.error(f"Error testing exchange interface: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return 1
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -1,21 +0,0 @@
"""
Neural Network Models
====================
This package contains the neural network models used in the trading system:
- CNN Model: Deep convolutional neural network for feature extraction
- DQN Agent: Deep Q-Network for reinforcement learning
- COB RL Model: Specialized RL model for order book data
- Advanced Transformer: High-performance transformer for trading
PyTorch implementation only.
"""
from NN.models.cnn_model import EnhancedCNNModel as CNNModel
from NN.models.dqn_agent import DQNAgent
from NN.models.cob_rl_model import MassiveRLNetwork, COBRLModelInterface
from NN.models.advanced_transformer_trading import AdvancedTradingTransformer, TradingTransformerConfig
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface
__all__ = ['CNNModel', 'DQNAgent', 'MassiveRLNetwork', 'COBRLModelInterface', 'AdvancedTradingTransformer', 'TradingTransformerConfig',
'ModelInterface', 'CNNModelInterface', 'RLAgentInterface', 'ExtremaTrainerInterface']

View File

@@ -267,7 +267,17 @@ class COBRLModelInterface(ModelInterface):
logger.info(f"COB RL Model Interface initialized on {self.device}")
<<<<<<< HEAD
def predict(self, cob_features) -> Dict[str, Any]:
=======
def to(self, device):
"""PyTorch-style device movement method"""
self.device = device
self.model = self.model.to(device)
return self
def predict(self, cob_features: np.ndarray) -> Dict[str, Any]:
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
"""Make prediction using the model"""
self.model.eval()
with torch.no_grad():

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@ import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
import time
import logging
import torch.nn.functional as F
from typing import List, Tuple, Dict, Any, Optional, Union
@@ -80,6 +81,9 @@ class EnhancedCNN(nn.Module):
self.n_actions = n_actions
self.confidence_threshold = confidence_threshold
# Training data storage
self.training_data = []
# Calculate input dimensions
if isinstance(input_shape, (list, tuple)):
if len(input_shape) == 3: # [channels, height, width]
@@ -265,8 +269,9 @@ class EnhancedCNN(nn.Module):
nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither
)
# ULTRA MASSIVE multi-timeframe price prediction heads
self.price_pred_immediate = nn.Sequential(
# ULTRA MASSIVE price direction prediction head
# Outputs single direction and confidence values
self.price_direction_head = nn.Sequential(
nn.Linear(1024, 1024), # Increased from 512
nn.ReLU(),
nn.Dropout(0.3),
@@ -275,32 +280,13 @@ class EnhancedCNN(nn.Module):
nn.Dropout(0.3),
nn.Linear(512, 256), # Increased from 128
nn.ReLU(),
nn.Linear(256, 3) # Up, Down, Sideways
nn.Linear(256, 2) # [direction, confidence]
)
self.price_pred_midterm = nn.Sequential(
nn.Linear(1024, 1024), # Increased from 512
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512), # Increased from 256
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256), # Increased from 128
nn.ReLU(),
nn.Linear(256, 3) # Up, Down, Sideways
)
self.price_pred_longterm = nn.Sequential(
nn.Linear(1024, 1024), # Increased from 512
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512), # Increased from 256
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256), # Increased from 128
nn.ReLU(),
nn.Linear(256, 3) # Up, Down, Sideways
)
# Direction activation (tanh for -1 to 1)
self.direction_activation = nn.Tanh()
# Confidence activation (sigmoid for 0 to 1)
self.confidence_activation = nn.Sigmoid()
# ULTRA MASSIVE value prediction with ensemble approaches
self.price_pred_value = nn.Sequential(
@@ -371,21 +357,45 @@ class EnhancedCNN(nn.Module):
nn.Linear(128, 4) # Low risk, medium risk, high risk, extreme risk
)
def _memory_barrier(self, tensor: torch.Tensor) -> torch.Tensor:
"""Create a memory barrier to prevent in-place operation issues"""
return tensor.detach().clone().requires_grad_(tensor.requires_grad)
def _check_rebuild_network(self, features):
"""Check if network needs to be rebuilt for different feature dimensions"""
"""DEPRECATED: Network should have fixed architecture - no runtime rebuilding"""
if features != self.feature_dim:
logger.info(f"Rebuilding network for new feature dimension: {features} (was {self.feature_dim})")
self.feature_dim = features
self._build_network()
# Move to device after rebuilding
self.to(self.device)
return True
logger.error(f"CRITICAL: Input feature dimension mismatch! Expected {self.feature_dim}, got {features}")
logger.error("This indicates a bug in data preprocessing - input should be fixed size!")
logger.error("Network architecture should NOT change at runtime!")
raise ValueError(f"Input dimension mismatch: expected {self.feature_dim}, got {features}")
return False
def forward(self, x):
"""Forward pass through the ULTRA MASSIVE network"""
batch_size = x.size(0)
# Validate input dimensions to prevent zero-element tensor issues
if x.numel() == 0:
logger.error(f"Forward pass received empty tensor with shape {x.shape}")
# Return default outputs for all 5 expected values to prevent crash
default_q_values = torch.zeros(batch_size, self.n_actions, device=x.device)
default_extrema = torch.zeros(batch_size, 3, device=x.device) # bottom/top/neither
default_price_pred = torch.zeros(batch_size, 1, device=x.device)
default_features = torch.zeros(batch_size, 1024, device=x.device)
default_advanced = torch.zeros(batch_size, 1, device=x.device)
return default_q_values, default_extrema, default_price_pred, default_features, default_advanced
# Check for zero feature dimensions
if len(x.shape) > 1 and any(dim == 0 for dim in x.shape[1:]):
logger.error(f"Forward pass received tensor with zero feature dimensions: {x.shape}")
# Return default outputs for all 5 expected values to prevent crash
default_q_values = torch.zeros(batch_size, self.n_actions, device=x.device)
default_extrema = torch.zeros(batch_size, 3, device=x.device) # bottom/top/neither
default_price_pred = torch.zeros(batch_size, 1, device=x.device)
default_features = torch.zeros(batch_size, 1024, device=x.device)
default_advanced = torch.zeros(batch_size, 1, device=x.device)
return default_q_values, default_extrema, default_price_pred, default_features, default_advanced
# Process different input shapes
if len(x.shape) > 2:
# Handle 4D input [batch, timeframes, window, features] or 3D input [batch, timeframes, features]
@@ -397,10 +407,11 @@ class EnhancedCNN(nn.Module):
# Now x is 3D: [batch, timeframes, features]
x_reshaped = x
# Check if the feature dimension has changed and rebuild if necessary
if x_reshaped.size(1) * x_reshaped.size(2) != self.feature_dim:
total_features = x_reshaped.size(1) * x_reshaped.size(2)
self._check_rebuild_network(total_features)
# Validate input dimensions (should be fixed)
total_features = x_reshaped.size(1) * x_reshaped.size(2)
if total_features != self.feature_dim:
logger.error(f"Input dimension mismatch: expected {self.feature_dim}, got {total_features}")
raise ValueError(f"Input dimension mismatch: expected {self.feature_dim}, got {total_features}")
# Apply ultra massive convolutions
x_conv = self.conv_layers(x_reshaped)
@@ -413,9 +424,10 @@ class EnhancedCNN(nn.Module):
# For 2D input [batch, features]
x_flat = x
# Check if dimensions have changed
# Validate input dimensions (should be fixed)
if x_flat.size(1) != self.feature_dim:
self._check_rebuild_network(x_flat.size(1))
logger.error(f"Input dimension mismatch: expected {self.feature_dim}, got {x_flat.size(1)}")
raise ValueError(f"Input dimension mismatch: expected {self.feature_dim}, got {x_flat.size(1)}")
# Apply ULTRA MASSIVE FC layers to get base features
features = self.fc_layers(x_flat) # [batch, 1024]
@@ -464,10 +476,14 @@ class EnhancedCNN(nn.Module):
# Extrema predictions (bottom/top/neither detection)
extrema_pred = self.extrema_head(features_refined)
# Multi-timeframe price movement predictions
price_immediate = self.price_pred_immediate(features_refined)
price_midterm = self.price_pred_midterm(features_refined)
price_longterm = self.price_pred_longterm(features_refined)
# Price direction predictions
price_direction_raw = self.price_direction_head(features_refined)
# Apply separate activations to direction and confidence
direction = self.direction_activation(price_direction_raw[:, 0:1]) # -1 to 1
confidence = self.confidence_activation(price_direction_raw[:, 1:2]) # 0 to 1
price_direction_pred = torch.cat([direction, confidence], dim=1) # [batch, 2]
price_values = self.price_pred_value(features_refined)
# Additional specialized predictions for enhanced accuracy
@@ -476,38 +492,42 @@ class EnhancedCNN(nn.Module):
market_regime_pred = self.market_regime_head(features_refined)
risk_pred = self.risk_head(features_refined)
# Package all price predictions
price_predictions = {
'immediate': price_immediate,
'midterm': price_midterm,
'longterm': price_longterm,
'values': price_values
}
# Use the price direction prediction directly (already [batch, 2])
price_direction_tensor = price_direction_pred
# Package additional predictions for enhanced decision making
advanced_predictions = {
'volatility': volatility_pred,
'support_resistance': support_resistance_pred,
'market_regime': market_regime_pred,
'risk_assessment': risk_pred
}
# Package additional predictions into a single tensor (use volatility as primary)
# For compatibility with DQN agent, we return volatility_pred as the advanced prediction tensor
advanced_pred_tensor = volatility_pred
return q_values, extrema_pred, price_predictions, features_refined, advanced_predictions
return q_values, extrema_pred, price_direction_tensor, features_refined, advanced_pred_tensor
def act(self, state, explore=True):
def act(self, state, explore=True) -> Tuple[int, float, List[float]]:
"""Enhanced action selection with ultra massive model predictions"""
if explore and np.random.random() < 0.1: # 10% random exploration
return np.random.choice(self.n_actions)
self.eval()
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
# Accept both NumPy arrays and already-built torch tensors
if isinstance(state, torch.Tensor):
state_tensor = state.detach().to(self.device)
if state_tensor.dim() == 1:
state_tensor = state_tensor.unsqueeze(0)
else:
# Convert to tensor **directly on the target device** to avoid intermediate CPU copies
state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device)
if state_tensor.dim() == 1:
state_tensor = state_tensor.unsqueeze(0)
with torch.no_grad():
q_values, extrema_pred, price_predictions, features, advanced_predictions = self(state_tensor)
q_values, extrema_pred, price_direction_predictions, features, advanced_predictions = self(state_tensor)
# Process price direction predictions
if price_direction_predictions is not None:
self.process_price_direction_predictions(price_direction_predictions)
# Apply softmax to get action probabilities
action_probs = torch.softmax(q_values, dim=1)
action = torch.argmax(action_probs, dim=1).item()
action_probs_tensor = torch.softmax(q_values, dim=1)
action_idx = int(torch.argmax(action_probs_tensor, dim=1).item())
confidence = float(action_probs_tensor[0, action_idx].item()) # Confidence of the chosen action
action_probs = action_probs_tensor.squeeze(0).tolist() # Convert to list of floats for return
# Log advanced predictions for better decision making
if hasattr(self, '_log_predictions') and self._log_predictions:
@@ -537,7 +557,180 @@ class EnhancedCNN(nn.Module):
logger.info(f" Market Regime: {regime_labels[regime_class]} ({regime[regime_class]:.3f})")
logger.info(f" Risk Level: {risk_labels[risk_class]} ({risk[risk_class]:.3f})")
return action
return action_idx, confidence, action_probs
def process_price_direction_predictions(self, price_direction_pred: torch.Tensor) -> Dict[str, float]:
"""
Process price direction predictions and convert to standardized format
Args:
price_direction_pred: Tensor of shape (batch_size, 2) containing [direction, confidence]
Returns:
Dict with direction (-1 to 1) and confidence (0 to 1)
"""
try:
if price_direction_pred is None or price_direction_pred.numel() == 0:
return {}
# Extract direction and confidence values
direction_value = float(price_direction_pred[0, 0].item()) # -1 to 1
confidence_value = float(price_direction_pred[0, 1].item()) # 0 to 1
processed_directions = {
'direction': direction_value,
'confidence': confidence_value
}
# Store for later access
self.last_price_direction = processed_directions
return processed_directions
except Exception as e:
logger.error(f"Error processing price direction predictions: {e}")
return {}
def get_price_direction_vector(self) -> Dict[str, float]:
"""
Get the current price direction and confidence
Returns:
Dict with direction (-1 to 1) and confidence (0 to 1)
"""
return getattr(self, 'last_price_direction', {})
def get_price_direction_summary(self) -> Dict[str, Any]:
"""
Get a summary of price direction prediction
Returns:
Dict containing direction and confidence information
"""
try:
last_direction = getattr(self, 'last_price_direction', {})
if not last_direction:
return {
'direction_value': 0.0,
'confidence_value': 0.0,
'direction_label': "SIDEWAYS",
'discrete_direction': 0,
'strength': 0.0,
'weighted_strength': 0.0
}
direction_value = last_direction['direction']
confidence_value = last_direction['confidence']
# Convert to discrete direction
if direction_value > 0.1:
direction_label = "UP"
discrete_direction = 1
elif direction_value < -0.1:
direction_label = "DOWN"
discrete_direction = -1
else:
direction_label = "SIDEWAYS"
discrete_direction = 0
return {
'direction_value': float(direction_value),
'confidence_value': float(confidence_value),
'direction_label': direction_label,
'discrete_direction': discrete_direction,
'strength': abs(float(direction_value)),
'weighted_strength': abs(float(direction_value)) * float(confidence_value)
}
except Exception as e:
logger.error(f"Error calculating price direction summary: {e}")
return {
'direction_value': 0.0,
'confidence_value': 0.0,
'direction_label': "SIDEWAYS",
'discrete_direction': 0,
'strength': 0.0,
'weighted_strength': 0.0
}
def add_training_data(self, state, action, reward, position_pnl=0.0, has_position=False):
"""
Add training data to the model's training buffer with position-based reward enhancement
Args:
state: Input state
action: Action taken
reward: Base reward received
position_pnl: Current position P&L (0.0 if no position)
has_position: Whether we currently have an open position
"""
try:
# Enhance reward based on position status
enhanced_reward = self._calculate_position_enhanced_reward(
reward, action, position_pnl, has_position
)
self.training_data.append({
'state': state,
'action': action,
'reward': enhanced_reward,
'base_reward': reward, # Keep original reward for analysis
'position_pnl': position_pnl,
'has_position': has_position,
'timestamp': time.time()
})
# Keep only the last 1000 training samples
if len(self.training_data) > 1000:
self.training_data = self.training_data[-1000:]
except Exception as e:
logger.error(f"Error adding training data: {e}")
def _calculate_position_enhanced_reward(self, base_reward, action, position_pnl, has_position):
"""
Calculate position-enhanced reward to incentivize profitable trades and closing losing ones
Args:
base_reward: Original reward from price prediction accuracy
action: Action taken ('BUY', 'SELL', 'HOLD')
position_pnl: Current position P&L
has_position: Whether we have an open position
Returns:
Enhanced reward that incentivizes profitable behavior
"""
try:
enhanced_reward = base_reward
if has_position and position_pnl != 0.0:
# Position-based reward adjustments
pnl_factor = position_pnl / 100.0 # Normalize P&L to reasonable scale
if position_pnl > 0: # Profitable position
if action == "HOLD":
# Reward holding profitable positions (let winners run)
enhanced_reward += abs(pnl_factor) * 0.5
elif action in ["BUY", "SELL"]:
# Moderate reward for taking action on profitable positions
enhanced_reward += abs(pnl_factor) * 0.3
elif position_pnl < 0: # Losing position
if action == "HOLD":
# Penalty for holding losing positions (cut losses)
enhanced_reward -= abs(pnl_factor) * 0.8
elif action in ["BUY", "SELL"]:
# Reward for taking action to close losing positions
enhanced_reward += abs(pnl_factor) * 0.6
# Ensure reward doesn't become extreme
enhanced_reward = max(-5.0, min(5.0, enhanced_reward))
return enhanced_reward
except Exception as e:
logger.error(f"Error calculating position-enhanced reward: {e}")
return base_reward
def save(self, path):
"""Save model weights and architecture"""

View File

@@ -1 +0,0 @@
{"best_reward": 4791516.572471984, "best_episode": 3250, "best_pnl": 826842167451289.1, "best_win_rate": 0.47368421052631576, "date": "2025-04-01 10:19:16"}

View File

@@ -1,20 +0,0 @@
{
"supervised": {
"epochs_completed": 22650,
"best_val_pnl": 0.0,
"best_epoch": 50,
"best_win_rate": 0
},
"reinforcement": {
"episodes_completed": 0,
"best_reward": -Infinity,
"best_episode": 0,
"best_win_rate": 0
},
"hybrid": {
"iterations_completed": 453,
"best_combined_score": 0.0,
"training_started": "2025-04-09T10:30:42.510856",
"last_update": "2025-04-09T10:40:02.217840"
}
}

View File

@@ -1,326 +0,0 @@
{
"epochs_completed": 8,
"best_val_pnl": 0.0,
"best_epoch": 1,
"best_win_rate": 0.0,
"training_started": "2025-04-02T10:43:58.946682",
"last_update": "2025-04-02T10:44:10.940892",
"epochs": [
{
"epoch": 1,
"train_loss": 1.0950355529785156,
"val_loss": 1.1657923062642415,
"train_acc": 0.3255208333333333,
"val_acc": 0.0,
"train_pnl": 0.0,
"val_pnl": 0.0,
"train_win_rate": 0.0,
"val_win_rate": 0.0,
"best_position_size": 0.1,
"signal_distribution": {
"train": {
"BUY": 1.0,
"SELL": 0.0,
"HOLD": 0.0
},
"val": {
"BUY": 1.0,
"SELL": 0.0,
"HOLD": 0.0
}
},
"timestamp": "2025-04-02T10:44:01.840889",
"data_age": 2,
"cumulative_pnl": {
"train": 0.0,
"val": 0.0
},
"total_trades": {
"train": 0,
"val": 0
},
"overall_win_rate": {
"train": 0.0,
"val": 0.0
}
},
{
"epoch": 2,
"train_loss": 1.0831659038861592,
"val_loss": 1.1212460199991863,
"train_acc": 0.390625,
"val_acc": 0.0,
"train_pnl": 0.0,
"val_pnl": 0.0,
"train_win_rate": 0.0,
"val_win_rate": 0.0,
"best_position_size": 0.1,
"signal_distribution": {
"train": {
"BUY": 1.0,
"SELL": 0.0,
"HOLD": 0.0
},
"val": {
"BUY": 1.0,
"SELL": 0.0,
"HOLD": 0.0
}
},
"timestamp": "2025-04-02T10:44:03.134833",
"data_age": 4,
"cumulative_pnl": {
"train": 0.0,
"val": 0.0
},
"total_trades": {
"train": 0,
"val": 0
},
"overall_win_rate": {
"train": 0.0,
"val": 0.0
}
},
{
"epoch": 3,
"train_loss": 1.0740693012873332,
"val_loss": 1.0992945830027263,
"train_acc": 0.4739583333333333,
"val_acc": 0.0,
"train_pnl": 0.0,
"val_pnl": 0.0,
"train_win_rate": 0.0,
"val_win_rate": 0.0,
"best_position_size": 0.1,
"signal_distribution": {
"train": {
"BUY": 1.0,
"SELL": 0.0,
"HOLD": 0.0
},
"val": {
"BUY": 1.0,
"SELL": 0.0,
"HOLD": 0.0
}
},
"timestamp": "2025-04-02T10:44:04.425272",
"data_age": 5,
"cumulative_pnl": {
"train": 0.0,
"val": 0.0
},
"total_trades": {
"train": 0,
"val": 0
},
"overall_win_rate": {
"train": 0.0,
"val": 0.0
}
},
{
"epoch": 4,
"train_loss": 1.0747728943824768,
"val_loss": 1.0821794271469116,
"train_acc": 0.4609375,
"val_acc": 0.3229166666666667,
"train_pnl": 0.0,
"val_pnl": 0.0,
"train_win_rate": 0.0,
"val_win_rate": 0.0,
"best_position_size": 0.1,
"signal_distribution": {
"train": {
"BUY": 1.0,
"SELL": 0.0,
"HOLD": 0.0
},
"val": {
"BUY": 1.0,
"SELL": 0.0,
"HOLD": 0.0
}
},
"timestamp": "2025-04-02T10:44:05.716421",
"data_age": 6,
"cumulative_pnl": {
"train": 0.0,
"val": 0.0
},
"total_trades": {
"train": 0,
"val": 0
},
"overall_win_rate": {
"train": 0.0,
"val": 0.0
}
},
{
"epoch": 5,
"train_loss": 1.0489931503931682,
"val_loss": 1.0669521888097127,
"train_acc": 0.5833333333333334,
"val_acc": 1.0,
"train_pnl": 0.0,
"val_pnl": 0.0,
"train_win_rate": 0.0,
"val_win_rate": 0.0,
"best_position_size": 0.1,
"signal_distribution": {
"train": {
"BUY": 1.0,
"SELL": 0.0,
"HOLD": 0.0
},
"val": {
"BUY": 1.0,
"SELL": 0.0,
"HOLD": 0.0
}
},
"timestamp": "2025-04-02T10:44:07.007935",
"data_age": 8,
"cumulative_pnl": {
"train": 0.0,
"val": 0.0
},
"total_trades": {
"train": 0,
"val": 0
},
"overall_win_rate": {
"train": 0.0,
"val": 0.0
}
},
{
"epoch": 6,
"train_loss": 1.0533669590950012,
"val_loss": 1.0505590836207073,
"train_acc": 0.5104166666666666,
"val_acc": 1.0,
"train_pnl": 0.0,
"val_pnl": 0.0,
"train_win_rate": 0.0,
"val_win_rate": 0.0,
"best_position_size": 0.1,
"signal_distribution": {
"train": {
"BUY": 1.0,
"SELL": 0.0,
"HOLD": 0.0
},
"val": {
"BUY": 1.0,
"SELL": 0.0,
"HOLD": 0.0
}
},
"timestamp": "2025-04-02T10:44:08.296061",
"data_age": 9,
"cumulative_pnl": {
"train": 0.0,
"val": 0.0
},
"total_trades": {
"train": 0,
"val": 0
},
"overall_win_rate": {
"train": 0.0,
"val": 0.0
}
},
{
"epoch": 7,
"train_loss": 1.0456886688868205,
"val_loss": 1.0351698795954387,
"train_acc": 0.5651041666666666,
"val_acc": 1.0,
"train_pnl": 0.0,
"val_pnl": 0.0,
"train_win_rate": 0.0,
"val_win_rate": 0.0,
"best_position_size": 0.1,
"signal_distribution": {
"train": {
"BUY": 1.0,
"SELL": 0.0,
"HOLD": 0.0
},
"val": {
"BUY": 1.0,
"SELL": 0.0,
"HOLD": 0.0
}
},
"timestamp": "2025-04-02T10:44:09.607584",
"data_age": 10,
"cumulative_pnl": {
"train": 0.0,
"val": 0.0
},
"total_trades": {
"train": 0,
"val": 0
},
"overall_win_rate": {
"train": 0.0,
"val": 0.0
}
},
{
"epoch": 8,
"train_loss": 1.040040671825409,
"val_loss": 1.0227736632029216,
"train_acc": 0.6119791666666666,
"val_acc": 1.0,
"train_pnl": 0.0,
"val_pnl": 0.0,
"train_win_rate": 0.0,
"val_win_rate": 0.0,
"best_position_size": 0.1,
"signal_distribution": {
"train": {
"BUY": 1.0,
"SELL": 0.0,
"HOLD": 0.0
},
"val": {
"BUY": 1.0,
"SELL": 0.0,
"HOLD": 0.0
}
},
"timestamp": "2025-04-02T10:44:10.940892",
"data_age": 11,
"cumulative_pnl": {
"train": 0.0,
"val": 0.0
},
"total_trades": {
"train": 0,
"val": 0
},
"overall_win_rate": {
"train": 0.0,
"val": 0.0
}
}
],
"cumulative_pnl": {
"train": 0.0,
"val": 0.0
},
"total_trades": {
"train": 0,
"val": 0
},
"total_wins": {
"train": 0,
"val": 0
}
}

View File

@@ -1,192 +0,0 @@
{
"epochs_completed": 7,
"best_val_pnl": 0.002028853100759435,
"best_epoch": 6,
"best_win_rate": 0.5157894736842106,
"training_started": "2025-03-31T02:50:10.418670",
"last_update": "2025-03-31T02:50:15.227593",
"epochs": [
{
"epoch": 1,
"train_loss": 1.1206786036491394,
"val_loss": 1.0542699098587036,
"train_acc": 0.11197916666666667,
"val_acc": 0.25,
"train_pnl": 0.0,
"val_pnl": 0.0,
"train_win_rate": 0.0,
"val_win_rate": 0.0,
"best_position_size": 0.1,
"signal_distribution": {
"train": {
"BUY": 0.0,
"SELL": 0.0,
"HOLD": 1.0
},
"val": {
"BUY": 0.0,
"SELL": 0.0,
"HOLD": 1.0
}
},
"timestamp": "2025-03-31T02:50:12.881423",
"data_age": 2
},
{
"epoch": 2,
"train_loss": 1.1266120672225952,
"val_loss": 1.072133183479309,
"train_acc": 0.1171875,
"val_acc": 0.25,
"train_pnl": 0.0,
"val_pnl": 0.0,
"train_win_rate": 0.0,
"val_win_rate": 0.0,
"best_position_size": 0.1,
"signal_distribution": {
"train": {
"BUY": 0.0,
"SELL": 0.0,
"HOLD": 1.0
},
"val": {
"BUY": 0.0,
"SELL": 0.0,
"HOLD": 1.0
}
},
"timestamp": "2025-03-31T02:50:13.186840",
"data_age": 2
},
{
"epoch": 3,
"train_loss": 1.1415620843569438,
"val_loss": 1.1701548099517822,
"train_acc": 0.1015625,
"val_acc": 0.5208333333333334,
"train_pnl": 0.0,
"val_pnl": 0.0,
"train_win_rate": 0.0,
"val_win_rate": 0.0,
"best_position_size": 0.1,
"signal_distribution": {
"train": {
"BUY": 0.0,
"SELL": 0.0,
"HOLD": 1.0
},
"val": {
"BUY": 0.0,
"SELL": 0.0,
"HOLD": 1.0
}
},
"timestamp": "2025-03-31T02:50:13.442018",
"data_age": 3
},
{
"epoch": 4,
"train_loss": 1.1331567962964375,
"val_loss": 1.070081114768982,
"train_acc": 0.09375,
"val_acc": 0.22916666666666666,
"train_pnl": 0.010650217327384765,
"val_pnl": -0.0007049481907895126,
"train_win_rate": 0.49279538904899134,
"val_win_rate": 0.40625,
"best_position_size": 0.1,
"signal_distribution": {
"train": {
"BUY": 0.0,
"SELL": 0.9036458333333334,
"HOLD": 0.09635416666666667
},
"val": {
"BUY": 0.0,
"SELL": 0.3333333333333333,
"HOLD": 0.6666666666666666
}
},
"timestamp": "2025-03-31T02:50:13.739899",
"data_age": 3
},
{
"epoch": 5,
"train_loss": 1.10965762535731,
"val_loss": 1.0485950708389282,
"train_acc": 0.12239583333333333,
"val_acc": 0.17708333333333334,
"train_pnl": 0.011924086862580204,
"val_pnl": 0.0,
"train_win_rate": 0.5070422535211268,
"val_win_rate": 0.0,
"best_position_size": 0.1,
"signal_distribution": {
"train": {
"BUY": 0.0,
"SELL": 0.7395833333333334,
"HOLD": 0.2604166666666667
},
"val": {
"BUY": 0.0,
"SELL": 0.0,
"HOLD": 1.0
}
},
"timestamp": "2025-03-31T02:50:14.073439",
"data_age": 3
},
{
"epoch": 6,
"train_loss": 1.1272419293721516,
"val_loss": 1.084235429763794,
"train_acc": 0.1015625,
"val_acc": 0.22916666666666666,
"train_pnl": 0.014825159601390072,
"val_pnl": 0.00405770620151887,
"train_win_rate": 0.4908616187989556,
"val_win_rate": 0.5157894736842106,
"best_position_size": 2.0,
"signal_distribution": {
"train": {
"BUY": 0.0,
"SELL": 1.0,
"HOLD": 0.0
},
"val": {
"BUY": 0.0,
"SELL": 1.0,
"HOLD": 0.0
}
},
"timestamp": "2025-03-31T02:50:14.658295",
"data_age": 4
},
{
"epoch": 7,
"train_loss": 1.1171108484268188,
"val_loss": 1.0741244554519653,
"train_acc": 0.1171875,
"val_acc": 0.22916666666666666,
"train_pnl": 0.0059474696523706605,
"val_pnl": 0.00405770620151887,
"train_win_rate": 0.4838709677419355,
"val_win_rate": 0.5157894736842106,
"best_position_size": 2.0,
"signal_distribution": {
"train": {
"BUY": 0.0,
"SELL": 0.7291666666666666,
"HOLD": 0.2708333333333333
},
"val": {
"BUY": 0.0,
"SELL": 1.0,
"HOLD": 0.0
}
},
"timestamp": "2025-03-31T02:50:15.227593",
"data_age": 4
}
]
}

View File

@@ -0,0 +1,512 @@
"""
Standardized CNN Model for Multi-Modal Trading System
This module extends the existing EnhancedCNN to work with standardized BaseDataInput format
and provides ModelOutput for cross-model feeding.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import logging
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
import sys
import os
# Add the project root to the path to import core modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from core.data_models import BaseDataInput, ModelOutput, create_model_output
from .enhanced_cnn import EnhancedCNN, SelfAttention, ResidualBlock
logger = logging.getLogger(__name__)
class StandardizedCNN(nn.Module):
"""
Standardized CNN Model that accepts BaseDataInput and outputs ModelOutput
Features:
- Accepts standardized BaseDataInput format
- Processes COB+OHLCV data: 300 frames (1s,1m,1h,1d) ETH + 300s 1s BTC
- Includes COB ±20 buckets and MA (1s,5s,15s,60s) of COB imbalance ±5 buckets
- Outputs BUY/SELL trading action with confidence scores
- Provides hidden states for cross-model feeding
- Integrates with checkpoint management system
"""
def __init__(self, model_name: str = "standardized_cnn_v1", confidence_threshold: float = 0.6):
"""
Initialize the standardized CNN model
Args:
model_name: Name identifier for this model instance
confidence_threshold: Minimum confidence threshold for predictions
"""
super(StandardizedCNN, self).__init__()
self.model_name = model_name
self.model_type = "cnn"
self.confidence_threshold = confidence_threshold
# Calculate expected input dimensions from BaseDataInput
self.expected_feature_dim = self._calculate_expected_features()
# Initialize the underlying enhanced CNN with calculated dimensions
self.enhanced_cnn = EnhancedCNN(
input_shape=self.expected_feature_dim,
n_actions=3, # BUY, SELL, HOLD
confidence_threshold=confidence_threshold
)
# Additional layers for processing BaseDataInput structure
self.input_processor = self._build_input_processor()
# Output processing layers
self.output_processor = self._build_output_processor()
# Optional numeric return head (predicts percent change for 1s,1m,1h,1d)
# Uses cnn_features (1024) to regress predicted returns per timeframe
self.return_head = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, 4) # [return_1s, return_1m, return_1h, return_1d]
)
# Device management
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.to(self.device)
try:
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
except Exception:
pass
logger.info(f"StandardizedCNN '{model_name}' initialized")
logger.info(f"Expected feature dimension: {self.expected_feature_dim}")
logger.info(f"Device: {self.device}")
def _calculate_expected_features(self) -> int:
"""
Calculate expected feature dimension from BaseDataInput structure
Based on actual BaseDataInput.get_feature_vector():
- OHLCV ETH: 300 frames x 4 timeframes x 5 features = 6000
- OHLCV BTC: 300 frames x 5 features = 1500
- COB features: ~184 features (actual from implementation)
- Technical indicators: 100 features (padded)
- Last predictions: 50 features (padded)
Total: ~7834 features (actual measured)
"""
return 7834 # Based on actual BaseDataInput.get_feature_vector() measurement
def _build_input_processor(self) -> nn.Module:
"""
Build input processing layers for BaseDataInput
Returns:
nn.Module: Input processing layers
"""
return nn.Sequential(
# Initial processing of raw BaseDataInput features
nn.Linear(self.expected_feature_dim, 4096),
nn.ReLU(),
nn.Dropout(0.2),
nn.BatchNorm1d(4096),
# Feature refinement
nn.Linear(4096, 2048),
nn.ReLU(),
nn.Dropout(0.2),
nn.BatchNorm1d(2048),
# Final feature extraction
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Dropout(0.1)
)
def _build_output_processor(self) -> nn.Module:
"""
Build output processing layers for standardized ModelOutput
Returns:
nn.Module: Output processing layers
"""
return nn.Sequential(
# Process CNN outputs for standardized format
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.2),
# Final action prediction
nn.Linear(512, 3), # BUY, SELL, HOLD
nn.Softmax(dim=1)
)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass through the standardized CNN
Args:
x: Input tensor from BaseDataInput.get_feature_vector()
Returns:
Tuple of (action_probabilities, hidden_states_dict)
"""
batch_size = x.size(0)
# Validate input dimensions
if x.size(1) != self.expected_feature_dim:
logger.warning(f"Input dimension mismatch: expected {self.expected_feature_dim}, got {x.size(1)}")
# Pad or truncate as needed
if x.size(1) < self.expected_feature_dim:
padding = torch.zeros(batch_size, self.expected_feature_dim - x.size(1), device=x.device)
x = torch.cat([x, padding], dim=1)
else:
x = x[:, :self.expected_feature_dim]
# Process input through input processor
processed_features = self.input_processor(x) # [batch, 1024]
# Get enhanced CNN predictions (using processed features as input)
# We need to reshape for the enhanced CNN which expects different input format
cnn_input = processed_features.unsqueeze(1) # Add sequence dimension
try:
q_values, extrema_pred, price_pred, cnn_features, advanced_pred = self.enhanced_cnn(cnn_input)
except Exception as e:
logger.warning(f"Enhanced CNN forward pass failed: {e}, using fallback")
# Fallback to direct processing
cnn_features = processed_features
q_values = torch.zeros(batch_size, 3, device=x.device)
extrema_pred = torch.zeros(batch_size, 3, device=x.device)
price_pred = torch.zeros(batch_size, 3, device=x.device)
advanced_pred = torch.zeros(batch_size, 5, device=x.device)
# Process outputs for standardized format
action_probs = self.output_processor(cnn_features) # [batch, 3]
# Predict numeric returns per timeframe from cnn_features
predicted_returns = self.return_head(cnn_features) # [batch, 4]
# Prepare hidden states for cross-model feeding
hidden_states = {
'processed_features': processed_features.detach(),
'cnn_features': cnn_features.detach(),
'q_values': q_values.detach(),
'extrema_predictions': extrema_pred.detach(),
'price_predictions': price_pred.detach(),
'advanced_predictions': advanced_pred.detach(),
'attention_weights': torch.ones(batch_size, 1, device=x.device) # Placeholder
}
return action_probs, hidden_states, predicted_returns.detach()
def predict_from_base_input(self, base_input: BaseDataInput) -> ModelOutput:
"""
Make prediction from BaseDataInput and return standardized ModelOutput
Args:
base_input: Standardized input data
Returns:
ModelOutput: Standardized model output
"""
try:
# Convert BaseDataInput to feature vector
feature_vector = base_input.get_feature_vector()
# Convert to tensor and add batch dimension
input_tensor = torch.tensor(feature_vector, dtype=torch.float32, device=self.device).unsqueeze(0)
# Set model to evaluation mode
self.eval()
with torch.no_grad():
# Forward pass
action_probs, hidden_states, predicted_returns = self.forward(input_tensor)
# Get action and confidence
action_probs_np = action_probs.squeeze(0).cpu().numpy()
action_idx = np.argmax(action_probs_np)
confidence = float(action_probs_np[action_idx])
# Map action index to action name
action_names = ['BUY', 'SELL', 'HOLD']
action = action_names[action_idx]
# Prepare predictions dictionary
predictions = {
'action': action,
'buy_probability': float(action_probs_np[0]),
'sell_probability': float(action_probs_np[1]),
'hold_probability': float(action_probs_np[2]),
'action_probabilities': action_probs_np.tolist(),
'extrema_detected': self._interpret_extrema(hidden_states.get('extrema_predictions')),
'price_direction': self._interpret_price_direction(hidden_states.get('price_predictions')),
'market_conditions': self._interpret_advanced_predictions(hidden_states.get('advanced_predictions'))
}
# Add numeric predicted returns per timeframe if available
try:
pr = predicted_returns.squeeze(0).cpu().numpy().tolist()
# Ensure length 4; if not, safely handle
if isinstance(pr, list) and len(pr) >= 4:
predictions['predicted_returns'] = pr[:4]
predictions['predicted_return_1s'] = float(pr[0])
predictions['predicted_return_1m'] = float(pr[1])
predictions['predicted_return_1h'] = float(pr[2])
predictions['predicted_return_1d'] = float(pr[3])
except Exception:
pass
# Prepare hidden states for cross-model feeding (convert tensors to numpy)
cross_model_states = {}
for key, tensor in hidden_states.items():
if isinstance(tensor, torch.Tensor):
cross_model_states[key] = tensor.squeeze(0).cpu().numpy().tolist()
else:
cross_model_states[key] = tensor
# Create metadata
metadata = {
'model_version': '1.0',
'confidence_threshold': self.confidence_threshold,
'feature_dimension': self.expected_feature_dim,
'processing_time_ms': 0, # Could add timing if needed
'input_validation': base_input.validate()
}
# Create standardized ModelOutput
model_output = ModelOutput(
model_type=self.model_type,
model_name=self.model_name,
symbol=base_input.symbol,
timestamp=datetime.now(),
confidence=confidence,
predictions=predictions,
hidden_states=cross_model_states,
metadata=metadata
)
return model_output
except Exception as e:
logger.error(f"Error in CNN prediction: {e}")
# Return default output
return self._create_default_output(base_input.symbol)
def _interpret_extrema(self, extrema_tensor: Optional[torch.Tensor]) -> str:
"""Interpret extrema predictions"""
if extrema_tensor is None:
return "unknown"
try:
extrema_probs = torch.softmax(extrema_tensor.squeeze(0), dim=0)
extrema_idx = torch.argmax(extrema_probs).item()
extrema_labels = ['bottom', 'top', 'neither']
return extrema_labels[extrema_idx]
except:
return "unknown"
def _interpret_price_direction(self, price_tensor: Optional[torch.Tensor]) -> str:
"""Interpret price direction predictions"""
if price_tensor is None:
return "unknown"
try:
price_probs = torch.softmax(price_tensor.squeeze(0), dim=0)
price_idx = torch.argmax(price_probs).item()
price_labels = ['up', 'down', 'sideways']
return price_labels[price_idx]
except:
return "unknown"
def _interpret_advanced_predictions(self, advanced_tensor: Optional[torch.Tensor]) -> Dict[str, str]:
"""Interpret advanced market predictions"""
if advanced_tensor is None:
return {"volatility": "unknown", "risk": "unknown"}
try:
# Assuming advanced predictions include volatility (5 classes)
if advanced_tensor.size(-1) >= 5:
volatility_probs = torch.softmax(advanced_tensor.squeeze(0)[:5], dim=0)
volatility_idx = torch.argmax(volatility_probs).item()
volatility_labels = ['very_low', 'low', 'medium', 'high', 'very_high']
volatility = volatility_labels[volatility_idx]
else:
volatility = "unknown"
return {
"volatility": volatility,
"risk": "medium" # Placeholder
}
except:
return {"volatility": "unknown", "risk": "unknown"}
def _create_default_output(self, symbol: str) -> ModelOutput:
"""Create default ModelOutput for error cases"""
return create_model_output(
model_type=self.model_type,
model_name=self.model_name,
symbol=symbol,
action='HOLD',
confidence=0.5,
metadata={'error': True, 'default_output': True}
)
def train_step(self, base_inputs: List[BaseDataInput], targets: List[str],
optimizer: torch.optim.Optimizer) -> float:
"""
Perform a single training step
Args:
base_inputs: List of BaseDataInput for training
targets: List of target actions ('BUY', 'SELL', 'HOLD')
optimizer: PyTorch optimizer
Returns:
float: Training loss
"""
self.train()
try:
# Convert inputs to tensors
feature_vectors = []
for base_input in base_inputs:
feature_vector = base_input.get_feature_vector()
feature_vectors.append(feature_vector)
input_tensor = torch.tensor(np.array(feature_vectors), dtype=torch.float32, device=self.device)
# Convert targets to tensor
action_to_idx = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
target_indices = [action_to_idx.get(target, 2) for target in targets]
target_tensor = torch.tensor(target_indices, dtype=torch.long, device=self.device)
# Forward pass
action_probs, _ = self.forward(input_tensor)
# Calculate loss
loss = F.cross_entropy(action_probs, target_tensor)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
return float(loss.item())
except Exception as e:
logger.error(f"Error in training step: {e}")
return float('inf')
def evaluate(self, base_inputs: List[BaseDataInput], targets: List[str]) -> Dict[str, float]:
"""
Evaluate model performance
Args:
base_inputs: List of BaseDataInput for evaluation
targets: List of target actions
Returns:
Dict containing evaluation metrics
"""
self.eval()
try:
correct = 0
total = len(base_inputs)
total_confidence = 0.0
with torch.no_grad():
for base_input, target in zip(base_inputs, targets):
model_output = self.predict_from_base_input(base_input)
predicted_action = model_output.predictions['action']
if predicted_action == target:
correct += 1
total_confidence += model_output.confidence
accuracy = correct / total if total > 0 else 0.0
avg_confidence = total_confidence / total if total > 0 else 0.0
return {
'accuracy': accuracy,
'avg_confidence': avg_confidence,
'correct_predictions': correct,
'total_predictions': total
}
except Exception as e:
logger.error(f"Error in evaluation: {e}")
return {'accuracy': 0.0, 'avg_confidence': 0.0, 'correct_predictions': 0, 'total_predictions': 0}
def save_checkpoint(self, filepath: str, metadata: Optional[Dict[str, Any]] = None):
"""
Save model checkpoint
Args:
filepath: Path to save checkpoint
metadata: Optional metadata to save with checkpoint
"""
try:
checkpoint = {
'model_state_dict': self.state_dict(),
'model_name': self.model_name,
'model_type': self.model_type,
'confidence_threshold': self.confidence_threshold,
'expected_feature_dim': self.expected_feature_dim,
'metadata': metadata or {},
'timestamp': datetime.now().isoformat()
}
torch.save(checkpoint, filepath)
logger.info(f"Checkpoint saved to {filepath}")
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
def load_checkpoint(self, filepath: str) -> bool:
"""
Load model checkpoint
Args:
filepath: Path to checkpoint file
Returns:
bool: True if loaded successfully, False otherwise
"""
try:
checkpoint = torch.load(filepath, map_location=self.device)
# Load model state
self.load_state_dict(checkpoint['model_state_dict'])
# Load configuration
self.model_name = checkpoint.get('model_name', self.model_name)
self.confidence_threshold = checkpoint.get('confidence_threshold', self.confidence_threshold)
self.expected_feature_dim = checkpoint.get('expected_feature_dim', self.expected_feature_dim)
logger.info(f"Checkpoint loaded from {filepath}")
return True
except Exception as e:
logger.error(f"Error loading checkpoint: {e}")
return False
def get_model_info(self) -> Dict[str, Any]:
"""Get model information"""
return {
'model_name': self.model_name,
'model_type': self.model_type,
'confidence_threshold': self.confidence_threshold,
'expected_feature_dim': self.expected_feature_dim,
'device': str(self.device),
'parameter_count': sum(p.numel() for p in self.parameters()),
'trainable_parameters': sum(p.numel() for p in self.parameters() if p.requires_grad)
}

View File

@@ -1,821 +0,0 @@
"""
Transformer Neural Network for timeseries analysis
This module implements a Transformer model with attention mechanisms for cryptocurrency price analysis.
It also includes a Mixture of Experts model that combines predictions from multiple models.
"""
import os
import logging
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import (
Input, Dense, Dropout, BatchNormalization,
Concatenate, Layer, LayerNormalization, MultiHeadAttention,
Add, GlobalAveragePooling1D, Conv1D, Reshape
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import datetime
import json
logger = logging.getLogger(__name__)
class TransformerBlock(Layer):
"""
Transformer block implementation with multi-head attention and feed-forward networks.
"""
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
super(TransformerBlock, self).__init__()
self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
self.ffn = tf.keras.Sequential([
Dense(ff_dim, activation="relu"),
Dense(embed_dim),
])
self.layernorm1 = LayerNormalization(epsilon=1e-6)
self.layernorm2 = LayerNormalization(epsilon=1e-6)
self.dropout1 = Dropout(rate)
self.dropout2 = Dropout(rate)
def call(self, inputs, training=False):
attn_output = self.att(inputs, inputs)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(inputs + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
return self.layernorm2(out1 + ffn_output)
def get_config(self):
config = super().get_config()
config.update({
'att': self.att,
'ffn': self.ffn,
'layernorm1': self.layernorm1,
'layernorm2': self.layernorm2,
'dropout1': self.dropout1,
'dropout2': self.dropout2
})
return config
class PositionalEncoding(Layer):
"""
Positional encoding layer to add position information to input embeddings.
"""
def __init__(self, position, d_model):
super(PositionalEncoding, self).__init__()
self.position = position
self.d_model = d_model
self.pos_encoding = self.positional_encoding(position, d_model)
def get_angles(self, position, i, d_model):
angles = 1 / tf.pow(10000, (2 * (i // 2)) / tf.cast(d_model, tf.float32))
return position * angles
def positional_encoding(self, position, d_model):
angle_rads = self.get_angles(
position=tf.range(position, dtype=tf.float32)[:, tf.newaxis],
i=tf.range(d_model, dtype=tf.float32)[tf.newaxis, :],
d_model=d_model
)
# Apply sin to even indices in the array
sines = tf.math.sin(angle_rads[:, 0::2])
# Apply cos to odd indices in the array
cosines = tf.math.cos(angle_rads[:, 1::2])
pos_encoding = tf.concat([sines, cosines], axis=-1)
pos_encoding = pos_encoding[tf.newaxis, ...]
return tf.cast(pos_encoding, tf.float32)
def call(self, inputs):
return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :]
def get_config(self):
config = super().get_config()
config.update({
'position': self.position,
'd_model': self.d_model,
'pos_encoding': self.pos_encoding
})
return config
class TransformerModel:
"""
Transformer Neural Network for time series analysis.
This model uses self-attention mechanisms to capture relationships between
different time points in the input data.
"""
def __init__(self, ts_input_shape=(20, 5), feature_input_shape=64, output_size=1, model_dir="NN/models/saved"):
"""
Initialize the Transformer model.
Args:
ts_input_shape (tuple): Shape of time series input data (sequence_length, features)
feature_input_shape (int): Shape of additional feature input (e.g., from CNN)
output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell)
model_dir (str): Directory to save trained models
"""
self.ts_input_shape = ts_input_shape
self.feature_input_shape = feature_input_shape
self.output_size = output_size
self.model_dir = model_dir
self.model = None
self.history = None
# Create model directory if it doesn't exist
os.makedirs(self.model_dir, exist_ok=True)
logger.info(f"Initialized Transformer model with TS input shape {ts_input_shape}, "
f"feature input shape {feature_input_shape}, and output size {output_size}")
def build_model(self, embed_dim=32, num_heads=4, ff_dim=64, num_transformer_blocks=2, dropout_rate=0.1, learning_rate=0.001):
"""
Build the Transformer model architecture.
Args:
embed_dim (int): Embedding dimension for transformer
num_heads (int): Number of attention heads
ff_dim (int): Hidden dimension of the feed forward network
num_transformer_blocks (int): Number of transformer blocks
dropout_rate (float): Dropout rate for regularization
learning_rate (float): Learning rate for Adam optimizer
Returns:
The compiled model
"""
# Time series input
ts_inputs = Input(shape=self.ts_input_shape, name="ts_input")
# Additional feature input (e.g., from CNN)
feature_inputs = Input(shape=(self.feature_input_shape,), name="feature_input")
# Process time series with transformer
# First, project the input to the embedding dimension
x = Conv1D(embed_dim, 1, activation="relu")(ts_inputs)
# Add positional encoding
x = PositionalEncoding(self.ts_input_shape[0], embed_dim)(x)
# Add transformer blocks
for _ in range(num_transformer_blocks):
x = TransformerBlock(embed_dim, num_heads, ff_dim, dropout_rate)(x)
# Global pooling to get a single vector representation
x = GlobalAveragePooling1D()(x)
x = Dropout(dropout_rate)(x)
# Combine with additional features
combined = Concatenate()([x, feature_inputs])
# Dense layers for final classification/regression
x = Dense(64, activation="relu")(combined)
x = BatchNormalization()(x)
x = Dropout(dropout_rate)(x)
# Output layer
if self.output_size == 1:
# Binary classification (up/down)
outputs = Dense(1, activation='sigmoid', name='output')(x)
loss = 'binary_crossentropy'
metrics = ['accuracy']
elif self.output_size == 3:
# Multi-class classification (buy/hold/sell)
outputs = Dense(3, activation='softmax', name='output')(x)
loss = 'categorical_crossentropy'
metrics = ['accuracy']
else:
# Regression
outputs = Dense(self.output_size, activation='linear', name='output')(x)
loss = 'mse'
metrics = ['mae']
# Create and compile model
self.model = Model(inputs=[ts_inputs, feature_inputs], outputs=outputs)
# Compile with Adam optimizer
self.model.compile(
optimizer=Adam(learning_rate=learning_rate),
loss=loss,
metrics=metrics
)
# Log model summary
self.model.summary(print_fn=lambda x: logger.info(x))
return self.model
def train(self, X_ts, X_features, y, batch_size=32, epochs=100, validation_split=0.2,
callbacks=None, class_weights=None):
"""
Train the Transformer model on the provided data.
Args:
X_ts (numpy.ndarray): Time series input features
X_features (numpy.ndarray): Additional input features
y (numpy.ndarray): Target labels
batch_size (int): Batch size
epochs (int): Number of epochs
validation_split (float): Fraction of data to use for validation
callbacks (list): List of Keras callbacks
class_weights (dict): Class weights for imbalanced datasets
Returns:
History object containing training metrics
"""
if self.model is None:
self.build_model()
# Default callbacks if none provided
if callbacks is None:
# Create a timestamp for model checkpoints
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
callbacks = [
EarlyStopping(
monitor='val_loss',
patience=10,
restore_best_weights=True
),
ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=5,
min_lr=1e-6
),
ModelCheckpoint(
filepath=os.path.join(self.model_dir, f"transformer_model_{timestamp}.h5"),
monitor='val_loss',
save_best_only=True
)
]
# Check if y needs to be one-hot encoded for multi-class
if self.output_size == 3 and len(y.shape) == 1:
y = tf.keras.utils.to_categorical(y, num_classes=3)
# Train the model
logger.info(f"Training Transformer model with {len(X_ts)} samples, batch size {batch_size}, epochs {epochs}")
self.history = self.model.fit(
[X_ts, X_features], y,
batch_size=batch_size,
epochs=epochs,
validation_split=validation_split,
callbacks=callbacks,
class_weight=class_weights,
verbose=2
)
# Save the trained model
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = os.path.join(self.model_dir, f"transformer_model_final_{timestamp}.h5")
self.model.save(model_path)
logger.info(f"Model saved to {model_path}")
# Save training history
history_path = os.path.join(self.model_dir, f"transformer_model_history_{timestamp}.json")
with open(history_path, 'w') as f:
# Convert numpy values to Python native types for JSON serialization
history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()}
json.dump(history_dict, f, indent=2)
return self.history
def evaluate(self, X_ts, X_features, y):
"""
Evaluate the model on test data.
Args:
X_ts (numpy.ndarray): Time series input features
X_features (numpy.ndarray): Additional input features
y (numpy.ndarray): Target labels
Returns:
dict: Evaluation metrics
"""
if self.model is None:
raise ValueError("Model has not been built or trained yet")
# Convert y to one-hot encoding for multi-class
if self.output_size == 3 and len(y.shape) == 1:
y = tf.keras.utils.to_categorical(y, num_classes=3)
# Evaluate model
logger.info(f"Evaluating Transformer model on {len(X_ts)} samples")
eval_results = self.model.evaluate([X_ts, X_features], y, verbose=0)
metrics = {}
for metric, value in zip(self.model.metrics_names, eval_results):
metrics[metric] = value
logger.info(f"{metric}: {value:.4f}")
return metrics
def predict(self, X_ts, X_features=None):
"""
Make predictions on new data.
Args:
X_ts (numpy.ndarray): Time series input features
X_features (numpy.ndarray): Additional input features
Returns:
tuple: (y_pred, y_proba) where:
y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class)
y_proba is the class probability
"""
if self.model is None:
raise ValueError("Model has not been built or trained yet")
# Ensure X_ts has the right shape
if len(X_ts.shape) == 2:
# Single sample, add batch dimension
X_ts = np.expand_dims(X_ts, axis=0)
# Ensure X_features has the right shape
if X_features is None:
# Extract features from time series data if no external features provided
X_features = self._extract_features_from_timeseries(X_ts)
elif len(X_features.shape) == 1:
# Single sample, add batch dimension
X_features = np.expand_dims(X_features, axis=0)
def _extract_features_from_timeseries(self, X_ts: np.ndarray) -> np.ndarray:
"""Extract meaningful features from time series data instead of using dummy zeros"""
try:
batch_size = X_ts.shape[0]
features = []
for i in range(batch_size):
sample = X_ts[i] # Shape: (timesteps, features)
# Extract statistical features from each feature dimension
sample_features = []
for feature_idx in range(sample.shape[1]):
feature_data = sample[:, feature_idx]
# Basic statistical features
sample_features.extend([
np.mean(feature_data), # Mean
np.std(feature_data), # Standard deviation
np.min(feature_data), # Minimum
np.max(feature_data), # Maximum
np.percentile(feature_data, 25), # 25th percentile
np.percentile(feature_data, 75), # 75th percentile
])
# Trend features
if len(feature_data) > 1:
# Linear trend (slope)
x = np.arange(len(feature_data))
slope = np.polyfit(x, feature_data, 1)[0]
sample_features.append(slope)
# Rate of change
rate_of_change = (feature_data[-1] - feature_data[0]) / feature_data[0] if feature_data[0] != 0 else 0
sample_features.append(rate_of_change)
else:
sample_features.extend([0.0, 0.0])
# Pad or truncate to expected feature size
while len(sample_features) < self.feature_input_shape:
sample_features.append(0.0)
sample_features = sample_features[:self.feature_input_shape]
features.append(sample_features)
return np.array(features, dtype=np.float32)
except Exception as e:
logger.error(f"Error extracting features from time series: {e}")
# Fallback to zeros if extraction fails
return np.zeros((X_ts.shape[0], self.feature_input_shape), dtype=np.float32)
# Get predictions
y_proba = self.model.predict([X_ts, X_features])
# Process based on output type
if self.output_size == 1:
# Binary classification
y_pred = (y_proba > 0.5).astype(int).flatten()
return y_pred, y_proba.flatten()
elif self.output_size == 3:
# Multi-class classification
y_pred = np.argmax(y_proba, axis=1)
return y_pred, y_proba
else:
# Regression
return y_proba, y_proba
def save(self, filepath=None):
"""
Save the model to disk.
Args:
filepath (str): Path to save the model
Returns:
str: Path where the model was saved
"""
if self.model is None:
raise ValueError("Model has not been built yet")
if filepath is None:
# Create a default filepath with timestamp
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filepath = os.path.join(self.model_dir, f"transformer_model_{timestamp}.h5")
self.model.save(filepath)
logger.info(f"Model saved to {filepath}")
return filepath
def load(self, filepath):
"""
Load a saved model from disk.
Args:
filepath (str): Path to the saved model
Returns:
The loaded model
"""
# Register custom layers
custom_objects = {
'TransformerBlock': TransformerBlock,
'PositionalEncoding': PositionalEncoding
}
self.model = load_model(filepath, custom_objects=custom_objects)
logger.info(f"Model loaded from {filepath}")
return self.model
def plot_training_history(self):
"""
Plot training history (loss and metrics).
Returns:
str: Path to the saved plot
"""
if self.history is None:
raise ValueError("Model has not been trained yet")
plt.figure(figsize=(12, 5))
# Plot loss
plt.subplot(1, 2, 1)
plt.plot(self.history.history['loss'], label='Training Loss')
if 'val_loss' in self.history.history:
plt.plot(self.history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
# Plot accuracy
plt.subplot(1, 2, 2)
if 'accuracy' in self.history.history:
plt.plot(self.history.history['accuracy'], label='Training Accuracy')
if 'val_accuracy' in self.history.history:
plt.plot(self.history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
elif 'mae' in self.history.history:
plt.plot(self.history.history['mae'], label='Training MAE')
if 'val_mae' in self.history.history:
plt.plot(self.history.history['val_mae'], label='Validation MAE')
plt.title('Model MAE')
plt.ylabel('MAE')
plt.xlabel('Epoch')
plt.legend()
plt.tight_layout()
# Save figure
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
fig_path = os.path.join(self.model_dir, f"transformer_training_history_{timestamp}.png")
plt.savefig(fig_path)
plt.close()
logger.info(f"Training history plot saved to {fig_path}")
return fig_path
class MixtureOfExpertsModel:
"""
Mixture of Experts (MoE) model.
This model combines predictions from multiple expert models (such as CNN and Transformer)
using a weighted ensemble approach.
"""
def __init__(self, output_size=1, model_dir="NN/models/saved"):
"""
Initialize the MoE model.
Args:
output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell)
model_dir (str): Directory to save trained models
"""
self.output_size = output_size
self.model_dir = model_dir
self.model = None
self.history = None
self.experts = {}
# Create model directory if it doesn't exist
os.makedirs(self.model_dir, exist_ok=True)
logger.info(f"Initialized Mixture of Experts model with output size {output_size}")
def add_expert(self, name, model):
"""
Add an expert model to the MoE.
Args:
name (str): Name of the expert model
model: The expert model instance
Returns:
None
"""
self.experts[name] = model
logger.info(f"Added expert model '{name}' to MoE")
def build_model(self, ts_input_shape=(20, 5), expert_weights=None, learning_rate=0.001):
"""
Build the MoE model by combining expert models.
Args:
ts_input_shape (tuple): Shape of time series input data
expert_weights (dict): Weights for each expert model
learning_rate (float): Learning rate for Adam optimizer
Returns:
The compiled model
"""
# Time series input
ts_inputs = Input(shape=ts_input_shape, name="ts_input")
# Additional feature input (from CNN)
feature_inputs = Input(shape=(64,), name="feature_input") # Default size for features
# Process with each expert model
expert_outputs = []
expert_names = []
for name, expert in self.experts.items():
# Skip if expert model is not valid or doesn't have a call/predict method
if expert is None:
logger.warning(f"Expert model '{name}' is None, skipping")
continue
try:
# Different handling based on model type
if name == 'cnn':
# CNN model takes only time series input
expert_output = expert(ts_inputs)
expert_outputs.append(expert_output)
expert_names.append(name)
elif name == 'transformer':
# Transformer model takes both time series and feature inputs
expert_output = expert([ts_inputs, feature_inputs])
expert_outputs.append(expert_output)
expert_names.append(name)
else:
logger.warning(f"Unknown expert model type: {name}")
except Exception as e:
logger.error(f"Error adding expert '{name}': {str(e)}")
if not expert_outputs:
logger.error("No valid expert models found")
return None
# Use expert weighting
if expert_weights is None:
# Equal weighting
weights = [1.0 / len(expert_outputs)] * len(expert_outputs)
else:
# User-provided weights
weights = [expert_weights.get(name, 1.0 / len(expert_outputs)) for name in expert_names]
# Normalize weights
weights = [w / sum(weights) for w in weights]
# Combine expert outputs using weighted average
if len(expert_outputs) == 1:
# Only one expert, use its output directly
combined_output = expert_outputs[0]
else:
# Multiple experts, compute weighted average
weighted_outputs = [output * weight for output, weight in zip(expert_outputs, weights)]
combined_output = Add()(weighted_outputs)
# Create the MoE model
moe_model = Model(inputs=[ts_inputs, feature_inputs], outputs=combined_output)
# Compile the model
if self.output_size == 1:
# Binary classification
moe_model.compile(
optimizer=Adam(learning_rate=learning_rate),
loss='binary_crossentropy',
metrics=['accuracy']
)
elif self.output_size == 3:
# Multi-class classification for BUY/HOLD/SELL
moe_model.compile(
optimizer=Adam(learning_rate=learning_rate),
loss='categorical_crossentropy',
metrics=['accuracy']
)
else:
# Regression
moe_model.compile(
optimizer=Adam(learning_rate=learning_rate),
loss='mse',
metrics=['mae']
)
self.model = moe_model
# Log model summary
self.model.summary(print_fn=lambda x: logger.info(x))
logger.info(f"Built MoE model with weights: {weights}")
return self.model
def train(self, X_ts, X_features, y, batch_size=32, epochs=100, validation_split=0.2,
callbacks=None, class_weights=None):
"""
Train the MoE model on the provided data.
Args:
X_ts (numpy.ndarray): Time series input features
X_features (numpy.ndarray): Additional input features
y (numpy.ndarray): Target labels
batch_size (int): Batch size
epochs (int): Number of epochs
validation_split (float): Fraction of data to use for validation
callbacks (list): List of Keras callbacks
class_weights (dict): Class weights for imbalanced datasets
Returns:
History object containing training metrics
"""
if self.model is None:
logger.error("MoE model has not been built yet")
return None
# Default callbacks if none provided
if callbacks is None:
# Create a timestamp for model checkpoints
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
callbacks = [
EarlyStopping(
monitor='val_loss',
patience=10,
restore_best_weights=True
),
ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=5,
min_lr=1e-6
),
ModelCheckpoint(
filepath=os.path.join(self.model_dir, f"moe_model_{timestamp}.h5"),
monitor='val_loss',
save_best_only=True
)
]
# Check if y needs to be one-hot encoded for multi-class
if self.output_size == 3 and len(y.shape) == 1:
y = tf.keras.utils.to_categorical(y, num_classes=3)
# Train the model
logger.info(f"Training MoE model with {len(X_ts)} samples, batch size {batch_size}, epochs {epochs}")
self.history = self.model.fit(
[X_ts, X_features], y,
batch_size=batch_size,
epochs=epochs,
validation_split=validation_split,
callbacks=callbacks,
class_weight=class_weights,
verbose=2
)
# Save the trained model
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = os.path.join(self.model_dir, f"moe_model_final_{timestamp}.h5")
self.model.save(model_path)
logger.info(f"Model saved to {model_path}")
# Save training history
history_path = os.path.join(self.model_dir, f"moe_model_history_{timestamp}.json")
with open(history_path, 'w') as f:
# Convert numpy values to Python native types for JSON serialization
history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()}
json.dump(history_dict, f, indent=2)
return self.history
def predict(self, X_ts, X_features=None):
"""
Make predictions on new data.
Args:
X_ts (numpy.ndarray): Time series input features
X_features (numpy.ndarray): Additional input features
Returns:
tuple: (y_pred, y_proba) where:
y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class)
y_proba is the class probability
"""
if self.model is None:
raise ValueError("Model has not been built or trained yet")
# Ensure X_ts has the right shape
if len(X_ts.shape) == 2:
# Single sample, add batch dimension
X_ts = np.expand_dims(X_ts, axis=0)
# Ensure X_features has the right shape
if X_features is None:
# Create dummy features with zeros
X_features = np.zeros((X_ts.shape[0], 64)) # Default size
elif len(X_features.shape) == 1:
# Single sample, add batch dimension
X_features = np.expand_dims(X_features, axis=0)
# Get predictions
y_proba = self.model.predict([X_ts, X_features])
# Process based on output type
if self.output_size == 1:
# Binary classification
y_pred = (y_proba > 0.5).astype(int).flatten()
return y_pred, y_proba.flatten()
elif self.output_size == 3:
# Multi-class classification
y_pred = np.argmax(y_proba, axis=1)
return y_pred, y_proba
else:
# Regression
return y_proba, y_proba
def save(self, filepath=None):
"""
Save the model to disk.
Args:
filepath (str): Path to save the model
Returns:
str: Path where the model was saved
"""
if self.model is None:
raise ValueError("Model has not been built yet")
if filepath is None:
# Create a default filepath with timestamp
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filepath = os.path.join(self.model_dir, f"moe_model_{timestamp}.h5")
self.model.save(filepath)
logger.info(f"Model saved to {filepath}")
return filepath
def load(self, filepath):
"""
Load a saved model from disk.
Args:
filepath (str): Path to the saved model
Returns:
The loaded model
"""
# Register custom layers
custom_objects = {
'TransformerBlock': TransformerBlock,
'PositionalEncoding': PositionalEncoding
}
self.model = load_model(filepath, custom_objects=custom_objects)
logger.info(f"Model loaded from {filepath}")
return self.model
# Example usage:
if __name__ == "__main__":
# This would be a complete implementation in a real system
print("Transformer and MoE models defined, but not implemented here.")

View File

@@ -1,88 +0,0 @@
#!/usr/bin/env python
"""
Start TensorBoard for monitoring neural network training
"""
import os
import sys
import subprocess
import webbrowser
from time import sleep
def start_tensorboard(logdir="NN/models/saved/logs", port=6006, open_browser=True):
"""
Start TensorBoard in a subprocess
Args:
logdir: Directory containing TensorBoard logs
port: Port to run TensorBoard on
open_browser: Whether to open a browser automatically
"""
# Make sure the log directory exists
os.makedirs(logdir, exist_ok=True)
# Create command
cmd = [
sys.executable,
"-m",
"tensorboard.main",
f"--logdir={logdir}",
f"--port={port}",
"--bind_all"
]
print(f"Starting TensorBoard with logs from {logdir} on port {port}")
print(f"Command: {' '.join(cmd)}")
# Start TensorBoard in a subprocess
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True
)
# Wait for TensorBoard to start up
for line in process.stdout:
print(line.strip())
if "TensorBoard" in line and "http://" in line:
# TensorBoard is running, extract the URL
url = None
for part in line.split():
if part.startswith(("http://", "https://")):
url = part
break
# Open browser if requested and URL found
if open_browser and url:
print(f"Opening TensorBoard in browser: {url}")
webbrowser.open(url)
break
# Return the process for the caller to manage
return process
if __name__ == "__main__":
import argparse
# Parse command line arguments
parser = argparse.ArgumentParser(description="Start TensorBoard for NN training visualization")
parser.add_argument("--logdir", default="NN/models/saved/logs", help="Directory containing TensorBoard logs")
parser.add_argument("--port", type=int, default=6006, help="Port to run TensorBoard on")
parser.add_argument("--no-browser", action="store_true", help="Don't open browser automatically")
args = parser.parse_args()
# Start TensorBoard
process = start_tensorboard(args.logdir, args.port, not args.no_browser)
try:
# Keep the script running until Ctrl+C
print("TensorBoard is running. Press Ctrl+C to stop.")
while True:
sleep(1)
except KeyboardInterrupt:
print("Stopping TensorBoard...")
process.terminate()
process.wait()

View File

@@ -27,8 +27,18 @@ import torch
import torch.nn as nn
import torch.optim as optim
<<<<<<< HEAD
# Import prediction tracking
from core.prediction_database import get_prediction_db
=======
# Import checkpoint management
try:
from utils.checkpoint_manager import get_checkpoint_manager, save_checkpoint
CHECKPOINT_MANAGER_AVAILABLE = True
except ImportError:
CHECKPOINT_MANAGER_AVAILABLE = False
logger.warning("Checkpoint manager not available. Model persistence will be disabled.")
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
logger = logging.getLogger(__name__)
@@ -61,12 +71,19 @@ class EnhancedRealtimeTrainingSystem:
# Experience buffers
self.experience_buffer = deque(maxlen=self.training_config['memory_size'])
self.validation_buffer = deque(maxlen=1000)
# Training counters - CRITICAL for checkpoint management
self.training_iteration = 0
self.dqn_training_count = 0
self.cnn_training_count = 0
self.cob_training_count = 0
self.priority_buffer = deque(maxlen=2000) # High-priority experiences
# Performance tracking
self.performance_history = {
'dqn_losses': deque(maxlen=1000),
'cnn_losses': deque(maxlen=1000),
'cob_rl_losses': deque(maxlen=1000), # Added COB RL loss tracking
'prediction_accuracy': deque(maxlen=500),
'trading_performance': deque(maxlen=200),
'validation_scores': deque(maxlen=100)
@@ -764,18 +781,33 @@ class EnhancedRealtimeTrainingSystem:
# Statistical features across time for each aggregated dimension
for feature_idx in range(agg_matrix.shape[1]):
feature_series = agg_matrix[:, feature_idx]
combined_features.extend([
np.mean(feature_series),
np.std(feature_series),
np.min(feature_series),
np.max(feature_series),
feature_series[-1] - feature_series[0] if len(feature_series) > 1 else 0, # Total change
np.mean(np.diff(feature_series)) if len(feature_series) > 1 else 0, # Average momentum
np.std(np.diff(feature_series)) if len(feature_series) > 2 else 0, # Momentum volatility
np.percentile(feature_series, 25), # 25th percentile
np.percentile(feature_series, 75), # 75th percentile
len([x for x in np.diff(feature_series) if x > 0]) / max(len(feature_series) - 1, 1) if len(feature_series) > 1 else 0.5 # Positive change ratio
])
# Clean feature series to prevent division warnings
feature_series_clean = feature_series[np.isfinite(feature_series)]
if len(feature_series_clean) > 0:
# Safe percentile calculation
try:
percentile_25 = np.percentile(feature_series_clean, 25)
percentile_75 = np.percentile(feature_series_clean, 75)
except (ValueError, RuntimeWarning):
percentile_25 = np.median(feature_series_clean) if len(feature_series_clean) > 0 else 0
percentile_75 = np.median(feature_series_clean) if len(feature_series_clean) > 0 else 0
combined_features.extend([
np.mean(feature_series_clean),
np.std(feature_series_clean),
np.min(feature_series_clean),
np.max(feature_series_clean),
feature_series_clean[-1] - feature_series_clean[0] if len(feature_series_clean) > 1 else 0, # Total change
np.mean(np.diff(feature_series_clean)) if len(feature_series_clean) > 1 else 0, # Average momentum
np.std(np.diff(feature_series_clean)) if len(feature_series_clean) > 2 else 0, # Momentum volatility
percentile_25, # 25th percentile
percentile_75, # 75th percentile
len([x for x in np.diff(feature_series_clean) if x > 0]) / max(len(feature_series_clean) - 1, 1) if len(feature_series_clean) > 1 else 0.5 # Positive change ratio
])
else:
# All values are NaN or inf, use zeros
combined_features.extend([0.0] * 10)
else:
combined_features.extend([0.0] * (15 * 10)) # 15 features * 10 statistics
@@ -913,13 +945,14 @@ class EnhancedRealtimeTrainingSystem:
lows = np.array([bar['low'] for bar in self.real_time_data['ohlcv_1m']])
# Update indicators
price_mean = np.mean(prices[-20:])
self.technical_indicators = {
'sma_10': np.mean(prices[-10:]),
'sma_20': np.mean(prices[-20:]),
'rsi': self._calculate_rsi(prices, 14),
'volatility': np.std(prices[-20:]) / np.mean(prices[-20:]),
'volatility': np.std(prices[-20:]) / price_mean if price_mean > 0 else 0,
'volume_sma': np.mean(volumes[-10:]),
'price_momentum': (prices[-1] - prices[-5]) / prices[-5] if len(prices) >= 5 else 0,
'price_momentum': (prices[-1] - prices[-5]) / prices[-5] if len(prices) >= 5 and prices[-5] > 0 else 0,
'atr': np.mean(highs[-14:] - lows[-14:]) if len(prices) >= 14 else 0
}
@@ -935,8 +968,8 @@ class EnhancedRealtimeTrainingSystem:
current_time = time.time()
current_bar = self.real_time_data['ohlcv_1m'][-1]
# Create comprehensive state features
state_features = self._build_comprehensive_state()
# Create comprehensive state features with default dimensions
state_features = self._build_comprehensive_state(100) # Use default 100 for general experiences
# Create experience with proper reward calculation
experience = {
@@ -959,8 +992,8 @@ class EnhancedRealtimeTrainingSystem:
except Exception as e:
logger.debug(f"Error creating training experiences: {e}")
def _build_comprehensive_state(self) -> np.ndarray:
"""Build comprehensive state vector for RL training"""
def _build_comprehensive_state(self, target_dimensions: int = 100) -> np.ndarray:
"""Build comprehensive state vector for RL training with adaptive dimensions"""
try:
state_features = []
@@ -1003,15 +1036,138 @@ class EnhancedRealtimeTrainingSystem:
state_features.append(np.cos(2 * np.pi * now.hour / 24))
state_features.append(now.weekday() / 6.0) # Day of week
# Pad to fixed size (100 features)
while len(state_features) < 100:
# Current count: 10 (prices) + 7 (indicators) + 1 (volume) + 5 (COB) + 3 (time) = 26 base features
# 6. Enhanced features for larger dimensions
if target_dimensions > 50:
# Add more price history
if len(self.real_time_data['ohlcv_1m']) >= 20:
extended_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-20:]]
base_price = extended_prices[0]
extended_normalized = [(p - base_price) / base_price for p in extended_prices[10:]] # Additional 10
state_features.extend(extended_normalized)
else:
state_features.extend([0.0] * 10)
# Add volume history
if len(self.real_time_data['ohlcv_1m']) >= 10:
volume_history = [bar['volume'] for bar in list(self.real_time_data['ohlcv_1m'])[-10:]]
avg_vol = np.mean(volume_history) if volume_history else 1.0
# Prevent division by zero
if avg_vol == 0:
avg_vol = 1.0
normalized_volumes = [v / avg_vol for v in volume_history]
state_features.extend(normalized_volumes)
else:
state_features.extend([0.0] * 10)
# Add extended COB features
extended_cob = self._extract_cob_features()
state_features.extend(extended_cob[5:]) # Remaining COB features
# Add 5m timeframe data if available
if len(self.real_time_data['ohlcv_5m']) >= 5:
tf_5m_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_5m'])[-5:]]
if tf_5m_prices:
base_5m = tf_5m_prices[0]
# Prevent division by zero
if base_5m == 0:
base_5m = 1.0
normalized_5m = [(p - base_5m) / base_5m for p in tf_5m_prices]
state_features.extend(normalized_5m)
else:
state_features.extend([0.0] * 5)
else:
state_features.extend([0.0] * 5)
# 7. Adaptive padding/truncation based on target dimensions
current_length = len(state_features)
if target_dimensions > current_length:
# Pad with additional engineered features
remaining = target_dimensions - current_length
# Add statistical features if we have data
if len(self.real_time_data['ohlcv_1m']) >= 20:
all_prices = [bar['close'] for bar in list(self.real_time_data['ohlcv_1m'])[-20:]]
all_volumes = [bar['volume'] for bar in list(self.real_time_data['ohlcv_1m'])[-20:]]
# Statistical features
additional_features = [
np.std(all_prices) / np.mean(all_prices) if np.mean(all_prices) > 0 else 0, # Price CV
np.std(all_volumes) / np.mean(all_volumes) if np.mean(all_volumes) > 0 else 0, # Volume CV
(max(all_prices) - min(all_prices)) / np.mean(all_prices) if np.mean(all_prices) > 0 else 0, # Price range
# Safe correlation calculation
np.corrcoef(all_prices, all_volumes)[0, 1] if (len(all_prices) == len(all_volumes) and len(all_prices) > 1 and
np.std(all_prices) > 0 and np.std(all_volumes) > 0) else 0, # Price-volume correlation
]
# Add momentum features
for window in [3, 5, 10]:
if len(all_prices) >= window:
momentum = (all_prices[-1] - all_prices[-window]) / all_prices[-window] if all_prices[-window] > 0 else 0
additional_features.append(momentum)
else:
additional_features.append(0.0)
# Extend to fill remaining space
while len(additional_features) < remaining and len(additional_features) < 50:
additional_features.extend([
np.sin(len(additional_features) * 0.1), # Sine waves for variety
np.cos(len(additional_features) * 0.1),
np.tanh(len(additional_features) * 0.01)
])
state_features.extend(additional_features[:remaining])
else:
# Fill with structured zeros/patterns if no data
pattern_features = []
for i in range(remaining):
pattern_features.append(np.sin(i * 0.01)) # Small oscillating pattern
state_features.extend(pattern_features)
# Ensure exact target dimension
state_features = state_features[:target_dimensions]
while len(state_features) < target_dimensions:
state_features.append(0.0)
return np.array(state_features[:100])
return np.array(state_features)
except Exception as e:
logger.error(f"Error building state: {e}")
return np.zeros(100)
return np.zeros(target_dimensions)
def _get_model_expected_dimensions(self, model_type: str) -> int:
"""Get expected input dimensions for different model types"""
try:
if model_type == 'dqn':
# Try to get DQN expected dimensions from model
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
and self.orchestrator.rl_agent and hasattr(self.orchestrator.rl_agent, 'policy_net')):
# Get first layer input size
first_layer = list(self.orchestrator.rl_agent.policy_net.children())[0]
if hasattr(first_layer, 'in_features'):
return first_layer.in_features
return 403 # Default for DQN based on error logs
elif model_type == 'cnn':
# CNN might have different input expectations
if (self.orchestrator and hasattr(self.orchestrator, 'cnn_model')
and self.orchestrator.cnn_model):
# Try to get CNN input size
if hasattr(self.orchestrator.cnn_model, 'input_shape'):
return self.orchestrator.cnn_model.input_shape
return 300 # Default for CNN based on error logs
elif model_type == 'cob_rl':
return 2000 # COB RL expects 2000 features
else:
return 100 # Default
except Exception as e:
logger.debug(f"Error getting model dimensions for {model_type}: {e}")
return 100 # Fallback
def _extract_cob_features(self) -> List[float]:
"""Extract features from COB data"""
@@ -1131,8 +1287,8 @@ class EnhancedRealtimeTrainingSystem:
total_loss += loss
training_iterations += 1
elif hasattr(rl_agent, 'replay'):
# Fallback to replay method
loss = rl_agent.replay(batch_size=len(batch))
# Fallback to replay method - DQNAgent.replay() doesn't accept batch_size parameter
loss = rl_agent.replay()
if loss is not None:
total_loss += loss
training_iterations += 1
@@ -1142,6 +1298,10 @@ class EnhancedRealtimeTrainingSystem:
self.dqn_training_count += 1
# Save checkpoint after training
if training_iterations > 0 and avg_loss > 0:
self._save_model_checkpoint('dqn_agent', rl_agent, avg_loss)
# Log progress every 10 training sessions
if self.dqn_training_count % 10 == 0:
logger.info(f"DQN TRAINING: Session {self.dqn_training_count}, "
@@ -1175,6 +1335,18 @@ class EnhancedRealtimeTrainingSystem:
aggregated_matrix = self.get_cob_training_matrix(symbol, '1s_aggregated')
if combined_features is not None:
# Ensure features are exactly 2000 dimensions
if len(combined_features) != 2000:
logger.warning(f"COB features wrong size: {len(combined_features)}, padding/truncating to 2000")
if len(combined_features) < 2000:
# Pad with zeros
padded_features = np.zeros(2000, dtype=np.float32)
padded_features[:len(combined_features)] = combined_features
combined_features = padded_features
else:
# Truncate to 2000
combined_features = combined_features[:2000]
# Create enhanced COB training experience
current_price = self._get_current_price_from_data(symbol)
if current_price:
@@ -1184,29 +1356,14 @@ class EnhancedRealtimeTrainingSystem:
# Calculate reward based on COB prediction accuracy
reward = self._calculate_cob_reward(symbol, action, combined_features)
# Create comprehensive state vector for COB RL
# Create comprehensive state vector for COB RL (exactly 2000 dimensions)
state = combined_features # 2000-dimensional state
# Store experience in COB RL agent
if hasattr(cob_rl_agent, 'store_experience'):
experience = {
'state': state,
'action': action,
'reward': reward,
'next_state': state, # Will be updated with next observation
'done': False,
'symbol': symbol,
'timestamp': datetime.now(),
'price': current_price,
'cob_features': {
'raw_tick_available': raw_tick_matrix is not None,
'aggregated_available': aggregated_matrix is not None,
'imbalance': combined_features[0] if len(combined_features) > 0 else 0,
'spread': combined_features[1] if len(combined_features) > 1 else 0,
'liquidity': combined_features[4] if len(combined_features) > 4 else 0
}
}
cob_rl_agent.store_experience(experience)
if hasattr(cob_rl_agent, 'remember'):
# Use tuple format for DQN agent compatibility
experience_tuple = (state, action, reward, state, False) # next_state = current state for now
cob_rl_agent.remember(state, action, reward, state, False)
training_updates += 1
# Perform COB RL training if enough experiences
@@ -1479,16 +1636,29 @@ class EnhancedRealtimeTrainingSystem:
# Moving averages
if len(prev_prices) >= 5:
ma5 = sum(prev_prices[-5:]) / 5
tech_features.append((current_price - ma5) / ma5)
# Prevent division by zero
if ma5 != 0:
tech_features.append((current_price - ma5) / ma5)
else:
tech_features.append(0.0)
if len(prev_prices) >= 10:
ma10 = sum(prev_prices[-10:]) / 10
tech_features.append((current_price - ma10) / ma10)
# Prevent division by zero
if ma10 != 0:
tech_features.append((current_price - ma10) / ma10)
else:
tech_features.append(0.0)
# Volatility measure
if len(prev_prices) >= 5:
volatility = np.std(prev_prices[-5:]) / np.mean(prev_prices[-5:])
tech_features.append(volatility)
price_mean = np.mean(prev_prices[-5:])
# Prevent division by zero
if price_mean != 0:
volatility = np.std(prev_prices[-5:]) / price_mean
tech_features.append(volatility)
else:
tech_features.append(0.0)
# Pad technical features to 200
while len(tech_features) < 200:
@@ -1670,6 +1840,14 @@ class EnhancedRealtimeTrainingSystem:
features_tensor = torch.from_numpy(features).float().to(device)
targets_tensor = torch.from_numpy(targets).long().to(device)
# FIXED: Move tensors to same device as model
device = next(model.parameters()).device
features_tensor = features_tensor.to(device)
targets_tensor = targets_tensor.to(device)
# Move criterion to same device as well
criterion = criterion.to(device)
# Ensure features_tensor has the correct shape for CNN (batch_size, channels, height, width)
# Assuming features are flattened (batch_size, 15*20) and need to be reshaped to (batch_size, 1, 15, 20)
# This depends on the actual CNN model architecture. Assuming a simple CNN that expects (batch, channels, height, width)
@@ -1700,6 +1878,7 @@ class EnhancedRealtimeTrainingSystem:
outputs = model(features_tensor)
<<<<<<< HEAD
# Extract logits from model output (model returns a dictionary)
if isinstance(outputs, dict):
logits = outputs['logits']
@@ -1713,6 +1892,19 @@ class EnhancedRealtimeTrainingSystem:
logger.error(f"CNN output is not a tensor: {type(logits)}")
return 0.0
=======
# FIXED: Handle case where model returns tuple (extract the logits)
if isinstance(outputs, tuple):
# Assume the first element is the main output (logits)
logits = outputs[0]
elif isinstance(outputs, dict):
# Handle dictionary output (get main prediction)
logits = outputs.get('logits', outputs.get('predictions', outputs.get('output', list(outputs.values())[0])))
else:
# Single tensor output
logits = outputs
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
loss = criterion(logits, targets_tensor)
loss.backward()
@@ -1721,8 +1913,122 @@ class EnhancedRealtimeTrainingSystem:
return loss.item()
except Exception as e:
logger.error(f"Error in CNN training: {e}")
logger.error(f"RT TRAINING: Error in CNN training: {e}")
return 1.0 # Return default loss value in case of error
def _sample_prioritized_experiences(self) -> List[Dict]:
"""Sample prioritized experiences for training"""
try:
experiences = []
# Sample from priority buffer first (high-priority experiences)
if self.priority_buffer:
priority_samples = min(len(self.priority_buffer), self.training_config['batch_size'] // 2)
priority_experiences = random.sample(list(self.priority_buffer), priority_samples)
experiences.extend(priority_experiences)
# Sample from regular experience buffer
if self.experience_buffer:
remaining_samples = self.training_config['batch_size'] - len(experiences)
if remaining_samples > 0:
regular_samples = min(len(self.experience_buffer), remaining_samples)
regular_experiences = random.sample(list(self.experience_buffer), regular_samples)
experiences.extend(regular_experiences)
# Convert experiences to DQN format
dqn_experiences = []
for exp in experiences:
# Create next state by shifting current state (simple approximation)
next_state = exp['state'].copy() if hasattr(exp['state'], 'copy') else exp['state']
# Simple reward based on recent market movement
reward = self._calculate_experience_reward(exp)
# Action mapping: 0=BUY, 1=SELL, 2=HOLD
action = self._determine_action_from_experience(exp)
dqn_exp = {
'state': exp['state'],
'action': action,
'reward': reward,
'next_state': next_state,
'done': False # Episodes don't really "end" in continuous trading
}
dqn_experiences.append(dqn_exp)
return dqn_experiences
except Exception as e:
logger.error(f"Error sampling prioritized experiences: {e}")
return []
def _calculate_experience_reward(self, experience: Dict) -> float:
"""Calculate reward for an experience"""
try:
# Simple reward based on technical indicators and market events
reward = 0.0
# Reward based on market events
if experience.get('market_events', 0) > 0:
reward += 0.1 # Bonus for learning from market events
# Reward based on technical indicators
tech_indicators = experience.get('technical_indicators', {})
if tech_indicators:
# Reward for strong momentum
momentum = tech_indicators.get('price_momentum', 0)
reward += np.tanh(momentum * 10) # Bounded reward
# Penalize high volatility
volatility = tech_indicators.get('volatility', 0)
reward -= min(volatility * 5, 0.2) # Penalty for high volatility
# Reward based on COB features
cob_features = experience.get('cob_features', [])
if cob_features and len(cob_features) > 0:
# Reward for strong order book imbalance
imbalance = cob_features[0] if len(cob_features) > 0 else 0
reward += abs(imbalance) * 0.1 # Reward for any imbalance signal
return max(-1.0, min(1.0, reward)) # Clamp to [-1, 1]
except Exception as e:
logger.debug(f"Error calculating experience reward: {e}")
return 0.0
def _determine_action_from_experience(self, experience: Dict) -> int:
"""Determine action from experience data"""
try:
# Use technical indicators to determine action
tech_indicators = experience.get('technical_indicators', {})
if tech_indicators:
momentum = tech_indicators.get('price_momentum', 0)
rsi = tech_indicators.get('rsi', 50)
# Simple logic based on momentum and RSI
if momentum > 0.005 and rsi < 70: # Upward momentum, not overbought
return 0 # BUY
elif momentum < -0.005 and rsi > 30: # Downward momentum, not oversold
return 1 # SELL
else:
return 2 # HOLD
# Fallback to COB-based action
cob_features = experience.get('cob_features', [])
if cob_features and len(cob_features) > 0:
imbalance = cob_features[0]
if imbalance > 0.1:
return 0 # BUY (bid imbalance)
elif imbalance < -0.1:
return 1 # SELL (ask imbalance)
return 2 # Default to HOLD
except Exception as e:
logger.debug(f"Error determining action from experience: {e}")
return 2 # Default to HOLD
def _perform_validation(self):
"""Perform validation to track model performance"""
@@ -2084,17 +2390,21 @@ class EnhancedRealtimeTrainingSystem:
def _generate_forward_dqn_prediction(self, symbol: str, current_time: float):
"""Generate a DQN prediction for future price movement"""
try:
# Get current market state (only historical data)
current_state = self._build_comprehensive_state()
# Get current market state with DQN-specific dimensions
target_dims = self._get_model_expected_dimensions('dqn')
current_state = self._build_comprehensive_state(target_dims)
current_price = self._get_current_price_from_data(symbol)
if current_price is None:
# SKIP prediction if price is invalid
if current_price is None or current_price <= 0:
logger.debug(f"Skipping DQN prediction for {symbol}: invalid price {current_price}")
return
# Use DQN model to predict action (if available)
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
and self.orchestrator.rl_agent):
<<<<<<< HEAD
# Use RL agent to make prediction
current_state = self._get_dqn_state_features(symbol)
if current_state is None:
@@ -2112,6 +2422,28 @@ class EnhancedRealtimeTrainingSystem:
confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33
=======
# Get action from DQN agent
action = self.orchestrator.rl_agent.act(current_state, explore=False)
# Get Q-values by manually calling the model
q_values = self._get_dqn_q_values(current_state)
# Calculate confidence from Q-values
if q_values is not None and len(q_values) > 0:
# Convert to probabilities and get confidence
probs = torch.softmax(torch.tensor(q_values), dim=0).numpy()
confidence = float(max(probs))
q_values = q_values.tolist() if hasattr(q_values, 'tolist') else list(q_values)
else:
confidence = 0.33
q_values = [0.33, 0.33, 0.34] # Default uniform distribution
# Handle case where action is None (HOLD)
if action is None:
action = 2 # Map None to HOLD action
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
else:
# Fallback to technical analysis-based prediction
action, q_values, confidence = self._technical_analysis_prediction(symbol)
@@ -2138,8 +2470,8 @@ class EnhancedRealtimeTrainingSystem:
if symbol in self.pending_predictions:
self.pending_predictions[symbol].append(prediction)
# Add to recent predictions for display (only if confident enough)
if confidence > 0.4:
# Add to recent predictions for display (only if confident enough AND valid price)
if confidence > 0.4 and current_price > 0:
display_prediction = {
'timestamp': prediction_time,
'price': current_price,
@@ -2152,6 +2484,7 @@ class EnhancedRealtimeTrainingSystem:
self.last_prediction_time[symbol] = int(current_time)
<<<<<<< HEAD
# Robust action labeling
if action is None:
action_label = 'HOLD'
@@ -2163,10 +2496,46 @@ class EnhancedRealtimeTrainingSystem:
action_label = 'UNKNOWN'
logger.info(f"Forward DQN prediction: {symbol} action={action_label} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
=======
logger.info(f"Forward DQN prediction: {symbol} action={['BUY','SELL','HOLD'][action]} confidence={confidence:.2f} price=${current_price:.2f} target={target_time.strftime('%H:%M:%S')} dims={len(current_state)}")
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
except Exception as e:
logger.error(f"Error generating forward DQN prediction: {e}")
def _get_dqn_q_values(self, state: np.ndarray) -> Optional[np.ndarray]:
"""Get Q-values from DQN agent without performing action selection"""
try:
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
return None
rl_agent = self.orchestrator.rl_agent
# Convert state to tensor
if isinstance(state, np.ndarray):
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(rl_agent.device)
else:
state_tensor = state.unsqueeze(0).to(rl_agent.device)
# Get Q-values directly from policy network
with torch.no_grad():
policy_output = rl_agent.policy_net(state_tensor)
# Handle different output formats
if isinstance(policy_output, dict):
q_values = policy_output.get('q_values', policy_output.get('Q_values', list(policy_output.values())[0]))
elif isinstance(policy_output, tuple):
q_values = policy_output[0] # Assume first element is Q-values
else:
q_values = policy_output
# Convert to numpy
return q_values.cpu().data.numpy()[0]
except Exception as e:
logger.debug(f"Error getting DQN Q-values: {e}")
return None
def _generate_forward_cnn_prediction(self, symbol: str, current_time: float):
"""Generate a CNN prediction for future price direction"""
try:
@@ -2174,9 +2543,15 @@ class EnhancedRealtimeTrainingSystem:
current_price = self._get_current_price_from_data(symbol)
price_sequence = self._get_historical_price_sequence(symbol, periods=15)
if current_price is None or len(price_sequence) < 15:
# SKIP prediction if price is invalid
if current_price is None or current_price <= 0:
logger.debug(f"Skipping CNN prediction for {symbol}: invalid price {current_price}")
return
if len(price_sequence) < 15:
logger.debug(f"Skipping CNN prediction for {symbol}: insufficient data")
return
# Use CNN model to predict direction (if available)
if (self.orchestrator and hasattr(self.orchestrator, 'cnn_model')
and self.orchestrator.cnn_model):
@@ -2229,8 +2604,8 @@ class EnhancedRealtimeTrainingSystem:
if symbol in self.pending_predictions:
self.pending_predictions[symbol].append(prediction)
# Add to recent predictions for display (only if confident enough)
if confidence > 0.5:
# Add to recent predictions for display (only if confident enough AND valid prices)
if confidence > 0.5 and current_price > 0 and predicted_price > 0:
display_prediction = {
'timestamp': prediction_time,
'current_price': current_price,
@@ -2241,7 +2616,7 @@ class EnhancedRealtimeTrainingSystem:
if symbol in self.recent_cnn_predictions:
self.recent_cnn_predictions[symbol].append(display_prediction)
logger.info(f"Forward CNN prediction: {symbol} direction={['DOWN','SAME','UP'][direction]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
logger.info(f"Forward CNN prediction: {symbol} direction={['DOWN','SAME','UP'][direction]} confidence={confidence:.2f} price=${current_price:.2f} -> ${predicted_price:.2f} target={target_time.strftime('%H:%M:%S')}")
except Exception as e:
logger.error(f"Error generating forward CNN prediction: {e}")
@@ -2332,8 +2707,24 @@ class EnhancedRealtimeTrainingSystem:
def _get_current_price_from_data(self, symbol: str) -> Optional[float]:
"""Get current price from real-time data streams"""
try:
# First, try to get from data provider (most reliable)
if self.data_provider:
price = self.data_provider.get_current_price(symbol)
if price and price > 0:
return price
# Fallback to internal buffer
if len(self.real_time_data['ohlcv_1m']) > 0:
return self.real_time_data['ohlcv_1m'][-1]['close']
price = self.real_time_data['ohlcv_1m'][-1]['close']
if price and price > 0:
return price
# Fallback to orchestrator price
if self.orchestrator:
price = self.orchestrator._get_current_price(symbol)
if price and price > 0:
return price
return None
except Exception as e:
logger.debug(f"Error getting current price: {e}")
@@ -2428,4 +2819,56 @@ class EnhancedRealtimeTrainingSystem:
except Exception as e:
logger.debug(f"Error estimating price change: {e}")
return 0.0
return 0.0 d
ef _save_model_checkpoint(self, model_name: str, model_obj, loss: float):
"""
Save model checkpoint after training if performance improved
This is CRITICAL for preserving training progress across restarts.
"""
try:
if not CHECKPOINT_MANAGER_AVAILABLE:
return
# Get checkpoint manager
checkpoint_manager = get_checkpoint_manager()
if not checkpoint_manager:
return
# Prepare performance metrics
performance_metrics = {
'loss': loss,
'training_samples': len(self.experience_buffer),
'timestamp': datetime.now().isoformat()
}
# Prepare training metadata
training_metadata = {
'timestamp': datetime.now().isoformat(),
'training_iteration': self.training_iteration,
'model_type': model_name
}
# Determine model type based on model name
model_type = model_name
if 'dqn' in model_name.lower():
model_type = 'dqn'
elif 'cnn' in model_name.lower():
model_type = 'cnn'
elif 'cob' in model_name.lower():
model_type = 'cob_rl'
# Save checkpoint
checkpoint_path = save_checkpoint(
model=model_obj,
model_name=model_name,
model_type=model_type,
performance_metrics=performance_metrics,
training_metadata=training_metadata
)
if checkpoint_path:
logger.info(f"💾 Saved checkpoint for {model_name}: {checkpoint_path} (loss: {loss:.4f})")
except Exception as e:
logger.error(f"Error saving checkpoint for {model_name}: {e}")

View File

@@ -0,0 +1,24 @@
so, curl example:
curl http://localhost:1234/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "deepseek-r1-0528-qwen3-8b",
"messages": [
{ "role": "system", "content": "what will be the next row in this sequence:
symbol MAIN SYMBOL (ETH) REF1 (BTC) REF2 (SPX) REF3 (SOL)
timeframe 1s 1m 1h 1d 1s 1s 1s
datapoint O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp
2025-01-15T10:00:00Z 3421.5 3421.75 3421.25 3421.6 125.4 2025-01-15T10:00:00Z 3422.1 3424.8 3420.5 3423.25 1245.7 2025-01-15T10:00:00Z 3420 3428.5 3418.75 3425.1 12847.2 2025-01-15T10:00:00Z 3415.25 3435.6 3410.8 3430.4 145238.6 2025-01-15T10:00:00Z 97850.2 97852.4 97848.1 97851.3 8.7 2025-01-15T10:00:00Z 5925.4 5926.1 5924.8 5925.7 0 2025-01-15T10:00:00Z 191.22 191.45 191.08 191.35 1247.3
2025-01-15T10:00:01Z 3421.6 3421.85 3421.45 3421.75 98.2 2025-01-15T10:01:00Z 3423.25 3425.9 3421.8 3424.6 1189.3 2025-01-15T11:00:00Z 3425.1 3432.2 3422.4 3429.8 11960.5 2025-01-16T10:00:00Z 3430.4 3445.2 3425.15 3440.85 138947.1 2025-01-15T10:00:01Z 97851.3 97853.8 97849.5 97852.9 9.1 2025-01-15T10:00:01Z 5925.7 5926.3 5925.2 5925.9 0 2025-01-15T10:00:01Z 191.35 191.58 191.15 191.48 1156.7
2025-01-15T10:00:02Z 3421.75 3421.95 3421.55 3421.8 110.6 2025-01-15T10:02:00Z 3424.6 3427.15 3423.4 3425.9 1356.8 2025-01-15T12:00:00Z 3429.8 3436.7 3427.2 3434.5 13205.9 2025-01-17T10:00:00Z 3440.85 3455.3 3438.9 3450.75 142568.3 2025-01-15T10:00:02Z 97852.9 97855.2 97850.7 97854.6 7.9 2025-01-15T10:00:02Z 5925.9 5926.5 5925.4 5926.1 0 2025-01-15T10:00:02Z 191.48 191.72 191.28 191.61 1298.4
2025-01-15T10:00:03Z 3421.8 3422.05 3421.65 3421.9 87.3 2025-01-15T10:03:00Z 3425.9 3428.4 3424.2 3427.1 1423.5 2025-01-15T13:00:00Z 3434.5 3441.8 3432.1 3438.2 14087.6 2025-01-18T10:00:00Z 3450.75 3465.4 3448.6 3460.2 149825.7 2025-01-15T10:00:03Z 97854.6 97857.1 97852.3 97856.8 8.4 2025-01-15T10:00:03Z 5926.1 5926.7 5925.6 5926.3 0 2025-01-15T10:00:03Z 191.61 191.85 191.42 191.74 1187.9
2025-01-15T10:00:04Z 3421.9 3422.15 3421.75 3422.0 134.7 2025-01-15T10:04:00Z 3427.1 3429.6 3425.8 3428.3 1298.2 2025-01-15T14:00:00Z 3438.2 3445.6 3436.4 3442.1 12734.8 2025-01-19T10:00:00Z 3460.2 3475.8 3457.4 3470.6 156742.4 2025-01-15T10:00:04Z 97856.8 97859.4 97854.9 97858.2 9.2 2025-01-15T10:00:04Z 5926.3 5926.9 5925.8 5926.5 0 2025-01-15T10:00:04Z 191.74 191.98 191.55 191.87 1342.6
2025-01-15T10:00:05Z 3422.0 3422.25 3421.85 3422.1 156.8 2025-01-15T10:05:00Z 3428.3 3430.8 3426.9 3429.5 1467.9 2025-01-15T15:00:00Z 3442.1 3449.3 3440.7 3446.8 11823.4 2025-01-20T10:00:00Z 3470.6 3485.2 3467.9 3480.1 163456.8 2025-01-15T10:00:05Z 97858.2 97860.7 97856.4 97859.8 8.8 2025-01-15T10:00:05Z 5926.5 5927.1 5926.0 5926.7 0 2025-01-15T10:00:05Z 191.87 192.11 191.68 192.0 1278.5
" },
{ "role": "user", "content": "2025-01-15T10:00:06Z" }
],
"temperature": 0.7,
"max_tokens": -1,
"stream": false
}'

View File

@@ -0,0 +1,16 @@
<!-- example text dump: -->
symbol MAIN SYMBOL (ETH) REF1 (BTC) REF2 (SPX) REF3 (SOL)
timeframe 1s 1m 1h 1d 1s 1s 1s
datapoint O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp
2025-01-15T10:00:00Z 3421.5 3421.75 3421.25 3421.6 125.4 2025-01-15T10:00:00Z 3422.1 3424.8 3420.5 3423.25 1245.7 2025-01-15T10:00:00Z 3420 3428.5 3418.75 3425.1 12847.2 2025-01-15T10:00:00Z 3415.25 3435.6 3410.8 3430.4 145238.6 2025-01-15T10:00:00Z 97850.2 97852.4 97848.1 97851.3 8.7 2025-01-15T10:00:00Z 5925.4 5926.1 5924.8 5925.7 0 2025-01-15T10:00:00Z 191.22 191.45 191.08 191.35 1247.3
2025-01-15T10:00:01Z 3421.6 3421.85 3421.45 3421.75 98.2 2025-01-15T10:01:00Z 3423.25 3425.9 3421.8 3424.6 1189.3 2025-01-15T11:00:00Z 3425.1 3432.2 3422.4 3429.8 11960.5 2025-01-16T10:00:00Z 3430.4 3445.2 3425.15 3440.85 138947.1 2025-01-15T10:00:01Z 97851.3 97853.8 97849.5 97852.9 9.1 2025-01-15T10:00:01Z 5925.7 5926.3 5925.2 5925.9 0 2025-01-15T10:00:01Z 191.35 191.58 191.15 191.48 1156.7
2025-01-15T10:00:02Z 3421.75 3421.95 3421.55 3421.8 110.6 2025-01-15T10:02:00Z 3424.6 3427.15 3423.4 3425.9 1356.8 2025-01-15T12:00:00Z 3429.8 3436.7 3427.2 3434.5 13205.9 2025-01-17T10:00:00Z 3440.85 3455.3 3438.9 3450.75 142568.3 2025-01-15T10:00:02Z 97852.9 97855.2 97850.7 97854.6 7.9 2025-01-15T10:00:02Z 5925.9 5926.5 5925.4 5926.1 0 2025-01-15T10:00:02Z 191.48 191.72 191.28 191.61 1298.4
2025-01-15T10:00:03Z 3421.8 3422.05 3421.65 3421.9 87.3 2025-01-15T10:03:00Z 3425.9 3428.4 3424.2 3427.1 1423.5 2025-01-15T13:00:00Z 3434.5 3441.8 3432.1 3438.2 14087.6 2025-01-18T10:00:00Z 3450.75 3465.4 3448.6 3460.2 149825.7 2025-01-15T10:00:03Z 97854.6 97857.1 97852.3 97856.8 8.4 2025-01-15T10:00:03Z 5926.1 5926.7 5925.6 5926.3 0 2025-01-15T10:00:03Z 191.61 191.85 191.42 191.74 1187.9
2025-01-15T10:00:04Z 3421.9 3422.15 3421.75 3422.0 134.7 2025-01-15T10:04:00Z 3427.1 3429.6 3425.8 3428.3 1298.2 2025-01-15T14:00:00Z 3438.2 3445.6 3436.4 3442.1 12734.8 2025-01-19T10:00:00Z 3460.2 3475.8 3457.4 3470.6 156742.4 2025-01-15T10:00:04Z 97856.8 97859.4 97854.9 97858.2 9.2 2025-01-15T10:00:04Z 5926.3 5926.9 5925.8 5926.5 0 2025-01-15T10:00:04Z 191.74 191.98 191.55 191.87 1342.6
2025-01-15T10:00:05Z 3422.0 3422.25 3421.85 3422.1 156.8 2025-01-15T10:05:00Z 3428.3 3430.8 3426.9 3429.5 1467.9 2025-01-15T15:00:00Z 3442.1 3449.3 3440.7 3446.8 11823.4 2025-01-20T10:00:00Z 3470.6 3485.2 3467.9 3480.1 163456.8 2025-01-15T10:00:05Z 97858.2 97860.7 97856.4 97859.8 8.8 2025-01-15T10:00:05Z 5926.5 5927.1 5926.0 5926.7 0 2025-01-15T10:00:05Z 191.87 192.11 191.68 192.0 1278.5

View File

@@ -0,0 +1,4 @@
symbol MAIN SYMBOL (ETH) REF1 (BTC) REF2 (SPX) REF3 (SOL)
timeframe 1s 1m 1h 1d 1s 1s 1s
datapoint O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp
2025-08-26T21:29:44Z 0 0 0 0 0 2025-08-26T21:29:44Z 0 0 0 0 0 2025-08-26T21:29:44Z 0 0 0 0 0 2025-08-26T21:29:44Z 0 0 0 0 0 2025-08-26T21:29:44Z 0 0 0 0 0 2025-08-26T21:29:44Z 5500.00 5520.00 5495.00 5510.00 1000000.0 2025-08-26T21:29:44Z 0 0 0 0 0 2025-08-26T21:29:44Z

View File

@@ -1,13 +0,0 @@
"""
Neural Network Utilities
======================
This package contains utility functions and classes used in the neural network trading system:
- Data Interface: Connects to realtime trading data and processes it for the neural network models
"""
from .data_interface import DataInterface
from .trading_env import TradingEnvironment
from .signal_interpreter import SignalInterpreter
__all__ = ['DataInterface', 'TradingEnvironment', 'SignalInterpreter']

View File

@@ -209,8 +209,8 @@ class DataInterface:
curr_close = data[window_size-1:-1, 3]
price_changes = (next_close - curr_close) / curr_close
# Define thresholds for price movement classification
threshold = 0.0005 # 0.05% threshold - smaller to encourage more signals
# Define thresholds for price movement classification
threshold = 0.001 # 0.10% threshold - prefer bigger moves to beat fees
y = np.zeros(len(price_changes), dtype=int)
y[price_changes > threshold] = 2 # Up
y[price_changes < -threshold] = 0 # Down

View File

@@ -1,123 +0,0 @@
"""
Enhanced Data Interface with additional NN trading parameters
"""
from typing import List, Optional, Tuple
import numpy as np
import pandas as pd
from datetime import datetime
from .data_interface import DataInterface
class MultiDataInterface(DataInterface):
"""
Enhanced data interface that supports window_size and output_size parameters
for neural network trading models.
"""
def __init__(self, symbol: str,
timeframes: List[str],
window_size: int = 20,
output_size: int = 3,
data_dir: str = "NN/data"):
"""
Initialize with window_size and output_size for NN predictions.
"""
super().__init__(symbol, timeframes, data_dir)
self.window_size = window_size
self.output_size = output_size
self.scalers = {} # Store scalers for each timeframe
self.min_window_threshold = 100 # Minimum candles needed for training
def get_feature_count(self) -> int:
"""
Get number of features (OHLCV) for NN input.
"""
return 5 # open, high, low, close, volume
def prepare_training_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Prepare training data with windowed sequences"""
# Get historical data for primary timeframe
primary_tf = self.timeframes[0]
df = self.get_historical_data(timeframe=primary_tf,
n_candles=self.min_window_threshold + 1000)
if df is None or len(df) < self.min_window_threshold:
raise ValueError(f"Insufficient data for training. Need at least {self.min_window_threshold} candles")
# Prepare OHLCV sequences
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values
# Create sequences and labels
X = []
y = []
for i in range(len(ohlcv) - self.window_size - self.output_size):
# Input sequence
seq = ohlcv[i:i+self.window_size]
X.append(seq)
# Output target (price movement direction)
close_prices = ohlcv[i+self.window_size:i+self.window_size+self.output_size, 3] # Close prices
price_changes = np.diff(close_prices)
if self.output_size == 1:
# Binary classification (up/down)
label = 1 if price_changes[0] > 0 else 0
elif self.output_size == 3:
# 3-class classification (buy/hold/sell)
if price_changes[0] > 0.002: # Significant rise
label = 0 # Buy
elif price_changes[0] < -0.002: # Significant drop
label = 2 # Sell
else:
label = 1 # Hold
else:
raise ValueError(f"Unsupported output_size: {self.output_size}")
y.append(label)
# Convert to numpy arrays
X = np.array(X)
y = np.array(y)
# Split into train/validation (80/20)
split_idx = int(0.8 * len(X))
X_train, y_train = X[:split_idx], y[:split_idx]
X_val, y_val = X[split_idx:], y[split_idx:]
return X_train, y_train, X_val, y_val
def prepare_prediction_data(self) -> np.ndarray:
"""Prepare most recent window for predictions"""
primary_tf = self.timeframes[0]
df = self.get_historical_data(timeframe=primary_tf,
n_candles=self.window_size,
use_cache=False)
if df is None or len(df) < self.window_size:
raise ValueError(f"Need at least {self.window_size} candles for prediction")
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values[-self.window_size:]
return np.array([ohlcv]) # Add batch dimension
def process_predictions(self, predictions: np.ndarray):
"""Convert prediction probabilities to trading signals"""
signals = []
for pred in predictions:
if self.output_size == 1:
signal = "BUY" if pred[0] > 0.5 else "SELL"
confidence = np.abs(pred[0] - 0.5) * 2 # Convert to 0-1 scale
elif self.output_size == 3:
action_idx = np.argmax(pred)
signal = ["BUY", "HOLD", "SELL"][action_idx]
confidence = pred[action_idx]
else:
signal = "HOLD"
confidence = 0.0
signals.append({
'action': signal,
'confidence': confidence,
'timestamp': datetime.now().isoformat()
})
return signals

View File

@@ -1,364 +0,0 @@
"""
Realtime Analyzer for Neural Network Trading System
This module implements real-time analysis of market data using trained neural network models.
"""
import logging
import time
import numpy as np
from threading import Thread
from queue import Queue
from datetime import datetime
import asyncio
import websockets
import json
import os
import pandas as pd
from collections import deque
logger = logging.getLogger(__name__)
class RealtimeAnalyzer:
"""
Handles real-time analysis of market data using trained neural network models.
Features:
- Connects to real-time data sources (websockets)
- Processes tick data into multiple timeframes (1s, 1m, 1h, 1d)
- Uses trained models to analyze all timeframes
- Generates trading signals
- Manages risk and position sizing
- Logs all trading decisions
"""
def __init__(self, data_interface, model, symbol="BTC/USDT", timeframes=None):
"""
Initialize the realtime analyzer.
Args:
data_interface (DataInterface): Preconfigured data interface
model: Trained neural network model
symbol (str): Trading pair symbol
timeframes (list): List of timeframes to monitor (default: ['1s', '1m', '1h', '1d'])
"""
self.data_interface = data_interface
self.model = model
self.symbol = symbol
self.timeframes = timeframes or ['1s', '1m', '1h', '1d']
self.running = False
self.data_queue = Queue()
self.prediction_interval = 10 # Seconds between predictions
self.ws_url = f"wss://stream.binance.com:9443/ws/{symbol.replace('/', '').lower()}@trade"
self.ws = None
self.tick_storage = deque(maxlen=10000) # Store up to 10,000 ticks
self.candle_cache = {
'1s': deque(maxlen=5000),
'1m': deque(maxlen=5000),
'1h': deque(maxlen=5000),
'1d': deque(maxlen=5000)
}
logger.info(f"RealtimeAnalyzer initialized for {symbol} with timeframes: {self.timeframes}")
def start(self):
"""Start the realtime analysis process."""
if self.running:
logger.warning("Realtime analyzer already running")
return
self.running = True
# Start WebSocket connection thread
self.ws_thread = Thread(target=self._run_websocket, daemon=True)
self.ws_thread.start()
# Start data processing thread
self.processing_thread = Thread(target=self._process_data, daemon=True)
self.processing_thread.start()
# Start analysis thread
self.analysis_thread = Thread(target=self._analyze_data, daemon=True)
self.analysis_thread.start()
logger.info("Realtime analysis started")
def stop(self):
"""Stop the realtime analysis process."""
self.running = False
if self.ws:
asyncio.run(self.ws.close())
if hasattr(self, 'ws_thread'):
self.ws_thread.join(timeout=1)
if hasattr(self, 'processing_thread'):
self.processing_thread.join(timeout=1)
if hasattr(self, 'analysis_thread'):
self.analysis_thread.join(timeout=1)
logger.info("Realtime analysis stopped")
def _run_websocket(self):
"""Thread function for running WebSocket connection."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._connect_websocket())
async def _connect_websocket(self):
"""Connect to WebSocket and receive data."""
while self.running:
try:
logger.info(f"Connecting to WebSocket: {self.ws_url}")
async with websockets.connect(self.ws_url) as ws:
self.ws = ws
logger.info("WebSocket connected")
while self.running:
try:
message = await ws.recv()
data = json.loads(message)
if 'e' in data and data['e'] == 'trade':
tick = {
'timestamp': data['T'],
'price': float(data['p']),
'volume': float(data['q']),
'symbol': self.symbol
}
self.tick_storage.append(tick)
self.data_queue.put(tick)
except websockets.exceptions.ConnectionClosed:
logger.warning("WebSocket connection closed")
break
except Exception as e:
logger.error(f"Error receiving WebSocket message: {str(e)}")
time.sleep(1)
except Exception as e:
logger.error(f"WebSocket connection error: {str(e)}")
time.sleep(5) # Wait before reconnecting
def _process_data(self):
"""Process incoming tick data into candles for all timeframes."""
logger.info("Starting data processing thread")
while self.running:
try:
# Process any new ticks
while not self.data_queue.empty():
tick = self.data_queue.get()
# Convert timestamp to datetime
timestamp = datetime.fromtimestamp(tick['timestamp'] / 1000)
# Process for each timeframe
for timeframe in self.timeframes:
interval = self._get_interval_seconds(timeframe)
if interval is None:
continue
# Round timestamp to nearest candle interval
candle_ts = int(tick['timestamp'] // (interval * 1000)) * (interval * 1000)
# Get or create candle for this timeframe
if not self.candle_cache[timeframe]:
# First candle for this timeframe
candle = {
'timestamp': candle_ts,
'open': tick['price'],
'high': tick['price'],
'low': tick['price'],
'close': tick['price'],
'volume': tick['volume']
}
self.candle_cache[timeframe].append(candle)
else:
# Update existing candle
last_candle = self.candle_cache[timeframe][-1]
if last_candle['timestamp'] == candle_ts:
# Update current candle
last_candle['high'] = max(last_candle['high'], tick['price'])
last_candle['low'] = min(last_candle['low'], tick['price'])
last_candle['close'] = tick['price']
last_candle['volume'] += tick['volume']
else:
# New candle
candle = {
'timestamp': candle_ts,
'open': tick['price'],
'high': tick['price'],
'low': tick['price'],
'close': tick['price'],
'volume': tick['volume']
}
self.candle_cache[timeframe].append(candle)
time.sleep(0.1)
except Exception as e:
logger.error(f"Error in data processing: {str(e)}")
time.sleep(1)
def _get_interval_seconds(self, timeframe):
"""Convert timeframe string to seconds."""
intervals = {
'1s': 1,
'1m': 60,
'1h': 3600,
'1d': 86400
}
return intervals.get(timeframe)
def _analyze_data(self):
"""Thread function for analyzing data and generating signals."""
logger.info("Starting analysis thread")
last_prediction_time = 0
while self.running:
try:
current_time = time.time()
# Only make predictions at the specified interval
if current_time - last_prediction_time < self.prediction_interval:
time.sleep(0.1)
continue
# Prepare input data from all timeframes
input_data = {}
valid = True
for timeframe in self.timeframes:
if not self.candle_cache[timeframe]:
logger.warning(f"No data available for timeframe {timeframe}")
valid = False
break
# Get last N candles for this timeframe
candles = list(self.candle_cache[timeframe])[-self.data_interface.window_size:]
# Convert to numpy array
ohlcv = np.array([
[c['open'], c['high'], c['low'], c['close'], c['volume']]
for c in candles
])
# Normalize data
ohlcv_normalized = (ohlcv - ohlcv.mean(axis=0)) / (ohlcv.std(axis=0) + 1e-8)
input_data[timeframe] = ohlcv_normalized
if not valid:
time.sleep(0.1)
continue
# Make prediction using the model
try:
prediction = self.model.predict(input_data)
# Get latest timestamp from 1s timeframe
latest_ts = self.candle_cache['1s'][-1]['timestamp'] if self.candle_cache['1s'] else int(time.time() * 1000)
# Process prediction
self._process_prediction(
prediction=prediction,
timeframe='multi',
timestamp=latest_ts
)
last_prediction_time = current_time
except Exception as e:
logger.error(f"Error making prediction: {str(e)}")
time.sleep(0.1)
except Exception as e:
logger.error(f"Error in analysis: {str(e)}")
time.sleep(1)
def _process_prediction(self, prediction, timeframe, timestamp):
"""
Process model prediction and generate trading signals.
Args:
prediction: Model prediction output
timeframe (str): Timeframe the prediction is for ('multi' for combined)
timestamp: Timestamp of the prediction (ms)
"""
# Convert prediction to trading signal
signal, confidence = self._prediction_to_signal(prediction)
# Convert timestamp to datetime
try:
dt = datetime.fromtimestamp(timestamp / 1000)
except:
dt = datetime.now()
# Log the signal with all timeframes
logger.info(
f"Signal generated - Timeframes: {', '.join(self.timeframes)}, "
f"Timestamp: {dt}, "
f"Signal: {signal} (Confidence: {confidence:.2f})"
)
# In a real implementation, we would execute trades here
# For now, we'll just log the signals
def _prediction_to_signal(self, prediction):
"""
Convert model prediction to trading signal and confidence.
Args:
prediction: Model prediction output (can be dict for multi-timeframe)
Returns:
tuple: (signal, confidence) where signal is BUY/SELL/HOLD,
confidence is probability (0-1)
"""
if isinstance(prediction, dict):
# Multi-timeframe prediction - combine signals
signals = []
confidences = []
for tf, pred in prediction.items():
if len(pred.shape) == 1:
# Binary classification
signal = "BUY" if pred[0] > 0.5 else "SELL"
confidence = pred[0] if signal == "BUY" else 1 - pred[0]
else:
# Multi-class
class_idx = np.argmax(pred)
signal = ["SELL", "HOLD", "BUY"][class_idx]
confidence = pred[class_idx]
signals.append(signal)
confidences.append(confidence)
# Simple voting system - count BUY/SELL signals
buy_count = signals.count("BUY")
sell_count = signals.count("SELL")
if buy_count > sell_count:
final_signal = "BUY"
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "BUY"])
elif sell_count > buy_count:
final_signal = "SELL"
final_confidence = np.mean([c for s, c in zip(signals, confidences) if s == "SELL"])
else:
final_signal = "HOLD"
final_confidence = np.mean(confidences)
return final_signal, final_confidence
else:
# Single prediction
if len(prediction.shape) == 1:
# Binary classification
signal = "BUY" if prediction[0] > 0.5 else "SELL"
confidence = prediction[0] if signal == "BUY" else 1 - prediction[0]
else:
# Multi-class
class_idx = np.argmax(prediction)
signal = ["SELL", "HOLD", "BUY"][class_idx]
confidence = prediction[class_idx]
return signal, confidence

View File

@@ -50,7 +50,7 @@ class SignalInterpreter:
self.oscillation_filter_enabled = self.config.get('oscillation_filter_enabled', False) # Disable oscillation filter by default
# Sensitivity parameters
self.min_price_movement = self.config.get('min_price_movement', 0.0001) # Lower price movement threshold
self.min_price_movement = self.config.get('min_price_movement', 0.001) # Align with deadzone; prefer bigger moves
self.hold_cooldown = self.config.get('hold_cooldown', 1) # Shorter hold cooldown
self.consecutive_signals_required = self.config.get('consecutive_signals_required', 1) # Require only one signal

View File

@@ -20,8 +20,10 @@ class TradingEnvironment(gym.Env):
window_size: int = 20,
risk_aversion: float = 0.2, # Controls how much to penalize volatility
price_scaling: str = 'zscore', # 'zscore', 'minmax', or 'raw'
reward_scaling: float = 10.0, # Scale factor for rewards
episode_penalty: float = 0.1): # Penalty for active positions at end of episode
reward_scaling: float = 10.0, # Scale factor for rewards
episode_penalty: float = 0.1, # Penalty for active positions at end of episode
min_profit_after_fees: float = 0.0005 # Deadzone: require >= 5 bps beyond fees
):
super(TradingEnvironment, self).__init__()
self.data = data
@@ -33,6 +35,7 @@ class TradingEnvironment(gym.Env):
self.price_scaling = price_scaling
self.reward_scaling = reward_scaling
self.episode_penalty = episode_penalty
self.min_profit_after_fees = max(0.0, float(min_profit_after_fees))
# Preprocess data if needed
self._preprocess_data()
@@ -177,8 +180,14 @@ class TradingEnvironment(gym.Env):
price_diff = current_price - self.entry_price
pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees
# Adjust reward based on PnL and risk
reward = pnl * self.reward_scaling
# Deadzone to discourage micro profits
if pnl > 0 and pnl < self.min_profit_after_fees:
reward = -self.fee_rate
elif pnl < 0 and abs(pnl) < self.min_profit_after_fees:
reward = pnl * self.reward_scaling * 0.5
else:
effective_pnl = pnl - (self.min_profit_after_fees if pnl > 0 else 0.0)
reward = effective_pnl * self.reward_scaling
# Track trade performance
self.total_trades += 1
@@ -212,8 +221,12 @@ class TradingEnvironment(gym.Env):
price_diff = current_price - self.entry_price
unrealized_pnl = price_diff / self.entry_price
# Small reward/penalty based on unrealized P&L
reward = unrealized_pnl * 0.05 # Scale down to encourage holding good positions
# Encourage holding only if unrealized edge exceeds deadzone
unrealized_edge = unrealized_pnl
if abs(unrealized_edge) >= self.min_profit_after_fees:
reward = unrealized_edge * (self.reward_scaling * 0.2)
else:
reward = -0.0002
elif self.position < 0: # Short position
if action == 0: # BUY (close short)
@@ -221,8 +234,13 @@ class TradingEnvironment(gym.Env):
price_diff = self.entry_price - current_price
pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees
# Adjust reward based on PnL and risk
reward = pnl * self.reward_scaling
if pnl > 0 and pnl < self.min_profit_after_fees:
reward = -self.fee_rate
elif pnl < 0 and abs(pnl) < self.min_profit_after_fees:
reward = pnl * self.reward_scaling * 0.5
else:
effective_pnl = pnl - (self.min_profit_after_fees if pnl > 0 else 0.0)
reward = effective_pnl * self.reward_scaling
# Track trade performance
self.total_trades += 1
@@ -256,8 +274,12 @@ class TradingEnvironment(gym.Env):
price_diff = self.entry_price - current_price
unrealized_pnl = price_diff / self.entry_price
# Small reward/penalty based on unrealized P&L
reward = unrealized_pnl * 0.05 # Scale down to encourage holding good positions
# Encourage holding only if unrealized edge exceeds deadzone
unrealized_edge = unrealized_pnl
if abs(unrealized_edge) >= self.min_profit_after_fees:
reward = unrealized_edge * (self.reward_scaling * 0.2)
else:
reward = -0.0002
# Record the action
self.actions_taken.append(action)