scalping dash also works initially
This commit is contained in:
@ -143,6 +143,10 @@ class DQNAgent:
|
||||
self.last_hidden_features = None # Store last extracted features
|
||||
self.feature_history = [] # Store history of features for analysis
|
||||
|
||||
# Real-time tick features integration
|
||||
self.realtime_tick_features = None # Latest tick features from tick processor
|
||||
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
|
||||
|
||||
# Check if mixed precision training should be used
|
||||
self.use_mixed_precision = False
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
||||
@ -163,6 +167,7 @@ class DQNAgent:
|
||||
|
||||
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
|
||||
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
|
||||
logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}")
|
||||
|
||||
# Log model parameters
|
||||
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
||||
@ -291,8 +296,11 @@ class DQNAgent:
|
||||
return random.randrange(self.n_actions)
|
||||
|
||||
with torch.no_grad():
|
||||
# Enhance state with real-time tick features
|
||||
enhanced_state = self._enhance_state_with_tick_features(state)
|
||||
|
||||
# Ensure state is normalized before inference
|
||||
state_tensor = self._normalize_state(state)
|
||||
state_tensor = self._normalize_state(enhanced_state)
|
||||
state_tensor = torch.FloatTensor(state_tensor).unsqueeze(0).to(self.device)
|
||||
|
||||
# Get predictions using the policy network
|
||||
@ -764,11 +772,14 @@ class DQNAgent:
|
||||
# Calculate price change for different timeframes
|
||||
immediate_changes = (next_prices - current_prices) / current_prices
|
||||
|
||||
# Get the actual batch size for this calculation
|
||||
actual_batch_size = states.shape[0]
|
||||
|
||||
# Create price direction labels - simplified for training
|
||||
# 0 = down, 1 = sideways, 2 = up
|
||||
immediate_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1 # Default: sideways
|
||||
midterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
|
||||
longterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
|
||||
immediate_labels = torch.ones(actual_batch_size, dtype=torch.long, device=self.device) * 1 # Default: sideways
|
||||
midterm_labels = torch.ones(actual_batch_size, dtype=torch.long, device=self.device) * 1
|
||||
longterm_labels = torch.ones(actual_batch_size, dtype=torch.long, device=self.device) * 1
|
||||
|
||||
# Immediate term direction (1s, 1m)
|
||||
immediate_up = (immediate_changes > 0.0005)
|
||||
@ -794,19 +805,19 @@ class DQNAgent:
|
||||
|
||||
# Generate target values for price change regression
|
||||
# For simplicity, we'll use the immediate change and scaled versions for longer timeframes
|
||||
price_value_targets = torch.zeros((min_size, 4), device=self.device)
|
||||
price_value_targets = torch.zeros((actual_batch_size, 4), device=self.device)
|
||||
price_value_targets[:, 0] = immediate_changes
|
||||
price_value_targets[:, 1] = immediate_changes * 2.0 # Approximate 1h change
|
||||
price_value_targets[:, 2] = immediate_changes * 4.0 # Approximate 1d change
|
||||
price_value_targets[:, 3] = immediate_changes * 6.0 # Approximate 1w change
|
||||
|
||||
# Calculate loss for price direction prediction (classification)
|
||||
if len(current_price_pred['immediate'].shape) > 1 and current_price_pred['immediate'].shape[0] >= min_size:
|
||||
if len(current_price_pred['immediate'].shape) > 1 and current_price_pred['immediate'].shape[0] >= actual_batch_size:
|
||||
# Slice predictions to match the adjusted batch size
|
||||
immediate_pred = current_price_pred['immediate'][:min_size]
|
||||
midterm_pred = current_price_pred['midterm'][:min_size]
|
||||
longterm_pred = current_price_pred['longterm'][:min_size]
|
||||
price_values_pred = current_price_pred['values'][:min_size]
|
||||
immediate_pred = current_price_pred['immediate'][:actual_batch_size]
|
||||
midterm_pred = current_price_pred['midterm'][:actual_batch_size]
|
||||
longterm_pred = current_price_pred['longterm'][:actual_batch_size]
|
||||
price_values_pred = current_price_pred['values'][:actual_batch_size]
|
||||
|
||||
# Compute losses for each task
|
||||
immediate_loss = nn.CrossEntropyLoss()(immediate_pred, immediate_labels)
|
||||
@ -820,7 +831,7 @@ class DQNAgent:
|
||||
price_loss = immediate_loss + 0.7 * midterm_loss + 0.5 * longterm_loss + 0.3 * price_value_loss
|
||||
|
||||
# Create extrema labels (same as before)
|
||||
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither
|
||||
extrema_labels = torch.ones(actual_batch_size, dtype=torch.long, device=self.device) * 2 # Default: neither
|
||||
|
||||
# Identify potential bottoms (significant negative change)
|
||||
bottoms = (immediate_changes < -0.003)
|
||||
@ -831,8 +842,8 @@ class DQNAgent:
|
||||
extrema_labels[tops] = 1
|
||||
|
||||
# Calculate extrema prediction loss
|
||||
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size:
|
||||
current_extrema_pred = current_extrema_pred[:min_size]
|
||||
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= actual_batch_size:
|
||||
current_extrema_pred = current_extrema_pred[:actual_batch_size]
|
||||
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
|
||||
|
||||
# Combined loss with all components
|
||||
@ -1017,6 +1028,71 @@ class DQNAgent:
|
||||
|
||||
return normalized_state
|
||||
|
||||
def update_realtime_tick_features(self, tick_features):
|
||||
"""Update with real-time tick features from tick processor"""
|
||||
try:
|
||||
if tick_features is not None:
|
||||
self.realtime_tick_features = tick_features
|
||||
|
||||
# Log high-confidence tick features
|
||||
if tick_features.get('confidence', 0) > 0.8:
|
||||
logger.debug(f"High-confidence tick features updated: confidence={tick_features['confidence']:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating real-time tick features: {e}")
|
||||
|
||||
def _enhance_state_with_tick_features(self, state: np.ndarray) -> np.ndarray:
|
||||
"""Enhance state with real-time tick features if available"""
|
||||
try:
|
||||
if self.realtime_tick_features is None:
|
||||
return state
|
||||
|
||||
# Extract neural features from tick processor
|
||||
neural_features = self.realtime_tick_features.get('neural_features', np.array([]))
|
||||
volume_features = self.realtime_tick_features.get('volume_features', np.array([]))
|
||||
microstructure_features = self.realtime_tick_features.get('microstructure_features', np.array([]))
|
||||
confidence = self.realtime_tick_features.get('confidence', 0.0)
|
||||
|
||||
# Combine tick features - make them compact to match state dimensions
|
||||
tick_features = np.concatenate([
|
||||
neural_features[:3] if len(neural_features) >= 3 else np.zeros(3), # Take first 3 neural features
|
||||
volume_features[:1] if len(volume_features) >= 1 else np.zeros(1), # Take first volume feature
|
||||
microstructure_features[:1] if len(microstructure_features) >= 1 else np.zeros(1), # Take first microstructure feature
|
||||
])
|
||||
|
||||
# Weight the tick features
|
||||
weighted_tick_features = tick_features * self.tick_feature_weight
|
||||
|
||||
# Enhance the state by adding tick features to each timeframe
|
||||
if len(state.shape) == 1:
|
||||
# 1D state - append tick features
|
||||
enhanced_state = np.concatenate([state, weighted_tick_features])
|
||||
else:
|
||||
# 2D state - add tick features to each timeframe row
|
||||
num_timeframes, num_features = state.shape
|
||||
|
||||
# Ensure tick features match the number of original features
|
||||
if len(weighted_tick_features) != num_features:
|
||||
# Pad or truncate tick features to match state feature dimension
|
||||
if len(weighted_tick_features) < num_features:
|
||||
# Pad with zeros
|
||||
padded_features = np.zeros(num_features)
|
||||
padded_features[:len(weighted_tick_features)] = weighted_tick_features
|
||||
weighted_tick_features = padded_features
|
||||
else:
|
||||
# Truncate to match
|
||||
weighted_tick_features = weighted_tick_features[:num_features]
|
||||
|
||||
# Add tick features to the last row (most recent timeframe)
|
||||
enhanced_state = state.copy()
|
||||
enhanced_state[-1, :] += weighted_tick_features # Add to last timeframe
|
||||
|
||||
return enhanced_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error enhancing state with tick features: {e}")
|
||||
return state
|
||||
|
||||
def update_learning_metrics(self, episode_reward, best_reward_threshold=0.01):
|
||||
"""Update learning metrics and perform learning rate adjustments if needed"""
|
||||
# Update average reward with exponential moving average
|
||||
|
Reference in New Issue
Block a user