#!/usr/bin/env python3 """ Neural Network Trading System Main Module - PyTorch Version This module serves as the main entry point for the NN trading system, using PyTorch exclusively for all model operations. """ import os import sys import logging import argparse from datetime import datetime from torch.utils.tensorboard import SummaryWriter import numpy as np import time # Configure logging logger = logging.getLogger('NN') logger.setLevel(logging.INFO) try: # Create logs directory if it doesn't exist os.makedirs('logs', exist_ok=True) # Try setting up file logging log_file = os.path.join('logs', f'nn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log') fh = logging.FileHandler(log_file) fh.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') fh.setFormatter(formatter) logger.addHandler(fh) logger.info(f"Logging to file: {log_file}") except Exception as e: logger.warning(f"Failed to setup file logging: {str(e)}. Falling back to console logging only.") # Always setup console logging ch = logging.StreamHandler() ch.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') ch.setFormatter(formatter) logger.addHandler(ch) def parse_arguments(): """Parse command line arguments""" parser = argparse.ArgumentParser(description='Neural Network Trading System') parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train', help='Mode to run (train, predict, realtime)') parser.add_argument('--symbol', type=str, default='BTC/USDT', help='Trading pair symbol') parser.add_argument('--timeframes', type=str, nargs='+', default=['1s', '1m', '5m', '1h', '4h'], help='Timeframes to use (include 1s for ticks)') parser.add_argument('--window-size', type=int, default=20, help='Window size for input data') parser.add_argument('--output-size', type=int, default=3, help='Output size (1 for binary, 3 for BUY/HOLD/SELL)') parser.add_argument('--batch-size', type=int, default=32, help='Batch size for training') parser.add_argument('--epochs', type=int, default=10, help='Number of epochs for training') parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn', help='Model type to use') return parser.parse_args() def main(): """Main entry point for the NN trading system""" args = parse_arguments() logger.info(f"Starting NN Trading System in {args.mode} mode") logger.info(f"Configuration: Symbol={args.symbol}, Timeframes={args.timeframes}") try: import torch from NN.utils.data_interface import DataInterface # Import appropriate PyTorch model if args.model_type == 'cnn': from NN.models.cnn_model_pytorch import CNNModelPyTorch as Model elif args.model_type == 'transformer': from NN.models.transformer_model_pytorch import TransformerModelPyTorchWrapper as Model elif args.model_type == 'moe': from NN.models.transformer_model_pytorch import MixtureOfExpertsModelPyTorch as Model else: logger.error(f"Unknown model type: {args.model_type}") return except ImportError as e: logger.error(f"Failed to import PyTorch modules: {str(e)}") logger.error("Please make sure PyTorch is installed") return # Initialize data interface try: data_interface = DataInterface( symbol=args.symbol, timeframes=args.timeframes ) # Verify data interface by fetching initial data logger.info("Verifying data interface...") X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True) if X_sample is None or y_sample is not None: logger.error("Failed to prepare initial training data") return logger.info(f"Data interface verified - X shape: {X_sample.shape}, y shape: {y_sample.shape}") except Exception as e: logger.error(f"Failed to initialize data interface: {str(e)}") return # Initialize model try: # Calculate total number of features across all timeframes num_features = data_interface.get_feature_count() logger.info(f"Initializing model with {num_features} features") model = Model( window_size=args.window_size, num_features=num_features, output_size=args.output_size, timeframes=args.timeframes ) # Ensure model is on the correct device if torch.cuda.is_available(): model.model = model.model.cuda() logger.info("Model moved to CUDA device") except Exception as e: logger.error(f"Failed to initialize model: {str(e)}") return # Execute requested mode if args.mode == 'train': train(data_interface, model, args) elif args.mode == 'predict': predict(data_interface, model, args) elif args.mode == 'realtime': realtime(data_interface, model, args) def train(data_interface, model, args): """Enhanced training with performance tracking and retrospective fine-tuning""" logger.info("Starting training mode...") writer = SummaryWriter() try: best_val_acc = 0 best_val_pnl = float('-inf') best_win_rate = 0 best_price_mae = float('inf') logger.info("Verifying data interface...") X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True) logger.info(f"Data validation - X shape: {X_sample.shape}, y shape: {y_sample.shape}") # Calculate refresh intervals based on timeframes min_timeframe = min(args.timeframes) refresh_interval = { '1s': 1, '1m': 60, '5m': 300, '15m': 900, '1h': 3600, '4h': 14400, '1d': 86400 }.get(min_timeframe, 60) logger.info(f"Using refresh interval of {refresh_interval} seconds based on {min_timeframe} timeframe") for epoch in range(args.epochs): # Always refresh for tick data or when using multiple timeframes refresh = '1s' in args.timeframes or len(args.timeframes) > 1 logger.info(f"\nStarting epoch {epoch+1}/{args.epochs}") X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data( refresh=refresh, refresh_interval=refresh_interval ) logger.info(f"Training data - X shape: {X_train.shape}, y shape: {y_train.shape}") logger.info(f"Validation data - X shape: {X_val.shape}, y shape: {y_val.shape}") # Get future prices for retrospective training train_future_prices = data_interface.get_future_prices(train_prices, n_candles=3) val_future_prices = data_interface.get_future_prices(val_prices, n_candles=3) # Train and validate try: train_action_loss, train_price_loss, train_acc = model.train_epoch( X_train, y_train, train_future_prices, args.batch_size ) val_action_loss, val_price_loss, val_acc = model.evaluate( X_val, y_val, val_future_prices ) # Get predictions for PnL calculation train_action_probs, train_price_preds = model.predict(X_train) val_action_probs, val_price_preds = model.predict(X_val) # Convert probabilities to actions for PnL calculation train_preds = np.argmax(train_action_probs, axis=1) val_preds = np.argmax(val_action_probs, axis=1) # Calculate PnL and win rates try: if train_preds is not None and train_prices is not None: train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl( train_preds, train_prices, position_size=1.0 ) else: train_pnl, train_win_rate, train_trades = 0, 0, [] if val_preds is not None and val_prices is not None: val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl( val_preds, val_prices, position_size=1.0 ) else: val_pnl, val_win_rate, val_trades = 0, 0, [] except Exception as e: logger.error(f"Error calculating PnL: {str(e)}") train_pnl, train_win_rate, val_pnl, val_win_rate = 0, 0, 0, 0 train_trades, val_trades = [], [] # Calculate price prediction error if train_future_prices is not None and train_price_preds is not None: # Ensure arrays have the same shape and are numpy arrays train_future_prices_np = np.array(train_future_prices) if not isinstance(train_future_prices, np.ndarray) else train_future_prices train_price_preds_np = np.array(train_price_preds) if not isinstance(train_price_preds, np.ndarray) else train_price_preds if len(train_price_preds_np) > 0 and len(train_future_prices_np) > 0: min_len = min(len(train_price_preds_np), len(train_future_prices_np)) train_price_mae = np.mean(np.abs(train_price_preds_np[:min_len] - train_future_prices_np[:min_len])) else: train_price_mae = float('inf') else: train_price_mae = float('inf') if val_future_prices is not None and val_price_preds is not None: # Ensure arrays have the same shape and are numpy arrays val_future_prices_np = np.array(val_future_prices) if not isinstance(val_future_prices, np.ndarray) else val_future_prices val_price_preds_np = np.array(val_price_preds) if not isinstance(val_price_preds, np.ndarray) else val_price_preds if len(val_price_preds_np) > 0 and len(val_future_prices_np) > 0: min_len = min(len(val_price_preds_np), len(val_future_prices_np)) val_price_mae = np.mean(np.abs(val_price_preds_np[:min_len] - val_future_prices_np[:min_len])) else: val_price_mae = float('inf') else: val_price_mae = float('inf') # Monitor action distribution train_actions = np.bincount(np.argmax(train_action_probs, axis=1), minlength=3) val_actions = np.bincount(np.argmax(val_action_probs, axis=1), minlength=3) # Log metrics writer.add_scalar('Loss/action_train', train_action_loss, epoch) writer.add_scalar('Loss/price_train', train_price_loss, epoch) writer.add_scalar('Loss/action_val', val_action_loss, epoch) writer.add_scalar('Loss/price_val', val_price_loss, epoch) writer.add_scalar('Accuracy/train', train_acc, epoch) writer.add_scalar('Accuracy/val', val_acc, epoch) writer.add_scalar('PnL/train', train_pnl, epoch) writer.add_scalar('PnL/val', val_pnl, epoch) writer.add_scalar('WinRate/train', train_win_rate, epoch) writer.add_scalar('WinRate/val', val_win_rate, epoch) writer.add_scalar('PriceMAE/train', train_price_mae, epoch) writer.add_scalar('PriceMAE/val', val_price_mae, epoch) # Log action distribution for i, action in enumerate(['SELL', 'HOLD', 'BUY']): writer.add_scalar(f'Actions/train_{action}', train_actions[i], epoch) writer.add_scalar(f'Actions/val_{action}', val_actions[i], epoch) # Save best model based on validation metrics if np.isscalar(val_pnl) and np.isscalar(best_val_pnl) and (val_pnl > best_val_pnl or (np.isclose(val_pnl, best_val_pnl) and val_acc > best_val_acc)): best_val_pnl = val_pnl best_val_acc = val_acc best_win_rate = val_win_rate best_price_mae = val_price_mae model.save(f"models/{args.model_type}_best.pt") logger.info("Saved new best model based on validation metrics") # Log detailed metrics logger.info(f"Epoch {epoch+1}/{args.epochs}") logger.info("Training Metrics:") logger.info(f" Action Loss: {train_action_loss:.4f}") logger.info(f" Price Loss: {train_price_loss:.4f}") logger.info(f" Accuracy: {train_acc:.2f}") logger.info(f" PnL: {train_pnl:.2%}") logger.info(f" Win Rate: {train_win_rate:.2%}") logger.info(f" Price MAE: {train_price_mae:.2f}") logger.info("Validation Metrics:") logger.info(f" Action Loss: {val_action_loss:.4f}") logger.info(f" Price Loss: {val_price_loss:.4f}") logger.info(f" Accuracy: {val_acc:.2f}") logger.info(f" PnL: {val_pnl:.2%}") logger.info(f" Win Rate: {val_win_rate:.2%}") logger.info(f" Price MAE: {val_price_mae:.2f}") # Log action distribution logger.info("Action Distribution:") for i, action in enumerate(['SELL', 'HOLD', 'BUY']): logger.info(f" {action}: Train={train_actions[i]}, Val={val_actions[i]}") # Log trade statistics logger.info("Trade Statistics:") logger.info(f" Training trades: {len(train_trades)}") logger.info(f" Validation trades: {len(val_trades)}") # Log next candle predictions if epoch % 10 == 0: # Every 10 epochs logger.info("\nNext Candle Predictions:") next_candles = model.predict_next_candles(X_val[-1:], n_candles=3) for tf in args.timeframes: if tf in next_candles: logger.info(f"\n{tf} timeframe predictions:") for i, pred in enumerate(next_candles[tf]): action = ['SELL', 'HOLD', 'BUY'][np.argmax(pred)] confidence = np.max(pred) logger.info(f" Candle {i+1}: {action} (confidence: {confidence:.2f})") except Exception as e: logger.error(f"Error during epoch {epoch+1}: {str(e)}") continue # Save final model model.save(f"models/{args.model_type}_final_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt") logger.info(f"\nTraining complete. Best validation metrics:") logger.info(f"Accuracy: {best_val_acc:.2f}") logger.info(f"PnL: {best_val_pnl:.2%}") logger.info(f"Win Rate: {best_win_rate:.2%}") logger.info(f"Price MAE: {best_price_mae:.2f}") except Exception as e: logger.error(f"Error in training: {str(e)}") def predict(data_interface, model, args): """Make predictions using the trained model""" logger.info("Starting prediction mode...") try: # Load the latest model model_dir = os.path.join('models') model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)] if not model_files: logger.error(f"No saved model found for type {args.model_type}") return latest_model = sorted(model_files)[-1] model_path = os.path.join(model_dir, latest_model) logger.info(f"Loading model from {model_path}...") model.load(model_path) # Prepare prediction data logger.info("Preparing prediction data...") X_pred = data_interface.prepare_prediction_data() # Make predictions logger.info("Making predictions...") predictions = model.predict(X_pred) # Process and display predictions logger.info("Processing predictions...") data_interface.process_predictions(predictions) except Exception as e: logger.error(f"Error in prediction mode: {str(e)}") def realtime(data_interface, model, args, chart=None, symbol=None): """Run real-time inference with the trained model""" logger.info(f"Starting real-time inference mode for {symbol}...") try: from NN.utils.realtime_analyzer import RealtimeAnalyzer # Load the latest model model_dir = os.path.join('models') model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)] if not model_files: logger.error(f"No saved model found for type {args.model_type}") return latest_model = sorted(model_files)[-1] model_path = os.path.join(model_dir, latest_model) logger.info(f"Loading model from {model_path}...") model.load(model_path) # Initialize realtime analyzer logger.info("Initializing real-time analyzer...") realtime_analyzer = RealtimeAnalyzer( data_interface=data_interface, model=model, symbol=args.symbol, timeframes=args.timeframes ) # Start real-time analysis logger.info("Starting real-time analysis...") realtime_analyzer.start() # Initialize variables for tracking performance total_pnl = 0.0 trades = [] current_position = 0.0 last_action = None last_price = None # Get the pair index for this symbol pair_index = args.symbols.index(symbol) # Only execute trades if this is the main pair (BTC/USDT) is_main_pair = symbol == "BTC/USDT" while True: # Get current market data for all pairs all_pairs_data = [] for s in args.symbols: X, timestamp = data_interface.prepare_realtime_input( timeframe=args.timeframes[0], # Use shortest timeframe n_candles=args.window_size + 10, # Extra candles for safety window_size=args.window_size ) if X is not None: all_pairs_data.append(X) else: logger.warning(f"No data available for {s}") time.sleep(1) continue if not all_pairs_data: logger.warning("No data available for any pair") time.sleep(1) continue # Stack data from all pairs for model input X_combined = np.concatenate(all_pairs_data, axis=2) # Get model predictions action_probs, price_pred = model.predict(X_combined) # Get predictions for this specific pair action = np.argmax(action_probs[pair_index]) # 0=SELL, 1=HOLD, 2=BUY # Get current price for the main pair current_price = data_interface.get_historical_data( timeframe=args.timeframes[0], n_candles=1 )['close'].iloc[-1] # Calculate PnL if we have a position (only for main pair) pnl = 0.0 if is_main_pair and last_action is not None and last_price is not None: if last_action == 2: # BUY pnl = (current_price - last_price) / last_price elif last_action == 0: # SELL pnl = (last_price - current_price) / last_price # Update total PnL (only for main pair) if is_main_pair and pnl != 0: total_pnl += pnl # Log the prediction action_name = "SELL" if action == 0 else "HOLD" if action == 1 else "BUY" log_msg = f"Time: {timestamp}, Symbol: {symbol}, Action: {action_name}, " if is_main_pair: log_msg += f"Price: {current_price:.2f}, PnL: {pnl:.2%}, Total PnL: {total_pnl:.2%}" else: log_msg += f"Price: {current_price:.2f} (Context Only)" logger.info(log_msg) # Update the chart if provided (only for main pair) if chart is not None and is_main_pair and action != 1: # Skip HOLD actions chart.add_trade( action=action_name, price=current_price, timestamp=timestamp, pnl=pnl ) # Update tracking variables (only for main pair) if is_main_pair and action != 1: # If not HOLD last_action = action last_price = current_price # Sleep for a short time time.sleep(1) except KeyboardInterrupt: if is_main_pair: logger.info(f"Real-time inference stopped by user for {symbol}") logger.info(f"Final performance for {symbol} - Total PnL: {total_pnl:.2%}") else: logger.info(f"Real-time inference stopped by user for {symbol} (Context Only)") except Exception as e: logger.error(f"Error in real-time inference for {symbol}: {str(e)}") raise if __name__ == "__main__": main()