buy/sell signals training - wip
This commit is contained in:
parent
8b3db10a85
commit
1b9f471076
@ -193,9 +193,16 @@ class CNNModelPyTorch:
|
||||
|
||||
def train_epoch(self, X_train, y_train, future_prices=None, batch_size=32):
|
||||
"""Train for one epoch and return loss and accuracy"""
|
||||
# Convert to PyTorch tensors
|
||||
X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(self.device)
|
||||
y_train_tensor = torch.tensor(y_train, dtype=torch.long).to(self.device)
|
||||
# Convert to PyTorch tensors if they aren't already
|
||||
if not isinstance(X_train, torch.Tensor):
|
||||
X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(self.device)
|
||||
else:
|
||||
X_train_tensor = X_train.to(self.device)
|
||||
|
||||
if not isinstance(y_train, torch.Tensor):
|
||||
y_train_tensor = torch.tensor(y_train, dtype=torch.long).to(self.device)
|
||||
else:
|
||||
y_train_tensor = y_train.to(self.device)
|
||||
|
||||
# Create DataLoader
|
||||
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
|
||||
@ -206,16 +213,70 @@ class CNNModelPyTorch:
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
for inputs, targets in train_loader:
|
||||
# Initialize retrospective training metrics
|
||||
retrospective_correct = 0
|
||||
retrospective_total = 0
|
||||
|
||||
for batch_idx, (inputs, targets) in enumerate(train_loader):
|
||||
# Zero gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
outputs = self.model(inputs)
|
||||
|
||||
# Calculate loss
|
||||
# Calculate base loss
|
||||
loss = self.criterion(outputs, targets)
|
||||
|
||||
# Retrospective training if future prices are available
|
||||
if future_prices is not None:
|
||||
# Get the corresponding future prices for this batch
|
||||
batch_start = batch_idx * batch_size
|
||||
batch_end = min((batch_idx + 1) * batch_size, len(future_prices))
|
||||
|
||||
if not isinstance(future_prices, torch.Tensor):
|
||||
batch_future_prices = torch.tensor(
|
||||
future_prices[batch_start:batch_end],
|
||||
dtype=torch.float32
|
||||
).to(self.device)
|
||||
else:
|
||||
batch_future_prices = future_prices[batch_start:batch_end].to(self.device)
|
||||
|
||||
# Ensure batch_future_prices matches the batch size
|
||||
if len(batch_future_prices) < len(inputs):
|
||||
# Pad with the last value if needed
|
||||
pad_size = len(inputs) - len(batch_future_prices)
|
||||
last_value = batch_future_prices[-1].item()
|
||||
batch_future_prices = torch.cat([
|
||||
batch_future_prices,
|
||||
torch.full((pad_size,), last_value, device=self.device)
|
||||
])
|
||||
|
||||
# Calculate price changes for the next n candles
|
||||
current_prices = inputs[:, -1, 3] # Using close prices
|
||||
price_changes = (batch_future_prices - current_prices) / current_prices
|
||||
|
||||
# Create retrospective targets based on future price movements
|
||||
retrospective_targets = torch.ones_like(targets) # Default to HOLD (1)
|
||||
|
||||
# Create masks for local extrema
|
||||
local_max_mask = (price_changes > 0.001).to(torch.bool) # 0.1% threshold for local maximum
|
||||
local_min_mask = (price_changes < -0.001).to(torch.bool) # -0.1% threshold for local minimum
|
||||
|
||||
# Apply masks to set retrospective targets using torch.where
|
||||
retrospective_targets = torch.where(local_max_mask, torch.zeros_like(targets), retrospective_targets) # SELL at local max
|
||||
retrospective_targets = torch.where(local_min_mask, 2 * torch.ones_like(targets), retrospective_targets) # BUY at local min
|
||||
|
||||
# Calculate retrospective loss with higher weight for profitable signals
|
||||
retrospective_loss = self.criterion(outputs, retrospective_targets)
|
||||
|
||||
# Combine losses with higher weight for retrospective loss
|
||||
loss = 0.3 * loss + 0.7 * retrospective_loss
|
||||
|
||||
# Update retrospective metrics
|
||||
_, predicted = torch.max(outputs, 1)
|
||||
retrospective_correct += (predicted == retrospective_targets).sum().item()
|
||||
retrospective_total += targets.size(0)
|
||||
|
||||
# Backward pass and optimize
|
||||
loss.backward()
|
||||
|
||||
@ -233,18 +294,26 @@ class CNNModelPyTorch:
|
||||
epoch_loss = running_loss / len(train_loader)
|
||||
epoch_acc = correct / total if total > 0 else 0
|
||||
|
||||
# Update learning rate scheduler
|
||||
self.scheduler.step(epoch_acc)
|
||||
# Calculate retrospective metrics
|
||||
retrospective_acc = retrospective_correct / retrospective_total if retrospective_total > 0 else 0
|
||||
|
||||
# To maintain compatibility with the updated training code, we'll return 3 values
|
||||
# But the price_loss will be zero since we're not using that in this model
|
||||
return epoch_loss, 0.0, epoch_acc
|
||||
# Update learning rate scheduler based on retrospective accuracy
|
||||
self.scheduler.step(retrospective_acc)
|
||||
|
||||
return epoch_loss, retrospective_acc, epoch_acc
|
||||
|
||||
def evaluate(self, X_val, y_val, future_prices=None):
|
||||
"""Evaluate on validation data and return loss and accuracy"""
|
||||
# Convert to PyTorch tensors
|
||||
X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(self.device)
|
||||
y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(self.device)
|
||||
if not isinstance(X_val, torch.Tensor):
|
||||
X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(self.device)
|
||||
else:
|
||||
X_val_tensor = X_val.to(self.device)
|
||||
|
||||
if not isinstance(y_val, torch.Tensor):
|
||||
y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(self.device)
|
||||
else:
|
||||
y_val_tensor = y_val.to(self.device)
|
||||
|
||||
# Create DataLoader
|
||||
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
|
||||
@ -255,16 +324,70 @@ class CNNModelPyTorch:
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
# Initialize retrospective metrics
|
||||
retrospective_correct = 0
|
||||
retrospective_total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for inputs, targets in val_loader:
|
||||
for batch_idx, (inputs, targets) in enumerate(val_loader):
|
||||
# Forward pass
|
||||
outputs = self.model(inputs)
|
||||
|
||||
# Calculate loss
|
||||
# Calculate base loss
|
||||
loss = self.criterion(outputs, targets)
|
||||
running_loss += loss.item()
|
||||
|
||||
# Calculate accuracy
|
||||
# Retrospective evaluation if future prices are available
|
||||
if future_prices is not None:
|
||||
# Get the corresponding future prices for this batch
|
||||
batch_start = batch_idx * 32
|
||||
batch_end = min((batch_idx + 1) * 32, len(future_prices))
|
||||
|
||||
if not isinstance(future_prices, torch.Tensor):
|
||||
batch_future_prices = torch.tensor(
|
||||
future_prices[batch_start:batch_end],
|
||||
dtype=torch.float32
|
||||
).to(self.device)
|
||||
else:
|
||||
batch_future_prices = future_prices[batch_start:batch_end].to(self.device)
|
||||
|
||||
# Ensure batch_future_prices matches the batch size
|
||||
if len(batch_future_prices) < len(inputs):
|
||||
# Pad with the last value if needed
|
||||
pad_size = len(inputs) - len(batch_future_prices)
|
||||
last_value = batch_future_prices[-1].item()
|
||||
batch_future_prices = torch.cat([
|
||||
batch_future_prices,
|
||||
torch.full((pad_size,), last_value, device=self.device)
|
||||
])
|
||||
|
||||
# Calculate price changes for the next n candles
|
||||
current_prices = inputs[:, -1, 3] # Using close prices
|
||||
price_changes = (batch_future_prices - current_prices) / current_prices
|
||||
|
||||
# Create retrospective targets based on future price movements
|
||||
retrospective_targets = torch.ones_like(targets) # Default to HOLD (1)
|
||||
|
||||
# Create masks for local extrema
|
||||
local_max_mask = (price_changes > 0.001).to(torch.bool) # 0.1% threshold for local maximum
|
||||
local_min_mask = (price_changes < -0.001).to(torch.bool) # -0.1% threshold for local minimum
|
||||
|
||||
# Apply masks to set retrospective targets using torch.where
|
||||
retrospective_targets = torch.where(local_max_mask, torch.zeros_like(targets), retrospective_targets) # SELL at local max
|
||||
retrospective_targets = torch.where(local_min_mask, 2 * torch.ones_like(targets), retrospective_targets) # BUY at local min
|
||||
|
||||
# Calculate retrospective loss with higher weight for profitable signals
|
||||
retrospective_loss = self.criterion(outputs, retrospective_targets)
|
||||
|
||||
# Combine losses with higher weight for retrospective loss
|
||||
loss = 0.3 * loss + 0.7 * retrospective_loss
|
||||
|
||||
# Update retrospective metrics
|
||||
_, predicted = torch.max(outputs, 1)
|
||||
retrospective_correct += (predicted == retrospective_targets).sum().item()
|
||||
retrospective_total += targets.size(0)
|
||||
|
||||
# Update metrics
|
||||
running_loss += loss.item()
|
||||
_, predicted = torch.max(outputs, 1)
|
||||
total += targets.size(0)
|
||||
correct += (predicted == targets).sum().item()
|
||||
@ -272,9 +395,10 @@ class CNNModelPyTorch:
|
||||
val_loss = running_loss / len(val_loader)
|
||||
val_acc = correct / total if total > 0 else 0
|
||||
|
||||
# To maintain compatibility with the updated training code, we'll return 3 values
|
||||
# But the price_loss will be zero since we're not using that in this model
|
||||
return val_loss, 0.0, val_acc
|
||||
# Calculate retrospective metrics
|
||||
retrospective_acc = retrospective_correct / retrospective_total if retrospective_total > 0 else 0
|
||||
|
||||
return val_loss, val_acc, retrospective_acc
|
||||
|
||||
def predict(self, X):
|
||||
"""Make predictions on input data"""
|
||||
|
@ -395,10 +395,19 @@ class DataInterface:
|
||||
win_rate is the ratio of winning trades
|
||||
trades is a list of trade dictionaries
|
||||
"""
|
||||
if len(predictions) != len(actual_prices) - 1:
|
||||
logger.error("Predictions and prices length mismatch")
|
||||
# Ensure we have enough prices for the predictions
|
||||
if len(actual_prices) <= 1:
|
||||
logger.error("Not enough price data for PnL calculation")
|
||||
return 0.0, 0.0, []
|
||||
|
||||
# Adjust predictions length to match available price data
|
||||
n_prices = len(actual_prices) - 1 # We need current and next price for each prediction
|
||||
if len(predictions) > n_prices:
|
||||
predictions = predictions[:n_prices]
|
||||
elif len(predictions) < n_prices:
|
||||
n_prices = len(predictions)
|
||||
actual_prices = actual_prices[:n_prices + 1] # +1 to include the next price
|
||||
|
||||
pnl = 0.0
|
||||
trades = 0
|
||||
wins = 0
|
||||
@ -422,7 +431,7 @@ class DataInterface:
|
||||
'type': 'buy',
|
||||
'price': current_price,
|
||||
'pnl': trade_pnl,
|
||||
'timestamp': self.dataframes[self.timeframes[0]]['timestamp'].iloc[i]
|
||||
'timestamp': self.dataframes[self.timeframes[0]]['timestamp'].iloc[i] if self.dataframes[self.timeframes[0]] is not None else None
|
||||
})
|
||||
elif pred == 0: # Sell
|
||||
trade_pnl = -price_change * position_size
|
||||
@ -433,7 +442,7 @@ class DataInterface:
|
||||
'type': 'sell',
|
||||
'price': current_price,
|
||||
'pnl': trade_pnl,
|
||||
'timestamp': self.dataframes[self.timeframes[0]]['timestamp'].iloc[i]
|
||||
'timestamp': self.dataframes[self.timeframes[0]]['timestamp'].iloc[i] if self.dataframes[self.timeframes[0]] is not None else None
|
||||
})
|
||||
|
||||
pnl += trade_pnl if pred in [0, 2] else 0
|
||||
@ -443,29 +452,29 @@ class DataInterface:
|
||||
|
||||
def get_future_prices(self, prices, n_candles=3):
|
||||
"""
|
||||
Extract future prices for use in retrospective training.
|
||||
Extract future prices for retrospective training.
|
||||
|
||||
Args:
|
||||
prices: Array of close prices
|
||||
n_candles: Number of future candles to predict
|
||||
prices (np.ndarray): Array of prices
|
||||
n_candles (int): Number of future candles to look at
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of future prices for each sample
|
||||
np.ndarray: Future prices array
|
||||
"""
|
||||
if prices is None or len(prices) <= n_candles:
|
||||
logger.warning(f"Not enough price data for future prediction: {len(prices) if prices is not None else 0} prices")
|
||||
# Return zeros if not enough data
|
||||
return np.zeros((len(prices) if prices is not None else 0, 1))
|
||||
if len(prices) < n_candles + 1:
|
||||
return None
|
||||
|
||||
# For each price point, get the maximum price in the next n_candles
|
||||
future_prices = np.zeros(len(prices))
|
||||
|
||||
# For each price point i, get the price at i+n_candles
|
||||
future_prices = np.zeros((len(prices), 1))
|
||||
for i in range(len(prices) - n_candles):
|
||||
future_prices[i, 0] = prices[i + n_candles]
|
||||
# Get the next n candles
|
||||
next_candles = prices[i+1:i+n_candles+1]
|
||||
# Use the maximum price as the future price
|
||||
future_prices[i] = np.max(next_candles)
|
||||
|
||||
# For the last n_candles positions, we don't have future data
|
||||
# We'll use the last known price as a placeholder
|
||||
for i in range(len(prices) - n_candles, len(prices)):
|
||||
future_prices[i, 0] = prices[-1]
|
||||
# For the last n_candles points, use the last available price
|
||||
future_prices[-n_candles:] = prices[-1]
|
||||
|
||||
return future_prices
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user