training fixes
This commit is contained in:
@@ -160,6 +160,12 @@ class RealTrainingAdapter:
|
||||
self.data_provider = data_provider
|
||||
self.training_sessions: Dict[str, TrainingSession] = {}
|
||||
|
||||
# CRITICAL: Training lock to prevent concurrent model access
|
||||
# Multiple threads (batch training + per-candle training) can corrupt
|
||||
# the computation graph if they access the model simultaneously
|
||||
import threading
|
||||
self._training_lock = threading.Lock()
|
||||
|
||||
# Real-time training tracking
|
||||
self.realtime_training_metrics = {
|
||||
'total_steps': 0,
|
||||
@@ -2614,9 +2620,12 @@ class RealTrainingAdapter:
|
||||
symbol = batch.get('metadata', {}).get('symbol', 'ETH/USDT')
|
||||
self._store_training_prediction(batch, trainer, symbol)
|
||||
|
||||
# Call the trainer's train_step method with mini-batch
|
||||
# Batch is already on GPU and contains multiple samples
|
||||
result = trainer.train_step(batch, accumulate_gradients=False)
|
||||
# CRITICAL: Acquire training lock to prevent concurrent model access
|
||||
# This prevents "inplace operation" errors when per-candle training runs simultaneously
|
||||
with self._training_lock:
|
||||
# Call the trainer's train_step method with mini-batch
|
||||
# Batch is already on GPU and contains multiple samples
|
||||
result = trainer.train_step(batch, accumulate_gradients=False)
|
||||
|
||||
if result is not None:
|
||||
# MEMORY FIX: Detach all tensor values to break computation graph
|
||||
@@ -3574,11 +3583,13 @@ class RealTrainingAdapter:
|
||||
logger.warning(f"Per-candle training skipped: Missing required keys: {missing_keys}")
|
||||
return
|
||||
|
||||
# Train on this batch
|
||||
# CRITICAL: Acquire training lock to prevent concurrent model access
|
||||
# This prevents "inplace operation" errors when batch training runs simultaneously
|
||||
import torch
|
||||
with torch.enable_grad():
|
||||
trainer.model.train()
|
||||
result = trainer.train_step(batch, accumulate_gradients=False)
|
||||
with self._training_lock:
|
||||
with torch.enable_grad():
|
||||
trainer.model.train()
|
||||
result = trainer.train_step(batch, accumulate_gradients=False)
|
||||
|
||||
if result:
|
||||
loss = result.get('total_loss', 0)
|
||||
|
||||
Reference in New Issue
Block a user