265 lines
9.5 KiB
Python
265 lines
9.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Neural Network Trading System Main Module
|
|
|
|
This module serves as the main entry point for the NN trading system,
|
|
coordinating data flow between different components and implementing
|
|
training and inference pipelines.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import logging
|
|
import argparse
|
|
from datetime import datetime
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler(),
|
|
logging.FileHandler(os.path.join('logs', f'nn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'))
|
|
]
|
|
)
|
|
|
|
logger = logging.getLogger('NN')
|
|
|
|
# Create logs directory if it doesn't exist
|
|
os.makedirs('logs', exist_ok=True)
|
|
|
|
def parse_arguments():
|
|
"""Parse command line arguments"""
|
|
parser = argparse.ArgumentParser(description='Neural Network Trading System')
|
|
|
|
parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train',
|
|
help='Mode to run (train, predict, realtime)')
|
|
parser.add_argument('--symbol', type=str, default='BTC/USDT',
|
|
help='Trading pair symbol')
|
|
parser.add_argument('--timeframes', type=str, nargs='+', default=['1h', '4h'],
|
|
help='Timeframes to use')
|
|
parser.add_argument('--window-size', type=int, default=20,
|
|
help='Window size for input data')
|
|
parser.add_argument('--output-size', type=int, default=3,
|
|
help='Output size (1 for binary, 3 for BUY/HOLD/SELL)')
|
|
parser.add_argument('--batch-size', type=int, default=32,
|
|
help='Batch size for training')
|
|
parser.add_argument('--epochs', type=int, default=100,
|
|
help='Number of epochs for training')
|
|
parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn',
|
|
help='Model type to use')
|
|
parser.add_argument('--framework', type=str, choices=['tensorflow', 'pytorch'], default='pytorch',
|
|
help='Deep learning framework to use')
|
|
|
|
return parser.parse_args()
|
|
|
|
def main():
|
|
"""Main entry point for the NN trading system"""
|
|
# Parse arguments
|
|
args = parse_arguments()
|
|
|
|
logger.info(f"Starting NN Trading System in {args.mode} mode")
|
|
logger.info(f"Configuration: Symbol={args.symbol}, Timeframes={args.timeframes}, "
|
|
f"Window Size={args.window_size}, Output Size={args.output_size}, "
|
|
f"Model Type={args.model_type}, Framework={args.framework}")
|
|
|
|
# Import the appropriate modules based on the framework
|
|
if args.framework == 'pytorch':
|
|
try:
|
|
import torch
|
|
logger.info(f"Using PyTorch {torch.__version__}")
|
|
|
|
# Import PyTorch-based modules
|
|
from NN.utils.data_interface import DataInterface
|
|
|
|
if args.model_type == 'cnn':
|
|
from NN.models.cnn_model_pytorch import CNNModelPyTorch as Model
|
|
elif args.model_type == 'transformer':
|
|
from NN.models.transformer_model_pytorch import TransformerModelPyTorchWrapper as Model
|
|
elif args.model_type == 'moe':
|
|
from NN.models.transformer_model_pytorch import MixtureOfExpertsModelPyTorch as Model
|
|
else:
|
|
logger.error(f"Unknown model type: {args.model_type}")
|
|
return
|
|
|
|
except ImportError as e:
|
|
logger.error(f"Failed to import PyTorch modules: {str(e)}")
|
|
logger.error("Please make sure PyTorch is installed or use the TensorFlow framework.")
|
|
return
|
|
|
|
elif args.framework == 'tensorflow':
|
|
try:
|
|
import tensorflow as tf
|
|
logger.info(f"Using TensorFlow {tf.__version__}")
|
|
|
|
# Import TensorFlow-based modules
|
|
from NN.utils.data_interface import DataInterface
|
|
|
|
if args.model_type == 'cnn':
|
|
from NN.models.cnn_model import CNNModel as Model
|
|
elif args.model_type == 'transformer':
|
|
from NN.models.transformer_model import TransformerModel as Model
|
|
elif args.model_type == 'moe':
|
|
from NN.models.transformer_model import MixtureOfExpertsModel as Model
|
|
else:
|
|
logger.error(f"Unknown model type: {args.model_type}")
|
|
return
|
|
|
|
except ImportError as e:
|
|
logger.error(f"Failed to import TensorFlow modules: {str(e)}")
|
|
logger.error("Please make sure TensorFlow is installed or use the PyTorch framework.")
|
|
return
|
|
else:
|
|
logger.error(f"Unknown framework: {args.framework}")
|
|
return
|
|
|
|
# Initialize data interface
|
|
try:
|
|
logger.info("Initializing data interface...")
|
|
data_interface = DataInterface(
|
|
symbol=args.symbol,
|
|
timeframes=args.timeframes,
|
|
window_size=args.window_size,
|
|
output_size=args.output_size
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize data interface: {str(e)}")
|
|
return
|
|
|
|
# Initialize model
|
|
try:
|
|
logger.info(f"Initializing {args.model_type.upper()} model...")
|
|
model = Model(
|
|
window_size=args.window_size,
|
|
num_features=data_interface.get_feature_count(),
|
|
output_size=args.output_size,
|
|
timeframes=args.timeframes
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize model: {str(e)}")
|
|
return
|
|
|
|
# Execute the requested mode
|
|
if args.mode == 'train':
|
|
train(data_interface, model, args)
|
|
elif args.mode == 'predict':
|
|
predict(data_interface, model, args)
|
|
elif args.mode == 'realtime':
|
|
realtime(data_interface, model, args)
|
|
else:
|
|
logger.error(f"Unknown mode: {args.mode}")
|
|
return
|
|
|
|
logger.info("Neural Network Trading System finished successfully")
|
|
|
|
def train(data_interface, model, args):
|
|
"""Train the model using the data interface"""
|
|
logger.info("Starting training mode...")
|
|
|
|
try:
|
|
# Prepare training data
|
|
logger.info("Preparing training data...")
|
|
X_train, y_train, X_val, y_val = data_interface.prepare_training_data()
|
|
|
|
# Train the model
|
|
logger.info("Training model...")
|
|
model.train(
|
|
X_train, y_train,
|
|
X_val, y_val,
|
|
batch_size=args.batch_size,
|
|
epochs=args.epochs
|
|
)
|
|
|
|
# Save the model
|
|
model_path = os.path.join(
|
|
'models',
|
|
f"{args.model_type}_{args.symbol.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
)
|
|
logger.info(f"Saving model to {model_path}...")
|
|
model.save(model_path)
|
|
|
|
# Evaluate the model
|
|
logger.info("Evaluating model...")
|
|
metrics = model.evaluate(X_val, y_val)
|
|
logger.info(f"Evaluation metrics: {metrics}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in training mode: {str(e)}")
|
|
return
|
|
|
|
def predict(data_interface, model, args):
|
|
"""Make predictions using the trained model"""
|
|
logger.info("Starting prediction mode...")
|
|
|
|
try:
|
|
# Load the latest model
|
|
model_dir = os.path.join('models')
|
|
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
|
|
|
|
if not model_files:
|
|
logger.error(f"No saved model found for type {args.model_type}")
|
|
return
|
|
|
|
latest_model = sorted(model_files)[-1]
|
|
model_path = os.path.join(model_dir, latest_model)
|
|
|
|
logger.info(f"Loading model from {model_path}...")
|
|
model.load(model_path)
|
|
|
|
# Prepare prediction data
|
|
logger.info("Preparing prediction data...")
|
|
X_pred = data_interface.prepare_prediction_data()
|
|
|
|
# Make predictions
|
|
logger.info("Making predictions...")
|
|
predictions = model.predict(X_pred)
|
|
|
|
# Process and display predictions
|
|
logger.info("Processing predictions...")
|
|
data_interface.process_predictions(predictions)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in prediction mode: {str(e)}")
|
|
return
|
|
|
|
def realtime(data_interface, model, args):
|
|
"""Run the model in real-time mode"""
|
|
logger.info("Starting real-time mode...")
|
|
|
|
try:
|
|
# Import realtime module
|
|
from NN.realtime import RealtimeAnalyzer
|
|
|
|
# Load the latest model
|
|
model_dir = os.path.join('models')
|
|
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
|
|
|
|
if not model_files:
|
|
logger.error(f"No saved model found for type {args.model_type}")
|
|
return
|
|
|
|
latest_model = sorted(model_files)[-1]
|
|
model_path = os.path.join(model_dir, latest_model)
|
|
|
|
logger.info(f"Loading model from {model_path}...")
|
|
model.load(model_path)
|
|
|
|
# Initialize realtime analyzer
|
|
logger.info("Initializing real-time analyzer...")
|
|
realtime_analyzer = RealtimeAnalyzer(
|
|
data_interface=data_interface,
|
|
model=model,
|
|
symbol=args.symbol,
|
|
timeframes=args.timeframes
|
|
)
|
|
|
|
# Start real-time analysis
|
|
logger.info("Starting real-time analysis...")
|
|
realtime_analyzer.start()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in real-time mode: {str(e)}")
|
|
return
|
|
|
|
if __name__ == "__main__":
|
|
main() |