232 lines
8.4 KiB
Python
232 lines
8.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Neural Network Training Runner Script
|
|
|
|
This script runs the Neural Network Trading System with the existing conda environment.
|
|
It detects which deep learning framework is available (TensorFlow or PyTorch) and
|
|
adjusts the implementation accordingly.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import subprocess
|
|
import argparse
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger('nn_runner')
|
|
|
|
def detect_framework():
|
|
"""Detect which deep learning framework is available in the environment"""
|
|
try:
|
|
import torch
|
|
torch_version = torch.__version__
|
|
logger.info(f"PyTorch {torch_version} detected")
|
|
return "pytorch", torch_version
|
|
except ImportError:
|
|
logger.warning("PyTorch not found in environment")
|
|
try:
|
|
import tensorflow as tf
|
|
tf_version = tf.__version__
|
|
logger.info(f"TensorFlow {tf_version} detected")
|
|
return "tensorflow", tf_version
|
|
except ImportError:
|
|
logger.error("Neither PyTorch nor TensorFlow is available in the environment")
|
|
return None, None
|
|
|
|
def check_dependencies():
|
|
"""Check for required dependencies and return if they are met"""
|
|
required_packages = ["numpy", "pandas", "matplotlib", "scikit-learn"]
|
|
missing_packages = []
|
|
|
|
for package in required_packages:
|
|
try:
|
|
__import__(package)
|
|
except ImportError:
|
|
missing_packages.append(package)
|
|
|
|
if missing_packages:
|
|
logger.warning(f"Missing required packages: {', '.join(missing_packages)}")
|
|
return False
|
|
|
|
return True
|
|
|
|
def create_run_command(args, framework):
|
|
"""Create the command to run the neural network based on the available framework"""
|
|
cmd = ["python", "-m", "NN.main"]
|
|
|
|
# Add mode
|
|
cmd.extend(["--mode", args.mode])
|
|
|
|
# Add symbol
|
|
if args.symbol:
|
|
cmd.extend(["--symbol", args.symbol])
|
|
|
|
# Add timeframes
|
|
if args.timeframes:
|
|
cmd.extend(["--timeframes"] + args.timeframes)
|
|
|
|
# Add window size
|
|
if args.window_size:
|
|
cmd.extend(["--window-size", str(args.window_size)])
|
|
|
|
# Add output size
|
|
if args.output_size:
|
|
cmd.extend(["--output-size", str(args.output_size)])
|
|
|
|
# Add batch size
|
|
if args.batch_size:
|
|
cmd.extend(["--batch-size", str(args.batch_size)])
|
|
|
|
# Add epochs
|
|
if args.epochs:
|
|
cmd.extend(["--epochs", str(args.epochs)])
|
|
|
|
# Add model type
|
|
if args.model_type:
|
|
cmd.extend(["--model-type", args.model_type])
|
|
|
|
# Add framework-specific flag
|
|
cmd.extend(["--framework", framework])
|
|
|
|
return cmd
|
|
|
|
def parse_arguments():
|
|
"""Parse command line arguments"""
|
|
parser = argparse.ArgumentParser(description='Neural Network Trading System Runner')
|
|
|
|
parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train',
|
|
help='Mode to run (train, predict, realtime)')
|
|
parser.add_argument('--symbol', type=str, default='BTC/USDT',
|
|
help='Trading pair symbol')
|
|
parser.add_argument('--timeframes', type=str, nargs='+', default=['1h', '4h'],
|
|
help='Timeframes to use')
|
|
parser.add_argument('--window-size', type=int, default=20,
|
|
help='Window size for input data')
|
|
parser.add_argument('--output-size', type=int, default=3,
|
|
help='Output size (1 for binary, 3 for BUY/HOLD/SELL)')
|
|
parser.add_argument('--batch-size', type=int, default=32,
|
|
help='Batch size for training')
|
|
parser.add_argument('--epochs', type=int, default=100,
|
|
help='Number of epochs for training')
|
|
parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn',
|
|
help='Model type to use')
|
|
parser.add_argument('--conda-env', type=str, default='gpt-gpu',
|
|
help='Name of conda environment to use')
|
|
parser.add_argument('--no-conda', action='store_true',
|
|
help='Do not use conda environment activation')
|
|
parser.add_argument('--framework', type=str, choices=['tensorflow', 'pytorch'], default='pytorch',
|
|
help='Deep learning framework to use (default: pytorch)')
|
|
|
|
return parser.parse_args()
|
|
|
|
def main():
|
|
# Parse arguments
|
|
args = parse_arguments()
|
|
|
|
# Check if we should run with conda
|
|
if not args.no_conda and args.conda_env:
|
|
# Create conda activation command
|
|
if sys.platform == 'win32':
|
|
conda_cmd = f"conda activate {args.conda_env} && "
|
|
else:
|
|
conda_cmd = f"source activate {args.conda_env} && "
|
|
|
|
logger.info(f"Running with conda environment: {args.conda_env}")
|
|
|
|
# Create the run script
|
|
script_path = Path("run_nn_in_conda.bat" if sys.platform == 'win32' else "run_nn_in_conda.sh")
|
|
|
|
with open(script_path, 'w') as f:
|
|
if sys.platform == 'win32':
|
|
f.write("@echo off\n")
|
|
f.write(f"call conda activate {args.conda_env}\n")
|
|
f.write(f"python -m NN.main --mode {args.mode} --symbol {args.symbol}")
|
|
|
|
if args.timeframes:
|
|
f.write(f" --timeframes {' '.join(args.timeframes)}")
|
|
|
|
if args.window_size:
|
|
f.write(f" --window-size {args.window_size}")
|
|
|
|
if args.output_size:
|
|
f.write(f" --output-size {args.output_size}")
|
|
|
|
if args.batch_size:
|
|
f.write(f" --batch-size {args.batch_size}")
|
|
|
|
if args.epochs:
|
|
f.write(f" --epochs {args.epochs}")
|
|
|
|
if args.model_type:
|
|
f.write(f" --model-type {args.model_type}")
|
|
else:
|
|
f.write("#!/bin/bash\n")
|
|
f.write(f"source activate {args.conda_env}\n")
|
|
f.write(f"python -m NN.main --mode {args.mode} --symbol {args.symbol}")
|
|
|
|
if args.timeframes:
|
|
f.write(f" --timeframes {' '.join(args.timeframes)}")
|
|
|
|
if args.window_size:
|
|
f.write(f" --window-size {args.window_size}")
|
|
|
|
if args.output_size:
|
|
f.write(f" --output-size {args.output_size}")
|
|
|
|
if args.batch_size:
|
|
f.write(f" --batch-size {args.batch_size}")
|
|
|
|
if args.epochs:
|
|
f.write(f" --epochs {args.epochs}")
|
|
|
|
if args.model_type:
|
|
f.write(f" --model-type {args.model_type}")
|
|
|
|
# Make script executable on Unix
|
|
if sys.platform != 'win32':
|
|
os.chmod(script_path, 0o755)
|
|
|
|
# Run the script
|
|
logger.info(f"Created script: {script_path}")
|
|
logger.info("Run this script to execute the neural network with the conda environment")
|
|
|
|
if sys.platform == 'win32':
|
|
print("\nTo run the neural network, execute the following command:")
|
|
print(f" {script_path}")
|
|
else:
|
|
print("\nTo run the neural network, execute the following command:")
|
|
print(f" ./{script_path}")
|
|
else:
|
|
# Run directly without conda
|
|
# First detect available framework
|
|
framework, version = detect_framework()
|
|
|
|
if framework is None:
|
|
logger.error("Cannot run Neural Network - no deep learning framework available")
|
|
return
|
|
|
|
# Check dependencies
|
|
if not check_dependencies():
|
|
logger.error("Missing required dependencies - please install them first")
|
|
return
|
|
|
|
# Create command
|
|
cmd = create_run_command(args, framework)
|
|
|
|
# Run command
|
|
logger.info(f"Running command: {' '.join(cmd)}")
|
|
try:
|
|
subprocess.run(cmd, check=True)
|
|
except subprocess.CalledProcessError as e:
|
|
logger.error(f"Error running neural network: {str(e)}")
|
|
except Exception as e:
|
|
logger.error(f"Error: {str(e)}")
|
|
|
|
if __name__ == "__main__":
|
|
main() |