gogo2/NN/trading_agent.py
2025-03-31 14:22:33 +03:00

310 lines
12 KiB
Python

import logging
import time
import threading
from typing import Dict, Any, List, Optional, Callable, Tuple, Union
from .exchanges import ExchangeInterface, MEXCInterface, BinanceInterface
logger = logging.getLogger(__name__)
class TradingAgent:
"""Trading agent that executes trades based on neural network signals.
This agent interfaces with different exchanges and executes trades
based on the signals received from the neural network.
"""
def __init__(self,
exchange_name: str = 'binance',
api_key: str = None,
api_secret: str = None,
test_mode: bool = True,
trade_symbols: List[str] = None,
position_size: float = 0.1,
max_trades_per_day: int = 5,
trade_cooldown_minutes: int = 60):
"""Initialize the trading agent.
Args:
exchange_name: Name of the exchange to use ('binance', 'mexc')
api_key: API key for the exchange
api_secret: API secret for the exchange
test_mode: If True, use test/sandbox environment
trade_symbols: List of trading symbols to monitor (e.g., ['BTC/USDT'])
position_size: Size of each position as a fraction of total available balance (0.0-1.0)
max_trades_per_day: Maximum number of trades to execute per day
trade_cooldown_minutes: Minimum time between trades in minutes
"""
self.exchange_name = exchange_name.lower()
self.api_key = api_key
self.api_secret = api_secret
self.test_mode = test_mode
self.trade_symbols = trade_symbols or ['BTC/USDT']
self.position_size = min(max(position_size, 0.01), 1.0) # Ensure between 0.01 and 1.0
self.max_trades_per_day = max(1, max_trades_per_day)
self.trade_cooldown_seconds = max(60, trade_cooldown_minutes * 60)
# Initialize exchange interface
self.exchange = self._create_exchange()
# Trading state
self.active = False
self.current_positions = {} # Symbol -> quantity
self.trades_today = {} # Symbol -> count
self.last_trade_time = {} # Symbol -> timestamp
self.trade_history = [] # List of trade records
# Threading
self.trading_thread = None
self.stop_event = threading.Event()
# Signal callback
self.signal_callback = None
# Connect to exchange
if not self.exchange.connect():
logger.error(f"Failed to connect to {self.exchange_name} exchange. Trading agent disabled.")
else:
logger.info(f"Successfully connected to {self.exchange_name} exchange.")
self._load_current_positions()
def _create_exchange(self) -> ExchangeInterface:
"""Create an exchange interface based on the exchange name."""
if self.exchange_name == 'mexc':
return MEXCInterface(
api_key=self.api_key,
api_secret=self.api_secret,
test_mode=self.test_mode
)
elif self.exchange_name == 'binance':
return BinanceInterface(
api_key=self.api_key,
api_secret=self.api_secret,
test_mode=self.test_mode
)
else:
raise ValueError(f"Unsupported exchange: {self.exchange_name}")
def _load_current_positions(self):
"""Load current positions from the exchange."""
for symbol in self.trade_symbols:
try:
base_asset, quote_asset = symbol.split('/')
balance = self.exchange.get_balance(base_asset)
if balance > 0:
self.current_positions[symbol] = balance
logger.info(f"Loaded existing position for {symbol}: {balance} {base_asset}")
except Exception as e:
logger.error(f"Error loading position for {symbol}: {str(e)}")
def start(self, signal_callback: Callable = None):
"""Start the trading agent.
Args:
signal_callback: Optional callback function to receive trade signals
"""
if self.active:
logger.warning("Trading agent is already running.")
return
self.active = True
self.signal_callback = signal_callback
self.stop_event.clear()
logger.info(f"Starting trading agent for {self.exchange_name} exchange.")
logger.info(f"Trading symbols: {', '.join(self.trade_symbols)}")
logger.info(f"Position size: {self.position_size * 100:.1f}% of available balance")
logger.info(f"Max trades per day: {self.max_trades_per_day}")
logger.info(f"Trade cooldown: {self.trade_cooldown_seconds // 60} minutes")
# Reset trading state
self.trades_today = {symbol: 0 for symbol in self.trade_symbols}
self.last_trade_time = {symbol: 0 for symbol in self.trade_symbols}
# Start trading thread
self.trading_thread = threading.Thread(target=self._trading_loop)
self.trading_thread.daemon = True
self.trading_thread.start()
def stop(self):
"""Stop the trading agent."""
if not self.active:
logger.warning("Trading agent is not running.")
return
logger.info("Stopping trading agent...")
self.active = False
self.stop_event.set()
if self.trading_thread and self.trading_thread.is_alive():
self.trading_thread.join(timeout=10)
logger.info("Trading agent stopped.")
def _trading_loop(self):
"""Main trading loop that monitors positions and executes trades."""
logger.info("Trading loop started.")
try:
while self.active and not self.stop_event.is_set():
# Check positions and update state
for symbol in self.trade_symbols:
try:
base_asset, _ = symbol.split('/')
current_balance = self.exchange.get_balance(base_asset)
# Update position if it has changed
if symbol in self.current_positions:
prev_balance = self.current_positions[symbol]
if abs(current_balance - prev_balance) > 0.001 * prev_balance:
logger.info(f"Position updated for {symbol}: {prev_balance} -> {current_balance} {base_asset}")
self.current_positions[symbol] = current_balance
except Exception as e:
logger.error(f"Error checking position for {symbol}: {str(e)}")
# Sleep for a while
time.sleep(10)
except Exception as e:
logger.error(f"Error in trading loop: {str(e)}")
finally:
logger.info("Trading loop stopped.")
def reset_daily_limits(self):
"""Reset daily trading limits. Call this at the start of each trading day."""
self.trades_today = {symbol: 0 for symbol in self.trade_symbols}
logger.info("Daily trading limits reset.")
def process_signal(self, symbol: str, action: str,
confidence: float = None, timestamp: int = None) -> Optional[Dict[str, Any]]:
"""Process a trading signal and execute a trade if conditions are met.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
action: Trade action ('BUY', 'SELL', 'HOLD')
confidence: Confidence level of the signal (0.0-1.0)
timestamp: Timestamp of the signal (unix time)
Returns:
dict: Trade information if a trade was executed, None otherwise
"""
if not self.active:
logger.warning("Trading agent is not active. Signal ignored.")
return None
if symbol not in self.trade_symbols:
logger.warning(f"Symbol {symbol} is not in the trading symbols list. Signal ignored.")
return None
if action not in ['BUY', 'SELL', 'HOLD']:
logger.warning(f"Invalid action: {action}. Must be 'BUY', 'SELL', or 'HOLD'.")
return None
# Log the signal
confidence_str = f" (confidence: {confidence:.2f})" if confidence is not None else ""
logger.info(f"Received {action} signal for {symbol}{confidence_str}")
# Ignore HOLD signals for trading
if action == 'HOLD':
return None
# Check if we can trade based on limits
current_time = time.time()
# Check max trades per day
if self.trades_today.get(symbol, 0) >= self.max_trades_per_day:
logger.warning(f"Max trades per day reached for {symbol}. Signal ignored.")
return None
# Check trade cooldown
last_trade_time = self.last_trade_time.get(symbol, 0)
if current_time - last_trade_time < self.trade_cooldown_seconds:
cooldown_remaining = self.trade_cooldown_seconds - (current_time - last_trade_time)
logger.warning(f"Trade cooldown active for {symbol}. {cooldown_remaining:.1f} seconds remaining. Signal ignored.")
return None
# Check if the action makes sense based on current position
base_asset, _ = symbol.split('/')
current_position = self.current_positions.get(symbol, 0)
if action == 'BUY' and current_position > 0:
logger.warning(f"Already have a position in {symbol}. BUY signal ignored.")
return None
if action == 'SELL' and current_position <= 0:
logger.warning(f"No position in {symbol} to sell. SELL signal ignored.")
return None
# Execute the trade
try:
trade_result = self.exchange.execute_trade(
symbol=symbol,
action=action,
percent_of_balance=self.position_size
)
if trade_result:
# Update trading state
self.trades_today[symbol] = self.trades_today.get(symbol, 0) + 1
self.last_trade_time[symbol] = current_time
# Create trade record
trade_record = {
'symbol': symbol,
'action': action,
'timestamp': timestamp or int(current_time),
'confidence': confidence,
'order_id': trade_result.get('orderId') if isinstance(trade_result, dict) else None,
'status': 'executed'
}
# Add to trade history
self.trade_history.append(trade_record)
# Call signal callback if provided
if self.signal_callback:
self.signal_callback(trade_record)
logger.info(f"Successfully executed {action} trade for {symbol}")
return trade_record
else:
logger.error(f"Failed to execute {action} trade for {symbol}")
return None
except Exception as e:
logger.error(f"Error executing trade for {symbol}: {str(e)}")
return None
def get_current_positions(self) -> Dict[str, float]:
"""Get current positions.
Returns:
dict: Symbol -> position size
"""
return self.current_positions.copy()
def get_trade_history(self) -> List[Dict[str, Any]]:
"""Get trade history.
Returns:
list: List of trade records
"""
return self.trade_history.copy()
def get_exchange_info(self) -> Dict[str, Any]:
"""Get exchange information.
Returns:
dict: Exchange information
"""
return {
'name': self.exchange_name,
'test_mode': self.test_mode,
'active': self.active,
'trade_symbols': self.trade_symbols,
'position_size': self.position_size,
'max_trades_per_day': self.max_trades_per_day,
'trade_cooldown_seconds': self.trade_cooldown_seconds,
'trades_today': self.trades_today.copy()
}