added leverage slider
This commit is contained in:
@ -14,15 +14,19 @@ class TradingEnvironment(gym.Env):
|
||||
"""
|
||||
Trading environment implementing gym interface for reinforcement learning
|
||||
|
||||
Actions:
|
||||
- 0: Buy
|
||||
- 1: Sell
|
||||
- 2: Hold
|
||||
2-Action System:
|
||||
- 0: SELL (or close long position)
|
||||
- 1: BUY (or close short position)
|
||||
|
||||
Intelligent Position Management:
|
||||
- When neutral: Actions enter positions
|
||||
- When positioned: Actions can close or flip positions
|
||||
- Different thresholds for entry vs exit decisions
|
||||
|
||||
State:
|
||||
- OHLCV data from multiple timeframes
|
||||
- Technical indicators
|
||||
- Position data
|
||||
- Position data and unrealized PnL
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -33,9 +37,11 @@ class TradingEnvironment(gym.Env):
|
||||
window_size: int = 20,
|
||||
max_position: float = 1.0,
|
||||
reward_scaling: float = 1.0,
|
||||
entry_threshold: float = 0.6, # Higher threshold for entering positions
|
||||
exit_threshold: float = 0.3, # Lower threshold for exiting positions
|
||||
):
|
||||
"""
|
||||
Initialize the trading environment.
|
||||
Initialize the trading environment with 2-action system.
|
||||
|
||||
Args:
|
||||
data_interface: DataInterface instance to get market data
|
||||
@ -44,6 +50,8 @@ class TradingEnvironment(gym.Env):
|
||||
window_size: Number of candles in the observation window
|
||||
max_position: Maximum position size as a fraction of balance
|
||||
reward_scaling: Scale factor for rewards
|
||||
entry_threshold: Confidence threshold for entering new positions
|
||||
exit_threshold: Confidence threshold for exiting positions
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -53,21 +61,23 @@ class TradingEnvironment(gym.Env):
|
||||
self.window_size = window_size
|
||||
self.max_position = max_position
|
||||
self.reward_scaling = reward_scaling
|
||||
self.entry_threshold = entry_threshold
|
||||
self.exit_threshold = exit_threshold
|
||||
|
||||
# Load data for primary timeframe (assuming the first one is primary)
|
||||
self.timeframe = self.data_interface.timeframes[0]
|
||||
self.reset_data()
|
||||
|
||||
# Define action and observation spaces
|
||||
self.action_space = spaces.Discrete(3) # Buy, Sell, Hold
|
||||
# Define action and observation spaces for 2-action system
|
||||
self.action_space = spaces.Discrete(2) # 0=SELL, 1=BUY
|
||||
|
||||
# For observation space, we consider multiple timeframes with OHLCV data
|
||||
# and additional features like technical indicators, position info, etc.
|
||||
n_timeframes = len(self.data_interface.timeframes)
|
||||
n_features = 5 # OHLCV data by default
|
||||
|
||||
# Add additional features for position, balance, etc.
|
||||
additional_features = 3 # position, balance, unrealized_pnl
|
||||
# Add additional features for position, balance, unrealized_pnl, etc.
|
||||
additional_features = 5 # position, balance, unrealized_pnl, entry_price, position_duration
|
||||
|
||||
# Calculate total feature dimension
|
||||
total_features = (n_timeframes * n_features * self.window_size) + additional_features
|
||||
@ -79,6 +89,11 @@ class TradingEnvironment(gym.Env):
|
||||
# Use tuple for state_shape that EnhancedCNN expects
|
||||
self.state_shape = (total_features,)
|
||||
|
||||
# Position tracking for 2-action system
|
||||
self.position = 0.0 # -1 (short), 0 (neutral), 1 (long)
|
||||
self.entry_price = 0.0 # Price at which position was entered
|
||||
self.entry_step = 0 # Step at which position was entered
|
||||
|
||||
# Initialize state
|
||||
self.reset()
|
||||
|
||||
@ -103,9 +118,6 @@ class TradingEnvironment(gym.Env):
|
||||
"""Reset the environment to initial state"""
|
||||
# Reset trading variables
|
||||
self.balance = self.initial_balance
|
||||
self.position = 0.0 # No position initially
|
||||
self.entry_price = 0.0
|
||||
self.total_pnl = 0.0
|
||||
self.trades = []
|
||||
self.rewards = []
|
||||
|
||||
@ -119,10 +131,10 @@ class TradingEnvironment(gym.Env):
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Take a step in the environment.
|
||||
Take a step in the environment using 2-action system with intelligent position management.
|
||||
|
||||
Args:
|
||||
action: Action to take (0: Buy, 1: Sell, 2: Hold)
|
||||
action: Action to take (0: SELL, 1: BUY)
|
||||
|
||||
Returns:
|
||||
tuple: (observation, reward, done, info)
|
||||
@ -132,7 +144,7 @@ class TradingEnvironment(gym.Env):
|
||||
prev_position = self.position
|
||||
prev_price = self.prices[self.current_step]
|
||||
|
||||
# Take action
|
||||
# Take action with intelligent position management
|
||||
info = {}
|
||||
reward = 0
|
||||
last_position_info = None
|
||||
@ -141,43 +153,50 @@ class TradingEnvironment(gym.Env):
|
||||
current_price = self.prices[self.current_step]
|
||||
next_price = self.prices[self.current_step + 1] if self.current_step + 1 < len(self.prices) else current_price
|
||||
|
||||
# Process the action
|
||||
if action == 0: # Buy
|
||||
if self.position <= 0: # Only buy if not already long
|
||||
# Close any existing short position
|
||||
if self.position < 0:
|
||||
close_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += close_pnl * self.reward_scaling
|
||||
|
||||
# Open new long position
|
||||
self._open_position(1.0 * self.max_position, current_price)
|
||||
logger.info(f"Buy at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
|
||||
|
||||
elif action == 1: # Sell
|
||||
if self.position >= 0: # Only sell if not already short
|
||||
# Close any existing long position
|
||||
if self.position > 0:
|
||||
close_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += close_pnl * self.reward_scaling
|
||||
|
||||
# Open new short position
|
||||
# Implement 2-action system with position management
|
||||
if action == 0: # SELL action
|
||||
if self.position == 0: # No position - enter short
|
||||
self._open_position(-1.0 * self.max_position, current_price)
|
||||
logger.info(f"Sell at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
|
||||
logger.info(f"ENTER SHORT at step {self.current_step}, price: {current_price:.4f}")
|
||||
reward = -self.transaction_fee # Entry cost
|
||||
|
||||
elif self.position > 0: # Long position - close it
|
||||
close_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += close_pnl * self.reward_scaling
|
||||
logger.info(f"CLOSE LONG at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
|
||||
|
||||
elif self.position < 0: # Already short - potentially flip to long if very strong signal
|
||||
# For now, just hold the short position (no action)
|
||||
pass
|
||||
|
||||
elif action == 2: # Hold
|
||||
# No action, but still calculate unrealized PnL for reward
|
||||
pass
|
||||
elif action == 1: # BUY action
|
||||
if self.position == 0: # No position - enter long
|
||||
self._open_position(1.0 * self.max_position, current_price)
|
||||
logger.info(f"ENTER LONG at step {self.current_step}, price: {current_price:.4f}")
|
||||
reward = -self.transaction_fee # Entry cost
|
||||
|
||||
elif self.position < 0: # Short position - close it
|
||||
close_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += close_pnl * self.reward_scaling
|
||||
logger.info(f"CLOSE SHORT at step {self.current_step}, price: {current_price:.4f}, PnL: {close_pnl:.4f}")
|
||||
|
||||
elif self.position > 0: # Already long - potentially flip to short if very strong signal
|
||||
# For now, just hold the long position (no action)
|
||||
pass
|
||||
|
||||
# Calculate unrealized PnL and add to reward
|
||||
# Calculate unrealized PnL and add to reward if holding position
|
||||
if self.position != 0:
|
||||
unrealized_pnl = self._calculate_unrealized_pnl(next_price)
|
||||
reward += unrealized_pnl * self.reward_scaling * 0.1 # Scale down unrealized PnL
|
||||
|
||||
# Apply time-based holding penalty to encourage decisive actions
|
||||
position_duration = self.current_step - self.entry_step
|
||||
holding_penalty = min(position_duration * 0.0001, 0.01) # Max 1% penalty
|
||||
reward -= holding_penalty
|
||||
|
||||
# Apply penalties for holding a position
|
||||
if self.position != 0:
|
||||
# Small holding fee/interest
|
||||
holding_penalty = abs(self.position) * 0.0001 # 0.01% per step
|
||||
reward -= holding_penalty * self.reward_scaling
|
||||
# Reward staying neutral when uncertain (no clear setup)
|
||||
else:
|
||||
reward += 0.0001 # Small reward for not trading without clear signals
|
||||
|
||||
# Move to next step
|
||||
self.current_step += 1
|
||||
@ -215,7 +234,7 @@ class TradingEnvironment(gym.Env):
|
||||
'step': self.current_step,
|
||||
'timestamp': self.timestamps[self.current_step],
|
||||
'action': action,
|
||||
'action_name': ['BUY', 'SELL', 'HOLD'][action],
|
||||
'action_name': ['SELL', 'BUY'][action],
|
||||
'price': current_price,
|
||||
'position_changed': prev_position != self.position,
|
||||
'prev_position': prev_position,
|
||||
@ -234,7 +253,7 @@ class TradingEnvironment(gym.Env):
|
||||
self.trades.append(trade_result)
|
||||
|
||||
# Log trade details
|
||||
logger.info(f"Trade executed - Action: {['BUY', 'SELL', 'HOLD'][action]}, "
|
||||
logger.info(f"Trade executed - Action: {['SELL', 'BUY'][action]}, "
|
||||
f"Price: {current_price:.4f}, PnL: {realized_pnl:.4f}, "
|
||||
f"Balance: {self.balance:.4f}")
|
||||
|
||||
@ -268,42 +287,71 @@ class TradingEnvironment(gym.Env):
|
||||
else: # Short position
|
||||
return -self.position * (1.0 - current_price / self.entry_price)
|
||||
|
||||
def _open_position(self, position_size, price):
|
||||
def _open_position(self, position_size: float, entry_price: float):
|
||||
"""Open a new position"""
|
||||
self.position = position_size
|
||||
self.entry_price = price
|
||||
self.entry_price = entry_price
|
||||
self.entry_step = self.current_step
|
||||
|
||||
def _close_position(self, price):
|
||||
"""Close the current position and return PnL"""
|
||||
pnl = self._calculate_unrealized_pnl(price)
|
||||
# Calculate position value
|
||||
position_value = abs(position_size) * entry_price
|
||||
|
||||
# Apply transaction fee
|
||||
fee = abs(self.position) * price * self.transaction_fee
|
||||
pnl -= fee
|
||||
fee = position_value * self.transaction_fee
|
||||
self.balance -= fee
|
||||
|
||||
logger.info(f"Opened position: {position_size:.4f} at {entry_price:.4f}, fee: {fee:.4f}")
|
||||
|
||||
def _close_position(self, exit_price: float) -> Tuple[float, Dict]:
|
||||
"""Close current position and return PnL"""
|
||||
if self.position == 0:
|
||||
return 0.0, {}
|
||||
|
||||
# Calculate PnL
|
||||
if self.position > 0: # Long position
|
||||
pnl = (exit_price - self.entry_price) / self.entry_price
|
||||
else: # Short position
|
||||
pnl = (self.entry_price - exit_price) / self.entry_price
|
||||
|
||||
# Apply transaction fees (entry + exit)
|
||||
position_value = abs(self.position) * exit_price
|
||||
exit_fee = position_value * self.transaction_fee
|
||||
total_fees = exit_fee # Entry fee already applied when opening
|
||||
|
||||
# Net PnL after fees
|
||||
net_pnl = pnl - (total_fees / (abs(self.position) * self.entry_price))
|
||||
|
||||
# Update balance
|
||||
self.balance += pnl
|
||||
self.total_pnl += pnl
|
||||
self.balance *= (1 + net_pnl)
|
||||
self.total_pnl += net_pnl
|
||||
|
||||
# Store position details before resetting
|
||||
last_position = {
|
||||
# Track trade
|
||||
position_info = {
|
||||
'position_size': self.position,
|
||||
'entry_price': self.entry_price,
|
||||
'exit_price': price,
|
||||
'pnl': pnl,
|
||||
'fee': fee
|
||||
'exit_price': exit_price,
|
||||
'pnl': net_pnl,
|
||||
'duration': self.current_step - self.entry_step,
|
||||
'entry_step': self.entry_step,
|
||||
'exit_step': self.current_step
|
||||
}
|
||||
|
||||
self.trades.append(position_info)
|
||||
|
||||
# Update trade statistics
|
||||
if net_pnl > 0:
|
||||
self.winning_trades += 1
|
||||
else:
|
||||
self.losing_trades += 1
|
||||
|
||||
logger.info(f"Closed position: {self.position:.4f}, PnL: {net_pnl:.4f}, Duration: {position_info['duration']} steps")
|
||||
|
||||
# Reset position
|
||||
self.position = 0.0
|
||||
self.entry_price = 0.0
|
||||
self.entry_step = 0
|
||||
|
||||
# Log position closure
|
||||
logger.info(f"Closed position - Size: {last_position['position_size']:.4f}, "
|
||||
f"Entry: {last_position['entry_price']:.4f}, Exit: {last_position['exit_price']:.4f}, "
|
||||
f"PnL: {last_position['pnl']:.4f}, Fee: {last_position['fee']:.4f}")
|
||||
|
||||
return pnl, last_position
|
||||
return net_pnl, position_info
|
||||
|
||||
def _get_observation(self):
|
||||
"""
|
||||
@ -411,7 +459,7 @@ class TradingEnvironment(gym.Env):
|
||||
for trade in last_n_trades:
|
||||
position_info = {
|
||||
'timestamp': trade.get('timestamp', self.timestamps[trade['step']]),
|
||||
'action': trade.get('action_name', ['BUY', 'SELL', 'HOLD'][trade['action']]),
|
||||
'action': trade.get('action_name', ['SELL', 'BUY'][trade['action']]),
|
||||
'entry_price': trade.get('entry_price', 0.0),
|
||||
'exit_price': trade.get('exit_price', trade['price']),
|
||||
'position_size': trade.get('position_size', self.max_position),
|
||||
|
@ -1,560 +0,0 @@
|
||||
"""
|
||||
Convolutional Neural Network for timeseries analysis
|
||||
|
||||
This module implements a deep CNN model for cryptocurrency price analysis.
|
||||
The model uses multiple parallel convolutional pathways and LSTM layers
|
||||
to detect patterns at different time scales.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import Model, load_model
|
||||
from tensorflow.keras.layers import (
|
||||
Input, Conv1D, MaxPooling1D, Dense, Dropout, BatchNormalization,
|
||||
LSTM, Bidirectional, Flatten, Concatenate, GlobalAveragePooling1D,
|
||||
LeakyReLU, Attention
|
||||
)
|
||||
from tensorflow.keras.optimizers import Adam
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
|
||||
from tensorflow.keras.metrics import AUC
|
||||
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
|
||||
import datetime
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CNNModel:
|
||||
"""
|
||||
Convolutional Neural Network for time series analysis.
|
||||
|
||||
This model uses a multi-pathway architecture with different filter sizes
|
||||
to detect patterns at different time scales, combined with LSTM layers
|
||||
for temporal dependencies.
|
||||
"""
|
||||
|
||||
def __init__(self, input_shape=(20, 5), output_size=1, model_dir="NN/models/saved"):
|
||||
"""
|
||||
Initialize the CNN model.
|
||||
|
||||
Args:
|
||||
input_shape (tuple): Shape of input data (sequence_length, features)
|
||||
output_size (int): Number of output classes (1 for binary, 3 for buy/hold/sell)
|
||||
model_dir (str): Directory to save trained models
|
||||
"""
|
||||
self.input_shape = input_shape
|
||||
self.output_size = output_size
|
||||
self.model_dir = model_dir
|
||||
self.model = None
|
||||
self.history = None
|
||||
|
||||
# Create model directory if it doesn't exist
|
||||
os.makedirs(self.model_dir, exist_ok=True)
|
||||
|
||||
logger.info(f"Initialized CNN model with input shape {input_shape} and output size {output_size}")
|
||||
|
||||
def build_model(self, filters=(32, 64, 128), kernel_sizes=(3, 5, 7),
|
||||
dropout_rate=0.3, learning_rate=0.001):
|
||||
"""
|
||||
Build the CNN model architecture.
|
||||
|
||||
Args:
|
||||
filters (tuple): Number of filters for each convolutional pathway
|
||||
kernel_sizes (tuple): Kernel sizes for each convolutional pathway
|
||||
dropout_rate (float): Dropout rate for regularization
|
||||
learning_rate (float): Learning rate for Adam optimizer
|
||||
|
||||
Returns:
|
||||
The compiled model
|
||||
"""
|
||||
# Input layer
|
||||
inputs = Input(shape=self.input_shape)
|
||||
|
||||
# Multiple parallel convolutional pathways with different kernel sizes
|
||||
# to capture patterns at different time scales
|
||||
conv_layers = []
|
||||
|
||||
for i, (filter_size, kernel_size) in enumerate(zip(filters, kernel_sizes)):
|
||||
conv_path = Conv1D(
|
||||
filters=filter_size,
|
||||
kernel_size=kernel_size,
|
||||
padding='same',
|
||||
name=f'conv1d_{i+1}'
|
||||
)(inputs)
|
||||
conv_path = BatchNormalization()(conv_path)
|
||||
conv_path = LeakyReLU(alpha=0.1)(conv_path)
|
||||
conv_path = MaxPooling1D(pool_size=2, padding='same')(conv_path)
|
||||
conv_path = Dropout(dropout_rate)(conv_path)
|
||||
conv_layers.append(conv_path)
|
||||
|
||||
# Merge convolutional pathways
|
||||
if len(conv_layers) > 1:
|
||||
merged = Concatenate()(conv_layers)
|
||||
else:
|
||||
merged = conv_layers[0]
|
||||
|
||||
# Add another Conv1D layer after merging
|
||||
x = Conv1D(filters=filters[-1], kernel_size=3, padding='same')(merged)
|
||||
x = BatchNormalization()(x)
|
||||
x = LeakyReLU(alpha=0.1)(x)
|
||||
x = MaxPooling1D(pool_size=2, padding='same')(x)
|
||||
x = Dropout(dropout_rate)(x)
|
||||
|
||||
# Bidirectional LSTM for temporal dependencies
|
||||
x = Bidirectional(LSTM(128, return_sequences=True))(x)
|
||||
x = Dropout(dropout_rate)(x)
|
||||
|
||||
# Attention mechanism to focus on important time steps
|
||||
x = Bidirectional(LSTM(64, return_sequences=True))(x)
|
||||
|
||||
# Global average pooling to reduce parameters
|
||||
x = GlobalAveragePooling1D()(x)
|
||||
x = Dropout(dropout_rate)(x)
|
||||
|
||||
# Dense layers for final classification/regression
|
||||
x = Dense(64, activation='relu')(x)
|
||||
x = BatchNormalization()(x)
|
||||
x = Dropout(dropout_rate)(x)
|
||||
|
||||
# Output layer
|
||||
if self.output_size == 1:
|
||||
# Binary classification (up/down)
|
||||
outputs = Dense(1, activation='sigmoid', name='output')(x)
|
||||
loss = 'binary_crossentropy'
|
||||
metrics = ['accuracy', AUC()]
|
||||
elif self.output_size == 3:
|
||||
# Multi-class classification (buy/hold/sell)
|
||||
outputs = Dense(3, activation='softmax', name='output')(x)
|
||||
loss = 'categorical_crossentropy'
|
||||
metrics = ['accuracy']
|
||||
else:
|
||||
# Regression
|
||||
outputs = Dense(self.output_size, activation='linear', name='output')(x)
|
||||
loss = 'mse'
|
||||
metrics = ['mae']
|
||||
|
||||
# Create and compile model
|
||||
self.model = Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
# Compile with Adam optimizer
|
||||
self.model.compile(
|
||||
optimizer=Adam(learning_rate=learning_rate),
|
||||
loss=loss,
|
||||
metrics=metrics
|
||||
)
|
||||
|
||||
# Log model summary
|
||||
self.model.summary(print_fn=lambda x: logger.info(x))
|
||||
|
||||
return self.model
|
||||
|
||||
def train(self, X_train, y_train, batch_size=32, epochs=100, validation_split=0.2,
|
||||
callbacks=None, class_weights=None):
|
||||
"""
|
||||
Train the CNN model on the provided data.
|
||||
|
||||
Args:
|
||||
X_train (numpy.ndarray): Training features
|
||||
y_train (numpy.ndarray): Training targets
|
||||
batch_size (int): Batch size
|
||||
epochs (int): Number of epochs
|
||||
validation_split (float): Fraction of data to use for validation
|
||||
callbacks (list): List of Keras callbacks
|
||||
class_weights (dict): Class weights for imbalanced datasets
|
||||
|
||||
Returns:
|
||||
History object containing training metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
self.build_model()
|
||||
|
||||
# Default callbacks if none provided
|
||||
if callbacks is None:
|
||||
# Create a timestamp for model checkpoints
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
callbacks = [
|
||||
EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
restore_best_weights=True
|
||||
),
|
||||
ReduceLROnPlateau(
|
||||
monitor='val_loss',
|
||||
factor=0.5,
|
||||
patience=5,
|
||||
min_lr=1e-6
|
||||
),
|
||||
ModelCheckpoint(
|
||||
filepath=os.path.join(self.model_dir, f"cnn_model_{timestamp}.h5"),
|
||||
monitor='val_loss',
|
||||
save_best_only=True
|
||||
)
|
||||
]
|
||||
|
||||
# Check if y_train needs to be one-hot encoded for multi-class
|
||||
if self.output_size == 3 and len(y_train.shape) == 1:
|
||||
y_train = tf.keras.utils.to_categorical(y_train, num_classes=3)
|
||||
|
||||
# Train the model
|
||||
logger.info(f"Training CNN model with {len(X_train)} samples, batch size {batch_size}, epochs {epochs}")
|
||||
self.history = self.model.fit(
|
||||
X_train, y_train,
|
||||
batch_size=batch_size,
|
||||
epochs=epochs,
|
||||
validation_split=validation_split,
|
||||
callbacks=callbacks,
|
||||
class_weight=class_weights,
|
||||
verbose=2
|
||||
)
|
||||
|
||||
# Save the trained model
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_path = os.path.join(self.model_dir, f"cnn_model_final_{timestamp}.h5")
|
||||
self.model.save(model_path)
|
||||
logger.info(f"Model saved to {model_path}")
|
||||
|
||||
# Save training history
|
||||
history_path = os.path.join(self.model_dir, f"cnn_model_history_{timestamp}.json")
|
||||
with open(history_path, 'w') as f:
|
||||
# Convert numpy values to Python native types for JSON serialization
|
||||
history_dict = {key: [float(value) for value in values] for key, values in self.history.history.items()}
|
||||
json.dump(history_dict, f, indent=2)
|
||||
|
||||
return self.history
|
||||
|
||||
def evaluate(self, X_test, y_test, plot_results=False):
|
||||
"""
|
||||
Evaluate the model on test data.
|
||||
|
||||
Args:
|
||||
X_test (numpy.ndarray): Test features
|
||||
y_test (numpy.ndarray): Test targets
|
||||
plot_results (bool): Whether to plot evaluation results
|
||||
|
||||
Returns:
|
||||
dict: Evaluation metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built or trained yet")
|
||||
|
||||
# Convert y_test to one-hot encoding for multi-class
|
||||
y_test_original = y_test.copy()
|
||||
if self.output_size == 3 and len(y_test.shape) == 1:
|
||||
y_test = tf.keras.utils.to_categorical(y_test, num_classes=3)
|
||||
|
||||
# Evaluate model
|
||||
logger.info(f"Evaluating CNN model on {len(X_test)} samples")
|
||||
eval_results = self.model.evaluate(X_test, y_test, verbose=0)
|
||||
|
||||
metrics = {}
|
||||
for metric, value in zip(self.model.metrics_names, eval_results):
|
||||
metrics[metric] = value
|
||||
logger.info(f"{metric}: {value:.4f}")
|
||||
|
||||
# Get predictions
|
||||
y_pred_prob = self.model.predict(X_test)
|
||||
|
||||
# Different processing based on output type
|
||||
if self.output_size == 1:
|
||||
# Binary classification
|
||||
y_pred = (y_pred_prob > 0.5).astype(int).flatten()
|
||||
|
||||
# Classification report
|
||||
report = classification_report(y_test, y_pred)
|
||||
logger.info(f"Classification Report:\n{report}")
|
||||
|
||||
# Confusion matrix
|
||||
cm = confusion_matrix(y_test, y_pred)
|
||||
logger.info(f"Confusion Matrix:\n{cm}")
|
||||
|
||||
# ROC curve and AUC
|
||||
fpr, tpr, _ = roc_curve(y_test, y_pred_prob)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
metrics['auc'] = roc_auc
|
||||
|
||||
if plot_results:
|
||||
self._plot_binary_results(y_test, y_pred, y_pred_prob, fpr, tpr, roc_auc)
|
||||
|
||||
elif self.output_size == 3:
|
||||
# Multi-class classification
|
||||
y_pred = np.argmax(y_pred_prob, axis=1)
|
||||
|
||||
# Classification report
|
||||
report = classification_report(y_test_original, y_pred)
|
||||
logger.info(f"Classification Report:\n{report}")
|
||||
|
||||
# Confusion matrix
|
||||
cm = confusion_matrix(y_test_original, y_pred)
|
||||
logger.info(f"Confusion Matrix:\n{cm}")
|
||||
|
||||
if plot_results:
|
||||
self._plot_multiclass_results(y_test_original, y_pred, y_pred_prob)
|
||||
|
||||
return metrics
|
||||
|
||||
def predict(self, X):
|
||||
"""
|
||||
Make predictions on new data.
|
||||
|
||||
Args:
|
||||
X (numpy.ndarray): Input features
|
||||
|
||||
Returns:
|
||||
tuple: (y_pred, y_proba) where:
|
||||
y_pred is the predicted class (0/1 for binary, 0/1/2 for multi-class)
|
||||
y_proba is the class probability
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built or trained yet")
|
||||
|
||||
# Ensure X has the right shape
|
||||
if len(X.shape) == 2:
|
||||
# Single sample, add batch dimension
|
||||
X = np.expand_dims(X, axis=0)
|
||||
|
||||
# Get predictions
|
||||
y_proba = self.model.predict(X)
|
||||
|
||||
# Process based on output type
|
||||
if self.output_size == 1:
|
||||
# Binary classification
|
||||
y_pred = (y_proba > 0.5).astype(int).flatten()
|
||||
return y_pred, y_proba.flatten()
|
||||
elif self.output_size == 3:
|
||||
# Multi-class classification
|
||||
y_pred = np.argmax(y_proba, axis=1)
|
||||
return y_pred, y_proba
|
||||
else:
|
||||
# Regression
|
||||
return y_proba, y_proba
|
||||
|
||||
def save(self, filepath=None):
|
||||
"""
|
||||
Save the model to disk.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to save the model
|
||||
|
||||
Returns:
|
||||
str: Path where the model was saved
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built yet")
|
||||
|
||||
if filepath is None:
|
||||
# Create a default filepath with timestamp
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filepath = os.path.join(self.model_dir, f"cnn_model_{timestamp}.h5")
|
||||
|
||||
self.model.save(filepath)
|
||||
logger.info(f"Model saved to {filepath}")
|
||||
return filepath
|
||||
|
||||
def load(self, filepath):
|
||||
"""
|
||||
Load a saved model from disk.
|
||||
|
||||
Args:
|
||||
filepath (str): Path to the saved model
|
||||
|
||||
Returns:
|
||||
The loaded model
|
||||
"""
|
||||
self.model = load_model(filepath)
|
||||
logger.info(f"Model loaded from {filepath}")
|
||||
return self.model
|
||||
|
||||
def extract_hidden_features(self, X):
|
||||
"""
|
||||
Extract features from the last hidden layer of the CNN for transfer learning.
|
||||
|
||||
Args:
|
||||
X (numpy.ndarray): Input data
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Extracted features
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model has not been built or trained yet")
|
||||
|
||||
# Create a new model that outputs the features from the layer before the output
|
||||
feature_layer_name = self.model.layers[-2].name
|
||||
feature_extractor = Model(
|
||||
inputs=self.model.input,
|
||||
outputs=self.model.get_layer(feature_layer_name).output
|
||||
)
|
||||
|
||||
# Extract features
|
||||
features = feature_extractor.predict(X)
|
||||
|
||||
return features
|
||||
|
||||
def _plot_binary_results(self, y_true, y_pred, y_proba, fpr, tpr, roc_auc):
|
||||
"""
|
||||
Plot evaluation results for binary classification.
|
||||
|
||||
Args:
|
||||
y_true (numpy.ndarray): True labels
|
||||
y_pred (numpy.ndarray): Predicted labels
|
||||
y_proba (numpy.ndarray): Prediction probabilities
|
||||
fpr (numpy.ndarray): False positive rates for ROC curve
|
||||
tpr (numpy.ndarray): True positive rates for ROC curve
|
||||
roc_auc (float): Area under ROC curve
|
||||
"""
|
||||
plt.figure(figsize=(15, 5))
|
||||
|
||||
# Confusion Matrix
|
||||
plt.subplot(1, 3, 1)
|
||||
cm = confusion_matrix(y_true, y_pred)
|
||||
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
||||
plt.title('Confusion Matrix')
|
||||
plt.colorbar()
|
||||
tick_marks = [0, 1]
|
||||
plt.xticks(tick_marks, ['0', '1'])
|
||||
plt.yticks(tick_marks, ['0', '1'])
|
||||
plt.xlabel('Predicted Label')
|
||||
plt.ylabel('True Label')
|
||||
|
||||
# Add text annotations to confusion matrix
|
||||
thresh = cm.max() / 2.
|
||||
for i in range(cm.shape[0]):
|
||||
for j in range(cm.shape[1]):
|
||||
plt.text(j, i, format(cm[i, j], 'd'),
|
||||
horizontalalignment="center",
|
||||
color="white" if cm[i, j] > thresh else "black")
|
||||
|
||||
# Histogram of prediction probabilities
|
||||
plt.subplot(1, 3, 2)
|
||||
plt.hist(y_proba[y_true == 0], alpha=0.5, label='Class 0')
|
||||
plt.hist(y_proba[y_true == 1], alpha=0.5, label='Class 1')
|
||||
plt.title('Prediction Probabilities')
|
||||
plt.xlabel('Probability of Class 1')
|
||||
plt.ylabel('Count')
|
||||
plt.legend()
|
||||
|
||||
# ROC Curve
|
||||
plt.subplot(1, 3, 3)
|
||||
plt.plot(fpr, tpr, label=f'ROC Curve (AUC = {roc_auc:.3f})')
|
||||
plt.plot([0, 1], [0, 1], 'k--')
|
||||
plt.xlim([0.0, 1.0])
|
||||
plt.ylim([0.0, 1.05])
|
||||
plt.xlabel('False Positive Rate')
|
||||
plt.ylabel('True Positive Rate')
|
||||
plt.title('Receiver Operating Characteristic')
|
||||
plt.legend(loc="lower right")
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Save figure
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
fig_path = os.path.join(self.model_dir, f"cnn_evaluation_{timestamp}.png")
|
||||
plt.savefig(fig_path)
|
||||
plt.close()
|
||||
|
||||
logger.info(f"Evaluation plots saved to {fig_path}")
|
||||
|
||||
def _plot_multiclass_results(self, y_true, y_pred, y_proba):
|
||||
"""
|
||||
Plot evaluation results for multi-class classification.
|
||||
|
||||
Args:
|
||||
y_true (numpy.ndarray): True labels
|
||||
y_pred (numpy.ndarray): Predicted labels
|
||||
y_proba (numpy.ndarray): Prediction probabilities
|
||||
"""
|
||||
plt.figure(figsize=(12, 5))
|
||||
|
||||
# Confusion Matrix
|
||||
plt.subplot(1, 2, 1)
|
||||
cm = confusion_matrix(y_true, y_pred)
|
||||
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
||||
plt.title('Confusion Matrix')
|
||||
plt.colorbar()
|
||||
classes = ['BUY', 'HOLD', 'SELL'] # Assumes classes are 0, 1, 2
|
||||
tick_marks = np.arange(len(classes))
|
||||
plt.xticks(tick_marks, classes)
|
||||
plt.yticks(tick_marks, classes)
|
||||
plt.xlabel('Predicted Label')
|
||||
plt.ylabel('True Label')
|
||||
|
||||
# Add text annotations to confusion matrix
|
||||
thresh = cm.max() / 2.
|
||||
for i in range(cm.shape[0]):
|
||||
for j in range(cm.shape[1]):
|
||||
plt.text(j, i, format(cm[i, j], 'd'),
|
||||
horizontalalignment="center",
|
||||
color="white" if cm[i, j] > thresh else "black")
|
||||
|
||||
# Class probability distributions
|
||||
plt.subplot(1, 2, 2)
|
||||
for i, cls in enumerate(classes):
|
||||
plt.hist(y_proba[y_true == i, i], alpha=0.5, label=f'Class {cls}')
|
||||
plt.title('Class Probability Distributions')
|
||||
plt.xlabel('Probability')
|
||||
plt.ylabel('Count')
|
||||
plt.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Save figure
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
fig_path = os.path.join(self.model_dir, f"cnn_multiclass_evaluation_{timestamp}.png")
|
||||
plt.savefig(fig_path)
|
||||
plt.close()
|
||||
|
||||
logger.info(f"Multiclass evaluation plots saved to {fig_path}")
|
||||
|
||||
def plot_training_history(self):
|
||||
"""
|
||||
Plot training history (loss and metrics).
|
||||
|
||||
Returns:
|
||||
str: Path to the saved plot
|
||||
"""
|
||||
if self.history is None:
|
||||
raise ValueError("Model has not been trained yet")
|
||||
|
||||
plt.figure(figsize=(12, 5))
|
||||
|
||||
# Plot loss
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.plot(self.history.history['loss'], label='Training Loss')
|
||||
if 'val_loss' in self.history.history:
|
||||
plt.plot(self.history.history['val_loss'], label='Validation Loss')
|
||||
plt.title('Model Loss')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
|
||||
# Plot accuracy
|
||||
plt.subplot(1, 2, 2)
|
||||
|
||||
if 'accuracy' in self.history.history:
|
||||
plt.plot(self.history.history['accuracy'], label='Training Accuracy')
|
||||
if 'val_accuracy' in self.history.history:
|
||||
plt.plot(self.history.history['val_accuracy'], label='Validation Accuracy')
|
||||
plt.title('Model Accuracy')
|
||||
plt.ylabel('Accuracy')
|
||||
elif 'mae' in self.history.history:
|
||||
plt.plot(self.history.history['mae'], label='Training MAE')
|
||||
if 'val_mae' in self.history.history:
|
||||
plt.plot(self.history.history['val_mae'], label='Validation MAE')
|
||||
plt.title('Model MAE')
|
||||
plt.ylabel('MAE')
|
||||
|
||||
plt.xlabel('Epoch')
|
||||
plt.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Save figure
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
fig_path = os.path.join(self.model_dir, f"cnn_training_history_{timestamp}.png")
|
||||
plt.savefig(fig_path)
|
||||
plt.close()
|
||||
|
||||
logger.info(f"Training history plot saved to {fig_path}")
|
||||
return fig_path
|
File diff suppressed because it is too large
Load Diff
@ -9,6 +9,7 @@ import os
|
||||
import sys
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
import time
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
@ -23,16 +24,16 @@ class DQNAgent:
|
||||
"""
|
||||
def __init__(self,
|
||||
state_shape: Tuple[int, ...],
|
||||
n_actions: int,
|
||||
learning_rate: float = 0.0005, # Reduced learning rate for more stability
|
||||
gamma: float = 0.97, # Slightly reduced discount factor
|
||||
n_actions: int = 2,
|
||||
learning_rate: float = 0.001,
|
||||
epsilon: float = 1.0,
|
||||
epsilon_min: float = 0.05, # Increased minimum epsilon for more exploration
|
||||
epsilon_decay: float = 0.9975, # Slower decay rate
|
||||
buffer_size: int = 20000, # Increased memory size
|
||||
batch_size: int = 128, # Larger batch size
|
||||
target_update: int = 5, # More frequent target updates
|
||||
device=None): # Device for computations
|
||||
epsilon_min: float = 0.01,
|
||||
epsilon_decay: float = 0.995,
|
||||
buffer_size: int = 10000,
|
||||
batch_size: int = 32,
|
||||
target_update: int = 100,
|
||||
priority_memory: bool = True,
|
||||
device=None):
|
||||
|
||||
# Extract state dimensions
|
||||
if isinstance(state_shape, tuple) and len(state_shape) > 1:
|
||||
@ -48,11 +49,9 @@ class DQNAgent:
|
||||
# Store parameters
|
||||
self.n_actions = n_actions
|
||||
self.learning_rate = learning_rate
|
||||
self.gamma = gamma
|
||||
self.epsilon = epsilon
|
||||
self.epsilon_min = epsilon_min
|
||||
self.epsilon_decay = epsilon_decay
|
||||
self.epsilon_start = epsilon # Store initial epsilon value for resets/bumps
|
||||
self.buffer_size = buffer_size
|
||||
self.batch_size = batch_size
|
||||
self.target_update = target_update
|
||||
@ -127,10 +126,41 @@ class DQNAgent:
|
||||
self.max_confidence = 0.0
|
||||
self.min_confidence = 1.0
|
||||
|
||||
# Enhanced features from EnhancedDQNAgent
|
||||
# Market adaptation capabilities
|
||||
self.market_regime_weights = {
|
||||
'trending': 1.2, # Higher confidence in trending markets
|
||||
'ranging': 0.8, # Lower confidence in ranging markets
|
||||
'volatile': 0.6 # Much lower confidence in volatile markets
|
||||
}
|
||||
|
||||
# Dueling network support (requires enhanced network architecture)
|
||||
self.use_dueling = True
|
||||
|
||||
# Prioritized experience replay parameters
|
||||
self.use_prioritized_replay = priority_memory
|
||||
self.alpha = 0.6 # Priority exponent
|
||||
self.beta = 0.4 # Importance sampling exponent
|
||||
self.beta_increment = 0.001
|
||||
|
||||
# Double DQN support
|
||||
self.use_double_dqn = True
|
||||
|
||||
# Enhanced training features from EnhancedDQNAgent
|
||||
self.target_update_freq = target_update # More descriptive name
|
||||
self.training_steps = 0
|
||||
self.gradient_clip_norm = 1.0 # Gradient clipping
|
||||
|
||||
# Enhanced statistics tracking
|
||||
self.epsilon_history = []
|
||||
self.td_errors = [] # Track TD errors for analysis
|
||||
|
||||
# Trade action fee and confidence thresholds
|
||||
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
|
||||
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
|
||||
self.recent_actions = [] # Track recent actions to avoid oscillations
|
||||
self.recent_actions = deque(maxlen=10)
|
||||
self.recent_prices = deque(maxlen=20)
|
||||
self.recent_rewards = deque(maxlen=100)
|
||||
|
||||
# Violent move detection
|
||||
self.price_history = []
|
||||
@ -173,6 +203,16 @@ class DQNAgent:
|
||||
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
||||
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
|
||||
|
||||
# Position management for 2-action system
|
||||
self.current_position = 0.0 # -1 (short), 0 (neutral), 1 (long)
|
||||
self.position_entry_price = 0.0
|
||||
self.position_entry_time = None
|
||||
|
||||
# Different thresholds for entry vs exit decisions
|
||||
self.entry_confidence_threshold = 0.7 # High threshold for new positions
|
||||
self.exit_confidence_threshold = 0.3 # Lower threshold for closing positions
|
||||
self.uncertainty_threshold = 0.1 # When to stay neutral
|
||||
|
||||
def move_models_to_device(self, device=None):
|
||||
"""Move models to the specified device (GPU/CPU)"""
|
||||
if device is not None:
|
||||
@ -290,247 +330,148 @@ class DQNAgent:
|
||||
if len(self.price_movement_memory) > self.buffer_size // 4:
|
||||
self.price_movement_memory = self.price_movement_memory[-(self.buffer_size // 4):]
|
||||
|
||||
def act(self, state: np.ndarray, explore=True) -> int:
|
||||
"""Choose action using epsilon-greedy policy with explore flag"""
|
||||
if explore and random.random() < self.epsilon:
|
||||
return random.randrange(self.n_actions)
|
||||
def act(self, state: np.ndarray, explore=True, current_price=None, market_context=None) -> int:
|
||||
"""
|
||||
Choose action based on current state using 2-action system with intelligent position management
|
||||
|
||||
with torch.no_grad():
|
||||
# Enhance state with real-time tick features
|
||||
enhanced_state = self._enhance_state_with_tick_features(state)
|
||||
Args:
|
||||
state: Current market state
|
||||
explore: Whether to use epsilon-greedy exploration
|
||||
current_price: Current market price for position management
|
||||
market_context: Additional market context for decision making
|
||||
|
||||
# Ensure state is normalized before inference
|
||||
state_tensor = self._normalize_state(enhanced_state)
|
||||
state_tensor = torch.FloatTensor(state_tensor).unsqueeze(0).to(self.device)
|
||||
|
||||
# Get predictions using the policy network
|
||||
self.policy_net.eval() # Set to evaluation mode for inference
|
||||
action_probs, extrema_pred, price_predictions, hidden_features, advanced_predictions = self.policy_net(state_tensor)
|
||||
self.policy_net.train() # Back to training mode
|
||||
|
||||
# Store hidden features for integration
|
||||
self.last_hidden_features = hidden_features.cpu().numpy()
|
||||
|
||||
# Track feature history (limited size)
|
||||
self.feature_history.append(hidden_features.cpu().numpy())
|
||||
if len(self.feature_history) > 100:
|
||||
self.feature_history = self.feature_history[-100:]
|
||||
|
||||
# Get the predicted extrema class (0=bottom, 1=top, 2=neither)
|
||||
extrema_class = extrema_pred.argmax(dim=1).item()
|
||||
extrema_confidence = torch.softmax(extrema_pred, dim=1)[0, extrema_class].item()
|
||||
|
||||
# Log extrema prediction for significant signals
|
||||
if extrema_confidence > 0.7 and extrema_class != 2: # Only log strong top/bottom signals
|
||||
extrema_type = "BOTTOM" if extrema_class == 0 else "TOP" if extrema_class == 1 else "NEITHER"
|
||||
logger.info(f"High confidence {extrema_type} detected! Confidence: {extrema_confidence:.4f}")
|
||||
|
||||
# Process price predictions
|
||||
price_immediate = torch.softmax(price_predictions['immediate'], dim=1)
|
||||
price_midterm = torch.softmax(price_predictions['midterm'], dim=1)
|
||||
price_longterm = torch.softmax(price_predictions['longterm'], dim=1)
|
||||
price_values = price_predictions['values']
|
||||
|
||||
# Get predicted direction for each timeframe (0=down, 1=sideways, 2=up)
|
||||
immediate_direction = price_immediate.argmax(dim=1).item()
|
||||
midterm_direction = price_midterm.argmax(dim=1).item()
|
||||
longterm_direction = price_longterm.argmax(dim=1).item()
|
||||
|
||||
# Get confidence levels
|
||||
immediate_conf = price_immediate[0, immediate_direction].item()
|
||||
midterm_conf = price_midterm[0, midterm_direction].item()
|
||||
longterm_conf = price_longterm[0, longterm_direction].item()
|
||||
|
||||
# Get predicted price change percentages
|
||||
price_changes = price_values[0].tolist()
|
||||
|
||||
# Log significant price movement predictions
|
||||
timeframes = ["1s/1m", "1h", "1d", "1w"]
|
||||
directions = ["DOWN", "SIDEWAYS", "UP"]
|
||||
|
||||
for i, (direction, conf) in enumerate([
|
||||
(immediate_direction, immediate_conf),
|
||||
(midterm_direction, midterm_conf),
|
||||
(longterm_direction, longterm_conf)
|
||||
]):
|
||||
if conf > 0.7 and direction != 1: # Only log high confidence non-sideways predictions
|
||||
logger.info(f"Price prediction: {timeframes[i]} -> {directions[direction]}, "
|
||||
f"Confidence: {conf:.4f}, Expected change: {price_changes[i]:.2f}%")
|
||||
|
||||
# Store predictions for environment to use
|
||||
self.last_extrema_pred = {
|
||||
'class': extrema_class,
|
||||
'confidence': extrema_confidence,
|
||||
'raw': extrema_pred.cpu().numpy()
|
||||
}
|
||||
|
||||
self.last_price_pred = {
|
||||
'immediate': {
|
||||
'direction': immediate_direction,
|
||||
'confidence': immediate_conf,
|
||||
'change': price_changes[0]
|
||||
},
|
||||
'midterm': {
|
||||
'direction': midterm_direction,
|
||||
'confidence': midterm_conf,
|
||||
'change': price_changes[1]
|
||||
},
|
||||
'longterm': {
|
||||
'direction': longterm_direction,
|
||||
'confidence': longterm_conf,
|
||||
'change': price_changes[2]
|
||||
}
|
||||
}
|
||||
|
||||
# Get the action with highest Q-value
|
||||
action = action_probs.argmax().item()
|
||||
|
||||
# Calculate overall confidence in the action
|
||||
q_values_softmax = F.softmax(action_probs, dim=1)[0]
|
||||
action_confidence = q_values_softmax[action].item()
|
||||
|
||||
# Track confidence metrics
|
||||
self.confidence_history.append(action_confidence)
|
||||
if len(self.confidence_history) > 100:
|
||||
self.confidence_history = self.confidence_history[-100:]
|
||||
|
||||
# Update confidence metrics
|
||||
self.avg_confidence = sum(self.confidence_history) / len(self.confidence_history)
|
||||
self.max_confidence = max(self.max_confidence, action_confidence)
|
||||
self.min_confidence = min(self.min_confidence, action_confidence)
|
||||
|
||||
# Log average confidence occasionally
|
||||
if random.random() < 0.01: # 1% of the time
|
||||
logger.info(f"Confidence metrics - Current: {action_confidence:.4f}, Avg: {self.avg_confidence:.4f}, " +
|
||||
f"Min: {self.min_confidence:.4f}, Max: {self.max_confidence:.4f}")
|
||||
|
||||
# Track price for violent move detection
|
||||
try:
|
||||
# Extract current price from state (assuming it's in the last position)
|
||||
if len(state.shape) > 1: # For 2D state
|
||||
current_price = state[-1, -1]
|
||||
else: # For 1D state
|
||||
current_price = state[-1]
|
||||
|
||||
self.price_history.append(current_price)
|
||||
if len(self.price_history) > self.volatility_window:
|
||||
self.price_history = self.price_history[-self.volatility_window:]
|
||||
|
||||
# Detect violent price moves if we have enough price history
|
||||
if len(self.price_history) >= 5:
|
||||
# Calculate short-term volatility
|
||||
recent_prices = self.price_history[-5:]
|
||||
|
||||
# Make sure we're working with scalar values, not arrays
|
||||
if isinstance(recent_prices[0], np.ndarray):
|
||||
# If prices are arrays, extract the last value (current price)
|
||||
recent_prices = [p[-1] if isinstance(p, np.ndarray) and p.size > 0 else p for p in recent_prices]
|
||||
|
||||
# Calculate price changes with protection against division by zero
|
||||
price_changes = []
|
||||
for i in range(1, len(recent_prices)):
|
||||
if recent_prices[i-1] != 0 and not np.isnan(recent_prices[i-1]) and not np.isnan(recent_prices[i]):
|
||||
change = (recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1]
|
||||
price_changes.append(change)
|
||||
else:
|
||||
price_changes.append(0.0)
|
||||
|
||||
# Calculate volatility as sum of absolute price changes
|
||||
volatility = sum([abs(change) for change in price_changes])
|
||||
|
||||
# Check if we've had a violent move
|
||||
if volatility > self.volatility_threshold:
|
||||
logger.info(f"Violent price move detected! Volatility: {volatility:.6f}")
|
||||
self.post_violent_move = True
|
||||
self.violent_move_cooldown = 10 # Set cooldown period
|
||||
|
||||
# Handle post-violent move period
|
||||
if self.post_violent_move:
|
||||
if self.violent_move_cooldown > 0:
|
||||
self.violent_move_cooldown -= 1
|
||||
# Increase confidence threshold temporarily after violent moves
|
||||
effective_threshold = self.minimum_action_confidence * 1.1
|
||||
logger.info(f"Post-violent move period: {self.violent_move_cooldown} steps remaining. " +
|
||||
f"Using higher confidence threshold: {effective_threshold:.4f}")
|
||||
else:
|
||||
self.post_violent_move = False
|
||||
logger.info("Post-violent move period ended")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in violent move detection: {str(e)}")
|
||||
|
||||
# Apply trade action fee to buy/sell actions but not to hold
|
||||
# This creates a threshold that must be exceeded to justify a trade
|
||||
action_values = action_probs.clone()
|
||||
|
||||
# If BUY or SELL, apply fee by reducing the Q-value
|
||||
if action == 0 or action == 1: # BUY or SELL
|
||||
# Check if confidence is above minimum threshold
|
||||
effective_threshold = self.minimum_action_confidence
|
||||
if self.post_violent_move:
|
||||
effective_threshold *= 1.1 # Higher threshold after violent moves
|
||||
|
||||
if action_confidence < effective_threshold:
|
||||
# If confidence is below threshold, force HOLD action
|
||||
logger.info(f"Action {action} confidence {action_confidence:.4f} below threshold {effective_threshold}, forcing HOLD")
|
||||
action = 2 # HOLD
|
||||
else:
|
||||
# Apply trade action fee to ensure we only trade when there's clear benefit
|
||||
fee_adjusted_action_values = action_values.clone()
|
||||
fee_adjusted_action_values[0, 0] -= self.trade_action_fee # Reduce BUY value
|
||||
fee_adjusted_action_values[0, 1] -= self.trade_action_fee # Reduce SELL value
|
||||
# Hold value remains unchanged
|
||||
|
||||
# Re-determine the action based on fee-adjusted values
|
||||
fee_adjusted_action = fee_adjusted_action_values.argmax().item()
|
||||
|
||||
# If the fee changes our decision, log this
|
||||
if fee_adjusted_action != action:
|
||||
logger.info(f"Trade action fee changed decision from {action} to {fee_adjusted_action}")
|
||||
action = fee_adjusted_action
|
||||
|
||||
# Adjust action based on extrema and price predictions
|
||||
# Prioritize short-term movement for trading decisions
|
||||
if immediate_conf > 0.8: # Only adjust for strong signals
|
||||
if immediate_direction == 2: # UP prediction
|
||||
# Bias toward BUY for strong up predictions
|
||||
if action != 0 and action != 2 and random.random() < 0.3 * immediate_conf:
|
||||
logger.info(f"Adjusting action to BUY based on immediate UP prediction")
|
||||
action = 0 # BUY
|
||||
elif immediate_direction == 0: # DOWN prediction
|
||||
# Bias toward SELL for strong down predictions
|
||||
if action != 1 and action != 2 and random.random() < 0.3 * immediate_conf:
|
||||
logger.info(f"Adjusting action to SELL based on immediate DOWN prediction")
|
||||
action = 1 # SELL
|
||||
|
||||
# Also consider extrema detection for action adjustment
|
||||
if extrema_confidence > 0.8: # Only adjust for strong signals
|
||||
if extrema_class == 0: # Bottom detected
|
||||
# Bias toward BUY at bottoms
|
||||
if action != 0 and action != 2 and random.random() < 0.3 * extrema_confidence:
|
||||
logger.info(f"Adjusting action to BUY based on bottom detection")
|
||||
action = 0 # BUY
|
||||
elif extrema_class == 1: # Top detected
|
||||
# Bias toward SELL at tops
|
||||
if action != 1 and action != 2 and random.random() < 0.3 * extrema_confidence:
|
||||
logger.info(f"Adjusting action to SELL based on top detection")
|
||||
action = 1 # SELL
|
||||
|
||||
# Finally, avoid action oscillation by checking recent history
|
||||
if len(self.recent_actions) >= 2:
|
||||
last_action = self.recent_actions[-1]
|
||||
if action != last_action and action != 2 and last_action != 2:
|
||||
# We're switching between BUY and SELL too quickly
|
||||
# Only allow this if we have very high confidence
|
||||
if action_confidence < 0.85:
|
||||
logger.info(f"Preventing oscillation from {last_action} to {action}, forcing HOLD")
|
||||
action = 2 # HOLD
|
||||
|
||||
# Update recent actions list
|
||||
Returns:
|
||||
int: Action (0=SELL, 1=BUY) or None if should hold position
|
||||
"""
|
||||
|
||||
# Convert state to tensor
|
||||
if isinstance(state, np.ndarray):
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
else:
|
||||
state_tensor = state.unsqueeze(0).to(self.device)
|
||||
|
||||
# Get Q-values
|
||||
q_values = self.policy_net(state_tensor)
|
||||
action_values = q_values.cpu().data.numpy()[0]
|
||||
|
||||
# Calculate confidence scores
|
||||
sell_confidence = torch.softmax(q_values, dim=1)[0, 0].item()
|
||||
buy_confidence = torch.softmax(q_values, dim=1)[0, 1].item()
|
||||
|
||||
# Determine action based on current position and confidence thresholds
|
||||
action = self._determine_action_with_position_management(
|
||||
sell_confidence, buy_confidence, current_price, market_context, explore
|
||||
)
|
||||
|
||||
# Update tracking
|
||||
if current_price:
|
||||
self.recent_prices.append(current_price)
|
||||
|
||||
if action is not None:
|
||||
self.recent_actions.append(action)
|
||||
if len(self.recent_actions) > 5:
|
||||
self.recent_actions = self.recent_actions[-5:]
|
||||
|
||||
return action
|
||||
else:
|
||||
# Return None to indicate HOLD (don't change position)
|
||||
return None
|
||||
|
||||
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]:
|
||||
"""Choose action with confidence score adapted to market regime (from Enhanced DQN)"""
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
q_values = self.policy_net(state_tensor)
|
||||
|
||||
# Convert Q-values to probabilities
|
||||
action_probs = torch.softmax(q_values, dim=1)
|
||||
action = q_values.argmax().item()
|
||||
base_confidence = action_probs[0, action].item()
|
||||
|
||||
# Adapt confidence based on market regime
|
||||
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
||||
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
||||
|
||||
return action, adapted_confidence
|
||||
|
||||
def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore):
|
||||
"""
|
||||
Determine action based on current position and confidence thresholds
|
||||
|
||||
This implements the intelligent position management where:
|
||||
- When neutral: Need high confidence to enter position
|
||||
- When in position: Need lower confidence to exit
|
||||
- Different thresholds for entry vs exit
|
||||
"""
|
||||
|
||||
# Apply epsilon-greedy exploration
|
||||
if explore and np.random.random() <= self.epsilon:
|
||||
return np.random.choice([0, 1])
|
||||
|
||||
# Get the dominant signal
|
||||
dominant_action = 0 if sell_conf > buy_conf else 1
|
||||
dominant_confidence = max(sell_conf, buy_conf)
|
||||
|
||||
# Decision logic based on current position
|
||||
if self.current_position == 0: # No position - need high confidence to enter
|
||||
if dominant_confidence >= self.entry_confidence_threshold:
|
||||
# Strong enough signal to enter position
|
||||
if dominant_action == 1: # BUY signal
|
||||
self.current_position = 1.0
|
||||
self.position_entry_price = current_price
|
||||
self.position_entry_time = time.time()
|
||||
logger.info(f"ENTERING LONG position at {current_price:.4f} with confidence {dominant_confidence:.4f}")
|
||||
return 1
|
||||
else: # SELL signal
|
||||
self.current_position = -1.0
|
||||
self.position_entry_price = current_price
|
||||
self.position_entry_time = time.time()
|
||||
logger.info(f"ENTERING SHORT position at {current_price:.4f} with confidence {dominant_confidence:.4f}")
|
||||
return 0
|
||||
else:
|
||||
# Not confident enough to enter position
|
||||
return None
|
||||
|
||||
elif self.current_position > 0: # Long position
|
||||
if dominant_action == 0 and dominant_confidence >= self.exit_confidence_threshold:
|
||||
# SELL signal with enough confidence to close long position
|
||||
pnl = (current_price - self.position_entry_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
||||
logger.info(f"CLOSING LONG position at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
||||
self.current_position = 0.0
|
||||
self.position_entry_price = 0.0
|
||||
self.position_entry_time = None
|
||||
return 0
|
||||
elif dominant_action == 0 and dominant_confidence >= self.entry_confidence_threshold:
|
||||
# Very strong SELL signal - close long and enter short
|
||||
pnl = (current_price - self.position_entry_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
||||
logger.info(f"FLIPPING from LONG to SHORT at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
||||
self.current_position = -1.0
|
||||
self.position_entry_price = current_price
|
||||
self.position_entry_time = time.time()
|
||||
return 0
|
||||
else:
|
||||
# Hold the long position
|
||||
return None
|
||||
|
||||
elif self.current_position < 0: # Short position
|
||||
if dominant_action == 1 and dominant_confidence >= self.exit_confidence_threshold:
|
||||
# BUY signal with enough confidence to close short position
|
||||
pnl = (self.position_entry_price - current_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
||||
logger.info(f"CLOSING SHORT position at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
||||
self.current_position = 0.0
|
||||
self.position_entry_price = 0.0
|
||||
self.position_entry_time = None
|
||||
return 1
|
||||
elif dominant_action == 1 and dominant_confidence >= self.entry_confidence_threshold:
|
||||
# Very strong BUY signal - close short and enter long
|
||||
pnl = (self.position_entry_price - current_price) / self.position_entry_price if current_price and self.position_entry_price else 0
|
||||
logger.info(f"FLIPPING from SHORT to LONG at {current_price:.4f} with confidence {dominant_confidence:.4f}, PnL: {pnl:.4f}")
|
||||
self.current_position = 1.0
|
||||
self.position_entry_price = current_price
|
||||
self.position_entry_time = time.time()
|
||||
return 1
|
||||
else:
|
||||
# Hold the short position
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def replay(self, experiences=None):
|
||||
"""Train the model using experiences from memory"""
|
||||
@ -658,10 +599,18 @@ class DQNAgent:
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get next Q values with target network
|
||||
# Enhanced Double DQN implementation
|
||||
with torch.no_grad():
|
||||
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
if self.use_double_dqn:
|
||||
# Double DQN: Use policy network to select actions, target network to evaluate
|
||||
policy_q_values, _, _, _, _ = self.policy_net(next_states)
|
||||
next_actions = policy_q_values.argmax(1)
|
||||
target_q_values_all, _, _, _, _ = self.target_net(next_states)
|
||||
next_q_values = target_q_values_all.gather(1, next_actions.unsqueeze(1)).squeeze(1)
|
||||
else:
|
||||
# Standard DQN: Use target network for both selection and evaluation
|
||||
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features, next_advanced_pred = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Check for dimension mismatch between rewards and next_q_values
|
||||
if rewards.shape[0] != next_q_values.shape[0]:
|
||||
@ -699,16 +648,25 @@ class DQNAgent:
|
||||
# Backward pass
|
||||
total_loss.backward()
|
||||
|
||||
# Clip gradients to avoid exploding gradients
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
||||
# Enhanced gradient clipping with configurable norm
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), self.gradient_clip_norm)
|
||||
|
||||
# Update weights
|
||||
self.optimizer.step()
|
||||
|
||||
# Update target network if needed
|
||||
self.update_count += 1
|
||||
if self.update_count % self.target_update == 0:
|
||||
# Enhanced target network update tracking
|
||||
self.training_steps += 1
|
||||
if self.training_steps % self.target_update_freq == 0:
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
logger.debug(f"Target network updated at step {self.training_steps}")
|
||||
|
||||
# Enhanced statistics tracking
|
||||
self.epsilon_history.append(self.epsilon)
|
||||
|
||||
# Calculate and store TD error for analysis
|
||||
with torch.no_grad():
|
||||
td_error = torch.abs(current_q_values - target_q_values).mean().item()
|
||||
self.td_errors.append(td_error)
|
||||
|
||||
# Return loss
|
||||
return total_loss.item()
|
||||
@ -1168,4 +1126,40 @@ class DQNAgent:
|
||||
|
||||
logger.info(f"Agent state loaded from {path}_agent_state.pt")
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
|
||||
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
|
||||
|
||||
def get_position_info(self):
|
||||
"""Get current position information"""
|
||||
return {
|
||||
'position': self.current_position,
|
||||
'entry_price': self.position_entry_price,
|
||||
'entry_time': self.position_entry_time,
|
||||
'entry_threshold': self.entry_confidence_threshold,
|
||||
'exit_threshold': self.exit_confidence_threshold
|
||||
}
|
||||
|
||||
def get_enhanced_training_stats(self):
|
||||
"""Get enhanced RL training statistics with detailed metrics (from EnhancedDQNAgent)"""
|
||||
return {
|
||||
'buffer_size': len(self.memory),
|
||||
'epsilon': self.epsilon,
|
||||
'avg_reward': self.avg_reward,
|
||||
'best_reward': self.best_reward,
|
||||
'recent_rewards': list(self.recent_rewards) if hasattr(self, 'recent_rewards') else [],
|
||||
'no_improvement_count': self.no_improvement_count,
|
||||
# Enhanced statistics from EnhancedDQNAgent
|
||||
'training_steps': self.training_steps,
|
||||
'avg_td_error': np.mean(self.td_errors[-100:]) if self.td_errors else 0.0,
|
||||
'recent_losses': self.losses[-10:] if self.losses else [],
|
||||
'epsilon_trend': self.epsilon_history[-20:] if self.epsilon_history else [],
|
||||
'specialized_buffers': {
|
||||
'extrema_memory': len(self.extrema_memory),
|
||||
'positive_memory': len(self.positive_memory),
|
||||
'price_movement_memory': len(self.price_movement_memory)
|
||||
},
|
||||
'market_regime_weights': self.market_regime_weights,
|
||||
'use_double_dqn': self.use_double_dqn,
|
||||
'use_prioritized_replay': self.use_prioritized_replay,
|
||||
'gradient_clip_norm': self.gradient_clip_norm,
|
||||
'target_update_frequency': self.target_update_freq
|
||||
}
|
@ -1,329 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
import random
|
||||
from typing import Tuple, List
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
# Import the EnhancedCNN model
|
||||
from NN.models.enhanced_cnn import EnhancedCNN, ExampleSiftingDataset
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EnhancedDQNAgent:
|
||||
"""
|
||||
Enhanced Deep Q-Network agent for trading
|
||||
Uses the improved EnhancedCNN model with residual connections and attention mechanisms
|
||||
"""
|
||||
def __init__(self,
|
||||
state_shape: Tuple[int, ...],
|
||||
n_actions: int,
|
||||
learning_rate: float = 0.0003, # Slightly reduced learning rate for stability
|
||||
gamma: float = 0.95, # Discount factor
|
||||
epsilon: float = 1.0,
|
||||
epsilon_min: float = 0.05,
|
||||
epsilon_decay: float = 0.995, # Slower decay for more exploration
|
||||
buffer_size: int = 50000, # Larger memory buffer
|
||||
batch_size: int = 128, # Larger batch size
|
||||
target_update: int = 10, # More frequent target updates
|
||||
confidence_threshold: float = 0.4, # Lower confidence threshold
|
||||
device=None):
|
||||
|
||||
# Extract state dimensions
|
||||
if isinstance(state_shape, tuple) and len(state_shape) > 1:
|
||||
# Multi-dimensional state (like image or sequence)
|
||||
self.state_dim = state_shape
|
||||
else:
|
||||
# 1D state
|
||||
if isinstance(state_shape, tuple):
|
||||
self.state_dim = state_shape[0]
|
||||
else:
|
||||
self.state_dim = state_shape
|
||||
|
||||
# Store parameters
|
||||
self.n_actions = n_actions
|
||||
self.learning_rate = learning_rate
|
||||
self.gamma = gamma
|
||||
self.epsilon = epsilon
|
||||
self.epsilon_min = epsilon_min
|
||||
self.epsilon_decay = epsilon_decay
|
||||
self.buffer_size = buffer_size
|
||||
self.batch_size = batch_size
|
||||
self.target_update = target_update
|
||||
self.confidence_threshold = confidence_threshold
|
||||
|
||||
# Set device for computation
|
||||
if device is None:
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
# Initialize models with the enhanced CNN
|
||||
self.policy_net = EnhancedCNN(self.state_dim, self.n_actions, self.confidence_threshold)
|
||||
self.target_net = EnhancedCNN(self.state_dim, self.n_actions, self.confidence_threshold)
|
||||
|
||||
# Initialize the target network with the same weights as the policy network
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
# Set models to eval mode (important for batch norm, dropout)
|
||||
self.target_net.eval()
|
||||
|
||||
# Optimization components
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
|
||||
self.criterion = nn.MSELoss()
|
||||
|
||||
# Experience replay memory with example sifting
|
||||
self.memory = ExampleSiftingDataset(max_examples=buffer_size)
|
||||
self.update_count = 0
|
||||
|
||||
# Confidence tracking
|
||||
self.confidence_history = []
|
||||
self.avg_confidence = 0.0
|
||||
self.max_confidence = 0.0
|
||||
self.min_confidence = 1.0
|
||||
|
||||
# Performance tracking
|
||||
self.losses = []
|
||||
self.rewards = []
|
||||
self.avg_reward = 0.0
|
||||
|
||||
# Check if mixed precision training should be used
|
||||
self.use_mixed_precision = False
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
||||
self.use_mixed_precision = True
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
logger.info("Mixed precision training enabled")
|
||||
else:
|
||||
logger.info("Mixed precision training disabled")
|
||||
|
||||
# For compatibility with old code
|
||||
self.action_size = n_actions
|
||||
|
||||
logger.info(f"Enhanced DQN Agent using device: {self.device}")
|
||||
logger.info(f"Confidence threshold set to {self.confidence_threshold}")
|
||||
|
||||
def move_models_to_device(self, device=None):
|
||||
"""Move models to the specified device (GPU/CPU)"""
|
||||
if device is not None:
|
||||
self.device = device
|
||||
|
||||
try:
|
||||
self.policy_net = self.policy_net.to(self.device)
|
||||
self.target_net = self.target_net.to(self.device)
|
||||
logger.info(f"Moved models to {self.device}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to move models to {self.device}: {str(e)}")
|
||||
return False
|
||||
|
||||
def _normalize_state(self, state):
|
||||
"""Normalize state for better training stability"""
|
||||
try:
|
||||
# Convert to numpy array if needed
|
||||
if isinstance(state, list):
|
||||
state = np.array(state, dtype=np.float32)
|
||||
|
||||
# Apply normalization based on state shape
|
||||
if len(state.shape) > 1:
|
||||
# Multi-dimensional state - normalize each feature dimension separately
|
||||
for i in range(state.shape[0]):
|
||||
# Skip if all zeros (to avoid division by zero)
|
||||
if np.sum(np.abs(state[i])) > 0:
|
||||
# Standardize each feature dimension
|
||||
mean = np.mean(state[i])
|
||||
std = np.std(state[i])
|
||||
if std > 0:
|
||||
state[i] = (state[i] - mean) / std
|
||||
else:
|
||||
# 1D state vector
|
||||
# Skip if all zeros
|
||||
if np.sum(np.abs(state)) > 0:
|
||||
mean = np.mean(state)
|
||||
std = np.std(state)
|
||||
if std > 0:
|
||||
state = (state - mean) / std
|
||||
|
||||
return state
|
||||
except Exception as e:
|
||||
logger.warning(f"Error normalizing state: {str(e)}")
|
||||
return state
|
||||
|
||||
def remember(self, state, action, reward, next_state, done):
|
||||
"""Store experience in memory with example sifting"""
|
||||
self.memory.add_example(state, action, reward, next_state, done)
|
||||
|
||||
# Also track rewards for monitoring
|
||||
self.rewards.append(reward)
|
||||
if len(self.rewards) > 100:
|
||||
self.rewards = self.rewards[-100:]
|
||||
self.avg_reward = np.mean(self.rewards)
|
||||
|
||||
def act(self, state, explore=True):
|
||||
"""Choose action using epsilon-greedy policy with built-in confidence thresholding"""
|
||||
if explore and random.random() < self.epsilon:
|
||||
return random.randrange(self.n_actions), 0.0 # Return action and zero confidence
|
||||
|
||||
# Normalize state before inference
|
||||
normalized_state = self._normalize_state(state)
|
||||
|
||||
# Use the EnhancedCNN's act method which includes confidence thresholding
|
||||
action, confidence = self.policy_net.act(normalized_state, explore=explore)
|
||||
|
||||
# Track confidence metrics
|
||||
self.confidence_history.append(confidence)
|
||||
if len(self.confidence_history) > 100:
|
||||
self.confidence_history = self.confidence_history[-100:]
|
||||
|
||||
# Update confidence metrics
|
||||
self.avg_confidence = sum(self.confidence_history) / len(self.confidence_history)
|
||||
self.max_confidence = max(self.max_confidence, confidence)
|
||||
self.min_confidence = min(self.min_confidence, confidence)
|
||||
|
||||
# Log average confidence occasionally
|
||||
if random.random() < 0.01: # 1% of the time
|
||||
logger.info(f"Confidence metrics - Current: {confidence:.4f}, Avg: {self.avg_confidence:.4f}, " +
|
||||
f"Min: {self.min_confidence:.4f}, Max: {self.max_confidence:.4f}")
|
||||
|
||||
return action, confidence
|
||||
|
||||
def replay(self):
|
||||
"""Train the model using experience replay with high-quality examples"""
|
||||
# Check if enough samples in memory
|
||||
if len(self.memory) < self.batch_size:
|
||||
return 0.0
|
||||
|
||||
# Get batch of experiences
|
||||
batch = self.memory.get_batch(self.batch_size)
|
||||
if batch is None:
|
||||
return 0.0
|
||||
|
||||
states = torch.FloatTensor(batch['states']).to(self.device)
|
||||
actions = torch.LongTensor(batch['actions']).to(self.device)
|
||||
rewards = torch.FloatTensor(batch['rewards']).to(self.device)
|
||||
next_states = torch.FloatTensor(batch['next_states']).to(self.device)
|
||||
dones = torch.FloatTensor(batch['dones']).to(self.device)
|
||||
|
||||
# Compute Q values
|
||||
self.policy_net.train() # Set to training mode
|
||||
|
||||
# Get current Q values
|
||||
if self.use_mixed_precision:
|
||||
with torch.cuda.amp.autocast():
|
||||
# Get current Q values
|
||||
q_values, _, _, _ = self.policy_net(states)
|
||||
current_q = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Compute target Q values
|
||||
with torch.no_grad():
|
||||
self.target_net.eval()
|
||||
next_q_values, _, _, _ = self.target_net(next_states)
|
||||
next_q = next_q_values.max(1)[0]
|
||||
target_q = rewards + (1 - dones) * self.gamma * next_q
|
||||
|
||||
# Compute loss
|
||||
loss = self.criterion(current_q, target_q)
|
||||
|
||||
# Perform backpropagation with mixed precision
|
||||
self.optimizer.zero_grad()
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
# Standard precision training
|
||||
# Get current Q values
|
||||
q_values, _, _, _ = self.policy_net(states)
|
||||
current_q = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Compute target Q values
|
||||
with torch.no_grad():
|
||||
self.target_net.eval()
|
||||
next_q_values, _, _, _ = self.target_net(next_states)
|
||||
next_q = next_q_values.max(1)[0]
|
||||
target_q = rewards + (1 - dones) * self.gamma * next_q
|
||||
|
||||
# Compute loss
|
||||
loss = self.criterion(current_q, target_q)
|
||||
|
||||
# Perform backpropagation
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
||||
self.optimizer.step()
|
||||
|
||||
# Track loss
|
||||
loss_value = loss.item()
|
||||
self.losses.append(loss_value)
|
||||
if len(self.losses) > 100:
|
||||
self.losses = self.losses[-100:]
|
||||
|
||||
# Update target network
|
||||
self.update_count += 1
|
||||
if self.update_count % self.target_update == 0:
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
logger.info(f"Updated target network (step {self.update_count})")
|
||||
|
||||
# Decay epsilon
|
||||
if self.epsilon > self.epsilon_min:
|
||||
self.epsilon *= self.epsilon_decay
|
||||
|
||||
return loss_value
|
||||
|
||||
def save(self, path):
|
||||
"""Save agent state and models"""
|
||||
self.policy_net.save(f"{path}_policy")
|
||||
self.target_net.save(f"{path}_target")
|
||||
|
||||
# Save agent state
|
||||
torch.save({
|
||||
'epsilon': self.epsilon,
|
||||
'confidence_threshold': self.confidence_threshold,
|
||||
'losses': self.losses,
|
||||
'rewards': self.rewards,
|
||||
'avg_reward': self.avg_reward,
|
||||
'confidence_history': self.confidence_history,
|
||||
'avg_confidence': self.avg_confidence,
|
||||
'max_confidence': self.max_confidence,
|
||||
'min_confidence': self.min_confidence,
|
||||
'update_count': self.update_count
|
||||
}, f"{path}_agent_state.pt")
|
||||
|
||||
logger.info(f"Agent state saved to {path}_agent_state.pt")
|
||||
|
||||
def load(self, path):
|
||||
"""Load agent state and models"""
|
||||
policy_loaded = self.policy_net.load(f"{path}_policy")
|
||||
target_loaded = self.target_net.load(f"{path}_target")
|
||||
|
||||
# Load agent state if available
|
||||
agent_state_path = f"{path}_agent_state.pt"
|
||||
if os.path.exists(agent_state_path):
|
||||
try:
|
||||
state = torch.load(agent_state_path)
|
||||
self.epsilon = state.get('epsilon', self.epsilon)
|
||||
self.confidence_threshold = state.get('confidence_threshold', self.confidence_threshold)
|
||||
self.policy_net.confidence_threshold = self.confidence_threshold
|
||||
self.target_net.confidence_threshold = self.confidence_threshold
|
||||
self.losses = state.get('losses', [])
|
||||
self.rewards = state.get('rewards', [])
|
||||
self.avg_reward = state.get('avg_reward', 0.0)
|
||||
self.confidence_history = state.get('confidence_history', [])
|
||||
self.avg_confidence = state.get('avg_confidence', 0.0)
|
||||
self.max_confidence = state.get('max_confidence', 0.0)
|
||||
self.min_confidence = state.get('min_confidence', 1.0)
|
||||
self.update_count = state.get('update_count', 0)
|
||||
logger.info(f"Agent state loaded from {agent_state_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading agent state: {str(e)}")
|
||||
|
||||
return policy_loaded and target_loaded
|
@ -110,96 +110,119 @@ class EnhancedCNN(nn.Module):
|
||||
logger.info(f"EnhancedCNN initialized with input shape: {input_shape}, actions: {n_actions}")
|
||||
|
||||
def _build_network(self):
|
||||
"""Build the MASSIVELY enhanced neural network for 4GB VRAM budget"""
|
||||
"""Build the ULTRA MASSIVE enhanced neural network for maximum learning capacity"""
|
||||
|
||||
# MASSIVELY SCALED ARCHITECTURE for 4GB VRAM (up to ~50M parameters)
|
||||
# ULTRA MASSIVE SCALED ARCHITECTURE for maximum learning (up to ~100M parameters)
|
||||
if self.channels > 1:
|
||||
# Massive convolutional backbone with deeper residual blocks
|
||||
# Ultra massive convolutional backbone with much deeper residual blocks
|
||||
self.conv_layers = nn.Sequential(
|
||||
# Initial large conv block
|
||||
nn.Conv1d(self.channels, 256, kernel_size=7, padding=3), # Much wider initial layer
|
||||
nn.BatchNorm1d(256),
|
||||
# Initial ultra large conv block
|
||||
nn.Conv1d(self.channels, 512, kernel_size=7, padding=3), # Ultra wide initial layer
|
||||
nn.BatchNorm1d(512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
|
||||
# First residual stage - 256 channels
|
||||
ResidualBlock(256, 512),
|
||||
ResidualBlock(512, 512),
|
||||
ResidualBlock(512, 512),
|
||||
# First residual stage - 512 channels
|
||||
ResidualBlock(512, 768),
|
||||
ResidualBlock(768, 768),
|
||||
ResidualBlock(768, 768),
|
||||
ResidualBlock(768, 768), # Additional layer
|
||||
nn.MaxPool1d(kernel_size=2, stride=2),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Second residual stage - 512 channels
|
||||
ResidualBlock(512, 1024),
|
||||
# Second residual stage - 768 to 1024 channels
|
||||
ResidualBlock(768, 1024),
|
||||
ResidualBlock(1024, 1024),
|
||||
ResidualBlock(1024, 1024),
|
||||
ResidualBlock(1024, 1024), # Additional layer
|
||||
nn.MaxPool1d(kernel_size=2, stride=2),
|
||||
nn.Dropout(0.25),
|
||||
|
||||
# Third residual stage - 1024 channels
|
||||
# Third residual stage - 1024 to 1536 channels
|
||||
ResidualBlock(1024, 1536),
|
||||
ResidualBlock(1536, 1536),
|
||||
ResidualBlock(1536, 1536),
|
||||
ResidualBlock(1536, 1536), # Additional layer
|
||||
nn.MaxPool1d(kernel_size=2, stride=2),
|
||||
nn.Dropout(0.3),
|
||||
|
||||
# Fourth residual stage - 1536 channels (MASSIVE)
|
||||
# Fourth residual stage - 1536 to 2048 channels
|
||||
ResidualBlock(1536, 2048),
|
||||
ResidualBlock(2048, 2048),
|
||||
ResidualBlock(2048, 2048),
|
||||
ResidualBlock(2048, 2048), # Additional layer
|
||||
nn.MaxPool1d(kernel_size=2, stride=2),
|
||||
nn.Dropout(0.3),
|
||||
|
||||
# Fifth residual stage - ULTRA MASSIVE 2048 to 3072 channels
|
||||
ResidualBlock(2048, 3072),
|
||||
ResidualBlock(3072, 3072),
|
||||
ResidualBlock(3072, 3072),
|
||||
ResidualBlock(3072, 3072),
|
||||
nn.AdaptiveAvgPool1d(1) # Global average pooling
|
||||
)
|
||||
# Massive feature dimension after conv layers
|
||||
self.conv_features = 2048
|
||||
# Ultra massive feature dimension after conv layers
|
||||
self.conv_features = 3072
|
||||
else:
|
||||
# For 1D vectors, use massive dense preprocessing
|
||||
# For 1D vectors, use ultra massive dense preprocessing
|
||||
self.conv_layers = None
|
||||
self.conv_features = 0
|
||||
|
||||
# MASSIVE fully connected feature extraction layers
|
||||
# ULTRA MASSIVE fully connected feature extraction layers
|
||||
if self.conv_layers is None:
|
||||
# For 1D inputs - massive feature extraction
|
||||
self.fc1 = nn.Linear(self.feature_dim, 2048)
|
||||
self.features_dim = 2048
|
||||
# For 1D inputs - ultra massive feature extraction
|
||||
self.fc1 = nn.Linear(self.feature_dim, 3072)
|
||||
self.features_dim = 3072
|
||||
else:
|
||||
# For data processed by massive conv layers
|
||||
self.fc1 = nn.Linear(self.conv_features, 2048)
|
||||
self.features_dim = 2048
|
||||
# For data processed by ultra massive conv layers
|
||||
self.fc1 = nn.Linear(self.conv_features, 3072)
|
||||
self.features_dim = 3072
|
||||
|
||||
# MASSIVE common feature extraction with multiple attention layers
|
||||
# ULTRA MASSIVE common feature extraction with multiple deep layers
|
||||
self.fc_layers = nn.Sequential(
|
||||
self.fc1,
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(2048, 2048), # Keep massive width
|
||||
nn.Linear(3072, 3072), # Keep ultra massive width
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(2048, 1536), # Still very wide
|
||||
nn.Linear(3072, 2560), # Ultra wide hidden layer
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(1536, 1024), # Large hidden layer
|
||||
nn.Linear(2560, 2048), # Still very wide
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(1024, 768), # Final feature representation
|
||||
nn.Linear(2048, 1536), # Large hidden layer
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(1536, 1024), # Final feature representation
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
# Multiple attention mechanisms for different aspects
|
||||
self.price_attention = SelfAttention(768)
|
||||
self.volume_attention = SelfAttention(768)
|
||||
self.trend_attention = SelfAttention(768)
|
||||
self.volatility_attention = SelfAttention(768)
|
||||
# Multiple attention mechanisms for different aspects (larger capacity)
|
||||
self.price_attention = SelfAttention(1024) # Increased from 768
|
||||
self.volume_attention = SelfAttention(1024)
|
||||
self.trend_attention = SelfAttention(1024)
|
||||
self.volatility_attention = SelfAttention(1024)
|
||||
self.momentum_attention = SelfAttention(1024) # Additional attention
|
||||
self.microstructure_attention = SelfAttention(1024) # Additional attention
|
||||
|
||||
# Attention fusion layer
|
||||
# Ultra massive attention fusion layer
|
||||
self.attention_fusion = nn.Sequential(
|
||||
nn.Linear(768 * 4, 1024), # Combine all attention outputs
|
||||
nn.Linear(1024 * 6, 2048), # Combine all 6 attention outputs
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(1024, 768)
|
||||
nn.Linear(2048, 1536),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(1536, 1024)
|
||||
)
|
||||
|
||||
# MASSIVE dueling architecture with deeper networks
|
||||
# ULTRA MASSIVE dueling architecture with much deeper networks
|
||||
self.advantage_stream = nn.Sequential(
|
||||
nn.Linear(1024, 768),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(768, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
@ -212,6 +235,9 @@ class EnhancedCNN(nn.Module):
|
||||
)
|
||||
|
||||
self.value_stream = nn.Sequential(
|
||||
nn.Linear(1024, 768),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(768, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
@ -223,8 +249,11 @@ class EnhancedCNN(nn.Module):
|
||||
nn.Linear(128, 1)
|
||||
)
|
||||
|
||||
# MASSIVE extrema detection head with ensemble predictions
|
||||
# ULTRA MASSIVE extrema detection head with deeper ensemble predictions
|
||||
self.extrema_head = nn.Sequential(
|
||||
nn.Linear(1024, 768),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(768, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
@ -236,9 +265,12 @@ class EnhancedCNN(nn.Module):
|
||||
nn.Linear(128, 3) # 0=bottom, 1=top, 2=neither
|
||||
)
|
||||
|
||||
# MASSIVE multi-timeframe price prediction heads
|
||||
# ULTRA MASSIVE multi-timeframe price prediction heads
|
||||
self.price_pred_immediate = nn.Sequential(
|
||||
nn.Linear(768, 256),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
@ -247,7 +279,10 @@ class EnhancedCNN(nn.Module):
|
||||
)
|
||||
|
||||
self.price_pred_midterm = nn.Sequential(
|
||||
nn.Linear(768, 256),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
@ -256,7 +291,10 @@ class EnhancedCNN(nn.Module):
|
||||
)
|
||||
|
||||
self.price_pred_longterm = nn.Sequential(
|
||||
nn.Linear(768, 256),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
@ -264,8 +302,11 @@ class EnhancedCNN(nn.Module):
|
||||
nn.Linear(128, 3) # Up, Down, Sideways
|
||||
)
|
||||
|
||||
# MASSIVE value prediction with ensemble approaches
|
||||
# ULTRA MASSIVE value prediction with ensemble approaches
|
||||
self.price_pred_value = nn.Sequential(
|
||||
nn.Linear(1024, 768),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(768, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
@ -280,7 +321,10 @@ class EnhancedCNN(nn.Module):
|
||||
# Additional specialized prediction heads for better accuracy
|
||||
# Volatility prediction head
|
||||
self.volatility_head = nn.Sequential(
|
||||
nn.Linear(768, 256),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
@ -290,7 +334,10 @@ class EnhancedCNN(nn.Module):
|
||||
|
||||
# Support/Resistance level detection head
|
||||
self.support_resistance_head = nn.Sequential(
|
||||
nn.Linear(768, 256),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
@ -300,7 +347,10 @@ class EnhancedCNN(nn.Module):
|
||||
|
||||
# Market regime classification head
|
||||
self.market_regime_head = nn.Sequential(
|
||||
nn.Linear(768, 256),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
@ -310,7 +360,10 @@ class EnhancedCNN(nn.Module):
|
||||
|
||||
# Risk assessment head
|
||||
self.risk_head = nn.Sequential(
|
||||
nn.Linear(768, 256),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
@ -330,7 +383,7 @@ class EnhancedCNN(nn.Module):
|
||||
return False
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass through the MASSIVE network"""
|
||||
"""Forward pass through the ULTRA MASSIVE network"""
|
||||
batch_size = x.size(0)
|
||||
|
||||
# Process different input shapes
|
||||
@ -349,7 +402,7 @@ class EnhancedCNN(nn.Module):
|
||||
total_features = x_reshaped.size(1) * x_reshaped.size(2)
|
||||
self._check_rebuild_network(total_features)
|
||||
|
||||
# Apply massive convolutions
|
||||
# Apply ultra massive convolutions
|
||||
x_conv = self.conv_layers(x_reshaped)
|
||||
# Flatten: [batch, channels, 1] -> [batch, channels]
|
||||
x_flat = x_conv.view(batch_size, -1)
|
||||
@ -364,33 +417,40 @@ class EnhancedCNN(nn.Module):
|
||||
if x_flat.size(1) != self.feature_dim:
|
||||
self._check_rebuild_network(x_flat.size(1))
|
||||
|
||||
# Apply MASSIVE FC layers to get base features
|
||||
features = self.fc_layers(x_flat) # [batch, 768]
|
||||
# Apply ULTRA MASSIVE FC layers to get base features
|
||||
features = self.fc_layers(x_flat) # [batch, 1024]
|
||||
|
||||
# Apply multiple specialized attention mechanisms
|
||||
features_3d = features.unsqueeze(1) # [batch, 1, 768]
|
||||
features_3d = features.unsqueeze(1) # [batch, 1, 1024]
|
||||
|
||||
# Get attention-refined features for different aspects
|
||||
price_features, _ = self.price_attention(features_3d)
|
||||
price_features = price_features.squeeze(1) # [batch, 768]
|
||||
price_features = price_features.squeeze(1) # [batch, 1024]
|
||||
|
||||
volume_features, _ = self.volume_attention(features_3d)
|
||||
volume_features = volume_features.squeeze(1) # [batch, 768]
|
||||
volume_features = volume_features.squeeze(1) # [batch, 1024]
|
||||
|
||||
trend_features, _ = self.trend_attention(features_3d)
|
||||
trend_features = trend_features.squeeze(1) # [batch, 768]
|
||||
trend_features = trend_features.squeeze(1) # [batch, 1024]
|
||||
|
||||
volatility_features, _ = self.volatility_attention(features_3d)
|
||||
volatility_features = volatility_features.squeeze(1) # [batch, 768]
|
||||
volatility_features = volatility_features.squeeze(1) # [batch, 1024]
|
||||
|
||||
momentum_features, _ = self.momentum_attention(features_3d)
|
||||
momentum_features = momentum_features.squeeze(1) # [batch, 1024]
|
||||
|
||||
microstructure_features, _ = self.microstructure_attention(features_3d)
|
||||
microstructure_features = microstructure_features.squeeze(1) # [batch, 1024]
|
||||
|
||||
# Fuse all attention outputs
|
||||
combined_attention = torch.cat([
|
||||
price_features, volume_features,
|
||||
trend_features, volatility_features
|
||||
], dim=1) # [batch, 768*4]
|
||||
trend_features, volatility_features,
|
||||
momentum_features, microstructure_features
|
||||
], dim=1) # [batch, 1024*6]
|
||||
|
||||
# Apply attention fusion to get final refined features
|
||||
features_refined = self.attention_fusion(combined_attention) # [batch, 768]
|
||||
features_refined = self.attention_fusion(combined_attention) # [batch, 1024]
|
||||
|
||||
# Calculate advantage and value (Dueling DQN architecture)
|
||||
advantage = self.advantage_stream(features_refined)
|
||||
@ -399,7 +459,7 @@ class EnhancedCNN(nn.Module):
|
||||
# Combine for Q-values (Dueling architecture)
|
||||
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||
|
||||
# Get massive ensemble of predictions
|
||||
# Get ultra massive ensemble of predictions
|
||||
|
||||
# Extrema predictions (bottom/top/neither detection)
|
||||
extrema_pred = self.extrema_head(features_refined)
|
||||
@ -435,7 +495,7 @@ class EnhancedCNN(nn.Module):
|
||||
return q_values, extrema_pred, price_predictions, features_refined, advanced_predictions
|
||||
|
||||
def act(self, state, explore=True):
|
||||
"""Enhanced action selection with massive model predictions"""
|
||||
"""Enhanced action selection with ultra massive model predictions"""
|
||||
if explore and np.random.random() < 0.1: # 10% random exploration
|
||||
return np.random.choice(self.n_actions)
|
||||
|
||||
@ -471,7 +531,7 @@ class EnhancedCNN(nn.Module):
|
||||
risk_class = torch.argmax(risk, dim=1).item()
|
||||
risk_labels = ['Low Risk', 'Medium Risk', 'High Risk', 'Extreme Risk']
|
||||
|
||||
logger.info(f"MASSIVE Model Predictions:")
|
||||
logger.info(f"ULTRA MASSIVE Model Predictions:")
|
||||
logger.info(f" Volatility: {volatility_labels[volatility_class]} ({volatility[0, volatility_class]:.3f})")
|
||||
logger.info(f" Support/Resistance: {sr_labels[sr_class]} ({sr[0, sr_class]:.3f})")
|
||||
logger.info(f" Market Regime: {regime_labels[regime_class]} ({regime[0, regime_class]:.3f})")
|
||||
|
Reference in New Issue
Block a user