added leverage slider

This commit is contained in:
Dobromir Popov
2025-05-30 22:33:41 +03:00
parent d870f74d0c
commit 7d8eca995e
21 changed files with 3205 additions and 2923 deletions

View File

@ -14,15 +14,19 @@ class TradingEnvironment(gym.Env):
"""
Trading environment implementing gym interface for reinforcement learning
Actions:
- 0: Buy
- 1: Sell
- 2: Hold
2-Action System:
- 0: SELL (or close long position)
- 1: BUY (or close short position)
Intelligent Position Management:
- When neutral: Actions enter positions
- When positioned: Actions can close or flip positions
- Different thresholds for entry vs exit decisions
State:
- OHLCV data from multiple timeframes
- Technical indicators
- Position data
- Position data and unrealized PnL
"""
def __init__(
@ -33,9 +37,11 @@ class TradingEnvironment(gym.Env):
window_size: int = 20,
max_position: float = 1.0,
reward_scaling: float = 1.0,
entry_threshold: float = 0.6, # Higher threshold for entering positions
exit_threshold: float = 0.3, # Lower threshold for exiting positions
):
"""
Initialize the trading environment.
Initialize the trading environment with 2-action system.
Args:
data_interface: DataInterface instance to get market data
@ -44,6 +50,8 @@ class TradingEnvironment(gym.Env):
window_size: Number of candles in the observation window
max_position: Maximum position size as a fraction of balance
reward_scaling: Scale factor for rewards
entry_threshold: Confidence threshold for entering new positions
exit_threshold: Confidence threshold for exiting positions
"""
super().__init__()
@ -53,21 +61,23 @@ class TradingEnvironment(gym.Env):
self.window_size = window_size
self.max_position = max_position
self.reward_scaling = reward_scaling
self.entry_threshold = entry_threshold
self.exit_threshold = exit_threshold
# Load data for primary timeframe (assuming the first one is primary)
self.timeframe = self.data_interface.timeframes[0]
self.reset_data()
# Define action and observation spaces
self.action_space = spaces.Discrete(3) # Buy, Sell, Hold
# Define action and observation spaces for 2-action system
self.action_space = spaces.Discrete(2) # 0=SELL, 1=BUY
# For observation space, we consider multiple timeframes with OHLCV data
# and additional features like technical indicators, position info, etc.
n_timeframes = len(self.data_interface.timeframes)
n_features = 5 # OHLCV data by default
# Add additional features for position, balance, etc.
additional_features = 3 # position, balance, unrealized_pnl
# Add additional features for position, balance, unrealized_pnl, etc.
additional_features = 5 # position, balance, unrealized_pnl, entry_price, position_duration
# Calculate total feature dimension
total_features = (n_timeframes * n_features * self.window_size) + additional_features
@ -79,6 +89,11 @@ class TradingEnvironment(gym.Env):
# Use tuple for state_shape that EnhancedCNN expects
self.state_shape = (total_features,)
# Position tracking for 2-action system
self.position = 0.0 # -1 (short), 0 (neutral), 1 (long)
self.entry_price = 0.0 # Price at which position was entered
self.entry_step = 0 # Step at which position was entered
# Initialize state
self.reset()
@ -103,9 +118,6 @@ class TradingEnvironment(gym.Env):
"""Reset the environment to initial state"""
# Reset trading variables
self.balance = self.initial_balance
self.position = 0.0 # No position initially
self.entry_price = 0.0
self.total_pnl = 0.0
self.trades = []
self.rewards = []
@ -119,10 +131,10 @@ class TradingEnvironment(gym.Env):
def step(self, action):
"""
Take a step in the environment.
Take a step in the environment using 2-action system with intelligent position management.
Args:
action: Action to take (0: Buy, 1: Sell, 2: Hold)
action: Action to take (0: SELL, 1: BUY)
Returns:
tuple: (observation, reward, done, info)
@ -132,7 +144,7 @@ class TradingEnvironment(gym.Env):
prev_position = self.position
prev_price = self.prices[self.current_step]
# Take action
# Take action with intelligent position management
info = {}
reward = 0
last_position_info = None
@ -141,43 +153,50 @@ class TradingEnvironment(gym.Env):
current_price = self.prices[self.current_step]
next_price = self.prices[self.current_step + 1] if self.current_step + 1 < len(self.prices) else current_price
# Process the action
if action == 0: # Buy
if self.position <= 0: # Only buy if not already long
# Close any existing short position
if self.position < 0:
close_pnl, last_position_info = self._close_position(current_price)
reward += close_pnl * self.reward_scaling
# Open new long position
self._open_position(1.0 * self.max_position, current_price)
logger.info(f"Buy at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
elif action == 1: # Sell
if self.position >= 0: # Only sell if not already short
# Close any existing long position
if self.position > 0:
close_pnl, last_position_info = self._close_position(current_price)
reward += close_pnl * self.reward_scaling
# Open new short position
# Implement 2-action system with position management
if action == 0: # SELL action
if self.position == 0: # No position - enter short
self._open_position(-1.0 * self.max_position, current_price)
logger.info(f"Sell at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
logger.info(f"ENTER SHORT at step {self.current_step}, price: {current_price:.4f}")
reward = -self.transaction_fee # Entry cost
elif self.position > 0: # Long position - close it
close_pnl, last_position_info = self._close_position(current_price)
reward += close_pnl * self.reward_scaling
logger.info(f"CLOSE LONG at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
elif self.position < 0: # Already short - potentially flip to long if very strong signal
# For now, just hold the short position (no action)
pass
elif action == 2: # Hold
# No action, but still calculate unrealized PnL for reward
pass
elif action == 1: # BUY action
if self.position == 0: # No position - enter long
self._open_position(1.0 * self.max_position, current_price)
logger.info(f"ENTER LONG at step {self.current_step}, price: {current_price:.4f}")
reward = -self.transaction_fee # Entry cost
elif self.position < 0: # Short position - close it
close_pnl, last_position_info = self._close_position(current_price)
reward += close_pnl * self.reward_scaling
logger.info(f"CLOSE SHORT at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
elif self.position > 0: # Already long - potentially flip to short if very strong signal
# For now, just hold the long position (no action)
pass
# Calculate unrealized PnL and add to reward
# Calculate unrealized PnL and add to reward if holding position
if self.position != 0:
unrealized_pnl = self._calculate_unrealized_pnl(next_price)
reward += unrealized_pnl * self.reward_scaling * 0.1 # Scale down unrealized PnL
# Apply time-based holding penalty to encourage decisive actions
position_duration = self.current_step - self.entry_step
holding_penalty = min(position_duration * 0.0001, 0.01) # Max 1% penalty
reward -= holding_penalty
# Apply penalties for holding a position
if self.position != 0:
# Small holding fee/interest
holding_penalty = abs(self.position) * 0.0001 # 0.01% per step
reward -= holding_penalty * self.reward_scaling
# Reward staying neutral when uncertain (no clear setup)
else:
reward += 0.0001 # Small reward for not trading without clear signals
# Move to next step
self.current_step += 1
@ -215,7 +234,7 @@ class TradingEnvironment(gym.Env):
'step': self.current_step,
'timestamp': self.timestamps[self.current_step],
'action': action,
'action_name': ['BUY', 'SELL', 'HOLD'][action],
'action_name': ['SELL', 'BUY'][action],
'price': current_price,
'position_changed': prev_position != self.position,
'prev_position': prev_position,
@ -234,7 +253,7 @@ class TradingEnvironment(gym.Env):
self.trades.append(trade_result)
# Log trade details
logger.info(f"Trade executed - Action: {['BUY', 'SELL', 'HOLD'][action]}, "
logger.info(f"Trade executed - Action: {['SELL', 'BUY'][action]}, "
f"Price: {current_price:.4f}, PnL: {realized_pnl:.4f}, "
f"Balance: {self.balance:.4f}")
@ -268,42 +287,71 @@ class TradingEnvironment(gym.Env):
else: # Short position
return -self.position * (1.0 - current_price / self.entry_price)
def _open_position(self, position_size, price):
def _open_position(self, position_size: float, entry_price: float):
"""Open a new position"""
self.position = position_size
self.entry_price = price
self.entry_price = entry_price
self.entry_step = self.current_step
def _close_position(self, price):
"""Close the current position and return PnL"""
pnl = self._calculate_unrealized_pnl(price)
# Calculate position value
position_value = abs(position_size) * entry_price
# Apply transaction fee
fee = abs(self.position) * price * self.transaction_fee
pnl -= fee
fee = position_value * self.transaction_fee
self.balance -= fee
logger.info(f"Opened position: {position_size:.4f} at {entry_price:.4f}, fee: {fee:.4f}")
def _close_position(self, exit_price: float) -> Tuple[float, Dict]:
"""Close current position and return PnL"""
if self.position == 0:
return 0.0, {}
# Calculate PnL
if self.position > 0: # Long position
pnl = (exit_price - self.entry_price) / self.entry_price
else: # Short position
pnl = (self.entry_price - exit_price) / self.entry_price
# Apply transaction fees (entry + exit)
position_value = abs(self.position) * exit_price
exit_fee = position_value * self.transaction_fee
total_fees = exit_fee # Entry fee already applied when opening
# Net PnL after fees
net_pnl = pnl - (total_fees / (abs(self.position) * self.entry_price))
# Update balance
self.balance += pnl
self.total_pnl += pnl
self.balance *= (1 + net_pnl)
self.total_pnl += net_pnl
# Store position details before resetting
last_position = {
# Track trade
position_info = {
'position_size': self.position,
'entry_price': self.entry_price,
'exit_price': price,
'pnl': pnl,
'fee': fee
'exit_price': exit_price,
'pnl': net_pnl,
'duration': self.current_step - self.entry_step,
'entry_step': self.entry_step,
'exit_step': self.current_step
}
self.trades.append(position_info)
# Update trade statistics
if net_pnl > 0:
self.winning_trades += 1
else:
self.losing_trades += 1
logger.info(f"Closed position: {self.position:.4f}, PnL: {net_pnl:.4f}, Duration: {position_info['duration']} steps")
# Reset position
self.position = 0.0
self.entry_price = 0.0
self.entry_step = 0
# Log position closure
logger.info(f"Closed position - Size: {last_position['position_size']:.4f}, "
f"Entry: {last_position['entry_price']:.4f}, Exit: {last_position['exit_price']:.4f}, "
f"PnL: {last_position['pnl']:.4f}, Fee: {last_position['fee']:.4f}")
return pnl, last_position
return net_pnl, position_info
def _get_observation(self):
"""
@ -411,7 +459,7 @@ class TradingEnvironment(gym.Env):
for trade in last_n_trades:
position_info = {
'timestamp': trade.get('timestamp', self.timestamps[trade['step']]),
'action': trade.get('action_name', ['BUY', 'SELL', 'HOLD'][trade['action']]),
'action': trade.get('action_name', ['SELL', 'BUY'][trade['action']]),
'entry_price': trade.get('entry_price', 0.0),
'exit_price': trade.get('exit_price', trade['price']),
'position_size': trade.get('position_size', self.max_position),

View File

@ -1,560 +0,0 @@
"""
Convolutional Neural Network for timeseries analysis
This module implements a deep CNN model for cryptocurrency price analysis.
The model uses multiple parallel convolutional pathways and LSTM layers
to detect patterns at different time scales.
"""
import os
import logging
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import (
Input, Conv1D, MaxPooling1D, Dense, Dropout, BatchNormalization,
LSTM, Bidirectional, Flatten, Concatenate, GlobalAveragePooling1D,
LeakyReLU, Attention
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.metrics import AUC
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
import datetime
import json
logger = logging.getLogger(__name__)
class CNNModel:
"""
Convolutional Neural Network for time series analysis.
This model uses a multi-pathway architecture with different filter sizes
to detect patterns at different time scales, combined with LSTM layers
for temporal dependencies.
"""
def __init__(self, input_shape=(20, 5), output_size=1, model_dir="NN/models/saved"):
"""
Initialize the CNN model.
Args:
input_shape (tuple): Shape of input data (sequence_length, features)
output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell)
model_dir (str): Directory to save trained models
"""
self.input_shape = input_shape
self.output_size = output_size
self.model_dir = model_dir
self.model = None
self.history = None
# Create model directory if it doesn't exist
os.makedirs(self.model_dir, exist_ok=True)
logger.info(f"Initialized CNN model with input shape {input_shape} and output size {output_size}")
def build_model(self, filters=(32, 64, 128), kernel_sizes=(3, 5, 7),
dropout_rate=0.3, learning_rate=0.001):
"""
Build the CNN model architecture.
Args:
filters (tuple): Number of filters for each convolutional pathway
kernel_sizes (tuple): Kernel sizes for each convolutional pathway
dropout_rate (float): Dropout rate for regularization
learning_rate (float): Learning rate for Adam optimizer
Returns:
The compiled model
"""
# Input layer
inputs = Input(shape=self.input_shape)
# Multiple parallel convolutional pathways with different kernel sizes
# to capture patterns at different time scales
conv_layers = []
for i, (filter_size, kernel_size) in enumerate(zip(filters, kernel_sizes)):
conv_path = Conv1D(
filters=filter_size,
kernel_size=kernel_size,
padding='same',
name=f'conv1d_{i+1}'
)(inputs)
conv_path = BatchNormalization()(conv_path)
conv_path = LeakyReLU(alpha=0.1)(conv_path)
conv_path = MaxPooling1D(pool_size=2, padding='same')(conv_path)
conv_path = Dropout(dropout_rate)(conv_path)
conv_layers.append(conv_path)
# Merge convolutional pathways
if len(conv_layers) > 1:
merged = Concatenate()(conv_layers)
else:
merged = conv_layers[0]
# Add another Conv1D layer after merging
x = Conv1D(filters=filters[-1], kernel_size=3, padding='same')(merged)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.1)(x)
x = MaxPooling1D(pool_size=2, padding='same')(x)
x = Dropout(dropout_rate)(x)
# Bidirectional LSTM for temporal dependencies
x = Bidirectional(LSTM(128, return_sequences=True))(x)
x = Dropout(dropout_rate)(x)
# Attention mechanism to focus on important time steps
x = Bidirectional(LSTM(64, return_sequences=True))(x)
# Global average pooling to reduce parameters
x = GlobalAveragePooling1D()(x)
x = Dropout(dropout_rate)(x)
# Dense layers for final classification/regression
x = Dense(64, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(dropout_rate)(x)
# Output layer
if self.output_size == 1:
# Binary classification (up/down)
outputs = Dense(1, activation='sigmoid', name='output')(x)
loss = 'binary_crossentropy'
metrics = ['accuracy', AUC()]
elif self.output_size == 3:
# Multi-class classification (buy/hold/sell)
outputs = Dense(3, activation='softmax', name='output')(x)
loss = 'categorical_crossentropy'
metrics = ['accuracy']
else:
# Regression
outputs = Dense(self.output_size, activation='linear', name='output')(x)
loss = 'mse'
metrics = ['mae']
# Create and compile model
self.model = Model(inputs=inputs, outputs=outputs)
# Compile with Adam optimizer
self.model.compile(
optimizer=Adam(learning_rate=learning_rate),
loss=loss,
metrics=metrics
)
# Log model summary
self.model.summary(print_fn=lambda x: logger.info(x))
return self.model
def train(self, X_train, y_train, batch_size=32, epochs=100, validation_split=0.2,
callbacks=None, class_weights=None):
"""
Train the CNN model on the provided data.
Args:
X_train (numpy.ndarray): Training features
y_train (numpy.ndarray): Training targets
batch_size (int): Batch size
epochs (int): Number of epochs
validation_split (float): Fraction of data to use for validation
callbacks (list): List of Keras callbacks
class_weights (dict): Class weights for imbalanced datasets
Returns:
History object containing training metrics
"""
if self.model is None:
self.build_model()
# Default callbacks if none provided
if callbacks is None:
# Create a timestamp for model checkpoints
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
callbacks = [
EarlyStopping(
monitor='val_loss',
patience=10,
restore_best_weights=True
),
ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=5,
min_lr=1e-6
),
ModelCheckpoint(
filepath=os.path.join(self.model_dir, f"cnn_model_{timestamp}.h5"),
monitor='val_loss',
save_best_only=True
)
]
# Check if y_train needs to be one-hot encoded for multi-class
if self.output_size == 3 and len(y_train.shape) == 1:
y_train = tf.keras.utils.to_categorical(y_train, num_classes=3)
# Train the model
logger.info(f"Training CNN model with {len(X_train)} samples, batch size {batch_size}, epochs {epochs}")
self.history = self.model.fit(
X_train, y_train,
batch_size=batch_size,
epochs=epochs,
validation_split=validation_split,
callbacks=callbacks,
class_weight=class_weights,
verbose=2
)
# Save the trained model
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = os.path.join(self.model_dir, f"cnn_model_final_{timestamp}.h5")
self.model.save(model_path)
logger.info(f"Model saved to {model_path}")
# Save training history
history_path = os.path.join(self.model_dir, f"cnn_model_history_{timestamp}.json")
with open(history_path, 'w') as f:
# Convert numpy values to Python native types for JSON serialization
history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()}
json.dump(history_dict, f, indent=2)
return self.history
def evaluate(self, X_test, y_test, plot_results=False):
"""
Evaluate the model on test data.
Args:
X_test (numpy.ndarray): Test features
y_test (numpy.ndarray): Test targets
plot_results (bool): Whether to plot evaluation results
Returns:
dict: Evaluation metrics
"""
if self.model is None:
raise ValueError("Model has not been built or trained yet")
# Convert y_test to one-hot encoding for multi-class
y_test_original = y_test.copy()
if self.output_size == 3 and len(y_test.shape) == 1:
y_test = tf.keras.utils.to_categorical(y_test, num_classes=3)
# Evaluate model
logger.info(f"Evaluating CNN model on {len(X_test)} samples")
eval_results = self.model.evaluate(X_test, y_test, verbose=0)
metrics = {}
for metric, value in zip(self.model.metrics_names, eval_results):
metrics[metric] = value
logger.info(f"{metric}: {value:.4f}")
# Get predictions
y_pred_prob = self.model.predict(X_test)
# Different processing based on output type
if self.output_size == 1:
# Binary classification
y_pred = (y_pred_prob > 0.5).astype(int).flatten()
# Classification report
report = classification_report(y_test, y_pred)
logger.info(f"Classification Report:\n{report}")
# Confusion matrix
cm = confusion_matrix(y_test, y_pred)
logger.info(f"Confusion Matrix:\n{cm}")
# ROC curve and AUC
fpr, tpr, _ = roc_curve(y_test, y_pred_prob)
roc_auc = auc(fpr, tpr)
metrics['auc'] = roc_auc
if plot_results:
self._plot_binary_results(y_test, y_pred, y_pred_prob, fpr, tpr, roc_auc)
elif self.output_size == 3:
# Multi-class classification
y_pred = np.argmax(y_pred_prob, axis=1)
# Classification report
report = classification_report(y_test_original, y_pred)
logger.info(f"Classification Report:\n{report}")
# Confusion matrix
cm = confusion_matrix(y_test_original, y_pred)
logger.info(f"Confusion Matrix:\n{cm}")
if plot_results:
self._plot_multiclass_results(y_test_original, y_pred, y_pred_prob)
return metrics
def predict(self, X):
"""
Make predictions on new data.
Args:
X (numpy.ndarray): Input features
Returns:
tuple: (y_pred, y_proba) where:
y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class)
y_proba is the class probability
"""
if self.model is None:
raise ValueError("Model has not been built or trained yet")
# Ensure X has the right shape
if len(X.shape) == 2:
# Single sample, add batch dimension
X = np.expand_dims(X, axis=0)
# Get predictions
y_proba = self.model.predict(X)
# Process based on output type
if self.output_size == 1:
# Binary classification
y_pred = (y_proba > 0.5).astype(int).flatten()
return y_pred, y_proba.flatten()
elif self.output_size == 3:
# Multi-class classification
y_pred = np.argmax(y_proba, axis=1)
return y_pred, y_proba
else:
# Regression
return y_proba, y_proba
def save(self, filepath=None):
"""
Save the model to disk.
Args:
filepath (str): Path to save the model
Returns:
str: Path where the model was saved
"""
if self.model is None:
raise ValueError("Model has not been built yet")
if filepath is None:
# Create a default filepath with timestamp
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filepath = os.path.join(self.model_dir, f"cnn_model_{timestamp}.h5")
self.model.save(filepath)
logger.info(f"Model saved to {filepath}")
return filepath
def load(self, filepath):
"""
Load a saved model from disk.
Args:
filepath (str): Path to the saved model
Returns:
The loaded model
"""
self.model = load_model(filepath)
logger.info(f"Model loaded from {filepath}")
return self.model
def extract_hidden_features(self, X):
"""
Extract features from the last hidden layer of the CNN for transfer learning.
Args:
X (numpy.ndarray): Input data
Returns:
numpy.ndarray: Extracted features
"""
if self.model is None:
raise ValueError("Model has not been built or trained yet")
# Create a new model that outputs the features from the layer before the output
feature_layer_name = self.model.layers[-2].name
feature_extractor = Model(
inputs=self.model.input,
outputs=self.model.get_layer(feature_layer_name).output
)
# Extract features
features = feature_extractor.predict(X)
return features
def _plot_binary_results(self, y_true, y_pred, y_proba, fpr, tpr, roc_auc):
"""
Plot evaluation results for binary classification.
Args:
y_true (numpy.ndarray): True labels
y_pred (numpy.ndarray): Predicted labels
y_proba (numpy.ndarray): Prediction probabilities
fpr (numpy.ndarray): False positive rates for ROC curve
tpr (numpy.ndarray): True positive rates for ROC curve
roc_auc (float): Area under ROC curve
"""
plt.figure(figsize=(15, 5))
# Confusion Matrix
plt.subplot(1, 3, 1)
cm = confusion_matrix(y_true, y_pred)
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = [0, 1]
plt.xticks(tick_marks, ['0', '1'])
plt.yticks(tick_marks, ['0', '1'])
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
# Add text annotations to confusion matrix
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
plt.text(j, i, format(cm[i, j], 'd'),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
# Histogram of prediction probabilities
plt.subplot(1, 3, 2)
plt.hist(y_proba[y_true == 0], alpha=0.5, label='Class 0')
plt.hist(y_proba[y_true == 1], alpha=0.5, label='Class 1')
plt.title('Prediction Probabilities')
plt.xlabel('Probability of Class 1')
plt.ylabel('Count')
plt.legend()
# ROC Curve
plt.subplot(1, 3, 3)
plt.plot(fpr, tpr, label=f'ROC Curve (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.tight_layout()
# Save figure
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
fig_path = os.path.join(self.model_dir, f"cnn_evaluation_{timestamp}.png")
plt.savefig(fig_path)
plt.close()
logger.info(f"Evaluation plots saved to {fig_path}")
def _plot_multiclass_results(self, y_true, y_pred, y_proba):
"""
Plot evaluation results for multi-class classification.
Args:
y_true (numpy.ndarray): True labels
y_pred (numpy.ndarray): Predicted labels
y_proba (numpy.ndarray): Prediction probabilities
"""
plt.figure(figsize=(12, 5))
# Confusion Matrix
plt.subplot(1, 2, 1)
cm = confusion_matrix(y_true, y_pred)
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
classes = ['BUY', 'HOLD', 'SELL'] # Assumes classes are 0, 1, 2
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
# Add text annotations to confusion matrix
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
plt.text(j, i, format(cm[i, j], 'd'),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
# Class probability distributions
plt.subplot(1, 2, 2)
for i, cls in enumerate(classes):
plt.hist(y_proba[y_true == i, i], alpha=0.5, label=f'Class {cls}')
plt.title('Class Probability Distributions')
plt.xlabel('Probability')
plt.ylabel('Count')
plt.legend()
plt.tight_layout()
# Save figure
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
fig_path = os.path.join(self.model_dir, f"cnn_multiclass_evaluation_{timestamp}.png")
plt.savefig(fig_path)
plt.close()
logger.info(f"Multiclass evaluation plots saved to {fig_path}")
def plot_training_history(self):
"""
Plot training history (loss and metrics).
Returns:
str: Path to the saved plot
"""
if self.history is None:
raise ValueError("Model has not been trained yet")
plt.figure(figsize=(12, 5))
# Plot loss
plt.subplot(1, 2, 1)
plt.plot(self.history.history['loss'], label='Training Loss')
if 'val_loss' in self.history.history:
plt.plot(self.history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
# Plot accuracy
plt.subplot(1, 2, 2)
if 'accuracy' in self.history.history:
plt.plot(self.history.history['accuracy'], label='Training Accuracy')
if 'val_accuracy' in self.history.history:
plt.plot(self.history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
elif 'mae' in self.history.history:
plt.plot(self.history.history['mae'], label='Training MAE')
if 'val_mae' in self.history.history:
plt.plot(self.history.history['val_mae'], label='Validation MAE')
plt.title('Model MAE')
plt.ylabel('MAE')
plt.xlabel('Epoch')
plt.legend()
plt.tight_layout()
# Save figure
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
fig_path = os.path.join(self.model_dir, f"cnn_training_history_{timestamp}.png")
plt.savefig(fig_path)
plt.close()
logger.info(f"Training history plot saved to {fig_path}")
return fig_path

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,7 @@ import os
import sys
import logging
import torch.nn.functional as F
import time
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
@ -23,16 +24,16 @@ class DQNAgent:
"""
def __init__(self,
state_shape: Tuple[int, ...],
n_actions: int,
learning_rate: float = 0.0005, # Reduced learning rate for more stability
gamma: float = 0.97, # Slightly reduced discount factor
n_actions: int = 2,
learning_rate: float = 0.001,
epsilon: float = 1.0,
epsilon_min: float = 0.05, # Increased minimum epsilon for more exploration
epsilon_decay: float = 0.9975, # Slower decay rate
buffer_size: int = 20000, # Increased memory size
batch_size: int = 128, # Larger batch size
target_update: int = 5, # More frequent target updates
device=None): # Device for computations
epsilon_min: float = 0.01,
epsilon_decay: float = 0.995,
buffer_size: int = 10000,
batch_size: int = 32,
target_update: int = 100,
priority_memory: bool = True,
device=None):
# Extract state dimensions
if isinstance(state_shape, tuple) and len(state_shape) > 1:
@ -48,11 +49,9 @@ class DQNAgent:
# Store parameters
self.n_actions = n_actions
self.learning_rate = learning_rate
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_min = epsilon_min
self.epsilon_decay = epsilon_decay
self.epsilon_start = epsilon # Store initial epsilon value for resets/bumps
self.buffer_size = buffer_size
self.batch_size = batch_size
self.target_update = target_update
@ -127,10 +126,41 @@ class DQNAgent:
self.max_confidence = 0.0
self.min_confidence = 1.0
# Enhanced features from EnhancedDQNAgent
# Market adaptation capabilities
self.market_regime_weights = {
'trending': 1.2, # Higher confidence in trending markets
'ranging': 0.8, # Lower confidence in ranging markets
'volatile': 0.6 # Much lower confidence in volatile markets
}
# Dueling network support (requires enhanced network architecture)
self.use_dueling = True
# Prioritized experience replay parameters
self.use_prioritized_replay = priority_memory
self.alpha = 0.6 # Priority exponent
self.beta = 0.4 # Importance sampling exponent
self.beta_increment = 0.001
# Double DQN support
self.use_double_dqn = True
# Enhanced training features from EnhancedDQNAgent
self.target_update_freq = target_update # More descriptive name
self.training_steps = 0
self.gradient_clip_norm = 1.0 # Gradient clipping
# Enhanced statistics tracking
self.epsilon_history = []
self.td_errors = [] # Track TD errors for analysis
# Trade action fee and confidence thresholds
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
self.recent_actions = [] # Track recent actions to avoid oscillations
self.recent_actions = deque(maxlen=10)
self.recent_prices = deque(maxlen=20)
self.recent_rewards = deque(maxlen=100)
# Violent move detection
self.price_history = []
@ -173,6 +203,16 @@ class DQNAgent:
total_params = sum(p.numel() for p in self.policy_net.parameters())
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
# Position management for 2-action system
self.current_position = 0.0 # -1 (short), 0 (neutral), 1 (long)
self.position_entry_price = 0.0
self.position_entry_time = None
# Different thresholds for entry vs exit decisions
self.entry_confidence_threshold = 0.7 # High threshold for new positions
self.exit_confidence_threshold = 0.3 # Lower threshold for closing positions
self.uncertainty_threshold = 0.1 # When to stay neutral
def move_models_to_device(self, device=None):
"""Move models to the specified device (GPU/CPU)"""
if device is not None:
@ -290,247 +330,148 @@ class DQNAgent:
if len(self.price_movement_memory) > self.buffer_size // 4:
self.price_movement_memory = self.price_movement_memory[-(self.buffer_size // 4):]
def act(self, state: np.ndarray, explore=True) -> int:
"""Choose action using epsilon-greedy policy with explore flag"""
if explore and random.random() < self.epsilon:
return random.randrange(self.n_actions)
def act(self, state: np.ndarray, explore=True, current_price=None, market_context=None) -> int:
"""
Choose action based on current state using 2-action system with intelligent position management
with torch.no_grad():
# Enhance state with real-time tick features
enhanced_state = self._enhance_state_with_tick_features(state)
Args:
state: Current market state
explore: Whether to use epsilon-greedy exploration
current_price: Current market price for position management
market_context: Additional market context for decision making
# Ensure state is normalized before inference
state_tensor = self._normalize_state(enhanced_state)
state_tensor = torch.FloatTensor(state_tensor).unsqueeze(0).to(self.device)
# Get predictions using the policy network
self.policy_net.eval() # Set to evaluation mode for inference
action_probs, extrema_pred, price_predictions, hidden_features, advanced_predictions = self.policy_net(state_tensor)
self.policy_net.train() # Back to training mode
# Store hidden features for integration
self.last_hidden_features = hidden_features.cpu().numpy()
# Track feature history (limited size)
self.feature_history.append(hidden_features.cpu().numpy())
if len(self.feature_history) > 100:
self.feature_history = self.feature_history[-100:]
# Get the predicted extrema class (0=bottom, 1=top, 2=neither)
extrema_class = extrema_pred.argmax(dim=1).item()
extrema_confidence = torch.softmax(extrema_pred, dim=1)[0, extrema_class].item()
# Log extrema prediction for significant signals
if extrema_confidence > 0.7 and extrema_class != 2: # Only log strong top/bottom signals
extrema_type = "BOTTOM" if extrema_class == 0 else "TOP" if extrema_class == 1 else "NEITHER"
logger.info(f"High confidence {extrema_type} detected! Confidence: {extrema_confidence:.4f}")
# Process price predictions
price_immediate = torch.softmax(price_predictions['immediate'], dim=1)
price_midterm = torch.softmax(price_predictions['midterm'], dim=1)
price_longterm = torch.softmax(price_predictions['longterm'], dim=1)
price_values = price_predictions['values']
# Get predicted direction for each timeframe (0=down, 1=sideways, 2=up)
immediate_direction = price_immediate.argmax(dim=1).item()
midterm_direction = price_midterm.argmax(dim=1).item()
longterm_direction = price_longterm.argmax(dim=1).item()
# Get confidence levels
immediate_conf = price_immediate[0, immediate_direction].item()
midterm_conf = price_midterm[0, midterm_direction].item()
longterm_conf = price_longterm[0, longterm_direction].item()
# Get predicted price change percentages
price_changes = price_values[0].tolist()
# Log significant price movement predictions
timeframes = ["1s/1m", "1h", "1d", "1w"]
directions = ["DOWN", "SIDEWAYS", "UP"]
for i, (direction, conf) in enumerate([
(immediate_direction, immediate_conf),
(midterm_direction, midterm_conf),
(longterm_direction, longterm_conf)
]):
if conf > 0.7 and direction != 1: # Only log high confidence non-sideways predictions
logger.info(f"Price prediction: {timeframes[i]} -> {directions[direction]}, "
f"Confidence: {conf:.4f}, Expected change: {price_changes[i]:.2f}%")
# Store predictions for environment to use
self.last_extrema_pred = {
'class': extrema_class,
'confidence': extrema_confidence,
'raw': extrema_pred.cpu().numpy()
}
self.last_price_pred = {
'immediate': {
'direction': immediate_direction,
'confidence': immediate_conf,
'change': price_changes[0]
},
'midterm': {
'direction': midterm_direction,
'confidence': midterm_conf,
'change': price_changes[1]
},
'longterm': {
'direction': longterm_direction,
'confidence': longterm_conf,
'change': price_changes[2]
}
}
# Get the action with highest Q-value
action = action_probs.argmax().item()
# Calculate overall confidence in the action
q_values_softmax = F.softmax(action_probs, dim=1)[0]
action_confidence = q_values_softmax[action].item()
# Track confidence metrics
self.confidence_history.append(action_confidence)
if len(self.confidence_history) > 100:
self.confidence_history = self.confidence_history[-100:]
# Update confidence metrics
self.avg_confidence = sum(self.confidence_history) / len(self.confidence_history)
self.max_confidence = max(self.max_confidence, action_confidence)
self.min_confidence = min(self.min_confidence, action_confidence)
# Log average confidence occasionally
if random.random() < 0.01: # 1% of the time
logger.info(f"Confidence metrics - Current: {action_confidence:.4f}, Avg: {self.avg_confidence:.4f}, " +
f"Min: {self.min_confidence:.4f}, Max: {self.max_confidence:.4f}")
# Track price for violent move detection
try:
# Extract current price from state (assuming it's in the last position)
if len(state.shape) > 1: # For 2D state
current_price = state[-1, -1]
else: # For 1D state
current_price = state[-1]
self.price_history.append(current_price)
if len(self.price_history) > self.volatility_window:
self.price_history = self.price_history[-self.volatility_window:]
# Detect violent price moves if we have enough price history
if len(self.price_history) >= 5:
# Calculate short-term volatility
recent_prices = self.price_history[-5:]
# Make sure we're working with scalar values, not arrays
if isinstance(recent_prices[0], np.ndarray):
# If prices are arrays, extract the last value (current price)
recent_prices = [p[-1] if isinstance(p, np.ndarray) and p.size > 0 else p for p in recent_prices]
# Calculate price changes with protection against division by zero
price_changes = []
for i in range(1, len(recent_prices)):
if recent_prices[i-1] != 0 and not np.isnan(recent_prices[i-1]) and not np.isnan(recent_prices[i]):
change = (recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1]
price_changes.append(change)
else:
price_changes.append(0.0)
# Calculate volatility as sum of absolute price changes
volatility = sum([abs(change) for change in price_changes])
# Check if we've had a violent move
if volatility > self.volatility_threshold:
logger.info(f"Violent price move detected! Volatility: {volatility:.6f}")
self.post_violent_move = True
self.violent_move_cooldown = 10 # Set cooldown period
# Handle post-violent move period
if self.post_violent_move:
if self.violent_move_cooldown > 0:
self.violent_move_cooldown -= 1
# Increase confidence threshold temporarily after violent moves
effective_threshold = self.minimum_action_confidence * 1.1
logger.info(f"Post-violent move period: {self.violent_move_cooldown} steps remaining. " +
f"Using higher confidence threshold: {effective_threshold:.4f}")
else:
self.post_violent_move = False
logger.info("Post-violent move period ended")
except Exception as e:
logger.warning(f"Error in violent move detection: {str(e)}")
# Apply trade action fee to buy/sell actions but not to hold
# This creates a threshold that must be exceeded to justify a trade
action_values = action_probs.clone()
# If BUY or SELL, apply fee by reducing the Q-value
if action == 0 or action == 1: # BUY or SELL
# Check if confidence is above minimum threshold
effective_threshold = self.minimum_action_confidence
if self.post_violent_move:
effective_threshold *= 1.1 # Higher threshold after violent moves
if action_confidence < effective_threshold:
# If confidence is below threshold, force HOLD action
logger.info(f"Action {action} confidence {action_confidence:.4f} below threshold {effective_threshold}, forcing HOLD")
action = 2 # HOLD
else:
# Apply trade action fee to ensure we only trade when there's clear benefit
fee_adjusted_action_values = action_values.clone()
fee_adjusted_action_values[0, 0] -= self.trade_action_fee # Reduce BUY value
fee_adjusted_action_values[0, 1] -= self.trade_action_fee # Reduce SELL value
# Hold value remains unchanged
# Re-determine the action based on fee-adjusted values
fee_adjusted_action = fee_adjusted_action_values.argmax().item()
# If the fee changes our decision, log this
if fee_adjusted_action != action:
logger.info(f"Trade action fee changed decision from {action} to {fee_adjusted_action}")
action = fee_adjusted_action
# Adjust action based on extrema and price predictions
# Prioritize short-term movement for trading decisions
if immediate_conf > 0.8: # Only adjust for strong signals
if immediate_direction == 2: # UP prediction
# Bias toward BUY for strong up predictions
if action != 0 and action != 2 and random.random() < 0.3 * immediate_conf:
logger.info(f"Adjusting action to BUY based on immediate UP prediction")
action = 0 # BUY
elif immediate_direction == 0: # DOWN prediction
# Bias toward SELL for strong down predictions
if action != 1 and action != 2 and random.random() < 0.3 * immediate_conf:
logger.info(f"Adjusting action to SELL based on immediate DOWN prediction")
action = 1 # SELL
# Also consider extrema detection for action adjustment
if extrema_confidence > 0.8: # Only adjust for strong signals
if extrema_class == 0: # Bottom detected
# Bias toward BUY at bottoms
if action != 0 and action != 2 and random.random() < 0.3 * extrema_confidence:
logger.info(f"Adjusting action to BUY based on bottom detection")
action = 0 # BUY
elif extrema_class == 1: # Top detected
# Bias toward SELL at tops
if action != 1 and action != 2 and random.random() < 0.3 * extrema_confidence:
logger.info(f"Adjusting action to SELL based on top detection")
action = 1 # SELL
# Finally, avoid action oscillation by checking recent history
if len(self.recent_actions) >= 2:
last_action = self.recent_actions[-1]
if action != last_action and action != 2 and last_action != 2:
# We're switching between BUY and SELL too quickly
# Only allow this if we have very high confidence
if action_confidence < 0.85:
logger.info(f"Preventing oscillation from {last_action} to {action}, forcing HOLD")
action = 2 # HOLD
# Update recent actions list
Returns:
int: Action (0=SELL, 1=BUY) or None if should hold position
"""
# Convert state to tensor
if isinstance(state, np.ndarray):
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
else:
state_tensor = state.unsqueeze(0).to(self.device)
# Get Q-values
q_values = self.policy_net(state_tensor)
action_values = q_values.cpu().data.numpy()[0]
# Calculate confidence scores
sell_confidence = torch.softmax(q_values, dim=1)[0, 0].item()
buy_confidence = torch.softmax(q_values, dim=1)[0, 1].item()
# Determine action based on current position and confidence thresholds
action = self._determine_action_with_position_management(
sell_confidence, buy_confidence, current_price, market_context, explore
)
# Update tracking
if current_price:
self.recent_prices.append(current_price)
if action is not None:
self.recent_actions.append(action)
if len(self.recent_actions) > 5:
self.recent_actions = self.recent_actions[-5:]
return action
else:
# Return None to indicate HOLD (don't change position)
return None
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]:
"""Choose action with confidence score adapted to market regime (from Enhanced DQN)"""
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
q_values = self.policy_net(state_tensor)
# Convert Q-values to probabilities
action_probs = torch.softmax(q_values, dim=1)
action = q_values.argmax().item()
base_confidence = action_probs[0, action].item()
# Adapt confidence based on market regime
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
adapted_confidence = min(base_confidence * regime_weight, 1.0)
return action, adapted_confidence
def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore):
"""
Determine action based on current position and confidence thresholds
This implements the intelligent position management where:
- When neutral: Need high confidence to enter position
- When in position: Need lower confidence to exit
- Different thresholds for entry vs exit
"""
# Apply epsilon-greedy exploration
if explore and np.random.random() <= self.epsilon:
return np.random.choice([0, 1])
# Get the dominant signal
dominant_action = 0 if sell_conf > buy_conf else 1
dominant_confidence = max(sell_conf, buy_conf)
# Decision logic based on current position
if self.current_position == 0: # No position - need high confidence to enter
if dominant_confidence >= self.entry_confidence_threshold:
# Strong enough signal to enter position
if dominant_action == 1: # BUY signal
self.current_position = 1.0
self.position_entry_price = current_price
self.position_entry_time = time.time()
logger.info(f"ENTERING LONG position at {current_price:.4f} with confidence {dominant_confidence:.4f}")
return 1
else: # SELL signal
self.current_position = -1.0
self.position_entry_price = current_price
self.position_entry_time = time.time()
logger.info(f"ENTERING SHORT position at {current_price:.4f} with confidence {dominant_confidence:.4f}")
return 0
else:
# Not confident enough to enter position
return None
elif self.current_position > 0: # Long position
if dominant_action == 0 and dominant_confidence >= self.exit_confidence_threshold:
# SELL signal with enough confidence to close long position
pnl = (current_price - self.position_entry_price) / self.position_entry_price if current_price and self.position_entry_price else 0
logger.info(f"CLOSING LONG position at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
self.current_position = 0.0
self.position_entry_price = 0.0
self.position_entry_time = None
return 0
elif dominant_action == 0 and dominant_confidence >= self.entry_confidence_threshold:
# Very strong SELL signal - close long and enter short
pnl = (current_price - self.position_entry_price) / self.position_entry_price if current_price and self.position_entry_price else 0
logger.info(f"FLIPPING from LONG to SHORT at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
self.current_position = -1.0
self.position_entry_price = current_price
self.position_entry_time = time.time()
return 0
else:
# Hold the long position
return None
elif self.current_position < 0: # Short position
if dominant_action == 1 and dominant_confidence >= self.exit_confidence_threshold:
# BUY signal with enough confidence to close short position
pnl = (self.position_entry_price - current_price) / self.position_entry_price if current_price and self.position_entry_price else 0
logger.info(f"CLOSING SHORT position at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
self.current_position = 0.0
self.position_entry_price = 0.0
self.position_entry_time = None
return 1
elif dominant_action == 1 and dominant_confidence >= self.entry_confidence_threshold:
# Very strong BUY signal - close short and enter long
pnl = (self.position_entry_price - current_price) / self.position_entry_price if current_price and self.position_entry_price else 0
logger.info(f"FLIPPING from SHORT to LONG at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
self.current_position = 1.0
self.position_entry_price = current_price
self.position_entry_time = time.time()
return 1
else:
# Hold the short position
return None
return None
def replay(self, experiences=None):
"""Train the model using experiences from memory"""
@ -658,10 +599,18 @@ class DQNAgent:
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Get next Q values with target network
# Enhanced Double DQN implementation
with torch.no_grad():
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
next_q_values = next_q_values.max(1)[0]
if self.use_double_dqn:
# Double DQN: Use policy network to select actions, target network to evaluate
policy_q_values, _, _, _, _ = self.policy_net(next_states)
next_actions = policy_q_values.argmax(1)
target_q_values_all, _, _, _, _ = self.target_net(next_states)
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
else:
# Standard DQN: Use target network for both selection and evaluation
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
next_q_values = next_q_values.max(1)[0]
# Check for dimension mismatch between rewards and next_q_values
if rewards.shape[0] != next_q_values.shape[0]:
@ -699,16 +648,25 @@ class DQNAgent:
# Backward pass
total_loss.backward()
# Clip gradients to avoid exploding gradients
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
# Enhanced gradient clipping with configurable norm
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), self.gradient_clip_norm)
# Update weights
self.optimizer.step()
# Update target network if needed
self.update_count += 1
if self.update_count % self.target_update == 0:
# Enhanced target network update tracking
self.training_steps += 1
if self.training_steps % self.target_update_freq == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
logger.debug(f"Target network updated at step {self.training_steps}")
# Enhanced statistics tracking
self.epsilon_history.append(self.epsilon)
# Calculate and store TD error for analysis
with torch.no_grad():
td_error = torch.abs(current_q_values - target_q_values).mean().item()
self.td_errors.append(td_error)
# Return loss
return total_loss.item()
@ -1168,4 +1126,40 @@ class DQNAgent:
logger.info(f"Agent state loaded from {path}_agent_state.pt")
except FileNotFoundError:
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
def get_position_info(self):
"""Get current position information"""
return {
'position': self.current_position,
'entry_price': self.position_entry_price,
'entry_time': self.position_entry_time,
'entry_threshold': self.entry_confidence_threshold,
'exit_threshold': self.exit_confidence_threshold
}
def get_enhanced_training_stats(self):
"""Get enhanced RL training statistics with detailed metrics (from EnhancedDQNAgent)"""
return {
'buffer_size': len(self.memory),
'epsilon': self.epsilon,
'avg_reward': self.avg_reward,
'best_reward': self.best_reward,
'recent_rewards': list(self.recent_rewards) if hasattr(self, 'recent_rewards') else [],
'no_improvement_count': self.no_improvement_count,
# Enhanced statistics from EnhancedDQNAgent
'training_steps': self.training_steps,
'avg_td_error': np.mean(self.td_errors[-100:]) if self.td_errors else 0.0,
'recent_losses': self.losses[-10:] if self.losses else [],
'epsilon_trend': self.epsilon_history[-20:] if self.epsilon_history else [],
'specialized_buffers': {
'extrema_memory': len(self.extrema_memory),
'positive_memory': len(self.positive_memory),
'price_movement_memory': len(self.price_movement_memory)
},
'market_regime_weights': self.market_regime_weights,
'use_double_dqn': self.use_double_dqn,
'use_prioritized_replay': self.use_prioritized_replay,
'gradient_clip_norm': self.gradient_clip_norm,
'target_update_frequency': self.target_update_freq
}

View File

@ -1,329 +0,0 @@
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
from typing import Tuple, List
import os
import sys
import logging
import torch.nn.functional as F
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
# Import the EnhancedCNN model
from NN.models.enhanced_cnn import EnhancedCNN, ExampleSiftingDataset
# Configure logger
logger = logging.getLogger(__name__)
class EnhancedDQNAgent:
"""
Enhanced Deep Q-Network agent for trading
Uses the improved EnhancedCNN model with residual connections and attention mechanisms
"""
def __init__(self,
state_shape: Tuple[int, ...],
n_actions: int,
learning_rate: float = 0.0003, # Slightly reduced learning rate for stability
gamma: float = 0.95, # Discount factor
epsilon: float = 1.0,
epsilon_min: float = 0.05,
epsilon_decay: float = 0.995, # Slower decay for more exploration
buffer_size: int = 50000, # Larger memory buffer
batch_size: int = 128, # Larger batch size
target_update: int = 10, # More frequent target updates
confidence_threshold: float = 0.4, # Lower confidence threshold
device=None):
# Extract state dimensions
if isinstance(state_shape, tuple) and len(state_shape) > 1:
# Multi-dimensional state (like image or sequence)
self.state_dim = state_shape
else:
# 1D state
if isinstance(state_shape, tuple):
self.state_dim = state_shape[0]
else:
self.state_dim = state_shape
# Store parameters
self.n_actions = n_actions
self.learning_rate = learning_rate
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_min = epsilon_min
self.epsilon_decay = epsilon_decay
self.buffer_size = buffer_size
self.batch_size = batch_size
self.target_update = target_update
self.confidence_threshold = confidence_threshold
# Set device for computation
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
# Initialize models with the enhanced CNN
self.policy_net = EnhancedCNN(self.state_dim, self.n_actions, self.confidence_threshold)
self.target_net = EnhancedCNN(self.state_dim, self.n_actions, self.confidence_threshold)
# Initialize the target network with the same weights as the policy network
self.target_net.load_state_dict(self.policy_net.state_dict())
# Set models to eval mode (important for batch norm, dropout)
self.target_net.eval()
# Optimization components
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
self.criterion = nn.MSELoss()
# Experience replay memory with example sifting
self.memory = ExampleSiftingDataset(max_examples=buffer_size)
self.update_count = 0
# Confidence tracking
self.confidence_history = []
self.avg_confidence = 0.0
self.max_confidence = 0.0
self.min_confidence = 1.0
# Performance tracking
self.losses = []
self.rewards = []
self.avg_reward = 0.0
# Check if mixed precision training should be used
self.use_mixed_precision = False
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
self.use_mixed_precision = True
self.scaler = torch.cuda.amp.GradScaler()
logger.info("Mixed precision training enabled")
else:
logger.info("Mixed precision training disabled")
# For compatibility with old code
self.action_size = n_actions
logger.info(f"Enhanced DQN Agent using device: {self.device}")
logger.info(f"Confidence threshold set to {self.confidence_threshold}")
def move_models_to_device(self, device=None):
"""Move models to the specified device (GPU/CPU)"""
if device is not None:
self.device = device
try:
self.policy_net = self.policy_net.to(self.device)
self.target_net = self.target_net.to(self.device)
logger.info(f"Moved models to {self.device}")
return True
except Exception as e:
logger.error(f"Failed to move models to {self.device}: {str(e)}")
return False
def _normalize_state(self, state):
"""Normalize state for better training stability"""
try:
# Convert to numpy array if needed
if isinstance(state, list):
state = np.array(state, dtype=np.float32)
# Apply normalization based on state shape
if len(state.shape) > 1:
# Multi-dimensional state - normalize each feature dimension separately
for i in range(state.shape[0]):
# Skip if all zeros (to avoid division by zero)
if np.sum(np.abs(state[i])) > 0:
# Standardize each feature dimension
mean = np.mean(state[i])
std = np.std(state[i])
if std > 0:
state[i] = (state[i] - mean) / std
else:
# 1D state vector
# Skip if all zeros
if np.sum(np.abs(state)) > 0:
mean = np.mean(state)
std = np.std(state)
if std > 0:
state = (state - mean) / std
return state
except Exception as e:
logger.warning(f"Error normalizing state: {str(e)}")
return state
def remember(self, state, action, reward, next_state, done):
"""Store experience in memory with example sifting"""
self.memory.add_example(state, action, reward, next_state, done)
# Also track rewards for monitoring
self.rewards.append(reward)
if len(self.rewards) > 100:
self.rewards = self.rewards[-100:]
self.avg_reward = np.mean(self.rewards)
def act(self, state, explore=True):
"""Choose action using epsilon-greedy policy with built-in confidence thresholding"""
if explore and random.random() < self.epsilon:
return random.randrange(self.n_actions), 0.0 # Return action and zero confidence
# Normalize state before inference
normalized_state = self._normalize_state(state)
# Use the EnhancedCNN's act method which includes confidence thresholding
action, confidence = self.policy_net.act(normalized_state, explore=explore)
# Track confidence metrics
self.confidence_history.append(confidence)
if len(self.confidence_history) > 100:
self.confidence_history = self.confidence_history[-100:]
# Update confidence metrics
self.avg_confidence = sum(self.confidence_history) / len(self.confidence_history)
self.max_confidence = max(self.max_confidence, confidence)
self.min_confidence = min(self.min_confidence, confidence)
# Log average confidence occasionally
if random.random() < 0.01: # 1% of the time
logger.info(f"Confidence metrics - Current: {confidence:.4f}, Avg: {self.avg_confidence:.4f}, " +
f"Min: {self.min_confidence:.4f}, Max: {self.max_confidence:.4f}")
return action, confidence
def replay(self):
"""Train the model using experience replay with high-quality examples"""
# Check if enough samples in memory
if len(self.memory) < self.batch_size:
return 0.0
# Get batch of experiences
batch = self.memory.get_batch(self.batch_size)
if batch is None:
return 0.0
states = torch.FloatTensor(batch['states']).to(self.device)
actions = torch.LongTensor(batch['actions']).to(self.device)
rewards = torch.FloatTensor(batch['rewards']).to(self.device)
next_states = torch.FloatTensor(batch['next_states']).to(self.device)
dones = torch.FloatTensor(batch['dones']).to(self.device)
# Compute Q values
self.policy_net.train() # Set to training mode
# Get current Q values
if self.use_mixed_precision:
with torch.cuda.amp.autocast():
# Get current Q values
q_values, _, _, _ = self.policy_net(states)
current_q = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Compute target Q values
with torch.no_grad():
self.target_net.eval()
next_q_values, _, _, _ = self.target_net(next_states)
next_q = next_q_values.max(1)[0]
target_q = rewards + (1 - dones) * self.gamma * next_q
# Compute loss
loss = self.criterion(current_q, target_q)
# Perform backpropagation with mixed precision
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# Standard precision training
# Get current Q values
q_values, _, _, _ = self.policy_net(states)
current_q = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Compute target Q values
with torch.no_grad():
self.target_net.eval()
next_q_values, _, _, _ = self.target_net(next_states)
next_q = next_q_values.max(1)[0]
target_q = rewards + (1 - dones) * self.gamma * next_q
# Compute loss
loss = self.criterion(current_q, target_q)
# Perform backpropagation
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
self.optimizer.step()
# Track loss
loss_value = loss.item()
self.losses.append(loss_value)
if len(self.losses) > 100:
self.losses = self.losses[-100:]
# Update target network
self.update_count += 1
if self.update_count % self.target_update == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
logger.info(f"Updated target network (step {self.update_count})")
# Decay epsilon
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
return loss_value
def save(self, path):
"""Save agent state and models"""
self.policy_net.save(f"{path}_policy")
self.target_net.save(f"{path}_target")
# Save agent state
torch.save({
'epsilon': self.epsilon,
'confidence_threshold': self.confidence_threshold,
'losses': self.losses,
'rewards': self.rewards,
'avg_reward': self.avg_reward,
'confidence_history': self.confidence_history,
'avg_confidence': self.avg_confidence,
'max_confidence': self.max_confidence,
'min_confidence': self.min_confidence,
'update_count': self.update_count
}, f"{path}_agent_state.pt")
logger.info(f"Agent state saved to {path}_agent_state.pt")
def load(self, path):
"""Load agent state and models"""
policy_loaded = self.policy_net.load(f"{path}_policy")
target_loaded = self.target_net.load(f"{path}_target")
# Load agent state if available
agent_state_path = f"{path}_agent_state.pt"
if os.path.exists(agent_state_path):
try:
state = torch.load(agent_state_path)
self.epsilon = state.get('epsilon', self.epsilon)
self.confidence_threshold = state.get('confidence_threshold', self.confidence_threshold)
self.policy_net.confidence_threshold = self.confidence_threshold
self.target_net.confidence_threshold = self.confidence_threshold
self.losses = state.get('losses', [])
self.rewards = state.get('rewards', [])
self.avg_reward = state.get('avg_reward', 0.0)
self.confidence_history = state.get('confidence_history', [])
self.avg_confidence = state.get('avg_confidence', 0.0)
self.max_confidence = state.get('max_confidence', 0.0)
self.min_confidence = state.get('min_confidence', 1.0)
self.update_count = state.get('update_count', 0)
logger.info(f"Agent state loaded from {agent_state_path}")
except Exception as e:
logger.error(f"Error loading agent state: {str(e)}")
return policy_loaded and target_loaded

View File

@ -110,96 +110,119 @@ class EnhancedCNN(nn.Module):
logger.info(f"EnhancedCNN initialized with input shape: {input_shape}, actions: {n_actions}")
def _build_network(self):
"""Build the MASSIVELY enhanced neural network for 4GB VRAM budget"""
"""Build the ULTRA MASSIVE enhanced neural network for maximum learning capacity"""
# MASSIVELY SCALED ARCHITECTURE for 4GB VRAM (up to ~50M parameters)
# ULTRA MASSIVE SCALED ARCHITECTURE for maximum learning (up to ~100M parameters)
if self.channels > 1:
# Massive convolutional backbone with deeper residual blocks
# Ultra massive convolutional backbone with much deeper residual blocks
self.conv_layers = nn.Sequential(
# Initial large conv block
nn.Conv1d(self.channels, 256, kernel_size=7, padding=3), # Much wider initial layer
nn.BatchNorm1d(256),
# Initial ultra large conv block
nn.Conv1d(self.channels, 512, kernel_size=7, padding=3), # Ultra wide initial layer
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.1),
# First residual stage - 256 channels
ResidualBlock(256, 512),
ResidualBlock(512, 512),
ResidualBlock(512, 512),
# First residual stage - 512 channels
ResidualBlock(512, 768),
ResidualBlock(768, 768),
ResidualBlock(768, 768),
ResidualBlock(768, 768), # Additional layer
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(0.2),
# Second residual stage - 512 channels
ResidualBlock(512, 1024),
# Second residual stage - 768 to 1024 channels
ResidualBlock(768, 1024),
ResidualBlock(1024, 1024),
ResidualBlock(1024, 1024),
ResidualBlock(1024, 1024), # Additional layer
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(0.25),
# Third residual stage - 1024 channels
# Third residual stage - 1024 to 1536 channels
ResidualBlock(1024, 1536),
ResidualBlock(1536, 1536),
ResidualBlock(1536, 1536),
ResidualBlock(1536, 1536), # Additional layer
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(0.3),
# Fourth residual stage - 1536 channels (MASSIVE)
# Fourth residual stage - 1536 to 2048 channels
ResidualBlock(1536, 2048),
ResidualBlock(2048, 2048),
ResidualBlock(2048, 2048),
ResidualBlock(2048, 2048), # Additional layer
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Dropout(0.3),
# Fifth residual stage - ULTRA MASSIVE 2048 to 3072 channels
ResidualBlock(2048, 3072),
ResidualBlock(3072, 3072),
ResidualBlock(3072, 3072),
ResidualBlock(3072, 3072),
nn.AdaptiveAvgPool1d(1) # Global average pooling
)
# Massive feature dimension after conv layers
self.conv_features = 2048
# Ultra massive feature dimension after conv layers
self.conv_features = 3072
else:
# For 1D vectors, use massive dense preprocessing
# For 1D vectors, use ultra massive dense preprocessing
self.conv_layers = None
self.conv_features = 0
# MASSIVE fully connected feature extraction layers
# ULTRA MASSIVE fully connected feature extraction layers
if self.conv_layers is None:
# For 1D inputs - massive feature extraction
self.fc1 = nn.Linear(self.feature_dim, 2048)
self.features_dim = 2048
# For 1D inputs - ultra massive feature extraction
self.fc1 = nn.Linear(self.feature_dim, 3072)
self.features_dim = 3072
else:
# For data processed by massive conv layers
self.fc1 = nn.Linear(self.conv_features, 2048)
self.features_dim = 2048
# For data processed by ultra massive conv layers
self.fc1 = nn.Linear(self.conv_features, 3072)
self.features_dim = 3072
# MASSIVE common feature extraction with multiple attention layers
# ULTRA MASSIVE common feature extraction with multiple deep layers
self.fc_layers = nn.Sequential(
self.fc1,
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(2048, 2048), # Keep massive width
nn.Linear(3072, 3072), # Keep ultra massive width
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(2048, 1536), # Still very wide
nn.Linear(3072, 2560), # Ultra wide hidden layer
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1536, 1024), # Large hidden layer
nn.Linear(2560, 2048), # Still very wide
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 768), # Final feature representation
nn.Linear(2048, 1536), # Large hidden layer
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1536, 1024), # Final feature representation
nn.ReLU()
)
# Multiple attention mechanisms for different aspects
self.price_attention = SelfAttention(768)
self.volume_attention = SelfAttention(768)
self.trend_attention = SelfAttention(768)
self.volatility_attention = SelfAttention(768)
# Multiple attention mechanisms for different aspects (larger capacity)
self.price_attention = SelfAttention(1024) # Increased from 768
self.volume_attention = SelfAttention(1024)
self.trend_attention = SelfAttention(1024)
self.volatility_attention = SelfAttention(1024)
self.momentum_attention = SelfAttention(1024) # Additional attention
self.microstructure_attention = SelfAttention(1024) # Additional attention
# Attention fusion layer
# Ultra massive attention fusion layer
self.attention_fusion = nn.Sequential(
nn.Linear(768 * 4, 1024), # Combine all attention outputs
nn.Linear(1024 * 6, 2048), # Combine all 6 attention outputs
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 768)
nn.Linear(2048, 1536),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1536, 1024)
)
# MASSIVE dueling architecture with deeper networks
# ULTRA MASSIVE dueling architecture with much deeper networks
self.advantage_stream = nn.Sequential(
nn.Linear(1024, 768),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(768, 512),
nn.ReLU(),
nn.Dropout(0.3),
@ -212,6 +235,9 @@ class EnhancedCNN(nn.Module):
)
self.value_stream = nn.Sequential(
nn.Linear(1024, 768),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(768, 512),
nn.ReLU(),
nn.Dropout(0.3),
@ -223,8 +249,11 @@ class EnhancedCNN(nn.Module):
nn.Linear(128, 1)
)
# MASSIVE extrema detection head with ensemble predictions
# ULTRA MASSIVE extrema detection head with deeper ensemble predictions
self.extrema_head = nn.Sequential(
nn.Linear(1024, 768),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(768, 512),
nn.ReLU(),
nn.Dropout(0.3),
@ -236,9 +265,12 @@ class EnhancedCNN(nn.Module):
nn.Linear(128, 3) # 0=bottom, 1=top, 2=neither
)
# MASSIVE multi-timeframe price prediction heads
# ULTRA MASSIVE multi-timeframe price prediction heads
self.price_pred_immediate = nn.Sequential(
nn.Linear(768, 256),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
@ -247,7 +279,10 @@ class EnhancedCNN(nn.Module):
)
self.price_pred_midterm = nn.Sequential(
nn.Linear(768, 256),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
@ -256,7 +291,10 @@ class EnhancedCNN(nn.Module):
)
self.price_pred_longterm = nn.Sequential(
nn.Linear(768, 256),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
@ -264,8 +302,11 @@ class EnhancedCNN(nn.Module):
nn.Linear(128, 3) # Up, Down, Sideways
)
# MASSIVE value prediction with ensemble approaches
# ULTRA MASSIVE value prediction with ensemble approaches
self.price_pred_value = nn.Sequential(
nn.Linear(1024, 768),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(768, 512),
nn.ReLU(),
nn.Dropout(0.3),
@ -280,7 +321,10 @@ class EnhancedCNN(nn.Module):
# Additional specialized prediction heads for better accuracy
# Volatility prediction head
self.volatility_head = nn.Sequential(
nn.Linear(768, 256),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
@ -290,7 +334,10 @@ class EnhancedCNN(nn.Module):
# Support/Resistance level detection head
self.support_resistance_head = nn.Sequential(
nn.Linear(768, 256),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
@ -300,7 +347,10 @@ class EnhancedCNN(nn.Module):
# Market regime classification head
self.market_regime_head = nn.Sequential(
nn.Linear(768, 256),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
@ -310,7 +360,10 @@ class EnhancedCNN(nn.Module):
# Risk assessment head
self.risk_head = nn.Sequential(
nn.Linear(768, 256),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
@ -330,7 +383,7 @@ class EnhancedCNN(nn.Module):
return False
def forward(self, x):
"""Forward pass through the MASSIVE network"""
"""Forward pass through the ULTRA MASSIVE network"""
batch_size = x.size(0)
# Process different input shapes
@ -349,7 +402,7 @@ class EnhancedCNN(nn.Module):
total_features = x_reshaped.size(1) * x_reshaped.size(2)
self._check_rebuild_network(total_features)
# Apply massive convolutions
# Apply ultra massive convolutions
x_conv = self.conv_layers(x_reshaped)
# Flatten: [batch, channels, 1] -> [batch, channels]
x_flat = x_conv.view(batch_size, -1)
@ -364,33 +417,40 @@ class EnhancedCNN(nn.Module):
if x_flat.size(1) != self.feature_dim:
self._check_rebuild_network(x_flat.size(1))
# Apply MASSIVE FC layers to get base features
features = self.fc_layers(x_flat) # [batch, 768]
# Apply ULTRA MASSIVE FC layers to get base features
features = self.fc_layers(x_flat) # [batch, 1024]
# Apply multiple specialized attention mechanisms
features_3d = features.unsqueeze(1) # [batch, 1, 768]
features_3d = features.unsqueeze(1) # [batch, 1, 1024]
# Get attention-refined features for different aspects
price_features, _ = self.price_attention(features_3d)
price_features = price_features.squeeze(1) # [batch, 768]
price_features = price_features.squeeze(1) # [batch, 1024]
volume_features, _ = self.volume_attention(features_3d)
volume_features = volume_features.squeeze(1) # [batch, 768]
volume_features = volume_features.squeeze(1) # [batch, 1024]
trend_features, _ = self.trend_attention(features_3d)
trend_features = trend_features.squeeze(1) # [batch, 768]
trend_features = trend_features.squeeze(1) # [batch, 1024]
volatility_features, _ = self.volatility_attention(features_3d)
volatility_features = volatility_features.squeeze(1) # [batch, 768]
volatility_features = volatility_features.squeeze(1) # [batch, 1024]
momentum_features, _ = self.momentum_attention(features_3d)
momentum_features = momentum_features.squeeze(1) # [batch, 1024]
microstructure_features, _ = self.microstructure_attention(features_3d)
microstructure_features = microstructure_features.squeeze(1) # [batch, 1024]
# Fuse all attention outputs
combined_attention = torch.cat([
price_features, volume_features,
trend_features, volatility_features
], dim=1) # [batch, 768*4]
trend_features, volatility_features,
momentum_features, microstructure_features
], dim=1) # [batch, 1024*6]
# Apply attention fusion to get final refined features
features_refined = self.attention_fusion(combined_attention) # [batch, 768]
features_refined = self.attention_fusion(combined_attention) # [batch, 1024]
# Calculate advantage and value (Dueling DQN architecture)
advantage = self.advantage_stream(features_refined)
@ -399,7 +459,7 @@ class EnhancedCNN(nn.Module):
# Combine for Q-values (Dueling architecture)
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
# Get massive ensemble of predictions
# Get ultra massive ensemble of predictions
# Extrema predictions (bottom/top/neither detection)
extrema_pred = self.extrema_head(features_refined)
@ -435,7 +495,7 @@ class EnhancedCNN(nn.Module):
return q_values, extrema_pred, price_predictions, features_refined, advanced_predictions
def act(self, state, explore=True):
"""Enhanced action selection with massive model predictions"""
"""Enhanced action selection with ultra massive model predictions"""
if explore and np.random.random() < 0.1: # 10% random exploration
return np.random.choice(self.n_actions)
@ -471,7 +531,7 @@ class EnhancedCNN(nn.Module):
risk_class = torch.argmax(risk, dim=1).item()
risk_labels = ['Low Risk', 'Medium Risk', 'High Risk', 'Extreme Risk']
logger.info(f"MASSIVE Model Predictions:")
logger.info(f"ULTRA MASSIVE Model Predictions:")
logger.info(f" Volatility: {volatility_labels[volatility_class]} ({volatility[0, volatility_class]:.3f})")
logger.info(f" Support/Resistance: {sr_labels[sr_class]} ({sr[0, sr_class]:.3f})")
logger.info(f" Market Regime: {regime_labels[regime_class]} ({regime[0, regime_class]:.3f})")