training fixes

This commit is contained in:
Dobromir Popov
2025-10-31 01:29:05 +02:00
parent ba91740e4c
commit cefd30d2bd
2 changed files with 68 additions and 23 deletions

View File

@@ -441,12 +441,19 @@ class RealTrainingAdapter:
logger.debug(" No holding period, skipping HOLD samples") logger.debug(" No holding period, skipping HOLD samples")
return hold_samples return hold_samples
# Parse entry timestamp # Parse entry timestamp - handle multiple formats
try: try:
if 'T' in entry_timestamp: if 'T' in entry_timestamp:
entry_time = datetime.fromisoformat(entry_timestamp.replace('Z', '+00:00')) entry_time = datetime.fromisoformat(entry_timestamp.replace('Z', '+00:00'))
else: else:
entry_time = datetime.strptime(entry_timestamp, '%Y-%m-%d %H:%M:%S') # Try with seconds first, then without
try:
entry_time = datetime.strptime(entry_timestamp, '%Y-%m-%d %H:%M:%S')
except ValueError:
# Try without seconds
entry_time = datetime.strptime(entry_timestamp, '%Y-%m-%d %H:%M')
# Make timezone-aware
if pytz: if pytz:
entry_time = entry_time.replace(tzinfo=pytz.UTC) entry_time = entry_time.replace(tzinfo=pytz.UTC)
else: else:
@@ -466,7 +473,21 @@ class RealTrainingAdapter:
# Find all candles between entry and exit # Find all candles between entry and exit
for idx, ts_str in enumerate(timestamps): for idx, ts_str in enumerate(timestamps):
ts = datetime.fromisoformat(ts_str.replace(' ', 'T')) # Parse timestamp and ensure it's timezone-aware
try:
if 'T' in ts_str:
ts = datetime.fromisoformat(ts_str.replace('Z', '+00:00'))
else:
ts = datetime.fromisoformat(ts_str.replace(' ', 'T'))
# Make timezone-aware if it's naive
if ts.tzinfo is None:
if pytz:
ts = ts.replace(tzinfo=pytz.UTC)
else:
ts = ts.replace(tzinfo=timezone.utc)
except Exception as e:
logger.debug(f"Could not parse timestamp '{ts_str}': {e}")
continue
# If this candle is between entry and exit (exclusive) # If this candle is between entry and exit (exclusive)
if entry_time < ts < exit_time: if entry_time < ts < exit_time:
@@ -534,7 +555,14 @@ class RealTrainingAdapter:
if 'T' in signal_timestamp: if 'T' in signal_timestamp:
signal_time = datetime.fromisoformat(signal_timestamp.replace('Z', '+00:00')) signal_time = datetime.fromisoformat(signal_timestamp.replace('Z', '+00:00'))
else: else:
signal_time = datetime.strptime(signal_timestamp, '%Y-%m-%d %H:%M:%S') # Try with seconds first, then without
try:
signal_time = datetime.strptime(signal_timestamp, '%Y-%m-%d %H:%M:%S')
except ValueError:
# Try without seconds
signal_time = datetime.strptime(signal_timestamp, '%Y-%m-%d %H:%M')
# Make timezone-aware
if pytz: if pytz:
signal_time = signal_time.replace(tzinfo=pytz.UTC) signal_time = signal_time.replace(tzinfo=pytz.UTC)
else: else:
@@ -546,15 +574,22 @@ class RealTrainingAdapter:
signal_index = None signal_index = None
for idx, ts_str in enumerate(timestamps): for idx, ts_str in enumerate(timestamps):
try: try:
# Parse timestamp from market data # Parse timestamp from market data - handle multiple formats
if 'T' in ts_str: if 'T' in ts_str:
ts = datetime.fromisoformat(ts_str.replace('Z', '+00:00')) ts = datetime.fromisoformat(ts_str.replace('Z', '+00:00'))
else: else:
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S') # Try with seconds first, then without
if pytz: try:
ts = ts.replace(tzinfo=pytz.UTC) ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
else: except ValueError:
ts = ts.replace(tzinfo=timezone.utc) ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M')
# Make timezone-aware if naive
if ts.tzinfo is None:
if pytz:
ts = ts.replace(tzinfo=pytz.UTC)
else:
ts = ts.replace(tzinfo=timezone.utc)
# Match within 1 minute # Match within 1 minute
if abs((ts - signal_time).total_seconds()) < 60: if abs((ts - signal_time).total_seconds()) < 60:
@@ -1166,9 +1201,13 @@ class RealTrainingAdapter:
batch = self._convert_annotation_to_transformer_batch(data) batch = self._convert_annotation_to_transformer_batch(data)
if batch is not None: if batch is not None:
# Repeat based on repetitions parameter # Repeat based on repetitions parameter
# IMPORTANT: Clone each batch to avoid in-place operation issues when reusing tensors
repetitions = data.get('repetitions', 1) repetitions = data.get('repetitions', 1)
for _ in range(repetitions): for _ in range(repetitions):
converted_batches.append(batch) # Clone all tensors in the batch to ensure independence
cloned_batch = {k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
converted_batches.append(cloned_batch)
else: else:
logger.warning(f" Failed to convert sample {i+1}") logger.warning(f" Failed to convert sample {i+1}")

View File

@@ -950,8 +950,9 @@ class TradingTransformerTrainer:
self.model.train() self.model.train()
self.optimizer.zero_grad() self.optimizer.zero_grad()
# Move batch to device # Clone and detach batch tensors before moving to device to avoid in-place operation issues
batch = {k: v.to(self.device) for k, v in batch.items()} # This ensures each batch is independent and prevents gradient computation errors
batch = {k: v.detach().clone().to(self.device) for k, v in batch.items()}
# Forward pass # Forward pass
outputs = self.model( outputs = self.model(
@@ -965,24 +966,29 @@ class TradingTransformerTrainer:
action_loss = self.action_criterion(outputs['action_logits'], batch['actions']) action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices']) price_loss = self.price_criterion(outputs['price_prediction'], batch['future_prices'])
# Start with base losses - avoid inplace operations on computation graph
total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task total_loss = action_loss + 0.1 * price_loss # Weight auxiliary task
# Add confidence loss if available # Add confidence loss if available
if 'confidence' in outputs and 'trade_success' in batch: if 'confidence' in outputs and 'trade_success' in batch:
# Ensure both tensors have compatible shapes # Ensure both tensors have compatible shapes for BCELoss
# confidence: [batch_size, 1] -> squeeze last dim to [batch_size] # BCELoss requires both inputs to have the same shape
# trade_success: [batch_size] - ensure same shape confidence_pred = outputs['confidence'] # Keep as [batch_size, 1]
confidence_pred = outputs['confidence'].squeeze(-1) # Only remove last dimension
trade_target = batch['trade_success'].float() trade_target = batch['trade_success'].float()
# Ensure shapes match (handle edge case where batch_size=1) # Reshape target to match prediction shape [batch_size, 1]
if confidence_pred.dim() == 0: # scalar case if trade_target.dim() == 1:
confidence_pred = confidence_pred.unsqueeze(0) trade_target = trade_target.unsqueeze(-1)
if trade_target.dim() == 0: # scalar case
trade_target = trade_target.unsqueeze(0) # Ensure both have same shape
if confidence_pred.shape != trade_target.shape:
# If shapes still don't match, squeeze both to 1D
confidence_pred = confidence_pred.view(-1)
trade_target = trade_target.view(-1)
confidence_loss = self.confidence_criterion(confidence_pred, trade_target) confidence_loss = self.confidence_criterion(confidence_pred, trade_target)
total_loss += 0.1 * confidence_loss # Use addition instead of += to avoid inplace operation
total_loss = total_loss + 0.1 * confidence_loss
# Backward pass # Backward pass
total_loss.backward() total_loss.backward()