gogo2/test_model.py
Dobromir Popov 1610d5bd49 train works
2025-03-31 03:20:12 +03:00

254 lines
11 KiB
Python

#!/usr/bin/env python
"""
Extended training session for CNN model optimized for short-term high-leverage trading
"""
import os
import sys
import logging
import numpy as np
import torch
import time
# Add the project root to path
sys.path.append(os.path.abspath('.'))
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('extended_training')
# Import the optimized model
from NN.models.cnn_model_pytorch import CNNModelPyTorch
from NN.utils.data_interface import DataInterface
def run_extended_training():
"""
Run an extended training session for CNN model with comprehensive performance tracking
"""
# Extended configuration parameters
symbol = "BTC/USDT"
timeframes = ["1m", "5m", "15m"] # Multiple timeframes for better signal quality
window_size = 24 # Larger window size to capture more context
output_size = 3 # BUY/HOLD/SELL
batch_size = 64 # Increased batch size for more stable gradients
epochs = 30 # Extended training session
logger.info(f"Starting extended training session for CNN model with {symbol} data")
logger.info(f"Configuration: timeframes={timeframes}, window_size={window_size}, epochs={epochs}, batch_size={batch_size}")
start_time = time.time()
try:
# Initialize data interface with more data
logger.info("Initializing data interface...")
data_interface = DataInterface(
symbol=symbol,
timeframes=timeframes
)
# Prepare training data with more history
logger.info("Loading extended training data...")
X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data(
refresh=True,
# Increase data size for better training
test_size=0.15, # Smaller test size to have more training data
max_samples=1000 # More samples for training
)
if X_train is None or y_train is None:
logger.error("Failed to load training data")
return
logger.info(f"Training data loaded - X shape: {X_train.shape}, y shape: {y_train.shape}")
logger.info(f"Validation data - X shape: {X_val.shape}, y shape: {y_val.shape}")
# Get future prices for longer-term prediction
logger.info("Calculating future price changes...")
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=8) # Look further ahead
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=8)
# Initialize model
num_features = data_interface.get_feature_count()
logger.info(f"Initializing model with {num_features} features")
# Use the same window size as the data interface
actual_window_size = X_train.shape[1]
logger.info(f"Actual window size from data: {actual_window_size}")
model = CNNModelPyTorch(
window_size=actual_window_size,
num_features=num_features,
output_size=output_size,
timeframes=timeframes
)
# Track metrics over time
best_val_pnl = -float('inf')
best_win_rate = 0
best_epoch = 0
# Create checkpoint directory
checkpoint_dir = "NN/models/saved/training_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
# Performance tracking
metrics_history = {
"epoch": [],
"train_loss": [],
"val_loss": [],
"train_acc": [],
"val_acc": [],
"train_pnl": [],
"val_pnl": [],
"train_win_rate": [],
"val_win_rate": [],
"signal_distribution": []
}
logger.info("Starting extended training...")
for epoch in range(epochs):
logger.info(f"Epoch {epoch+1}/{epochs}")
epoch_start = time.time()
# Train one epoch
train_action_loss, train_price_loss, train_acc = model.train_epoch(
X_train, y_train, train_future_prices, batch_size
)
# Evaluate
val_action_loss, val_price_loss, val_acc = model.evaluate(
X_val, y_val, val_future_prices
)
logger.info(f"Epoch {epoch+1} results:")
logger.info(f" Train - Loss: {train_action_loss:.4f}, Accuracy: {train_acc:.4f}")
logger.info(f" Valid - Loss: {val_action_loss:.4f}, Accuracy: {val_acc:.4f}")
# Get predictions for PnL calculation
train_action_probs, train_price_preds = model.predict(X_train)
val_action_probs, val_price_preds = model.predict(X_val)
# Convert probabilities to actions
train_preds = np.argmax(train_action_probs, axis=1)
val_preds = np.argmax(val_action_probs, axis=1)
# Track signal distribution
train_buy_count = np.sum(train_preds == 2)
train_sell_count = np.sum(train_preds == 0)
train_hold_count = np.sum(train_preds == 1)
val_buy_count = np.sum(val_preds == 2)
val_sell_count = np.sum(val_preds == 0)
val_hold_count = np.sum(val_preds == 1)
signal_dist = {
"train": {
"BUY": train_buy_count / len(train_preds) if len(train_preds) > 0 else 0,
"SELL": train_sell_count / len(train_preds) if len(train_preds) > 0 else 0,
"HOLD": train_hold_count / len(train_preds) if len(train_preds) > 0 else 0
},
"val": {
"BUY": val_buy_count / len(val_preds) if len(val_preds) > 0 else 0,
"SELL": val_sell_count / len(val_preds) if len(val_preds) > 0 else 0,
"HOLD": val_hold_count / len(val_preds) if len(val_preds) > 0 else 0
}
}
# Calculate PnL and win rates with different position sizes
position_sizes = [0.1, 0.25, 0.5, 1.0, 2.0] # Adding higher leverage
best_position_train_pnl = -float('inf')
best_position_val_pnl = -float('inf')
best_position_train_wr = 0
best_position_val_wr = 0
for position_size in position_sizes:
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
train_preds, train_prices, position_size=position_size
)
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
val_preds, val_prices, position_size=position_size
)
logger.info(f" Position Size: {position_size}")
logger.info(f" Train - PnL: {train_pnl:.4f}, Win Rate: {train_win_rate:.4f}, Trades: {len(train_trades)}")
logger.info(f" Valid - PnL: {val_pnl:.4f}, Win Rate: {val_win_rate:.4f}, Trades: {len(val_trades)}")
# Track best position size for this epoch
if val_pnl > best_position_val_pnl:
best_position_val_pnl = val_pnl
best_position_val_wr = val_win_rate
if train_pnl > best_position_train_pnl:
best_position_train_pnl = train_pnl
best_position_train_wr = train_win_rate
# Track best model overall (using position size 1.0 as reference)
if val_pnl > best_val_pnl and position_size == 1.0:
best_val_pnl = val_pnl
best_win_rate = val_win_rate
best_epoch = epoch + 1
logger.info(f" New best validation PnL: {best_val_pnl:.4f} at epoch {best_epoch}")
# Save the best model
model.save(f"NN/models/saved/optimized_short_term_model_best")
# Track metrics for this epoch
metrics_history["epoch"].append(epoch + 1)
metrics_history["train_loss"].append(train_action_loss)
metrics_history["val_loss"].append(val_action_loss)
metrics_history["train_acc"].append(train_acc)
metrics_history["val_acc"].append(val_acc)
metrics_history["train_pnl"].append(best_position_train_pnl)
metrics_history["val_pnl"].append(best_position_val_pnl)
metrics_history["train_win_rate"].append(best_position_train_wr)
metrics_history["val_win_rate"].append(best_position_val_wr)
metrics_history["signal_distribution"].append(signal_dist)
# Save checkpoint every 5 epochs
if (epoch + 1) % 5 == 0:
model.save(f"{checkpoint_dir}/checkpoint_epoch_{epoch+1}")
# Log trading statistics
logger.info(f" Train - Actions: BUY={train_buy_count}, SELL={train_sell_count}, HOLD={train_hold_count}")
logger.info(f" Valid - Actions: BUY={val_buy_count}, SELL={val_sell_count}, HOLD={val_hold_count}")
# Log epoch timing
epoch_time = time.time() - epoch_start
logger.info(f" Epoch completed in {epoch_time:.2f} seconds")
# Save final model and performance metrics
logger.info("Saving final optimized model...")
model.save("NN/models/saved/optimized_short_term_model_extended")
# Save performance metrics to file
try:
import json
metrics_file = "NN/models/saved/training_metrics.json"
with open(metrics_file, 'w') as f:
json.dump(metrics_history, f, indent=2)
logger.info(f"Training metrics saved to {metrics_file}")
except Exception as e:
logger.error(f"Error saving metrics: {str(e)}")
# Generate performance plots
try:
model.plot_training_history()
except Exception as e:
logger.error(f"Error generating plots: {str(e)}")
# Calculate total training time
total_time = time.time() - start_time
hours, remainder = divmod(total_time, 3600)
minutes, seconds = divmod(remainder, 60)
logger.info(f"Extended training completed in {int(hours)}h {int(minutes)}m {int(seconds)}s")
logger.info(f"Best model performance - Epoch: {best_epoch}, PnL: {best_val_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
except Exception as e:
logger.error(f"Error during extended training: {str(e)}")
import traceback
logger.error(traceback.format_exc())
if __name__ == "__main__":
run_extended_training()