trying to fix training
This commit is contained in:
@ -1,10 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Neural Network Trading System Main Module
|
||||
Neural Network Trading System Main Module - PyTorch Version
|
||||
|
||||
This module serves as the main entry point for the NN trading system,
|
||||
coordinating data flow between different components and implementing
|
||||
training and inference pipelines.
|
||||
using PyTorch exclusively for all model operations.
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -12,200 +11,259 @@ import sys
|
||||
import logging
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import numpy as np
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler(os.path.join('logs', f'nn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'))
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger('NN')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
try:
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
|
||||
# Try setting up file logging
|
||||
log_file = os.path.join('logs', f'nn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
|
||||
fh = logging.FileHandler(log_file)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
logger.addHandler(fh)
|
||||
logger.info(f"Logging to file: {log_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup file logging: {str(e)}. Falling back to console logging only.")
|
||||
|
||||
# Always setup console logging
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
|
||||
def parse_arguments():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(description='Neural Network Trading System')
|
||||
|
||||
parser.add_argument('--mode', type=str, choices=['train', 'predict', 'realtime'], default='train',
|
||||
help='Mode to run (train, predict, realtime)')
|
||||
parser.add_argument('--symbol', type=str, default='BTC/USD',
|
||||
help='Main trading pair symbol (default: BTC/USD)')
|
||||
parser.add_argument('--context-pairs', type=str, nargs='*', default=[],
|
||||
help='Additional context trading pairs')
|
||||
parser.add_argument('--timeframes', type=str, nargs='+', default=['5m', '15m', '1h'],
|
||||
help='Timeframes to use (default: 5m,15m,1h)')
|
||||
parser.add_argument('--window-size', type=int, default=30,
|
||||
help='Window size for input data (default: 30)')
|
||||
parser.add_argument('--output-size', type=int, default=5,
|
||||
help='Output size (1=up/down, 3=BUY/HOLD/SELL, 5=with extrema)')
|
||||
help='Mode to run (train, predict, realtime)')
|
||||
parser.add_argument('--symbol', type=str, default='BTC/USDT',
|
||||
help='Trading pair symbol')
|
||||
parser.add_argument('--timeframes', type=str, nargs='+', default=['1s', '1m', '5m', '1h', '4h'],
|
||||
help='Timeframes to use (include 1s for ticks)')
|
||||
parser.add_argument('--window-size', type=int, default=20,
|
||||
help='Window size for input data')
|
||||
parser.add_argument('--output-size', type=int, default=3,
|
||||
help='Output size (1 for binary, 3 for BUY/HOLD/SELL)')
|
||||
parser.add_argument('--batch-size', type=int, default=32,
|
||||
help='Batch size for training')
|
||||
parser.add_argument('--epochs', type=int, default=100,
|
||||
help='Number of epochs for training')
|
||||
help='Batch size for training')
|
||||
parser.add_argument('--epochs', type=int, default=10,
|
||||
help='Number of epochs for training')
|
||||
parser.add_argument('--model-type', type=str, choices=['cnn', 'transformer', 'moe'], default='cnn',
|
||||
help='Model type to use')
|
||||
parser.add_argument('--framework', type=str, choices=['tensorflow', 'pytorch'], default='pytorch',
|
||||
help='Deep learning framework to use')
|
||||
help='Model type to use')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def main():
|
||||
"""Main entry point for the NN trading system"""
|
||||
# Parse arguments
|
||||
args = parse_arguments()
|
||||
|
||||
logger.info(f"Starting NN Trading System in {args.mode} mode")
|
||||
logger.info(f"Main Symbol: {args.symbol}")
|
||||
if args.context_pairs:
|
||||
logger.info(f"Context Pairs: {args.context_pairs}")
|
||||
logger.info(f"Timeframes: {args.timeframes}")
|
||||
logger.info(f"Window Size: {args.window_size}")
|
||||
logger.info(f"Output Size: {args.output_size} (1=up/down, 3=BUY/HOLD/SELL, 5=with extrema)")
|
||||
logger.info(f"Model Type: {args.model_type}")
|
||||
logger.info(f"Framework: {args.framework}")
|
||||
logger.info(f"Configuration: Symbol={args.symbol}, Timeframes={args.timeframes}")
|
||||
|
||||
# Import the appropriate modules based on the framework
|
||||
if args.framework == 'pytorch':
|
||||
try:
|
||||
import torch
|
||||
logger.info(f"Using PyTorch {torch.__version__}")
|
||||
|
||||
# Import PyTorch-based modules
|
||||
from NN.utils.multi_data_interface import MultiDataInterface
|
||||
|
||||
if args.model_type == 'cnn':
|
||||
from NN.models.cnn_model_pytorch import CNNModelPyTorch as Model
|
||||
elif args.model_type == 'transformer':
|
||||
from NN.models.transformer_model_pytorch import TransformerModelPyTorchWrapper as Model
|
||||
elif args.model_type == 'moe':
|
||||
from NN.models.transformer_model_pytorch import MixtureOfExpertsModelPyTorch as Model
|
||||
else:
|
||||
logger.error(f"Unknown model type: {args.model_type}")
|
||||
return
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import PyTorch modules: {str(e)}")
|
||||
logger.error("Please make sure PyTorch is installed or use the TensorFlow framework.")
|
||||
try:
|
||||
import torch
|
||||
from NN.utils.data_interface import DataInterface
|
||||
|
||||
# Import appropriate PyTorch model
|
||||
if args.model_type == 'cnn':
|
||||
from NN.models.cnn_model_pytorch import CNNModelPyTorch as Model
|
||||
elif args.model_type == 'transformer':
|
||||
from NN.models.transformer_model_pytorch import TransformerModelPyTorchWrapper as Model
|
||||
elif args.model_type == 'moe':
|
||||
from NN.models.transformer_model_pytorch import MixtureOfExpertsModelPyTorch as Model
|
||||
else:
|
||||
logger.error(f"Unknown model type: {args.model_type}")
|
||||
return
|
||||
|
||||
elif args.framework == 'tensorflow':
|
||||
try:
|
||||
import tensorflow as tf
|
||||
logger.info(f"Using TensorFlow {tf.__version__}")
|
||||
|
||||
# Import TensorFlow-based modules
|
||||
from NN.utils.multi_data_interface import MultiDataInterface
|
||||
|
||||
if args.model_type == 'cnn':
|
||||
from NN.models.cnn_model import CNNModel as Model
|
||||
elif args.model_type == 'transformer':
|
||||
from NN.models.transformer_model import TransformerModel as Model
|
||||
elif args.model_type == 'moe':
|
||||
from NN.models.transformer_model import MixtureOfExpertsModel as Model
|
||||
else:
|
||||
logger.error(f"Unknown model type: {args.model_type}")
|
||||
return
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import TensorFlow modules: {str(e)}")
|
||||
logger.error("Please make sure TensorFlow is installed or use the PyTorch framework.")
|
||||
return
|
||||
else:
|
||||
logger.error(f"Unknown framework: {args.framework}")
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import PyTorch modules: {str(e)}")
|
||||
logger.error("Please make sure PyTorch is installed")
|
||||
return
|
||||
|
||||
|
||||
# Initialize data interface
|
||||
try:
|
||||
logger.info("Initializing data interface...")
|
||||
data_interface = MultiDataInterface(
|
||||
data_interface = DataInterface(
|
||||
symbol=args.symbol,
|
||||
timeframes=args.timeframes,
|
||||
window_size=args.window_size,
|
||||
output_size=args.output_size
|
||||
timeframes=args.timeframes
|
||||
)
|
||||
|
||||
# Verify data interface by fetching initial data
|
||||
logger.info("Verifying data interface...")
|
||||
X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True)
|
||||
if X_sample is None or y_sample is None:
|
||||
logger.error("Failed to prepare initial training data")
|
||||
return
|
||||
|
||||
logger.info(f"Data interface verified - X shape: {X_sample.shape}, y shape: {y_sample.shape}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize data interface: {str(e)}")
|
||||
return
|
||||
|
||||
|
||||
# Initialize model
|
||||
try:
|
||||
logger.info(f"Initializing {args.model_type.upper()} model...")
|
||||
# Calculate actual feature count (OHLCV per timeframe)
|
||||
num_features = 5 * len(args.timeframes)
|
||||
# Calculate total number of features across all timeframes
|
||||
num_features = data_interface.get_feature_count()
|
||||
logger.info(f"Initializing model with {num_features} features")
|
||||
|
||||
model = Model(
|
||||
window_size=args.window_size,
|
||||
num_features=num_features,
|
||||
output_size=args.output_size,
|
||||
timeframes=args.timeframes
|
||||
)
|
||||
|
||||
# Ensure model is on the correct device
|
||||
if torch.cuda.is_available():
|
||||
model.model = model.model.cuda()
|
||||
logger.info("Model moved to CUDA device")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model: {str(e)}")
|
||||
return
|
||||
|
||||
# Execute the requested mode
|
||||
|
||||
# Execute requested mode
|
||||
if args.mode == 'train':
|
||||
train(data_interface, model, args)
|
||||
elif args.mode == 'predict':
|
||||
predict(data_interface, model, args)
|
||||
elif args.mode == 'realtime':
|
||||
realtime(data_interface, model, args)
|
||||
else:
|
||||
logger.error(f"Unknown mode: {args.mode}")
|
||||
return
|
||||
|
||||
logger.info("Neural Network Trading System finished successfully")
|
||||
|
||||
def train(data_interface, model, args):
|
||||
"""Train the model using the data interface"""
|
||||
"""Enhanced training with performance tracking and retrospective fine-tuning"""
|
||||
logger.info("Starting training mode...")
|
||||
writer = SummaryWriter()
|
||||
|
||||
try:
|
||||
# Prepare training data
|
||||
logger.info("Preparing training data...")
|
||||
X, y, _ = data_interface.prepare_nn_input(
|
||||
timeframes=args.timeframes,
|
||||
n_candles=1000,
|
||||
window_size=args.window_size
|
||||
)
|
||||
logger.info(f"Training data shape: {X.shape}")
|
||||
logger.info(f"Target data shape: {y.shape}")
|
||||
best_val_acc = 0
|
||||
best_val_pnl = float('-inf')
|
||||
best_win_rate = 0
|
||||
|
||||
# Split into train/validation sets (80/20)
|
||||
split_idx = int(len(X) * 0.8)
|
||||
X_train, y_train = X[:split_idx], y[:split_idx]
|
||||
X_val, y_val = X[split_idx:], y[split_idx:]
|
||||
logger.info("Verifying data interface...")
|
||||
X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True)
|
||||
logger.info(f"Data validation - X shape: {X_sample.shape}, y shape: {y_sample.shape}")
|
||||
|
||||
# Train the model
|
||||
logger.info("Training model...")
|
||||
model.train(
|
||||
X_train, y_train,
|
||||
X_val, y_val,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs
|
||||
)
|
||||
|
||||
# Save the model
|
||||
model_path = os.path.join(
|
||||
'models',
|
||||
f"{args.model_type}_{args.symbol.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
)
|
||||
logger.info(f"Saving model to {model_path}...")
|
||||
model.save(model_path)
|
||||
|
||||
# Evaluate the model
|
||||
logger.info("Evaluating model...")
|
||||
metrics = model.evaluate(X_val, y_val)
|
||||
logger.info(f"Evaluation metrics: {metrics}")
|
||||
for epoch in range(args.epochs):
|
||||
# More frequent refresh for shorter timeframes
|
||||
if '1s' in args.timeframes:
|
||||
refresh = True # Always refresh for tick data
|
||||
refresh_interval = 30 # 30 seconds for tick data
|
||||
else:
|
||||
refresh = epoch % 1 == 0 # Refresh every epoch
|
||||
refresh_interval = 120 # 2 minutes for other timeframes
|
||||
|
||||
logger.info(f"\nStarting epoch {epoch+1}/{args.epochs}")
|
||||
X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data(
|
||||
refresh=refresh,
|
||||
refresh_interval=refresh_interval
|
||||
)
|
||||
logger.info(f"Training data - X shape: {X_train.shape}, y shape: {y_train.shape}")
|
||||
logger.info(f"Validation data - X shape: {X_val.shape}, y shape: {y_val.shape}")
|
||||
|
||||
# Train and validate
|
||||
try:
|
||||
train_loss, train_acc = model.train_epoch(X_train, y_train, args.batch_size)
|
||||
val_loss, val_acc = model.evaluate(X_val, y_val)
|
||||
|
||||
# Get predictions for PnL calculation
|
||||
train_preds = model.predict(X_train)
|
||||
val_preds = model.predict(X_val)
|
||||
|
||||
# Calculate PnL and win rates
|
||||
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
|
||||
train_preds, train_prices, position_size=1.0
|
||||
)
|
||||
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
|
||||
val_preds, val_prices, position_size=1.0
|
||||
)
|
||||
|
||||
# Monitor action distribution
|
||||
train_actions = np.bincount(train_preds, minlength=3)
|
||||
val_actions = np.bincount(val_preds, minlength=3)
|
||||
|
||||
# Log metrics
|
||||
writer.add_scalar('Loss/train', train_loss, epoch)
|
||||
writer.add_scalar('Accuracy/train', train_acc, epoch)
|
||||
writer.add_scalar('Loss/val', val_loss, epoch)
|
||||
writer.add_scalar('Accuracy/val', val_acc, epoch)
|
||||
writer.add_scalar('PnL/train', train_pnl, epoch)
|
||||
writer.add_scalar('PnL/val', val_pnl, epoch)
|
||||
writer.add_scalar('WinRate/train', train_win_rate, epoch)
|
||||
writer.add_scalar('WinRate/val', val_win_rate, epoch)
|
||||
|
||||
# Log action distribution
|
||||
for i, action in enumerate(['SELL', 'HOLD', 'BUY']):
|
||||
writer.add_scalar(f'Actions/train_{action}', train_actions[i], epoch)
|
||||
writer.add_scalar(f'Actions/val_{action}', val_actions[i], epoch)
|
||||
|
||||
# Save best model based on validation PnL
|
||||
if val_pnl > best_val_pnl:
|
||||
best_val_pnl = val_pnl
|
||||
best_val_acc = val_acc
|
||||
best_win_rate = val_win_rate
|
||||
model.save(f"models/{args.model_type}_best.pt")
|
||||
|
||||
# Log detailed metrics
|
||||
logger.info(f"Epoch {epoch+1}/{args.epochs} - "
|
||||
f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}, "
|
||||
f"PnL: {train_pnl:.2%}, Win Rate: {train_win_rate:.2%} - "
|
||||
f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}, "
|
||||
f"PnL: {val_pnl:.2%}, Win Rate: {val_win_rate:.2%}")
|
||||
|
||||
# Log action distribution
|
||||
logger.info("Action Distribution:")
|
||||
for i, action in enumerate(['SELL', 'HOLD', 'BUY']):
|
||||
logger.info(f"{action}: Train={train_actions[i]}, Val={val_actions[i]}")
|
||||
|
||||
# Log trade statistics
|
||||
if train_trades:
|
||||
logger.info(f"Training trades: {len(train_trades)}")
|
||||
logger.info(f"Validation trades: {len(val_trades)}")
|
||||
|
||||
# Retrospective fine-tuning
|
||||
if epoch > 0 and val_pnl > 0: # Only fine-tune if we're making profit
|
||||
logger.info("Performing retrospective fine-tuning...")
|
||||
|
||||
# Get predictions for next few candles
|
||||
next_candles = model.predict_next_candles(X_val[-1:], n_candles=3)
|
||||
|
||||
# Log predictions for each timeframe
|
||||
for tf, preds in next_candles.items():
|
||||
logger.info(f"Next 3 candles for {tf}:")
|
||||
for i, pred in enumerate(preds):
|
||||
action = ['SELL', 'HOLD', 'BUY'][np.argmax(pred)]
|
||||
confidence = np.max(pred)
|
||||
logger.info(f"Candle {i+1}: {action} (confidence: {confidence:.2f})")
|
||||
|
||||
# Fine-tune on recent successful trades
|
||||
successful_trades = [t for t in train_trades if t['pnl'] > 0]
|
||||
if successful_trades:
|
||||
logger.info(f"Fine-tuning on {len(successful_trades)} successful trades")
|
||||
# TODO: Implement fine-tuning logic here
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during epoch {epoch+1}: {str(e)}")
|
||||
continue
|
||||
|
||||
# Save final model
|
||||
model.save(f"models/{args.model_type}_final_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt")
|
||||
logger.info(f"Training complete. Best validation metrics:")
|
||||
logger.info(f"Accuracy: {best_val_acc:.2f}")
|
||||
logger.info(f"PnL: {best_val_pnl:.2%}")
|
||||
logger.info(f"Win Rate: {best_win_rate:.2%}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training mode: {str(e)}")
|
||||
return
|
||||
logger.error(f"Error in training: {str(e)}")
|
||||
|
||||
def predict(data_interface, model, args):
|
||||
"""Make predictions using the trained model"""
|
||||
@ -240,14 +298,12 @@ def predict(data_interface, model, args):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prediction mode: {str(e)}")
|
||||
return
|
||||
|
||||
def realtime(data_interface, model, args):
|
||||
"""Run the model in real-time mode"""
|
||||
logger.info("Starting real-time mode...")
|
||||
|
||||
try:
|
||||
# Import realtime analyzer
|
||||
from NN.utils.realtime_analyzer import RealtimeAnalyzer
|
||||
|
||||
# Load the latest model
|
||||
@ -279,7 +335,6 @@ def realtime(data_interface, model, args):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in real-time mode: {str(e)}")
|
||||
return
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user