#!/usr/bin/env python3 """ Neural Network Trading System Main Module - PyTorch Version This module serves as the main entry point for the NN trading system, using PyTorch exclusively for all model operations. """ import os import sys import logging import argparse from datetime import datetime from torch.utils.tensorboard import SummaryWriter import numpy as np # Configure logging logger = logging.getLogger('NN') logger.setLevel(logging.INFO) try: # Create logs directory if it doesn't exist os.makedirs('logs', exist_ok=True) # Try setting up file logging log_file = os.path.join('logs', f'nn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log') fh = logging.FileHandler(log_file) fh.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') fh.setFormatter(formatter) logger.addHandler(fh) logger.info(f"Logging to file: {log_file}") except Exception as e: logger.warning(f"Failed to setup file logging: {str(e)}. Falling back to console logging only.") # Always setup console logging ch = logging.StreamHandler() ch.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') ch.setFormatter(formatter) logger.addHandler(ch) def parse_arguments(): """Parse command line arguments""" parser = argparse.ArgumentParser(description='Neural Network Trading System') parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train', help='Mode to run (train, predict, realtime)') parser.add_argument('--symbol', type=str, default='BTC/USDT', help='Trading pair symbol') parser.add_argument('--timeframes', type=str, nargs='+', default=['1s', '1m', '5m', '1h', '4h'], help='Timeframes to use (include 1s for ticks)') parser.add_argument('--window-size', type=int, default=20, help='Window size for input data') parser.add_argument('--output-size', type=int, default=3, help='Output size (1 for binary, 3 for BUY/HOLD/SELL)') parser.add_argument('--batch-size', type=int, default=32, help='Batch size for training') parser.add_argument('--epochs', type=int, default=10, help='Number of epochs for training') parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn', help='Model type to use') return parser.parse_args() def main(): """Main entry point for the NN trading system""" args = parse_arguments() logger.info(f"Starting NN Trading System in {args.mode} mode") logger.info(f"Configuration: Symbol={args.symbol}, Timeframes={args.timeframes}") try: import torch from NN.utils.data_interface import DataInterface # Import appropriate PyTorch model if args.model_type == 'cnn': from NN.models.cnn_model_pytorch import CNNModelPyTorch as Model elif args.model_type == 'transformer': from NN.models.transformer_model_pytorch import TransformerModelPyTorchWrapper as Model elif args.model_type == 'moe': from NN.models.transformer_model_pytorch import MixtureOfExpertsModelPyTorch as Model else: logger.error(f"Unknown model type: {args.model_type}") return except ImportError as e: logger.error(f"Failed to import PyTorch modules: {str(e)}") logger.error("Please make sure PyTorch is installed") return # Initialize data interface try: data_interface = DataInterface( symbol=args.symbol, timeframes=args.timeframes ) # Verify data interface by fetching initial data logger.info("Verifying data interface...") X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True) if X_sample is None or y_sample is None: logger.error("Failed to prepare initial training data") return logger.info(f"Data interface verified - X shape: {X_sample.shape}, y shape: {y_sample.shape}") except Exception as e: logger.error(f"Failed to initialize data interface: {str(e)}") return # Initialize model try: # Calculate total number of features across all timeframes num_features = data_interface.get_feature_count() logger.info(f"Initializing model with {num_features} features") model = Model( window_size=args.window_size, num_features=num_features, output_size=args.output_size, timeframes=args.timeframes ) # Ensure model is on the correct device if torch.cuda.is_available(): model.model = model.model.cuda() logger.info("Model moved to CUDA device") except Exception as e: logger.error(f"Failed to initialize model: {str(e)}") return # Execute requested mode if args.mode == 'train': train(data_interface, model, args) elif args.mode == 'predict': predict(data_interface, model, args) elif args.mode == 'realtime': realtime(data_interface, model, args) def train(data_interface, model, args): """Enhanced training with performance tracking and retrospective fine-tuning""" logger.info("Starting training mode...") writer = SummaryWriter() try: best_val_acc = 0 best_val_pnl = float('-inf') best_win_rate = 0 logger.info("Verifying data interface...") X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True) logger.info(f"Data validation - X shape: {X_sample.shape}, y shape: {y_sample.shape}") for epoch in range(args.epochs): # More frequent refresh for shorter timeframes if '1s' in args.timeframes: refresh = True # Always refresh for tick data refresh_interval = 30 # 30 seconds for tick data else: refresh = epoch % 1 == 0 # Refresh every epoch refresh_interval = 120 # 2 minutes for other timeframes logger.info(f"\nStarting epoch {epoch+1}/{args.epochs}") X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data( refresh=refresh, refresh_interval=refresh_interval ) logger.info(f"Training data - X shape: {X_train.shape}, y shape: {y_train.shape}") logger.info(f"Validation data - X shape: {X_val.shape}, y shape: {y_val.shape}") # Train and validate try: train_loss, train_acc = model.train_epoch(X_train, y_train, args.batch_size) val_loss, val_acc = model.evaluate(X_val, y_val) # Get predictions for PnL calculation train_preds = model.predict(X_train) val_preds = model.predict(X_val) # Calculate PnL and win rates train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl( train_preds, train_prices, position_size=1.0 ) val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl( val_preds, val_prices, position_size=1.0 ) # Monitor action distribution train_actions = np.bincount(train_preds, minlength=3) val_actions = np.bincount(val_preds, minlength=3) # Log metrics writer.add_scalar('Loss/train', train_loss, epoch) writer.add_scalar('Accuracy/train', train_acc, epoch) writer.add_scalar('Loss/val', val_loss, epoch) writer.add_scalar('Accuracy/val', val_acc, epoch) writer.add_scalar('PnL/train', train_pnl, epoch) writer.add_scalar('PnL/val', val_pnl, epoch) writer.add_scalar('WinRate/train', train_win_rate, epoch) writer.add_scalar('WinRate/val', val_win_rate, epoch) # Log action distribution for i, action in enumerate(['SELL', 'HOLD', 'BUY']): writer.add_scalar(f'Actions/train_{action}', train_actions[i], epoch) writer.add_scalar(f'Actions/val_{action}', val_actions[i], epoch) # Save best model based on validation PnL if val_pnl > best_val_pnl: best_val_pnl = val_pnl best_val_acc = val_acc best_win_rate = val_win_rate model.save(f"models/{args.model_type}_best.pt") # Log detailed metrics logger.info(f"Epoch {epoch+1}/{args.epochs} - " f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}, " f"PnL: {train_pnl:.2%}, Win Rate: {train_win_rate:.2%} - " f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}, " f"PnL: {val_pnl:.2%}, Win Rate: {val_win_rate:.2%}") # Log action distribution logger.info("Action Distribution:") for i, action in enumerate(['SELL', 'HOLD', 'BUY']): logger.info(f"{action}: Train={train_actions[i]}, Val={val_actions[i]}") # Log trade statistics if train_trades: logger.info(f"Training trades: {len(train_trades)}") logger.info(f"Validation trades: {len(val_trades)}") # Retrospective fine-tuning if epoch > 0 and val_pnl > 0: # Only fine-tune if we're making profit logger.info("Performing retrospective fine-tuning...") # Get predictions for next few candles next_candles = model.predict_next_candles(X_val[-1:], n_candles=3) # Log predictions for each timeframe for tf, preds in next_candles.items(): logger.info(f"Next 3 candles for {tf}:") for i, pred in enumerate(preds): action = ['SELL', 'HOLD', 'BUY'][np.argmax(pred)] confidence = np.max(pred) logger.info(f"Candle {i+1}: {action} (confidence: {confidence:.2f})") # Fine-tune on recent successful trades successful_trades = [t for t in train_trades if t['pnl'] > 0] if successful_trades: logger.info(f"Fine-tuning on {len(successful_trades)} successful trades") # TODO: Implement fine-tuning logic here except Exception as e: logger.error(f"Error during epoch {epoch+1}: {str(e)}") continue # Save final model model.save(f"models/{args.model_type}_final_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt") logger.info(f"Training complete. Best validation metrics:") logger.info(f"Accuracy: {best_val_acc:.2f}") logger.info(f"PnL: {best_val_pnl:.2%}") logger.info(f"Win Rate: {best_win_rate:.2%}") except Exception as e: logger.error(f"Error in training: {str(e)}") def predict(data_interface, model, args): """Make predictions using the trained model""" logger.info("Starting prediction mode...") try: # Load the latest model model_dir = os.path.join('models') model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)] if not model_files: logger.error(f"No saved model found for type {args.model_type}") return latest_model = sorted(model_files)[-1] model_path = os.path.join(model_dir, latest_model) logger.info(f"Loading model from {model_path}...") model.load(model_path) # Prepare prediction data logger.info("Preparing prediction data...") X_pred = data_interface.prepare_prediction_data() # Make predictions logger.info("Making predictions...") predictions = model.predict(X_pred) # Process and display predictions logger.info("Processing predictions...") data_interface.process_predictions(predictions) except Exception as e: logger.error(f"Error in prediction mode: {str(e)}") def realtime(data_interface, model, args): """Run the model in real-time mode""" logger.info("Starting real-time mode...") try: from NN.utils.realtime_analyzer import RealtimeAnalyzer # Load the latest model model_dir = os.path.join('models') model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)] if not model_files: logger.error(f"No saved model found for type {args.model_type}") return latest_model = sorted(model_files)[-1] model_path = os.path.join(model_dir, latest_model) logger.info(f"Loading model from {model_path}...") model.load(model_path) # Initialize realtime analyzer logger.info("Initializing real-time analyzer...") realtime_analyzer = RealtimeAnalyzer( data_interface=data_interface, model=model, symbol=args.symbol, timeframes=args.timeframes ) # Start real-time analysis logger.info("Starting real-time analysis...") realtime_analyzer.start() except Exception as e: logger.error(f"Error in real-time mode: {str(e)}") if __name__ == "__main__": main()