price vector predictions
This commit is contained in:
@ -7,6 +7,7 @@ import time
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Tuple, Dict, Any, Optional, Union
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -283,10 +284,59 @@ class EnhancedCNN(nn.Module):
|
||||
nn.Linear(256, 2) # [direction, confidence]
|
||||
)
|
||||
|
||||
# MULTI-TIMEFRAME PRICE VECTOR PREDICTION HEADS
|
||||
# Short-term: 1-5 minutes prediction
|
||||
self.short_term_vector_head = nn.Sequential(
|
||||
nn.Linear(1024, 1024),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 4) # [direction, confidence, magnitude, volatility_risk]
|
||||
)
|
||||
|
||||
# Mid-term: 5-30 minutes prediction
|
||||
self.mid_term_vector_head = nn.Sequential(
|
||||
nn.Linear(1024, 1024),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 4) # [direction, confidence, magnitude, volatility_risk]
|
||||
)
|
||||
|
||||
# Long-term: 30-120 minutes prediction
|
||||
self.long_term_vector_head = nn.Sequential(
|
||||
nn.Linear(1024, 1024),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 4) # [direction, confidence, magnitude, volatility_risk]
|
||||
)
|
||||
|
||||
# Direction activation (tanh for -1 to 1)
|
||||
self.direction_activation = nn.Tanh()
|
||||
# Confidence activation (sigmoid for 0 to 1)
|
||||
self.confidence_activation = nn.Sigmoid()
|
||||
# Magnitude activation (sigmoid for 0 to 1, will be scaled)
|
||||
self.magnitude_activation = nn.Sigmoid()
|
||||
# Volatility risk activation (sigmoid for 0 to 1)
|
||||
self.volatility_activation = nn.Sigmoid()
|
||||
|
||||
# INFERENCE RECORD STORAGE for long-term training
|
||||
self.inference_records = []
|
||||
self.max_inference_records = 50
|
||||
self.training_loss_history = []
|
||||
|
||||
# ULTRA MASSIVE value prediction with ensemble approaches
|
||||
self.price_pred_value = nn.Sequential(
|
||||
@ -484,6 +534,34 @@ class EnhancedCNN(nn.Module):
|
||||
confidence = self.confidence_activation(price_direction_raw[:, 1:2]) # 0 to 1
|
||||
price_direction_pred = torch.cat([direction, confidence], dim=1) # [batch, 2]
|
||||
|
||||
# MULTI-TIMEFRAME PRICE VECTOR PREDICTIONS
|
||||
short_term_vector_pred = self.short_term_vector_head(features_refined)
|
||||
mid_term_vector_pred = self.mid_term_vector_head(features_refined)
|
||||
long_term_vector_pred = self.long_term_vector_head(features_refined)
|
||||
|
||||
# Apply separate activations to direction, confidence, magnitude, volatility_risk
|
||||
short_term_direction = self.direction_activation(short_term_vector_pred[:, 0:1])
|
||||
short_term_confidence = self.confidence_activation(short_term_vector_pred[:, 1:2])
|
||||
short_term_magnitude = self.magnitude_activation(short_term_vector_pred[:, 2:3])
|
||||
short_term_volatility_risk = self.volatility_activation(short_term_vector_pred[:, 3:4])
|
||||
|
||||
mid_term_direction = self.direction_activation(mid_term_vector_pred[:, 0:1])
|
||||
mid_term_confidence = self.confidence_activation(mid_term_vector_pred[:, 1:2])
|
||||
mid_term_magnitude = self.magnitude_activation(mid_term_vector_pred[:, 2:3])
|
||||
mid_term_volatility_risk = self.volatility_activation(mid_term_vector_pred[:, 3:4])
|
||||
|
||||
long_term_direction = self.direction_activation(long_term_vector_pred[:, 0:1])
|
||||
long_term_confidence = self.confidence_activation(long_term_vector_pred[:, 1:2])
|
||||
long_term_magnitude = self.magnitude_activation(long_term_vector_pred[:, 2:3])
|
||||
long_term_volatility_risk = self.volatility_activation(long_term_vector_pred[:, 3:4])
|
||||
|
||||
# Package multi-timeframe predictions into a single tensor
|
||||
multi_timeframe_predictions = torch.cat([
|
||||
short_term_direction, short_term_confidence, short_term_magnitude, short_term_volatility_risk,
|
||||
mid_term_direction, mid_term_confidence, mid_term_magnitude, mid_term_volatility_risk,
|
||||
long_term_direction, long_term_confidence, long_term_magnitude, long_term_volatility_risk
|
||||
], dim=1) # [batch, 4*3]
|
||||
|
||||
price_values = self.price_pred_value(features_refined)
|
||||
|
||||
# Additional specialized predictions for enhanced accuracy
|
||||
@ -499,7 +577,7 @@ class EnhancedCNN(nn.Module):
|
||||
# For compatibility with DQN agent, we return volatility_pred as the advanced prediction tensor
|
||||
advanced_pred_tensor = volatility_pred
|
||||
|
||||
return q_values, extrema_pred, price_direction_tensor, features_refined, advanced_pred_tensor
|
||||
return q_values, extrema_pred, price_direction_tensor, features_refined, advanced_pred_tensor, multi_timeframe_predictions
|
||||
|
||||
def act(self, state, explore=True) -> Tuple[int, float, List[float]]:
|
||||
"""Enhanced action selection with ultra massive model predictions"""
|
||||
@ -517,7 +595,7 @@ class EnhancedCNN(nn.Module):
|
||||
state_tensor = state_tensor.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
q_values, extrema_pred, price_direction_predictions, features, advanced_predictions = self(state_tensor)
|
||||
q_values, extrema_pred, price_direction_predictions, features, advanced_predictions, multi_timeframe_predictions = self(state_tensor)
|
||||
|
||||
# Process price direction predictions
|
||||
if price_direction_predictions is not None:
|
||||
@ -762,6 +840,286 @@ class EnhancedCNN(nn.Module):
|
||||
logger.error(f"Error loading model: {str(e)}")
|
||||
return False
|
||||
|
||||
def store_inference_record(self, input_data, prediction_output, metadata=None):
|
||||
"""Store inference record for long-term training"""
|
||||
try:
|
||||
record = {
|
||||
'timestamp': datetime.now(),
|
||||
'input_data': input_data.clone().detach() if isinstance(input_data, torch.Tensor) else input_data,
|
||||
'prediction_output': {
|
||||
'q_values': prediction_output[0].clone().detach() if prediction_output[0] is not None else None,
|
||||
'extrema_pred': prediction_output[1].clone().detach() if prediction_output[1] is not None else None,
|
||||
'price_direction': prediction_output[2].clone().detach() if prediction_output[2] is not None else None,
|
||||
'multi_timeframe': prediction_output[5].clone().detach() if len(prediction_output) > 5 and prediction_output[5] is not None else None
|
||||
},
|
||||
'metadata': metadata or {}
|
||||
}
|
||||
|
||||
self.inference_records.append(record)
|
||||
|
||||
# Keep only the last max_inference_records
|
||||
if len(self.inference_records) > self.max_inference_records:
|
||||
self.inference_records = self.inference_records[-self.max_inference_records:]
|
||||
|
||||
logger.debug(f"CNN: Stored inference record. Total records: {len(self.inference_records)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing CNN inference record: {e}")
|
||||
|
||||
def calculate_price_vector_loss(self, predicted_vectors, actual_price_changes, time_diffs):
|
||||
"""
|
||||
Calculate price vector loss for multi-timeframe predictions
|
||||
|
||||
Args:
|
||||
predicted_vectors: Dict with 'short_term', 'mid_term', 'long_term' predictions
|
||||
actual_price_changes: Dict with corresponding actual price changes
|
||||
time_diffs: Dict with time differences for each timeframe
|
||||
|
||||
Returns:
|
||||
Total loss tensor for backpropagation
|
||||
"""
|
||||
try:
|
||||
total_loss = 0.0
|
||||
loss_count = 0
|
||||
|
||||
timeframes = ['short_term', 'mid_term', 'long_term']
|
||||
weights = [1.0, 0.8, 0.6] # Weight short-term predictions higher
|
||||
|
||||
for timeframe, weight in zip(timeframes, weights):
|
||||
if timeframe in predicted_vectors and timeframe in actual_price_changes:
|
||||
pred_vector = predicted_vectors[timeframe]
|
||||
actual_change = actual_price_changes[timeframe]
|
||||
time_diff = time_diffs.get(timeframe, 1.0)
|
||||
|
||||
# Extract prediction components [direction, confidence, magnitude, volatility_risk]
|
||||
pred_direction = pred_vector[0].item() if isinstance(pred_vector, torch.Tensor) else pred_vector[0]
|
||||
pred_confidence = pred_vector[1].item() if isinstance(pred_vector, torch.Tensor) else pred_vector[1]
|
||||
pred_magnitude = pred_vector[2].item() if isinstance(pred_vector, torch.Tensor) else pred_vector[2]
|
||||
pred_volatility = pred_vector[3].item() if isinstance(pred_vector, torch.Tensor) else pred_vector[3]
|
||||
|
||||
# Calculate actual metrics
|
||||
actual_direction = 1.0 if actual_change > 0.05 else -1.0 if actual_change < -0.05 else 0.0
|
||||
actual_magnitude = min(abs(actual_change) / 5.0, 1.0) # Normalize to 0-1, cap at 5%
|
||||
|
||||
# Direction loss (most important)
|
||||
if actual_direction != 0.0:
|
||||
direction_error = abs(pred_direction - actual_direction)
|
||||
else:
|
||||
direction_error = abs(pred_direction) * 0.5 # Penalty for predicting movement when there's none
|
||||
|
||||
# Magnitude loss
|
||||
magnitude_error = abs(pred_magnitude - actual_magnitude)
|
||||
|
||||
# Confidence calibration loss (confidence should match accuracy)
|
||||
direction_accuracy = 1.0 - (direction_error / 2.0) # 0 to 1
|
||||
confidence_error = abs(pred_confidence - direction_accuracy)
|
||||
|
||||
# Time decay factor
|
||||
time_decay = max(0.1, 1.0 - (time_diff / 60.0)) # Decay over 1 hour
|
||||
|
||||
# Combined loss for this timeframe
|
||||
timeframe_loss = (
|
||||
direction_error * 2.0 + # Direction is most important
|
||||
magnitude_error * 1.5 + # Magnitude is important
|
||||
confidence_error * 1.0 # Confidence calibration
|
||||
) * time_decay * weight
|
||||
|
||||
total_loss += timeframe_loss
|
||||
loss_count += 1
|
||||
|
||||
logger.debug(f"CNN {timeframe.upper()} VECTOR LOSS: "
|
||||
f"dir_err={direction_error:.3f}, mag_err={magnitude_error:.3f}, "
|
||||
f"conf_err={confidence_error:.3f}, total={timeframe_loss:.3f}")
|
||||
|
||||
if loss_count > 0:
|
||||
avg_loss = total_loss / loss_count
|
||||
return torch.tensor(avg_loss, dtype=torch.float32, device=self.device, requires_grad=True)
|
||||
else:
|
||||
return torch.tensor(0.0, dtype=torch.float32, device=self.device, requires_grad=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating CNN price vector loss: {e}")
|
||||
return torch.tensor(0.0, dtype=torch.float32, device=self.device, requires_grad=True)
|
||||
|
||||
def train_on_stored_records(self, optimizer, min_records=10):
|
||||
"""
|
||||
Train on stored inference records for long-term price vector prediction
|
||||
|
||||
Args:
|
||||
optimizer: PyTorch optimizer
|
||||
min_records: Minimum number of records needed for training
|
||||
|
||||
Returns:
|
||||
Average training loss
|
||||
"""
|
||||
try:
|
||||
if len(self.inference_records) < min_records:
|
||||
logger.debug(f"CNN: Not enough records for long-term training ({len(self.inference_records)} < {min_records})")
|
||||
return 0.0
|
||||
|
||||
self.train()
|
||||
total_loss = 0.0
|
||||
trained_count = 0
|
||||
|
||||
# Process records in batches
|
||||
batch_size = min(8, len(self.inference_records))
|
||||
for i in range(0, len(self.inference_records), batch_size):
|
||||
batch_records = self.inference_records[i:i+batch_size]
|
||||
|
||||
batch_inputs = []
|
||||
batch_targets = []
|
||||
|
||||
for record in batch_records:
|
||||
# Check if we have actual price movement data for this record
|
||||
if 'actual_price_changes' in record['metadata'] and 'time_diffs' in record['metadata']:
|
||||
batch_inputs.append(record['input_data'])
|
||||
batch_targets.append({
|
||||
'actual_price_changes': record['metadata']['actual_price_changes'],
|
||||
'time_diffs': record['metadata']['time_diffs']
|
||||
})
|
||||
|
||||
if not batch_inputs:
|
||||
continue
|
||||
|
||||
# Stack inputs into batch tensor
|
||||
if isinstance(batch_inputs[0], torch.Tensor):
|
||||
batch_input_tensor = torch.stack(batch_inputs).to(self.device)
|
||||
else:
|
||||
batch_input_tensor = torch.tensor(batch_inputs, dtype=torch.float32, device=self.device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
q_values, extrema_pred, price_direction_pred, features, advanced_pred, multi_timeframe_pred = self(batch_input_tensor)
|
||||
|
||||
# Calculate price vector losses for the batch
|
||||
batch_loss = 0.0
|
||||
for j, target in enumerate(batch_targets):
|
||||
# Extract multi-timeframe predictions for this sample
|
||||
sample_multi_pred = multi_timeframe_pred[j] if multi_timeframe_pred is not None else None
|
||||
|
||||
if sample_multi_pred is not None:
|
||||
predicted_vectors = {
|
||||
'short_term': sample_multi_pred[0:4], # [direction, confidence, magnitude, volatility]
|
||||
'mid_term': sample_multi_pred[4:8], # [direction, confidence, magnitude, volatility]
|
||||
'long_term': sample_multi_pred[8:12] # [direction, confidence, magnitude, volatility]
|
||||
}
|
||||
|
||||
sample_loss = self.calculate_price_vector_loss(
|
||||
predicted_vectors,
|
||||
target['actual_price_changes'],
|
||||
target['time_diffs']
|
||||
)
|
||||
batch_loss += sample_loss
|
||||
|
||||
if batch_loss > 0:
|
||||
avg_batch_loss = batch_loss / len(batch_targets)
|
||||
avg_batch_loss.backward()
|
||||
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
total_loss += avg_batch_loss.item()
|
||||
trained_count += 1
|
||||
|
||||
avg_loss = total_loss / max(trained_count, 1)
|
||||
self.training_loss_history.append(avg_loss)
|
||||
|
||||
# Keep only last 100 loss values
|
||||
if len(self.training_loss_history) > 100:
|
||||
self.training_loss_history = self.training_loss_history[-100:]
|
||||
|
||||
logger.info(f"CNN: Trained on {trained_count} batches from {len(self.inference_records)} stored records. Avg loss: {avg_loss:.4f}")
|
||||
return avg_loss
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training CNN on stored records: {e}")
|
||||
return 0.0
|
||||
|
||||
def process_price_direction_predictions(self, price_direction_tensor):
|
||||
"""
|
||||
Process price direction predictions into a standardized format
|
||||
Compatible with orchestrator's price vector system
|
||||
|
||||
Args:
|
||||
price_direction_tensor: Tensor with [direction, confidence] or multi-timeframe predictions
|
||||
|
||||
Returns:
|
||||
Dict with direction and confidence for compatibility
|
||||
"""
|
||||
try:
|
||||
if price_direction_tensor is None:
|
||||
return None
|
||||
|
||||
if isinstance(price_direction_tensor, torch.Tensor):
|
||||
if price_direction_tensor.dim() > 1:
|
||||
price_direction_tensor = price_direction_tensor.squeeze(0)
|
||||
|
||||
# Extract short-term prediction (most immediate) for compatibility
|
||||
direction = float(price_direction_tensor[0].item())
|
||||
confidence = float(price_direction_tensor[1].item())
|
||||
|
||||
return {
|
||||
'direction': direction,
|
||||
'confidence': confidence
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error processing CNN price direction predictions: {e}")
|
||||
return None
|
||||
|
||||
def get_multi_timeframe_predictions(self, multi_timeframe_tensor):
|
||||
"""
|
||||
Extract multi-timeframe price vector predictions
|
||||
|
||||
Args:
|
||||
multi_timeframe_tensor: Tensor with all timeframe predictions
|
||||
|
||||
Returns:
|
||||
Dict with short_term, mid_term, long_term predictions
|
||||
"""
|
||||
try:
|
||||
if multi_timeframe_tensor is None:
|
||||
return {}
|
||||
|
||||
if isinstance(multi_timeframe_tensor, torch.Tensor):
|
||||
if multi_timeframe_tensor.dim() > 1:
|
||||
multi_timeframe_tensor = multi_timeframe_tensor.squeeze(0)
|
||||
|
||||
predictions = {
|
||||
'short_term': {
|
||||
'direction': float(multi_timeframe_tensor[0].item()),
|
||||
'confidence': float(multi_timeframe_tensor[1].item()),
|
||||
'magnitude': float(multi_timeframe_tensor[2].item()),
|
||||
'volatility_risk': float(multi_timeframe_tensor[3].item())
|
||||
},
|
||||
'mid_term': {
|
||||
'direction': float(multi_timeframe_tensor[4].item()),
|
||||
'confidence': float(multi_timeframe_tensor[5].item()),
|
||||
'magnitude': float(multi_timeframe_tensor[6].item()),
|
||||
'volatility_risk': float(multi_timeframe_tensor[7].item())
|
||||
},
|
||||
'long_term': {
|
||||
'direction': float(multi_timeframe_tensor[8].item()),
|
||||
'confidence': float(multi_timeframe_tensor[9].item()),
|
||||
'magnitude': float(multi_timeframe_tensor[10].item()),
|
||||
'volatility_risk': float(multi_timeframe_tensor[11].item())
|
||||
}
|
||||
}
|
||||
|
||||
return predictions
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting multi-timeframe predictions: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# Additional utility for example sifting
|
||||
class ExampleSiftingDataset:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user