showing trades on realtime chart - chart broken

This commit is contained in:
Dobromir Popov 2025-03-31 14:22:33 +03:00
parent 0ad9484d56
commit a46b2c74f8
14 changed files with 3182 additions and 76 deletions

5
NN/exchanges/__init__.py Normal file
View File

@ -0,0 +1,5 @@
from .exchange_interface import ExchangeInterface
from .mexc_interface import MEXCInterface
from .binance_interface import BinanceInterface
__all__ = ['ExchangeInterface', 'MEXCInterface', 'BinanceInterface']

View File

@ -0,0 +1,276 @@
import logging
import time
from typing import Dict, Any, List, Optional
import requests
import hmac
import hashlib
from urllib.parse import urlencode
from .exchange_interface import ExchangeInterface
logger = logging.getLogger(__name__)
class BinanceInterface(ExchangeInterface):
"""Binance Exchange API Interface"""
def __init__(self, api_key: str = None, api_secret: str = None, test_mode: bool = True):
"""Initialize Binance exchange interface.
Args:
api_key: Binance API key
api_secret: Binance API secret
test_mode: If True, use testnet environment
"""
super().__init__(api_key, api_secret, test_mode)
# Use testnet URLs if in test mode
if test_mode:
self.base_url = "https://testnet.binance.vision"
else:
self.base_url = "https://api.binance.com"
self.api_version = "v3"
def connect(self) -> bool:
"""Connect to Binance API. This is a no-op for REST API."""
if not self.api_key or not self.api_secret:
logger.warning("Binance API credentials not provided. Running in read-only mode.")
return False
try:
# Test connection by pinging server and checking account info
ping_result = self._send_public_request('GET', 'ping')
if self.api_key and self.api_secret:
# Check account connectivity
self.get_account_info()
logger.info(f"Successfully connected to Binance API ({'testnet' if self.test_mode else 'live'})")
return True
except Exception as e:
logger.error(f"Failed to connect to Binance API: {str(e)}")
return False
def _generate_signature(self, params: Dict[str, Any]) -> str:
"""Generate signature for authenticated requests."""
query_string = urlencode(params)
signature = hmac.new(
self.api_secret.encode('utf-8'),
query_string.encode('utf-8'),
hashlib.sha256
).hexdigest()
return signature
def _send_public_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
"""Send public request to Binance API."""
url = f"{self.base_url}/api/{self.api_version}/{endpoint}"
try:
if method.upper() == 'GET':
response = requests.get(url, params=params)
else:
response = requests.post(url, json=params)
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Error in public request to {endpoint}: {str(e)}")
raise
def _send_private_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
"""Send private/authenticated request to Binance API."""
if not self.api_key or not self.api_secret:
raise ValueError("API key and secret are required for private requests")
if params is None:
params = {}
# Add timestamp
params['timestamp'] = int(time.time() * 1000)
# Generate signature
signature = self._generate_signature(params)
params['signature'] = signature
# Set headers
headers = {
'X-MBX-APIKEY': self.api_key
}
url = f"{self.base_url}/api/{self.api_version}/{endpoint}"
try:
if method.upper() == 'GET':
response = requests.get(url, params=params, headers=headers)
elif method.upper() == 'POST':
response = requests.post(url, data=params, headers=headers)
elif method.upper() == 'DELETE':
response = requests.delete(url, params=params, headers=headers)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
# Log detailed error if available
if response.status_code != 200:
logger.error(f"Binance API error: {response.text}")
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Error in private request to {endpoint}: {str(e)}")
raise
def get_account_info(self) -> Dict[str, Any]:
"""Get account information."""
return self._send_private_request('GET', 'account')
def get_balance(self, asset: str) -> float:
"""Get balance of a specific asset.
Args:
asset: Asset symbol (e.g., 'BTC', 'USDT')
Returns:
float: Available balance of the asset
"""
try:
account_info = self._send_private_request('GET', 'account')
balances = account_info.get('balances', [])
for balance in balances:
if balance['asset'] == asset:
return float(balance['free'])
# Asset not found
return 0.0
except Exception as e:
logger.error(f"Error getting balance for {asset}: {str(e)}")
return 0.0
def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker data for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
Returns:
dict: Ticker data including price information
"""
binance_symbol = symbol.replace('/', '')
try:
ticker = self._send_public_request('GET', 'ticker/24hr', {'symbol': binance_symbol})
# Convert to a standardized format
result = {
'symbol': symbol,
'bid': float(ticker['bidPrice']),
'ask': float(ticker['askPrice']),
'last': float(ticker['lastPrice']),
'volume': float(ticker['volume']),
'timestamp': int(ticker['closeTime'])
}
return result
except Exception as e:
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
raise
def place_order(self, symbol: str, side: str, order_type: str,
quantity: float, price: float = None) -> Dict[str, Any]:
"""Place an order on the exchange.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
side: Order side ('buy' or 'sell')
order_type: Order type ('market', 'limit', etc.)
quantity: Order quantity
price: Order price (for limit orders)
Returns:
dict: Order information including order ID
"""
binance_symbol = symbol.replace('/', '')
params = {
'symbol': binance_symbol,
'side': side.upper(),
'type': order_type.upper(),
'quantity': quantity,
}
if order_type.lower() == 'limit' and price is not None:
params['price'] = price
params['timeInForce'] = 'GTC' # Good Till Cancelled
# Use test order endpoint in test mode
endpoint = 'order/test' if self.test_mode else 'order'
try:
order_result = self._send_private_request('POST', endpoint, params)
return order_result
except Exception as e:
logger.error(f"Error placing {side} {order_type} order for {symbol}: {str(e)}")
raise
def cancel_order(self, symbol: str, order_id: str) -> bool:
"""Cancel an existing order.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
order_id: ID of the order to cancel
Returns:
bool: True if cancellation successful, False otherwise
"""
binance_symbol = symbol.replace('/', '')
params = {
'symbol': binance_symbol,
'orderId': order_id
}
try:
cancel_result = self._send_private_request('DELETE', 'order', params)
return True
except Exception as e:
logger.error(f"Error cancelling order {order_id} for {symbol}: {str(e)}")
return False
def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]:
"""Get status of an existing order.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
order_id: ID of the order
Returns:
dict: Order status information
"""
binance_symbol = symbol.replace('/', '')
params = {
'symbol': binance_symbol,
'orderId': order_id
}
try:
order_info = self._send_private_request('GET', 'order', params)
return order_info
except Exception as e:
logger.error(f"Error getting order status for {order_id} on {symbol}: {str(e)}")
raise
def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]:
"""Get all open orders, optionally filtered by symbol.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT'), or None for all symbols
Returns:
list: List of open orders
"""
params = {}
if symbol:
params['symbol'] = symbol.replace('/', '')
try:
open_orders = self._send_private_request('GET', 'openOrders', params)
return open_orders
except Exception as e:
logger.error(f"Error getting open orders: {str(e)}")
return []

View File

@ -0,0 +1,191 @@
import abc
import logging
from typing import Dict, Any, List, Tuple, Optional
logger = logging.getLogger(__name__)
class ExchangeInterface(abc.ABC):
"""Base class for all exchange interfaces.
This abstract class defines the required methods that all exchange
implementations must provide to ensure compatibility with the trading system.
"""
def __init__(self, api_key: str = None, api_secret: str = None, test_mode: bool = True):
"""Initialize the exchange interface.
Args:
api_key: API key for the exchange
api_secret: API secret for the exchange
test_mode: If True, use test/sandbox environment
"""
self.api_key = api_key
self.api_secret = api_secret
self.test_mode = test_mode
self.client = None
self.last_price_cache = {}
@abc.abstractmethod
def connect(self) -> bool:
"""Connect to the exchange API.
Returns:
bool: True if connection successful, False otherwise
"""
pass
@abc.abstractmethod
def get_balance(self, asset: str) -> float:
"""Get balance of a specific asset.
Args:
asset: Asset symbol (e.g., 'BTC', 'USDT')
Returns:
float: Available balance of the asset
"""
pass
@abc.abstractmethod
def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker data for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
Returns:
dict: Ticker data including price information
"""
pass
@abc.abstractmethod
def place_order(self, symbol: str, side: str, order_type: str,
quantity: float, price: float = None) -> Dict[str, Any]:
"""Place an order on the exchange.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
side: Order side ('buy' or 'sell')
order_type: Order type ('market', 'limit', etc.)
quantity: Order quantity
price: Order price (for limit orders)
Returns:
dict: Order information including order ID
"""
pass
@abc.abstractmethod
def cancel_order(self, symbol: str, order_id: str) -> bool:
"""Cancel an existing order.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
order_id: ID of the order to cancel
Returns:
bool: True if cancellation successful, False otherwise
"""
pass
@abc.abstractmethod
def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]:
"""Get status of an existing order.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
order_id: ID of the order
Returns:
dict: Order status information
"""
pass
@abc.abstractmethod
def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]:
"""Get all open orders, optionally filtered by symbol.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT'), or None for all symbols
Returns:
list: List of open orders
"""
pass
def get_last_price(self, symbol: str) -> float:
"""Get last known price for a symbol, may use cached value.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
Returns:
float: Last price
"""
try:
ticker = self.get_ticker(symbol)
price = float(ticker['last'])
self.last_price_cache[symbol] = price
return price
except Exception as e:
logger.error(f"Error getting price for {symbol}: {str(e)}")
# Return cached price if available
return self.last_price_cache.get(symbol, 0.0)
def execute_trade(self, symbol: str, action: str, quantity: float = None,
percent_of_balance: float = None) -> Optional[Dict[str, Any]]:
"""Execute a trade based on a signal.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
action: Trade action ('BUY', 'SELL')
quantity: Specific quantity to trade
percent_of_balance: Alternative to quantity - percentage of available balance to use
Returns:
dict: Order information or None if order failed
"""
if action not in ['BUY', 'SELL']:
logger.error(f"Invalid action: {action}. Must be 'BUY' or 'SELL'")
return None
side = action.lower()
try:
# Determine base and quote assets from symbol (e.g., BTC/USDT -> BTC, USDT)
base_asset, quote_asset = symbol.split('/')
# Calculate quantity if percent_of_balance is provided
if quantity is None and percent_of_balance is not None:
if percent_of_balance <= 0 or percent_of_balance > 1:
logger.error(f"Invalid percent_of_balance: {percent_of_balance}. Must be between 0 and 1")
return None
if side == 'buy':
# For buy, use quote asset (e.g., USDT)
balance = self.get_balance(quote_asset)
price = self.get_last_price(symbol)
quantity = (balance * percent_of_balance) / price
else:
# For sell, use base asset (e.g., BTC)
balance = self.get_balance(base_asset)
quantity = balance * percent_of_balance
if not quantity or quantity <= 0:
logger.error(f"Invalid quantity: {quantity}")
return None
# Place market order
order = self.place_order(
symbol=symbol,
side=side,
order_type='market',
quantity=quantity
)
logger.info(f"Executed {side.upper()} order for {quantity} {base_asset} at market price")
return order
except Exception as e:
logger.error(f"Error executing {action} trade for {symbol}: {str(e)}")
return None

