training fixes
This commit is contained in:
@@ -441,12 +441,19 @@ class RealTrainingAdapter:
|
|||||||
logger.debug(" No holding period, skipping HOLD samples")
|
logger.debug(" No holding period, skipping HOLD samples")
|
||||||
return hold_samples
|
return hold_samples
|
||||||
|
|
||||||
# Parse entry timestamp
|
# Parse entry timestamp - handle multiple formats
|
||||||
try:
|
try:
|
||||||
if 'T' in entry_timestamp:
|
if 'T' in entry_timestamp:
|
||||||
entry_time = datetime.fromisoformat(entry_timestamp.replace('Z', '+00:00'))
|
entry_time = datetime.fromisoformat(entry_timestamp.replace('Z', '+00:00'))
|
||||||
else:
|
else:
|
||||||
entry_time = datetime.strptime(entry_timestamp, '%Y-%m-%d %H:%M:%S')
|
# Try with seconds first, then without
|
||||||
|
try:
|
||||||
|
entry_time = datetime.strptime(entry_timestamp, '%Y-%m-%d %H:%M:%S')
|
||||||
|
except ValueError:
|
||||||
|
# Try without seconds
|
||||||
|
entry_time = datetime.strptime(entry_timestamp, '%Y-%m-%d %H:%M')
|
||||||
|
|
||||||
|
# Make timezone-aware
|
||||||
if pytz:
|
if pytz:
|
||||||
entry_time = entry_time.replace(tzinfo=pytz.UTC)
|
entry_time = entry_time.replace(tzinfo=pytz.UTC)
|
||||||
else:
|
else:
|
||||||
@@ -466,7 +473,21 @@ class RealTrainingAdapter:
|
|||||||
|
|
||||||
# Find all candles between entry and exit
|
# Find all candles between entry and exit
|
||||||
for idx, ts_str in enumerate(timestamps):
|
for idx, ts_str in enumerate(timestamps):
|
||||||
ts = datetime.fromisoformat(ts_str.replace(' ', 'T'))
|
# Parse timestamp and ensure it's timezone-aware
|
||||||
|
try:
|
||||||
|
if 'T' in ts_str:
|
||||||
|
ts = datetime.fromisoformat(ts_str.replace('Z', '+00:00'))
|
||||||
|
else:
|
||||||
|
ts = datetime.fromisoformat(ts_str.replace(' ', 'T'))
|
||||||
|
# Make timezone-aware if it's naive
|
||||||
|
if ts.tzinfo is None:
|
||||||
|
if pytz:
|
||||||
|
ts = ts.replace(tzinfo=pytz.UTC)
|
||||||
|
else:
|
||||||
|
ts = ts.replace(tzinfo=timezone.utc)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not parse timestamp '{ts_str}': {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
# If this candle is between entry and exit (exclusive)
|
# If this candle is between entry and exit (exclusive)
|
||||||
if entry_time < ts < exit_time:
|
if entry_time < ts < exit_time:
|
||||||
@@ -534,7 +555,14 @@ class RealTrainingAdapter:
|
|||||||
if 'T' in signal_timestamp:
|
if 'T' in signal_timestamp:
|
||||||
signal_time = datetime.fromisoformat(signal_timestamp.replace('Z', '+00:00'))
|
signal_time = datetime.fromisoformat(signal_timestamp.replace('Z', '+00:00'))
|
||||||
else:
|
else:
|
||||||
signal_time = datetime.strptime(signal_timestamp, '%Y-%m-%d %H:%M:%S')
|
# Try with seconds first, then without
|
||||||
|
try:
|
||||||
|
signal_time = datetime.strptime(signal_timestamp, '%Y-%m-%d %H:%M:%S')
|
||||||
|
except ValueError:
|
||||||
|
# Try without seconds
|
||||||
|
signal_time = datetime.strptime(signal_timestamp, '%Y-%m-%d %H:%M')
|
||||||
|
|
||||||
|
# Make timezone-aware
|
||||||
if pytz:
|
if pytz:
|
||||||
signal_time = signal_time.replace(tzinfo=pytz.UTC)
|
signal_time = signal_time.replace(tzinfo=pytz.UTC)
|
||||||
else:
|
else:
|
||||||
@@ -546,15 +574,22 @@ class RealTrainingAdapter:
|
|||||||
signal_index = None
|
signal_index = None
|
||||||
for idx, ts_str in enumerate(timestamps):
|
for idx, ts_str in enumerate(timestamps):
|
||||||
try:
|
try:
|
||||||
# Parse timestamp from market data
|
# Parse timestamp from market data - handle multiple formats
|
||||||
if 'T' in ts_str:
|
if 'T' in ts_str:
|
||||||
ts = datetime.fromisoformat(ts_str.replace('Z', '+00:00'))
|
ts = datetime.fromisoformat(ts_str.replace('Z', '+00:00'))
|
||||||
else:
|
else:
|
||||||
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
|
# Try with seconds first, then without
|
||||||
if pytz:
|
try:
|
||||||
ts = ts.replace(tzinfo=pytz.UTC)
|
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
|
||||||
else:
|
except ValueError:
|
||||||
ts = ts.replace(tzinfo=timezone.utc)
|
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M')
|
||||||
|
|
||||||
|
# Make timezone-aware if naive
|
||||||
|
if ts.tzinfo is None:
|
||||||
|
if pytz:
|
||||||
|
ts = ts.replace(tzinfo=pytz.UTC)
|
||||||
|
else:
|
||||||
|
ts = ts.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
# Match within 1 minute
|
# Match within 1 minute
|
||||||
if abs((ts - signal_time).total_seconds()) < 60:
|
if abs((ts - signal_time).total_seconds()) < 60:
|
||||||
@@ -1166,9 +1201,13 @@ class RealTrainingAdapter:
|
|||||||
batch = self._convert_annotation_to_transformer_batch(data)
|
batch = self._convert_annotation_to_transformer_batch(data)
|
||||||
if batch is not None:
|
if batch is not None:
|
||||||
# Repeat based on repetitions parameter
|
# Repeat based on repetitions parameter
|
||||||
|
# IMPORTANT: Clone each batch to avoid in-place operation issues when reusing tensors
|
||||||
repetitions = data.get('repetitions', 1)
|
repetitions = data.get('repetitions', 1)
|
||||||
for _ in range(repetitions):
|
for _ in range(repetitions):
|
||||||
converted_batches.append(batch)
|
# Clone all tensors in the batch to ensure independence
|
||||||
|
cloned_batch = {k: v.clone() if isinstance(v, torch.Tensor) else v
|
||||||
|
for k, v in batch.items()}
|
||||||
|
converted_batches.append(cloned_batch)
|
||||||
else:
|
else:
|
||||||
logger.warning(f" Failed to convert sample {i+1}")
|
logger.warning(f" Failed to convert sample {i+1}")
|
||||||
|
|
||||||
|
|||||||
@@ -950,8 +950,9 @@ class TradingTransformerTrainer:
|
|||||||
self.model.train()
|
self.model.train()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
# Move batch to device
|
# Clone and detach batch tensors before moving to device to avoid in-place operation issues
|
||||||
batch = {k: v.to(self.device) for k, v in batch.items()}
|
# This ensures each batch is independent and prevents gradient computation errors
|
||||||
|
batch = {k: v.detach().clone().to(self.device) for k, v in batch.items()}
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
@@ -965,24 +966,29 @@ class TradingTransformerTrainer:
|
|||||||
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
|
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
|
||||||
price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices'])
|
price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices'])
|
||||||
|
|
||||||
|
# Start with base losses - avoid inplace operations on computation graph
|
||||||
total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task
|
total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task
|
||||||
|
|
||||||
# Add confidence loss if available
|
# Add confidence loss if available
|
||||||
if 'confidence' in outputs and 'trade_success' in batch:
|
if 'confidence' in outputs and 'trade_success' in batch:
|
||||||
# Ensure both tensors have compatible shapes
|
# Ensure both tensors have compatible shapes for BCELoss
|
||||||
# confidence: [batch_size, 1] -> squeeze last dim to [batch_size]
|
# BCELoss requires both inputs to have the same shape
|
||||||
# trade_success: [batch_size] - ensure same shape
|
confidence_pred = outputs['confidence'] # Keep as [batch_size, 1]
|
||||||
confidence_pred = outputs['confidence'].squeeze(-1) # Only remove last dimension
|
|
||||||
trade_target = batch['trade_success'].float()
|
trade_target = batch['trade_success'].float()
|
||||||
|
|
||||||
# Ensure shapes match (handle edge case where batch_size=1)
|
# Reshape target to match prediction shape [batch_size, 1]
|
||||||
if confidence_pred.dim() == 0: # scalar case
|
if trade_target.dim() == 1:
|
||||||
confidence_pred = confidence_pred.unsqueeze(0)
|
trade_target = trade_target.unsqueeze(-1)
|
||||||
if trade_target.dim() == 0: # scalar case
|
|
||||||
trade_target = trade_target.unsqueeze(0)
|
# Ensure both have same shape
|
||||||
|
if confidence_pred.shape != trade_target.shape:
|
||||||
|
# If shapes still don't match, squeeze both to 1D
|
||||||
|
confidence_pred = confidence_pred.view(-1)
|
||||||
|
trade_target = trade_target.view(-1)
|
||||||
|
|
||||||
confidence_loss = self.confidence_criterion(confidence_pred, trade_target)
|
confidence_loss = self.confidence_criterion(confidence_pred, trade_target)
|
||||||
total_loss += 0.1 * confidence_loss
|
# Use addition instead of += to avoid inplace operation
|
||||||
|
total_loss = total_loss + 0.1 * confidence_loss
|
||||||
|
|
||||||
# Backward pass
|
# Backward pass
|
||||||
total_loss.backward()
|
total_loss.backward()
|
||||||
|
|||||||
Reference in New Issue
Block a user