RL training
This commit is contained in:
parent
1610d5bd49
commit
4eac14022c
@ -919,13 +919,7 @@ class CNNModelPyTorch:
|
|||||||
logger.info(f"Backup saved to {backup_path}")
|
logger.info(f"Backup saved to {backup_path}")
|
||||||
|
|
||||||
def load(self, filepath):
|
def load(self, filepath):
|
||||||
"""
|
"""Load model weights from file"""
|
||||||
Load the model from a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filepath: Path to load the model from
|
|
||||||
"""
|
|
||||||
# Check if file exists
|
|
||||||
if not os.path.exists(f"{filepath}.pt"):
|
if not os.path.exists(f"{filepath}.pt"):
|
||||||
logger.error(f"Model file {filepath}.pt not found")
|
logger.error(f"Model file {filepath}.pt not found")
|
||||||
return False
|
return False
|
||||||
@ -938,27 +932,20 @@ class CNNModelPyTorch:
|
|||||||
self.window_size = model_state['window_size']
|
self.window_size = model_state['window_size']
|
||||||
self.num_features = model_state['num_features']
|
self.num_features = model_state['num_features']
|
||||||
self.output_size = model_state['output_size']
|
self.output_size = model_state['output_size']
|
||||||
self.timeframes = model_state['timeframes']
|
self.timeframes = model_state.get('timeframes', ["1m"])
|
||||||
|
|
||||||
|
# Load model state dict
|
||||||
|
self.load_state_dict(model_state['model_state_dict'])
|
||||||
|
|
||||||
|
# Load optimizer state if available
|
||||||
|
if 'optimizer_state_dict' in model_state:
|
||||||
|
self.optimizer.load_state_dict(model_state['optimizer_state_dict'])
|
||||||
|
|
||||||
# Load trading configuration if available
|
# Load trading configuration if available
|
||||||
if 'confidence_threshold' in model_state:
|
if 'confidence_threshold' in model_state:
|
||||||
self.confidence_threshold = model_state['confidence_threshold']
|
self.confidence_threshold = model_state['confidence_threshold']
|
||||||
if 'max_consecutive_same_action' in model_state:
|
if 'max_consecutive_same_action' in model_state:
|
||||||
self.max_consecutive_same_action = model_state['max_consecutive_same_action']
|
self.max_consecutive_same_action = model_state['max_consecutive_same_action']
|
||||||
if 'action_counts' in model_state:
|
|
||||||
self.action_counts = model_state['action_counts']
|
|
||||||
if 'last_actions' in model_state:
|
|
||||||
self.last_actions = model_state['last_actions']
|
|
||||||
|
|
||||||
# Rebuild the model
|
|
||||||
self.build_model()
|
|
||||||
|
|
||||||
# Load the model state
|
|
||||||
self.model.load_state_dict(model_state['model_state_dict'])
|
|
||||||
self.optimizer.load_state_dict(model_state['optimizer_state_dict'])
|
|
||||||
self.history = model_state['history']
|
|
||||||
|
|
||||||
logger.info(f"Model loaded from {filepath}.pt")
|
|
||||||
|
|
||||||
# Log model version information if available
|
# Log model version information if available
|
||||||
if 'model_version' in model_state:
|
if 'model_version' in model_state:
|
||||||
@ -973,7 +960,7 @@ class CNNModelPyTorch:
|
|||||||
|
|
||||||
def plot_training_history(self, metrics_file="NN/models/saved/training_metrics.json"):
|
def plot_training_history(self, metrics_file="NN/models/saved/training_metrics.json"):
|
||||||
"""
|
"""
|
||||||
Generate comprehensive performance visualization plots from training history
|
Plot training history from saved metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
metrics_file: Path to the saved metrics JSON file
|
metrics_file: Path to the saved metrics JSON file
|
||||||
@ -983,253 +970,72 @@ class CNNModelPyTorch:
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import matplotlib.dates as mdates
|
import matplotlib.dates as mdates
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Create directory for plots
|
|
||||||
plots_dir = "NN/models/saved/performance_plots"
|
|
||||||
os.makedirs(plots_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Load metrics
|
# Load metrics
|
||||||
with open(metrics_file, 'r') as f:
|
with open(metrics_file, 'r') as f:
|
||||||
metrics = json.load(f)
|
metrics = json.load(f)
|
||||||
|
|
||||||
epochs = metrics["epoch"]
|
# Create plots directory
|
||||||
|
plots_dir = os.path.join(os.path.dirname(metrics_file), 'plots')
|
||||||
|
os.makedirs(plots_dir, exist_ok=True)
|
||||||
|
|
||||||
# Set default style for better visualization
|
# Convert timestamps to datetime objects
|
||||||
plt.style.use('seaborn-darkgrid')
|
timestamps = [datetime.fromisoformat(ts) for ts in metrics['timestamps']]
|
||||||
|
|
||||||
# 1. Plot Loss and Accuracy
|
# 1. Plot Loss and Accuracy
|
||||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
|
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
|
||||||
|
|
||||||
# Loss plot
|
# Loss plot
|
||||||
ax1.plot(epochs, metrics["train_loss"], 'b-', label='Training Loss')
|
ax1.plot(timestamps, metrics['train_loss'], 'b-', label='Training Loss')
|
||||||
ax1.plot(epochs, metrics["val_loss"], 'r-', label='Validation Loss')
|
ax1.plot(timestamps, metrics['val_loss'], 'r-', label='Validation Loss')
|
||||||
ax1.set_title('Model Loss over Epochs', fontsize=16)
|
ax1.set_title('Model Loss Over Time')
|
||||||
ax1.set_ylabel('Loss', fontsize=14)
|
ax1.set_ylabel('Loss')
|
||||||
ax1.legend(loc='upper right', fontsize=12)
|
ax1.legend()
|
||||||
ax1.grid(True)
|
ax1.grid(True)
|
||||||
|
|
||||||
# Accuracy plot
|
# Accuracy plot
|
||||||
ax2.plot(epochs, metrics["train_acc"], 'b-', label='Training Accuracy')
|
ax2.plot(timestamps, metrics['train_acc'], 'g-', label='Training Accuracy')
|
||||||
ax2.plot(epochs, metrics["val_acc"], 'r-', label='Validation Accuracy')
|
ax2.plot(timestamps, metrics['val_acc'], 'm-', label='Validation Accuracy')
|
||||||
ax2.set_title('Model Accuracy over Epochs', fontsize=16)
|
ax2.set_title('Model Accuracy Over Time')
|
||||||
ax2.set_xlabel('Epoch', fontsize=14)
|
ax2.set_ylabel('Accuracy')
|
||||||
ax2.set_ylabel('Accuracy', fontsize=14)
|
ax2.legend()
|
||||||
ax2.legend(loc='lower right', fontsize=12)
|
|
||||||
ax2.grid(True)
|
ax2.grid(True)
|
||||||
|
|
||||||
|
# Format x-axis
|
||||||
|
ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M'))
|
||||||
|
plt.xticks(rotation=45)
|
||||||
|
|
||||||
|
# Save the plot
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(f"{plots_dir}/loss_accuracy.png", dpi=300)
|
plt.savefig(os.path.join(plots_dir, 'loss_accuracy.png'))
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
# 2. Plot PnL and Win Rate
|
# 2. Plot PnL and Win Rate
|
||||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
|
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
|
||||||
|
|
||||||
# PnL plot
|
# PnL plot
|
||||||
ax1.plot(epochs, metrics["train_pnl"], 'g-', label='Training PnL')
|
ax1.plot(timestamps, metrics['train_pnl'], 'g-', label='Training PnL')
|
||||||
ax1.plot(epochs, metrics["val_pnl"], 'm-', label='Validation PnL')
|
ax1.plot(timestamps, metrics['val_pnl'], 'r-', label='Validation PnL')
|
||||||
ax1.set_title('Trading Profit and Loss over Epochs', fontsize=16)
|
ax1.set_title('PnL Over Time')
|
||||||
ax1.set_ylabel('PnL', fontsize=14)
|
ax1.set_ylabel('PnL')
|
||||||
ax1.legend(loc='upper left', fontsize=12)
|
ax1.legend()
|
||||||
ax1.grid(True)
|
ax1.grid(True)
|
||||||
|
|
||||||
# Win Rate plot
|
# Win Rate plot
|
||||||
ax2.plot(epochs, metrics["train_win_rate"], 'g-', label='Training Win Rate')
|
ax2.plot(timestamps, metrics['train_win_rate'], 'b-', label='Training Win Rate')
|
||||||
ax2.plot(epochs, metrics["val_win_rate"], 'm-', label='Validation Win Rate')
|
ax2.plot(timestamps, metrics['val_win_rate'], 'm-', label='Validation Win Rate')
|
||||||
ax2.set_title('Trading Win Rate over Epochs', fontsize=16)
|
ax2.set_title('Win Rate Over Time')
|
||||||
ax2.set_xlabel('Epoch', fontsize=14)
|
ax2.set_ylabel('Win Rate')
|
||||||
ax2.set_ylabel('Win Rate', fontsize=14)
|
ax2.legend()
|
||||||
ax2.axhline(y=0.5, color='r', linestyle='--', label='50% Threshold')
|
|
||||||
ax2.legend(loc='lower right', fontsize=12)
|
|
||||||
ax2.grid(True)
|
ax2.grid(True)
|
||||||
|
|
||||||
|
# Format x-axis
|
||||||
|
ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M'))
|
||||||
|
plt.xticks(rotation=45)
|
||||||
|
|
||||||
|
# Save the plot
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(f"{plots_dir}/pnl_winrate.png", dpi=300)
|
plt.savefig(os.path.join(plots_dir, 'pnl_winrate.png'))
|
||||||
plt.close()
|
|
||||||
|
|
||||||
# 3. Plot Signal Distribution over time
|
|
||||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
|
|
||||||
|
|
||||||
# Training Signal Distribution
|
|
||||||
buy_train = [epoch_dist["train"]["BUY"] for epoch_dist in metrics["signal_distribution"]]
|
|
||||||
sell_train = [epoch_dist["train"]["SELL"] for epoch_dist in metrics["signal_distribution"]]
|
|
||||||
hold_train = [epoch_dist["train"]["HOLD"] for epoch_dist in metrics["signal_distribution"]]
|
|
||||||
|
|
||||||
ax1.stackplot(epochs, buy_train, hold_train, sell_train,
|
|
||||||
labels=['BUY', 'HOLD', 'SELL'],
|
|
||||||
colors=['green', 'gray', 'red'], alpha=0.7)
|
|
||||||
ax1.set_title('Training Signal Distribution over Epochs', fontsize=16)
|
|
||||||
ax1.set_ylabel('Proportion', fontsize=14)
|
|
||||||
ax1.legend(loc='upper right', fontsize=12)
|
|
||||||
ax1.set_ylim(0, 1)
|
|
||||||
ax1.grid(True)
|
|
||||||
|
|
||||||
# Validation Signal Distribution
|
|
||||||
buy_val = [epoch_dist["val"]["BUY"] for epoch_dist in metrics["signal_distribution"]]
|
|
||||||
sell_val = [epoch_dist["val"]["SELL"] for epoch_dist in metrics["signal_distribution"]]
|
|
||||||
hold_val = [epoch_dist["val"]["HOLD"] for epoch_dist in metrics["signal_distribution"]]
|
|
||||||
|
|
||||||
ax2.stackplot(epochs, buy_val, hold_val, sell_val,
|
|
||||||
labels=['BUY', 'HOLD', 'SELL'],
|
|
||||||
colors=['green', 'gray', 'red'], alpha=0.7)
|
|
||||||
ax2.set_title('Validation Signal Distribution over Epochs', fontsize=16)
|
|
||||||
ax2.set_xlabel('Epoch', fontsize=14)
|
|
||||||
ax2.set_ylabel('Proportion', fontsize=14)
|
|
||||||
ax2.legend(loc='upper right', fontsize=12)
|
|
||||||
ax2.set_ylim(0, 1)
|
|
||||||
ax2.grid(True)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig(f"{plots_dir}/signal_distribution.png", dpi=300)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
# 4. Performance Correlation Matrix
|
|
||||||
fig, ax = plt.subplots(figsize=(10, 8))
|
|
||||||
|
|
||||||
# Extract key metrics for correlation
|
|
||||||
corr_data = {}
|
|
||||||
corr_data['Loss'] = metrics["train_loss"]
|
|
||||||
corr_data['Accuracy'] = metrics["train_acc"]
|
|
||||||
corr_data['PnL'] = metrics["train_pnl"]
|
|
||||||
corr_data['Win Rate'] = metrics["train_win_rate"]
|
|
||||||
corr_data['BUY %'] = buy_train
|
|
||||||
corr_data['SELL %'] = sell_train
|
|
||||||
corr_data['HOLD %'] = hold_train
|
|
||||||
|
|
||||||
# Convert to numpy array
|
|
||||||
corr_matrix = np.zeros((len(corr_data), len(corr_data)))
|
|
||||||
labels = list(corr_data.keys())
|
|
||||||
|
|
||||||
# Calculate correlation
|
|
||||||
for i, key1 in enumerate(labels):
|
|
||||||
for j, key2 in enumerate(labels):
|
|
||||||
if i == j:
|
|
||||||
corr_matrix[i, j] = 1.0
|
|
||||||
else:
|
|
||||||
corr = np.corrcoef(corr_data[key1], corr_data[key2])[0, 1]
|
|
||||||
corr_matrix[i, j] = corr
|
|
||||||
|
|
||||||
# Plot heatmap
|
|
||||||
im = ax.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)
|
|
||||||
|
|
||||||
# Add colorbar
|
|
||||||
cbar = fig.colorbar(im, ax=ax)
|
|
||||||
cbar.set_label('Correlation', rotation=270, labelpad=20, fontsize=14)
|
|
||||||
|
|
||||||
# Add ticks and labels
|
|
||||||
ax.set_xticks(np.arange(len(labels)))
|
|
||||||
ax.set_yticks(np.arange(len(labels)))
|
|
||||||
ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=12)
|
|
||||||
ax.set_yticklabels(labels, fontsize=12)
|
|
||||||
|
|
||||||
# Add text annotations
|
|
||||||
for i in range(len(labels)):
|
|
||||||
for j in range(len(labels)):
|
|
||||||
text = ax.text(j, i, f"{corr_matrix[i, j]:.2f}",
|
|
||||||
ha="center", va="center", color="black" if abs(corr_matrix[i, j]) < 0.7 else "white")
|
|
||||||
|
|
||||||
ax.set_title('Correlation Matrix of Performance Metrics', fontsize=16)
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig(f"{plots_dir}/correlation_matrix.png", dpi=300)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
# 5. Combined Performance Dashboard
|
|
||||||
fig = plt.figure(figsize=(16, 20))
|
|
||||||
|
|
||||||
# Define grid layout
|
|
||||||
gs = fig.add_gridspec(4, 2, hspace=0.4, wspace=0.3)
|
|
||||||
|
|
||||||
# Plot 1: Loss curves
|
|
||||||
ax1 = fig.add_subplot(gs[0, 0])
|
|
||||||
ax1.plot(epochs, metrics["train_loss"], 'b-', label='Training')
|
|
||||||
ax1.plot(epochs, metrics["val_loss"], 'r-', label='Validation')
|
|
||||||
ax1.set_title('Loss', fontsize=14)
|
|
||||||
ax1.set_xlabel('Epoch', fontsize=12)
|
|
||||||
ax1.set_ylabel('Loss', fontsize=12)
|
|
||||||
ax1.legend(fontsize=10)
|
|
||||||
ax1.grid(True)
|
|
||||||
|
|
||||||
# Plot 2: Accuracy
|
|
||||||
ax2 = fig.add_subplot(gs[0, 1])
|
|
||||||
ax2.plot(epochs, metrics["train_acc"], 'b-', label='Training')
|
|
||||||
ax2.plot(epochs, metrics["val_acc"], 'r-', label='Validation')
|
|
||||||
ax2.set_title('Accuracy', fontsize=14)
|
|
||||||
ax2.set_xlabel('Epoch', fontsize=12)
|
|
||||||
ax2.set_ylabel('Accuracy', fontsize=12)
|
|
||||||
ax2.legend(fontsize=10)
|
|
||||||
ax2.grid(True)
|
|
||||||
|
|
||||||
# Plot 3: PnL
|
|
||||||
ax3 = fig.add_subplot(gs[1, 0])
|
|
||||||
ax3.plot(epochs, metrics["train_pnl"], 'g-', label='Training')
|
|
||||||
ax3.plot(epochs, metrics["val_pnl"], 'm-', label='Validation')
|
|
||||||
ax3.set_title('Profit and Loss', fontsize=14)
|
|
||||||
ax3.set_xlabel('Epoch', fontsize=12)
|
|
||||||
ax3.set_ylabel('PnL', fontsize=12)
|
|
||||||
ax3.legend(fontsize=10)
|
|
||||||
ax3.grid(True)
|
|
||||||
|
|
||||||
# Plot 4: Win Rate
|
|
||||||
ax4 = fig.add_subplot(gs[1, 1])
|
|
||||||
ax4.plot(epochs, metrics["train_win_rate"], 'g-', label='Training')
|
|
||||||
ax4.plot(epochs, metrics["val_win_rate"], 'm-', label='Validation')
|
|
||||||
ax4.axhline(y=0.5, color='r', linestyle='--', label='50% Threshold')
|
|
||||||
ax4.set_title('Win Rate', fontsize=14)
|
|
||||||
ax4.set_xlabel('Epoch', fontsize=12)
|
|
||||||
ax4.set_ylabel('Win Rate', fontsize=12)
|
|
||||||
ax4.legend(fontsize=10)
|
|
||||||
ax4.grid(True)
|
|
||||||
|
|
||||||
# Plot 5: Training Signal Distribution
|
|
||||||
ax5 = fig.add_subplot(gs[2, 0])
|
|
||||||
ax5.stackplot(epochs, buy_train, hold_train, sell_train,
|
|
||||||
labels=['BUY', 'HOLD', 'SELL'],
|
|
||||||
colors=['green', 'gray', 'red'], alpha=0.7)
|
|
||||||
ax5.set_title('Training Signal Distribution', fontsize=14)
|
|
||||||
ax5.set_xlabel('Epoch', fontsize=12)
|
|
||||||
ax5.set_ylabel('Proportion', fontsize=12)
|
|
||||||
ax5.legend(fontsize=10)
|
|
||||||
ax5.set_ylim(0, 1)
|
|
||||||
ax5.grid(True)
|
|
||||||
|
|
||||||
# Plot 6: Validation Signal Distribution
|
|
||||||
ax6 = fig.add_subplot(gs[2, 1])
|
|
||||||
ax6.stackplot(epochs, buy_val, hold_val, sell_val,
|
|
||||||
labels=['BUY', 'HOLD', 'SELL'],
|
|
||||||
colors=['green', 'gray', 'red'], alpha=0.7)
|
|
||||||
ax6.set_title('Validation Signal Distribution', fontsize=14)
|
|
||||||
ax6.set_xlabel('Epoch', fontsize=12)
|
|
||||||
ax6.set_ylabel('Proportion', fontsize=12)
|
|
||||||
ax6.legend(fontsize=10)
|
|
||||||
ax6.set_ylim(0, 1)
|
|
||||||
ax6.grid(True)
|
|
||||||
|
|
||||||
# Plot 7: Performance Correlation Heatmap
|
|
||||||
ax7 = fig.add_subplot(gs[3, :])
|
|
||||||
im = ax7.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)
|
|
||||||
cbar = fig.colorbar(im, ax=ax7, fraction=0.025, pad=0.04)
|
|
||||||
cbar.set_label('Correlation', rotation=270, labelpad=20, fontsize=12)
|
|
||||||
|
|
||||||
# Add ticks and labels
|
|
||||||
ax7.set_xticks(np.arange(len(labels)))
|
|
||||||
ax7.set_yticks(np.arange(len(labels)))
|
|
||||||
ax7.set_xticklabels(labels, rotation=45, ha="right", fontsize=10)
|
|
||||||
ax7.set_yticklabels(labels, fontsize=10)
|
|
||||||
|
|
||||||
# Add text annotations
|
|
||||||
for i in range(len(labels)):
|
|
||||||
for j in range(len(labels)):
|
|
||||||
text = ax7.text(j, i, f"{corr_matrix[i, j]:.2f}",
|
|
||||||
ha="center", va="center", color="black" if abs(corr_matrix[i, j]) < 0.7 else "white")
|
|
||||||
|
|
||||||
ax7.set_title('Correlation Matrix of Performance Metrics', fontsize=14)
|
|
||||||
|
|
||||||
# Add main title
|
|
||||||
plt.suptitle('CNN Model Performance Dashboard', fontsize=20, y=0.98)
|
|
||||||
|
|
||||||
plt.tight_layout(rect=[0, 0, 1, 0.97])
|
|
||||||
plt.savefig(f"{plots_dir}/performance_dashboard.png", dpi=300)
|
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
print(f"Performance visualizations saved to {plots_dir}")
|
print(f"Performance visualizations saved to {plots_dir}")
|
||||||
|
170
NN/models/dqn_agent.py
Normal file
170
NN/models/dqn_agent.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import numpy as np
|
||||||
|
from collections import deque
|
||||||
|
import random
|
||||||
|
from typing import Tuple, List
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Add parent directory to path
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||||
|
|
||||||
|
from NN.models.simple_cnn import CNNModelPyTorch
|
||||||
|
|
||||||
|
class DQNAgent:
|
||||||
|
"""
|
||||||
|
Deep Q-Network agent for trading
|
||||||
|
Uses CNN model as the base network
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
state_size: int,
|
||||||
|
action_size: int,
|
||||||
|
window_size: int,
|
||||||
|
num_features: int,
|
||||||
|
timeframes: List[str],
|
||||||
|
learning_rate: float = 0.001,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
epsilon: float = 1.0,
|
||||||
|
epsilon_min: float = 0.01,
|
||||||
|
epsilon_decay: float = 0.995,
|
||||||
|
memory_size: int = 10000,
|
||||||
|
batch_size: int = 64,
|
||||||
|
target_update: int = 10):
|
||||||
|
|
||||||
|
self.state_size = state_size
|
||||||
|
self.action_size = action_size
|
||||||
|
self.window_size = window_size
|
||||||
|
self.num_features = num_features
|
||||||
|
self.timeframes = timeframes
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.gamma = gamma
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.epsilon_min = epsilon_min
|
||||||
|
self.epsilon_decay = epsilon_decay
|
||||||
|
self.memory_size = memory_size
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.target_update = target_update
|
||||||
|
|
||||||
|
# Device configuration
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
# Initialize networks
|
||||||
|
self.policy_net = CNNModelPyTorch(
|
||||||
|
window_size=window_size,
|
||||||
|
num_features=num_features,
|
||||||
|
output_size=action_size,
|
||||||
|
timeframes=timeframes
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
self.target_net = CNNModelPyTorch(
|
||||||
|
window_size=window_size,
|
||||||
|
num_features=num_features,
|
||||||
|
output_size=action_size,
|
||||||
|
timeframes=timeframes
|
||||||
|
).to(self.device)
|
||||||
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||||
|
|
||||||
|
# Initialize optimizer
|
||||||
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
|
||||||
|
|
||||||
|
# Initialize memory
|
||||||
|
self.memory = deque(maxlen=memory_size)
|
||||||
|
|
||||||
|
# Training metrics
|
||||||
|
self.update_count = 0
|
||||||
|
self.losses = []
|
||||||
|
|
||||||
|
def remember(self, state: np.ndarray, action: int, reward: float,
|
||||||
|
next_state: np.ndarray, done: bool):
|
||||||
|
"""Store experience in memory"""
|
||||||
|
self.memory.append((state, action, reward, next_state, done))
|
||||||
|
|
||||||
|
def act(self, state: np.ndarray) -> int:
|
||||||
|
"""Choose action using epsilon-greedy policy"""
|
||||||
|
if random.random() < self.epsilon:
|
||||||
|
return random.randrange(self.action_size)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||||
|
action_probs, _ = self.policy_net(state)
|
||||||
|
return action_probs.argmax().item()
|
||||||
|
|
||||||
|
def replay(self) -> float:
|
||||||
|
"""Train on a batch of experiences"""
|
||||||
|
if len(self.memory) < self.batch_size:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Sample batch
|
||||||
|
batch = random.sample(self.memory, self.batch_size)
|
||||||
|
states, actions, rewards, next_states, dones = zip(*batch)
|
||||||
|
|
||||||
|
# Convert to tensors and move to device
|
||||||
|
states = torch.FloatTensor(np.array(states)).to(self.device)
|
||||||
|
actions = torch.LongTensor(actions).to(self.device)
|
||||||
|
rewards = torch.FloatTensor(rewards).to(self.device)
|
||||||
|
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
||||||
|
dones = torch.FloatTensor(dones).to(self.device)
|
||||||
|
|
||||||
|
# Get current Q values
|
||||||
|
current_q_values, _ = self.policy_net(states)
|
||||||
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
|
||||||
|
|
||||||
|
# Get next Q values from target network
|
||||||
|
with torch.no_grad():
|
||||||
|
next_q_values, _ = self.target_net(next_states)
|
||||||
|
next_q_values = next_q_values.max(1)[0]
|
||||||
|
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
||||||
|
|
||||||
|
# Compute loss
|
||||||
|
loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
|
||||||
|
|
||||||
|
# Optimize
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
# Update target network if needed
|
||||||
|
self.update_count += 1
|
||||||
|
if self.update_count % self.target_update == 0:
|
||||||
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||||
|
|
||||||
|
# Decay epsilon
|
||||||
|
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
|
||||||
|
|
||||||
|
return loss.item()
|
||||||
|
|
||||||
|
def save(self, path: str):
|
||||||
|
"""Save model and agent state"""
|
||||||
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
|
||||||
|
# Save policy network
|
||||||
|
self.policy_net.save(f"{path}_policy")
|
||||||
|
|
||||||
|
# Save target network
|
||||||
|
self.target_net.save(f"{path}_target")
|
||||||
|
|
||||||
|
# Save agent state
|
||||||
|
state = {
|
||||||
|
'epsilon': self.epsilon,
|
||||||
|
'update_count': self.update_count,
|
||||||
|
'losses': self.losses,
|
||||||
|
'optimizer_state': self.optimizer.state_dict()
|
||||||
|
}
|
||||||
|
torch.save(state, f"{path}_agent_state.pt")
|
||||||
|
|
||||||
|
def load(self, path: str):
|
||||||
|
"""Load model and agent state"""
|
||||||
|
# Load policy network
|
||||||
|
self.policy_net.load(f"{path}_policy")
|
||||||
|
|
||||||
|
# Load target network
|
||||||
|
self.target_net.load(f"{path}_target")
|
||||||
|
|
||||||
|
# Load agent state
|
||||||
|
state = torch.load(f"{path}_agent_state.pt")
|
||||||
|
self.epsilon = state['epsilon']
|
||||||
|
self.update_count = state['update_count']
|
||||||
|
self.losses = state['losses']
|
||||||
|
self.optimizer.load_state_dict(state['optimizer_state'])
|
130
NN/models/simple_cnn.py
Normal file
130
NN/models/simple_cnn.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
# Configure logger
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class CNNModelPyTorch(nn.Module):
|
||||||
|
"""
|
||||||
|
CNN model for trading signals
|
||||||
|
Simplified version for RL training
|
||||||
|
"""
|
||||||
|
def __init__(self, window_size: int, num_features: int, output_size: int, timeframes: List[str]):
|
||||||
|
super(CNNModelPyTorch, self).__init__()
|
||||||
|
|
||||||
|
self.window_size = window_size
|
||||||
|
self.num_features = num_features
|
||||||
|
self.output_size = output_size
|
||||||
|
self.timeframes = timeframes
|
||||||
|
|
||||||
|
# Device configuration
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
logger.info(f"Using device: {self.device}")
|
||||||
|
|
||||||
|
# Build model
|
||||||
|
self.build_model()
|
||||||
|
|
||||||
|
# Initialize optimizer and scheduler
|
||||||
|
self.optimizer = optim.Adam(self.parameters(), lr=0.001)
|
||||||
|
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
self.optimizer, mode='max', factor=0.5, patience=5, verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move model to device
|
||||||
|
self.to(self.device)
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
"""Build the CNN architecture"""
|
||||||
|
# First Convolutional Layer
|
||||||
|
self.conv1 = nn.Conv1d(
|
||||||
|
in_channels=self.num_features * len(self.timeframes),
|
||||||
|
out_channels=32,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1
|
||||||
|
)
|
||||||
|
self.bn1 = nn.BatchNorm1d(32)
|
||||||
|
|
||||||
|
# Second Convolutional Layer
|
||||||
|
self.conv2 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
|
||||||
|
self.bn2 = nn.BatchNorm1d(64)
|
||||||
|
|
||||||
|
# Third Convolutional Layer
|
||||||
|
self.conv3 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
|
||||||
|
self.bn3 = nn.BatchNorm1d(128)
|
||||||
|
|
||||||
|
# Calculate size after convolutions
|
||||||
|
conv_out_size = self.window_size * 128
|
||||||
|
|
||||||
|
# Fully connected layers
|
||||||
|
self.fc1 = nn.Linear(conv_out_size, 512)
|
||||||
|
self.fc2 = nn.Linear(512, 256)
|
||||||
|
self.fc3 = nn.Linear(256, self.output_size)
|
||||||
|
|
||||||
|
# Additional output for value estimation
|
||||||
|
self.value_fc = nn.Linear(256, 1)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Forward pass through the network"""
|
||||||
|
# Ensure input is on the correct device
|
||||||
|
x = x.to(self.device)
|
||||||
|
|
||||||
|
# Reshape input: [batch, window_size, features] -> [batch, channels, window_size]
|
||||||
|
batch_size = x.size(0)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
# Convolutional layers
|
||||||
|
x = F.relu(self.bn1(self.conv1(x)))
|
||||||
|
x = F.relu(self.bn2(self.conv2(x)))
|
||||||
|
x = F.relu(self.bn3(self.conv3(x)))
|
||||||
|
|
||||||
|
# Flatten
|
||||||
|
x = x.view(batch_size, -1)
|
||||||
|
|
||||||
|
# Fully connected layers
|
||||||
|
x = F.relu(self.fc1(x))
|
||||||
|
x = F.relu(self.fc2(x))
|
||||||
|
|
||||||
|
# Split into advantage and value streams
|
||||||
|
advantage = self.fc3(x)
|
||||||
|
value = self.value_fc(x)
|
||||||
|
|
||||||
|
# Combine value and advantage
|
||||||
|
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
|
||||||
|
|
||||||
|
return q_values, value
|
||||||
|
|
||||||
|
def predict(self, X):
|
||||||
|
"""Make predictions"""
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
# Convert to tensor if not already
|
||||||
|
if not isinstance(X, torch.Tensor):
|
||||||
|
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
|
||||||
|
else:
|
||||||
|
X_tensor = X.to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
q_values, value = self(X_tensor)
|
||||||
|
q_values_np = q_values.cpu().numpy()
|
||||||
|
actions = np.argmax(q_values_np, axis=1)
|
||||||
|
|
||||||
|
return actions, q_values_np
|
||||||
|
|
||||||
|
def save(self, path: str):
|
||||||
|
"""Save model weights"""
|
||||||
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
torch.save(self.state_dict(), f"{path}.pt")
|
||||||
|
logger.info(f"Model saved to {path}.pt")
|
||||||
|
|
||||||
|
def load(self, path: str):
|
||||||
|
"""Load model weights"""
|
||||||
|
self.load_state_dict(torch.load(f"{path}.pt", map_location=self.device))
|
||||||
|
self.eval()
|
||||||
|
logger.info(f"Model loaded from {path}.pt")
|
192
NN/train_rl.py
Normal file
192
NN/train_rl.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import pandas as pd
|
||||||
|
import gym
|
||||||
|
|
||||||
|
# Add parent directory to path
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from NN.utils.data_interface import DataInterface
|
||||||
|
from NN.utils.trading_env import TradingEnvironment
|
||||||
|
from NN.models.dqn_agent import DQNAgent
|
||||||
|
from NN.utils.signal_interpreter import SignalInterpreter
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler('rl_training.log'),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
class RLTradingEnvironment(TradingEnvironment):
|
||||||
|
"""Extended trading environment that reshapes state for CNN"""
|
||||||
|
def __init__(self, data, window_size, num_features, num_timeframes, **kwargs):
|
||||||
|
# Set attributes before parent initialization
|
||||||
|
self.window_size = window_size
|
||||||
|
self.num_features = num_features
|
||||||
|
self.num_timeframes = num_timeframes
|
||||||
|
self.feature_dim = num_features * num_timeframes
|
||||||
|
|
||||||
|
# Initialize parent class
|
||||||
|
super().__init__(data=data, **kwargs)
|
||||||
|
|
||||||
|
# Update observation space for CNN
|
||||||
|
self.observation_space = gym.spaces.Box(
|
||||||
|
low=-np.inf,
|
||||||
|
high=np.inf,
|
||||||
|
shape=(self.window_size, self.feature_dim),
|
||||||
|
dtype=np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_observation(self):
|
||||||
|
"""Get current observation reshaped for CNN"""
|
||||||
|
# Get flattened observation from parent class
|
||||||
|
flat_obs = super()._get_observation()
|
||||||
|
|
||||||
|
# Extract features (exclude close price)
|
||||||
|
features = flat_obs[:-1] # Remove close price
|
||||||
|
|
||||||
|
# Calculate number of complete windows
|
||||||
|
n_windows = len(features) // self.feature_dim
|
||||||
|
if n_windows < self.window_size:
|
||||||
|
# Pad with zeros if not enough data
|
||||||
|
padding = np.zeros((self.window_size - n_windows, self.feature_dim))
|
||||||
|
reshaped = np.vstack([
|
||||||
|
padding,
|
||||||
|
features[-(n_windows * self.feature_dim):].reshape(n_windows, self.feature_dim)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
# Take the most recent window_size windows
|
||||||
|
start_idx = (n_windows - self.window_size) * self.feature_dim
|
||||||
|
reshaped = features[start_idx:].reshape(self.window_size, self.feature_dim)
|
||||||
|
|
||||||
|
return reshaped.astype(np.float32)
|
||||||
|
|
||||||
|
def train_rl():
|
||||||
|
"""
|
||||||
|
Train the RL model using the DQN agent
|
||||||
|
"""
|
||||||
|
# Initialize data interface with BTC/USDT and multiple timeframes
|
||||||
|
timeframes = ['1m', '5m', '15m']
|
||||||
|
window_size = 20
|
||||||
|
data_interface = DataInterface(symbol="BTC/USDT", timeframes=timeframes)
|
||||||
|
|
||||||
|
# Get training data
|
||||||
|
X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data()
|
||||||
|
if X_train is None:
|
||||||
|
logger.error("Failed to get training data")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate feature dimensions
|
||||||
|
num_features = X_train.shape[2] # Number of features per timeframe
|
||||||
|
total_features = num_features * len(timeframes) # Total features across all timeframes
|
||||||
|
|
||||||
|
# Flatten features for environment
|
||||||
|
n_samples = X_train.shape[0]
|
||||||
|
flattened_features = X_train.reshape(n_samples, window_size, -1) # Reshape to (batch, window, features)
|
||||||
|
|
||||||
|
# Create DataFrame with features as separate columns
|
||||||
|
feature_columns = [f'feature_{i}' for i in range(flattened_features.shape[2])]
|
||||||
|
df = pd.DataFrame(flattened_features.reshape(n_samples, -1), columns=feature_columns * window_size)
|
||||||
|
df['close'] = train_prices
|
||||||
|
|
||||||
|
# Create environment
|
||||||
|
env = RLTradingEnvironment(
|
||||||
|
data=df,
|
||||||
|
window_size=window_size,
|
||||||
|
num_features=num_features,
|
||||||
|
num_timeframes=len(timeframes),
|
||||||
|
initial_balance=10000,
|
||||||
|
fee_rate=0.001,
|
||||||
|
max_steps=1000
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create DQN agent
|
||||||
|
agent = DQNAgent(
|
||||||
|
state_size=window_size, # First dimension of observation space
|
||||||
|
action_size=env.action_space.n,
|
||||||
|
window_size=window_size,
|
||||||
|
num_features=num_features,
|
||||||
|
timeframes=timeframes,
|
||||||
|
learning_rate=0.001,
|
||||||
|
gamma=0.99,
|
||||||
|
epsilon=1.0,
|
||||||
|
epsilon_min=0.01,
|
||||||
|
epsilon_decay=0.995,
|
||||||
|
memory_size=10000,
|
||||||
|
batch_size=32,
|
||||||
|
target_update=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
episodes = 1000
|
||||||
|
max_steps = 1000
|
||||||
|
best_reward = float('-inf')
|
||||||
|
best_model_path = 'NN/models/saved/best_rl_model.pth'
|
||||||
|
|
||||||
|
# Create models directory if it doesn't exist
|
||||||
|
os.makedirs(os.path.dirname(best_model_path), exist_ok=True)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
for episode in range(episodes):
|
||||||
|
state = env.reset()
|
||||||
|
total_reward = 0
|
||||||
|
|
||||||
|
for step in range(max_steps):
|
||||||
|
# Get action from agent
|
||||||
|
action = agent.act(state)
|
||||||
|
|
||||||
|
# Take action in environment
|
||||||
|
next_state, reward, done, info = env.step(action)
|
||||||
|
|
||||||
|
# Store experience in agent's memory
|
||||||
|
agent.remember(state, action, reward, next_state, done)
|
||||||
|
|
||||||
|
# Train agent
|
||||||
|
if len(agent.memory) > agent.batch_size:
|
||||||
|
loss = agent.replay()
|
||||||
|
if loss is not None:
|
||||||
|
logger.debug(f"Training loss: {loss:.4f}")
|
||||||
|
|
||||||
|
# Update state and reward
|
||||||
|
state = next_state
|
||||||
|
total_reward += reward
|
||||||
|
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Update epsilon
|
||||||
|
agent.epsilon = max(agent.epsilon_min, agent.epsilon * agent.epsilon_decay)
|
||||||
|
|
||||||
|
# Log episode results
|
||||||
|
logger.info(f"Episode: {episode + 1}/{episodes}")
|
||||||
|
logger.info(f"Total Reward: {total_reward:.2f}")
|
||||||
|
logger.info(f"Final Balance: {info['balance']:.2f}")
|
||||||
|
logger.info(f"Max Drawdown: {info['max_drawdown']:.2%}")
|
||||||
|
logger.info(f"Win Rate: {info['win_rate']:.2%}")
|
||||||
|
logger.info(f"Epsilon: {agent.epsilon:.4f}")
|
||||||
|
|
||||||
|
# Save best model
|
||||||
|
if total_reward > best_reward:
|
||||||
|
best_reward = total_reward
|
||||||
|
agent.save(best_model_path)
|
||||||
|
logger.info(f"New best model saved with reward: {best_reward:.2f}")
|
||||||
|
|
||||||
|
# Save checkpoint every 100 episodes
|
||||||
|
if (episode + 1) % 100 == 0:
|
||||||
|
checkpoint_path = f'NN/models/saved/rl_model_episode_{episode + 1}.pth'
|
||||||
|
agent.save(checkpoint_path)
|
||||||
|
logger.info(f"Checkpoint saved at episode {episode + 1}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_rl()
|
@ -6,6 +6,8 @@ This package contains utility functions and classes used in the neural network t
|
|||||||
- Data Interface: Connects to realtime trading data and processes it for the neural network models
|
- Data Interface: Connects to realtime trading data and processes it for the neural network models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from NN.utils.data_interface import DataInterface
|
from .data_interface import DataInterface
|
||||||
|
from .trading_env import TradingEnvironment
|
||||||
|
from .signal_interpreter import SignalInterpreter
|
||||||
|
|
||||||
__all__ = ['DataInterface']
|
__all__ = ['DataInterface', 'TradingEnvironment', 'SignalInterpreter']
|
@ -13,6 +13,7 @@ import json
|
|||||||
import pickle
|
import pickle
|
||||||
from sklearn.preprocessing import MinMaxScaler
|
from sklearn.preprocessing import MinMaxScaler
|
||||||
import sys
|
import sys
|
||||||
|
import ta
|
||||||
|
|
||||||
# Add project root to sys.path
|
# Add project root to sys.path
|
||||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
@ -534,3 +535,77 @@ class DataInterface:
|
|||||||
timestamp = df['timestamp'].iloc[-1]
|
timestamp = df['timestamp'].iloc[-1]
|
||||||
|
|
||||||
return X, timestamp
|
return X, timestamp
|
||||||
|
|
||||||
|
def get_training_data(self, timeframe='1m', n_candles=5000):
|
||||||
|
"""
|
||||||
|
Get a consolidated dataframe for RL training with OHLCV and technical indicators
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeframe (str): Timeframe to use
|
||||||
|
n_candles (int): Number of candles to fetch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame: Combined dataframe with price data and technical indicators
|
||||||
|
"""
|
||||||
|
# Get historical data
|
||||||
|
df = self.get_historical_data(timeframe=timeframe, n_candles=n_candles, use_cache=True)
|
||||||
|
|
||||||
|
if df is None or len(df) < 100: # Minimum required for indicators
|
||||||
|
logger.error(f"Not enough data for RL training (need at least 100 candles)")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Calculate technical indicators
|
||||||
|
try:
|
||||||
|
# Add RSI (14)
|
||||||
|
df['rsi'] = ta.rsi(df['close'], length=14)
|
||||||
|
|
||||||
|
# Add MACD
|
||||||
|
macd = ta.macd(df['close'])
|
||||||
|
df['macd'] = macd['MACD_12_26_9']
|
||||||
|
df['macd_signal'] = macd['MACDs_12_26_9']
|
||||||
|
df['macd_hist'] = macd['MACDh_12_26_9']
|
||||||
|
|
||||||
|
# Add Bollinger Bands
|
||||||
|
bbands = ta.bbands(df['close'], length=20)
|
||||||
|
df['bb_upper'] = bbands['BBU_20_2.0']
|
||||||
|
df['bb_middle'] = bbands['BBM_20_2.0']
|
||||||
|
df['bb_lower'] = bbands['BBL_20_2.0']
|
||||||
|
|
||||||
|
# Add ATR (Average True Range)
|
||||||
|
df['atr'] = ta.atr(df['high'], df['low'], df['close'], length=14)
|
||||||
|
|
||||||
|
# Add moving averages
|
||||||
|
df['sma_20'] = ta.sma(df['close'], length=20)
|
||||||
|
df['sma_50'] = ta.sma(df['close'], length=50)
|
||||||
|
df['ema_20'] = ta.ema(df['close'], length=20)
|
||||||
|
|
||||||
|
# Add OBV (On-Balance Volume)
|
||||||
|
df['obv'] = ta.obv(df['close'], df['volume'])
|
||||||
|
|
||||||
|
# Add momentum indicators
|
||||||
|
df['mom'] = ta.mom(df['close'], length=10)
|
||||||
|
|
||||||
|
# Normalize price to previous close
|
||||||
|
df['close_norm'] = df['close'] / df['close'].shift(1) - 1
|
||||||
|
df['high_norm'] = df['high'] / df['close'].shift(1) - 1
|
||||||
|
df['low_norm'] = df['low'] / df['close'].shift(1) - 1
|
||||||
|
|
||||||
|
# Volatility features
|
||||||
|
df['volatility'] = df['high'] / df['low'] - 1
|
||||||
|
|
||||||
|
# Volume features
|
||||||
|
df['volume_norm'] = df['volume'] / df['volume'].rolling(20).mean()
|
||||||
|
|
||||||
|
# Calculate returns
|
||||||
|
df['returns_1'] = df['close'].pct_change(1)
|
||||||
|
df['returns_5'] = df['close'].pct_change(5)
|
||||||
|
df['returns_10'] = df['close'].pct_change(10)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calculating technical indicators: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Drop NaN values
|
||||||
|
df = df.dropna()
|
||||||
|
|
||||||
|
return df
|
||||||
|
162
NN/utils/trading_env.py
Normal file
162
NN/utils/trading_env.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
import numpy as np
|
||||||
|
import gym
|
||||||
|
from gym import spaces
|
||||||
|
from typing import Dict, Tuple, List
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
class TradingEnvironment(gym.Env):
|
||||||
|
"""
|
||||||
|
Custom trading environment for reinforcement learning
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
data: pd.DataFrame,
|
||||||
|
initial_balance: float = 100.0,
|
||||||
|
fee_rate: float = 0.0002,
|
||||||
|
max_steps: int = 1000):
|
||||||
|
super(TradingEnvironment, self).__init__()
|
||||||
|
|
||||||
|
self.data = data
|
||||||
|
self.initial_balance = initial_balance
|
||||||
|
self.fee_rate = fee_rate
|
||||||
|
self.max_steps = max_steps
|
||||||
|
|
||||||
|
# Action space: 0 (SELL), 1 (HOLD), 2 (BUY)
|
||||||
|
self.action_space = spaces.Discrete(3)
|
||||||
|
|
||||||
|
# Observation space: price data, technical indicators, and account state
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=-np.inf,
|
||||||
|
high=np.inf,
|
||||||
|
shape=(data.shape[1],), # Number of features
|
||||||
|
dtype=np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize state
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self) -> np.ndarray:
|
||||||
|
"""Reset the environment to initial state"""
|
||||||
|
self.current_step = 0
|
||||||
|
self.balance = self.initial_balance
|
||||||
|
self.position = 0 # 0: no position, 1: long position
|
||||||
|
self.entry_price = 0
|
||||||
|
self.total_trades = 0
|
||||||
|
self.winning_trades = 0
|
||||||
|
self.total_pnl = 0
|
||||||
|
self.balance_history = [self.initial_balance]
|
||||||
|
self.max_balance = self.initial_balance
|
||||||
|
|
||||||
|
return self._get_observation()
|
||||||
|
|
||||||
|
def _get_observation(self) -> np.ndarray:
|
||||||
|
"""Get current observation state"""
|
||||||
|
return self.data.iloc[self.current_step].values
|
||||||
|
|
||||||
|
def _calculate_reward(self, action: int) -> float:
|
||||||
|
"""Calculate reward based on action and outcome"""
|
||||||
|
current_price = self.data.iloc[self.current_step]['close']
|
||||||
|
|
||||||
|
# If we have an open position
|
||||||
|
if self.position != 0:
|
||||||
|
# Calculate PnL
|
||||||
|
pnl = self.position * (current_price - self.entry_price) / self.entry_price
|
||||||
|
fees = self.fee_rate * 2 # Entry and exit fees
|
||||||
|
|
||||||
|
# Close position
|
||||||
|
if (action == 0 and self.position > 0) or (action == 2 and self.position < 0):
|
||||||
|
net_pnl = pnl - fees
|
||||||
|
self.total_pnl += net_pnl
|
||||||
|
self.balance *= (1 + net_pnl)
|
||||||
|
self.balance_history.append(self.balance)
|
||||||
|
self.max_balance = max(self.max_balance, self.balance)
|
||||||
|
|
||||||
|
self.total_trades += 1
|
||||||
|
if net_pnl > 0:
|
||||||
|
self.winning_trades += 1
|
||||||
|
|
||||||
|
# Reward based on PnL
|
||||||
|
reward = net_pnl * 100 # Scale up for better learning
|
||||||
|
|
||||||
|
# Additional reward for win rate
|
||||||
|
win_rate = self.winning_trades / max(1, self.total_trades)
|
||||||
|
reward += win_rate * 0.1
|
||||||
|
|
||||||
|
self.position = 0
|
||||||
|
return reward
|
||||||
|
|
||||||
|
# Hold position
|
||||||
|
return pnl * 0.1 # Small reward for holding profitable positions
|
||||||
|
|
||||||
|
# No position
|
||||||
|
if action == 1: # HOLD
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Open new position
|
||||||
|
if action in [0, 2]: # SELL or BUY
|
||||||
|
self.position = -1 if action == 0 else 1
|
||||||
|
self.entry_price = current_price
|
||||||
|
return -self.fee_rate # Small penalty for trading
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
|
||||||
|
"""Execute one step in the environment"""
|
||||||
|
# Calculate reward
|
||||||
|
reward = self._calculate_reward(action)
|
||||||
|
|
||||||
|
# Move to next step
|
||||||
|
self.current_step += 1
|
||||||
|
|
||||||
|
# Check if episode is done
|
||||||
|
done = self.current_step >= min(self.max_steps - 1, len(self.data) - 1)
|
||||||
|
|
||||||
|
# Get next observation
|
||||||
|
observation = self._get_observation()
|
||||||
|
|
||||||
|
# Calculate max drawdown
|
||||||
|
max_drawdown = 0
|
||||||
|
if len(self.balance_history) > 1:
|
||||||
|
peak = self.balance_history[0]
|
||||||
|
for balance in self.balance_history:
|
||||||
|
peak = max(peak, balance)
|
||||||
|
drawdown = (peak - balance) / peak
|
||||||
|
max_drawdown = max(max_drawdown, drawdown)
|
||||||
|
|
||||||
|
# Additional info
|
||||||
|
info = {
|
||||||
|
'balance': self.balance,
|
||||||
|
'position': self.position,
|
||||||
|
'total_trades': self.total_trades,
|
||||||
|
'win_rate': self.winning_trades / max(1, self.total_trades),
|
||||||
|
'total_pnl': self.total_pnl,
|
||||||
|
'max_drawdown': max_drawdown
|
||||||
|
}
|
||||||
|
|
||||||
|
return observation, reward, done, info
|
||||||
|
|
||||||
|
def render(self, mode='human'):
|
||||||
|
"""Render the environment"""
|
||||||
|
if mode == 'human':
|
||||||
|
print(f"Step: {self.current_step}")
|
||||||
|
print(f"Balance: ${self.balance:.2f}")
|
||||||
|
print(f"Position: {self.position}")
|
||||||
|
print(f"Total Trades: {self.total_trades}")
|
||||||
|
print(f"Win Rate: {self.winning_trades/max(1, self.total_trades):.2%}")
|
||||||
|
print(f"Total PnL: ${self.total_pnl:.2f}")
|
||||||
|
print(f"Max Drawdown: {self._calculate_max_drawdown():.2%}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
def _calculate_max_drawdown(self):
|
||||||
|
"""Calculate maximum drawdown from balance history"""
|
||||||
|
if len(self.balance_history) <= 1:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
peak = self.balance_history[0]
|
||||||
|
max_drawdown = 0.0
|
||||||
|
|
||||||
|
for balance in self.balance_history:
|
||||||
|
peak = max(peak, balance)
|
||||||
|
drawdown = (peak - balance) / peak
|
||||||
|
max_drawdown = max(max_drawdown, drawdown)
|
||||||
|
|
||||||
|
return max_drawdown
|
@ -49,3 +49,7 @@ python NN/realtime-main.py --mode train --model-type cnn --framework pytorch --s
|
|||||||
----------
|
----------
|
||||||
$ python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --epochs 10
|
$ python -c "import sys; sys.path.append('f:/projects/gogo2'); from NN.realtime_main import main; main()" --mode train --model-type cnn --epochs 10
|
||||||
python test_model.py
|
python test_model.py
|
||||||
|
|
||||||
|
|
||||||
|
python train_with_realtime_ticks.py
|
||||||
|
python NN/train_rl.py
|
704
train_with_realtime_ticks.py
Normal file
704
train_with_realtime_ticks.py
Normal file
@ -0,0 +1,704 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Real-time training with tick data and multiple timeframes for context
|
||||||
|
This script uses streaming tick data for fast adaptation while maintaining higher timeframe context
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import signal
|
||||||
|
import threading
|
||||||
|
import asyncio
|
||||||
|
import websockets
|
||||||
|
from collections import deque
|
||||||
|
import pandas as pd
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import io
|
||||||
|
|
||||||
|
# Add the project root to path
|
||||||
|
sys.path.append(os.path.abspath('.'))
|
||||||
|
|
||||||
|
# Configure logging with timestamp in filename
|
||||||
|
log_dir = "logs"
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
log_file = os.path.join(log_dir, f"realtime_ticks_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(log_file),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger('realtime_ticks_training')
|
||||||
|
|
||||||
|
# Import the model and data interfaces
|
||||||
|
from NN.models.cnn_model_pytorch import CNNModelPyTorch
|
||||||
|
from NN.utils.data_interface import DataInterface
|
||||||
|
from NN.utils.signal_interpreter import SignalInterpreter
|
||||||
|
|
||||||
|
# Global variables for graceful shutdown
|
||||||
|
running = True
|
||||||
|
training_stats = {
|
||||||
|
"epochs_completed": 0,
|
||||||
|
"best_val_pnl": -float('inf'),
|
||||||
|
"best_epoch": 0,
|
||||||
|
"best_win_rate": 0,
|
||||||
|
"training_started": datetime.now().isoformat(),
|
||||||
|
"last_update": datetime.now().isoformat(),
|
||||||
|
"epochs": [],
|
||||||
|
"cumulative_pnl": {
|
||||||
|
"train": 0.0,
|
||||||
|
"val": 0.0
|
||||||
|
},
|
||||||
|
"total_trades": {
|
||||||
|
"train": 0,
|
||||||
|
"val": 0
|
||||||
|
},
|
||||||
|
"total_wins": {
|
||||||
|
"train": 0,
|
||||||
|
"val": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class TickDataProcessor:
|
||||||
|
"""Process and store real-time tick data"""
|
||||||
|
def __init__(self, symbol: str, max_ticks: int = 10000):
|
||||||
|
self.symbol = symbol
|
||||||
|
self.ticks = deque(maxlen=max_ticks)
|
||||||
|
self.candle_cache = {
|
||||||
|
'1s': deque(maxlen=5000),
|
||||||
|
'1m': deque(maxlen=5000),
|
||||||
|
'5m': deque(maxlen=5000),
|
||||||
|
'15m': deque(maxlen=5000)
|
||||||
|
}
|
||||||
|
self.last_tick = None
|
||||||
|
self.ws_url = f"wss://stream.binance.com:9443/ws/{symbol.replace('/', '').lower()}@trade"
|
||||||
|
self.ws = None
|
||||||
|
self.running = False
|
||||||
|
self.data_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def start_websocket(self):
|
||||||
|
"""Start WebSocket connection and receive tick data"""
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
async with websockets.connect(self.ws_url) as ws:
|
||||||
|
self.ws = ws
|
||||||
|
logger.info("WebSocket connected")
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
message = await ws.recv()
|
||||||
|
data = json.loads(message)
|
||||||
|
|
||||||
|
if 'e' in data and data['e'] == 'trade':
|
||||||
|
tick = {
|
||||||
|
'timestamp': data['T'],
|
||||||
|
'price': float(data['p']),
|
||||||
|
'volume': float(data['q']),
|
||||||
|
'symbol': self.symbol
|
||||||
|
}
|
||||||
|
self.ticks.append(tick)
|
||||||
|
await self.data_queue.put(tick)
|
||||||
|
|
||||||
|
except websockets.exceptions.ConnectionClosed:
|
||||||
|
logger.warning("WebSocket connection closed")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error receiving WebSocket message: {str(e)}")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket connection error: {str(e)}")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
def process_tick(self, tick: Dict):
|
||||||
|
"""Process a single tick into candles for all timeframes"""
|
||||||
|
timestamp = tick['timestamp']
|
||||||
|
price = tick['price']
|
||||||
|
volume = tick['volume']
|
||||||
|
|
||||||
|
for timeframe in self.candle_cache.keys():
|
||||||
|
interval = self._get_interval_seconds(timeframe)
|
||||||
|
if interval is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Round timestamp to nearest candle interval
|
||||||
|
candle_ts = int(timestamp // (interval * 1000)) * (interval * 1000)
|
||||||
|
|
||||||
|
# Get or create candle for this timeframe
|
||||||
|
if not self.candle_cache[timeframe]:
|
||||||
|
# First candle for this timeframe
|
||||||
|
candle = {
|
||||||
|
'timestamp': candle_ts,
|
||||||
|
'open': price,
|
||||||
|
'high': price,
|
||||||
|
'low': price,
|
||||||
|
'close': price,
|
||||||
|
'volume': volume
|
||||||
|
}
|
||||||
|
self.candle_cache[timeframe].append(candle)
|
||||||
|
else:
|
||||||
|
# Update existing candle
|
||||||
|
last_candle = self.candle_cache[timeframe][-1]
|
||||||
|
|
||||||
|
if last_candle['timestamp'] == candle_ts:
|
||||||
|
# Update current candle
|
||||||
|
last_candle['high'] = max(last_candle['high'], price)
|
||||||
|
last_candle['low'] = min(last_candle['low'], price)
|
||||||
|
last_candle['close'] = price
|
||||||
|
last_candle['volume'] += volume
|
||||||
|
else:
|
||||||
|
# New candle
|
||||||
|
candle = {
|
||||||
|
'timestamp': candle_ts,
|
||||||
|
'open': price,
|
||||||
|
'high': price,
|
||||||
|
'low': price,
|
||||||
|
'close': price,
|
||||||
|
'volume': volume
|
||||||
|
}
|
||||||
|
self.candle_cache[timeframe].append(candle)
|
||||||
|
|
||||||
|
def _get_interval_seconds(self, timeframe: str) -> Optional[int]:
|
||||||
|
"""Convert timeframe string to seconds"""
|
||||||
|
try:
|
||||||
|
value = int(timeframe[:-1])
|
||||||
|
unit = timeframe[-1]
|
||||||
|
if unit == 's':
|
||||||
|
return value
|
||||||
|
elif unit == 'm':
|
||||||
|
return value * 60
|
||||||
|
elif unit == 'h':
|
||||||
|
return value * 3600
|
||||||
|
elif unit == 'd':
|
||||||
|
return value * 86400
|
||||||
|
return None
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_candles(self, timeframe: str) -> pd.DataFrame:
|
||||||
|
"""Get candles for a specific timeframe"""
|
||||||
|
if timeframe not in self.candle_cache:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
candles = list(self.candle_cache[timeframe])
|
||||||
|
if not candles:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
df = pd.DataFrame(candles)
|
||||||
|
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||||||
|
return df
|
||||||
|
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
"""Handle CTRL+C to gracefully exit training"""
|
||||||
|
global running
|
||||||
|
logger.info("Received interrupt signal. Finishing current epoch and saving model...")
|
||||||
|
running = False
|
||||||
|
|
||||||
|
# Register signal handler
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
def save_training_stats(stats, filepath="NN/models/saved/realtime_ticks_training_stats.json"):
|
||||||
|
"""Save training statistics to file"""
|
||||||
|
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||||
|
|
||||||
|
with open(filepath, 'w') as f:
|
||||||
|
json.dump(stats, f, indent=2)
|
||||||
|
|
||||||
|
logger.info(f"Training statistics saved to {filepath}")
|
||||||
|
|
||||||
|
def calculate_pnl_with_fees(predictions, prices, position_size=1.0, fee_rate=0.0002, initial_balance=100.0):
|
||||||
|
"""
|
||||||
|
Calculate PnL including trading fees and track USD balance
|
||||||
|
fee_rate: 0.02% per trade (both entry and exit)
|
||||||
|
initial_balance: Starting balance in USD (default: 100.0)
|
||||||
|
"""
|
||||||
|
trades = []
|
||||||
|
pnl = 0
|
||||||
|
win_count = 0
|
||||||
|
total_trades = 0
|
||||||
|
current_balance = initial_balance
|
||||||
|
balance_history = [initial_balance]
|
||||||
|
|
||||||
|
for i in range(len(predictions)):
|
||||||
|
if predictions[i] == 2: # BUY
|
||||||
|
entry_price = prices[i]
|
||||||
|
# Look ahead for exit
|
||||||
|
for j in range(i + 1, min(i + 8, len(prices))):
|
||||||
|
if predictions[j] == 0: # SELL
|
||||||
|
exit_price = prices[j]
|
||||||
|
# Calculate position size in USD
|
||||||
|
position_usd = current_balance * position_size
|
||||||
|
|
||||||
|
# Calculate raw PnL in USD
|
||||||
|
raw_pnl_usd = position_usd * ((exit_price - entry_price) / entry_price)
|
||||||
|
|
||||||
|
# Calculate fees in USD (both entry and exit)
|
||||||
|
entry_fee_usd = position_usd * fee_rate
|
||||||
|
exit_fee_usd = position_usd * fee_rate
|
||||||
|
total_fees_usd = entry_fee_usd + exit_fee_usd
|
||||||
|
|
||||||
|
# Calculate net PnL in USD after fees
|
||||||
|
net_pnl_usd = raw_pnl_usd - total_fees_usd
|
||||||
|
|
||||||
|
# Update balance
|
||||||
|
current_balance += net_pnl_usd
|
||||||
|
balance_history.append(current_balance)
|
||||||
|
|
||||||
|
trades.append({
|
||||||
|
'entry_idx': i,
|
||||||
|
'exit_idx': j,
|
||||||
|
'entry_price': entry_price,
|
||||||
|
'exit_price': exit_price,
|
||||||
|
'position_size_usd': position_usd,
|
||||||
|
'raw_pnl_usd': raw_pnl_usd,
|
||||||
|
'fees_usd': total_fees_usd,
|
||||||
|
'net_pnl_usd': net_pnl_usd,
|
||||||
|
'balance': current_balance
|
||||||
|
})
|
||||||
|
|
||||||
|
pnl += net_pnl_usd / initial_balance # Convert to percentage
|
||||||
|
if net_pnl_usd > 0:
|
||||||
|
win_count += 1
|
||||||
|
total_trades += 1
|
||||||
|
break
|
||||||
|
|
||||||
|
win_rate = win_count / total_trades if total_trades > 0 else 0
|
||||||
|
final_balance = current_balance
|
||||||
|
total_return = (final_balance - initial_balance) / initial_balance * 100 # Percentage return
|
||||||
|
|
||||||
|
return pnl, win_rate, trades, balance_history, total_return
|
||||||
|
|
||||||
|
def calculate_max_drawdown(balance_history):
|
||||||
|
"""Calculate maximum drawdown from balance history"""
|
||||||
|
if not balance_history:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
peak = balance_history[0]
|
||||||
|
max_drawdown = 0.0
|
||||||
|
|
||||||
|
for balance in balance_history:
|
||||||
|
if balance > peak:
|
||||||
|
peak = balance
|
||||||
|
drawdown = (peak - balance) / peak * 100
|
||||||
|
max_drawdown = max(max_drawdown, drawdown)
|
||||||
|
|
||||||
|
return max_drawdown
|
||||||
|
|
||||||
|
async def run_realtime_training():
|
||||||
|
"""
|
||||||
|
Run continuous training with real-time tick data and multiple timeframes
|
||||||
|
"""
|
||||||
|
global running, training_stats
|
||||||
|
|
||||||
|
# Configuration parameters
|
||||||
|
symbol = "BTC/USDT"
|
||||||
|
timeframes = ["1s", "1m", "5m", "15m"] # Include 1s for tick-based training
|
||||||
|
window_size = 24 # Larger window size for capturing more patterns
|
||||||
|
output_size = 3 # BUY/HOLD/SELL
|
||||||
|
batch_size = 64 # Batch size for training
|
||||||
|
|
||||||
|
# Real-time configuration
|
||||||
|
data_refresh_interval = 60 # Refresh data every minute
|
||||||
|
checkpoint_interval = 3600 # Save checkpoint every hour
|
||||||
|
max_training_time = float('inf') # Run indefinitely
|
||||||
|
|
||||||
|
# Initialize TensorBoard writer
|
||||||
|
tensorboard_dir = "runs/realtime_ticks_training"
|
||||||
|
os.makedirs(tensorboard_dir, exist_ok=True)
|
||||||
|
writer = SummaryWriter(tensorboard_dir)
|
||||||
|
|
||||||
|
# Initialize training start time
|
||||||
|
start_time = time.time()
|
||||||
|
last_checkpoint_time = start_time
|
||||||
|
last_data_refresh_time = start_time
|
||||||
|
|
||||||
|
logger.info(f"Starting continuous real-time training with tick data for {symbol}")
|
||||||
|
logger.info(f"Configuration: timeframes={timeframes}, window_size={window_size}, batch_size={batch_size}")
|
||||||
|
logger.info(f"Data will refresh every {data_refresh_interval} seconds")
|
||||||
|
logger.info(f"Checkpoints will be saved every {checkpoint_interval} seconds")
|
||||||
|
logger.info(f"TensorBoard logs will be saved to {tensorboard_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize tick data processor
|
||||||
|
tick_processor = TickDataProcessor(symbol)
|
||||||
|
tick_processor.running = True
|
||||||
|
|
||||||
|
# Start WebSocket connection in background
|
||||||
|
websocket_task = asyncio.create_task(tick_processor.start_websocket())
|
||||||
|
|
||||||
|
# Initialize data interface
|
||||||
|
logger.info("Initializing data interface...")
|
||||||
|
data_interface = DataInterface(
|
||||||
|
symbol=symbol,
|
||||||
|
timeframes=timeframes
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
num_features = data_interface.get_feature_count()
|
||||||
|
logger.info(f"Initializing model with {num_features} features")
|
||||||
|
|
||||||
|
model = CNNModelPyTorch(
|
||||||
|
window_size=window_size,
|
||||||
|
num_features=num_features,
|
||||||
|
output_size=output_size,
|
||||||
|
timeframes=timeframes
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to load existing model
|
||||||
|
model_path = "NN/models/saved/optimized_short_term_model_best.pt"
|
||||||
|
try:
|
||||||
|
if os.path.exists(model_path):
|
||||||
|
logger.info(f"Loading existing model from {model_path}")
|
||||||
|
model.load(model_path)
|
||||||
|
logger.info("Model loaded successfully")
|
||||||
|
else:
|
||||||
|
logger.info("No existing model found. Starting with a new model.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading model: {str(e)}")
|
||||||
|
logger.info("Starting with a new model.")
|
||||||
|
|
||||||
|
# Initialize signal interpreter
|
||||||
|
signal_interpreter = SignalInterpreter(config={
|
||||||
|
'buy_threshold': 0.55, # Lower threshold to catch more opportunities
|
||||||
|
'sell_threshold': 0.55, # Lower threshold to catch more opportunities
|
||||||
|
'hold_threshold': 0.65, # Lower threshold to reduce missed trades
|
||||||
|
'trend_filter_enabled': True,
|
||||||
|
'volume_filter_enabled': True,
|
||||||
|
'min_confidence': 0.45 # Minimum confidence to consider a trade
|
||||||
|
})
|
||||||
|
|
||||||
|
# Create checkpoint directory
|
||||||
|
checkpoint_dir = "NN/models/saved/realtime_ticks_checkpoints"
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Track metrics
|
||||||
|
epoch = 0
|
||||||
|
best_val_pnl = -float('inf')
|
||||||
|
best_win_rate = 0
|
||||||
|
best_epoch = 0
|
||||||
|
consecutive_failures = 0
|
||||||
|
max_consecutive_failures = 5
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
while running:
|
||||||
|
try:
|
||||||
|
epoch += 1
|
||||||
|
epoch_start = time.time()
|
||||||
|
|
||||||
|
logger.info(f"Epoch {epoch} - Starting at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
|
||||||
|
# Process any new ticks
|
||||||
|
while not tick_processor.data_queue.empty():
|
||||||
|
tick = await tick_processor.data_queue.get()
|
||||||
|
tick_processor.process_tick(tick)
|
||||||
|
|
||||||
|
# Check if we need to refresh data
|
||||||
|
if time.time() - last_data_refresh_time > data_refresh_interval:
|
||||||
|
logger.info("Refreshing training data...")
|
||||||
|
last_data_refresh_time = time.time()
|
||||||
|
|
||||||
|
# Get candles for all timeframes
|
||||||
|
candles_data = {}
|
||||||
|
for timeframe in timeframes:
|
||||||
|
df = tick_processor.get_candles(timeframe)
|
||||||
|
if not df.empty:
|
||||||
|
candles_data[timeframe] = df
|
||||||
|
|
||||||
|
# Prepare training data with multiple timeframes
|
||||||
|
X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data(
|
||||||
|
refresh=True,
|
||||||
|
refresh_interval=data_refresh_interval
|
||||||
|
)
|
||||||
|
|
||||||
|
if X_train is None or y_train is None:
|
||||||
|
logger.warning("Failed to prepare training data. Using previous data.")
|
||||||
|
consecutive_failures += 1
|
||||||
|
if consecutive_failures >= max_consecutive_failures:
|
||||||
|
logger.error("Too many consecutive failures. Stopping training.")
|
||||||
|
break
|
||||||
|
await asyncio.sleep(5) # Wait before retrying
|
||||||
|
continue
|
||||||
|
|
||||||
|
consecutive_failures = 0 # Reset failure counter on success
|
||||||
|
logger.info(f"Training data prepared - X shape: {X_train.shape}, y shape: {y_train.shape}")
|
||||||
|
|
||||||
|
# Calculate future prices for profitability-focused loss function
|
||||||
|
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=8)
|
||||||
|
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=8)
|
||||||
|
|
||||||
|
# Train one epoch
|
||||||
|
train_action_loss, train_price_loss, train_acc = model.train_epoch(
|
||||||
|
X_train, y_train, train_future_prices, batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
val_action_loss, val_price_loss, val_acc = model.evaluate(
|
||||||
|
X_val, y_val, val_future_prices
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Epoch {epoch} results:")
|
||||||
|
logger.info(f" Train - Loss: {train_action_loss:.4f}, Accuracy: {train_acc:.4f}")
|
||||||
|
logger.info(f" Valid - Loss: {val_action_loss:.4f}, Accuracy: {val_acc:.4f}")
|
||||||
|
|
||||||
|
# Get predictions for PnL calculation
|
||||||
|
train_action_probs, train_price_preds = model.predict(X_train)
|
||||||
|
val_action_probs, val_price_preds = model.predict(X_val)
|
||||||
|
|
||||||
|
# Convert probabilities to actions
|
||||||
|
train_preds = np.argmax(train_action_probs, axis=1)
|
||||||
|
val_preds = np.argmax(val_action_probs, axis=1)
|
||||||
|
|
||||||
|
# Track signal distribution
|
||||||
|
train_buy_count = np.sum(train_preds == 2)
|
||||||
|
train_sell_count = np.sum(train_preds == 0)
|
||||||
|
train_hold_count = np.sum(train_preds == 1)
|
||||||
|
|
||||||
|
val_buy_count = np.sum(val_preds == 2)
|
||||||
|
val_sell_count = np.sum(val_preds == 0)
|
||||||
|
val_hold_count = np.sum(val_preds == 1)
|
||||||
|
|
||||||
|
signal_dist = {
|
||||||
|
"train": {
|
||||||
|
"BUY": float(train_buy_count / len(train_preds)) if len(train_preds) > 0 else 0,
|
||||||
|
"SELL": float(train_sell_count / len(train_preds)) if len(train_preds) > 0 else 0,
|
||||||
|
"HOLD": float(train_hold_count / len(train_preds)) if len(train_preds) > 0 else 0
|
||||||
|
},
|
||||||
|
"val": {
|
||||||
|
"BUY": float(val_buy_count / len(val_preds)) if len(val_preds) > 0 else 0,
|
||||||
|
"SELL": float(val_sell_count / len(val_preds)) if len(val_preds) > 0 else 0,
|
||||||
|
"HOLD": float(val_hold_count / len(val_preds)) if len(val_preds) > 0 else 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Calculate PnL and win rates with different position sizes
|
||||||
|
position_sizes = [0.1, 0.25, 0.5, 1.0, 2.0]
|
||||||
|
best_position_train_pnl = -float('inf')
|
||||||
|
best_position_val_pnl = -float('inf')
|
||||||
|
best_position_train_wr = 0
|
||||||
|
best_position_val_wr = 0
|
||||||
|
best_position_size = 1.0
|
||||||
|
|
||||||
|
for position_size in position_sizes:
|
||||||
|
train_pnl, train_win_rate, train_trades, train_balance_history, train_total_return = calculate_pnl_with_fees(
|
||||||
|
train_preds, train_prices, position_size=position_size
|
||||||
|
)
|
||||||
|
val_pnl, val_win_rate, val_trades, val_balance_history, val_total_return = calculate_pnl_with_fees(
|
||||||
|
val_preds, val_prices, position_size=position_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update cumulative PnL and trade statistics
|
||||||
|
training_stats["cumulative_pnl"]["train"] += train_pnl
|
||||||
|
training_stats["cumulative_pnl"]["val"] += val_pnl
|
||||||
|
training_stats["total_trades"]["train"] += len(train_trades)
|
||||||
|
training_stats["total_trades"]["val"] += len(val_trades)
|
||||||
|
training_stats["total_wins"]["train"] += sum(1 for t in train_trades if t['net_pnl_usd'] > 0)
|
||||||
|
training_stats["total_wins"]["val"] += sum(1 for t in val_trades if t['net_pnl_usd'] > 0)
|
||||||
|
|
||||||
|
# Calculate average fees per trade
|
||||||
|
avg_train_fees = np.mean([t['fees_usd'] for t in train_trades]) if train_trades else 0
|
||||||
|
avg_val_fees = np.mean([t['fees_usd'] for t in val_trades]) if val_trades else 0
|
||||||
|
|
||||||
|
# Calculate max drawdown
|
||||||
|
train_drawdown = calculate_max_drawdown(train_balance_history)
|
||||||
|
val_drawdown = calculate_max_drawdown(val_balance_history)
|
||||||
|
|
||||||
|
# Calculate overall win rate
|
||||||
|
overall_train_wr = training_stats["total_wins"]["train"] / training_stats["total_trades"]["train"] if training_stats["total_trades"]["train"] > 0 else 0
|
||||||
|
overall_val_wr = training_stats["total_wins"]["val"] / training_stats["total_trades"]["val"] if training_stats["total_trades"]["val"] > 0 else 0
|
||||||
|
|
||||||
|
logger.info(f" Position Size: {position_size}")
|
||||||
|
logger.info(f" Train - PnL: {train_pnl:.4f}, Win Rate: {train_win_rate:.4f}, Trades: {len(train_trades)}")
|
||||||
|
logger.info(f" Train - Total Return: {train_total_return:.2f}%, Max Drawdown: {train_drawdown:.2f}%")
|
||||||
|
logger.info(f" Train - Avg Fees: ${avg_train_fees:.2f} per trade")
|
||||||
|
logger.info(f" Train - Cumulative PnL: {training_stats['cumulative_pnl']['train']:.4f}, Overall WR: {overall_train_wr:.4f}")
|
||||||
|
logger.info(f" Valid - PnL: {val_pnl:.4f}, Win Rate: {val_win_rate:.4f}, Trades: {len(val_trades)}")
|
||||||
|
logger.info(f" Valid - Total Return: {val_total_return:.2f}%, Max Drawdown: {val_drawdown:.2f}%")
|
||||||
|
logger.info(f" Valid - Avg Fees: ${avg_val_fees:.2f} per trade")
|
||||||
|
logger.info(f" Valid - Cumulative PnL: {training_stats['cumulative_pnl']['val']:.4f}, Overall WR: {overall_val_wr:.4f}")
|
||||||
|
|
||||||
|
# Log to TensorBoard
|
||||||
|
writer.add_scalar(f'Balance/train/position_{position_size}', train_balance_history[-1], epoch)
|
||||||
|
writer.add_scalar(f'Balance/validation/position_{position_size}', val_balance_history[-1], epoch)
|
||||||
|
writer.add_scalar(f'Return/train/position_{position_size}', train_total_return, epoch)
|
||||||
|
writer.add_scalar(f'Return/validation/position_{position_size}', val_total_return, epoch)
|
||||||
|
writer.add_scalar(f'Drawdown/train/position_{position_size}', train_drawdown, epoch)
|
||||||
|
writer.add_scalar(f'Drawdown/validation/position_{position_size}', val_drawdown, epoch)
|
||||||
|
writer.add_scalar(f'CumulativePnL/train/position_{position_size}', training_stats["cumulative_pnl"]["train"], epoch)
|
||||||
|
writer.add_scalar(f'CumulativePnL/validation/position_{position_size}', training_stats["cumulative_pnl"]["val"], epoch)
|
||||||
|
writer.add_scalar(f'OverallWinRate/train/position_{position_size}', overall_train_wr, epoch)
|
||||||
|
writer.add_scalar(f'OverallWinRate/validation/position_{position_size}', overall_val_wr, epoch)
|
||||||
|
|
||||||
|
# Track best position size for this epoch
|
||||||
|
if val_pnl > best_position_val_pnl:
|
||||||
|
best_position_val_pnl = val_pnl
|
||||||
|
best_position_val_wr = val_win_rate
|
||||||
|
best_position_size = position_size
|
||||||
|
|
||||||
|
if train_pnl > best_position_train_pnl:
|
||||||
|
best_position_train_pnl = train_pnl
|
||||||
|
best_position_train_wr = train_win_rate
|
||||||
|
|
||||||
|
# Track best model overall (using position size 1.0 as reference)
|
||||||
|
if val_pnl > best_val_pnl and position_size == 1.0:
|
||||||
|
best_val_pnl = val_pnl
|
||||||
|
best_win_rate = val_win_rate
|
||||||
|
best_epoch = epoch
|
||||||
|
logger.info(f" New best validation PnL: {best_val_pnl:.4f} at epoch {best_epoch}")
|
||||||
|
|
||||||
|
# Save the best model
|
||||||
|
model.save(f"NN/models/saved/optimized_short_term_model_ticks_best")
|
||||||
|
|
||||||
|
# Store epoch metrics with cumulative statistics
|
||||||
|
epoch_metrics = {
|
||||||
|
"epoch": epoch,
|
||||||
|
"train_loss": float(train_action_loss),
|
||||||
|
"val_loss": float(val_action_loss),
|
||||||
|
"train_acc": float(train_acc),
|
||||||
|
"val_acc": float(val_acc),
|
||||||
|
"train_pnl": float(best_position_train_pnl),
|
||||||
|
"val_pnl": float(best_position_val_pnl),
|
||||||
|
"train_win_rate": float(best_position_train_wr),
|
||||||
|
"val_win_rate": float(best_position_val_wr),
|
||||||
|
"best_position_size": float(best_position_size),
|
||||||
|
"signal_distribution": signal_dist,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"data_age": int(time.time() - last_data_refresh_time),
|
||||||
|
"cumulative_pnl": {
|
||||||
|
"train": float(training_stats["cumulative_pnl"]["train"]),
|
||||||
|
"val": float(training_stats["cumulative_pnl"]["val"])
|
||||||
|
},
|
||||||
|
"total_trades": {
|
||||||
|
"train": int(training_stats["total_trades"]["train"]),
|
||||||
|
"val": int(training_stats["total_trades"]["val"])
|
||||||
|
},
|
||||||
|
"overall_win_rate": {
|
||||||
|
"train": float(overall_train_wr),
|
||||||
|
"val": float(overall_val_wr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update training stats
|
||||||
|
training_stats["epochs_completed"] = epoch
|
||||||
|
training_stats["best_val_pnl"] = float(best_val_pnl)
|
||||||
|
training_stats["best_epoch"] = best_epoch
|
||||||
|
training_stats["best_win_rate"] = float(best_win_rate)
|
||||||
|
training_stats["last_update"] = datetime.now().isoformat()
|
||||||
|
training_stats["epochs"].append(epoch_metrics)
|
||||||
|
|
||||||
|
# Check if we need to save checkpoint
|
||||||
|
if time.time() - last_checkpoint_time > checkpoint_interval:
|
||||||
|
logger.info(f"Saving checkpoint at epoch {epoch}")
|
||||||
|
# Save model checkpoint
|
||||||
|
model.save(f"{checkpoint_dir}/checkpoint_epoch_{epoch}")
|
||||||
|
# Save training statistics
|
||||||
|
save_training_stats(training_stats)
|
||||||
|
last_checkpoint_time = time.time()
|
||||||
|
|
||||||
|
# Test trade signal generation with a random sample
|
||||||
|
random_idx = np.random.randint(0, len(X_val))
|
||||||
|
sample_X = X_val[random_idx:random_idx+1]
|
||||||
|
sample_probs, sample_price_pred = model.predict(sample_X)
|
||||||
|
|
||||||
|
# Process with signal interpreter
|
||||||
|
signal = signal_interpreter.interpret_signal(
|
||||||
|
sample_probs[0],
|
||||||
|
float(sample_price_pred[0][0]) if hasattr(sample_price_pred, "__getitem__") else float(sample_price_pred[0]),
|
||||||
|
market_data={'price': float(val_prices[random_idx]) if random_idx < len(val_prices) else 50000.0}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f" Sample trade signal: {signal['action']} with confidence {signal['confidence']:.4f}")
|
||||||
|
|
||||||
|
# Log trading statistics
|
||||||
|
logger.info(f" Train - Actions: BUY={train_buy_count}, SELL={train_sell_count}, HOLD={train_hold_count}")
|
||||||
|
logger.info(f" Valid - Actions: BUY={val_buy_count}, SELL={val_sell_count}, HOLD={val_hold_count}")
|
||||||
|
|
||||||
|
# Log epoch timing
|
||||||
|
epoch_time = time.time() - epoch_start
|
||||||
|
total_elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
logger.info(f" Epoch completed in {epoch_time:.2f} seconds")
|
||||||
|
logger.info(f" Total training time: {total_elapsed/3600:.2f} hours")
|
||||||
|
|
||||||
|
# Small delay to prevent CPU overload
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during epoch {epoch}: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
consecutive_failures += 1
|
||||||
|
if consecutive_failures >= max_consecutive_failures:
|
||||||
|
logger.error("Too many consecutive failures. Stopping training.")
|
||||||
|
break
|
||||||
|
await asyncio.sleep(5) # Wait before retrying
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
tick_processor.running = False
|
||||||
|
websocket_task.cancel()
|
||||||
|
|
||||||
|
# Save final model and performance metrics
|
||||||
|
logger.info("Saving final optimized model...")
|
||||||
|
model.save("NN/models/saved/optimized_short_term_model_ticks_final")
|
||||||
|
|
||||||
|
# Save performance metrics to file
|
||||||
|
save_training_stats(training_stats)
|
||||||
|
|
||||||
|
# Generate performance plots
|
||||||
|
try:
|
||||||
|
model.plot_training_history("NN/models/saved/realtime_ticks_training_stats.json")
|
||||||
|
logger.info("Performance plots generated successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating plots: {str(e)}")
|
||||||
|
|
||||||
|
# Calculate total training time
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
hours, remainder = divmod(total_time, 3600)
|
||||||
|
minutes, seconds = divmod(remainder, 60)
|
||||||
|
|
||||||
|
logger.info(f"Continuous training completed in {int(hours)}h {int(minutes)}m {int(seconds)}s")
|
||||||
|
logger.info(f"Best model performance - Epoch: {best_epoch}, PnL: {best_val_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
|
||||||
|
|
||||||
|
# Close TensorBoard writer
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during real-time training: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
# Try to save the model and stats in case of error
|
||||||
|
try:
|
||||||
|
if 'model' in locals():
|
||||||
|
model.save("NN/models/saved/optimized_short_term_model_ticks_emergency")
|
||||||
|
logger.info("Emergency model save completed")
|
||||||
|
if 'training_stats' in locals():
|
||||||
|
save_training_stats(training_stats, "NN/models/saved/realtime_ticks_training_stats_emergency.json")
|
||||||
|
if 'writer' in locals():
|
||||||
|
writer.close()
|
||||||
|
except Exception as e2:
|
||||||
|
logger.error(f"Failed to save emergency checkpoint: {str(e2)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Print startup banner
|
||||||
|
print("=" * 80)
|
||||||
|
print("CONTINUOUS REALTIME TICKS TRAINING SESSION")
|
||||||
|
print("This script will continuously train the model using real-time tick data")
|
||||||
|
print("Press Ctrl+C to safely stop training and save the model")
|
||||||
|
print("TensorBoard logs will be saved to runs/realtime_ticks_training")
|
||||||
|
print("To view TensorBoard, run: tensorboard --logdir=runs/realtime_ticks_training")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Run the async training loop
|
||||||
|
asyncio.run(run_realtime_training())
|
Loading…
x
Reference in New Issue
Block a user