View File

@ -0,0 +1,258 @@
import logging
import time
from typing import Dict, Any, List, Optional
import requests
import hmac
import hashlib
from urllib.parse import urlencode
from .exchange_interface import ExchangeInterface
logger = logging.getLogger(__name__)
class MEXCInterface(ExchangeInterface):
"""MEXC Exchange API Interface"""
def __init__(self, api_key: str = None, api_secret: str = None, test_mode: bool = True):
"""Initialize MEXC exchange interface.
Args:
api_key: MEXC API key
api_secret: MEXC API secret
test_mode: If True, use test/sandbox environment (Note: MEXC doesn't have a true sandbox)
"""
super().__init__(api_key, api_secret, test_mode)
self.base_url = "https://api.mexc.com"
self.api_version = "v3"
def connect(self) -> bool:
"""Connect to MEXC API. This is a no-op for REST API."""
if not self.api_key or not self.api_secret:
logger.warning("MEXC API credentials not provided. Running in read-only mode.")
return False
try:
# Test connection by getting account info
self.get_account_info()
logger.info("Successfully connected to MEXC API")
return True
except Exception as e:
logger.error(f"Failed to connect to MEXC API: {str(e)}")
return False
def _generate_signature(self, params: Dict[str, Any]) -> str:
"""Generate signature for authenticated requests."""
query_string = urlencode(params)
signature = hmac.new(
self.api_secret.encode('utf-8'),
query_string.encode('utf-8'),
hashlib.sha256
).hexdigest()
return signature
def _send_public_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
"""Send public request to MEXC API."""
url = f"{self.base_url}/{self.api_version}/{endpoint}"
try:
if method.upper() == 'GET':
response = requests.get(url, params=params)
else:
response = requests.post(url, json=params)
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Error in public request to {endpoint}: {str(e)}")
raise
def _send_private_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Dict[str, Any]:
"""Send private/authenticated request to MEXC API."""
if not self.api_key or not self.api_secret:
raise ValueError("API key and secret are required for private requests")
if params is None:
params = {}
# Add timestamp
params['timestamp'] = int(time.time() * 1000)
# Generate signature
signature = self._generate_signature(params)
params['signature'] = signature
# Set headers
headers = {
'X-MEXC-APIKEY': self.api_key
}
url = f"{self.base_url}/{self.api_version}/{endpoint}"
try:
if method.upper() == 'GET':
response = requests.get(url, params=params, headers=headers)
elif method.upper() == 'POST':
response = requests.post(url, json=params, headers=headers)
elif method.upper() == 'DELETE':
response = requests.delete(url, params=params, headers=headers)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Error in private request to {endpoint}: {str(e)}")
raise
def get_account_info(self) -> Dict[str, Any]:
"""Get account information."""
return self._send_private_request('GET', 'account')
def get_balance(self, asset: str) -> float:
"""Get balance of a specific asset.
Args:
asset: Asset symbol (e.g., 'BTC', 'USDT')
Returns:
float: Available balance of the asset
"""
try:
account_info = self._send_private_request('GET', 'account')
balances = account_info.get('balances', [])
for balance in balances:
if balance['asset'] == asset:
return float(balance['free'])
# Asset not found
return 0.0
except Exception as e:
logger.error(f"Error getting balance for {asset}: {str(e)}")
return 0.0
def get_ticker(self, symbol: str) -> Dict[str, Any]:
"""Get current ticker data for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
Returns:
dict: Ticker data including price information
"""
mexc_symbol = symbol.replace('/', '')
try:
ticker = self._send_public_request('GET', 'ticker/24hr', {'symbol': mexc_symbol})
# Convert to a standardized format
result = {
'symbol': symbol,
'bid': float(ticker['bidPrice']),
'ask': float(ticker['askPrice']),
'last': float(ticker['lastPrice']),
'volume': float(ticker['volume']),
'timestamp': int(ticker['closeTime'])
}
return result
except Exception as e:
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
raise
def place_order(self, symbol: str, side: str, order_type: str,
quantity: float, price: float = None) -> Dict[str, Any]:
"""Place an order on the exchange.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
side: Order side ('buy' or 'sell')
order_type: Order type ('market', 'limit', etc.)
quantity: Order quantity
price: Order price (for limit orders)
Returns:
dict: Order information including order ID
"""
mexc_symbol = symbol.replace('/', '')
params = {
'symbol': mexc_symbol,
'side': side.upper(),
'type': order_type.upper(),
'quantity': quantity,
}
if order_type.lower() == 'limit' and price is not None:
params['price'] = price
params['timeInForce'] = 'GTC' # Good Till Cancelled
try:
order_result = self._send_private_request('POST', 'order', params)
return order_result
except Exception as e:
logger.error(f"Error placing {side} {order_type} order for {symbol}: {str(e)}")
raise
def cancel_order(self, symbol: str, order_id: str) -> bool:
"""Cancel an existing order.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
order_id: ID of the order to cancel
Returns:
bool: True if cancellation successful, False otherwise
"""
mexc_symbol = symbol.replace('/', '')
params = {
'symbol': mexc_symbol,
'orderId': order_id
}
try:
cancel_result = self._send_private_request('DELETE', 'order', params)
return True
except Exception as e:
logger.error(f"Error cancelling order {order_id} for {symbol}: {str(e)}")
return False
def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]:
"""Get status of an existing order.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
order_id: ID of the order
Returns:
dict: Order status information
"""
mexc_symbol = symbol.replace('/', '')
params = {
'symbol': mexc_symbol,
'orderId': order_id
}
try:
order_info = self._send_private_request('GET', 'order', params)
return order_info
except Exception as e:
logger.error(f"Error getting order status for {order_id} on {symbol}: {str(e)}")
raise
def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]:
"""Get all open orders, optionally filtered by symbol.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT'), or None for all symbols
Returns:
list: List of open orders
"""
params = {}
if symbol:
params['symbol'] = symbol.replace('/', '')
try:
open_orders = self._send_private_request('GET', 'openOrders', params)
return open_orders
except Exception as e:
logger.error(f"Error getting open orders: {str(e)}")
return []

