gogo2/NN/main.py
Dobromir Popov 0042581275 new nn wip
2025-03-25 13:38:25 +02:00

265 lines
9.5 KiB
Python

#!/usr/bin/env python3
"""
Neural Network Trading System Main Module
This module serves as the main entry point for the NN trading system,
coordinating data flow between different components and implementing
training and inference pipelines.
"""
import os
import sys
import logging
import argparse
from datetime import datetime
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler(os.path.join('logs', f'nn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'))
]
)
logger = logging.getLogger('NN')
# Create logs directory if it doesn't exist
os.makedirs('logs', exist_ok=True)
def parse_arguments():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description='Neural Network Trading System')
parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train',
help='Mode to run (train, predict, realtime)')
parser.add_argument('--symbol', type=str, default='BTC/USDT',
help='Trading pair symbol')
parser.add_argument('--timeframes', type=str, nargs='+', default=['1h', '4h'],
help='Timeframes to use')
parser.add_argument('--window-size', type=int, default=20,
help='Window size for input data')
parser.add_argument('--output-size', type=int, default=3,
help='Output size (1 for binary, 3 for BUY/HOLD/SELL)')
parser.add_argument('--batch-size', type=int, default=32,
help='Batch size for training')
parser.add_argument('--epochs', type=int, default=100,
help='Number of epochs for training')
parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn',
help='Model type to use')
parser.add_argument('--framework', type=str, choices=['tensorflow', 'pytorch'], default='pytorch',
help='Deep learning framework to use')
return parser.parse_args()
def main():
"""Main entry point for the NN trading system"""
# Parse arguments
args = parse_arguments()
logger.info(f"Starting NN Trading System in {args.mode} mode")
logger.info(f"Configuration: Symbol={args.symbol}, Timeframes={args.timeframes}, "
f"Window Size={args.window_size}, Output Size={args.output_size}, "
f"Model Type={args.model_type}, Framework={args.framework}")
# Import the appropriate modules based on the framework
if args.framework == 'pytorch':
try:
import torch
logger.info(f"Using PyTorch {torch.__version__}")
# Import PyTorch-based modules
from NN.utils.data_interface import DataInterface
if args.model_type == 'cnn':
from NN.models.cnn_model_pytorch import CNNModelPyTorch as Model
elif args.model_type == 'transformer':
from NN.models.transformer_model_pytorch import TransformerModelPyTorchWrapper as Model
elif args.model_type == 'moe':
from NN.models.transformer_model_pytorch import MixtureOfExpertsModelPyTorch as Model
else:
logger.error(f"Unknown model type: {args.model_type}")
return
except ImportError as e:
logger.error(f"Failed to import PyTorch modules: {str(e)}")
logger.error("Please make sure PyTorch is installed or use the TensorFlow framework.")
return
elif args.framework == 'tensorflow':
try:
import tensorflow as tf
logger.info(f"Using TensorFlow {tf.__version__}")
# Import TensorFlow-based modules
from NN.utils.data_interface import DataInterface
if args.model_type == 'cnn':
from NN.models.cnn_model import CNNModel as Model
elif args.model_type == 'transformer':
from NN.models.transformer_model import TransformerModel as Model
elif args.model_type == 'moe':
from NN.models.transformer_model import MixtureOfExpertsModel as Model
else:
logger.error(f"Unknown model type: {args.model_type}")
return
except ImportError as e:
logger.error(f"Failed to import TensorFlow modules: {str(e)}")
logger.error("Please make sure TensorFlow is installed or use the PyTorch framework.")
return
else:
logger.error(f"Unknown framework: {args.framework}")
return
# Initialize data interface
try:
logger.info("Initializing data interface...")
data_interface = DataInterface(
symbol=args.symbol,
timeframes=args.timeframes,
window_size=args.window_size,
output_size=args.output_size
)
except Exception as e:
logger.error(f"Failed to initialize data interface: {str(e)}")
return
# Initialize model
try:
logger.info(f"Initializing {args.model_type.upper()} model...")
model = Model(
window_size=args.window_size,
num_features=data_interface.get_feature_count(),
output_size=args.output_size,
timeframes=args.timeframes
)
except Exception as e:
logger.error(f"Failed to initialize model: {str(e)}")
return
# Execute the requested mode
if args.mode == 'train':
train(data_interface, model, args)
elif args.mode == 'predict':
predict(data_interface, model, args)
elif args.mode == 'realtime':
realtime(data_interface, model, args)
else:
logger.error(f"Unknown mode: {args.mode}")
return
logger.info("Neural Network Trading System finished successfully")
def train(data_interface, model, args):
"""Train the model using the data interface"""
logger.info("Starting training mode...")
try:
# Prepare training data
logger.info("Preparing training data...")
X_train, y_train, X_val, y_val = data_interface.prepare_training_data()
# Train the model
logger.info("Training model...")
model.train(
X_train, y_train,
X_val, y_val,
batch_size=args.batch_size,
epochs=args.epochs
)
# Save the model
model_path = os.path.join(
'models',
f"{args.model_type}_{args.symbol.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
)
logger.info(f"Saving model to {model_path}...")
model.save(model_path)
# Evaluate the model
logger.info("Evaluating model...")
metrics = model.evaluate(X_val, y_val)
logger.info(f"Evaluation metrics: {metrics}")
except Exception as e:
logger.error(f"Error in training mode: {str(e)}")
return
def predict(data_interface, model, args):
"""Make predictions using the trained model"""
logger.info("Starting prediction mode...")
try:
# Load the latest model
model_dir = os.path.join('models')
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
if not model_files:
logger.error(f"No saved model found for type {args.model_type}")
return
latest_model = sorted(model_files)[-1]
model_path = os.path.join(model_dir, latest_model)
logger.info(f"Loading model from {model_path}...")
model.load(model_path)
# Prepare prediction data
logger.info("Preparing prediction data...")
X_pred = data_interface.prepare_prediction_data()
# Make predictions
logger.info("Making predictions...")
predictions = model.predict(X_pred)
# Process and display predictions
logger.info("Processing predictions...")
data_interface.process_predictions(predictions)
except Exception as e:
logger.error(f"Error in prediction mode: {str(e)}")
return
def realtime(data_interface, model, args):
"""Run the model in real-time mode"""
logger.info("Starting real-time mode...")
try:
# Import realtime module
from NN.realtime import RealtimeAnalyzer
# Load the latest model
model_dir = os.path.join('models')
model_files = [f for f in os.listdir(model_dir) if f.startswith(args.model_type)]
if not model_files:
logger.error(f"No saved model found for type {args.model_type}")
return
latest_model = sorted(model_files)[-1]
model_path = os.path.join(model_dir, latest_model)
logger.info(f"Loading model from {model_path}...")
model.load(model_path)
# Initialize realtime analyzer
logger.info("Initializing real-time analyzer...")
realtime_analyzer = RealtimeAnalyzer(
data_interface=data_interface,
model=model,
symbol=args.symbol,
timeframes=args.timeframes
)
# Start real-time analysis
logger.info("Starting real-time analysis...")
realtime_analyzer.start()
except Exception as e:
logger.error(f"Error in real-time mode: {str(e)}")
return
if __name__ == "__main__":
main()