showing trades on realtime chart - chart broken
This commit is contained in:
parent
0ad9484d56
commit
a46b2c74f8
5
NN/exchanges/__init__.py
Normal file
5
NN/exchanges/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .exchange_interface import ExchangeInterface
|
||||
from .mexc_interface import MEXCInterface
|
||||
from .binance_interface import BinanceInterface
|
||||
|
||||
__all__ = ['ExchangeInterface', 'MEXCInterface', 'BinanceInterface']
|
276
NN/exchanges/binance_interface.py
Normal file
276
NN/exchanges/binance_interface.py
Normal file
@ -0,0 +1,276 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
import requests
|
||||
import hmac
|
||||
import hashlib
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from .exchange_interface import ExchangeInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BinanceInterface(ExchangeInterface):
|
||||
"""Binance Exchange API Interface"""
|
||||
|
||||
def __init__(self, api_key: str = None, api_secret: str = None, test_mode: bool = True):
|
||||
"""Initialize Binance exchange interface.
|
||||
|
||||
Args:
|
||||
api_key: Binance API key
|
||||
api_secret: Binance API secret
|
||||
test_mode: If True, use testnet environment
|
||||
"""
|
||||
super().__init__(api_key, api_secret, test_mode)
|
||||
|
||||
# Use testnet URLs if in test mode
|
||||
if test_mode:
|
||||
self.base_url = "https://testnet.binance.vision"
|
||||
else:
|
||||
self.base_url = "https://api.binance.com"
|
||||
|
||||
self.api_version = "v3"
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""Connect to Binance API. This is a no-op for REST API."""
|
||||
if not self.api_key or not self.api_secret:
|
||||
logger.warning("Binance API credentials not provided. Running in read-only mode.")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Test connection by pinging server and checking account info
|
||||
ping_result = self._send_public_request('GET', 'ping')
|
||||
|
||||
if self.api_key and self.api_secret:
|
||||
# Check account connectivity
|
||||
self.get_account_info()
|
||||
|
||||
logger.info(f"Successfully connected to Binance API ({'testnet' if self.test_mode else 'live'})")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Binance API: {str(e)}")
|
||||
return False
|
||||
|
||||
def _generate_signature(self, params: Dict[str, Any]) -> str:
|
||||
"""Generate signature for authenticated requests."""
|
||||
query_string = urlencode(params)
|
||||
signature = hmac.new(
|
||||
self.api_secret.encode('utf-8'),
|
||||
query_string.encode('utf-8'),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
return signature
|
||||
|
||||
def _send_public_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Send public request to Binance API."""
|
||||
url = f"{self.base_url}/api/{self.api_version}/{endpoint}"
|
||||
|
||||
try:
|
||||
if method.upper() == 'GET':
|
||||
response = requests.get(url, params=params)
|
||||
else:
|
||||
response = requests.post(url, json=params)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in public request to {endpoint}: {str(e)}")
|
||||
raise
|
||||
|
||||
def _send_private_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Send private/authenticated request to Binance API."""
|
||||
if not self.api_key or not self.api_secret:
|
||||
raise ValueError("API key and secret are required for private requests")
|
||||
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
# Add timestamp
|
||||
params['timestamp'] = int(time.time() * 1000)
|
||||
|
||||
# Generate signature
|
||||
signature = self._generate_signature(params)
|
||||
params['signature'] = signature
|
||||
|
||||
# Set headers
|
||||
headers = {
|
||||
'X-MBX-APIKEY': self.api_key
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/api/{self.api_version}/{endpoint}"
|
||||
|
||||
try:
|
||||
if method.upper() == 'GET':
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
elif method.upper() == 'POST':
|
||||
response = requests.post(url, data=params, headers=headers)
|
||||
elif method.upper() == 'DELETE':
|
||||
response = requests.delete(url, params=params, headers=headers)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
# Log detailed error if available
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Binance API error: {response.text}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in private request to {endpoint}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_account_info(self) -> Dict[str, Any]:
|
||||
"""Get account information."""
|
||||
return self._send_private_request('GET', 'account')
|
||||
|
||||
def get_balance(self, asset: str) -> float:
|
||||
"""Get balance of a specific asset.
|
||||
|
||||
Args:
|
||||
asset: Asset symbol (e.g., 'BTC', 'USDT')
|
||||
|
||||
Returns:
|
||||
float: Available balance of the asset
|
||||
"""
|
||||
try:
|
||||
account_info = self._send_private_request('GET', 'account')
|
||||
balances = account_info.get('balances', [])
|
||||
|
||||
for balance in balances:
|
||||
if balance['asset'] == asset:
|
||||
return float(balance['free'])
|
||||
|
||||
# Asset not found
|
||||
return 0.0
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting balance for {asset}: {str(e)}")
|
||||
return 0.0
|
||||
|
||||
def get_ticker(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get current ticker data for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
|
||||
Returns:
|
||||
dict: Ticker data including price information
|
||||
"""
|
||||
binance_symbol = symbol.replace('/', '')
|
||||
try:
|
||||
ticker = self._send_public_request('GET', 'ticker/24hr', {'symbol': binance_symbol})
|
||||
|
||||
# Convert to a standardized format
|
||||
result = {
|
||||
'symbol': symbol,
|
||||
'bid': float(ticker['bidPrice']),
|
||||
'ask': float(ticker['askPrice']),
|
||||
'last': float(ticker['lastPrice']),
|
||||
'volume': float(ticker['volume']),
|
||||
'timestamp': int(ticker['closeTime'])
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
|
||||
raise
|
||||
|
||||
def place_order(self, symbol: str, side: str, order_type: str,
|
||||
quantity: float, price: float = None) -> Dict[str, Any]:
|
||||
"""Place an order on the exchange.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
side: Order side ('buy' or 'sell')
|
||||
order_type: Order type ('market', 'limit', etc.)
|
||||
quantity: Order quantity
|
||||
price: Order price (for limit orders)
|
||||
|
||||
Returns:
|
||||
dict: Order information including order ID
|
||||
"""
|
||||
binance_symbol = symbol.replace('/', '')
|
||||
params = {
|
||||
'symbol': binance_symbol,
|
||||
'side': side.upper(),
|
||||
'type': order_type.upper(),
|
||||
'quantity': quantity,
|
||||
}
|
||||
|
||||
if order_type.lower() == 'limit' and price is not None:
|
||||
params['price'] = price
|
||||
params['timeInForce'] = 'GTC' # Good Till Cancelled
|
||||
|
||||
# Use test order endpoint in test mode
|
||||
endpoint = 'order/test' if self.test_mode else 'order'
|
||||
|
||||
try:
|
||||
order_result = self._send_private_request('POST', endpoint, params)
|
||||
return order_result
|
||||
except Exception as e:
|
||||
logger.error(f"Error placing {side} {order_type} order for {symbol}: {str(e)}")
|
||||
raise
|
||||
|
||||
def cancel_order(self, symbol: str, order_id: str) -> bool:
|
||||
"""Cancel an existing order.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
order_id: ID of the order to cancel
|
||||
|
||||
Returns:
|
||||
bool: True if cancellation successful, False otherwise
|
||||
"""
|
||||
binance_symbol = symbol.replace('/', '')
|
||||
params = {
|
||||
'symbol': binance_symbol,
|
||||
'orderId': order_id
|
||||
}
|
||||
|
||||
try:
|
||||
cancel_result = self._send_private_request('DELETE', 'order', params)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling order {order_id} for {symbol}: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]:
|
||||
"""Get status of an existing order.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
order_id: ID of the order
|
||||
|
||||
Returns:
|
||||
dict: Order status information
|
||||
"""
|
||||
binance_symbol = symbol.replace('/', '')
|
||||
params = {
|
||||
'symbol': binance_symbol,
|
||||
'orderId': order_id
|
||||
}
|
||||
|
||||
try:
|
||||
order_info = self._send_private_request('GET', 'order', params)
|
||||
return order_info
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting order status for {order_id} on {symbol}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]:
|
||||
"""Get all open orders, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT'), or None for all symbols
|
||||
|
||||
Returns:
|
||||
list: List of open orders
|
||||
"""
|
||||
params = {}
|
||||
if symbol:
|
||||
params['symbol'] = symbol.replace('/', '')
|
||||
|
||||
try:
|
||||
open_orders = self._send_private_request('GET', 'openOrders', params)
|
||||
return open_orders
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting open orders: {str(e)}")
|
||||
return []
|
191
NN/exchanges/exchange_interface.py
Normal file
191
NN/exchanges/exchange_interface.py
Normal file
@ -0,0 +1,191 @@
|
||||
import abc
|
||||
import logging
|
||||
from typing import Dict, Any, List, Tuple, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ExchangeInterface(abc.ABC):
|
||||
"""Base class for all exchange interfaces.
|
||||
|
||||
This abstract class defines the required methods that all exchange
|
||||
implementations must provide to ensure compatibility with the trading system.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str = None, api_secret: str = None, test_mode: bool = True):
|
||||
"""Initialize the exchange interface.
|
||||
|
||||
Args:
|
||||
api_key: API key for the exchange
|
||||
api_secret: API secret for the exchange
|
||||
test_mode: If True, use test/sandbox environment
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.test_mode = test_mode
|
||||
self.client = None
|
||||
self.last_price_cache = {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self) -> bool:
|
||||
"""Connect to the exchange API.
|
||||
|
||||
Returns:
|
||||
bool: True if connection successful, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_balance(self, asset: str) -> float:
|
||||
"""Get balance of a specific asset.
|
||||
|
||||
Args:
|
||||
asset: Asset symbol (e.g., 'BTC', 'USDT')
|
||||
|
||||
Returns:
|
||||
float: Available balance of the asset
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_ticker(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get current ticker data for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
|
||||
Returns:
|
||||
dict: Ticker data including price information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def place_order(self, symbol: str, side: str, order_type: str,
|
||||
quantity: float, price: float = None) -> Dict[str, Any]:
|
||||
"""Place an order on the exchange.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
side: Order side ('buy' or 'sell')
|
||||
order_type: Order type ('market', 'limit', etc.)
|
||||
quantity: Order quantity
|
||||
price: Order price (for limit orders)
|
||||
|
||||
Returns:
|
||||
dict: Order information including order ID
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def cancel_order(self, symbol: str, order_id: str) -> bool:
|
||||
"""Cancel an existing order.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
order_id: ID of the order to cancel
|
||||
|
||||
Returns:
|
||||
bool: True if cancellation successful, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]:
|
||||
"""Get status of an existing order.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
order_id: ID of the order
|
||||
|
||||
Returns:
|
||||
dict: Order status information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]:
|
||||
"""Get all open orders, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT'), or None for all symbols
|
||||
|
||||
Returns:
|
||||
list: List of open orders
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_last_price(self, symbol: str) -> float:
|
||||
"""Get last known price for a symbol, may use cached value.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
|
||||
Returns:
|
||||
float: Last price
|
||||
"""
|
||||
try:
|
||||
ticker = self.get_ticker(symbol)
|
||||
price = float(ticker['last'])
|
||||
self.last_price_cache[symbol] = price
|
||||
return price
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting price for {symbol}: {str(e)}")
|
||||
# Return cached price if available
|
||||
return self.last_price_cache.get(symbol, 0.0)
|
||||
|
||||
def execute_trade(self, symbol: str, action: str, quantity: float = None,
|
||||
percent_of_balance: float = None) -> Optional[Dict[str, Any]]:
|
||||
"""Execute a trade based on a signal.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
action: Trade action ('BUY', 'SELL')
|
||||
quantity: Specific quantity to trade
|
||||
percent_of_balance: Alternative to quantity - percentage of available balance to use
|
||||
|
||||
Returns:
|
||||
dict: Order information or None if order failed
|
||||
"""
|
||||
if action not in ['BUY', 'SELL']:
|
||||
logger.error(f"Invalid action: {action}. Must be 'BUY' or 'SELL'")
|
||||
return None
|
||||
|
||||
side = action.lower()
|
||||
|
||||
try:
|
||||
# Determine base and quote assets from symbol (e.g., BTC/USDT -> BTC, USDT)
|
||||
base_asset, quote_asset = symbol.split('/')
|
||||
|
||||
# Calculate quantity if percent_of_balance is provided
|
||||
if quantity is None and percent_of_balance is not None:
|
||||
if percent_of_balance <= 0 or percent_of_balance > 1:
|
||||
logger.error(f"Invalid percent_of_balance: {percent_of_balance}. Must be between 0 and 1")
|
||||
return None
|
||||
|
||||
if side == 'buy':
|
||||
# For buy, use quote asset (e.g., USDT)
|
||||
balance = self.get_balance(quote_asset)
|
||||
price = self.get_last_price(symbol)
|
||||
quantity = (balance * percent_of_balance) / price
|
||||
else:
|
||||
# For sell, use base asset (e.g., BTC)
|
||||
balance = self.get_balance(base_asset)
|
||||
quantity = balance * percent_of_balance
|
||||
|
||||
if not quantity or quantity <= 0:
|
||||
logger.error(f"Invalid quantity: {quantity}")
|
||||
return None
|
||||
|
||||
# Place market order
|
||||
order = self.place_order(
|
||||
symbol=symbol,
|
||||
side=side,
|
||||
order_type='market',
|
||||
quantity=quantity
|
||||
)
|
||||
|
||||
logger.info(f"Executed {side.upper()} order for {quantity} {base_asset} at market price")
|
||||
return order
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing {action} trade for {symbol}: {str(e)}")
|
||||
return None
|
258
NN/exchanges/mexc_interface.py
Normal file
258
NN/exchanges/mexc_interface.py
Normal file
@ -0,0 +1,258 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
import requests
|
||||
import hmac
|
||||
import hashlib
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from .exchange_interface import ExchangeInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MEXCInterface(ExchangeInterface):
|
||||
"""MEXC Exchange API Interface"""
|
||||
|
||||
def __init__(self, api_key: str = None, api_secret: str = None, test_mode: bool = True):
|
||||
"""Initialize MEXC exchange interface.
|
||||
|
||||
Args:
|
||||
api_key: MEXC API key
|
||||
api_secret: MEXC API secret
|
||||
test_mode: If True, use test/sandbox environment (Note: MEXC doesn't have a true sandbox)
|
||||
"""
|
||||
super().__init__(api_key, api_secret, test_mode)
|
||||
self.base_url = "https://api.mexc.com"
|
||||
self.api_version = "v3"
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""Connect to MEXC API. This is a no-op for REST API."""
|
||||
if not self.api_key or not self.api_secret:
|
||||
logger.warning("MEXC API credentials not provided. Running in read-only mode.")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Test connection by getting account info
|
||||
self.get_account_info()
|
||||
logger.info("Successfully connected to MEXC API")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to MEXC API: {str(e)}")
|
||||
return False
|
||||
|
||||
def _generate_signature(self, params: Dict[str, Any]) -> str:
|
||||
"""Generate signature for authenticated requests."""
|
||||
query_string = urlencode(params)
|
||||
signature = hmac.new(
|
||||
self.api_secret.encode('utf-8'),
|
||||
query_string.encode('utf-8'),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
return signature
|
||||
|
||||
def _send_public_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Send public request to MEXC API."""
|
||||
url = f"{self.base_url}/{self.api_version}/{endpoint}"
|
||||
|
||||
try:
|
||||
if method.upper() == 'GET':
|
||||
response = requests.get(url, params=params)
|
||||
else:
|
||||
response = requests.post(url, json=params)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in public request to {endpoint}: {str(e)}")
|
||||
raise
|
||||
|
||||
def _send_private_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Send private/authenticated request to MEXC API."""
|
||||
if not self.api_key or not self.api_secret:
|
||||
raise ValueError("API key and secret are required for private requests")
|
||||
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
# Add timestamp
|
||||
params['timestamp'] = int(time.time() * 1000)
|
||||
|
||||
# Generate signature
|
||||
signature = self._generate_signature(params)
|
||||
params['signature'] = signature
|
||||
|
||||
# Set headers
|
||||
headers = {
|
||||
'X-MEXC-APIKEY': self.api_key
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/{self.api_version}/{endpoint}"
|
||||
|
||||
try:
|
||||
if method.upper() == 'GET':
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
elif method.upper() == 'POST':
|
||||
response = requests.post(url, json=params, headers=headers)
|
||||
elif method.upper() == 'DELETE':
|
||||
response = requests.delete(url, params=params, headers=headers)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in private request to {endpoint}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_account_info(self) -> Dict[str, Any]:
|
||||
"""Get account information."""
|
||||
return self._send_private_request('GET', 'account')
|
||||
|
||||
def get_balance(self, asset: str) -> float:
|
||||
"""Get balance of a specific asset.
|
||||
|
||||
Args:
|
||||
asset: Asset symbol (e.g., 'BTC', 'USDT')
|
||||
|
||||
Returns:
|
||||
float: Available balance of the asset
|
||||
"""
|
||||
try:
|
||||
account_info = self._send_private_request('GET', 'account')
|
||||
balances = account_info.get('balances', [])
|
||||
|
||||
for balance in balances:
|
||||
if balance['asset'] == asset:
|
||||
return float(balance['free'])
|
||||
|
||||
# Asset not found
|
||||
return 0.0
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting balance for {asset}: {str(e)}")
|
||||
return 0.0
|
||||
|
||||
def get_ticker(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get current ticker data for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
|
||||
Returns:
|
||||
dict: Ticker data including price information
|
||||
"""
|
||||
mexc_symbol = symbol.replace('/', '')
|
||||
try:
|
||||
ticker = self._send_public_request('GET', 'ticker/24hr', {'symbol': mexc_symbol})
|
||||
|
||||
# Convert to a standardized format
|
||||
result = {
|
||||
'symbol': symbol,
|
||||
'bid': float(ticker['bidPrice']),
|
||||
'ask': float(ticker['askPrice']),
|
||||
'last': float(ticker['lastPrice']),
|
||||
'volume': float(ticker['volume']),
|
||||
'timestamp': int(ticker['closeTime'])
|
||||
}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
|
||||
raise
|
||||
|
||||
def place_order(self, symbol: str, side: str, order_type: str,
|
||||
quantity: float, price: float = None) -> Dict[str, Any]:
|
||||
"""Place an order on the exchange.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
side: Order side ('buy' or 'sell')
|
||||
order_type: Order type ('market', 'limit', etc.)
|
||||
quantity: Order quantity
|
||||
price: Order price (for limit orders)
|
||||
|
||||
Returns:
|
||||
dict: Order information including order ID
|
||||
"""
|
||||
mexc_symbol = symbol.replace('/', '')
|
||||
params = {
|
||||
'symbol': mexc_symbol,
|
||||
'side': side.upper(),
|
||||
'type': order_type.upper(),
|
||||
'quantity': quantity,
|
||||
}
|
||||
|
||||
if order_type.lower() == 'limit' and price is not None:
|
||||
params['price'] = price
|
||||
params['timeInForce'] = 'GTC' # Good Till Cancelled
|
||||
|
||||
try:
|
||||
order_result = self._send_private_request('POST', 'order', params)
|
||||
return order_result
|
||||
except Exception as e:
|
||||
logger.error(f"Error placing {side} {order_type} order for {symbol}: {str(e)}")
|
||||
raise
|
||||
|
||||
def cancel_order(self, symbol: str, order_id: str) -> bool:
|
||||
"""Cancel an existing order.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
order_id: ID of the order to cancel
|
||||
|
||||
Returns:
|
||||
bool: True if cancellation successful, False otherwise
|
||||
"""
|
||||
mexc_symbol = symbol.replace('/', '')
|
||||
params = {
|
||||
'symbol': mexc_symbol,
|
||||
'orderId': order_id
|
||||
}
|
||||
|
||||
try:
|
||||
cancel_result = self._send_private_request('DELETE', 'order', params)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling order {order_id} for {symbol}: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]:
|
||||
"""Get status of an existing order.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
order_id: ID of the order
|
||||
|
||||
Returns:
|
||||
dict: Order status information
|
||||
"""
|
||||
mexc_symbol = symbol.replace('/', '')
|
||||
params = {
|
||||
'symbol': mexc_symbol,
|
||||
'orderId': order_id
|
||||
}
|
||||
|
||||
try:
|
||||
order_info = self._send_private_request('GET', 'order', params)
|
||||
return order_info
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting order status for {order_id} on {symbol}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]:
|
||||
"""Get all open orders, optionally filtered by symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT'), or None for all symbols
|
||||
|
||||
Returns:
|
||||
list: List of open orders
|
||||
"""
|
||||
params = {}
|
||||
if symbol:
|
||||
params['symbol'] = symbol.replace('/', '')
|
||||
|
||||
try:
|
||||
open_orders = self._send_private_request('GET', 'openOrders', params)
|
||||
return open_orders
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting open orders: {str(e)}")
|
||||
return []
|
244
NN/main.py
Normal file
244
NN/main.py
Normal file
@ -0,0 +1,244 @@
|
||||
"""
|
||||
Neural Network Trading System Main Module (Compatibility Layer)
|
||||
|
||||
This module serves as a compatibility layer for the realtime.py module.
|
||||
It re-exports the functionality from realtime_main.py that is needed by realtime.py.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger('NN')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Re-export everything from realtime_main.py
|
||||
from .realtime_main import (
|
||||
parse_arguments,
|
||||
realtime,
|
||||
train,
|
||||
predict
|
||||
)
|
||||
|
||||
# Create a class that realtime.py expects
|
||||
class NeuralNetworkOrchestrator:
|
||||
"""
|
||||
Orchestrates the neural network operations.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Initialize the orchestrator with configuration.
|
||||
|
||||
Args:
|
||||
config (dict): Configuration parameters
|
||||
"""
|
||||
self.config = config
|
||||
self.symbol = config.get('symbol', 'BTC/USDT')
|
||||
self.timeframes = config.get('timeframes', ['1m', '5m', '1h', '4h'])
|
||||
self.window_size = config.get('window_size', 20)
|
||||
self.n_features = config.get('n_features', 5)
|
||||
self.output_size = config.get('output_size', 3)
|
||||
self.model_dir = config.get('model_dir', 'NN/models/saved')
|
||||
self.data_dir = config.get('data_dir', 'NN/data')
|
||||
self.model = None
|
||||
self.data_interface = None
|
||||
|
||||
# Initialize with default values in case imports fail
|
||||
self.model_initialized = False
|
||||
self.data_initialized = False
|
||||
|
||||
# Import necessary modules dynamically
|
||||
try:
|
||||
from .utils.data_interface import DataInterface
|
||||
|
||||
# Initialize data interface
|
||||
self.data_interface = DataInterface(
|
||||
symbol=self.symbol,
|
||||
timeframes=self.timeframes
|
||||
)
|
||||
self.data_initialized = True
|
||||
logger.info(f"Data interface initialized for {self.symbol}")
|
||||
|
||||
try:
|
||||
from .models.cnn_model_pytorch import CNNModelPyTorch as Model
|
||||
|
||||
# Initialize model
|
||||
feature_count = self.data_interface.get_feature_count() if hasattr(self.data_interface, 'get_feature_count') else 5
|
||||
try:
|
||||
# First try with expected parameters
|
||||
self.model = Model(
|
||||
window_size=self.window_size,
|
||||
num_features=feature_count,
|
||||
output_size=self.output_size,
|
||||
timeframes=self.timeframes
|
||||
)
|
||||
except TypeError as e:
|
||||
logger.warning(f"TypeError in model initialization with num_features: {str(e)}")
|
||||
# Try alternate parameter naming
|
||||
try:
|
||||
self.model = Model(
|
||||
input_shape=(self.window_size, feature_count),
|
||||
output_size=self.output_size
|
||||
)
|
||||
logger.info("Model initialized with alternate parameters")
|
||||
except Exception as ex:
|
||||
logger.error(f"Failed to initialize model with alternate parameters: {str(ex)}")
|
||||
self.model = DummyModel()
|
||||
|
||||
# Try to load the best model
|
||||
self._load_model()
|
||||
self.model_initialized = True
|
||||
logger.info("Model initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing model: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
self.model = DummyModel()
|
||||
|
||||
logger.info(f"NeuralNetworkOrchestrator initialized with config: {config}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing NeuralNetworkOrchestrator: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
self.model = DummyModel()
|
||||
|
||||
def _load_model(self):
|
||||
"""Load the best trained model from available files"""
|
||||
try:
|
||||
model_paths = [
|
||||
os.path.join(self.model_dir, "dqn_agent_best_policy.pt"),
|
||||
os.path.join(self.model_dir, "cnn_model_best.pt"),
|
||||
os.path.join("models/saved", "dqn_agent_best_policy.pt"),
|
||||
os.path.join("models/saved", "cnn_model_best.pt")
|
||||
]
|
||||
|
||||
for model_path in model_paths:
|
||||
if os.path.exists(model_path):
|
||||
try:
|
||||
self.model.load(model_path)
|
||||
logger.info(f"Loaded model from {model_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load model from {model_path}: {str(e)}")
|
||||
continue
|
||||
|
||||
logger.warning("No trained model found, using dummy model")
|
||||
self.model = DummyModel()
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {str(e)}")
|
||||
self.model = DummyModel()
|
||||
return False
|
||||
|
||||
def run_inference_pipeline(self, model_type='cnn', timeframe='1h'):
|
||||
"""
|
||||
Run the inference pipeline using the trained model.
|
||||
|
||||
Args:
|
||||
model_type (str): Type of model to use (cnn, transformer, etc.)
|
||||
timeframe (str): Timeframe to use for inference
|
||||
|
||||
Returns:
|
||||
dict: Inference result
|
||||
"""
|
||||
try:
|
||||
# Check if we have a model
|
||||
if not hasattr(self, 'model') or self.model is None:
|
||||
logger.warning("No model available, initializing dummy model")
|
||||
self.model = DummyModel()
|
||||
|
||||
# Check if we have a data interface
|
||||
if not hasattr(self, 'data_interface') or self.data_interface is None:
|
||||
logger.warning("No data interface available")
|
||||
# Return a dummy prediction
|
||||
return self._get_dummy_prediction()
|
||||
|
||||
# Prepare input data for the selected timeframe
|
||||
X, timestamp = self.data_interface.prepare_realtime_input(
|
||||
timeframe=timeframe,
|
||||
n_candles=self.window_size + 10, # Extra candles for safety
|
||||
window_size=self.window_size
|
||||
)
|
||||
|
||||
if X is None:
|
||||
logger.warning(f"No data available for {self.symbol}")
|
||||
return self._get_dummy_prediction()
|
||||
|
||||
# Get model predictions
|
||||
action_probs, price_pred = self.model.predict(X)
|
||||
|
||||
# Convert predictions to action
|
||||
action_idx = np.argmax(action_probs) if hasattr(action_probs, 'argmax') else 1 # Default to HOLD
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
action = action_names[action_idx]
|
||||
|
||||
# Format timestamp
|
||||
if not isinstance(timestamp, str):
|
||||
try:
|
||||
if hasattr(timestamp, 'isoformat'): # If it's already a datetime-like object
|
||||
timestamp = timestamp.isoformat()
|
||||
else: # If it's a numeric timestamp
|
||||
timestamp = datetime.fromtimestamp(float(timestamp)/1000).isoformat()
|
||||
except (TypeError, ValueError):
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
# Return result
|
||||
result = {
|
||||
'timestamp': timestamp,
|
||||
'action': action,
|
||||
'action_index': int(action_idx),
|
||||
'probability': float(action_probs[action_idx]) if hasattr(action_probs, '__getitem__') else 0.33,
|
||||
'probabilities': {name: float(prob) for name, prob in zip(action_names, action_probs)} if hasattr(action_probs, '__iter__') else {'SELL': 0.33, 'HOLD': 0.34, 'BUY': 0.33},
|
||||
'price_prediction': float(price_pred) if price_pred is not None else None
|
||||
}
|
||||
|
||||
logger.info(f"Inference result: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in inference pipeline: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return self._get_dummy_prediction()
|
||||
|
||||
def _get_dummy_prediction(self):
|
||||
"""Return a dummy prediction when model or data is unavailable"""
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
action_idx = 1 # Default to HOLD
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
return {
|
||||
'timestamp': timestamp,
|
||||
'action': 'HOLD',
|
||||
'action_index': action_idx,
|
||||
'probability': 0.8,
|
||||
'probabilities': {'SELL': 0.1, 'HOLD': 0.8, 'BUY': 0.1},
|
||||
'price_prediction': None,
|
||||
'is_dummy': True
|
||||
}
|
||||
|
||||
|
||||
class DummyModel:
|
||||
"""Dummy model that returns random predictions"""
|
||||
|
||||
def __init__(self):
|
||||
logger.info("Initializing dummy model")
|
||||
|
||||
def predict(self, X):
|
||||
"""Return random predictions"""
|
||||
# Generate random probabilities for SELL, HOLD, BUY
|
||||
action_probs = np.array([0.1, 0.8, 0.1]) # Bias towards HOLD
|
||||
|
||||
# Generate a random price prediction (None for now)
|
||||
price_pred = None
|
||||
|
||||
return action_probs, price_pred
|
||||
|
||||
def load(self, model_path):
|
||||
"""Dummy load method"""
|
||||
logger.info(f"Dummy model pretending to load from {model_path}")
|
||||
return True
|
@ -72,6 +72,9 @@ class CNNPyTorch(nn.Module):
|
||||
"""
|
||||
super(CNNPyTorch, self).__init__()
|
||||
|
||||
# Set device
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
window_size, num_features = input_shape
|
||||
self.window_size = window_size
|
||||
|
||||
@ -225,75 +228,94 @@ class CNNModelPyTorch:
|
||||
predictions with the CNN model, optimized for short-term trading opportunities.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size, num_features, output_size=3, timeframes=None):
|
||||
def __init__(self, window_size=20, timeframes=None, output_size=3, num_pairs=3):
|
||||
"""
|
||||
Initialize the CNN model.
|
||||
|
||||
Args:
|
||||
window_size (int): Size of the input window
|
||||
num_features (int): Number of features in the input data
|
||||
output_size (int): Size of the output (default: 3 for BUY/HOLD/SELL)
|
||||
timeframes (list): List of timeframes used (for logging)
|
||||
window_size (int): Size of the sliding window
|
||||
timeframes (list): List of timeframes used
|
||||
output_size (int): Number of output classes (3 for BUY/HOLD/SELL)
|
||||
num_pairs (int): Number of trading pairs to analyze in parallel (default 3)
|
||||
"""
|
||||
# Action tracking
|
||||
self.action_counts = {
|
||||
'BUY': 0,
|
||||
'SELL': 0,
|
||||
'HOLD': 0
|
||||
}
|
||||
self.window_size = window_size
|
||||
self.num_features = num_features
|
||||
self.timeframes = timeframes if timeframes else ["1m", "5m", "15m"]
|
||||
self.output_size = output_size
|
||||
self.timeframes = timeframes or []
|
||||
self.num_pairs = num_pairs
|
||||
|
||||
# Determine device (GPU or CPU)
|
||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Using device: {self.device}")
|
||||
# Calculate total features (5 OHLCV features per timeframe per pair)
|
||||
self.total_features = len(self.timeframes) * 5 * self.num_pairs
|
||||
|
||||
# Initialize model
|
||||
self.model = None
|
||||
self.build_model()
|
||||
# Build the model
|
||||
logger.info(f"Building PyTorch CNN model with window_size={window_size}, "
|
||||
f"num_features={self.total_features}, output_size={output_size}, "
|
||||
f"num_pairs={num_pairs}")
|
||||
|
||||
# Initialize training history
|
||||
self.history = {
|
||||
'loss': [],
|
||||
'val_loss': [],
|
||||
'accuracy': [],
|
||||
'val_accuracy': []
|
||||
}
|
||||
# Calculate channel sizes that are divisible by num_pairs
|
||||
base_channels = 96 # 96 is divisible by 3
|
||||
self.model = nn.Sequential(
|
||||
# First convolutional layer - process each pair's features
|
||||
nn.Sequential(
|
||||
nn.Conv1d(self.total_features, base_channels, kernel_size=5, padding=2, groups=num_pairs),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(base_channels),
|
||||
nn.Dropout(0.2)
|
||||
),
|
||||
|
||||
# Sensitivity parameters for high-leverage trading
|
||||
self.confidence_threshold = 0.65 # Minimum confidence for trading actions
|
||||
self.max_consecutive_same_action = 3 # Limit consecutive identical actions
|
||||
self.last_actions = [] # Track recent actions
|
||||
# Second convolutional layer - start mixing pair information
|
||||
nn.Sequential(
|
||||
nn.Conv1d(base_channels, base_channels*2, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(base_channels*2),
|
||||
nn.Dropout(0.2)
|
||||
),
|
||||
|
||||
def build_model(self):
|
||||
"""Build the CNN model architecture"""
|
||||
logger.info(f"Building PyTorch CNN model with window_size={self.window_size}, "
|
||||
f"num_features={self.num_features}, output_size={self.output_size}")
|
||||
# Third convolutional layer - deeper feature extraction
|
||||
nn.Sequential(
|
||||
nn.Conv1d(base_channels*2, base_channels*4, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(base_channels*4),
|
||||
nn.Dropout(0.2)
|
||||
),
|
||||
|
||||
# Ensure window size is not less than the actual input
|
||||
input_window_size = max(self.window_size, 20) # Use at least 20 as minimum window size
|
||||
# Global average pooling
|
||||
nn.AdaptiveAvgPool1d(1),
|
||||
|
||||
self.model = CNNPyTorch(
|
||||
input_shape=(input_window_size, self.num_features),
|
||||
output_size=self.output_size
|
||||
# Flatten
|
||||
nn.Flatten(),
|
||||
|
||||
# Dense layers for action prediction with cross-pair attention
|
||||
nn.Sequential(
|
||||
nn.Linear(base_channels*4, base_channels*2),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(base_channels*2, base_channels),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(base_channels, output_size * num_pairs) # Output for each pair
|
||||
)
|
||||
).to(self.device)
|
||||
|
||||
# Initialize optimizer with higher learning rate for faster adaptation
|
||||
self.optimizer = optim.Adam(self.model.parameters(), lr=0.002)
|
||||
|
||||
# Learning rate scheduler with faster decay
|
||||
# Initialize optimizer and loss function
|
||||
self.optimizer = optim.Adam(self.model.parameters(), lr=0.0005)
|
||||
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.optimizer, mode='max', factor=0.6, patience=6, verbose=True
|
||||
self.optimizer, mode='max', factor=0.5, patience=5, verbose=True
|
||||
)
|
||||
self.criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# Initialize loss function with higher weights for BUY/SELL
|
||||
class_weights = torch.tensor([7.0, 1.0, 7.0]).to(self.device) # Even higher weights for BUY/SELL
|
||||
self.criterion = nn.CrossEntropyLoss(weight=class_weights)
|
||||
# Initialize metrics tracking
|
||||
self.train_losses = []
|
||||
self.val_losses = []
|
||||
self.train_accuracies = []
|
||||
self.val_accuracies = []
|
||||
|
||||
logger.info(f"Model built successfully with {sum(p.numel() for p in self.model.parameters())} parameters")
|
||||
|
||||
# Sensitivity parameters for high-leverage trading
|
||||
self.confidence_threshold = 0.65
|
||||
self.max_consecutive_same_action = 3
|
||||
self.last_actions = [[] for _ in range(num_pairs)] # Track recent actions per pair
|
||||
|
||||
def compute_trading_loss(self, action_probs, price_pred, targets, future_prices=None):
|
||||
"""
|
||||
Custom loss function that prioritizes profitable trades
|
||||
@ -644,12 +666,12 @@ class CNNModelPyTorch:
|
||||
action_probs_np[:, 2] *= 1.3 # Boost BUY probabilities
|
||||
|
||||
# Implement signal filtering based on previous actions to avoid oscillation
|
||||
if len(self.last_actions) >= self.max_consecutive_same_action:
|
||||
if len(self.last_actions[0]) >= self.max_consecutive_same_action:
|
||||
# Check for too many consecutive identical actions
|
||||
if all(a == 0 for a in self.last_actions[-self.max_consecutive_same_action:]):
|
||||
if all(a == 0 for a in self.last_actions[0][-self.max_consecutive_same_action:]):
|
||||
# Too many consecutive SELL - reduce sell probability
|
||||
action_probs_np[:, 0] *= 0.7
|
||||
elif all(a == 2 for a in self.last_actions[-self.max_consecutive_same_action:]):
|
||||
elif all(a == 2 for a in self.last_actions[0][-self.max_consecutive_same_action:]):
|
||||
# Too many consecutive BUY - reduce buy probability
|
||||
action_probs_np[:, 2] *= 0.7
|
||||
|
||||
@ -666,9 +688,9 @@ class CNNModelPyTorch:
|
||||
# Store the predicted action for the most recent input
|
||||
if action_probs_np.shape[0] > 0:
|
||||
latest_action = np.argmax(action_probs_np[-1])
|
||||
self.last_actions.append(int(latest_action))
|
||||
self.last_actions[0].append(int(latest_action))
|
||||
# Keep only the most recent actions
|
||||
self.last_actions = self.last_actions[-10:] # Store last 10 actions
|
||||
self.last_actions[0] = self.last_actions[0][-10:] # Store last 10 actions
|
||||
|
||||
# Update action counts for stats
|
||||
actions = np.argmax(action_probs_np, axis=1)
|
||||
@ -676,11 +698,11 @@ class CNNModelPyTorch:
|
||||
action_dict = dict(zip(unique, counts))
|
||||
|
||||
if 0 in action_dict:
|
||||
self.action_counts['SELL'] += action_dict[0]
|
||||
self.action_counts['SELL'][0] += action_dict[0]
|
||||
if 1 in action_dict:
|
||||
self.action_counts['HOLD'] += action_dict[1]
|
||||
self.action_counts['HOLD'][0] += action_dict[1]
|
||||
if 2 in action_dict:
|
||||
self.action_counts['BUY'] += action_dict[2]
|
||||
self.action_counts['BUY'][0] += action_dict[2]
|
||||
|
||||
# Get the current close prices from the input
|
||||
current_prices = X_tensor[:, -1, 3].cpu().numpy() if X_tensor.shape[2] > 3 else np.zeros(X_tensor.shape[0])
|
||||
@ -838,20 +860,25 @@ class CNNModelPyTorch:
|
||||
f"val_loss: {val_loss:.4f} - val_acc: {val_acc:.4f}")
|
||||
|
||||
# Update history
|
||||
self.history['loss'].append(epoch_loss)
|
||||
self.history['accuracy'].append(epoch_acc)
|
||||
self.history['val_loss'].append(val_loss)
|
||||
self.history['val_accuracy'].append(val_acc)
|
||||
self.train_losses.append(epoch_loss)
|
||||
self.train_accuracies.append(epoch_acc)
|
||||
self.val_losses.append(val_loss)
|
||||
self.val_accuracies.append(val_acc)
|
||||
else:
|
||||
logger.info(f"Epoch {epoch+1}/{epochs} - "
|
||||
f"loss: {epoch_loss:.4f} - acc: {epoch_acc:.4f}")
|
||||
|
||||
# Update history without validation
|
||||
self.history['loss'].append(epoch_loss)
|
||||
self.history['accuracy'].append(epoch_acc)
|
||||
self.train_losses.append(epoch_loss)
|
||||
self.train_accuracies.append(epoch_acc)
|
||||
|
||||
logger.info("Training completed")
|
||||
return self.history
|
||||
return {
|
||||
'loss': self.train_losses,
|
||||
'accuracy': self.train_accuracies,
|
||||
'val_loss': self.val_losses,
|
||||
'val_accuracy': self.val_accuracies
|
||||
}
|
||||
|
||||
def evaluate_metrics(self, X_test, y_test):
|
||||
"""
|
||||
@ -892,9 +919,14 @@ class CNNModelPyTorch:
|
||||
model_state = {
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'history': self.history,
|
||||
'history': {
|
||||
'loss': self.train_losses,
|
||||
'accuracy': self.train_accuracies,
|
||||
'val_loss': self.val_losses,
|
||||
'val_accuracy': self.val_accuracies
|
||||
},
|
||||
'window_size': self.window_size,
|
||||
'num_features': self.num_features,
|
||||
'num_features': self.total_features,
|
||||
'output_size': self.output_size,
|
||||
'timeframes': self.timeframes,
|
||||
# Save trading configuration
|
||||
@ -930,12 +962,12 @@ class CNNModelPyTorch:
|
||||
|
||||
# Update model parameters
|
||||
self.window_size = model_state['window_size']
|
||||
self.num_features = model_state['num_features']
|
||||
self.total_features = model_state['num_features']
|
||||
self.output_size = model_state['output_size']
|
||||
self.timeframes = model_state.get('timeframes', ["1m"])
|
||||
|
||||
# Load model state dict
|
||||
self.load_state_dict(model_state['model_state_dict'])
|
||||
self.model.load_state_dict(model_state['model_state_dict'])
|
||||
|
||||
# Load optimizer state if available
|
||||
if 'optimizer_state_dict' in model_state:
|
||||
|
287
NN/neural_network_orchestrator.py
Normal file
287
NN/neural_network_orchestrator.py
Normal file
@ -0,0 +1,287 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, Callable, Tuple
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .trading_agent import TradingAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class NeuralNetworkOrchestrator:
|
||||
"""Orchestrator for neural network models and trading operations.
|
||||
|
||||
This class coordinates between neural network models and trading agents,
|
||||
ensuring that signals from the models are properly processed and trades
|
||||
are executed according to the strategy.
|
||||
"""
|
||||
|
||||
def __init__(self, model, data_interface, chart=None,
|
||||
symbols: List[str] = None,
|
||||
timeframes: List[str] = None,
|
||||
window_size: int = 20,
|
||||
num_features: int = 5,
|
||||
output_size: int = 3,
|
||||
models_dir: str = "NN/models/saved",
|
||||
data_dir: str = "NN/data",
|
||||
exchange_config: Dict[str, Any] = None):
|
||||
"""Initialize the neural network orchestrator.
|
||||
|
||||
Args:
|
||||
model: Neural network model instance
|
||||
data_interface: Data interface for retrieving market data
|
||||
chart: Real-time chart for visualization (optional)
|
||||
symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT'])
|
||||
timeframes: List of timeframes to monitor (e.g., ['1m', '5m', '1h'])
|
||||
window_size: Window size for model input
|
||||
num_features: Number of features per datapoint
|
||||
output_size: Number of output classes (e.g., 3 for BUY/HOLD/SELL)
|
||||
models_dir: Directory for saved models
|
||||
data_dir: Directory for data storage
|
||||
exchange_config: Configuration for trading agent (exchange, API keys, etc.)
|
||||
"""
|
||||
self.model = model
|
||||
self.data_interface = data_interface
|
||||
self.chart = chart
|
||||
|
||||
self.symbols = symbols or ["BTC/USDT"]
|
||||
self.timeframes = timeframes or ["1m", "5m", "1h", "4h", "1d"]
|
||||
self.window_size = window_size
|
||||
self.num_features = num_features
|
||||
self.output_size = output_size
|
||||
self.models_dir = models_dir
|
||||
self.data_dir = data_dir
|
||||
|
||||
# Initialize trading agent if configuration provided
|
||||
self.trading_agent = None
|
||||
if exchange_config:
|
||||
self.init_trading_agent(exchange_config)
|
||||
|
||||
# Initialize inference state
|
||||
self.is_running = False
|
||||
self.inference_thread = None
|
||||
self.stop_event = threading.Event()
|
||||
self.last_inference_time = 0
|
||||
self.inference_interval = int(os.environ.get("NN_INFERENCE_INTERVAL", "60"))
|
||||
|
||||
logger.info(f"Initializing NeuralNetworkOrchestrator with:")
|
||||
logger.info(f"- Symbol: {self.symbols[0]}")
|
||||
logger.info(f"- Timeframes: {', '.join(self.timeframes)}")
|
||||
logger.info(f"- Window size: {window_size}")
|
||||
logger.info(f"- Num features: {num_features}")
|
||||
logger.info(f"- Output size: {output_size}")
|
||||
logger.info(f"- Models dir: {models_dir}")
|
||||
logger.info(f"- Data dir: {data_dir}")
|
||||
logger.info(f"- Inference interval: {self.inference_interval} seconds")
|
||||
|
||||
def init_trading_agent(self, config: Dict[str, Any]):
|
||||
"""Initialize the trading agent with the given configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration for the trading agent
|
||||
"""
|
||||
exchange_name = config.get("exchange", "binance")
|
||||
api_key = config.get("api_key")
|
||||
api_secret = config.get("api_secret")
|
||||
test_mode = config.get("test_mode", True)
|
||||
trade_symbols = config.get("trade_symbols", self.symbols)
|
||||
position_size = config.get("position_size", 0.1)
|
||||
max_trades_per_day = config.get("max_trades_per_day", 5)
|
||||
trade_cooldown_minutes = config.get("trade_cooldown_minutes", 60)
|
||||
|
||||
self.trading_agent = TradingAgent(
|
||||
exchange_name=exchange_name,
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
test_mode=test_mode,
|
||||
trade_symbols=trade_symbols,
|
||||
position_size=position_size,
|
||||
max_trades_per_day=max_trades_per_day,
|
||||
trade_cooldown_minutes=trade_cooldown_minutes
|
||||
)
|
||||
|
||||
logger.info(f"Trading agent initialized for {exchange_name} exchange.")
|
||||
|
||||
def start_inference(self):
|
||||
"""Start the inference thread."""
|
||||
if self.is_running:
|
||||
logger.warning("Neural network inference is already running.")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.stop_event.clear()
|
||||
|
||||
# Start inference thread
|
||||
self.inference_thread = threading.Thread(target=self._inference_loop)
|
||||
self.inference_thread.daemon = True
|
||||
self.inference_thread.start()
|
||||
|
||||
logger.info(f"Neural network inference thread started with {self.inference_interval}s interval.")
|
||||
|
||||
# Start trading agent if available
|
||||
if self.trading_agent:
|
||||
self.trading_agent.start(signal_callback=self._on_trade_executed)
|
||||
|
||||
def stop_inference(self):
|
||||
"""Stop the inference thread."""
|
||||
if not self.is_running:
|
||||
logger.warning("Neural network inference is not running.")
|
||||
return
|
||||
|
||||
logger.info("Stopping neural network inference...")
|
||||
self.is_running = False
|
||||
self.stop_event.set()
|
||||
|
||||
if self.inference_thread and self.inference_thread.is_alive():
|
||||
self.inference_thread.join(timeout=10)
|
||||
|
||||
logger.info("Neural network inference stopped.")
|
||||
|
||||
# Stop trading agent if available
|
||||
if self.trading_agent:
|
||||
self.trading_agent.stop()
|
||||
|
||||
def _inference_loop(self):
|
||||
"""Main inference loop that processes data and generates signals."""
|
||||
logger.info("Inference loop started.")
|
||||
|
||||
try:
|
||||
while self.is_running and not self.stop_event.is_set():
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we should run inference
|
||||
if current_time - self.last_inference_time >= self.inference_interval:
|
||||
try:
|
||||
# Run inference for all symbols
|
||||
for symbol in self.symbols:
|
||||
prediction = self._run_inference(symbol)
|
||||
if prediction:
|
||||
self._process_prediction(symbol, prediction)
|
||||
|
||||
self.last_inference_time = current_time
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference: {str(e)}")
|
||||
|
||||
# Sleep for a short time to prevent CPU hogging
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in inference loop: {str(e)}")
|
||||
finally:
|
||||
logger.info("Inference loop stopped.")
|
||||
|
||||
def _run_inference(self, symbol: str) -> Optional[Tuple[np.ndarray, float]]:
|
||||
"""Run inference for a specific symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
|
||||
Returns:
|
||||
tuple: (action probabilities, current price) or None if inference failed
|
||||
"""
|
||||
try:
|
||||
# Get the model timeframe from environment
|
||||
model_timeframe = os.environ.get("NN_TIMEFRAME", "1h")
|
||||
if model_timeframe not in self.timeframes:
|
||||
logger.warning(f"Model timeframe {model_timeframe} not in available timeframes. Using {self.timeframes[0]}.")
|
||||
model_timeframe = self.timeframes[0]
|
||||
|
||||
# Load candles for the model timeframe
|
||||
logger.info(f"Loading {1000} candles from cache for {symbol} at {model_timeframe} timeframe")
|
||||
candles = self.data_interface.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=model_timeframe,
|
||||
n_candles=1000
|
||||
)
|
||||
|
||||
if candles is None or len(candles) < self.window_size:
|
||||
logger.warning(f"Not enough data for {symbol} at {model_timeframe} timeframe. Need at least {self.window_size} candles.")
|
||||
return None
|
||||
|
||||
# Prepare input data
|
||||
X, timestamp = self.data_interface.prepare_model_input(
|
||||
data=candles,
|
||||
window_size=self.window_size,
|
||||
symbol=symbol
|
||||
)
|
||||
|
||||
if X is None:
|
||||
logger.warning(f"Failed to prepare model input for {symbol}.")
|
||||
return None
|
||||
|
||||
# Get current price
|
||||
current_price = candles['close'].iloc[-1]
|
||||
|
||||
# Run model inference
|
||||
action_probs, price_pred = self.model.predict(X)
|
||||
|
||||
return action_probs, current_price
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running inference for {symbol}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _process_prediction(self, symbol: str, prediction: Tuple[np.ndarray, float]):
|
||||
"""Process a prediction and generate signals.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
prediction: Tuple of (action probabilities, current price)
|
||||
"""
|
||||
action_probs, current_price = prediction
|
||||
|
||||
# Get the best action (0=SELL, 1=HOLD, 2=BUY)
|
||||
best_action = np.argmax(action_probs)
|
||||
best_prob = float(action_probs[best_action])
|
||||
|
||||
# Convert to action name
|
||||
action_names = ["SELL", "HOLD", "BUY"]
|
||||
action_name = action_names[best_action]
|
||||
|
||||
# Log the prediction
|
||||
logger.info(f"Inference result for {symbol}: Action={action_name}, Probability={best_prob:.2f}, Price={current_price:.2f}")
|
||||
|
||||
# Add signal to chart if available
|
||||
if self.chart:
|
||||
self.chart.add_nn_signal(symbol=symbol, signal=action_name, confidence=best_prob, timestamp=int(time.time()))
|
||||
|
||||
# Process signal with trading agent if available
|
||||
if self.trading_agent:
|
||||
self.trading_agent.process_signal(
|
||||
symbol=symbol,
|
||||
action=action_name,
|
||||
confidence=best_prob,
|
||||
timestamp=int(time.time())
|
||||
)
|
||||
|
||||
def _on_trade_executed(self, trade_record: Dict[str, Any]):
|
||||
"""Callback for when a trade is executed.
|
||||
|
||||
Args:
|
||||
trade_record: Trade information
|
||||
"""
|
||||
if self.chart and trade_record:
|
||||
# Add trade to chart
|
||||
self.chart.add_trade(
|
||||
action=trade_record['action'],
|
||||
price=trade_record.get('price', 0),
|
||||
timestamp=trade_record['timestamp'],
|
||||
pnl=trade_record.get('pnl', 0)
|
||||
)
|
||||
|
||||
logger.info(f"Trade added to chart: {trade_record['action']} at {trade_record.get('price', 0):.2f}")
|
||||
|
||||
def get_trading_agent_info(self) -> Dict[str, Any]:
|
||||
"""Get information about the trading agent.
|
||||
|
||||
Returns:
|
||||
dict: Trading agent information or None if no agent is available
|
||||
"""
|
||||
if self.trading_agent:
|
||||
return {
|
||||
'exchange_info': self.trading_agent.get_exchange_info(),
|
||||
'positions': self.trading_agent.get_current_positions(),
|
||||
'trades': len(self.trading_agent.get_trade_history())
|
||||
}
|
||||
return None
|
@ -13,6 +13,7 @@ import argparse
|
||||
from datetime import datetime
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger('NN')
|
||||
@ -100,7 +101,7 @@ def main():
|
||||
# Verify data interface by fetching initial data
|
||||
logger.info("Verifying data interface...")
|
||||
X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True)
|
||||
if X_sample is None or y_sample is None:
|
||||
if X_sample is None or y_sample is not None:
|
||||
logger.error("Failed to prepare initial training data")
|
||||
return
|
||||
|
||||
@ -369,9 +370,9 @@ def predict(data_interface, model, args):
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prediction mode: {str(e)}")
|
||||
|
||||
def realtime(data_interface, model, args):
|
||||
"""Run the model in real-time mode"""
|
||||
logger.info("Starting real-time mode...")
|
||||
def realtime(data_interface, model, args, chart=None, symbol=None):
|
||||
"""Run real-time inference with the trained model"""
|
||||
logger.info(f"Starting real-time inference mode for {symbol}...")
|
||||
|
||||
try:
|
||||
from NN.utils.realtime_analyzer import RealtimeAnalyzer
|
||||
@ -403,8 +404,104 @@ def realtime(data_interface, model, args):
|
||||
logger.info("Starting real-time analysis...")
|
||||
realtime_analyzer.start()
|
||||
|
||||
|
||||
|
||||
# Initialize variables for tracking performance
|
||||
total_pnl = 0.0
|
||||
trades = []
|
||||
current_position = 0.0
|
||||
last_action = None
|
||||
last_price = None
|
||||
|
||||
# Get the pair index for this symbol
|
||||
pair_index = args.symbols.index(symbol)
|
||||
|
||||
# Only execute trades if this is the main pair (BTC/USDT)
|
||||
is_main_pair = symbol == "BTC/USDT"
|
||||
|
||||
while True:
|
||||
# Get current market data for all pairs
|
||||
all_pairs_data = []
|
||||
for s in args.symbols:
|
||||
X, timestamp = data_interface.prepare_realtime_input(
|
||||
timeframe=args.timeframes[0], # Use shortest timeframe
|
||||
n_candles=args.window_size + 10, # Extra candles for safety
|
||||
window_size=args.window_size
|
||||
)
|
||||
if X is not None:
|
||||
all_pairs_data.append(X)
|
||||
else:
|
||||
logger.warning(f"No data available for {s}")
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
if not all_pairs_data:
|
||||
logger.warning("No data available for any pair")
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
# Stack data from all pairs for model input
|
||||
X_combined = np.concatenate(all_pairs_data, axis=2)
|
||||
|
||||
# Get model predictions
|
||||
action_probs, price_pred = model.predict(X_combined)
|
||||
|
||||
# Get predictions for this specific pair
|
||||
action = np.argmax(action_probs[pair_index]) # 0=SELL, 1=HOLD, 2=BUY
|
||||
|
||||
# Get current price for the main pair
|
||||
current_price = data_interface.get_historical_data(
|
||||
timeframe=args.timeframes[0],
|
||||
n_candles=1
|
||||
)['close'].iloc[-1]
|
||||
|
||||
# Calculate PnL if we have a position (only for main pair)
|
||||
pnl = 0.0
|
||||
if is_main_pair and last_action is not None and last_price is not None:
|
||||
if last_action == 2: # BUY
|
||||
pnl = (current_price - last_price) / last_price
|
||||
elif last_action == 0: # SELL
|
||||
pnl = (last_price - current_price) / last_price
|
||||
|
||||
# Update total PnL (only for main pair)
|
||||
if is_main_pair and pnl != 0:
|
||||
total_pnl += pnl
|
||||
|
||||
# Log the prediction
|
||||
action_name = "SELL" if action == 0 else "HOLD" if action == 1 else "BUY"
|
||||
log_msg = f"Time: {timestamp}, Symbol: {symbol}, Action: {action_name}, "
|
||||
if is_main_pair:
|
||||
log_msg += f"Price: {current_price:.2f}, PnL: {pnl:.2%}, Total PnL: {total_pnl:.2%}"
|
||||
else:
|
||||
log_msg += f"Price: {current_price:.2f} (Context Only)"
|
||||
logger.info(log_msg)
|
||||
|
||||
# Update the chart if provided (only for main pair)
|
||||
if chart is not None and is_main_pair and action != 1: # Skip HOLD actions
|
||||
chart.add_trade(
|
||||
action=action_name,
|
||||
price=current_price,
|
||||
timestamp=timestamp,
|
||||
pnl=pnl
|
||||
)
|
||||
|
||||
# Update tracking variables (only for main pair)
|
||||
if is_main_pair and action != 1: # If not HOLD
|
||||
last_action = action
|
||||
last_price = current_price
|
||||
|
||||
# Sleep for a short time
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
if is_main_pair:
|
||||
logger.info(f"Real-time inference stopped by user for {symbol}")
|
||||
logger.info(f"Final performance for {symbol} - Total PnL: {total_pnl:.2%}")
|
||||
else:
|
||||
logger.info(f"Real-time inference stopped by user for {symbol} (Context Only)")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in real-time mode: {str(e)}")
|
||||
logger.error(f"Error in real-time inference for {symbol}: {str(e)}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
310
NN/trading_agent.py
Normal file
310
NN/trading_agent.py
Normal file
@ -0,0 +1,310 @@
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, Any, List, Optional, Callable, Tuple, Union
|
||||
|
||||
from .exchanges import ExchangeInterface, MEXCInterface, BinanceInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TradingAgent:
|
||||
"""Trading agent that executes trades based on neural network signals.
|
||||
|
||||
This agent interfaces with different exchanges and executes trades
|
||||
based on the signals received from the neural network.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
exchange_name: str = 'binance',
|
||||
api_key: str = None,
|
||||
api_secret: str = None,
|
||||
test_mode: bool = True,
|
||||
trade_symbols: List[str] = None,
|
||||
position_size: float = 0.1,
|
||||
max_trades_per_day: int = 5,
|
||||
trade_cooldown_minutes: int = 60):
|
||||
"""Initialize the trading agent.
|
||||
|
||||
Args:
|
||||
exchange_name: Name of the exchange to use ('binance', 'mexc')
|
||||
api_key: API key for the exchange
|
||||
api_secret: API secret for the exchange
|
||||
test_mode: If True, use test/sandbox environment
|
||||
trade_symbols: List of trading symbols to monitor (e.g., ['BTC/USDT'])
|
||||
position_size: Size of each position as a fraction of total available balance (0.0-1.0)
|
||||
max_trades_per_day: Maximum number of trades to execute per day
|
||||
trade_cooldown_minutes: Minimum time between trades in minutes
|
||||
"""
|
||||
self.exchange_name = exchange_name.lower()
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.test_mode = test_mode
|
||||
self.trade_symbols = trade_symbols or ['BTC/USDT']
|
||||
self.position_size = min(max(position_size, 0.01), 1.0) # Ensure between 0.01 and 1.0
|
||||
self.max_trades_per_day = max(1, max_trades_per_day)
|
||||
self.trade_cooldown_seconds = max(60, trade_cooldown_minutes * 60)
|
||||
|
||||
# Initialize exchange interface
|
||||
self.exchange = self._create_exchange()
|
||||
|
||||
# Trading state
|
||||
self.active = False
|
||||
self.current_positions = {} # Symbol -> quantity
|
||||
self.trades_today = {} # Symbol -> count
|
||||
self.last_trade_time = {} # Symbol -> timestamp
|
||||
self.trade_history = [] # List of trade records
|
||||
|
||||
# Threading
|
||||
self.trading_thread = None
|
||||
self.stop_event = threading.Event()
|
||||
|
||||
# Signal callback
|
||||
self.signal_callback = None
|
||||
|
||||
# Connect to exchange
|
||||
if not self.exchange.connect():
|
||||
logger.error(f"Failed to connect to {self.exchange_name} exchange. Trading agent disabled.")
|
||||
else:
|
||||
logger.info(f"Successfully connected to {self.exchange_name} exchange.")
|
||||
self._load_current_positions()
|
||||
|
||||
def _create_exchange(self) -> ExchangeInterface:
|
||||
"""Create an exchange interface based on the exchange name."""
|
||||
if self.exchange_name == 'mexc':
|
||||
return MEXCInterface(
|
||||
api_key=self.api_key,
|
||||
api_secret=self.api_secret,
|
||||
test_mode=self.test_mode
|
||||
)
|
||||
elif self.exchange_name == 'binance':
|
||||
return BinanceInterface(
|
||||
api_key=self.api_key,
|
||||
api_secret=self.api_secret,
|
||||
test_mode=self.test_mode
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported exchange: {self.exchange_name}")
|
||||
|
||||
def _load_current_positions(self):
|
||||
"""Load current positions from the exchange."""
|
||||
for symbol in self.trade_symbols:
|
||||
try:
|
||||
base_asset, quote_asset = symbol.split('/')
|
||||
balance = self.exchange.get_balance(base_asset)
|
||||
|
||||
if balance > 0:
|
||||
self.current_positions[symbol] = balance
|
||||
logger.info(f"Loaded existing position for {symbol}: {balance} {base_asset}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading position for {symbol}: {str(e)}")
|
||||
|
||||
def start(self, signal_callback: Callable = None):
|
||||
"""Start the trading agent.
|
||||
|
||||
Args:
|
||||
signal_callback: Optional callback function to receive trade signals
|
||||
"""
|
||||
if self.active:
|
||||
logger.warning("Trading agent is already running.")
|
||||
return
|
||||
|
||||
self.active = True
|
||||
self.signal_callback = signal_callback
|
||||
self.stop_event.clear()
|
||||
|
||||
logger.info(f"Starting trading agent for {self.exchange_name} exchange.")
|
||||
logger.info(f"Trading symbols: {', '.join(self.trade_symbols)}")
|
||||
logger.info(f"Position size: {self.position_size * 100:.1f}% of available balance")
|
||||
logger.info(f"Max trades per day: {self.max_trades_per_day}")
|
||||
logger.info(f"Trade cooldown: {self.trade_cooldown_seconds // 60} minutes")
|
||||
|
||||
# Reset trading state
|
||||
self.trades_today = {symbol: 0 for symbol in self.trade_symbols}
|
||||
self.last_trade_time = {symbol: 0 for symbol in self.trade_symbols}
|
||||
|
||||
# Start trading thread
|
||||
self.trading_thread = threading.Thread(target=self._trading_loop)
|
||||
self.trading_thread.daemon = True
|
||||
self.trading_thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop the trading agent."""
|
||||
if not self.active:
|
||||
logger.warning("Trading agent is not running.")
|
||||
return
|
||||
|
||||
logger.info("Stopping trading agent...")
|
||||
self.active = False
|
||||
self.stop_event.set()
|
||||
|
||||
if self.trading_thread and self.trading_thread.is_alive():
|
||||
self.trading_thread.join(timeout=10)
|
||||
|
||||
logger.info("Trading agent stopped.")
|
||||
|
||||
def _trading_loop(self):
|
||||
"""Main trading loop that monitors positions and executes trades."""
|
||||
logger.info("Trading loop started.")
|
||||
|
||||
try:
|
||||
while self.active and not self.stop_event.is_set():
|
||||
# Check positions and update state
|
||||
for symbol in self.trade_symbols:
|
||||
try:
|
||||
base_asset, _ = symbol.split('/')
|
||||
current_balance = self.exchange.get_balance(base_asset)
|
||||
|
||||
# Update position if it has changed
|
||||
if symbol in self.current_positions:
|
||||
prev_balance = self.current_positions[symbol]
|
||||
if abs(current_balance - prev_balance) > 0.001 * prev_balance:
|
||||
logger.info(f"Position updated for {symbol}: {prev_balance} -> {current_balance} {base_asset}")
|
||||
|
||||
self.current_positions[symbol] = current_balance
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking position for {symbol}: {str(e)}")
|
||||
|
||||
# Sleep for a while
|
||||
time.sleep(10)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in trading loop: {str(e)}")
|
||||
finally:
|
||||
logger.info("Trading loop stopped.")
|
||||
|
||||
def reset_daily_limits(self):
|
||||
"""Reset daily trading limits. Call this at the start of each trading day."""
|
||||
self.trades_today = {symbol: 0 for symbol in self.trade_symbols}
|
||||
logger.info("Daily trading limits reset.")
|
||||
|
||||
def process_signal(self, symbol: str, action: str,
|
||||
confidence: float = None, timestamp: int = None) -> Optional[Dict[str, Any]]:
|
||||
"""Process a trading signal and execute a trade if conditions are met.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||
action: Trade action ('BUY', 'SELL', 'HOLD')
|
||||
confidence: Confidence level of the signal (0.0-1.0)
|
||||
timestamp: Timestamp of the signal (unix time)
|
||||
|
||||
Returns:
|
||||
dict: Trade information if a trade was executed, None otherwise
|
||||
"""
|
||||
if not self.active:
|
||||
logger.warning("Trading agent is not active. Signal ignored.")
|
||||
return None
|
||||
|
||||
if symbol not in self.trade_symbols:
|
||||
logger.warning(f"Symbol {symbol} is not in the trading symbols list. Signal ignored.")
|
||||
return None
|
||||
|
||||
if action not in ['BUY', 'SELL', 'HOLD']:
|
||||
logger.warning(f"Invalid action: {action}. Must be 'BUY', 'SELL', or 'HOLD'.")
|
||||
return None
|
||||
|
||||
# Log the signal
|
||||
confidence_str = f" (confidence: {confidence:.2f})" if confidence is not None else ""
|
||||
logger.info(f"Received {action} signal for {symbol}{confidence_str}")
|
||||
|
||||
# Ignore HOLD signals for trading
|
||||
if action == 'HOLD':
|
||||
return None
|
||||
|
||||
# Check if we can trade based on limits
|
||||
current_time = time.time()
|
||||
|
||||
# Check max trades per day
|
||||
if self.trades_today.get(symbol, 0) >= self.max_trades_per_day:
|
||||
logger.warning(f"Max trades per day reached for {symbol}. Signal ignored.")
|
||||
return None
|
||||
|
||||
# Check trade cooldown
|
||||
last_trade_time = self.last_trade_time.get(symbol, 0)
|
||||
if current_time - last_trade_time < self.trade_cooldown_seconds:
|
||||
cooldown_remaining = self.trade_cooldown_seconds - (current_time - last_trade_time)
|
||||
logger.warning(f"Trade cooldown active for {symbol}. {cooldown_remaining:.1f} seconds remaining. Signal ignored.")
|
||||
return None
|
||||
|
||||
# Check if the action makes sense based on current position
|
||||
base_asset, _ = symbol.split('/')
|
||||
current_position = self.current_positions.get(symbol, 0)
|
||||
|
||||
if action == 'BUY' and current_position > 0:
|
||||
logger.warning(f"Already have a position in {symbol}. BUY signal ignored.")
|
||||
return None
|
||||
|
||||
if action == 'SELL' and current_position <= 0:
|
||||
logger.warning(f"No position in {symbol} to sell. SELL signal ignored.")
|
||||
return None
|
||||
|
||||
# Execute the trade
|
||||
try:
|
||||
trade_result = self.exchange.execute_trade(
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
percent_of_balance=self.position_size
|
||||
)
|
||||
|
||||
if trade_result:
|
||||
# Update trading state
|
||||
self.trades_today[symbol] = self.trades_today.get(symbol, 0) + 1
|
||||
self.last_trade_time[symbol] = current_time
|
||||
|
||||
# Create trade record
|
||||
trade_record = {
|
||||
'symbol': symbol,
|
||||
'action': action,
|
||||
'timestamp': timestamp or int(current_time),
|
||||
'confidence': confidence,
|
||||
'order_id': trade_result.get('orderId') if isinstance(trade_result, dict) else None,
|
||||
'status': 'executed'
|
||||
}
|
||||
|
||||
# Add to trade history
|
||||
self.trade_history.append(trade_record)
|
||||
|
||||
# Call signal callback if provided
|
||||
if self.signal_callback:
|
||||
self.signal_callback(trade_record)
|
||||
|
||||
logger.info(f"Successfully executed {action} trade for {symbol}")
|
||||
return trade_record
|
||||
else:
|
||||
logger.error(f"Failed to execute {action} trade for {symbol}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing trade for {symbol}: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_current_positions(self) -> Dict[str, float]:
|
||||
"""Get current positions.
|
||||
|
||||
Returns:
|
||||
dict: Symbol -> position size
|
||||
"""
|
||||
return self.current_positions.copy()
|
||||
|
||||
def get_trade_history(self) -> List[Dict[str, Any]]:
|
||||
"""Get trade history.
|
||||
|
||||
Returns:
|
||||
list: List of trade records
|
||||
"""
|
||||
return self.trade_history.copy()
|
||||
|
||||
def get_exchange_info(self) -> Dict[str, Any]:
|
||||
"""Get exchange information.
|
||||
|
||||
Returns:
|
||||
dict: Exchange information
|
||||
"""
|
||||
return {
|
||||
'name': self.exchange_name,
|
||||
'test_mode': self.test_mode,
|
||||
'active': self.active,
|
||||
'trade_symbols': self.trade_symbols,
|
||||
'position_size': self.position_size,
|
||||
'max_trades_per_day': self.max_trades_per_day,
|
||||
'trade_cooldown_seconds': self.trade_cooldown_seconds,
|
||||
'trades_today': self.trades_today.copy()
|
||||
}
|
@ -247,9 +247,27 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo
|
||||
)
|
||||
|
||||
# Get training data for each timeframe with more data
|
||||
logger.info("Loading training data...")
|
||||
features_1m = data_interface.get_training_data("1m", n_candles=5000)
|
||||
if features_1m is not None:
|
||||
logger.info(f"Loaded {len(features_1m)} 1m candles")
|
||||
else:
|
||||
logger.error("Failed to load 1m data")
|
||||
return None
|
||||
|
||||
features_5m = data_interface.get_training_data("5m", n_candles=2500)
|
||||
if features_5m is not None:
|
||||
logger.info(f"Loaded {len(features_5m)} 5m candles")
|
||||
else:
|
||||
logger.error("Failed to load 5m data")
|
||||
return None
|
||||
|
||||
features_15m = data_interface.get_training_data("15m", n_candles=2500)
|
||||
if features_15m is not None:
|
||||
logger.info(f"Loaded {len(features_15m)} 15m candles")
|
||||
else:
|
||||
logger.error("Failed to load 15m data")
|
||||
return None
|
||||
|
||||
if features_1m is None or features_5m is None or features_15m is None:
|
||||
logger.error("Failed to load training data")
|
||||
|
@ -1,5 +1,7 @@
|
||||
https://github.com/mexcdevelop/mexc-api-sdk/blob/main/README.md#test-new-order
|
||||
|
||||
https://mexcdevelop.github.io/apidocs/spot_v3_en/#test-new-order
|
||||
|
||||
python mexc_tick_visualizer.py --symbol BTC/USDT --interval 1.0 --candle 60
|
||||
|
||||
python main.py --mode live --symbol ETH/USDT --timeframe 1m --use-websocket
|
||||
|
124
launch_training.py
Normal file
124
launch_training.py
Normal file
@ -0,0 +1,124 @@
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import webbrowser
|
||||
from threading import Thread
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('training_launch.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def start_tensorboard(port=6007):
|
||||
"""Start TensorBoard on a specified port"""
|
||||
try:
|
||||
cmd = f"tensorboard --logdir=runs --port={port}"
|
||||
process = subprocess.Popen(cmd, shell=True)
|
||||
logger.info(f"Started TensorBoard on port {port}")
|
||||
return process
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start TensorBoard: {str(e)}")
|
||||
return None
|
||||
|
||||
def start_web_chart():
|
||||
"""Start the web chart server"""
|
||||
try:
|
||||
cmd = "python main.py --symbols BTC/USDT ETH/USDT SOL/USDT --timeframes 1m 5m 15m --mode realtime"
|
||||
process = subprocess.Popen(cmd, shell=True)
|
||||
logger.info("Started web chart server")
|
||||
return process
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start web chart server: {str(e)}")
|
||||
return None
|
||||
|
||||
def start_training():
|
||||
"""Start the RL training process"""
|
||||
try:
|
||||
cmd = "python NN/train_rl.py"
|
||||
process = subprocess.Popen(cmd, shell=True)
|
||||
logger.info("Started RL training process")
|
||||
return process
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start training process: {str(e)}")
|
||||
return None
|
||||
|
||||
def open_web_interfaces():
|
||||
"""Open web browsers for TensorBoard and chart after a delay"""
|
||||
time.sleep(5) # Wait for servers to start
|
||||
try:
|
||||
webbrowser.open('http://localhost:6007') # TensorBoard
|
||||
webbrowser.open('http://localhost:8050') # Web chart
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to open web interfaces: {str(e)}")
|
||||
|
||||
def monitor_processes(processes):
|
||||
"""Monitor running processes and log any unexpected terminations"""
|
||||
while True:
|
||||
for name, process in processes.items():
|
||||
if process and process.poll() is not None:
|
||||
logger.error(f"{name} process terminated unexpectedly")
|
||||
return False
|
||||
time.sleep(1)
|
||||
|
||||
def main():
|
||||
"""Main function to orchestrate the training environment"""
|
||||
logger.info("Starting training environment setup...")
|
||||
|
||||
# Start TensorBoard
|
||||
tensorboard_process = start_tensorboard(port=6007)
|
||||
if not tensorboard_process:
|
||||
logger.error("Failed to start TensorBoard")
|
||||
return
|
||||
|
||||
# Start web chart
|
||||
web_chart_process = start_web_chart()
|
||||
if not web_chart_process:
|
||||
tensorboard_process.terminate()
|
||||
logger.error("Failed to start web chart")
|
||||
return
|
||||
|
||||
# Start training
|
||||
training_process = start_training()
|
||||
if not training_process:
|
||||
tensorboard_process.terminate()
|
||||
web_chart_process.terminate()
|
||||
logger.error("Failed to start training")
|
||||
return
|
||||
|
||||
# Open web interfaces in a separate thread
|
||||
Thread(target=open_web_interfaces).start()
|
||||
|
||||
# Monitor processes
|
||||
processes = {
|
||||
'tensorboard': tensorboard_process,
|
||||
'web_chart': web_chart_process,
|
||||
'training': training_process
|
||||
}
|
||||
|
||||
try:
|
||||
if not monitor_processes(processes):
|
||||
raise Exception("One or more processes terminated unexpectedly")
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received shutdown signal")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in monitoring: {str(e)}")
|
||||
finally:
|
||||
# Cleanup
|
||||
logger.info("Shutting down training environment...")
|
||||
for name, process in processes.items():
|
||||
if process:
|
||||
process.terminate()
|
||||
logger.info(f"Terminated {name} process")
|
||||
logger.info("Training environment shutdown complete")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
155
trading_main.py
Normal file
155
trading_main.py
Normal file
@ -0,0 +1,155 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
|
||||
# Add the NN directory to the Python path
|
||||
sys.path.append(os.path.abspath("NN"))
|
||||
|
||||
from NN.main import load_model
|
||||
from NN.neural_network_orchestrator import NeuralNetworkOrchestrator
|
||||
from NN.realtime_data_interface import RealtimeDataInterface
|
||||
|
||||
# Initialize logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler("trading_bot.log"),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
"""Main function for the trading bot."""
|
||||
# Parse command-line arguments
|
||||
parser = argparse.ArgumentParser(description="Trading Bot with Neural Network Integration")
|
||||
parser.add_argument('--symbols', nargs='+', default=["BTC/USDT", "ETH/USDT"],
|
||||
help='Trading symbols to monitor')
|
||||
parser.add_argument('--timeframes', nargs='+', default=["1m", "5m", "1h", "4h", "1d"],
|
||||
help='Timeframes to monitor')
|
||||
parser.add_argument('--window-size', type=int, default=20,
|
||||
help='Window size for model input')
|
||||
parser.add_argument('--output-size', type=int, default=3,
|
||||
help='Output size of the model (3 for BUY/HOLD/SELL)')
|
||||
parser.add_argument('--model-type', type=str, default="cnn", choices=["cnn", "lstm", "mlp"],
|
||||
help='Type of neural network model')
|
||||
parser.add_argument('--mode', type=str, default="realtime", choices=["realtime", "backtest"],
|
||||
help='Trading mode')
|
||||
parser.add_argument('--exchange', type=str, default="binance", choices=["binance", "mexc"],
|
||||
help='Exchange to use for trading')
|
||||
parser.add_argument('--api-key', type=str, default=None,
|
||||
help='API key for the exchange')
|
||||
parser.add_argument('--api-secret', type=str, default=None,
|
||||
help='API secret for the exchange')
|
||||
parser.add_argument('--test-mode', action='store_true',
|
||||
help='Use test/sandbox exchange environment')
|
||||
parser.add_argument('--position-size', type=float, default=0.1,
|
||||
help='Position size as a fraction of total balance (0.0-1.0)')
|
||||
parser.add_argument('--max-trades-per-day', type=int, default=5,
|
||||
help='Maximum number of trades per day')
|
||||
parser.add_argument('--trade-cooldown', type=int, default=60,
|
||||
help='Trade cooldown period in minutes')
|
||||
parser.add_argument('--config-file', type=str, default=None,
|
||||
help='Path to configuration file')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load configuration from file if provided
|
||||
if args.config_file and os.path.exists(args.config_file):
|
||||
with open(args.config_file, 'r') as f:
|
||||
config = json.load(f)
|
||||
# Override config with command-line args
|
||||
for key, value in vars(args).items():
|
||||
if key != 'config_file' and value is not None:
|
||||
config[key] = value
|
||||
else:
|
||||
# Use command-line args as config
|
||||
config = vars(args)
|
||||
|
||||
# Initialize real-time charts and data interfaces
|
||||
try:
|
||||
from realtime import RealTimeChart
|
||||
|
||||
# Create a real-time chart for each symbol
|
||||
charts = {}
|
||||
for symbol in config['symbols']:
|
||||
charts[symbol] = RealTimeChart(symbol=symbol)
|
||||
|
||||
main_chart = charts[config['symbols'][0]]
|
||||
|
||||
# Create a data interface for retrieving market data
|
||||
data_interface = RealtimeDataInterface(symbols=config['symbols'], chart=main_chart)
|
||||
|
||||
# Load trained model
|
||||
model_type = os.environ.get("NN_MODEL_TYPE", config['model_type'])
|
||||
model = load_model(
|
||||
model_type=model_type,
|
||||
input_shape=(config['window_size'], len(config['symbols']), 5), # 5 features (OHLCV)
|
||||
output_size=config['output_size']
|
||||
)
|
||||
|
||||
# Configure trading agent
|
||||
exchange_config = {
|
||||
"exchange": config['exchange'],
|
||||
"api_key": config['api_key'],
|
||||
"api_secret": config['api_secret'],
|
||||
"test_mode": config['test_mode'],
|
||||
"trade_symbols": config['symbols'],
|
||||
"position_size": config['position_size'],
|
||||
"max_trades_per_day": config['max_trades_per_day'],
|
||||
"trade_cooldown_minutes": config['trade_cooldown']
|
||||
}
|
||||
|
||||
# Initialize neural network orchestrator
|
||||
orchestrator = NeuralNetworkOrchestrator(
|
||||
model=model,
|
||||
data_interface=data_interface,
|
||||
chart=main_chart,
|
||||
symbols=config['symbols'],
|
||||
timeframes=config['timeframes'],
|
||||
window_size=config['window_size'],
|
||||
num_features=5, # OHLCV
|
||||
output_size=config['output_size'],
|
||||
exchange_config=exchange_config
|
||||
)
|
||||
|
||||
# Start data collection
|
||||
logger.info("Starting data collection threads...")
|
||||
for symbol in config['symbols']:
|
||||
charts[symbol].start()
|
||||
|
||||
# Start neural network inference
|
||||
if os.environ.get("ENABLE_NN_MODELS", "0") == "1":
|
||||
logger.info("Starting neural network inference...")
|
||||
orchestrator.start_inference()
|
||||
else:
|
||||
logger.info("Neural network models disabled. Set ENABLE_NN_MODELS=1 to enable.")
|
||||
|
||||
# Start web servers for chart display
|
||||
logger.info("Starting web servers for chart display...")
|
||||
main_chart.start_server()
|
||||
|
||||
logger.info("Trading bot initialized successfully. Press Ctrl+C to exit.")
|
||||
|
||||
# Keep the main thread alive
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Keyboard interrupt received. Shutting down...")
|
||||
# Stop all threads
|
||||
for symbol in config['symbols']:
|
||||
charts[symbol].stop()
|
||||
orchestrator.stop_inference()
|
||||
logger.info("Trading bot stopped.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main function: {str(e)}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user