training fixes

This commit is contained in:
Dobromir Popov
2025-12-08 22:09:43 +02:00
parent 08ee2b6a3a
commit 600bee98f3
2 changed files with 69 additions and 33 deletions

View File

@@ -160,6 +160,12 @@ class RealTrainingAdapter:
self.data_provider = data_provider
self.training_sessions: Dict[str, TrainingSession] = {}
# CRITICAL: Training lock to prevent concurrent model access
# Multiple threads (batch training + per-candle training) can corrupt
# the computation graph if they access the model simultaneously
import threading
self._training_lock = threading.Lock()
# Real-time training tracking
self.realtime_training_metrics = {
'total_steps': 0,
@@ -2614,9 +2620,12 @@ class RealTrainingAdapter:
symbol = batch.get('metadata', {}).get('symbol', 'ETH/USDT')
self._store_training_prediction(batch, trainer, symbol)
# Call the trainer's train_step method with mini-batch
# Batch is already on GPU and contains multiple samples
result = trainer.train_step(batch, accumulate_gradients=False)
# CRITICAL: Acquire training lock to prevent concurrent model access
# This prevents "inplace operation" errors when per-candle training runs simultaneously
with self._training_lock:
# Call the trainer's train_step method with mini-batch
# Batch is already on GPU and contains multiple samples
result = trainer.train_step(batch, accumulate_gradients=False)
if result is not None:
# MEMORY FIX: Detach all tensor values to break computation graph
@@ -3574,11 +3583,13 @@ class RealTrainingAdapter:
logger.warning(f"Per-candle training skipped: Missing required keys: {missing_keys}")
return
# Train on this batch
# CRITICAL: Acquire training lock to prevent concurrent model access
# This prevents "inplace operation" errors when batch training runs simultaneously
import torch
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False)
with self._training_lock:
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False)
if result:
loss = result.get('total_loss', 0)