310 lines
12 KiB
Python
310 lines
12 KiB
Python
import logging
|
|
import time
|
|
import threading
|
|
from typing import Dict, Any, List, Optional, Callable, Tuple, Union
|
|
|
|
from .exchanges import ExchangeInterface, MEXCInterface, BinanceInterface
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TradingAgent:
|
|
"""Trading agent that executes trades based on neural network signals.
|
|
|
|
This agent interfaces with different exchanges and executes trades
|
|
based on the signals received from the neural network.
|
|
"""
|
|
|
|
def __init__(self,
|
|
exchange_name: str = 'binance',
|
|
api_key: str = None,
|
|
api_secret: str = None,
|
|
test_mode: bool = True,
|
|
trade_symbols: List[str] = None,
|
|
position_size: float = 0.1,
|
|
max_trades_per_day: int = 5,
|
|
trade_cooldown_minutes: int = 60):
|
|
"""Initialize the trading agent.
|
|
|
|
Args:
|
|
exchange_name: Name of the exchange to use ('binance', 'mexc')
|
|
api_key: API key for the exchange
|
|
api_secret: API secret for the exchange
|
|
test_mode: If True, use test/sandbox environment
|
|
trade_symbols: List of trading symbols to monitor (e.g., ['BTC/USDT'])
|
|
position_size: Size of each position as a fraction of total available balance (0.0-1.0)
|
|
max_trades_per_day: Maximum number of trades to execute per day
|
|
trade_cooldown_minutes: Minimum time between trades in minutes
|
|
"""
|
|
self.exchange_name = exchange_name.lower()
|
|
self.api_key = api_key
|
|
self.api_secret = api_secret
|
|
self.test_mode = test_mode
|
|
self.trade_symbols = trade_symbols or ['BTC/USDT']
|
|
self.position_size = min(max(position_size, 0.01), 1.0) # Ensure between 0.01 and 1.0
|
|
self.max_trades_per_day = max(1, max_trades_per_day)
|
|
self.trade_cooldown_seconds = max(60, trade_cooldown_minutes * 60)
|
|
|
|
# Initialize exchange interface
|
|
self.exchange = self._create_exchange()
|
|
|
|
# Trading state
|
|
self.active = False
|
|
self.current_positions = {} # Symbol -> quantity
|
|
self.trades_today = {} # Symbol -> count
|
|
self.last_trade_time = {} # Symbol -> timestamp
|
|
self.trade_history = [] # List of trade records
|
|
|
|
# Threading
|
|
self.trading_thread = None
|
|
self.stop_event = threading.Event()
|
|
|
|
# Signal callback
|
|
self.signal_callback = None
|
|
|
|
# Connect to exchange
|
|
if not self.exchange.connect():
|
|
logger.error(f"Failed to connect to {self.exchange_name} exchange. Trading agent disabled.")
|
|
else:
|
|
logger.info(f"Successfully connected to {self.exchange_name} exchange.")
|
|
self._load_current_positions()
|
|
|
|
def _create_exchange(self) -> ExchangeInterface:
|
|
"""Create an exchange interface based on the exchange name."""
|
|
if self.exchange_name == 'mexc':
|
|
return MEXCInterface(
|
|
api_key=self.api_key,
|
|
api_secret=self.api_secret,
|
|
test_mode=self.test_mode
|
|
)
|
|
elif self.exchange_name == 'binance':
|
|
return BinanceInterface(
|
|
api_key=self.api_key,
|
|
api_secret=self.api_secret,
|
|
test_mode=self.test_mode
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported exchange: {self.exchange_name}")
|
|
|
|
def _load_current_positions(self):
|
|
"""Load current positions from the exchange."""
|
|
for symbol in self.trade_symbols:
|
|
try:
|
|
base_asset, quote_asset = symbol.split('/')
|
|
balance = self.exchange.get_balance(base_asset)
|
|
|
|
if balance > 0:
|
|
self.current_positions[symbol] = balance
|
|
logger.info(f"Loaded existing position for {symbol}: {balance} {base_asset}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading position for {symbol}: {str(e)}")
|
|
|
|
def start(self, signal_callback: Callable = None):
|
|
"""Start the trading agent.
|
|
|
|
Args:
|
|
signal_callback: Optional callback function to receive trade signals
|
|
"""
|
|
if self.active:
|
|
logger.warning("Trading agent is already running.")
|
|
return
|
|
|
|
self.active = True
|
|
self.signal_callback = signal_callback
|
|
self.stop_event.clear()
|
|
|
|
logger.info(f"Starting trading agent for {self.exchange_name} exchange.")
|
|
logger.info(f"Trading symbols: {', '.join(self.trade_symbols)}")
|
|
logger.info(f"Position size: {self.position_size * 100:.1f}% of available balance")
|
|
logger.info(f"Max trades per day: {self.max_trades_per_day}")
|
|
logger.info(f"Trade cooldown: {self.trade_cooldown_seconds // 60} minutes")
|
|
|
|
# Reset trading state
|
|
self.trades_today = {symbol: 0 for symbol in self.trade_symbols}
|
|
self.last_trade_time = {symbol: 0 for symbol in self.trade_symbols}
|
|
|
|
# Start trading thread
|
|
self.trading_thread = threading.Thread(target=self._trading_loop)
|
|
self.trading_thread.daemon = True
|
|
self.trading_thread.start()
|
|
|
|
def stop(self):
|
|
"""Stop the trading agent."""
|
|
if not self.active:
|
|
logger.warning("Trading agent is not running.")
|
|
return
|
|
|
|
logger.info("Stopping trading agent...")
|
|
self.active = False
|
|
self.stop_event.set()
|
|
|
|
if self.trading_thread and self.trading_thread.is_alive():
|
|
self.trading_thread.join(timeout=10)
|
|
|
|
logger.info("Trading agent stopped.")
|
|
|
|
def _trading_loop(self):
|
|
"""Main trading loop that monitors positions and executes trades."""
|
|
logger.info("Trading loop started.")
|
|
|
|
try:
|
|
while self.active and not self.stop_event.is_set():
|
|
# Check positions and update state
|
|
for symbol in self.trade_symbols:
|
|
try:
|
|
base_asset, _ = symbol.split('/')
|
|
current_balance = self.exchange.get_balance(base_asset)
|
|
|
|
# Update position if it has changed
|
|
if symbol in self.current_positions:
|
|
prev_balance = self.current_positions[symbol]
|
|
if abs(current_balance - prev_balance) > 0.001 * prev_balance:
|
|
logger.info(f"Position updated for {symbol}: {prev_balance} -> {current_balance} {base_asset}")
|
|
|
|
self.current_positions[symbol] = current_balance
|
|
except Exception as e:
|
|
logger.error(f"Error checking position for {symbol}: {str(e)}")
|
|
|
|
# Sleep for a while
|
|
time.sleep(10)
|
|
except Exception as e:
|
|
logger.error(f"Error in trading loop: {str(e)}")
|
|
finally:
|
|
logger.info("Trading loop stopped.")
|
|
|
|
def reset_daily_limits(self):
|
|
"""Reset daily trading limits. Call this at the start of each trading day."""
|
|
self.trades_today = {symbol: 0 for symbol in self.trade_symbols}
|
|
logger.info("Daily trading limits reset.")
|
|
|
|
def process_signal(self, symbol: str, action: str,
|
|
confidence: float = None, timestamp: int = None) -> Optional[Dict[str, Any]]:
|
|
"""Process a trading signal and execute a trade if conditions are met.
|
|
|
|
Args:
|
|
symbol: Trading symbol (e.g., 'BTC/USDT')
|
|
action: Trade action ('BUY', 'SELL', 'HOLD')
|
|
confidence: Confidence level of the signal (0.0-1.0)
|
|
timestamp: Timestamp of the signal (unix time)
|
|
|
|
Returns:
|
|
dict: Trade information if a trade was executed, None otherwise
|
|
"""
|
|
if not self.active:
|
|
logger.warning("Trading agent is not active. Signal ignored.")
|
|
return None
|
|
|
|
if symbol not in self.trade_symbols:
|
|
logger.warning(f"Symbol {symbol} is not in the trading symbols list. Signal ignored.")
|
|
return None
|
|
|
|
if action not in ['BUY', 'SELL', 'HOLD']:
|
|
logger.warning(f"Invalid action: {action}. Must be 'BUY', 'SELL', or 'HOLD'.")
|
|
return None
|
|
|
|
# Log the signal
|
|
confidence_str = f" (confidence: {confidence:.2f})" if confidence is not None else ""
|
|
logger.info(f"Received {action} signal for {symbol}{confidence_str}")
|
|
|
|
# Ignore HOLD signals for trading
|
|
if action == 'HOLD':
|
|
return None
|
|
|
|
# Check if we can trade based on limits
|
|
current_time = time.time()
|
|
|
|
# Check max trades per day
|
|
if self.trades_today.get(symbol, 0) >= self.max_trades_per_day:
|
|
logger.warning(f"Max trades per day reached for {symbol}. Signal ignored.")
|
|
return None
|
|
|
|
# Check trade cooldown
|
|
last_trade_time = self.last_trade_time.get(symbol, 0)
|
|
if current_time - last_trade_time < self.trade_cooldown_seconds:
|
|
cooldown_remaining = self.trade_cooldown_seconds - (current_time - last_trade_time)
|
|
logger.warning(f"Trade cooldown active for {symbol}. {cooldown_remaining:.1f} seconds remaining. Signal ignored.")
|
|
return None
|
|
|
|
# Check if the action makes sense based on current position
|
|
base_asset, _ = symbol.split('/')
|
|
current_position = self.current_positions.get(symbol, 0)
|
|
|
|
if action == 'BUY' and current_position > 0:
|
|
logger.warning(f"Already have a position in {symbol}. BUY signal ignored.")
|
|
return None
|
|
|
|
if action == 'SELL' and current_position <= 0:
|
|
logger.warning(f"No position in {symbol} to sell. SELL signal ignored.")
|
|
return None
|
|
|
|
# Execute the trade
|
|
try:
|
|
trade_result = self.exchange.execute_trade(
|
|
symbol=symbol,
|
|
action=action,
|
|
percent_of_balance=self.position_size
|
|
)
|
|
|
|
if trade_result:
|
|
# Update trading state
|
|
self.trades_today[symbol] = self.trades_today.get(symbol, 0) + 1
|
|
self.last_trade_time[symbol] = current_time
|
|
|
|
# Create trade record
|
|
trade_record = {
|
|
'symbol': symbol,
|
|
'action': action,
|
|
'timestamp': timestamp or int(current_time),
|
|
'confidence': confidence,
|
|
'order_id': trade_result.get('orderId') if isinstance(trade_result, dict) else None,
|
|
'status': 'executed'
|
|
}
|
|
|
|
# Add to trade history
|
|
self.trade_history.append(trade_record)
|
|
|
|
# Call signal callback if provided
|
|
if self.signal_callback:
|
|
self.signal_callback(trade_record)
|
|
|
|
logger.info(f"Successfully executed {action} trade for {symbol}")
|
|
return trade_record
|
|
else:
|
|
logger.error(f"Failed to execute {action} trade for {symbol}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error executing trade for {symbol}: {str(e)}")
|
|
return None
|
|
|
|
def get_current_positions(self) -> Dict[str, float]:
|
|
"""Get current positions.
|
|
|
|
Returns:
|
|
dict: Symbol -> position size
|
|
"""
|
|
return self.current_positions.copy()
|
|
|
|
def get_trade_history(self) -> List[Dict[str, Any]]:
|
|
"""Get trade history.
|
|
|
|
Returns:
|
|
list: List of trade records
|
|
"""
|
|
return self.trade_history.copy()
|
|
|
|
def get_exchange_info(self) -> Dict[str, Any]:
|
|
"""Get exchange information.
|
|
|
|
Returns:
|
|
dict: Exchange information
|
|
"""
|
|
return {
|
|
'name': self.exchange_name,
|
|
'test_mode': self.test_mode,
|
|
'active': self.active,
|
|
'trade_symbols': self.trade_symbols,
|
|
'position_size': self.position_size,
|
|
'max_trades_per_day': self.max_trades_per_day,
|
|
'trade_cooldown_seconds': self.trade_cooldown_seconds,
|
|
'trades_today': self.trades_today.copy()
|
|
} |