411 lines
18 KiB
Python
411 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Neural Network Trading System Main Module - PyTorch Version
|
|
|
|
This module serves as the main entry point for the NN trading system,
|
|
using PyTorch exclusively for all model operations.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import logging
|
|
import argparse
|
|
from datetime import datetime
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import numpy as np
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger('NN')
|
|
logger.setLevel(logging.INFO)
|
|
|
|
try:
|
|
# Create logs directory if it doesn't exist
|
|
os.makedirs('logs', exist_ok=True)
|
|
|
|
# Try setting up file logging
|
|
log_file = os.path.join('logs', f'nn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
|
|
fh = logging.FileHandler(log_file)
|
|
fh.setLevel(logging.INFO)
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
fh.setFormatter(formatter)
|
|
logger.addHandler(fh)
|
|
logger.info(f"Logging to file: {log_file}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to setup file logging: {str(e)}. Falling back to console logging only.")
|
|
|
|
# Always setup console logging
|
|
ch = logging.StreamHandler()
|
|
ch.setLevel(logging.INFO)
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
ch.setFormatter(formatter)
|
|
logger.addHandler(ch)
|
|
|
|
def parse_arguments():
|
|
"""Parse command line arguments"""
|
|
parser = argparse.ArgumentParser(description='Neural Network Trading System')
|
|
|
|
parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train',
|
|
help='Mode to run (train, predict, realtime)')
|
|
parser.add_argument('--symbol', type=str, default='BTC/USDT',
|
|
help='Trading pair symbol')
|
|
parser.add_argument('--timeframes', type=str, nargs='+', default=['1s', '1m', '5m', '1h', '4h'],
|
|
help='Timeframes to use (include 1s for ticks)')
|
|
parser.add_argument('--window-size', type=int, default=20,
|
|
help='Window size for input data')
|
|
parser.add_argument('--output-size', type=int, default=3,
|
|
help='Output size (1 for binary, 3 for BUY/HOLD/SELL)')
|
|
parser.add_argument('--batch-size', type=int, default=32,
|
|
help='Batch size for training')
|
|
parser.add_argument('--epochs', type=int, default=10,
|
|
help='Number of epochs for training')
|
|
parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn',
|
|
help='Model type to use')
|
|
|
|
return parser.parse_args()
|
|
|
|
def main():
|
|
"""Main entry point for the NN trading system"""
|
|
args = parse_arguments()
|
|
|
|
logger.info(f"Starting NN Trading System in {args.mode} mode")
|
|
logger.info(f"Configuration: Symbol={args.symbol}, Timeframes={args.timeframes}")
|
|
|
|
try:
|
|
import torch
|
|
from NN.utils.data_interface import DataInterface
|
|
|
|
# Import appropriate PyTorch model
|
|
if args.model_type == 'cnn':
|
|
from NN.models.cnn_model_pytorch import CNNModelPyTorch as Model
|
|
elif args.model_type == 'transformer':
|
|
from NN.models.transformer_model_pytorch import TransformerModelPyTorchWrapper as Model
|
|
elif args.model_type == 'moe':
|
|
from NN.models.transformer_model_pytorch import MixtureOfExpertsModelPyTorch as Model
|
|
else:
|
|
logger.error(f"Unknown model type: {args.model_type}")
|
|
return
|
|
|
|
except ImportError as e:
|
|
logger.error(f"Failed to import PyTorch modules: {str(e)}")
|
|
logger.error("Please make sure PyTorch is installed")
|
|
return
|
|
|
|
# Initialize data interface
|
|
try:
|
|
data_interface = DataInterface(
|
|
symbol=args.symbol,
|
|
timeframes=args.timeframes
|
|
)
|
|
|
|
# Verify data interface by fetching initial data
|
|
logger.info("Verifying data interface...")
|
|
X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True)
|
|
if X_sample is None or y_sample is None:
|
|
logger.error("Failed to prepare initial training data")
|
|
return
|
|
|
|
logger.info(f"Data interface verified - X shape: {X_sample.shape}, y shape: {y_sample.shape}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize data interface: {str(e)}")
|
|
return
|
|
|
|
# Initialize model
|
|
try:
|
|
# Calculate total number of features across all timeframes
|
|
num_features = data_interface.get_feature_count()
|
|
logger.info(f"Initializing model with {num_features} features")
|
|
|
|
model = Model(
|
|
window_size=args.window_size,
|
|
num_features=num_features,
|
|
output_size=args.output_size,
|
|
timeframes=args.timeframes
|
|
)
|
|
|
|
# Ensure model is on the correct device
|
|
if torch.cuda.is_available():
|
|
model.model = model.model.cuda()
|
|
logger.info("Model moved to CUDA device")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize model: {str(e)}")
|
|
return
|
|
|
|
# Execute requested mode
|
|
if args.mode == 'train':
|
|
train(data_interface, model, args)
|
|
elif args.mode == 'predict':
|
|
predict(data_interface, model, args)
|
|
elif args.mode == 'realtime':
|
|
realtime(data_interface, model, args)
|
|
|
|
def train(data_interface, model, args):
|
|
"""Enhanced training with performance tracking and retrospective fine-tuning"""
|
|
logger.info("Starting training mode...")
|
|
writer = SummaryWriter()
|
|
|
|
try:
|
|
best_val_acc = 0
|
|
best_val_pnl = float('-inf')
|
|
best_win_rate = 0
|
|
best_price_mae = float('inf')
|
|
|
|
logger.info("Verifying data interface...")
|
|
X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True)
|
|
logger.info(f"Data validation - X shape: {X_sample.shape}, y shape: {y_sample.shape}")
|
|
|
|
# Calculate refresh intervals based on timeframes
|
|
min_timeframe = min(args.timeframes)
|
|
refresh_interval = {
|
|
'1s': 1,
|
|
'1m': 60,
|
|
'5m': 300,
|
|
'15m': 900,
|
|
'1h': 3600,
|
|
'4h': 14400,
|
|
'1d': 86400
|
|
}.get(min_timeframe, 60)
|
|
|
|
logger.info(f"Using refresh interval of {refresh_interval} seconds based on {min_timeframe} timeframe")
|
|
|
|
for epoch in range(args.epochs):
|
|
# Always refresh for tick data or when using multiple timeframes
|
|
refresh = '1s' in args.timeframes or len(args.timeframes) > 1
|
|
|
|
logger.info(f"\nStarting epoch {epoch+1}/{args.epochs}")
|
|
X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data(
|
|
refresh=refresh,
|
|
refresh_interval=refresh_interval
|
|
)
|
|
logger.info(f"Training data - X shape: {X_train.shape}, y shape: {y_train.shape}")
|
|
logger.info(f"Validation data - X shape: {X_val.shape}, y shape: {y_val.shape}")
|
|
|
|
# Get future prices for retrospective training
|
|
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=3)
|
|
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=3)
|
|
|
|
# Train and validate
|
|
try:
|
|
train_action_loss, train_price_loss, train_acc = model.train_epoch(
|
|
X_train, y_train, train_future_prices, args.batch_size
|
|
)
|
|
val_action_loss, val_price_loss, val_acc = model.evaluate(
|
|
X_val, y_val, val_future_prices
|
|
)
|
|
|
|
# Get predictions for PnL calculation
|
|
train_action_probs, train_price_preds = model.predict(X_train)
|
|
val_action_probs, val_price_preds = model.predict(X_val)
|
|
|
|
# Convert probabilities to actions for PnL calculation
|
|
train_preds = np.argmax(train_action_probs, axis=1)
|
|
val_preds = np.argmax(val_action_probs, axis=1)
|
|
|
|
# Calculate PnL and win rates
|
|
try:
|
|
if train_preds is not None and train_prices is not None:
|
|
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
|
|
train_preds, train_prices, position_size=1.0
|
|
)
|
|
else:
|
|
train_pnl, train_win_rate, train_trades = 0, 0, []
|
|
|
|
if val_preds is not None and val_prices is not None:
|
|
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
|
|
val_preds, val_prices, position_size=1.0
|
|
)
|
|
else:
|
|
val_pnl, val_win_rate, val_trades = 0, 0, []
|
|
except Exception as e:
|
|
logger.error(f"Error calculating PnL: {str(e)}")
|
|
train_pnl, train_win_rate, val_pnl, val_win_rate = 0, 0, 0, 0
|
|
train_trades, val_trades = [], []
|
|
|
|
# Calculate price prediction error
|
|
if train_future_prices is not None and train_price_preds is not None:
|
|
# Ensure arrays have the same shape and are numpy arrays
|
|
train_future_prices_np = np.array(train_future_prices) if not isinstance(train_future_prices, np.ndarray) else train_future_prices
|
|
train_price_preds_np = np.array(train_price_preds) if not isinstance(train_price_preds, np.ndarray) else train_price_preds
|
|
|
|
if len(train_price_preds_np) > 0 and len(train_future_prices_np) > 0:
|
|
min_len = min(len(train_price_preds_np), len(train_future_prices_np))
|
|
train_price_mae = np.mean(np.abs(train_price_preds_np[:min_len] - train_future_prices_np[:min_len]))
|
|
else:
|
|
train_price_mae = float('inf')
|
|
else:
|
|
train_price_mae = float('inf')
|
|
|
|
if val_future_prices is not None and val_price_preds is not None:
|
|
# Ensure arrays have the same shape and are numpy arrays
|
|
val_future_prices_np = np.array(val_future_prices) if not isinstance(val_future_prices, np.ndarray) else val_future_prices
|
|
val_price_preds_np = np.array(val_price_preds) if not isinstance(val_price_preds, np.ndarray) else val_price_preds
|
|
|
|
if len(val_price_preds_np) > 0 and len(val_future_prices_np) > 0:
|
|
min_len = min(len(val_price_preds_np), len(val_future_prices_np))
|
|
val_price_mae = np.mean(np.abs(val_price_preds_np[:min_len] - val_future_prices_np[:min_len]))
|
|
else:
|
|
val_price_mae = float('inf')
|
|
else:
|
|
val_price_mae = float('inf')
|
|
|
|
# Monitor action distribution
|
|
train_actions = np.bincount(np.argmax(train_action_probs, axis=1), minlength=3)
|
|
val_actions = np.bincount(np.argmax(val_action_probs, axis=1), minlength=3)
|
|
|
|
# Log metrics
|
|
writer.add_scalar('Loss/action_train', train_action_loss, epoch)
|
|
writer.add_scalar('Loss/price_train', train_price_loss, epoch)
|
|
writer.add_scalar('Loss/action_val', val_action_loss, epoch)
|
|
writer.add_scalar('Loss/price_val', val_price_loss, epoch)
|
|
writer.add_scalar('Accuracy/train', train_acc, epoch)
|
|
writer.add_scalar('Accuracy/val', val_acc, epoch)
|
|
writer.add_scalar('PnL/train', train_pnl, epoch)
|
|
writer.add_scalar('PnL/val', val_pnl, epoch)
|
|
writer.add_scalar('WinRate/train', train_win_rate, epoch)
|
|
writer.add_scalar('WinRate/val', val_win_rate, epoch)
|
|
writer.add_scalar('PriceMAE/train', train_price_mae, epoch)
|
|
writer.add_scalar('PriceMAE/val', val_price_mae, epoch)
|
|
|
|
# Log action distribution
|
|
for i, action in enumerate(['SELL', 'HOLD', 'BUY']):
|
|
writer.add_scalar(f'Actions/train_{action}', train_actions[i], epoch)
|
|
writer.add_scalar(f'Actions/val_{action}', val_actions[i], epoch)
|
|
|
|
# Save best model based on validation metrics
|
|
if np.isscalar(val_pnl) and np.isscalar(best_val_pnl) and (val_pnl > best_val_pnl or (np.isclose(val_pnl, best_val_pnl) and val_acc > best_val_acc)):
|
|
best_val_pnl = val_pnl
|
|
best_val_acc = val_acc
|
|
best_win_rate = val_win_rate
|
|
best_price_mae = val_price_mae
|
|
model.save(f"models/{args.model_type}_best.pt")
|
|
logger.info("Saved new best model based on validation metrics")
|
|
|
|
# Log detailed metrics
|
|
logger.info(f"Epoch {epoch+1}/{args.epochs}")
|
|
logger.info("Training Metrics:")
|
|
logger.info(f" Action Loss: {train_action_loss:.4f}")
|
|
logger.info(f" Price Loss: {train_price_loss:.4f}")
|
|
logger.info(f" Accuracy: {train_acc:.2f}")
|
|
logger.info(f" PnL: {train_pnl:.2%}")
|
|
logger.info(f" Win Rate: {train_win_rate:.2%}")
|
|
logger.info(f" Price MAE: {train_price_mae:.2f}")
|
|
|
|
logger.info("Validation Metrics:")
|
|
logger.info(f" Action Loss: {val_action_loss:.4f}")
|
|
logger.info(f" Price Loss: {val_price_loss:.4f}")
|
|
logger.info(f" Accuracy: {val_acc:.2f}")
|
|
logger.info(f" PnL: {val_pnl:.2%}")
|
|
logger.info(f" Win Rate: {val_win_rate:.2%}")
|
|
logger.info(f" Price MAE: {val_price_mae:.2f}")
|
|
|
|
# Log action distribution
|
|
logger.info("Action Distribution:")
|
|
for i, action in enumerate(['SELL', 'HOLD', 'BUY']):
|
|
logger.info(f" {action}: Train={train_actions[i]}, Val={val_actions[i]}")
|
|
|
|
# Log trade statistics
|
|
logger.info("Trade Statistics:")
|
|
logger.info(f" Training trades: {len(train_trades)}")
|
|
logger.info(f" Validation trades: {len(val_trades)}")
|
|
|
|
# Log next candle predictions
|
|
if epoch % 10 == 0: # Every 10 epochs
|
|
logger.info("\nNext Candle Predictions:")
|
|
next_candles = model.predict_next_candles(X_val[-1:], n_candles=3)
|
|
for tf in args.timeframes:
|
|
if tf in next_candles:
|
|
logger.info(f"\n{tf} timeframe predictions:")
|
|
for i, pred in enumerate(next_candles[tf]):
|
|
action = ['SELL', 'HOLD', 'BUY'][np.argmax(pred)]
|
|
confidence = np.max(pred)
|
|
logger.info(f" Candle {i+1}: {action} (confidence: {confidence:.2f})")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during epoch {epoch+1}: {str(e)}")
|
|
continue
|
|
|
|
# Save final model
|
|
model.save(f"models/{args.model_type}_final_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt")
|
|
logger.info(f"\nTraining complete. Best validation metrics:")
|
|
logger.info(f"Accuracy: {best_val_acc:.2f}")
|
|
logger.info(f"PnL: {best_val_pnl:.2%}")
|
|
logger.info(f"Win Rate: {best_win_rate:.2%}")
|
|
logger.info(f"Price MAE: {best_price_mae:.2f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in training: {str(e)}")
|
|
|
|
def predict(data_interface, model, args):
|
|
"""Make predictions using the trained model"""
|
|
logger.info("Starting prediction mode...")
|
|
|
|
try:
|
|
# Load the latest model
|
|
model_dir = os.path.join('models')
|
|
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
|
|
|
|
if not model_files:
|
|
logger.error(f"No saved model found for type {args.model_type}")
|
|
return
|
|
|
|
latest_model = sorted(model_files)[-1]
|
|
model_path = os.path.join(model_dir, latest_model)
|
|
|
|
logger.info(f"Loading model from {model_path}...")
|
|
model.load(model_path)
|
|
|
|
# Prepare prediction data
|
|
logger.info("Preparing prediction data...")
|
|
X_pred = data_interface.prepare_prediction_data()
|
|
|
|
# Make predictions
|
|
logger.info("Making predictions...")
|
|
predictions = model.predict(X_pred)
|
|
|
|
# Process and display predictions
|
|
logger.info("Processing predictions...")
|
|
data_interface.process_predictions(predictions)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in prediction mode: {str(e)}")
|
|
|
|
def realtime(data_interface, model, args):
|
|
"""Run the model in real-time mode"""
|
|
logger.info("Starting real-time mode...")
|
|
|
|
try:
|
|
from NN.utils.realtime_analyzer import RealtimeAnalyzer
|
|
|
|
# Load the latest model
|
|
model_dir = os.path.join('models')
|
|
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
|
|
|
|
if not model_files:
|
|
logger.error(f"No saved model found for type {args.model_type}")
|
|
return
|
|
|
|
latest_model = sorted(model_files)[-1]
|
|
model_path = os.path.join(model_dir, latest_model)
|
|
|
|
logger.info(f"Loading model from {model_path}...")
|
|
model.load(model_path)
|
|
|
|
# Initialize realtime analyzer
|
|
logger.info("Initializing real-time analyzer...")
|
|
realtime_analyzer = RealtimeAnalyzer(
|
|
data_interface=data_interface,
|
|
model=model,
|
|
symbol=args.symbol,
|
|
timeframes=args.timeframes
|
|
)
|
|
|
|
# Start real-time analysis
|
|
logger.info("Starting real-time analysis...")
|
|
realtime_analyzer.start()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in real-time mode: {str(e)}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|