244
NN/main.py Normal file
View File

@ -0,0 +1,244 @@
"""
Neural Network Trading System Main Module (Compatibility Layer)
This module serves as a compatibility layer for the realtime.py module.
It re-exports the functionality from realtime_main.py that is needed by realtime.py.
"""
import os
import sys
import logging
from datetime import datetime
import numpy as np
# Configure logging
logger = logging.getLogger('NN')
logger.setLevel(logging.INFO)
# Re-export everything from realtime_main.py
from .realtime_main import (
parse_arguments,
realtime,
train,
predict
)
# Create a class that realtime.py expects
class NeuralNetworkOrchestrator:
"""
Orchestrates the neural network operations.
"""
def __init__(self, config):
"""
Initialize the orchestrator with configuration.
Args:
config (dict): Configuration parameters
"""
self.config = config
self.symbol = config.get('symbol', 'BTC/USDT')
self.timeframes = config.get('timeframes', ['1m', '5m', '1h', '4h'])
self.window_size = config.get('window_size', 20)
self.n_features = config.get('n_features', 5)
self.output_size = config.get('output_size', 3)
self.model_dir = config.get('model_dir', 'NN/models/saved')
self.data_dir = config.get('data_dir', 'NN/data')
self.model = None
self.data_interface = None
# Initialize with default values in case imports fail
self.model_initialized = False
self.data_initialized = False
# Import necessary modules dynamically
try:
from .utils.data_interface import DataInterface
# Initialize data interface
self.data_interface = DataInterface(
symbol=self.symbol,
timeframes=self.timeframes
)
self.data_initialized = True
logger.info(f"Data interface initialized for {self.symbol}")
try:
from .models.cnn_model_pytorch import CNNModelPyTorch as Model
# Initialize model
feature_count = self.data_interface.get_feature_count() if hasattr(self.data_interface, 'get_feature_count') else 5
try:
# First try with expected parameters
self.model = Model(
window_size=self.window_size,
num_features=feature_count,
output_size=self.output_size,
timeframes=self.timeframes
)
except TypeError as e:
logger.warning(f"TypeError in model initialization with num_features: {str(e)}")
# Try alternate parameter naming
try:
self.model = Model(
input_shape=(self.window_size, feature_count),
output_size=self.output_size
)
logger.info("Model initialized with alternate parameters")
except Exception as ex:
logger.error(f"Failed to initialize model with alternate parameters: {str(ex)}")
self.model = DummyModel()
# Try to load the best model
self._load_model()
self.model_initialized = True
logger.info("Model initialized successfully")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
import traceback
logger.error(traceback.format_exc())
self.model = DummyModel()
logger.info(f"NeuralNetworkOrchestrator initialized with config: {config}")
except Exception as e:
logger.error(f"Error initializing NeuralNetworkOrchestrator: {str(e)}")
import traceback
logger.error(traceback.format_exc())
self.model = DummyModel()
def _load_model(self):
"""Load the best trained model from available files"""
try:
model_paths = [
os.path.join(self.model_dir, "dqn_agent_best_policy.pt"),
os.path.join(self.model_dir, "cnn_model_best.pt"),
os.path.join("models/saved", "dqn_agent_best_policy.pt"),
os.path.join("models/saved", "cnn_model_best.pt")
]
for model_path in model_paths:
if os.path.exists(model_path):
try:
self.model.load(model_path)
logger.info(f"Loaded model from {model_path}")
return True
except Exception as e:
logger.warning(f"Failed to load model from {model_path}: {str(e)}")
continue
logger.warning("No trained model found, using dummy model")
self.model = DummyModel()
return False
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
self.model = DummyModel()
return False
def run_inference_pipeline(self, model_type='cnn', timeframe='1h'):
"""
Run the inference pipeline using the trained model.
Args:
model_type (str): Type of model to use (cnn, transformer, etc.)
timeframe (str): Timeframe to use for inference
Returns:
dict: Inference result
"""
try:
# Check if we have a model
if not hasattr(self, 'model') or self.model is None:
logger.warning("No model available, initializing dummy model")
self.model = DummyModel()
# Check if we have a data interface
if not hasattr(self, 'data_interface') or self.data_interface is None:
logger.warning("No data interface available")
# Return a dummy prediction
return self._get_dummy_prediction()
# Prepare input data for the selected timeframe
X, timestamp = self.data_interface.prepare_realtime_input(
timeframe=timeframe,
n_candles=self.window_size + 10, # Extra candles for safety
window_size=self.window_size
)
if X is None:
logger.warning(f"No data available for {self.symbol}")
return self._get_dummy_prediction()
# Get model predictions
action_probs, price_pred = self.model.predict(X)
# Convert predictions to action
action_idx = np.argmax(action_probs) if hasattr(action_probs, 'argmax') else 1 # Default to HOLD
action_names = ['SELL', 'HOLD', 'BUY']
action = action_names[action_idx]
# Format timestamp
if not isinstance(timestamp, str):
try:
if hasattr(timestamp, 'isoformat'): # If it's already a datetime-like object
timestamp = timestamp.isoformat()
else: # If it's a numeric timestamp
timestamp = datetime.fromtimestamp(float(timestamp)/1000).isoformat()
except (TypeError, ValueError):
timestamp = datetime.now().isoformat()
# Return result
result = {
'timestamp': timestamp,
'action': action,
'action_index': int(action_idx),
'probability': float(action_probs[action_idx]) if hasattr(action_probs, '__getitem__') else 0.33,
'probabilities': {name: float(prob) for name, prob in zip(action_names, action_probs)} if hasattr(action_probs, '__iter__') else {'SELL': 0.33, 'HOLD': 0.34, 'BUY': 0.33},
'price_prediction': float(price_pred) if price_pred is not None else None
}
logger.info(f"Inference result: {result}")
return result
except Exception as e:
logger.error(f"Error in inference pipeline: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return self._get_dummy_prediction()
def _get_dummy_prediction(self):
"""Return a dummy prediction when model or data is unavailable"""
action_names = ['SELL', 'HOLD', 'BUY']
action_idx = 1 # Default to HOLD
timestamp = datetime.now().isoformat()
return {
'timestamp': timestamp,
'action': 'HOLD',
'action_index': action_idx,
'probability': 0.8,
'probabilities': {'SELL': 0.1, 'HOLD': 0.8, 'BUY': 0.1},
'price_prediction': None,
'is_dummy': True
}
class DummyModel:
"""Dummy model that returns random predictions"""
def __init__(self):
logger.info("Initializing dummy model")
def predict(self, X):
"""Return random predictions"""
# Generate random probabilities for SELL, HOLD, BUY
action_probs = np.array([0.1, 0.8, 0.1]) # Bias towards HOLD
# Generate a random price prediction (None for now)
price_pred = None
return action_probs, price_pred
def load(self, model_path):
"""Dummy load method"""
logger.info(f"Dummy model pretending to load from {model_path}")
return True

View File

@ -72,6 +72,9 @@ class CNNPyTorch(nn.Module):
"""
super(CNNPyTorch, self).__init__()
# Set device
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
window_size, num_features = input_shape
self.window_size = window_size
@ -225,75 +228,94 @@ class CNNModelPyTorch:
predictions with the CNN model, optimized for short-term trading opportunities.
"""
def __init__(self, window_size, num_features, output_size=3, timeframes=None):
def __init__(self, window_size=20, timeframes=None, output_size=3, num_pairs=3):
"""
Initialize the CNN model.
Args:
window_size (int): Size of the input window
num_features (int): Number of features in the input data
output_size (int): Size of the output (default: 3 for BUY/HOLD/SELL)
timeframes (list): List of timeframes used (for logging)
window_size (int): Size of the sliding window
timeframes (list): List of timeframes used
output_size (int): Number of output classes (3 for BUY/HOLD/SELL)
num_pairs (int): Number of trading pairs to analyze in parallel (default 3)
"""
# Action tracking
self.action_counts = {
'BUY': 0,
'SELL': 0,
'HOLD': 0
}
self.window_size = window_size
self.num_features = num_features
self.timeframes = timeframes if timeframes else ["1m", "5m", "15m"]
self.output_size = output_size
self.timeframes = timeframes or []
self.num_pairs = num_pairs
# Determine device (GPU or CPU)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Calculate total features (5 OHLCV features per timeframe per pair)
self.total_features = len(self.timeframes) * 5 * self.num_pairs
# Initialize model
self.model = None
self.build_model()
# Build the model
logger.info(f"Building PyTorch CNN model with window_size={window_size}, "
f"num_features={self.total_features}, output_size={output_size}, "
f"num_pairs={num_pairs}")
# Initialize training history
self.history = {
'loss': [],
'val_loss': [],
'accuracy': [],
'val_accuracy': []
}
# Calculate channel sizes that are divisible by num_pairs
base_channels = 96 # 96 is divisible by 3
self.model = nn.Sequential(
# First convolutional layer - process each pair's features
nn.Sequential(
nn.Conv1d(self.total_features, base_channels, kernel_size=5, padding=2, groups=num_pairs),
nn.ReLU(),
nn.BatchNorm1d(base_channels),
nn.Dropout(0.2)
),
# Sensitivity parameters for high-leverage trading
self.confidence_threshold = 0.65 # Minimum confidence for trading actions
self.max_consecutive_same_action = 3 # Limit consecutive identical actions
self.last_actions = [] # Track recent actions
# Second convolutional layer - start mixing pair information
nn.Sequential(
nn.Conv1d(base_channels, base_channels*2, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm1d(base_channels*2),
nn.Dropout(0.2)
),
def build_model(self):
"""Build the CNN model architecture"""
logger.info(f"Building PyTorch CNN model with window_size={self.window_size}, "
f"num_features={self.num_features}, output_size={self.output_size}")
# Third convolutional layer - deeper feature extraction
nn.Sequential(
nn.Conv1d(base_channels*2, base_channels*4, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm1d(base_channels*4),
nn.Dropout(0.2)
),
# Ensure window size is not less than the actual input
input_window_size = max(self.window_size, 20) # Use at least 20 as minimum window size
# Global average pooling
nn.AdaptiveAvgPool1d(1),
self.model = CNNPyTorch(
input_shape=(input_window_size, self.num_features),
output_size=self.output_size
# Flatten
nn.Flatten(),
# Dense layers for action prediction with cross-pair attention
nn.Sequential(
nn.Linear(base_channels*4, base_channels*2),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(base_channels*2, base_channels),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(base_channels, output_size * num_pairs) # Output for each pair
)
).to(self.device)
# Initialize optimizer with higher learning rate for faster adaptation
self.optimizer = optim.Adam(self.model.parameters(), lr=0.002)
# Learning rate scheduler with faster decay
# Initialize optimizer and loss function
self.optimizer = optim.Adam(self.model.parameters(), lr=0.0005)
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode='max', factor=0.6, patience=6, verbose=True
self.optimizer, mode='max', factor=0.5, patience=5, verbose=True
)
self.criterion = nn.CrossEntropyLoss()
# Initialize loss function with higher weights for BUY/SELL
class_weights = torch.tensor([7.0, 1.0, 7.0]).to(self.device) # Even higher weights for BUY/SELL
self.criterion = nn.CrossEntropyLoss(weight=class_weights)
# Initialize metrics tracking
self.train_losses = []
self.val_losses = []
self.train_accuracies = []
self.val_accuracies = []
logger.info(f"Model built successfully with {sum(p.numel() for p in self.model.parameters())} parameters")
# Sensitivity parameters for high-leverage trading
self.confidence_threshold = 0.65
self.max_consecutive_same_action = 3
self.last_actions = [[] for _ in range(num_pairs)] # Track recent actions per pair
def compute_trading_loss(self, action_probs, price_pred, targets, future_prices=None):
"""
Custom loss function that prioritizes profitable trades
@ -644,12 +666,12 @@ class CNNModelPyTorch:
action_probs_np[:, 2] *= 1.3 # Boost BUY probabilities
# Implement signal filtering based on previous actions to avoid oscillation
if len(self.last_actions) >= self.max_consecutive_same_action:
if len(self.last_actions[0]) >= self.max_consecutive_same_action:
# Check for too many consecutive identical actions
if all(a == 0 for a in self.last_actions[-self.max_consecutive_same_action:]):
if all(a == 0 for a in self.last_actions[0][-self.max_consecutive_same_action:]):
# Too many consecutive SELL - reduce sell probability
action_probs_np[:, 0] *= 0.7
elif all(a == 2 for a in self.last_actions[-self.max_consecutive_same_action:]):
elif all(a == 2 for a in self.last_actions[0][-self.max_consecutive_same_action:]):
# Too many consecutive BUY - reduce buy probability
action_probs_np[:, 2] *= 0.7
@ -666,9 +688,9 @@ class CNNModelPyTorch:
# Store the predicted action for the most recent input
if action_probs_np.shape[0] > 0:
latest_action = np.argmax(action_probs_np[-1])
self.last_actions.append(int(latest_action))
self.last_actions[0].append(int(latest_action))
# Keep only the most recent actions
self.last_actions = self.last_actions[-10:] # Store last 10 actions
self.last_actions[0] = self.last_actions[0][-10:] # Store last 10 actions
# Update action counts for stats
actions = np.argmax(action_probs_np, axis=1)
@ -676,11 +698,11 @@ class CNNModelPyTorch:
action_dict = dict(zip(unique, counts))
if 0 in action_dict:
self.action_counts['SELL'] += action_dict[0]
self.action_counts['SELL'][0] += action_dict[0]
if 1 in action_dict:
self.action_counts['HOLD'] += action_dict[1]
self.action_counts['HOLD'][0] += action_dict[1]
if 2 in action_dict:
self.action_counts['BUY'] += action_dict[2]
self.action_counts['BUY'][0] += action_dict[2]
# Get the current close prices from the input
current_prices = X_tensor[:, -1, 3].cpu().numpy() if X_tensor.shape[2] > 3 else np.zeros(X_tensor.shape[0])
@ -838,20 +860,25 @@ class CNNModelPyTorch:
f"val_loss: {val_loss:.4f} - val_acc: {val_acc:.4f}")
# Update history
self.history['loss'].append(epoch_loss)
self.history['accuracy'].append(epoch_acc)
self.history['val_loss'].append(val_loss)
self.history['val_accuracy'].append(val_acc)
self.train_losses.append(epoch_loss)
self.train_accuracies.append(epoch_acc)
self.val_losses.append(val_loss)
self.val_accuracies.append(val_acc)
else:
logger.info(f"Epoch {epoch+1}/{epochs} - "
f"loss: {epoch_loss:.4f} - acc: {epoch_acc:.4f}")
# Update history without validation
self.history['loss'].append(epoch_loss)
self.history['accuracy'].append(epoch_acc)
self.train_losses.append(epoch_loss)
self.train_accuracies.append(epoch_acc)
logger.info("Training completed")
return self.history
return {
'loss': self.train_losses,
'accuracy': self.train_accuracies,
'val_loss': self.val_losses,
'val_accuracy': self.val_accuracies
}
def evaluate_metrics(self, X_test, y_test):
"""
@ -892,9 +919,14 @@ class CNNModelPyTorch:
model_state = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'history': self.history,
'history': {
'loss': self.train_losses,
'accuracy': self.train_accuracies,
'val_loss': self.val_losses,
'val_accuracy': self.val_accuracies
},
'window_size': self.window_size,
'num_features': self.num_features,
'num_features': self.total_features,
'output_size': self.output_size,
'timeframes': self.timeframes,
# Save trading configuration
@ -930,12 +962,12 @@ class CNNModelPyTorch:
# Update model parameters
self.window_size = model_state['window_size']
self.num_features = model_state['num_features']
self.total_features = model_state['num_features']
self.output_size = model_state['output_size']
self.timeframes = model_state.get('timeframes', ["1m"])
# Load model state dict
self.load_state_dict(model_state['model_state_dict'])
self.model.load_state_dict(model_state['model_state_dict'])
# Load optimizer state if available
if 'optimizer_state_dict' in model_state:

View File

@ -0,0 +1,287 @@
import logging
import threading
import time
from typing import Dict, Any, List, Optional, Callable, Tuple
import os
import numpy as np
import pandas as pd
from .trading_agent import TradingAgent
logger = logging.getLogger(__name__)
class NeuralNetworkOrchestrator:
"""Orchestrator for neural network models and trading operations.
This class coordinates between neural network models and trading agents,
ensuring that signals from the models are properly processed and trades
are executed according to the strategy.
"""
def __init__(self, model, data_interface, chart=None,
symbols: List[str] = None,
timeframes: List[str] = None,
window_size: int = 20,
num_features: int = 5,
output_size: int = 3,
models_dir: str = "NN/models/saved",
data_dir: str = "NN/data",
exchange_config: Dict[str, Any] = None):
"""Initialize the neural network orchestrator.
Args:
model: Neural network model instance
data_interface: Data interface for retrieving market data
chart: Real-time chart for visualization (optional)
symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT'])
timeframes: List of timeframes to monitor (e.g., ['1m', '5m', '1h'])
window_size: Window size for model input
num_features: Number of features per datapoint
output_size: Number of output classes (e.g., 3 for BUY/HOLD/SELL)
models_dir: Directory for saved models
data_dir: Directory for data storage
exchange_config: Configuration for trading agent (exchange, API keys, etc.)
"""
self.model = model
self.data_interface = data_interface
self.chart = chart
self.symbols = symbols or ["BTC/USDT"]
self.timeframes = timeframes or ["1m", "5m", "1h", "4h", "1d"]
self.window_size = window_size
self.num_features = num_features
self.output_size = output_size
self.models_dir = models_dir
self.data_dir = data_dir
# Initialize trading agent if configuration provided
self.trading_agent = None
if exchange_config:
self.init_trading_agent(exchange_config)
# Initialize inference state
self.is_running = False
self.inference_thread = None
self.stop_event = threading.Event()
self.last_inference_time = 0
self.inference_interval = int(os.environ.get("NN_INFERENCE_INTERVAL", "60"))
logger.info(f"Initializing NeuralNetworkOrchestrator with:")
logger.info(f"- Symbol: {self.symbols[0]}")
logger.info(f"- Timeframes: {', '.join(self.timeframes)}")
logger.info(f"- Window size: {window_size}")
logger.info(f"- Num features: {num_features}")
logger.info(f"- Output size: {output_size}")
logger.info(f"- Models dir: {models_dir}")
logger.info(f"- Data dir: {data_dir}")
logger.info(f"- Inference interval: {self.inference_interval} seconds")
def init_trading_agent(self, config: Dict[str, Any]):
"""Initialize the trading agent with the given configuration.
Args:
config: Configuration for the trading agent
"""
exchange_name = config.get("exchange", "binance")
api_key = config.get("api_key")
api_secret = config.get("api_secret")
test_mode = config.get("test_mode", True)
trade_symbols = config.get("trade_symbols", self.symbols)
position_size = config.get("position_size", 0.1)
max_trades_per_day = config.get("max_trades_per_day", 5)
trade_cooldown_minutes = config.get("trade_cooldown_minutes", 60)
self.trading_agent = TradingAgent(
exchange_name=exchange_name,
api_key=api_key,
api_secret=api_secret,
test_mode=test_mode,
trade_symbols=trade_symbols,
position_size=position_size,
max_trades_per_day=max_trades_per_day,
trade_cooldown_minutes=trade_cooldown_minutes
)
logger.info(f"Trading agent initialized for {exchange_name} exchange.")
def start_inference(self):
"""Start the inference thread."""
if self.is_running:
logger.warning("Neural network inference is already running.")
return
self.is_running = True
self.stop_event.clear()
# Start inference thread
self.inference_thread = threading.Thread(target=self._inference_loop)
self.inference_thread.daemon = True
self.inference_thread.start()
logger.info(f"Neural network inference thread started with {self.inference_interval}s interval.")
# Start trading agent if available
if self.trading_agent:
self.trading_agent.start(signal_callback=self._on_trade_executed)
def stop_inference(self):
"""Stop the inference thread."""
if not self.is_running:
logger.warning("Neural network inference is not running.")
return
logger.info("Stopping neural network inference...")
self.is_running = False
self.stop_event.set()
if self.inference_thread and self.inference_thread.is_alive():
self.inference_thread.join(timeout=10)
logger.info("Neural network inference stopped.")
# Stop trading agent if available
if self.trading_agent:
self.trading_agent.stop()
def _inference_loop(self):
"""Main inference loop that processes data and generates signals."""
logger.info("Inference loop started.")
try:
while self.is_running and not self.stop_event.is_set():
current_time = time.time()
# Check if we should run inference
if current_time - self.last_inference_time >= self.inference_interval:
try:
# Run inference for all symbols
for symbol in self.symbols:
prediction = self._run_inference(symbol)
if prediction:
self._process_prediction(symbol, prediction)
self.last_inference_time = current_time
except Exception as e:
logger.error(f"Error during inference: {str(e)}")
# Sleep for a short time to prevent CPU hogging
time.sleep(1)
except Exception as e:
logger.error(f"Error in inference loop: {str(e)}")
finally:
logger.info("Inference loop stopped.")
def _run_inference(self, symbol: str) -> Optional[Tuple[np.ndarray, float]]:
"""Run inference for a specific symbol.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
Returns:
tuple: (action probabilities, current price) or None if inference failed
"""
try:
# Get the model timeframe from environment
model_timeframe = os.environ.get("NN_TIMEFRAME", "1h")
if model_timeframe not in self.timeframes:
logger.warning(f"Model timeframe {model_timeframe} not in available timeframes. Using {self.timeframes[0]}.")
model_timeframe = self.timeframes[0]
# Load candles for the model timeframe
logger.info(f"Loading {1000} candles from cache for {symbol} at {model_timeframe} timeframe")
candles = self.data_interface.get_historical_data(
symbol=symbol,
timeframe=model_timeframe,
n_candles=1000
)
if candles is None or len(candles) < self.window_size:
logger.warning(f"Not enough data for {symbol} at {model_timeframe} timeframe. Need at least {self.window_size} candles.")
return None
# Prepare input data
X, timestamp = self.data_interface.prepare_model_input(
data=candles,
window_size=self.window_size,
symbol=symbol
)
if X is None:
logger.warning(f"Failed to prepare model input for {symbol}.")
return None
# Get current price
current_price = candles['close'].iloc[-1]
# Run model inference
action_probs, price_pred = self.model.predict(X)
return action_probs, current_price
except Exception as e:
logger.error(f"Error running inference for {symbol}: {str(e)}")
return None
def _process_prediction(self, symbol: str, prediction: Tuple[np.ndarray, float]):
"""Process a prediction and generate signals.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
prediction: Tuple of (action probabilities, current price)
"""
action_probs, current_price = prediction
# Get the best action (0=SELL, 1=HOLD, 2=BUY)
best_action = np.argmax(action_probs)
best_prob = float(action_probs[best_action])
# Convert to action name
action_names = ["SELL", "HOLD", "BUY"]
action_name = action_names[best_action]
# Log the prediction
logger.info(f"Inference result for {symbol}: Action={action_name}, Probability={best_prob:.2f}, Price={current_price:.2f}")
# Add signal to chart if available
if self.chart:
self.chart.add_nn_signal(symbol=symbol, signal=action_name, confidence=best_prob, timestamp=int(time.time()))
# Process signal with trading agent if available
if self.trading_agent:
self.trading_agent.process_signal(
symbol=symbol,
action=action_name,
confidence=best_prob,
timestamp=int(time.time())
)
def _on_trade_executed(self, trade_record: Dict[str, Any]):
"""Callback for when a trade is executed.
Args:
trade_record: Trade information
"""
if self.chart and trade_record:
# Add trade to chart
self.chart.add_trade(
action=trade_record['action'],
price=trade_record.get('price', 0),
timestamp=trade_record['timestamp'],
pnl=trade_record.get('pnl', 0)
)
logger.info(f"Trade added to chart: {trade_record['action']} at {trade_record.get('price', 0):.2f}")
def get_trading_agent_info(self) -> Dict[str, Any]:
"""Get information about the trading agent.
Returns:
dict: Trading agent information or None if no agent is available
"""
if self.trading_agent:
return {
'exchange_info': self.trading_agent.get_exchange_info(),
'positions': self.trading_agent.get_current_positions(),
'trades': len(self.trading_agent.get_trade_history())
}
return None

View File

@ -13,6 +13,7 @@ import argparse
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import time
# Configure logging
logger = logging.getLogger('NN')
@ -100,7 +101,7 @@ def main():
# Verify data interface by fetching initial data
logger.info("Verifying data interface...")
X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True)
if X_sample is None or y_sample is None:
if X_sample is None or y_sample is not None:
logger.error("Failed to prepare initial training data")
return
@ -369,12 +370,12 @@ def predict(data_interface, model, args):
except Exception as e:
logger.error(f"Error in prediction mode: {str(e)}")
def realtime(data_interface, model, args):
"""Run the model in real-time mode"""
logger.info("Starting real-time mode...")
def realtime(data_interface, model, args, chart=None, symbol=None):
"""Run real-time inference with the trained model"""
logger.info(f"Starting real-time inference mode for {symbol}...")
try:
from NN.utils.realtime_analyzer import RealtimeAnalyzer
from NN.utils.realtime_analyzer import RealtimeAnalyzer
# Load the latest model
model_dir = os.path.join('models')
@ -403,8 +404,104 @@ def realtime(data_interface, model, args):
logger.info("Starting real-time analysis...")
realtime_analyzer.start()
# Initialize variables for tracking performance
total_pnl = 0.0
trades = []
current_position = 0.0
last_action = None
last_price = None
# Get the pair index for this symbol
pair_index = args.symbols.index(symbol)
# Only execute trades if this is the main pair (BTC/USDT)
is_main_pair = symbol == "BTC/USDT"
while True:
# Get current market data for all pairs
all_pairs_data = []
for s in args.symbols:
X, timestamp = data_interface.prepare_realtime_input(
timeframe=args.timeframes[0], # Use shortest timeframe
n_candles=args.window_size + 10, # Extra candles for safety
window_size=args.window_size
)
if X is not None:
all_pairs_data.append(X)
else:
logger.warning(f"No data available for {s}")
time.sleep(1)
continue
if not all_pairs_data:
logger.warning("No data available for any pair")
time.sleep(1)
continue
# Stack data from all pairs for model input
X_combined = np.concatenate(all_pairs_data, axis=2)
# Get model predictions
action_probs, price_pred = model.predict(X_combined)
# Get predictions for this specific pair
action = np.argmax(action_probs[pair_index]) # 0=SELL, 1=HOLD, 2=BUY
# Get current price for the main pair
current_price = data_interface.get_historical_data(
timeframe=args.timeframes[0],
n_candles=1
)['close'].iloc[-1]
# Calculate PnL if we have a position (only for main pair)
pnl = 0.0
if is_main_pair and last_action is not None and last_price is not None:
if last_action == 2: # BUY
pnl = (current_price - last_price) / last_price
elif last_action == 0: # SELL
pnl = (last_price - current_price) / last_price
# Update total PnL (only for main pair)
if is_main_pair and pnl != 0:
total_pnl += pnl
# Log the prediction
action_name = "SELL" if action == 0 else "HOLD" if action == 1 else "BUY"
log_msg = f"Time: {timestamp}, Symbol: {symbol}, Action: {action_name}, "
if is_main_pair:
log_msg += f"Price: {current_price:.2f}, PnL: {pnl:.2%}, Total PnL: {total_pnl:.2%}"
else:
log_msg += f"Price: {current_price:.2f} (Context Only)"
logger.info(log_msg)
# Update the chart if provided (only for main pair)
if chart is not None and is_main_pair and action != 1: # Skip HOLD actions
chart.add_trade(
action=action_name,
price=current_price,
timestamp=timestamp,
pnl=pnl
)
# Update tracking variables (only for main pair)
if is_main_pair and action != 1: # If not HOLD
last_action = action
last_price = current_price
# Sleep for a short time
time.sleep(1)
except KeyboardInterrupt:
if is_main_pair:
logger.info(f"Real-time inference stopped by user for {symbol}")
logger.info(f"Final performance for {symbol} - Total PnL: {total_pnl:.2%}")
else:
logger.info(f"Real-time inference stopped by user for {symbol} (Context Only)")
except Exception as e:
logger.error(f"Error in real-time mode: {str(e)}")
logger.error(f"Error in real-time inference for {symbol}: {str(e)}")
raise
if __name__ == "__main__":
main()

310
NN/trading_agent.py Normal file
View File

@ -0,0 +1,310 @@
import logging
import time
import threading
from typing import Dict, Any, List, Optional, Callable, Tuple, Union
from .exchanges import ExchangeInterface, MEXCInterface, BinanceInterface
logger = logging.getLogger(__name__)
class TradingAgent:
"""Trading agent that executes trades based on neural network signals.
This agent interfaces with different exchanges and executes trades
based on the signals received from the neural network.
"""
def __init__(self,
exchange_name: str = 'binance',
api_key: str = None,
api_secret: str = None,
test_mode: bool = True,
trade_symbols: List[str] = None,
position_size: float = 0.1,
max_trades_per_day: int = 5,
trade_cooldown_minutes: int = 60):
"""Initialize the trading agent.
Args:
exchange_name: Name of the exchange to use ('binance', 'mexc')
api_key: API key for the exchange
api_secret: API secret for the exchange
test_mode: If True, use test/sandbox environment
trade_symbols: List of trading symbols to monitor (e.g., ['BTC/USDT'])
position_size: Size of each position as a fraction of total available balance (0.0-1.0)
max_trades_per_day: Maximum number of trades to execute per day
trade_cooldown_minutes: Minimum time between trades in minutes
"""
self.exchange_name = exchange_name.lower()
self.api_key = api_key
self.api_secret = api_secret
self.test_mode = test_mode
self.trade_symbols = trade_symbols or ['BTC/USDT']
self.position_size = min(max(position_size, 0.01), 1.0) # Ensure between 0.01 and 1.0
self.max_trades_per_day = max(1, max_trades_per_day)
self.trade_cooldown_seconds = max(60, trade_cooldown_minutes * 60)
# Initialize exchange interface
self.exchange = self._create_exchange()
# Trading state
self.active = False
self.current_positions = {} # Symbol -> quantity
self.trades_today = {} # Symbol -> count
self.last_trade_time = {} # Symbol -> timestamp
self.trade_history = [] # List of trade records
# Threading
self.trading_thread = None
self.stop_event = threading.Event()
# Signal callback
self.signal_callback = None
# Connect to exchange
if not self.exchange.connect():
logger.error(f"Failed to connect to {self.exchange_name} exchange. Trading agent disabled.")
else:
logger.info(f"Successfully connected to {self.exchange_name} exchange.")
self._load_current_positions()
def _create_exchange(self) -> ExchangeInterface:
"""Create an exchange interface based on the exchange name."""
if self.exchange_name == 'mexc':
return MEXCInterface(
api_key=self.api_key,
api_secret=self.api_secret,
test_mode=self.test_mode
)
elif self.exchange_name == 'binance':
return BinanceInterface(
api_key=self.api_key,
api_secret=self.api_secret,
test_mode=self.test_mode
)
else:
raise ValueError(f"Unsupported exchange: {self.exchange_name}")
def _load_current_positions(self):
"""Load current positions from the exchange."""
for symbol in self.trade_symbols:
try:
base_asset, quote_asset = symbol.split('/')
balance = self.exchange.get_balance(base_asset)
if balance > 0:
self.current_positions[symbol] = balance
logger.info(f"Loaded existing position for {symbol}: {balance} {base_asset}")
except Exception as e:
logger.error(f"Error loading position for {symbol}: {str(e)}")
def start(self, signal_callback: Callable = None):
"""Start the trading agent.
Args:
signal_callback: Optional callback function to receive trade signals
"""
if self.active:
logger.warning("Trading agent is already running.")
return
self.active = True
self.signal_callback = signal_callback
self.stop_event.clear()
logger.info(f"Starting trading agent for {self.exchange_name} exchange.")
logger.info(f"Trading symbols: {', '.join(self.trade_symbols)}")
logger.info(f"Position size: {self.position_size * 100:.1f}% of available balance")
logger.info(f"Max trades per day: {self.max_trades_per_day}")
logger.info(f"Trade cooldown: {self.trade_cooldown_seconds // 60} minutes")
# Reset trading state
self.trades_today = {symbol: 0 for symbol in self.trade_symbols}
self.last_trade_time = {symbol: 0 for symbol in self.trade_symbols}
# Start trading thread
self.trading_thread = threading.Thread(target=self._trading_loop)
self.trading_thread.daemon = True
self.trading_thread.start()
def stop(self):
"""Stop the trading agent."""
if not self.active:
logger.warning("Trading agent is not running.")
return
logger.info("Stopping trading agent...")
self.active = False
self.stop_event.set()
if self.trading_thread and self.trading_thread.is_alive():
self.trading_thread.join(timeout=10)
logger.info("Trading agent stopped.")
def _trading_loop(self):
"""Main trading loop that monitors positions and executes trades."""
logger.info("Trading loop started.")
try:
while self.active and not self.stop_event.is_set():
# Check positions and update state
for symbol in self.trade_symbols:
try:
base_asset, _ = symbol.split('/')
current_balance = self.exchange.get_balance(base_asset)
# Update position if it has changed
if symbol in self.current_positions:
prev_balance = self.current_positions[symbol]
if abs(current_balance - prev_balance) > 0.001 * prev_balance:
logger.info(f"Position updated for {symbol}: {prev_balance} -> {current_balance} {base_asset}")
self.current_positions[symbol] = current_balance
except Exception as e:
logger.error(f"Error checking position for {symbol}: {str(e)}")
# Sleep for a while
time.sleep(10)
except Exception as e:
logger.error(f"Error in trading loop: {str(e)}")
finally:
logger.info("Trading loop stopped.")
def reset_daily_limits(self):
"""Reset daily trading limits. Call this at the start of each trading day."""
self.trades_today = {symbol: 0 for symbol in self.trade_symbols}
logger.info("Daily trading limits reset.")
def process_signal(self, symbol: str, action: str,
confidence: float = None, timestamp: int = None) -> Optional[Dict[str, Any]]:
"""Process a trading signal and execute a trade if conditions are met.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
action: Trade action ('BUY', 'SELL', 'HOLD')
confidence: Confidence level of the signal (0.0-1.0)
timestamp: Timestamp of the signal (unix time)
Returns:
dict: Trade information if a trade was executed, None otherwise
"""
if not self.active:
logger.warning("Trading agent is not active. Signal ignored.")
return None
if symbol not in self.trade_symbols:
logger.warning(f"Symbol {symbol} is not in the trading symbols list. Signal ignored.")
return None
if action not in ['BUY', 'SELL', 'HOLD']:
logger.warning(f"Invalid action: {action}. Must be 'BUY', 'SELL', or 'HOLD'.")
return None
# Log the signal
confidence_str = f" (confidence: {confidence:.2f})" if confidence is not None else ""
logger.info(f"Received {action} signal for {symbol}{confidence_str}")
# Ignore HOLD signals for trading
if action == 'HOLD':
return None
# Check if we can trade based on limits
current_time = time.time()
# Check max trades per day
if self.trades_today.get(symbol, 0) >= self.max_trades_per_day:
logger.warning(f"Max trades per day reached for {symbol}. Signal ignored.")
return None
# Check trade cooldown
last_trade_time = self.last_trade_time.get(symbol, 0)
if current_time - last_trade_time < self.trade_cooldown_seconds:
cooldown_remaining = self.trade_cooldown_seconds - (current_time - last_trade_time)
logger.warning(f"Trade cooldown active for {symbol}. {cooldown_remaining:.1f} seconds remaining. Signal ignored.")
return None
# Check if the action makes sense based on current position
base_asset, _ = symbol.split('/')
current_position = self.current_positions.get(symbol, 0)
if action == 'BUY' and current_position > 0:
logger.warning(f"Already have a position in {symbol}. BUY signal ignored.")
return None
if action == 'SELL' and current_position <= 0:
logger.warning(f"No position in {symbol} to sell. SELL signal ignored.")
return None
# Execute the trade
try:
trade_result = self.exchange.execute_trade(
symbol=symbol,
action=action,
percent_of_balance=self.position_size
)
if trade_result:
# Update trading state
self.trades_today[symbol] = self.trades_today.get(symbol, 0) + 1
self.last_trade_time[symbol] = current_time
# Create trade record
trade_record = {
'symbol': symbol,
'action': action,
'timestamp': timestamp or int(current_time),
'confidence': confidence,
'order_id': trade_result.get('orderId') if isinstance(trade_result, dict) else None,
'status': 'executed'
}
# Add to trade history
self.trade_history.append(trade_record)
# Call signal callback if provided
if self.signal_callback:
self.signal_callback(trade_record)
logger.info(f"Successfully executed {action} trade for {symbol}")
return trade_record
else:
logger.error(f"Failed to execute {action} trade for {symbol}")
return None
except Exception as e:
logger.error(f"Error executing trade for {symbol}: {str(e)}")
return None
def get_current_positions(self) -> Dict[str, float]:
"""Get current positions.
Returns:
dict: Symbol -> position size
"""
return self.current_positions.copy()
def get_trade_history(self) -> List[Dict[str, Any]]:
"""Get trade history.
Returns:
list: List of trade records
"""
return self.trade_history.copy()
def get_exchange_info(self) -> Dict[str, Any]:
"""Get exchange information.
Returns:
dict: Exchange information
"""
return {
'name': self.exchange_name,
'test_mode': self.test_mode,
'active': self.active,
'trade_symbols': self.trade_symbols,
'position_size': self.position_size,
'max_trades_per_day': self.max_trades_per_day,
'trade_cooldown_seconds': self.trade_cooldown_seconds,
'trades_today': self.trades_today.copy()
}

View File

@ -247,9 +247,27 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo
)
# Get training data for each timeframe with more data
logger.info("Loading training data...")
features_1m = data_interface.get_training_data("1m", n_candles=5000)
if features_1m is not None:
logger.info(f"Loaded {len(features_1m)} 1m candles")
else:
logger.error("Failed to load 1m data")
return None
features_5m = data_interface.get_training_data("5m", n_candles=2500)
if features_5m is not None:
logger.info(f"Loaded {len(features_5m)} 5m candles")
else:
logger.error("Failed to load 5m data")
return None
features_15m = data_interface.get_training_data("15m", n_candles=2500)
if features_15m is not None:
logger.info(f"Loaded {len(features_15m)} 15m candles")
else:
logger.error("Failed to load 15m data")
return None
if features_1m is None or features_5m is None or features_15m is None:
logger.error("Failed to load training data")

View File

@ -1,5 +1,7 @@
https://github.com/mexcdevelop/mexc-api-sdk/blob/main/README.md#test-new-order
https://mexcdevelop.github.io/apidocs/spot_v3_en/#test-new-order
python mexc_tick_visualizer.py --symbol BTC/USDT --interval 1.0 --candle 60
python main.py --mode live --symbol ETH/USDT --timeframe 1m --use-websocket

124
launch_training.py Normal file
View File

@ -0,0 +1,124 @@
import os
import sys
import subprocess
import time
import logging
from datetime import datetime
import webbrowser
from threading import Thread
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('training_launch.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def start_tensorboard(port=6007):
"""Start TensorBoard on a specified port"""
try:
cmd = f"tensorboard --logdir=runs --port={port}"
process = subprocess.Popen(cmd, shell=True)
logger.info(f"Started TensorBoard on port {port}")
return process
except Exception as e:
logger.error(f"Failed to start TensorBoard: {str(e)}")
return None
def start_web_chart():
"""Start the web chart server"""
try:
cmd = "python main.py --symbols BTC/USDT ETH/USDT SOL/USDT --timeframes 1m 5m 15m --mode realtime"
process = subprocess.Popen(cmd, shell=True)
logger.info("Started web chart server")
return process
except Exception as e:
logger.error(f"Failed to start web chart server: {str(e)}")
return None
def start_training():
"""Start the RL training process"""
try:
cmd = "python NN/train_rl.py"
process = subprocess.Popen(cmd, shell=True)
logger.info("Started RL training process")
return process
except Exception as e:
logger.error(f"Failed to start training process: {str(e)}")
return None
def open_web_interfaces():
"""Open web browsers for TensorBoard and chart after a delay"""
time.sleep(5) # Wait for servers to start
try:
webbrowser.open('http://localhost:6007') # TensorBoard
webbrowser.open('http://localhost:8050') # Web chart
except Exception as e:
logger.error(f"Failed to open web interfaces: {str(e)}")
def monitor_processes(processes):
"""Monitor running processes and log any unexpected terminations"""
while True:
for name, process in processes.items():
if process and process.poll() is not None:
logger.error(f"{name} process terminated unexpectedly")
return False
time.sleep(1)
def main():
"""Main function to orchestrate the training environment"""
logger.info("Starting training environment setup...")
# Start TensorBoard
tensorboard_process = start_tensorboard(port=6007)
if not tensorboard_process:
logger.error("Failed to start TensorBoard")
return
# Start web chart
web_chart_process = start_web_chart()
if not web_chart_process:
tensorboard_process.terminate()
logger.error("Failed to start web chart")
return
# Start training
training_process = start_training()
if not training_process:
tensorboard_process.terminate()
web_chart_process.terminate()
logger.error("Failed to start training")
return
# Open web interfaces in a separate thread
Thread(target=open_web_interfaces).start()
# Monitor processes
processes = {
'tensorboard': tensorboard_process,
'web_chart': web_chart_process,
'training': training_process
}
try:
if not monitor_processes(processes):
raise Exception("One or more processes terminated unexpectedly")
except KeyboardInterrupt:
logger.info("Received shutdown signal")
except Exception as e:
logger.error(f"Error in monitoring: {str(e)}")
finally:
# Cleanup
logger.info("Shutting down training environment...")
for name, process in processes.items():
if process:
process.terminate()
logger.info(f"Terminated {name} process")
logger.info("Training environment shutdown complete")
if __name__ == "__main__":
main()

1107
main.py

File diff suppressed because it is too large Load Diff

155
trading_main.py Normal file
View File

@ -0,0 +1,155 @@
import os
import time
import logging
import sys
import argparse
import json
# Add the NN directory to the Python path
sys.path.append(os.path.abspath("NN"))
from NN.main import load_model
from NN.neural_network_orchestrator import NeuralNetworkOrchestrator
from NN.realtime_data_interface import RealtimeDataInterface
# Initialize logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("trading_bot.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def main():
"""Main function for the trading bot."""
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Trading Bot with Neural Network Integration")
parser.add_argument('--symbols', nargs='+', default=["BTC/USDT", "ETH/USDT"],
help='Trading symbols to monitor')
parser.add_argument('--timeframes', nargs='+', default=["1m", "5m", "1h", "4h", "1d"],
help='Timeframes to monitor')
parser.add_argument('--window-size', type=int, default=20,
help='Window size for model input')
parser.add_argument('--output-size', type=int, default=3,
help='Output size of the model (3 for BUY/HOLD/SELL)')
parser.add_argument('--model-type', type=str, default="cnn", choices=["cnn", "lstm", "mlp"],
help='Type of neural network model')
parser.add_argument('--mode', type=str, default="realtime", choices=["realtime", "backtest"],
help='Trading mode')
parser.add_argument('--exchange', type=str, default="binance", choices=["binance", "mexc"],
help='Exchange to use for trading')
parser.add_argument('--api-key', type=str, default=None,
help='API key for the exchange')
parser.add_argument('--api-secret', type=str, default=None,
help='API secret for the exchange')
parser.add_argument('--test-mode', action='store_true',
help='Use test/sandbox exchange environment')
parser.add_argument('--position-size', type=float, default=0.1,
help='Position size as a fraction of total balance (0.0-1.0)')
parser.add_argument('--max-trades-per-day', type=int, default=5,
help='Maximum number of trades per day')
parser.add_argument('--trade-cooldown', type=int, default=60,
help='Trade cooldown period in minutes')
parser.add_argument('--config-file', type=str, default=None,
help='Path to configuration file')
args = parser.parse_args()
# Load configuration from file if provided
if args.config_file and os.path.exists(args.config_file):
with open(args.config_file, 'r') as f:
config = json.load(f)
# Override config with command-line args
for key, value in vars(args).items():
if key != 'config_file' and value is not None:
config[key] = value
else:
# Use command-line args as config
config = vars(args)
# Initialize real-time charts and data interfaces
try:
from realtime import RealTimeChart
# Create a real-time chart for each symbol
charts = {}
for symbol in config['symbols']:
charts[symbol] = RealTimeChart(symbol=symbol)
main_chart = charts[config['symbols'][0]]
# Create a data interface for retrieving market data
data_interface = RealtimeDataInterface(symbols=config['symbols'], chart=main_chart)
# Load trained model
model_type = os.environ.get("NN_MODEL_TYPE", config['model_type'])
model = load_model(
model_type=model_type,
input_shape=(config['window_size'], len(config['symbols']), 5), # 5 features (OHLCV)
output_size=config['output_size']
)
# Configure trading agent
exchange_config = {
"exchange": config['exchange'],
"api_key": config['api_key'],
"api_secret": config['api_secret'],
"test_mode": config['test_mode'],
"trade_symbols": config['symbols'],
"position_size": config['position_size'],
"max_trades_per_day": config['max_trades_per_day'],
"trade_cooldown_minutes": config['trade_cooldown']
}
# Initialize neural network orchestrator
orchestrator = NeuralNetworkOrchestrator(
model=model,
data_interface=data_interface,
chart=main_chart,
symbols=config['symbols'],
timeframes=config['timeframes'],
window_size=config['window_size'],
num_features=5, # OHLCV
output_size=config['output_size'],
exchange_config=exchange_config
)
# Start data collection
logger.info("Starting data collection threads...")
for symbol in config['symbols']:
charts[symbol].start()
# Start neural network inference
if os.environ.get("ENABLE_NN_MODELS", "0") == "1":
logger.info("Starting neural network inference...")
orchestrator.start_inference()
else:
logger.info("Neural network models disabled. Set ENABLE_NN_MODELS=1 to enable.")
# Start web servers for chart display
logger.info("Starting web servers for chart display...")
main_chart.start_server()
logger.info("Trading bot initialized successfully. Press Ctrl+C to exit.")
# Keep the main thread alive
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("Keyboard interrupt received. Shutting down...")
# Stop all threads
for symbol in config['symbols']:
charts[symbol].stop()
orchestrator.stop_inference()
logger.info("Trading bot stopped.")
except Exception as e:
logger.error(f"Error in main function: {str(e)}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()