rename
This commit is contained in:
265
NN/realtime-main.py
Normal file
265
NN/realtime-main.py
Normal file
@ -0,0 +1,265 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Neural Network Trading System Main Module
|
||||
|
||||
This module serves as the main entry point for the NN trading system,
|
||||
coordinating data flow between different components and implementing
|
||||
training and inference pipelines.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler(os.path.join('logs', f'nn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'))
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger('NN')
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
|
||||
def parse_arguments():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(description='Neural Network Trading System')
|
||||
|
||||
parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train',
|
||||
help='Mode to run (train, predict, realtime)')
|
||||
parser.add_argument('--symbol', type=str, default='BTC/USDT',
|
||||
help='Trading pair symbol')
|
||||
parser.add_argument('--timeframes', type=str, nargs='+', default=['1h', '4h'],
|
||||
help='Timeframes to use')
|
||||
parser.add_argument('--window-size', type=int, default=20,
|
||||
help='Window size for input data')
|
||||
parser.add_argument('--output-size', type=int, default=3,
|
||||
help='Output size (1 for binary, 3 for BUY/HOLD/SELL)')
|
||||
parser.add_argument('--batch-size', type=int, default=32,
|
||||
help='Batch size for training')
|
||||
parser.add_argument('--epochs', type=int, default=100,
|
||||
help='Number of epochs for training')
|
||||
parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn',
|
||||
help='Model type to use')
|
||||
parser.add_argument('--framework', type=str, choices=['tensorflow', 'pytorch'], default='pytorch',
|
||||
help='Deep learning framework to use')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def main():
|
||||
"""Main entry point for the NN trading system"""
|
||||
# Parse arguments
|
||||
args = parse_arguments()
|
||||
|
||||
logger.info(f"Starting NN Trading System in {args.mode} mode")
|
||||
logger.info(f"Configuration: Symbol={args.symbol}, Timeframes={args.timeframes}, "
|
||||
f"Window Size={args.window_size}, Output Size={args.output_size}, "
|
||||
f"Model Type={args.model_type}, Framework={args.framework}")
|
||||
|
||||
# Import the appropriate modules based on the framework
|
||||
if args.framework == 'pytorch':
|
||||
try:
|
||||
import torch
|
||||
logger.info(f"Using PyTorch {torch.__version__}")
|
||||
|
||||
# Import PyTorch-based modules
|
||||
from NN.utils.data_interface import DataInterface
|
||||
|
||||
if args.model_type == 'cnn':
|
||||
from NN.models.cnn_model_pytorch import CNNModelPyTorch as Model
|
||||
elif args.model_type == 'transformer':
|
||||
from NN.models.transformer_model_pytorch import TransformerModelPyTorchWrapper as Model
|
||||
elif args.model_type == 'moe':
|
||||
from NN.models.transformer_model_pytorch import MixtureOfExpertsModelPyTorch as Model
|
||||
else:
|
||||
logger.error(f"Unknown model type: {args.model_type}")
|
||||
return
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import PyTorch modules: {str(e)}")
|
||||
logger.error("Please make sure PyTorch is installed or use the TensorFlow framework.")
|
||||
return
|
||||
|
||||
elif args.framework == 'tensorflow':
|
||||
try:
|
||||
import tensorflow as tf
|
||||
logger.info(f"Using TensorFlow {tf.__version__}")
|
||||
|
||||
# Import TensorFlow-based modules
|
||||
from NN.utils.data_interface import DataInterface
|
||||
|
||||
if args.model_type == 'cnn':
|
||||
from NN.models.cnn_model import CNNModel as Model
|
||||
elif args.model_type == 'transformer':
|
||||
from NN.models.transformer_model import TransformerModel as Model
|
||||
elif args.model_type == 'moe':
|
||||
from NN.models.transformer_model import MixtureOfExpertsModel as Model
|
||||
else:
|
||||
logger.error(f"Unknown model type: {args.model_type}")
|
||||
return
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import TensorFlow modules: {str(e)}")
|
||||
logger.error("Please make sure TensorFlow is installed or use the PyTorch framework.")
|
||||
return
|
||||
else:
|
||||
logger.error(f"Unknown framework: {args.framework}")
|
||||
return
|
||||
|
||||
# Initialize data interface
|
||||
try:
|
||||
logger.info("Initializing data interface...")
|
||||
data_interface = DataInterface(
|
||||
symbol=args.symbol,
|
||||
timeframes=args.timeframes,
|
||||
window_size=args.window_size,
|
||||
output_size=args.output_size
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize data interface: {str(e)}")
|
||||
return
|
||||
|
||||
# Initialize model
|
||||
try:
|
||||
logger.info(f"Initializing {args.model_type.upper()} model...")
|
||||
model = Model(
|
||||
window_size=args.window_size,
|
||||
num_features=data_interface.get_feature_count(),
|
||||
output_size=args.output_size,
|
||||
timeframes=args.timeframes
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model: {str(e)}")
|
||||
return
|
||||
|
||||
# Execute the requested mode
|
||||
if args.mode == 'train':
|
||||
train(data_interface, model, args)
|
||||
elif args.mode == 'predict':
|
||||
predict(data_interface, model, args)
|
||||
elif args.mode == 'realtime':
|
||||
realtime(data_interface, model, args)
|
||||
else:
|
||||
logger.error(f"Unknown mode: {args.mode}")
|
||||
return
|
||||
|
||||
logger.info("Neural Network Trading System finished successfully")
|
||||
|
||||
def train(data_interface, model, args):
|
||||
"""Train the model using the data interface"""
|
||||
logger.info("Starting training mode...")
|
||||
|
||||
try:
|
||||
# Prepare training data
|
||||
logger.info("Preparing training data...")
|
||||
X_train, y_train, X_val, y_val = data_interface.prepare_training_data()
|
||||
|
||||
# Train the model
|
||||
logger.info("Training model...")
|
||||
model.train(
|
||||
X_train, y_train,
|
||||
X_val, y_val,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs
|
||||
)
|
||||
|
||||
# Save the model
|
||||
model_path = os.path.join(
|
||||
'models',
|
||||
f"{args.model_type}_{args.symbol.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
)
|
||||
logger.info(f"Saving model to {model_path}...")
|
||||
model.save(model_path)
|
||||
|
||||
# Evaluate the model
|
||||
logger.info("Evaluating model...")
|
||||
metrics = model.evaluate(X_val, y_val)
|
||||
logger.info(f"Evaluation metrics: {metrics}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training mode: {str(e)}")
|
||||
return
|
||||
|
||||
def predict(data_interface, model, args):
|
||||
"""Make predictions using the trained model"""
|
||||
logger.info("Starting prediction mode...")
|
||||
|
||||
try:
|
||||
# Load the latest model
|
||||
model_dir = os.path.join('models')
|
||||
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
|
||||
|
||||
if not model_files:
|
||||
logger.error(f"No saved model found for type {args.model_type}")
|
||||
return
|
||||
|
||||
latest_model = sorted(model_files)[-1]
|
||||
model_path = os.path.join(model_dir, latest_model)
|
||||
|
||||
logger.info(f"Loading model from {model_path}...")
|
||||
model.load(model_path)
|
||||
|
||||
# Prepare prediction data
|
||||
logger.info("Preparing prediction data...")
|
||||
X_pred = data_interface.prepare_prediction_data()
|
||||
|
||||
# Make predictions
|
||||
logger.info("Making predictions...")
|
||||
predictions = model.predict(X_pred)
|
||||
|
||||
# Process and display predictions
|
||||
logger.info("Processing predictions...")
|
||||
data_interface.process_predictions(predictions)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prediction mode: {str(e)}")
|
||||
return
|
||||
|
||||
def realtime(data_interface, model, args):
|
||||
"""Run the model in real-time mode"""
|
||||
logger.info("Starting real-time mode...")
|
||||
|
||||
try:
|
||||
# Skip realtime import for training mode
|
||||
RealtimeAnalyzer = None
|
||||
|
||||
# Load the latest model
|
||||
model_dir = os.path.join('models')
|
||||
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
|
||||
|
||||
if not model_files:
|
||||
logger.error(f"No saved model found for type {args.model_type}")
|
||||
return
|
||||
|
||||
latest_model = sorted(model_files)[-1]
|
||||
model_path = os.path.join(model_dir, latest_model)
|
||||
|
||||
logger.info(f"Loading model from {model_path}...")
|
||||
model.load(model_path)
|
||||
|
||||
# Initialize realtime analyzer
|
||||
logger.info("Initializing real-time analyzer...")
|
||||
realtime_analyzer = RealtimeAnalyzer(
|
||||
data_interface=data_interface,
|
||||
model=model,
|
||||
symbol=args.symbol,
|
||||
timeframes=args.timeframes
|
||||
)
|
||||
|
||||
# Start real-time analysis
|
||||
logger.info("Starting real-time analysis...")
|
||||
realtime_analyzer.start()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in real-time mode: {str(e)}")
|
||||
return
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user