RL training
This commit is contained in:
@ -475,7 +475,7 @@ class CNNModelPyTorch:
|
||||
diversity_weight * diversity_loss)
|
||||
|
||||
return total_loss, action_loss, price_loss
|
||||
|
||||
|
||||
def train_epoch(self, X_train, y_train, future_prices, batch_size):
|
||||
"""Train the model for one epoch with focus on short-term pattern recognition"""
|
||||
self.model.train()
|
||||
@ -919,13 +919,7 @@ class CNNModelPyTorch:
|
||||
logger.info(f"Backup saved to {backup_path}")
|
||||
|
||||
def load(self, filepath):
|
||||
"""
|
||||
Load the model from a file.
|
||||
|
||||
Args:
|
||||
filepath: Path to load the model from
|
||||
"""
|
||||
# Check if file exists
|
||||
"""Load model weights from file"""
|
||||
if not os.path.exists(f"{filepath}.pt"):
|
||||
logger.error(f"Model file {filepath}.pt not found")
|
||||
return False
|
||||
@ -938,27 +932,20 @@ class CNNModelPyTorch:
|
||||
self.window_size = model_state['window_size']
|
||||
self.num_features = model_state['num_features']
|
||||
self.output_size = model_state['output_size']
|
||||
self.timeframes = model_state['timeframes']
|
||||
self.timeframes = model_state.get('timeframes', ["1m"])
|
||||
|
||||
# Load model state dict
|
||||
self.load_state_dict(model_state['model_state_dict'])
|
||||
|
||||
# Load optimizer state if available
|
||||
if 'optimizer_state_dict' in model_state:
|
||||
self.optimizer.load_state_dict(model_state['optimizer_state_dict'])
|
||||
|
||||
# Load trading configuration if available
|
||||
if 'confidence_threshold' in model_state:
|
||||
self.confidence_threshold = model_state['confidence_threshold']
|
||||
if 'max_consecutive_same_action' in model_state:
|
||||
self.max_consecutive_same_action = model_state['max_consecutive_same_action']
|
||||
if 'action_counts' in model_state:
|
||||
self.action_counts = model_state['action_counts']
|
||||
if 'last_actions' in model_state:
|
||||
self.last_actions = model_state['last_actions']
|
||||
|
||||
# Rebuild the model
|
||||
self.build_model()
|
||||
|
||||
# Load the model state
|
||||
self.model.load_state_dict(model_state['model_state_dict'])
|
||||
self.optimizer.load_state_dict(model_state['optimizer_state_dict'])
|
||||
self.history = model_state['history']
|
||||
|
||||
logger.info(f"Model loaded from {filepath}.pt")
|
||||
|
||||
# Log model version information if available
|
||||
if 'model_version' in model_state:
|
||||
@ -973,7 +960,7 @@ class CNNModelPyTorch:
|
||||
|
||||
def plot_training_history(self, metrics_file="NN/models/saved/training_metrics.json"):
|
||||
"""
|
||||
Generate comprehensive performance visualization plots from training history
|
||||
Plot training history from saved metrics.
|
||||
|
||||
Args:
|
||||
metrics_file: Path to the saved metrics JSON file
|
||||
@ -983,253 +970,72 @@ class CNNModelPyTorch:
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
# Create directory for plots
|
||||
plots_dir = "NN/models/saved/performance_plots"
|
||||
os.makedirs(plots_dir, exist_ok=True)
|
||||
|
||||
# Load metrics
|
||||
with open(metrics_file, 'r') as f:
|
||||
metrics = json.load(f)
|
||||
|
||||
epochs = metrics["epoch"]
|
||||
|
||||
# Set default style for better visualization
|
||||
plt.style.use('seaborn-darkgrid')
|
||||
# Create plots directory
|
||||
plots_dir = os.path.join(os.path.dirname(metrics_file), 'plots')
|
||||
os.makedirs(plots_dir, exist_ok=True)
|
||||
|
||||
# Convert timestamps to datetime objects
|
||||
timestamps = [datetime.fromisoformat(ts) for ts in metrics['timestamps']]
|
||||
|
||||
# 1. Plot Loss and Accuracy
|
||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
|
||||
|
||||
# Loss plot
|
||||
ax1.plot(epochs, metrics["train_loss"], 'b-', label='Training Loss')
|
||||
ax1.plot(epochs, metrics["val_loss"], 'r-', label='Validation Loss')
|
||||
ax1.set_title('Model Loss over Epochs', fontsize=16)
|
||||
ax1.set_ylabel('Loss', fontsize=14)
|
||||
ax1.legend(loc='upper right', fontsize=12)
|
||||
ax1.plot(timestamps, metrics['train_loss'], 'b-', label='Training Loss')
|
||||
ax1.plot(timestamps, metrics['val_loss'], 'r-', label='Validation Loss')
|
||||
ax1.set_title('Model Loss Over Time')
|
||||
ax1.set_ylabel('Loss')
|
||||
ax1.legend()
|
||||
ax1.grid(True)
|
||||
|
||||
# Accuracy plot
|
||||
ax2.plot(epochs, metrics["train_acc"], 'b-', label='Training Accuracy')
|
||||
ax2.plot(epochs, metrics["val_acc"], 'r-', label='Validation Accuracy')
|
||||
ax2.set_title('Model Accuracy over Epochs', fontsize=16)
|
||||
ax2.set_xlabel('Epoch', fontsize=14)
|
||||
ax2.set_ylabel('Accuracy', fontsize=14)
|
||||
ax2.legend(loc='lower right', fontsize=12)
|
||||
ax2.plot(timestamps, metrics['train_acc'], 'g-', label='Training Accuracy')
|
||||
ax2.plot(timestamps, metrics['val_acc'], 'm-', label='Validation Accuracy')
|
||||
ax2.set_title('Model Accuracy Over Time')
|
||||
ax2.set_ylabel('Accuracy')
|
||||
ax2.legend()
|
||||
ax2.grid(True)
|
||||
|
||||
# Format x-axis
|
||||
ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M'))
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# Save the plot
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"{plots_dir}/loss_accuracy.png", dpi=300)
|
||||
plt.savefig(os.path.join(plots_dir, 'loss_accuracy.png'))
|
||||
plt.close()
|
||||
|
||||
# 2. Plot PnL and Win Rate
|
||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
|
||||
|
||||
# PnL plot
|
||||
ax1.plot(epochs, metrics["train_pnl"], 'g-', label='Training PnL')
|
||||
ax1.plot(epochs, metrics["val_pnl"], 'm-', label='Validation PnL')
|
||||
ax1.set_title('Trading Profit and Loss over Epochs', fontsize=16)
|
||||
ax1.set_ylabel('PnL', fontsize=14)
|
||||
ax1.legend(loc='upper left', fontsize=12)
|
||||
ax1.plot(timestamps, metrics['train_pnl'], 'g-', label='Training PnL')
|
||||
ax1.plot(timestamps, metrics['val_pnl'], 'r-', label='Validation PnL')
|
||||
ax1.set_title('PnL Over Time')
|
||||
ax1.set_ylabel('PnL')
|
||||
ax1.legend()
|
||||
ax1.grid(True)
|
||||
|
||||
# Win Rate plot
|
||||
ax2.plot(epochs, metrics["train_win_rate"], 'g-', label='Training Win Rate')
|
||||
ax2.plot(epochs, metrics["val_win_rate"], 'm-', label='Validation Win Rate')
|
||||
ax2.set_title('Trading Win Rate over Epochs', fontsize=16)
|
||||
ax2.set_xlabel('Epoch', fontsize=14)
|
||||
ax2.set_ylabel('Win Rate', fontsize=14)
|
||||
ax2.axhline(y=0.5, color='r', linestyle='--', label='50% Threshold')
|
||||
ax2.legend(loc='lower right', fontsize=12)
|
||||
ax2.plot(timestamps, metrics['train_win_rate'], 'b-', label='Training Win Rate')
|
||||
ax2.plot(timestamps, metrics['val_win_rate'], 'm-', label='Validation Win Rate')
|
||||
ax2.set_title('Win Rate Over Time')
|
||||
ax2.set_ylabel('Win Rate')
|
||||
ax2.legend()
|
||||
ax2.grid(True)
|
||||
|
||||
# Format x-axis
|
||||
ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M'))
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# Save the plot
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"{plots_dir}/pnl_winrate.png", dpi=300)
|
||||
plt.close()
|
||||
|
||||
# 3. Plot Signal Distribution over time
|
||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
|
||||
|
||||
# Training Signal Distribution
|
||||
buy_train = [epoch_dist["train"]["BUY"] for epoch_dist in metrics["signal_distribution"]]
|
||||
sell_train = [epoch_dist["train"]["SELL"] for epoch_dist in metrics["signal_distribution"]]
|
||||
hold_train = [epoch_dist["train"]["HOLD"] for epoch_dist in metrics["signal_distribution"]]
|
||||
|
||||
ax1.stackplot(epochs, buy_train, hold_train, sell_train,
|
||||
labels=['BUY', 'HOLD', 'SELL'],
|
||||
colors=['green', 'gray', 'red'], alpha=0.7)
|
||||
ax1.set_title('Training Signal Distribution over Epochs', fontsize=16)
|
||||
ax1.set_ylabel('Proportion', fontsize=14)
|
||||
ax1.legend(loc='upper right', fontsize=12)
|
||||
ax1.set_ylim(0, 1)
|
||||
ax1.grid(True)
|
||||
|
||||
# Validation Signal Distribution
|
||||
buy_val = [epoch_dist["val"]["BUY"] for epoch_dist in metrics["signal_distribution"]]
|
||||
sell_val = [epoch_dist["val"]["SELL"] for epoch_dist in metrics["signal_distribution"]]
|
||||
hold_val = [epoch_dist["val"]["HOLD"] for epoch_dist in metrics["signal_distribution"]]
|
||||
|
||||
ax2.stackplot(epochs, buy_val, hold_val, sell_val,
|
||||
labels=['BUY', 'HOLD', 'SELL'],
|
||||
colors=['green', 'gray', 'red'], alpha=0.7)
|
||||
ax2.set_title('Validation Signal Distribution over Epochs', fontsize=16)
|
||||
ax2.set_xlabel('Epoch', fontsize=14)
|
||||
ax2.set_ylabel('Proportion', fontsize=14)
|
||||
ax2.legend(loc='upper right', fontsize=12)
|
||||
ax2.set_ylim(0, 1)
|
||||
ax2.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"{plots_dir}/signal_distribution.png", dpi=300)
|
||||
plt.close()
|
||||
|
||||
# 4. Performance Correlation Matrix
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# Extract key metrics for correlation
|
||||
corr_data = {}
|
||||
corr_data['Loss'] = metrics["train_loss"]
|
||||
corr_data['Accuracy'] = metrics["train_acc"]
|
||||
corr_data['PnL'] = metrics["train_pnl"]
|
||||
corr_data['Win Rate'] = metrics["train_win_rate"]
|
||||
corr_data['BUY %'] = buy_train
|
||||
corr_data['SELL %'] = sell_train
|
||||
corr_data['HOLD %'] = hold_train
|
||||
|
||||
# Convert to numpy array
|
||||
corr_matrix = np.zeros((len(corr_data), len(corr_data)))
|
||||
labels = list(corr_data.keys())
|
||||
|
||||
# Calculate correlation
|
||||
for i, key1 in enumerate(labels):
|
||||
for j, key2 in enumerate(labels):
|
||||
if i == j:
|
||||
corr_matrix[i, j] = 1.0
|
||||
else:
|
||||
corr = np.corrcoef(corr_data[key1], corr_data[key2])[0, 1]
|
||||
corr_matrix[i, j] = corr
|
||||
|
||||
# Plot heatmap
|
||||
im = ax.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)
|
||||
|
||||
# Add colorbar
|
||||
cbar = fig.colorbar(im, ax=ax)
|
||||
cbar.set_label('Correlation', rotation=270, labelpad=20, fontsize=14)
|
||||
|
||||
# Add ticks and labels
|
||||
ax.set_xticks(np.arange(len(labels)))
|
||||
ax.set_yticks(np.arange(len(labels)))
|
||||
ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=12)
|
||||
ax.set_yticklabels(labels, fontsize=12)
|
||||
|
||||
# Add text annotations
|
||||
for i in range(len(labels)):
|
||||
for j in range(len(labels)):
|
||||
text = ax.text(j, i, f"{corr_matrix[i, j]:.2f}",
|
||||
ha="center", va="center", color="black" if abs(corr_matrix[i, j]) < 0.7 else "white")
|
||||
|
||||
ax.set_title('Correlation Matrix of Performance Metrics', fontsize=16)
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"{plots_dir}/correlation_matrix.png", dpi=300)
|
||||
plt.close()
|
||||
|
||||
# 5. Combined Performance Dashboard
|
||||
fig = plt.figure(figsize=(16, 20))
|
||||
|
||||
# Define grid layout
|
||||
gs = fig.add_gridspec(4, 2, hspace=0.4, wspace=0.3)
|
||||
|
||||
# Plot 1: Loss curves
|
||||
ax1 = fig.add_subplot(gs[0, 0])
|
||||
ax1.plot(epochs, metrics["train_loss"], 'b-', label='Training')
|
||||
ax1.plot(epochs, metrics["val_loss"], 'r-', label='Validation')
|
||||
ax1.set_title('Loss', fontsize=14)
|
||||
ax1.set_xlabel('Epoch', fontsize=12)
|
||||
ax1.set_ylabel('Loss', fontsize=12)
|
||||
ax1.legend(fontsize=10)
|
||||
ax1.grid(True)
|
||||
|
||||
# Plot 2: Accuracy
|
||||
ax2 = fig.add_subplot(gs[0, 1])
|
||||
ax2.plot(epochs, metrics["train_acc"], 'b-', label='Training')
|
||||
ax2.plot(epochs, metrics["val_acc"], 'r-', label='Validation')
|
||||
ax2.set_title('Accuracy', fontsize=14)
|
||||
ax2.set_xlabel('Epoch', fontsize=12)
|
||||
ax2.set_ylabel('Accuracy', fontsize=12)
|
||||
ax2.legend(fontsize=10)
|
||||
ax2.grid(True)
|
||||
|
||||
# Plot 3: PnL
|
||||
ax3 = fig.add_subplot(gs[1, 0])
|
||||
ax3.plot(epochs, metrics["train_pnl"], 'g-', label='Training')
|
||||
ax3.plot(epochs, metrics["val_pnl"], 'm-', label='Validation')
|
||||
ax3.set_title('Profit and Loss', fontsize=14)
|
||||
ax3.set_xlabel('Epoch', fontsize=12)
|
||||
ax3.set_ylabel('PnL', fontsize=12)
|
||||
ax3.legend(fontsize=10)
|
||||
ax3.grid(True)
|
||||
|
||||
# Plot 4: Win Rate
|
||||
ax4 = fig.add_subplot(gs[1, 1])
|
||||
ax4.plot(epochs, metrics["train_win_rate"], 'g-', label='Training')
|
||||
ax4.plot(epochs, metrics["val_win_rate"], 'm-', label='Validation')
|
||||
ax4.axhline(y=0.5, color='r', linestyle='--', label='50% Threshold')
|
||||
ax4.set_title('Win Rate', fontsize=14)
|
||||
ax4.set_xlabel('Epoch', fontsize=12)
|
||||
ax4.set_ylabel('Win Rate', fontsize=12)
|
||||
ax4.legend(fontsize=10)
|
||||
ax4.grid(True)
|
||||
|
||||
# Plot 5: Training Signal Distribution
|
||||
ax5 = fig.add_subplot(gs[2, 0])
|
||||
ax5.stackplot(epochs, buy_train, hold_train, sell_train,
|
||||
labels=['BUY', 'HOLD', 'SELL'],
|
||||
colors=['green', 'gray', 'red'], alpha=0.7)
|
||||
ax5.set_title('Training Signal Distribution', fontsize=14)
|
||||
ax5.set_xlabel('Epoch', fontsize=12)
|
||||
ax5.set_ylabel('Proportion', fontsize=12)
|
||||
ax5.legend(fontsize=10)
|
||||
ax5.set_ylim(0, 1)
|
||||
ax5.grid(True)
|
||||
|
||||
# Plot 6: Validation Signal Distribution
|
||||
ax6 = fig.add_subplot(gs[2, 1])
|
||||
ax6.stackplot(epochs, buy_val, hold_val, sell_val,
|
||||
labels=['BUY', 'HOLD', 'SELL'],
|
||||
colors=['green', 'gray', 'red'], alpha=0.7)
|
||||
ax6.set_title('Validation Signal Distribution', fontsize=14)
|
||||
ax6.set_xlabel('Epoch', fontsize=12)
|
||||
ax6.set_ylabel('Proportion', fontsize=12)
|
||||
ax6.legend(fontsize=10)
|
||||
ax6.set_ylim(0, 1)
|
||||
ax6.grid(True)
|
||||
|
||||
# Plot 7: Performance Correlation Heatmap
|
||||
ax7 = fig.add_subplot(gs[3, :])
|
||||
im = ax7.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)
|
||||
cbar = fig.colorbar(im, ax=ax7, fraction=0.025, pad=0.04)
|
||||
cbar.set_label('Correlation', rotation=270, labelpad=20, fontsize=12)
|
||||
|
||||
# Add ticks and labels
|
||||
ax7.set_xticks(np.arange(len(labels)))
|
||||
ax7.set_yticks(np.arange(len(labels)))
|
||||
ax7.set_xticklabels(labels, rotation=45, ha="right", fontsize=10)
|
||||
ax7.set_yticklabels(labels, fontsize=10)
|
||||
|
||||
# Add text annotations
|
||||
for i in range(len(labels)):
|
||||
for j in range(len(labels)):
|
||||
text = ax7.text(j, i, f"{corr_matrix[i, j]:.2f}",
|
||||
ha="center", va="center", color="black" if abs(corr_matrix[i, j]) < 0.7 else "white")
|
||||
|
||||
ax7.set_title('Correlation Matrix of Performance Metrics', fontsize=14)
|
||||
|
||||
# Add main title
|
||||
plt.suptitle('CNN Model Performance Dashboard', fontsize=20, y=0.98)
|
||||
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.97])
|
||||
plt.savefig(f"{plots_dir}/performance_dashboard.png", dpi=300)
|
||||
plt.savefig(os.path.join(plots_dir, 'pnl_winrate.png'))
|
||||
plt.close()
|
||||
|
||||
print(f"Performance visualizations saved to {plots_dir}")
|
||||
@ -1239,7 +1045,7 @@ class CNNModelPyTorch:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return False
|
||||
|
||||
|
||||
def extract_hidden_features(self, X):
|
||||
"""
|
||||
Extract hidden features from the model - outputs from last dense layer before output.
|
||||
|
Reference in New Issue
Block a user