gogo2/train_cnn_with_realtime.py
Dobromir Popov c0872248ab misc
2025-05-13 17:19:52 +03:00

415 lines
18 KiB
Python

#!/usr/bin/env python
"""
Extended overnight training session for CNN model with real-time data updates
This script runs continuous model training, refreshing market data at regular intervals
"""
import os
import sys
import logging
import numpy as np
import torch
import time
import json
from datetime import datetime, timedelta
import signal
import threading
# Add the project root to path
sys.path.append(os.path.abspath('.'))
# Configure logging with timestamp in filename
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"realtime_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
logger = logging.getLogger('realtime_training')
# Import the model and data interfaces
from NN.models.cnn_model_pytorch import CNNModelPyTorch
from realtime import MultiTimeframeDataInterface
from NN.utils.signal_interpreter import SignalInterpreter
# Global variables for graceful shutdown
running = True
training_stats = {
"epochs_completed": 0,
"best_val_pnl": -float('inf'),
"best_epoch": 0,
"best_win_rate": 0,
"training_started": datetime.now().isoformat(),
"last_update": datetime.now().isoformat(),
"epochs": []
}
def signal_handler(sig, frame):
"""Handle CTRL+C to gracefully exit training"""
global running
logger.info("Received interrupt signal. Finishing current epoch and saving model...")
running = False
# Register signal handler
signal.signal(signal.SIGINT, signal_handler)
def save_training_stats(stats, filepath="NN/models/saved/realtime_training_stats.json"):
"""Save training statistics to file"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, 'w') as f:
json.dump(stats, f, indent=2)
logger.info(f"Training statistics saved to {filepath}")
def run_overnight_training():
"""
Run a continuous training session with real-time data updates
"""
global running, training_stats
# Configuration parameters
symbol = "BTC/USDT"
timeframes = ["1m", "5m", "15m"] # Multiple timeframes for better signal quality
window_size = 24 # Larger window size for capturing more patterns
output_size = 3 # BUY/HOLD/SELL
batch_size = 64 # Batch size for training
# Real-time configuration
data_refresh_interval = 300 # Refresh data every 5 minutes
checkpoint_interval = 3600 # Save checkpoint every hour
max_training_time = 12 * 3600 # 12 hours max runtime
# Initialize training start time
start_time = time.time()
last_checkpoint_time = start_time
last_data_refresh_time = start_time
logger.info(f"Starting overnight training session for CNN model with {symbol} real-time data")
logger.info(f"Configuration: timeframes={timeframes}, window_size={window_size}, batch_size={batch_size}")
logger.info(f"Data will refresh every {data_refresh_interval} seconds")
logger.info(f"Checkpoints will be saved every {checkpoint_interval} seconds")
logger.info(f"Maximum training time: {max_training_time/3600} hours")
try:
# Initialize data interface
logger.info("Initializing MultiTimeframeDataInterface...")
data_interface = MultiTimeframeDataInterface(
symbol=symbol,
timeframes=timeframes
)
# Prepare initial training data
logger.info("Loading initial training data...")
X_train_dict, y_train, X_val_dict, y_val, train_prices, val_prices = data_interface.prepare_training_data(
window_size=window_size,
refresh=True
)
if X_train_dict is None or y_train is None:
logger.error("Failed to load training data")
return
# Get reference timeframe (lowest timeframe)
reference_tf = min(timeframes, key=lambda x: data_interface.timeframe_to_seconds.get(x, 3600))
logger.info(f"Using {reference_tf} as reference timeframe")
# Log data shape information
for tf, X in X_train_dict.items():
logger.info(f"Training data for {tf} - X shape: {X.shape}")
logger.info(f"Target labels shape: {y_train.shape}")
logger.info(f"Validation data for {reference_tf} - X shape: {X_val_dict[reference_tf].shape}, y shape: {y_val.shape}")
# Target distribution analysis
target_distribution = {
"SELL": np.sum(y_train == 0),
"HOLD": np.sum(y_train == 1),
"BUY": np.sum(y_train == 2)
}
logger.info(f"Target distribution: SELL: {target_distribution['SELL']} ({target_distribution['SELL']/len(y_train):.2%}), "
f"HOLD: {target_distribution['HOLD']} ({target_distribution['HOLD']/len(y_train):.2%}), "
f"BUY: {target_distribution['BUY']} ({target_distribution['BUY']/len(y_train):.2%})")
# Calculate future prices for profitability-focused loss function
logger.info("Calculating future price changes...")
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=8)
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=8)
# Initialize model
num_features = X_train_dict[reference_tf].shape[2] # Get feature count from the data
logger.info(f"Initializing model with {num_features} features")
# Use the same window size as the data
actual_window_size = X_train_dict[reference_tf].shape[1]
logger.info(f"Actual window size from data: {actual_window_size}")
# Try to load existing model if available
model_path = "NN/models/saved/optimized_short_term_model_best.pt"
model = CNNModelPyTorch(
window_size=actual_window_size,
num_features=num_features,
output_size=output_size,
timeframes=timeframes
)
# Try to load existing model for continued training
try:
if os.path.exists(model_path):
logger.info(f"Loading existing model from {model_path}")
model.load(model_path)
logger.info("Model loaded successfully")
else:
logger.info("No existing model found. Starting with a new model.")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
logger.info("Starting with a new model.")
# Initialize signal interpreter for testing predictions
signal_interpreter = SignalInterpreter(config={
'buy_threshold': 0.65,
'sell_threshold': 0.65,
'hold_threshold': 0.75,
'trend_filter_enabled': True,
'volume_filter_enabled': True
})
# Create checkpoint directory
checkpoint_dir = "NN/models/saved/realtime_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
# Track metrics
epoch = 0
best_val_pnl = -float('inf')
best_win_rate = 0
best_epoch = 0
# Training loop
while running and (time.time() - start_time < max_training_time):
epoch += 1
epoch_start = time.time()
logger.info(f"Epoch {epoch} - Starting at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
# Check if we need to refresh data
if time.time() - last_data_refresh_time > data_refresh_interval:
logger.info("Refreshing training data...")
X_train_dict, y_train, X_val_dict, y_val, train_prices, val_prices = data_interface.prepare_training_data(
window_size=window_size,
refresh=True
)
if X_train_dict is None or y_train is None:
logger.warning("Failed to refresh training data. Using previous data.")
else:
logger.info(f"Refreshed training data for {reference_tf} - X shape: {X_train_dict[reference_tf].shape}, y shape: {y_train.shape}")
# Recalculate future prices
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=8)
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=8)
last_data_refresh_time = time.time()
# Convert multi-timeframe dict to the format expected by the model
# For now, we use only the reference timeframe, but in the future,
# the model should be updated to handle multi-timeframe inputs
X_train = X_train_dict[reference_tf]
X_val = X_val_dict[reference_tf]
# Train one epoch
train_action_loss, train_price_loss, train_acc = model.train_epoch(
X_train, y_train, train_future_prices, batch_size
)
# Evaluate
val_action_loss, val_price_loss, val_acc = model.evaluate(
X_val, y_val, val_future_prices
)
logger.info(f"Epoch {epoch} results:")
logger.info(f" Train - Loss: {train_action_loss:.4f}, Accuracy: {train_acc:.4f}")
logger.info(f" Valid - Loss: {val_action_loss:.4f}, Accuracy: {val_acc:.4f}")
# Get predictions for PnL calculation
train_action_probs, train_price_preds = model.predict(X_train)
val_action_probs, val_price_preds = model.predict(X_val)
# Convert probabilities to actions
train_preds = np.argmax(train_action_probs, axis=1)
val_preds = np.argmax(val_action_probs, axis=1)
# Track signal distribution
train_buy_count = np.sum(train_preds == 2)
train_sell_count = np.sum(train_preds == 0)
train_hold_count = np.sum(train_preds == 1)
val_buy_count = np.sum(val_preds == 2)
val_sell_count = np.sum(val_preds == 0)
val_hold_count = np.sum(val_preds == 1)
signal_dist = {
"train": {
"BUY": float(train_buy_count / len(train_preds)) if len(train_preds) > 0 else 0,
"SELL": float(train_sell_count / len(train_preds)) if len(train_preds) > 0 else 0,
"HOLD": float(train_hold_count / len(train_preds)) if len(train_preds) > 0 else 0
},
"val": {
"BUY": float(val_buy_count / len(val_preds)) if len(val_preds) > 0 else 0,
"SELL": float(val_sell_count / len(val_preds)) if len(val_preds) > 0 else 0,
"HOLD": float(val_hold_count / len(val_preds)) if len(val_preds) > 0 else 0
}
}
# Calculate PnL and win rates with different position sizes
position_sizes = [0.1, 0.25, 0.5, 1.0, 2.0] # Multiple position sizes for robustness
best_position_train_pnl = -float('inf')
best_position_val_pnl = -float('inf')
best_position_train_wr = 0
best_position_val_wr = 0
best_position_size = 1.0
for position_size in position_sizes:
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
train_preds, train_prices, position_size=position_size
)
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
val_preds, val_prices, position_size=position_size
)
logger.info(f" Position Size: {position_size}")
logger.info(f" Train - PnL: {train_pnl:.4f}, Win Rate: {train_win_rate:.4f}, Trades: {len(train_trades)}")
logger.info(f" Valid - PnL: {val_pnl:.4f}, Win Rate: {val_win_rate:.4f}, Trades: {len(val_trades)}")
# Track best position size for this epoch
if val_pnl > best_position_val_pnl:
best_position_val_pnl = val_pnl
best_position_val_wr = val_win_rate
best_position_size = position_size
if train_pnl > best_position_train_pnl:
best_position_train_pnl = train_pnl
best_position_train_wr = train_win_rate
# Track best model overall (using position size 1.0 as reference)
if val_pnl > best_val_pnl and position_size == 1.0:
best_val_pnl = val_pnl
best_win_rate = val_win_rate
best_epoch = epoch
logger.info(f" New best validation PnL: {best_val_pnl:.4f} at epoch {best_epoch}")
# Save the best model
model.save(f"NN/models/saved/optimized_short_term_model_realtime_best")
# Store epoch metrics
epoch_metrics = {
"epoch": epoch,
"train_loss": float(train_action_loss),
"val_loss": float(val_action_loss),
"train_acc": float(train_acc),
"val_acc": float(val_acc),
"train_pnl": float(best_position_train_pnl),
"val_pnl": float(best_position_val_pnl),
"train_win_rate": float(best_position_train_wr),
"val_win_rate": float(best_position_val_wr),
"best_position_size": float(best_position_size),
"signal_distribution": signal_dist,
"timestamp": datetime.now().isoformat(),
"data_age": int(time.time() - last_data_refresh_time)
}
# Update training stats
training_stats["epochs_completed"] = epoch
training_stats["best_val_pnl"] = float(best_val_pnl)
training_stats["best_epoch"] = best_epoch
training_stats["best_win_rate"] = float(best_win_rate)
training_stats["last_update"] = datetime.now().isoformat()
training_stats["epochs"].append(epoch_metrics)
# Check if we need to save checkpoint
if time.time() - last_checkpoint_time > checkpoint_interval:
logger.info(f"Saving checkpoint at epoch {epoch}")
# Save model checkpoint
model.save(f"{checkpoint_dir}/checkpoint_epoch_{epoch}")
# Save training statistics
save_training_stats(training_stats)
last_checkpoint_time = time.time()
# Test trade signal generation with a random sample
random_idx = np.random.randint(0, len(X_val))
sample_X = X_val[random_idx:random_idx+1]
sample_probs, sample_price_pred = model.predict(sample_X)
# Process with signal interpreter
signal = signal_interpreter.interpret_signal(
sample_probs[0],
float(sample_price_pred[0][0]) if hasattr(sample_price_pred, "__getitem__") else float(sample_price_pred[0]),
market_data={'price': float(val_prices[random_idx]) if random_idx < len(val_prices) else 50000.0}
)
logger.info(f" Sample trade signal: {signal['action']} with confidence {signal['confidence']:.4f}")
# Log trading statistics
logger.info(f" Train - Actions: BUY={train_buy_count}, SELL={train_sell_count}, HOLD={train_hold_count}")
logger.info(f" Valid - Actions: BUY={val_buy_count}, SELL={val_sell_count}, HOLD={val_hold_count}")
# Log epoch timing
epoch_time = time.time() - epoch_start
total_elapsed = time.time() - start_time
time_remaining = max_training_time - total_elapsed
logger.info(f" Epoch completed in {epoch_time:.2f} seconds")
logger.info(f" Training time: {total_elapsed/3600:.2f} hours / {max_training_time/3600:.2f} hours")
logger.info(f" Estimated time remaining: {time_remaining/3600:.2f} hours")
# Save final model and performance metrics
logger.info("Saving final optimized model...")
model.save("NN/models/saved/optimized_short_term_model_realtime_final")
# Save performance metrics to file
save_training_stats(training_stats)
# Generate performance plots
try:
model.plot_training_history("NN/models/saved/realtime_training_stats.json")
logger.info("Performance plots generated successfully")
except Exception as e:
logger.error(f"Error generating plots: {str(e)}")
# Calculate total training time
total_time = time.time() - start_time
hours, remainder = divmod(total_time, 3600)
minutes, seconds = divmod(remainder, 60)
logger.info(f"Overnight training completed in {int(hours)}h {int(minutes)}m {int(seconds)}s")
logger.info(f"Best model performance - Epoch: {best_epoch}, PnL: {best_val_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
except Exception as e:
logger.error(f"Error during overnight training: {str(e)}")
import traceback
logger.error(traceback.format_exc())
# Try to save the model and stats in case of error
try:
if 'model' in locals():
model.save("NN/models/saved/optimized_short_term_model_realtime_emergency")
logger.info("Emergency model save completed")
if 'training_stats' in locals():
save_training_stats(training_stats, "NN/models/saved/realtime_training_stats_emergency.json")
except Exception as e2:
logger.error(f"Failed to save emergency checkpoint: {str(e2)}")
if __name__ == "__main__":
# Print startup banner
print("=" * 80)
print("OVERNIGHT REALTIME TRAINING SESSION")
print("This script will continuously train the model using real-time market data")
print("Press Ctrl+C to safely stop training and save the model")
print("=" * 80)
run_overnight_training()