402 lines
17 KiB
Python
402 lines
17 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Extended overnight training session for CNN model with real-time data updates
|
|
This script runs continuous model training, refreshing market data at regular intervals
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import logging
|
|
import numpy as np
|
|
import torch
|
|
import time
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
import signal
|
|
import threading
|
|
|
|
# Add the project root to path
|
|
sys.path.append(os.path.abspath('.'))
|
|
|
|
# Configure logging with timestamp in filename
|
|
log_dir = "logs"
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
log_file = os.path.join(log_dir, f"realtime_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler(log_file),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger('realtime_training')
|
|
|
|
# Import the model and data interfaces
|
|
from NN.models.cnn_model_pytorch import CNNModelPyTorch
|
|
from NN.utils.data_interface import DataInterface
|
|
from NN.utils.signal_interpreter import SignalInterpreter
|
|
|
|
# Global variables for graceful shutdown
|
|
running = True
|
|
training_stats = {
|
|
"epochs_completed": 0,
|
|
"best_val_pnl": -float('inf'),
|
|
"best_epoch": 0,
|
|
"best_win_rate": 0,
|
|
"training_started": datetime.now().isoformat(),
|
|
"last_update": datetime.now().isoformat(),
|
|
"epochs": []
|
|
}
|
|
|
|
def signal_handler(sig, frame):
|
|
"""Handle CTRL+C to gracefully exit training"""
|
|
global running
|
|
logger.info("Received interrupt signal. Finishing current epoch and saving model...")
|
|
running = False
|
|
|
|
# Register signal handler
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
def save_training_stats(stats, filepath="NN/models/saved/realtime_training_stats.json"):
|
|
"""Save training statistics to file"""
|
|
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
|
|
|
with open(filepath, 'w') as f:
|
|
json.dump(stats, f, indent=2)
|
|
|
|
logger.info(f"Training statistics saved to {filepath}")
|
|
|
|
def run_overnight_training():
|
|
"""
|
|
Run a continuous training session with real-time data updates
|
|
"""
|
|
global running, training_stats
|
|
|
|
# Configuration parameters
|
|
symbol = "BTC/USDT"
|
|
timeframes = ["1m", "5m", "15m"] # Multiple timeframes for better signal quality
|
|
window_size = 24 # Larger window size for capturing more patterns
|
|
output_size = 3 # BUY/HOLD/SELL
|
|
batch_size = 64 # Batch size for training
|
|
|
|
# Real-time configuration
|
|
data_refresh_interval = 300 # Refresh data every 5 minutes
|
|
checkpoint_interval = 3600 # Save checkpoint every hour
|
|
max_training_time = 12 * 3600 # 12 hours max runtime
|
|
|
|
# Initialize training start time
|
|
start_time = time.time()
|
|
last_checkpoint_time = start_time
|
|
last_data_refresh_time = start_time
|
|
|
|
logger.info(f"Starting overnight training session for CNN model with {symbol} real-time data")
|
|
logger.info(f"Configuration: timeframes={timeframes}, window_size={window_size}, batch_size={batch_size}")
|
|
logger.info(f"Data will refresh every {data_refresh_interval} seconds")
|
|
logger.info(f"Checkpoints will be saved every {checkpoint_interval} seconds")
|
|
logger.info(f"Maximum training time: {max_training_time/3600} hours")
|
|
|
|
try:
|
|
# Initialize data interface
|
|
logger.info("Initializing data interface...")
|
|
data_interface = DataInterface(
|
|
symbol=symbol,
|
|
timeframes=timeframes
|
|
)
|
|
|
|
# Prepare initial training data
|
|
logger.info("Loading initial training data...")
|
|
X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data(
|
|
refresh=True,
|
|
refresh_interval=data_refresh_interval
|
|
)
|
|
|
|
if X_train is None or y_train is None:
|
|
logger.error("Failed to load training data")
|
|
return
|
|
|
|
logger.info(f"Training data loaded - X shape: {X_train.shape}, y shape: {y_train.shape}")
|
|
logger.info(f"Validation data - X shape: {X_val.shape}, y shape: {y_val.shape}")
|
|
|
|
# Target distribution analysis
|
|
target_distribution = {
|
|
"SELL": np.sum(y_train == 0),
|
|
"HOLD": np.sum(y_train == 1),
|
|
"BUY": np.sum(y_train == 2)
|
|
}
|
|
|
|
logger.info(f"Target distribution: SELL: {target_distribution['SELL']} ({target_distribution['SELL']/len(y_train):.2%}), "
|
|
f"HOLD: {target_distribution['HOLD']} ({target_distribution['HOLD']/len(y_train):.2%}), "
|
|
f"BUY: {target_distribution['BUY']} ({target_distribution['BUY']/len(y_train):.2%})")
|
|
|
|
# Calculate future prices for profitability-focused loss function
|
|
logger.info("Calculating future price changes...")
|
|
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=8)
|
|
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=8)
|
|
|
|
# Initialize model
|
|
num_features = data_interface.get_feature_count()
|
|
logger.info(f"Initializing model with {num_features} features")
|
|
|
|
# Use the same window size as the data interface
|
|
actual_window_size = X_train.shape[1]
|
|
logger.info(f"Actual window size from data: {actual_window_size}")
|
|
|
|
# Try to load existing model if available
|
|
model_path = "NN/models/saved/optimized_short_term_model_best.pt"
|
|
model = CNNModelPyTorch(
|
|
window_size=actual_window_size,
|
|
num_features=num_features,
|
|
output_size=output_size,
|
|
timeframes=timeframes
|
|
)
|
|
|
|
# Try to load existing model for continued training
|
|
try:
|
|
if os.path.exists(model_path):
|
|
logger.info(f"Loading existing model from {model_path}")
|
|
model.load(model_path)
|
|
logger.info("Model loaded successfully")
|
|
else:
|
|
logger.info("No existing model found. Starting with a new model.")
|
|
except Exception as e:
|
|
logger.error(f"Error loading model: {str(e)}")
|
|
logger.info("Starting with a new model.")
|
|
|
|
# Initialize signal interpreter for testing predictions
|
|
signal_interpreter = SignalInterpreter(config={
|
|
'buy_threshold': 0.65,
|
|
'sell_threshold': 0.65,
|
|
'hold_threshold': 0.75,
|
|
'trend_filter_enabled': True,
|
|
'volume_filter_enabled': True
|
|
})
|
|
|
|
# Create checkpoint directory
|
|
checkpoint_dir = "NN/models/saved/realtime_checkpoints"
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
# Track metrics
|
|
epoch = 0
|
|
best_val_pnl = -float('inf')
|
|
best_win_rate = 0
|
|
best_epoch = 0
|
|
|
|
# Training loop
|
|
while running and (time.time() - start_time < max_training_time):
|
|
epoch += 1
|
|
epoch_start = time.time()
|
|
|
|
logger.info(f"Epoch {epoch} - Starting at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
|
|
# Check if we need to refresh data
|
|
if time.time() - last_data_refresh_time > data_refresh_interval:
|
|
logger.info("Refreshing training data...")
|
|
X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data(
|
|
refresh=True,
|
|
refresh_interval=data_refresh_interval
|
|
)
|
|
|
|
if X_train is None or y_train is None:
|
|
logger.warning("Failed to refresh training data. Using previous data.")
|
|
else:
|
|
logger.info(f"Refreshed training data - X shape: {X_train.shape}, y shape: {y_train.shape}")
|
|
|
|
# Recalculate future prices
|
|
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=8)
|
|
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=8)
|
|
|
|
last_data_refresh_time = time.time()
|
|
|
|
# Train one epoch
|
|
train_action_loss, train_price_loss, train_acc = model.train_epoch(
|
|
X_train, y_train, train_future_prices, batch_size
|
|
)
|
|
|
|
# Evaluate
|
|
val_action_loss, val_price_loss, val_acc = model.evaluate(
|
|
X_val, y_val, val_future_prices
|
|
)
|
|
|
|
logger.info(f"Epoch {epoch} results:")
|
|
logger.info(f" Train - Loss: {train_action_loss:.4f}, Accuracy: {train_acc:.4f}")
|
|
logger.info(f" Valid - Loss: {val_action_loss:.4f}, Accuracy: {val_acc:.4f}")
|
|
|
|
# Get predictions for PnL calculation
|
|
train_action_probs, train_price_preds = model.predict(X_train)
|
|
val_action_probs, val_price_preds = model.predict(X_val)
|
|
|
|
# Convert probabilities to actions
|
|
train_preds = np.argmax(train_action_probs, axis=1)
|
|
val_preds = np.argmax(val_action_probs, axis=1)
|
|
|
|
# Track signal distribution
|
|
train_buy_count = np.sum(train_preds == 2)
|
|
train_sell_count = np.sum(train_preds == 0)
|
|
train_hold_count = np.sum(train_preds == 1)
|
|
|
|
val_buy_count = np.sum(val_preds == 2)
|
|
val_sell_count = np.sum(val_preds == 0)
|
|
val_hold_count = np.sum(val_preds == 1)
|
|
|
|
signal_dist = {
|
|
"train": {
|
|
"BUY": float(train_buy_count / len(train_preds)) if len(train_preds) > 0 else 0,
|
|
"SELL": float(train_sell_count / len(train_preds)) if len(train_preds) > 0 else 0,
|
|
"HOLD": float(train_hold_count / len(train_preds)) if len(train_preds) > 0 else 0
|
|
},
|
|
"val": {
|
|
"BUY": float(val_buy_count / len(val_preds)) if len(val_preds) > 0 else 0,
|
|
"SELL": float(val_sell_count / len(val_preds)) if len(val_preds) > 0 else 0,
|
|
"HOLD": float(val_hold_count / len(val_preds)) if len(val_preds) > 0 else 0
|
|
}
|
|
}
|
|
|
|
# Calculate PnL and win rates with different position sizes
|
|
position_sizes = [0.1, 0.25, 0.5, 1.0, 2.0] # Multiple position sizes for robustness
|
|
best_position_train_pnl = -float('inf')
|
|
best_position_val_pnl = -float('inf')
|
|
best_position_train_wr = 0
|
|
best_position_val_wr = 0
|
|
best_position_size = 1.0
|
|
|
|
for position_size in position_sizes:
|
|
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
|
|
train_preds, train_prices, position_size=position_size
|
|
)
|
|
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
|
|
val_preds, val_prices, position_size=position_size
|
|
)
|
|
|
|
logger.info(f" Position Size: {position_size}")
|
|
logger.info(f" Train - PnL: {train_pnl:.4f}, Win Rate: {train_win_rate:.4f}, Trades: {len(train_trades)}")
|
|
logger.info(f" Valid - PnL: {val_pnl:.4f}, Win Rate: {val_win_rate:.4f}, Trades: {len(val_trades)}")
|
|
|
|
# Track best position size for this epoch
|
|
if val_pnl > best_position_val_pnl:
|
|
best_position_val_pnl = val_pnl
|
|
best_position_val_wr = val_win_rate
|
|
best_position_size = position_size
|
|
|
|
if train_pnl > best_position_train_pnl:
|
|
best_position_train_pnl = train_pnl
|
|
best_position_train_wr = train_win_rate
|
|
|
|
# Track best model overall (using position size 1.0 as reference)
|
|
if val_pnl > best_val_pnl and position_size == 1.0:
|
|
best_val_pnl = val_pnl
|
|
best_win_rate = val_win_rate
|
|
best_epoch = epoch
|
|
logger.info(f" New best validation PnL: {best_val_pnl:.4f} at epoch {best_epoch}")
|
|
|
|
# Save the best model
|
|
model.save(f"NN/models/saved/optimized_short_term_model_realtime_best")
|
|
|
|
# Store epoch metrics
|
|
epoch_metrics = {
|
|
"epoch": epoch,
|
|
"train_loss": float(train_action_loss),
|
|
"val_loss": float(val_action_loss),
|
|
"train_acc": float(train_acc),
|
|
"val_acc": float(val_acc),
|
|
"train_pnl": float(best_position_train_pnl),
|
|
"val_pnl": float(best_position_val_pnl),
|
|
"train_win_rate": float(best_position_train_wr),
|
|
"val_win_rate": float(best_position_val_wr),
|
|
"best_position_size": float(best_position_size),
|
|
"signal_distribution": signal_dist,
|
|
"timestamp": datetime.now().isoformat(),
|
|
"data_age": int(time.time() - last_data_refresh_time)
|
|
}
|
|
|
|
# Update training stats
|
|
training_stats["epochs_completed"] = epoch
|
|
training_stats["best_val_pnl"] = float(best_val_pnl)
|
|
training_stats["best_epoch"] = best_epoch
|
|
training_stats["best_win_rate"] = float(best_win_rate)
|
|
training_stats["last_update"] = datetime.now().isoformat()
|
|
training_stats["epochs"].append(epoch_metrics)
|
|
|
|
# Check if we need to save checkpoint
|
|
if time.time() - last_checkpoint_time > checkpoint_interval:
|
|
logger.info(f"Saving checkpoint at epoch {epoch}")
|
|
# Save model checkpoint
|
|
model.save(f"{checkpoint_dir}/checkpoint_epoch_{epoch}")
|
|
# Save training statistics
|
|
save_training_stats(training_stats)
|
|
last_checkpoint_time = time.time()
|
|
|
|
# Test trade signal generation with a random sample
|
|
random_idx = np.random.randint(0, len(X_val))
|
|
sample_X = X_val[random_idx:random_idx+1]
|
|
sample_probs, sample_price_pred = model.predict(sample_X)
|
|
|
|
# Process with signal interpreter
|
|
signal = signal_interpreter.interpret_signal(
|
|
sample_probs[0],
|
|
float(sample_price_pred[0][0]) if hasattr(sample_price_pred, "__getitem__") else float(sample_price_pred[0]),
|
|
market_data={'price': float(val_prices[random_idx]) if random_idx < len(val_prices) else 50000.0}
|
|
)
|
|
|
|
logger.info(f" Sample trade signal: {signal['action']} with confidence {signal['confidence']:.4f}")
|
|
|
|
# Log trading statistics
|
|
logger.info(f" Train - Actions: BUY={train_buy_count}, SELL={train_sell_count}, HOLD={train_hold_count}")
|
|
logger.info(f" Valid - Actions: BUY={val_buy_count}, SELL={val_sell_count}, HOLD={val_hold_count}")
|
|
|
|
# Log epoch timing
|
|
epoch_time = time.time() - epoch_start
|
|
total_elapsed = time.time() - start_time
|
|
time_remaining = max_training_time - total_elapsed
|
|
|
|
logger.info(f" Epoch completed in {epoch_time:.2f} seconds")
|
|
logger.info(f" Training time: {total_elapsed/3600:.2f} hours / {max_training_time/3600:.2f} hours")
|
|
logger.info(f" Estimated time remaining: {time_remaining/3600:.2f} hours")
|
|
|
|
# Save final model and performance metrics
|
|
logger.info("Saving final optimized model...")
|
|
model.save("NN/models/saved/optimized_short_term_model_realtime_final")
|
|
|
|
# Save performance metrics to file
|
|
save_training_stats(training_stats)
|
|
|
|
# Generate performance plots
|
|
try:
|
|
model.plot_training_history("NN/models/saved/realtime_training_stats.json")
|
|
logger.info("Performance plots generated successfully")
|
|
except Exception as e:
|
|
logger.error(f"Error generating plots: {str(e)}")
|
|
|
|
# Calculate total training time
|
|
total_time = time.time() - start_time
|
|
hours, remainder = divmod(total_time, 3600)
|
|
minutes, seconds = divmod(remainder, 60)
|
|
|
|
logger.info(f"Overnight training completed in {int(hours)}h {int(minutes)}m {int(seconds)}s")
|
|
logger.info(f"Best model performance - Epoch: {best_epoch}, PnL: {best_val_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during overnight training: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
|
|
# Try to save the model and stats in case of error
|
|
try:
|
|
if 'model' in locals():
|
|
model.save("NN/models/saved/optimized_short_term_model_realtime_emergency")
|
|
logger.info("Emergency model save completed")
|
|
if 'training_stats' in locals():
|
|
save_training_stats(training_stats, "NN/models/saved/realtime_training_stats_emergency.json")
|
|
except Exception as e2:
|
|
logger.error(f"Failed to save emergency checkpoint: {str(e2)}")
|
|
|
|
if __name__ == "__main__":
|
|
# Print startup banner
|
|
print("=" * 80)
|
|
print("OVERNIGHT REALTIME TRAINING SESSION")
|
|
print("This script will continuously train the model using real-time market data")
|
|
print("Press Ctrl+C to safely stop training and save the model")
|
|
print("=" * 80)
|
|
|
|
run_overnight_training() |