155 lines
6.2 KiB
Python
155 lines
6.2 KiB
Python
import os
|
|
import time
|
|
import logging
|
|
import sys
|
|
import argparse
|
|
import json
|
|
|
|
# Add the NN directory to the Python path
|
|
sys.path.append(os.path.abspath("NN"))
|
|
|
|
from NN.main import load_model
|
|
from NN.neural_network_orchestrator import NeuralNetworkOrchestrator
|
|
from NN.realtime_data_interface import RealtimeDataInterface
|
|
|
|
# Initialize logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler("trading_bot.log"),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def main():
|
|
"""Main function for the trading bot."""
|
|
# Parse command-line arguments
|
|
parser = argparse.ArgumentParser(description="Trading Bot with Neural Network Integration")
|
|
parser.add_argument('--symbols', nargs='+', default=["BTC/USDT", "ETH/USDT"],
|
|
help='Trading symbols to monitor')
|
|
parser.add_argument('--timeframes', nargs='+', default=["1m", "5m", "1h", "4h", "1d"],
|
|
help='Timeframes to monitor')
|
|
parser.add_argument('--window-size', type=int, default=20,
|
|
help='Window size for model input')
|
|
parser.add_argument('--output-size', type=int, default=3,
|
|
help='Output size of the model (3 for BUY/HOLD/SELL)')
|
|
parser.add_argument('--model-type', type=str, default="cnn", choices=["cnn", "lstm", "mlp"],
|
|
help='Type of neural network model')
|
|
parser.add_argument('--mode', type=str, default="realtime", choices=["realtime", "backtest"],
|
|
help='Trading mode')
|
|
parser.add_argument('--exchange', type=str, default="binance", choices=["binance", "mexc"],
|
|
help='Exchange to use for trading')
|
|
parser.add_argument('--api-key', type=str, default=None,
|
|
help='API key for the exchange')
|
|
parser.add_argument('--api-secret', type=str, default=None,
|
|
help='API secret for the exchange')
|
|
parser.add_argument('--test-mode', action='store_true',
|
|
help='Use test/sandbox exchange environment')
|
|
parser.add_argument('--position-size', type=float, default=0.1,
|
|
help='Position size as a fraction of total balance (0.0-1.0)')
|
|
parser.add_argument('--max-trades-per-day', type=int, default=5,
|
|
help='Maximum number of trades per day')
|
|
parser.add_argument('--trade-cooldown', type=int, default=60,
|
|
help='Trade cooldown period in minutes')
|
|
parser.add_argument('--config-file', type=str, default=None,
|
|
help='Path to configuration file')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Load configuration from file if provided
|
|
if args.config_file and os.path.exists(args.config_file):
|
|
with open(args.config_file, 'r') as f:
|
|
config = json.load(f)
|
|
# Override config with command-line args
|
|
for key, value in vars(args).items():
|
|
if key != 'config_file' and value is not None:
|
|
config[key] = value
|
|
else:
|
|
# Use command-line args as config
|
|
config = vars(args)
|
|
|
|
# Initialize real-time charts and data interfaces
|
|
try:
|
|
from dataprovider_realtime import RealTimeChart
|
|
|
|
# Create a real-time chart for each symbol
|
|
charts = {}
|
|
for symbol in config['symbols']:
|
|
charts[symbol] = RealTimeChart(symbol=symbol)
|
|
|
|
main_chart = charts[config['symbols'][0]]
|
|
|
|
# Create a data interface for retrieving market data
|
|
data_interface = RealtimeDataInterface(symbols=config['symbols'], chart=main_chart)
|
|
|
|
# Load trained model
|
|
model_type = os.environ.get("NN_MODEL_TYPE", config['model_type'])
|
|
model = load_model(
|
|
model_type=model_type,
|
|
input_shape=(config['window_size'], len(config['symbols']), 5), # 5 features (OHLCV)
|
|
output_size=config['output_size']
|
|
)
|
|
|
|
# Configure trading agent
|
|
exchange_config = {
|
|
"exchange": config['exchange'],
|
|
"api_key": config['api_key'],
|
|
"api_secret": config['api_secret'],
|
|
"test_mode": config['test_mode'],
|
|
"trade_symbols": config['symbols'],
|
|
"position_size": config['position_size'],
|
|
"max_trades_per_day": config['max_trades_per_day'],
|
|
"trade_cooldown_minutes": config['trade_cooldown']
|
|
}
|
|
|
|
# Initialize neural network orchestrator
|
|
orchestrator = NeuralNetworkOrchestrator(
|
|
model=model,
|
|
data_interface=data_interface,
|
|
chart=main_chart,
|
|
symbols=config['symbols'],
|
|
timeframes=config['timeframes'],
|
|
window_size=config['window_size'],
|
|
num_features=5, # OHLCV
|
|
output_size=config['output_size'],
|
|
exchange_config=exchange_config
|
|
)
|
|
|
|
# Start data collection
|
|
logger.info("Starting data collection threads...")
|
|
for symbol in config['symbols']:
|
|
charts[symbol].start()
|
|
|
|
# Start neural network inference
|
|
if os.environ.get("ENABLE_NN_MODELS", "0") == "1":
|
|
logger.info("Starting neural network inference...")
|
|
orchestrator.start_inference()
|
|
else:
|
|
logger.info("Neural network models disabled. Set ENABLE_NN_MODELS=1 to enable.")
|
|
|
|
# Start web servers for chart display
|
|
logger.info("Starting web servers for chart display...")
|
|
main_chart.start_server()
|
|
|
|
logger.info("Trading bot initialized successfully. Press Ctrl+C to exit.")
|
|
|
|
# Keep the main thread alive
|
|
try:
|
|
while True:
|
|
time.sleep(1)
|
|
except KeyboardInterrupt:
|
|
logger.info("Keyboard interrupt received. Shutting down...")
|
|
# Stop all threads
|
|
for symbol in config['symbols']:
|
|
charts[symbol].stop()
|
|
orchestrator.stop_inference()
|
|
logger.info("Trading bot stopped.")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in main function: {str(e)}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
main() |