#!/usr/bin/env python3 """ Neural Network Trading System Main Module This module serves as the main entry point for the NN trading system, coordinating data flow between different components and implementing training and inference pipelines. """ import os import sys import logging import argparse from datetime import datetime # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(), logging.FileHandler(os.path.join('logs', f'nn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')) ] ) logger = logging.getLogger('NN') # Create logs directory if it doesn't exist os.makedirs('logs', exist_ok=True) def parse_arguments(): """Parse command line arguments""" parser = argparse.ArgumentParser(description='Neural Network Trading System') parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train', help='Mode to run (train, predict, realtime)') parser.add_argument('--symbol', type=str, default='BTC/USD', help='Main trading pair symbol (default: BTC/USD)') parser.add_argument('--context-pairs', type=str, nargs='*', default=[], help='Additional context trading pairs') parser.add_argument('--timeframes', type=str, nargs='+', default=['5m', '15m', '1h'], help='Timeframes to use (default: 5m,15m,1h)') parser.add_argument('--window-size', type=int, default=30, help='Window size for input data (default: 30)') parser.add_argument('--output-size', type=int, default=5, help='Output size (1=up/down, 3=BUY/HOLD/SELL, 5=with extrema)') parser.add_argument('--batch-size', type=int, default=32, help='Batch size for training') parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training') parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn', help='Model type to use') parser.add_argument('--framework', type=str, choices=['tensorflow', 'pytorch'], default='pytorch', help='Deep learning framework to use') return parser.parse_args() def main(): """Main entry point for the NN trading system""" # Parse arguments args = parse_arguments() logger.info(f"Starting NN Trading System in {args.mode} mode") logger.info(f"Main Symbol: {args.symbol}") if args.context_pairs: logger.info(f"Context Pairs: {args.context_pairs}") logger.info(f"Timeframes: {args.timeframes}") logger.info(f"Window Size: {args.window_size}") logger.info(f"Output Size: {args.output_size} (1=up/down, 3=BUY/HOLD/SELL, 5=with extrema)") logger.info(f"Model Type: {args.model_type}") logger.info(f"Framework: {args.framework}") # Import the appropriate modules based on the framework if args.framework == 'pytorch': try: import torch logger.info(f"Using PyTorch {torch.__version__}") # Import PyTorch-based modules from NN.utils.multi_data_interface import MultiDataInterface if args.model_type == 'cnn': from NN.models.cnn_model_pytorch import CNNModelPyTorch as Model elif args.model_type == 'transformer': from NN.models.transformer_model_pytorch import TransformerModelPyTorchWrapper as Model elif args.model_type == 'moe': from NN.models.transformer_model_pytorch import MixtureOfExpertsModelPyTorch as Model else: logger.error(f"Unknown model type: {args.model_type}") return except ImportError as e: logger.error(f"Failed to import PyTorch modules: {str(e)}") logger.error("Please make sure PyTorch is installed or use the TensorFlow framework.") return elif args.framework == 'tensorflow': try: import tensorflow as tf logger.info(f"Using TensorFlow {tf.__version__}") # Import TensorFlow-based modules from NN.utils.multi_data_interface import MultiDataInterface if args.model_type == 'cnn': from NN.models.cnn_model import CNNModel as Model elif args.model_type == 'transformer': from NN.models.transformer_model import TransformerModel as Model elif args.model_type == 'moe': from NN.models.transformer_model import MixtureOfExpertsModel as Model else: logger.error(f"Unknown model type: {args.model_type}") return except ImportError as e: logger.error(f"Failed to import TensorFlow modules: {str(e)}") logger.error("Please make sure TensorFlow is installed or use the PyTorch framework.") return else: logger.error(f"Unknown framework: {args.framework}") return # Initialize data interface try: logger.info("Initializing data interface...") data_interface = MultiDataInterface( symbol=args.symbol, timeframes=args.timeframes, window_size=args.window_size, output_size=args.output_size ) except Exception as e: logger.error(f"Failed to initialize data interface: {str(e)}") return # Initialize model try: logger.info(f"Initializing {args.model_type.upper()} model...") # Calculate actual feature count (OHLCV per timeframe) num_features = 5 * len(args.timeframes) model = Model( window_size=args.window_size, num_features=num_features, output_size=args.output_size, timeframes=args.timeframes ) except Exception as e: logger.error(f"Failed to initialize model: {str(e)}") return # Execute the requested mode if args.mode == 'train': train(data_interface, model, args) elif args.mode == 'predict': predict(data_interface, model, args) elif args.mode == 'realtime': realtime(data_interface, model, args) else: logger.error(f"Unknown mode: {args.mode}") return logger.info("Neural Network Trading System finished successfully") def train(data_interface, model, args): """Train the model using the data interface""" logger.info("Starting training mode...") try: # Prepare training data logger.info("Preparing training data...") X, y, _ = data_interface.prepare_nn_input( timeframes=args.timeframes, n_candles=1000, window_size=args.window_size ) logger.info(f"Training data shape: {X.shape}") logger.info(f"Target data shape: {y.shape}") # Split into train/validation sets (80/20) split_idx = int(len(X) * 0.8) X_train, y_train = X[:split_idx], y[:split_idx] X_val, y_val = X[split_idx:], y[split_idx:] # Train the model logger.info("Training model...") model.train( X_train, y_train, X_val, y_val, batch_size=args.batch_size, epochs=args.epochs ) # Save the model model_path = os.path.join( 'models', f"{args.model_type}_{args.symbol.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" ) logger.info(f"Saving model to {model_path}...") model.save(model_path) # Evaluate the model logger.info("Evaluating model...") metrics = model.evaluate(X_val, y_val) logger.info(f"Evaluation metrics: {metrics}") except Exception as e: logger.error(f"Error in training mode: {str(e)}") return def predict(data_interface, model, args): """Make predictions using the trained model""" logger.info("Starting prediction mode...") try: # Load the latest model model_dir = os.path.join('models') model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)] if not model_files: logger.error(f"No saved model found for type {args.model_type}") return latest_model = sorted(model_files)[-1] model_path = os.path.join(model_dir, latest_model) logger.info(f"Loading model from {model_path}...") model.load(model_path) # Prepare prediction data logger.info("Preparing prediction data...") X_pred = data_interface.prepare_prediction_data() # Make predictions logger.info("Making predictions...") predictions = model.predict(X_pred) # Process and display predictions logger.info("Processing predictions...") data_interface.process_predictions(predictions) except Exception as e: logger.error(f"Error in prediction mode: {str(e)}") return def realtime(data_interface, model, args): """Run the model in real-time mode""" logger.info("Starting real-time mode...") try: # Import realtime analyzer from NN.utils.realtime_analyzer import RealtimeAnalyzer # Load the latest model model_dir = os.path.join('models') model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)] if not model_files: logger.error(f"No saved model found for type {args.model_type}") return latest_model = sorted(model_files)[-1] model_path = os.path.join(model_dir, latest_model) logger.info(f"Loading model from {model_path}...") model.load(model_path) # Initialize realtime analyzer logger.info("Initializing real-time analyzer...") realtime_analyzer = RealtimeAnalyzer( data_interface=data_interface, model=model, symbol=args.symbol, timeframes=args.timeframes ) # Start real-time analysis logger.info("Starting real-time analysis...") realtime_analyzer.start() except Exception as e: logger.error(f"Error in real-time mode: {str(e)}") return if __name__ == "__main__": main()