improved data structure
This commit is contained in:
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader, TensorDataset
|
||||
import numpy as np
|
||||
import math
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Tuple, List
|
||||
from typing import Dict, Any, Optional, Tuple, List, Callable
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import json
|
||||
@@ -421,6 +421,48 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
# NEW: Next candle OHLCV prediction heads for each timeframe (1s, 1m, 1h, 1d)
|
||||
# Each timeframe predicts: [open, high, low, close, volume] = 5 values
|
||||
self.timeframes = ['1s', '1m', '1h', '1d']
|
||||
self.next_candle_heads = nn.ModuleDict({
|
||||
tf: nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, 5) # OHLCV: [open, high, low, close, volume]
|
||||
) for tf in self.timeframes
|
||||
})
|
||||
|
||||
# NEW: Next pivot point prediction heads for L1-L5 levels
|
||||
# Each level predicts: [price, type_prob_high, type_prob_low, confidence]
|
||||
# type_prob_high + type_prob_low = 1 (softmax), but we output separately for clarity
|
||||
self.pivot_levels = [1, 2, 3, 4, 5] # L1 to L5
|
||||
self.pivot_heads = nn.ModuleDict({
|
||||
f'L{level}': nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, 4) # [price, type_prob_high, type_prob_low, confidence]
|
||||
) for level in self.pivot_levels
|
||||
})
|
||||
|
||||
# NEW: Trend vector analysis head (calculates trend from pivot predictions)
|
||||
self.trend_analysis_head = nn.Sequential(
|
||||
nn.Linear(config.d_model, config.d_model // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.d_model // 4, 3) # [angle_radians, steepness, direction]
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
self._init_weights()
|
||||
|
||||
@@ -522,11 +564,341 @@ class AdvancedTradingTransformer(nn.Module):
|
||||
trend_strength_pred = self.trend_strength_head(pooled)
|
||||
outputs['trend_strength_prediction'] = trend_strength_pred
|
||||
|
||||
# NEW: Next candle OHLCV predictions for each timeframe
|
||||
next_candles = {}
|
||||
for tf in self.timeframes:
|
||||
candle_pred = self.next_candle_heads[tf](pooled) # (batch, 5)
|
||||
next_candles[tf] = candle_pred
|
||||
outputs['next_candles'] = next_candles
|
||||
|
||||
# NEW: Next pivot point predictions for L1-L5
|
||||
next_pivots = {}
|
||||
for level in self.pivot_levels:
|
||||
pivot_pred = self.pivot_heads[f'L{level}'](pooled) # (batch, 4)
|
||||
# Extract components: [price, type_logit_high, type_logit_low, confidence]
|
||||
# Use softmax to ensure type probabilities sum to 1
|
||||
type_logits = pivot_pred[:, 1:3] # (batch, 2) - [high, low]
|
||||
type_probs = F.softmax(type_logits, dim=-1) # (batch, 2)
|
||||
|
||||
next_pivots[f'L{level}'] = {
|
||||
'price': pivot_pred[:, 0:1], # Keep as (batch, 1)
|
||||
'type_prob_high': type_probs[:, 0:1], # Probability of high pivot
|
||||
'type_prob_low': type_probs[:, 1:2], # Probability of low pivot
|
||||
'pivot_type': torch.argmax(type_probs, dim=-1, keepdim=True), # 0=high, 1=low
|
||||
'confidence': torch.sigmoid(pivot_pred[:, 3:4]) # Prediction confidence
|
||||
}
|
||||
outputs['next_pivots'] = next_pivots
|
||||
|
||||
# NEW: Trend vector analysis from pivot predictions
|
||||
trend_analysis = self.trend_analysis_head(pooled) # (batch, 3)
|
||||
outputs['trend_analysis'] = {
|
||||
'angle_radians': trend_analysis[:, 0:1], # Trend angle in radians
|
||||
'steepness': F.softplus(trend_analysis[:, 1:2]), # Always positive steepness
|
||||
'direction': torch.tanh(trend_analysis[:, 2:3]) # -1 to 1 (down to up)
|
||||
}
|
||||
|
||||
# NEW: Calculate trend vector from pivot predictions
|
||||
# Extract pivot prices and create trend vector
|
||||
pivot_prices = torch.stack([next_pivots[f'L{level}']['price'] for level in self.pivot_levels], dim=1) # (batch, 5, 1)
|
||||
pivot_prices = pivot_prices.squeeze(-1) # (batch, 5)
|
||||
|
||||
# Calculate trend vector: (price_change, time_change)
|
||||
# Assume equal time spacing between pivot levels
|
||||
time_points = torch.arange(1, len(self.pivot_levels) + 1, dtype=torch.float32, device=pooled.device).unsqueeze(0) # (1, 5)
|
||||
|
||||
# Calculate trend line slope using linear regression on pivot prices
|
||||
# Trend vector = (delta_price, delta_time) normalized
|
||||
if batch_size > 0:
|
||||
# For each sample, calculate trend from L1 to L5
|
||||
price_deltas = pivot_prices[:, -1:] - pivot_prices[:, :1] # L5 - L1 price change
|
||||
time_deltas = time_points[:, -1:] - time_points[:, :1] # Time change (should be 4)
|
||||
|
||||
# Calculate angle and steepness
|
||||
trend_angles = torch.atan2(price_deltas.squeeze(), time_deltas.squeeze()) # (batch,)
|
||||
trend_steepness = torch.sqrt(price_deltas.squeeze() ** 2 + time_deltas.squeeze() ** 2) # (batch,)
|
||||
trend_direction = torch.sign(price_deltas.squeeze()) # (batch,)
|
||||
|
||||
outputs['trend_vector'] = {
|
||||
'pivot_prices': pivot_prices, # (batch, 5) - prices for L1-L5
|
||||
'price_delta': price_deltas.squeeze(), # (batch,) - price change from L1 to L5
|
||||
'time_delta': time_deltas.squeeze(), # (batch,) - time change
|
||||
'calculated_angle': trend_angles.unsqueeze(-1), # (batch, 1)
|
||||
'calculated_steepness': trend_steepness.unsqueeze(-1), # (batch, 1)
|
||||
'calculated_direction': trend_direction.unsqueeze(-1), # (batch, 1)
|
||||
'vector': torch.stack([price_deltas.squeeze(), time_deltas.squeeze()], dim=1) # (batch, 2) - [price_delta, time_delta]
|
||||
}
|
||||
else:
|
||||
outputs['trend_vector'] = {
|
||||
'pivot_prices': pivot_prices,
|
||||
'price_delta': torch.zeros(batch_size, device=pooled.device),
|
||||
'time_delta': torch.zeros(batch_size, device=pooled.device),
|
||||
'calculated_angle': torch.zeros(batch_size, 1, device=pooled.device),
|
||||
'calculated_steepness': torch.zeros(batch_size, 1, device=pooled.device),
|
||||
'calculated_direction': torch.zeros(batch_size, 1, device=pooled.device),
|
||||
'vector': torch.zeros(batch_size, 2, device=pooled.device)
|
||||
}
|
||||
|
||||
# NEW: Trade action based on trend steepness and angle
|
||||
# Combine predicted trend analysis with calculated trend vector
|
||||
predicted_angle = outputs['trend_analysis']['angle_radians'].squeeze() # (batch,)
|
||||
predicted_steepness = outputs['trend_analysis']['steepness'].squeeze() # (batch,)
|
||||
predicted_direction = outputs['trend_analysis']['direction'].squeeze() # (batch,)
|
||||
|
||||
# Use calculated trend if available, otherwise use predicted
|
||||
if 'calculated_angle' in outputs['trend_vector']:
|
||||
trend_angle = outputs['trend_vector']['calculated_angle'].squeeze() # (batch,)
|
||||
trend_steepness_val = outputs['trend_vector']['calculated_steepness'].squeeze() # (batch,)
|
||||
else:
|
||||
trend_angle = predicted_angle
|
||||
trend_steepness_val = predicted_steepness
|
||||
|
||||
# Trade action logic based on trend steepness and angle
|
||||
# Steep upward trend (> 45 degrees) -> BUY
|
||||
# Steep downward trend (< -45 degrees) -> SELL
|
||||
# Shallow trend -> HOLD
|
||||
angle_threshold = math.pi / 4 # 45 degrees
|
||||
|
||||
# Determine action from trend angle
|
||||
trend_action_logits = torch.zeros(batch_size, 3, device=pooled.device) # [BUY, SELL, HOLD]
|
||||
|
||||
# Calculate action probabilities based on trend
|
||||
for i in range(batch_size):
|
||||
angle = trend_angle[i].item() if batch_size > 0 else 0.0
|
||||
steep = trend_steepness_val[i].item() if batch_size > 0 else 0.0
|
||||
|
||||
# Normalize steepness to [0, 1] range (assuming max steepness of 10 units)
|
||||
normalized_steepness = min(steep / 10.0, 1.0) if steep > 0 else 0.0
|
||||
|
||||
if angle > angle_threshold: # Steep upward trend
|
||||
trend_action_logits[i, 0] = normalized_steepness * 2.0 # BUY
|
||||
trend_action_logits[i, 2] = (1.0 - normalized_steepness) * 0.5 # HOLD
|
||||
elif angle < -angle_threshold: # Steep downward trend
|
||||
trend_action_logits[i, 1] = normalized_steepness * 2.0 # SELL
|
||||
trend_action_logits[i, 2] = (1.0 - normalized_steepness) * 0.5 # HOLD
|
||||
else: # Shallow trend
|
||||
trend_action_logits[i, 2] = 1.0 # HOLD
|
||||
|
||||
# Combine trend-based action with main action prediction
|
||||
trend_action_probs = F.softmax(trend_action_logits, dim=-1)
|
||||
outputs['trend_based_action'] = {
|
||||
'logits': trend_action_logits,
|
||||
'probabilities': trend_action_probs,
|
||||
'action_idx': torch.argmax(trend_action_probs, dim=-1),
|
||||
'trend_angle_degrees': trend_angle * 180.0 / math.pi, # Convert to degrees
|
||||
'trend_steepness': trend_steepness_val
|
||||
}
|
||||
|
||||
# Market regime information
|
||||
if regime_probs_history:
|
||||
outputs['regime_probs'] = torch.stack(regime_probs_history, dim=1)
|
||||
|
||||
return outputs
|
||||
|
||||
def extract_predictions(self, outputs: Dict[str, torch.Tensor], denormalize_prices: Optional[Callable] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract predictions from model outputs in a user-friendly format
|
||||
|
||||
Args:
|
||||
outputs: Raw model outputs from forward() method
|
||||
denormalize_prices: Optional function to denormalize predicted prices
|
||||
|
||||
Returns:
|
||||
Dictionary with formatted predictions including:
|
||||
- next_candles: Dict[str, Dict] - OHLCV predictions for each timeframe
|
||||
- next_pivots: Dict[str, Dict] - Pivot predictions for L1-L5
|
||||
- trend_vector: Dict - Trend vector analysis
|
||||
- trend_based_action: Dict - Trading action based on trend
|
||||
"""
|
||||
self.eval()
|
||||
device = next(self.parameters()).device
|
||||
|
||||
predictions = {}
|
||||
|
||||
# Extract next candle predictions for each timeframe
|
||||
if 'next_candles' in outputs:
|
||||
next_candles = {}
|
||||
for tf in self.timeframes:
|
||||
candle_tensor = outputs['next_candles'][tf]
|
||||
if candle_tensor.dim() > 1:
|
||||
candle_tensor = candle_tensor[0] # Take first batch item
|
||||
|
||||
candle_values = candle_tensor.cpu().detach().numpy() if hasattr(candle_tensor, 'cpu') else candle_tensor
|
||||
if isinstance(candle_values, np.ndarray):
|
||||
candle_values = candle_values.tolist()
|
||||
|
||||
next_candles[tf] = {
|
||||
'open': float(candle_values[0]) if len(candle_values) > 0 else 0.0,
|
||||
'high': float(candle_values[1]) if len(candle_values) > 1 else 0.0,
|
||||
'low': float(candle_values[2]) if len(candle_values) > 2 else 0.0,
|
||||
'close': float(candle_values[3]) if len(candle_values) > 3 else 0.0,
|
||||
'volume': float(candle_values[4]) if len(candle_values) > 4 else 0.0
|
||||
}
|
||||
|
||||
# Denormalize if function provided
|
||||
if denormalize_prices and callable(denormalize_prices):
|
||||
for key in ['open', 'high', 'low', 'close']:
|
||||
next_candles[tf][key] = denormalize_prices(next_candles[tf][key])
|
||||
|
||||
predictions['next_candles'] = next_candles
|
||||
|
||||
# Extract pivot point predictions
|
||||
if 'next_pivots' in outputs:
|
||||
next_pivots = {}
|
||||
for level in self.pivot_levels:
|
||||
pivot_data = outputs['next_pivots'][f'L{level}']
|
||||
|
||||
# Extract values
|
||||
price = pivot_data['price']
|
||||
if price.dim() > 1:
|
||||
price = price[0, 0] if price.shape[0] > 0 else torch.tensor(0.0, device=device)
|
||||
price_val = float(price.cpu().detach().item() if hasattr(price, 'cpu') else price)
|
||||
|
||||
type_prob_high = pivot_data['type_prob_high']
|
||||
if type_prob_high.dim() > 1:
|
||||
type_prob_high = type_prob_high[0, 0] if type_prob_high.shape[0] > 0 else torch.tensor(0.0, device=device)
|
||||
prob_high = float(type_prob_high.cpu().detach().item() if hasattr(type_prob_high, 'cpu') else type_prob_high)
|
||||
|
||||
type_prob_low = pivot_data['type_prob_low']
|
||||
if type_prob_low.dim() > 1:
|
||||
type_prob_low = type_prob_low[0, 0] if type_prob_low.shape[0] > 0 else torch.tensor(0.0, device=device)
|
||||
prob_low = float(type_prob_low.cpu().detach().item() if hasattr(type_prob_low, 'cpu') else type_prob_low)
|
||||
|
||||
confidence = pivot_data['confidence']
|
||||
if confidence.dim() > 1:
|
||||
confidence = confidence[0, 0] if confidence.shape[0] > 0 else torch.tensor(0.0, device=device)
|
||||
conf_val = float(confidence.cpu().detach().item() if hasattr(confidence, 'cpu') else confidence)
|
||||
|
||||
pivot_type = pivot_data.get('pivot_type', torch.tensor(0))
|
||||
if isinstance(pivot_type, torch.Tensor):
|
||||
if pivot_type.dim() > 1:
|
||||
pivot_type = pivot_type[0, 0] if pivot_type.shape[0] > 0 else torch.tensor(0, device=device)
|
||||
pivot_type_val = int(pivot_type.cpu().detach().item() if hasattr(pivot_type, 'cpu') else pivot_type)
|
||||
else:
|
||||
pivot_type_val = int(pivot_type)
|
||||
|
||||
# Denormalize price if function provided
|
||||
if denormalize_prices and callable(denormalize_prices):
|
||||
price_val = denormalize_prices(price_val)
|
||||
|
||||
next_pivots[f'L{level}'] = {
|
||||
'price': price_val,
|
||||
'type': 'high' if pivot_type_val == 0 else 'low',
|
||||
'type_prob_high': prob_high,
|
||||
'type_prob_low': prob_low,
|
||||
'confidence': conf_val
|
||||
}
|
||||
|
||||
predictions['next_pivots'] = next_pivots
|
||||
|
||||
# Extract trend vector
|
||||
if 'trend_vector' in outputs:
|
||||
trend_vec = outputs['trend_vector']
|
||||
|
||||
# Extract pivot prices
|
||||
pivot_prices = trend_vec.get('pivot_prices', torch.zeros(5, device=device))
|
||||
if isinstance(pivot_prices, torch.Tensor):
|
||||
if pivot_prices.dim() > 1:
|
||||
pivot_prices = pivot_prices[0]
|
||||
pivot_prices_list = pivot_prices.cpu().detach().numpy().tolist() if hasattr(pivot_prices, 'cpu') else pivot_prices.tolist()
|
||||
else:
|
||||
pivot_prices_list = pivot_prices
|
||||
|
||||
# Denormalize pivot prices if function provided
|
||||
if denormalize_prices and callable(denormalize_prices):
|
||||
pivot_prices_list = [denormalize_prices(p) for p in pivot_prices_list]
|
||||
|
||||
angle = trend_vec.get('calculated_angle', torch.tensor(0.0, device=device))
|
||||
if isinstance(angle, torch.Tensor):
|
||||
if angle.dim() > 1:
|
||||
angle = angle[0, 0] if angle.shape[0] > 0 else torch.tensor(0.0, device=device)
|
||||
angle_val = float(angle.cpu().detach().item() if hasattr(angle, 'cpu') else angle)
|
||||
else:
|
||||
angle_val = float(angle)
|
||||
|
||||
steepness = trend_vec.get('calculated_steepness', torch.tensor(0.0, device=device))
|
||||
if isinstance(steepness, torch.Tensor):
|
||||
if steepness.dim() > 1:
|
||||
steepness = steepness[0, 0] if steepness.shape[0] > 0 else torch.tensor(0.0, device=device)
|
||||
steepness_val = float(steepness.cpu().detach().item() if hasattr(steepness, 'cpu') else steepness)
|
||||
else:
|
||||
steepness_val = float(steepness)
|
||||
|
||||
direction = trend_vec.get('calculated_direction', torch.tensor(0.0, device=device))
|
||||
if isinstance(direction, torch.Tensor):
|
||||
if direction.dim() > 1:
|
||||
direction = direction[0, 0] if direction.shape[0] > 0 else torch.tensor(0.0, device=device)
|
||||
direction_val = float(direction.cpu().detach().item() if hasattr(direction, 'cpu') else direction)
|
||||
else:
|
||||
direction_val = float(direction)
|
||||
|
||||
price_delta = trend_vec.get('price_delta', torch.tensor(0.0, device=device))
|
||||
if isinstance(price_delta, torch.Tensor):
|
||||
if price_delta.dim() > 0:
|
||||
price_delta = price_delta[0] if price_delta.shape[0] > 0 else torch.tensor(0.0, device=device)
|
||||
price_delta_val = float(price_delta.cpu().detach().item() if hasattr(price_delta, 'cpu') else price_delta)
|
||||
else:
|
||||
price_delta_val = float(price_delta)
|
||||
|
||||
predictions['trend_vector'] = {
|
||||
'pivot_prices': pivot_prices_list, # [L1, L2, L3, L4, L5]
|
||||
'angle_radians': angle_val,
|
||||
'angle_degrees': angle_val * 180.0 / math.pi,
|
||||
'steepness': steepness_val,
|
||||
'direction': 'up' if direction_val > 0 else 'down' if direction_val < 0 else 'sideways',
|
||||
'price_delta': price_delta_val
|
||||
}
|
||||
|
||||
# Extract trend-based action
|
||||
if 'trend_based_action' in outputs:
|
||||
trend_action = outputs['trend_based_action']
|
||||
|
||||
action_probs = trend_action.get('probabilities', torch.zeros(3, device=device))
|
||||
if isinstance(action_probs, torch.Tensor):
|
||||
if action_probs.dim() > 1:
|
||||
action_probs = action_probs[0]
|
||||
action_probs_list = action_probs.cpu().detach().numpy().tolist() if hasattr(action_probs, 'cpu') else action_probs.tolist()
|
||||
else:
|
||||
action_probs_list = action_probs
|
||||
|
||||
action_idx = trend_action.get('action_idx', torch.tensor(2, device=device))
|
||||
if isinstance(action_idx, torch.Tensor):
|
||||
if action_idx.dim() > 0:
|
||||
action_idx = action_idx[0] if action_idx.shape[0] > 0 else torch.tensor(2, device=device)
|
||||
action_idx_val = int(action_idx.cpu().detach().item() if hasattr(action_idx, 'cpu') else action_idx)
|
||||
else:
|
||||
action_idx_val = int(action_idx)
|
||||
|
||||
angle_degrees = trend_action.get('trend_angle_degrees', torch.tensor(0.0, device=device))
|
||||
if isinstance(angle_degrees, torch.Tensor):
|
||||
if angle_degrees.dim() > 0:
|
||||
angle_degrees = angle_degrees[0] if angle_degrees.shape[0] > 0 else torch.tensor(0.0, device=device)
|
||||
angle_degrees_val = float(angle_degrees.cpu().detach().item() if hasattr(angle_degrees, 'cpu') else angle_degrees)
|
||||
else:
|
||||
angle_degrees_val = float(angle_degrees)
|
||||
|
||||
steepness = trend_action.get('trend_steepness', torch.tensor(0.0, device=device))
|
||||
if isinstance(steepness, torch.Tensor):
|
||||
if steepness.dim() > 0:
|
||||
steepness = steepness[0] if steepness.shape[0] > 0 else torch.tensor(0.0, device=device)
|
||||
steepness_val = float(steepness.cpu().detach().item() if hasattr(steepness, 'cpu') else steepness)
|
||||
else:
|
||||
steepness_val = float(steepness)
|
||||
|
||||
action_names = ['BUY', 'SELL', 'HOLD']
|
||||
|
||||
predictions['trend_based_action'] = {
|
||||
'action': action_names[action_idx_val] if 0 <= action_idx_val < len(action_names) else 'HOLD',
|
||||
'action_idx': action_idx_val,
|
||||
'probabilities': {
|
||||
'BUY': float(action_probs_list[0]) if len(action_probs_list) > 0 else 0.0,
|
||||
'SELL': float(action_probs_list[1]) if len(action_probs_list) > 1 else 0.0,
|
||||
'HOLD': float(action_probs_list[2]) if len(action_probs_list) > 2 else 0.0
|
||||
},
|
||||
'trend_angle_degrees': angle_degrees_val,
|
||||
'trend_steepness': steepness_val
|
||||
}
|
||||
|
||||
return predictions
|
||||
|
||||
class TradingTransformerTrainer:
|
||||
"""Trainer for the advanced trading transformer"""
|
||||
|
||||
Reference in New Issue
Block a user