buy/sell signals training - wip
This commit is contained in:
@ -193,9 +193,16 @@ class CNNModelPyTorch:
|
||||
|
||||
def train_epoch(self, X_train, y_train, future_prices=None, batch_size=32):
|
||||
"""Train for one epoch and return loss and accuracy"""
|
||||
# Convert to PyTorch tensors
|
||||
X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(self.device)
|
||||
y_train_tensor = torch.tensor(y_train, dtype=torch.long).to(self.device)
|
||||
# Convert to PyTorch tensors if they aren't already
|
||||
if not isinstance(X_train, torch.Tensor):
|
||||
X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(self.device)
|
||||
else:
|
||||
X_train_tensor = X_train.to(self.device)
|
||||
|
||||
if not isinstance(y_train, torch.Tensor):
|
||||
y_train_tensor = torch.tensor(y_train, dtype=torch.long).to(self.device)
|
||||
else:
|
||||
y_train_tensor = y_train.to(self.device)
|
||||
|
||||
# Create DataLoader
|
||||
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
|
||||
@ -206,16 +213,70 @@ class CNNModelPyTorch:
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
for inputs, targets in train_loader:
|
||||
# Initialize retrospective training metrics
|
||||
retrospective_correct = 0
|
||||
retrospective_total = 0
|
||||
|
||||
for batch_idx, (inputs, targets) in enumerate(train_loader):
|
||||
# Zero gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
outputs = self.model(inputs)
|
||||
|
||||
# Calculate loss
|
||||
# Calculate base loss
|
||||
loss = self.criterion(outputs, targets)
|
||||
|
||||
# Retrospective training if future prices are available
|
||||
if future_prices is not None:
|
||||
# Get the corresponding future prices for this batch
|
||||
batch_start = batch_idx * batch_size
|
||||
batch_end = min((batch_idx + 1) * batch_size, len(future_prices))
|
||||
|
||||
if not isinstance(future_prices, torch.Tensor):
|
||||
batch_future_prices = torch.tensor(
|
||||
future_prices[batch_start:batch_end],
|
||||
dtype=torch.float32
|
||||
).to(self.device)
|
||||
else:
|
||||
batch_future_prices = future_prices[batch_start:batch_end].to(self.device)
|
||||
|
||||
# Ensure batch_future_prices matches the batch size
|
||||
if len(batch_future_prices) < len(inputs):
|
||||
# Pad with the last value if needed
|
||||
pad_size = len(inputs) - len(batch_future_prices)
|
||||
last_value = batch_future_prices[-1].item()
|
||||
batch_future_prices = torch.cat([
|
||||
batch_future_prices,
|
||||
torch.full((pad_size,), last_value, device=self.device)
|
||||
])
|
||||
|
||||
# Calculate price changes for the next n candles
|
||||
current_prices = inputs[:, -1, 3] # Using close prices
|
||||
price_changes = (batch_future_prices - current_prices) / current_prices
|
||||
|
||||
# Create retrospective targets based on future price movements
|
||||
retrospective_targets = torch.ones_like(targets) # Default to HOLD (1)
|
||||
|
||||
# Create masks for local extrema
|
||||
local_max_mask = (price_changes > 0.001).to(torch.bool) # 0.1% threshold for local maximum
|
||||
local_min_mask = (price_changes < -0.001).to(torch.bool) # -0.1% threshold for local minimum
|
||||
|
||||
# Apply masks to set retrospective targets using torch.where
|
||||
retrospective_targets = torch.where(local_max_mask, torch.zeros_like(targets), retrospective_targets) # SELL at local max
|
||||
retrospective_targets = torch.where(local_min_mask, 2 * torch.ones_like(targets), retrospective_targets) # BUY at local min
|
||||
|
||||
# Calculate retrospective loss with higher weight for profitable signals
|
||||
retrospective_loss = self.criterion(outputs, retrospective_targets)
|
||||
|
||||
# Combine losses with higher weight for retrospective loss
|
||||
loss = 0.3 * loss + 0.7 * retrospective_loss
|
||||
|
||||
# Update retrospective metrics
|
||||
_, predicted = torch.max(outputs, 1)
|
||||
retrospective_correct += (predicted == retrospective_targets).sum().item()
|
||||
retrospective_total += targets.size(0)
|
||||
|
||||
# Backward pass and optimize
|
||||
loss.backward()
|
||||
|
||||
@ -233,18 +294,26 @@ class CNNModelPyTorch:
|
||||
epoch_loss = running_loss / len(train_loader)
|
||||
epoch_acc = correct / total if total > 0 else 0
|
||||
|
||||
# Update learning rate scheduler
|
||||
self.scheduler.step(epoch_acc)
|
||||
# Calculate retrospective metrics
|
||||
retrospective_acc = retrospective_correct / retrospective_total if retrospective_total > 0 else 0
|
||||
|
||||
# To maintain compatibility with the updated training code, we'll return 3 values
|
||||
# But the price_loss will be zero since we're not using that in this model
|
||||
return epoch_loss, 0.0, epoch_acc
|
||||
# Update learning rate scheduler based on retrospective accuracy
|
||||
self.scheduler.step(retrospective_acc)
|
||||
|
||||
return epoch_loss, retrospective_acc, epoch_acc
|
||||
|
||||
def evaluate(self, X_val, y_val, future_prices=None):
|
||||
"""Evaluate on validation data and return loss and accuracy"""
|
||||
# Convert to PyTorch tensors
|
||||
X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(self.device)
|
||||
y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(self.device)
|
||||
if not isinstance(X_val, torch.Tensor):
|
||||
X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(self.device)
|
||||
else:
|
||||
X_val_tensor = X_val.to(self.device)
|
||||
|
||||
if not isinstance(y_val, torch.Tensor):
|
||||
y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(self.device)
|
||||
else:
|
||||
y_val_tensor = y_val.to(self.device)
|
||||
|
||||
# Create DataLoader
|
||||
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
|
||||
@ -255,16 +324,70 @@ class CNNModelPyTorch:
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
# Initialize retrospective metrics
|
||||
retrospective_correct = 0
|
||||
retrospective_total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for inputs, targets in val_loader:
|
||||
for batch_idx, (inputs, targets) in enumerate(val_loader):
|
||||
# Forward pass
|
||||
outputs = self.model(inputs)
|
||||
|
||||
# Calculate loss
|
||||
# Calculate base loss
|
||||
loss = self.criterion(outputs, targets)
|
||||
running_loss += loss.item()
|
||||
|
||||
# Calculate accuracy
|
||||
# Retrospective evaluation if future prices are available
|
||||
if future_prices is not None:
|
||||
# Get the corresponding future prices for this batch
|
||||
batch_start = batch_idx * 32
|
||||
batch_end = min((batch_idx + 1) * 32, len(future_prices))
|
||||
|
||||
if not isinstance(future_prices, torch.Tensor):
|
||||
batch_future_prices = torch.tensor(
|
||||
future_prices[batch_start:batch_end],
|
||||
dtype=torch.float32
|
||||
).to(self.device)
|
||||
else:
|
||||
batch_future_prices = future_prices[batch_start:batch_end].to(self.device)
|
||||
|
||||
# Ensure batch_future_prices matches the batch size
|
||||
if len(batch_future_prices) < len(inputs):
|
||||
# Pad with the last value if needed
|
||||
pad_size = len(inputs) - len(batch_future_prices)
|
||||
last_value = batch_future_prices[-1].item()
|
||||
batch_future_prices = torch.cat([
|
||||
batch_future_prices,
|
||||
torch.full((pad_size,), last_value, device=self.device)
|
||||
])
|
||||
|
||||
# Calculate price changes for the next n candles
|
||||
current_prices = inputs[:, -1, 3] # Using close prices
|
||||
price_changes = (batch_future_prices - current_prices) / current_prices
|
||||
|
||||
# Create retrospective targets based on future price movements
|
||||
retrospective_targets = torch.ones_like(targets) # Default to HOLD (1)
|
||||
|
||||
# Create masks for local extrema
|
||||
local_max_mask = (price_changes > 0.001).to(torch.bool) # 0.1% threshold for local maximum
|
||||
local_min_mask = (price_changes < -0.001).to(torch.bool) # -0.1% threshold for local minimum
|
||||
|
||||
# Apply masks to set retrospective targets using torch.where
|
||||
retrospective_targets = torch.where(local_max_mask, torch.zeros_like(targets), retrospective_targets) # SELL at local max
|
||||
retrospective_targets = torch.where(local_min_mask, 2 * torch.ones_like(targets), retrospective_targets) # BUY at local min
|
||||
|
||||
# Calculate retrospective loss with higher weight for profitable signals
|
||||
retrospective_loss = self.criterion(outputs, retrospective_targets)
|
||||
|
||||
# Combine losses with higher weight for retrospective loss
|
||||
loss = 0.3 * loss + 0.7 * retrospective_loss
|
||||
|
||||
# Update retrospective metrics
|
||||
_, predicted = torch.max(outputs, 1)
|
||||
retrospective_correct += (predicted == retrospective_targets).sum().item()
|
||||
retrospective_total += targets.size(0)
|
||||
|
||||
# Update metrics
|
||||
running_loss += loss.item()
|
||||
_, predicted = torch.max(outputs, 1)
|
||||
total += targets.size(0)
|
||||
correct += (predicted == targets).sum().item()
|
||||
@ -272,9 +395,10 @@ class CNNModelPyTorch:
|
||||
val_loss = running_loss / len(val_loader)
|
||||
val_acc = correct / total if total > 0 else 0
|
||||
|
||||
# To maintain compatibility with the updated training code, we'll return 3 values
|
||||
# But the price_loss will be zero since we're not using that in this model
|
||||
return val_loss, 0.0, val_acc
|
||||
# Calculate retrospective metrics
|
||||
retrospective_acc = retrospective_correct / retrospective_total if retrospective_total > 0 else 0
|
||||
|
||||
return val_loss, val_acc, retrospective_acc
|
||||
|
||||
def predict(self, X):
|
||||
"""Make predictions on input data"""
|
||||
|
Reference in New Issue
Block a user