RL training
This commit is contained in:
@ -475,7 +475,7 @@ class CNNModelPyTorch:
|
||||
diversity_weight * diversity_loss)
|
||||
|
||||
return total_loss, action_loss, price_loss
|
||||
|
||||
|
||||
def train_epoch(self, X_train, y_train, future_prices, batch_size):
|
||||
"""Train the model for one epoch with focus on short-term pattern recognition"""
|
||||
self.model.train()
|
||||
@ -919,13 +919,7 @@ class CNNModelPyTorch:
|
||||
logger.info(f"Backup saved to {backup_path}")
|
||||
|
||||
def load(self, filepath):
|
||||
"""
|
||||
Load the model from a file.
|
||||
|
||||
Args:
|
||||
filepath: Path to load the model from
|
||||
"""
|
||||
# Check if file exists
|
||||
"""Load model weights from file"""
|
||||
if not os.path.exists(f"{filepath}.pt"):
|
||||
logger.error(f"Model file {filepath}.pt not found")
|
||||
return False
|
||||
@ -938,27 +932,20 @@ class CNNModelPyTorch:
|
||||
self.window_size = model_state['window_size']
|
||||
self.num_features = model_state['num_features']
|
||||
self.output_size = model_state['output_size']
|
||||
self.timeframes = model_state['timeframes']
|
||||
self.timeframes = model_state.get('timeframes', ["1m"])
|
||||
|
||||
# Load model state dict
|
||||
self.load_state_dict(model_state['model_state_dict'])
|
||||
|
||||
# Load optimizer state if available
|
||||
if 'optimizer_state_dict' in model_state:
|
||||
self.optimizer.load_state_dict(model_state['optimizer_state_dict'])
|
||||
|
||||
# Load trading configuration if available
|
||||
if 'confidence_threshold' in model_state:
|
||||
self.confidence_threshold = model_state['confidence_threshold']
|
||||
if 'max_consecutive_same_action' in model_state:
|
||||
self.max_consecutive_same_action = model_state['max_consecutive_same_action']
|
||||
if 'action_counts' in model_state:
|
||||
self.action_counts = model_state['action_counts']
|
||||
if 'last_actions' in model_state:
|
||||
self.last_actions = model_state['last_actions']
|
||||
|
||||
# Rebuild the model
|
||||
self.build_model()
|
||||
|
||||
# Load the model state
|
||||
self.model.load_state_dict(model_state['model_state_dict'])
|
||||
self.optimizer.load_state_dict(model_state['optimizer_state_dict'])
|
||||
self.history = model_state['history']
|
||||
|
||||
logger.info(f"Model loaded from {filepath}.pt")
|
||||
|
||||
# Log model version information if available
|
||||
if 'model_version' in model_state:
|
||||
@ -973,7 +960,7 @@ class CNNModelPyTorch:
|
||||
|
||||
def plot_training_history(self, metrics_file="NN/models/saved/training_metrics.json"):
|
||||
"""
|
||||
Generate comprehensive performance visualization plots from training history
|
||||
Plot training history from saved metrics.
|
||||
|
||||
Args:
|
||||
metrics_file: Path to the saved metrics JSON file
|
||||
@ -983,253 +970,72 @@ class CNNModelPyTorch:
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
# Create directory for plots
|
||||
plots_dir = "NN/models/saved/performance_plots"
|
||||
os.makedirs(plots_dir, exist_ok=True)
|
||||
|
||||
# Load metrics
|
||||
with open(metrics_file, 'r') as f:
|
||||
metrics = json.load(f)
|
||||
|
||||
epochs = metrics["epoch"]
|
||||
|
||||
# Set default style for better visualization
|
||||
plt.style.use('seaborn-darkgrid')
|
||||
# Create plots directory
|
||||
plots_dir = os.path.join(os.path.dirname(metrics_file), 'plots')
|
||||
os.makedirs(plots_dir, exist_ok=True)
|
||||
|
||||
# Convert timestamps to datetime objects
|
||||
timestamps = [datetime.fromisoformat(ts) for ts in metrics['timestamps']]
|
||||
|
||||
# 1. Plot Loss and Accuracy
|
||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
|
||||
|
||||
# Loss plot
|
||||
ax1.plot(epochs, metrics["train_loss"], 'b-', label='Training Loss')
|
||||
ax1.plot(epochs, metrics["val_loss"], 'r-', label='Validation Loss')
|
||||
ax1.set_title('Model Loss over Epochs', fontsize=16)
|
||||
ax1.set_ylabel('Loss', fontsize=14)
|
||||
ax1.legend(loc='upper right', fontsize=12)
|
||||
ax1.plot(timestamps, metrics['train_loss'], 'b-', label='Training Loss')
|
||||
ax1.plot(timestamps, metrics['val_loss'], 'r-', label='Validation Loss')
|
||||
ax1.set_title('Model Loss Over Time')
|
||||
ax1.set_ylabel('Loss')
|
||||
ax1.legend()
|
||||
ax1.grid(True)
|
||||
|
||||
# Accuracy plot
|
||||
ax2.plot(epochs, metrics["train_acc"], 'b-', label='Training Accuracy')
|
||||
ax2.plot(epochs, metrics["val_acc"], 'r-', label='Validation Accuracy')
|
||||
ax2.set_title('Model Accuracy over Epochs', fontsize=16)
|
||||
ax2.set_xlabel('Epoch', fontsize=14)
|
||||
ax2.set_ylabel('Accuracy', fontsize=14)
|
||||
ax2.legend(loc='lower right', fontsize=12)
|
||||
ax2.plot(timestamps, metrics['train_acc'], 'g-', label='Training Accuracy')
|
||||
ax2.plot(timestamps, metrics['val_acc'], 'm-', label='Validation Accuracy')
|
||||
ax2.set_title('Model Accuracy Over Time')
|
||||
ax2.set_ylabel('Accuracy')
|
||||
ax2.legend()
|
||||
ax2.grid(True)
|
||||
|
||||
# Format x-axis
|
||||
ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M'))
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# Save the plot
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"{plots_dir}/loss_accuracy.png", dpi=300)
|
||||
plt.savefig(os.path.join(plots_dir, 'loss_accuracy.png'))
|
||||
plt.close()
|
||||
|
||||
# 2. Plot PnL and Win Rate
|
||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
|
||||
|
||||
# PnL plot
|
||||
ax1.plot(epochs, metrics["train_pnl"], 'g-', label='Training PnL')
|
||||
ax1.plot(epochs, metrics["val_pnl"], 'm-', label='Validation PnL')
|
||||
ax1.set_title('Trading Profit and Loss over Epochs', fontsize=16)
|
||||
ax1.set_ylabel('PnL', fontsize=14)
|
||||
ax1.legend(loc='upper left', fontsize=12)
|
||||
ax1.plot(timestamps, metrics['train_pnl'], 'g-', label='Training PnL')
|
||||
ax1.plot(timestamps, metrics['val_pnl'], 'r-', label='Validation PnL')
|
||||
ax1.set_title('PnL Over Time')
|
||||
ax1.set_ylabel('PnL')
|
||||
ax1.legend()
|
||||
ax1.grid(True)
|
||||
|
||||
# Win Rate plot
|
||||
ax2.plot(epochs, metrics["train_win_rate"], 'g-', label='Training Win Rate')
|
||||
ax2.plot(epochs, metrics["val_win_rate"], 'm-', label='Validation Win Rate')
|
||||
ax2.set_title('Trading Win Rate over Epochs', fontsize=16)
|
||||
ax2.set_xlabel('Epoch', fontsize=14)
|
||||
ax2.set_ylabel('Win Rate', fontsize=14)
|
||||
ax2.axhline(y=0.5, color='r', linestyle='--', label='50% Threshold')
|
||||
ax2.legend(loc='lower right', fontsize=12)
|
||||
ax2.plot(timestamps, metrics['train_win_rate'], 'b-', label='Training Win Rate')
|
||||
ax2.plot(timestamps, metrics['val_win_rate'], 'm-', label='Validation Win Rate')
|
||||
ax2.set_title('Win Rate Over Time')
|
||||
ax2.set_ylabel('Win Rate')
|
||||
ax2.legend()
|
||||
ax2.grid(True)
|
||||
|
||||
# Format x-axis
|
||||
ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M'))
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# Save the plot
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"{plots_dir}/pnl_winrate.png", dpi=300)
|
||||
plt.close()
|
||||
|
||||
# 3. Plot Signal Distribution over time
|
||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
|
||||
|
||||
# Training Signal Distribution
|
||||
buy_train = [epoch_dist["train"]["BUY"] for epoch_dist in metrics["signal_distribution"]]
|
||||
sell_train = [epoch_dist["train"]["SELL"] for epoch_dist in metrics["signal_distribution"]]
|
||||
hold_train = [epoch_dist["train"]["HOLD"] for epoch_dist in metrics["signal_distribution"]]
|
||||
|
||||
ax1.stackplot(epochs, buy_train, hold_train, sell_train,
|
||||
labels=['BUY', 'HOLD', 'SELL'],
|
||||
colors=['green', 'gray', 'red'], alpha=0.7)
|
||||
ax1.set_title('Training Signal Distribution over Epochs', fontsize=16)
|
||||
ax1.set_ylabel('Proportion', fontsize=14)
|
||||
ax1.legend(loc='upper right', fontsize=12)
|
||||
ax1.set_ylim(0, 1)
|
||||
ax1.grid(True)
|
||||
|
||||
# Validation Signal Distribution
|
||||
buy_val = [epoch_dist["val"]["BUY"] for epoch_dist in metrics["signal_distribution"]]
|
||||
sell_val = [epoch_dist["val"]["SELL"] for epoch_dist in metrics["signal_distribution"]]
|
||||
hold_val = [epoch_dist["val"]["HOLD"] for epoch_dist in metrics["signal_distribution"]]
|
||||
|
||||
ax2.stackplot(epochs, buy_val, hold_val, sell_val,
|
||||
labels=['BUY', 'HOLD', 'SELL'],
|
||||
colors=['green', 'gray', 'red'], alpha=0.7)
|
||||
ax2.set_title('Validation Signal Distribution over Epochs', fontsize=16)
|
||||
ax2.set_xlabel('Epoch', fontsize=14)
|
||||
ax2.set_ylabel('Proportion', fontsize=14)
|
||||
ax2.legend(loc='upper right', fontsize=12)
|
||||
ax2.set_ylim(0, 1)
|
||||
ax2.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"{plots_dir}/signal_distribution.png", dpi=300)
|
||||
plt.close()
|
||||
|
||||
# 4. Performance Correlation Matrix
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# Extract key metrics for correlation
|
||||
corr_data = {}
|
||||
corr_data['Loss'] = metrics["train_loss"]
|
||||
corr_data['Accuracy'] = metrics["train_acc"]
|
||||
corr_data['PnL'] = metrics["train_pnl"]
|
||||
corr_data['Win Rate'] = metrics["train_win_rate"]
|
||||
corr_data['BUY %'] = buy_train
|
||||
corr_data['SELL %'] = sell_train
|
||||
corr_data['HOLD %'] = hold_train
|
||||
|
||||
# Convert to numpy array
|
||||
corr_matrix = np.zeros((len(corr_data), len(corr_data)))
|
||||
labels = list(corr_data.keys())
|
||||
|
||||
# Calculate correlation
|
||||
for i, key1 in enumerate(labels):
|
||||
for j, key2 in enumerate(labels):
|
||||
if i == j:
|
||||
corr_matrix[i, j] = 1.0
|
||||
else:
|
||||
corr = np.corrcoef(corr_data[key1], corr_data[key2])[0, 1]
|
||||
corr_matrix[i, j] = corr
|
||||
|
||||
# Plot heatmap
|
||||
im = ax.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)
|
||||
|
||||
# Add colorbar
|
||||
cbar = fig.colorbar(im, ax=ax)
|
||||
cbar.set_label('Correlation', rotation=270, labelpad=20, fontsize=14)
|
||||
|
||||
# Add ticks and labels
|
||||
ax.set_xticks(np.arange(len(labels)))
|
||||
ax.set_yticks(np.arange(len(labels)))
|
||||
ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=12)
|
||||
ax.set_yticklabels(labels, fontsize=12)
|
||||
|
||||
# Add text annotations
|
||||
for i in range(len(labels)):
|
||||
for j in range(len(labels)):
|
||||
text = ax.text(j, i, f"{corr_matrix[i, j]:.2f}",
|
||||
ha="center", va="center", color="black" if abs(corr_matrix[i, j]) < 0.7 else "white")
|
||||
|
||||
ax.set_title('Correlation Matrix of Performance Metrics', fontsize=16)
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"{plots_dir}/correlation_matrix.png", dpi=300)
|
||||
plt.close()
|
||||
|
||||
# 5. Combined Performance Dashboard
|
||||
fig = plt.figure(figsize=(16, 20))
|
||||
|
||||
# Define grid layout
|
||||
gs = fig.add_gridspec(4, 2, hspace=0.4, wspace=0.3)
|
||||
|
||||
# Plot 1: Loss curves
|
||||
ax1 = fig.add_subplot(gs[0, 0])
|
||||
ax1.plot(epochs, metrics["train_loss"], 'b-', label='Training')
|
||||
ax1.plot(epochs, metrics["val_loss"], 'r-', label='Validation')
|
||||
ax1.set_title('Loss', fontsize=14)
|
||||
ax1.set_xlabel('Epoch', fontsize=12)
|
||||
ax1.set_ylabel('Loss', fontsize=12)
|
||||
ax1.legend(fontsize=10)
|
||||
ax1.grid(True)
|
||||
|
||||
# Plot 2: Accuracy
|
||||
ax2 = fig.add_subplot(gs[0, 1])
|
||||
ax2.plot(epochs, metrics["train_acc"], 'b-', label='Training')
|
||||
ax2.plot(epochs, metrics["val_acc"], 'r-', label='Validation')
|
||||
ax2.set_title('Accuracy', fontsize=14)
|
||||
ax2.set_xlabel('Epoch', fontsize=12)
|
||||
ax2.set_ylabel('Accuracy', fontsize=12)
|
||||
ax2.legend(fontsize=10)
|
||||
ax2.grid(True)
|
||||
|
||||
# Plot 3: PnL
|
||||
ax3 = fig.add_subplot(gs[1, 0])
|
||||
ax3.plot(epochs, metrics["train_pnl"], 'g-', label='Training')
|
||||
ax3.plot(epochs, metrics["val_pnl"], 'm-', label='Validation')
|
||||
ax3.set_title('Profit and Loss', fontsize=14)
|
||||
ax3.set_xlabel('Epoch', fontsize=12)
|
||||
ax3.set_ylabel('PnL', fontsize=12)
|
||||
ax3.legend(fontsize=10)
|
||||
ax3.grid(True)
|
||||
|
||||
# Plot 4: Win Rate
|
||||
ax4 = fig.add_subplot(gs[1, 1])
|
||||
ax4.plot(epochs, metrics["train_win_rate"], 'g-', label='Training')
|
||||
ax4.plot(epochs, metrics["val_win_rate"], 'm-', label='Validation')
|
||||
ax4.axhline(y=0.5, color='r', linestyle='--', label='50% Threshold')
|
||||
ax4.set_title('Win Rate', fontsize=14)
|
||||
ax4.set_xlabel('Epoch', fontsize=12)
|
||||
ax4.set_ylabel('Win Rate', fontsize=12)
|
||||
ax4.legend(fontsize=10)
|
||||
ax4.grid(True)
|
||||
|
||||
# Plot 5: Training Signal Distribution
|
||||
ax5 = fig.add_subplot(gs[2, 0])
|
||||
ax5.stackplot(epochs, buy_train, hold_train, sell_train,
|
||||
labels=['BUY', 'HOLD', 'SELL'],
|
||||
colors=['green', 'gray', 'red'], alpha=0.7)
|
||||
ax5.set_title('Training Signal Distribution', fontsize=14)
|
||||
ax5.set_xlabel('Epoch', fontsize=12)
|
||||
ax5.set_ylabel('Proportion', fontsize=12)
|
||||
ax5.legend(fontsize=10)
|
||||
ax5.set_ylim(0, 1)
|
||||
ax5.grid(True)
|
||||
|
||||
# Plot 6: Validation Signal Distribution
|
||||
ax6 = fig.add_subplot(gs[2, 1])
|
||||
ax6.stackplot(epochs, buy_val, hold_val, sell_val,
|
||||
labels=['BUY', 'HOLD', 'SELL'],
|
||||
colors=['green', 'gray', 'red'], alpha=0.7)
|
||||
ax6.set_title('Validation Signal Distribution', fontsize=14)
|
||||
ax6.set_xlabel('Epoch', fontsize=12)
|
||||
ax6.set_ylabel('Proportion', fontsize=12)
|
||||
ax6.legend(fontsize=10)
|
||||
ax6.set_ylim(0, 1)
|
||||
ax6.grid(True)
|
||||
|
||||
# Plot 7: Performance Correlation Heatmap
|
||||
ax7 = fig.add_subplot(gs[3, :])
|
||||
im = ax7.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)
|
||||
cbar = fig.colorbar(im, ax=ax7, fraction=0.025, pad=0.04)
|
||||
cbar.set_label('Correlation', rotation=270, labelpad=20, fontsize=12)
|
||||
|
||||
# Add ticks and labels
|
||||
ax7.set_xticks(np.arange(len(labels)))
|
||||
ax7.set_yticks(np.arange(len(labels)))
|
||||
ax7.set_xticklabels(labels, rotation=45, ha="right", fontsize=10)
|
||||
ax7.set_yticklabels(labels, fontsize=10)
|
||||
|
||||
# Add text annotations
|
||||
for i in range(len(labels)):
|
||||
for j in range(len(labels)):
|
||||
text = ax7.text(j, i, f"{corr_matrix[i, j]:.2f}",
|
||||
ha="center", va="center", color="black" if abs(corr_matrix[i, j]) < 0.7 else "white")
|
||||
|
||||
ax7.set_title('Correlation Matrix of Performance Metrics', fontsize=14)
|
||||
|
||||
# Add main title
|
||||
plt.suptitle('CNN Model Performance Dashboard', fontsize=20, y=0.98)
|
||||
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.97])
|
||||
plt.savefig(f"{plots_dir}/performance_dashboard.png", dpi=300)
|
||||
plt.savefig(os.path.join(plots_dir, 'pnl_winrate.png'))
|
||||
plt.close()
|
||||
|
||||
print(f"Performance visualizations saved to {plots_dir}")
|
||||
@ -1239,7 +1045,7 @@ class CNNModelPyTorch:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return False
|
||||
|
||||
|
||||
def extract_hidden_features(self, X):
|
||||
"""
|
||||
Extract hidden features from the model - outputs from last dense layer before output.
|
||||
|
170
NN/models/dqn_agent.py
Normal file
170
NN/models/dqn_agent.py
Normal file
@ -0,0 +1,170 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
import random
|
||||
from typing import Tuple, List
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
from NN.models.simple_cnn import CNNModelPyTorch
|
||||
|
||||
class DQNAgent:
|
||||
"""
|
||||
Deep Q-Network agent for trading
|
||||
Uses CNN model as the base network
|
||||
"""
|
||||
def __init__(self,
|
||||
state_size: int,
|
||||
action_size: int,
|
||||
window_size: int,
|
||||
num_features: int,
|
||||
timeframes: List[str],
|
||||
learning_rate: float = 0.001,
|
||||
gamma: float = 0.99,
|
||||
epsilon: float = 1.0,
|
||||
epsilon_min: float = 0.01,
|
||||
epsilon_decay: float = 0.995,
|
||||
memory_size: int = 10000,
|
||||
batch_size: int = 64,
|
||||
target_update: int = 10):
|
||||
|
||||
self.state_size = state_size
|
||||
self.action_size = action_size
|
||||
self.window_size = window_size
|
||||
self.num_features = num_features
|
||||
self.timeframes = timeframes
|
||||
self.learning_rate = learning_rate
|
||||
self.gamma = gamma
|
||||
self.epsilon = epsilon
|
||||
self.epsilon_min = epsilon_min
|
||||
self.epsilon_decay = epsilon_decay
|
||||
self.memory_size = memory_size
|
||||
self.batch_size = batch_size
|
||||
self.target_update = target_update
|
||||
|
||||
# Device configuration
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Initialize networks
|
||||
self.policy_net = CNNModelPyTorch(
|
||||
window_size=window_size,
|
||||
num_features=num_features,
|
||||
output_size=action_size,
|
||||
timeframes=timeframes
|
||||
).to(self.device)
|
||||
|
||||
self.target_net = CNNModelPyTorch(
|
||||
window_size=window_size,
|
||||
num_features=num_features,
|
||||
output_size=action_size,
|
||||
timeframes=timeframes
|
||||
).to(self.device)
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
# Initialize optimizer
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
|
||||
|
||||
# Initialize memory
|
||||
self.memory = deque(maxlen=memory_size)
|
||||
|
||||
# Training metrics
|
||||
self.update_count = 0
|
||||
self.losses = []
|
||||
|
||||
def remember(self, state: np.ndarray, action: int, reward: float,
|
||||
next_state: np.ndarray, done: bool):
|
||||
"""Store experience in memory"""
|
||||
self.memory.append((state, action, reward, next_state, done))
|
||||
|
||||
def act(self, state: np.ndarray) -> int:
|
||||
"""Choose action using epsilon-greedy policy"""
|
||||
if random.random() < self.epsilon:
|
||||
return random.randrange(self.action_size)
|
||||
|
||||
with torch.no_grad():
|
||||
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
action_probs, _ = self.policy_net(state)
|
||||
return action_probs.argmax().item()
|
||||
|
||||
def replay(self) -> float:
|
||||
"""Train on a batch of experiences"""
|
||||
if len(self.memory) < self.batch_size:
|
||||
return 0.0
|
||||
|
||||
# Sample batch
|
||||
batch = random.sample(self.memory, self.batch_size)
|
||||
states, actions, rewards, next_states, dones = zip(*batch)
|
||||
|
||||
# Convert to tensors and move to device
|
||||
states = torch.FloatTensor(np.array(states)).to(self.device)
|
||||
actions = torch.LongTensor(actions).to(self.device)
|
||||
rewards = torch.FloatTensor(rewards).to(self.device)
|
||||
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
||||
dones = torch.FloatTensor(dones).to(self.device)
|
||||
|
||||
# Get current Q values
|
||||
current_q_values, _ = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
|
||||
|
||||
# Get next Q values from target network
|
||||
with torch.no_grad():
|
||||
next_q_values, _ = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
||||
|
||||
# Compute loss
|
||||
loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
|
||||
|
||||
# Optimize
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
# Update target network if needed
|
||||
self.update_count += 1
|
||||
if self.update_count % self.target_update == 0:
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
# Decay epsilon
|
||||
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
|
||||
|
||||
return loss.item()
|
||||
|
||||
def save(self, path: str):
|
||||
"""Save model and agent state"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
|
||||
# Save policy network
|
||||
self.policy_net.save(f"{path}_policy")
|
||||
|
||||
# Save target network
|
||||
self.target_net.save(f"{path}_target")
|
||||
|
||||
# Save agent state
|
||||
state = {
|
||||
'epsilon': self.epsilon,
|
||||
'update_count': self.update_count,
|
||||
'losses': self.losses,
|
||||
'optimizer_state': self.optimizer.state_dict()
|
||||
}
|
||||
torch.save(state, f"{path}_agent_state.pt")
|
||||
|
||||
def load(self, path: str):
|
||||
"""Load model and agent state"""
|
||||
# Load policy network
|
||||
self.policy_net.load(f"{path}_policy")
|
||||
|
||||
# Load target network
|
||||
self.target_net.load(f"{path}_target")
|
||||
|
||||
# Load agent state
|
||||
state = torch.load(f"{path}_agent_state.pt")
|
||||
self.epsilon = state['epsilon']
|
||||
self.update_count = state['update_count']
|
||||
self.losses = state['losses']
|
||||
self.optimizer.load_state_dict(state['optimizer_state'])
|
130
NN/models/simple_cnn.py
Normal file
130
NN/models/simple_cnn.py
Normal file
@ -0,0 +1,130 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
import os
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Tuple
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CNNModelPyTorch(nn.Module):
|
||||
"""
|
||||
CNN model for trading signals
|
||||
Simplified version for RL training
|
||||
"""
|
||||
def __init__(self, window_size: int, num_features: int, output_size: int, timeframes: List[str]):
|
||||
super(CNNModelPyTorch, self).__init__()
|
||||
|
||||
self.window_size = window_size
|
||||
self.num_features = num_features
|
||||
self.output_size = output_size
|
||||
self.timeframes = timeframes
|
||||
|
||||
# Device configuration
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logger.info(f"Using device: {self.device}")
|
||||
|
||||
# Build model
|
||||
self.build_model()
|
||||
|
||||
# Initialize optimizer and scheduler
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=0.001)
|
||||
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.optimizer, mode='max', factor=0.5, patience=5, verbose=True
|
||||
)
|
||||
|
||||
# Move model to device
|
||||
self.to(self.device)
|
||||
|
||||
def build_model(self):
|
||||
"""Build the CNN architecture"""
|
||||
# First Convolutional Layer
|
||||
self.conv1 = nn.Conv1d(
|
||||
in_channels=self.num_features * len(self.timeframes),
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
padding=1
|
||||
)
|
||||
self.bn1 = nn.BatchNorm1d(32)
|
||||
|
||||
# Second Convolutional Layer
|
||||
self.conv2 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
|
||||
self.bn2 = nn.BatchNorm1d(64)
|
||||
|
||||
# Third Convolutional Layer
|
||||
self.conv3 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
|
||||
self.bn3 = nn.BatchNorm1d(128)
|
||||
|
||||
# Calculate size after convolutions
|
||||
conv_out_size = self.window_size * 128
|
||||
|
||||
# Fully connected layers
|
||||
self.fc1 = nn.Linear(conv_out_size, 512)
|
||||
self.fc2 = nn.Linear(512, 256)
|
||||
self.fc3 = nn.Linear(256, self.output_size)
|
||||
|
||||
# Additional output for value estimation
|
||||
self.value_fc = nn.Linear(256, 1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass through the network"""
|
||||
# Ensure input is on the correct device
|
||||
x = x.to(self.device)
|
||||
|
||||
# Reshape input: [batch, window_size, features] -> [batch, channels, window_size]
|
||||
batch_size = x.size(0)
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
# Convolutional layers
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.relu(self.bn3(self.conv3(x)))
|
||||
|
||||
# Flatten
|
||||
x = x.view(batch_size, -1)
|
||||
|
||||
# Fully connected layers
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
|
||||
# Split into advantage and value streams
|
||||
advantage = self.fc3(x)
|
||||
value = self.value_fc(x)
|
||||
|
||||
# Combine value and advantage
|
||||
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
|
||||
|
||||
return q_values, value
|
||||
|
||||
def predict(self, X):
|
||||
"""Make predictions"""
|
||||
self.eval()
|
||||
|
||||
# Convert to tensor if not already
|
||||
if not isinstance(X, torch.Tensor):
|
||||
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
|
||||
else:
|
||||
X_tensor = X.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
q_values, value = self(X_tensor)
|
||||
q_values_np = q_values.cpu().numpy()
|
||||
actions = np.argmax(q_values_np, axis=1)
|
||||
|
||||
return actions, q_values_np
|
||||
|
||||
def save(self, path: str):
|
||||
"""Save model weights"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
torch.save(self.state_dict(), f"{path}.pt")
|
||||
logger.info(f"Model saved to {path}.pt")
|
||||
|
||||
def load(self, path: str):
|
||||
"""Load model weights"""
|
||||
self.load_state_dict(torch.load(f"{path}.pt", map_location=self.device))
|
||||
self.eval()
|
||||
logger.info(f"Model loaded from {path}.pt")
|
Reference in New Issue
Block a user