change to CPU

This commit is contained in:
Dobromir Popov
2025-12-08 16:56:26 +02:00
parent 1aded7325f
commit 81e7e6bfe6
5 changed files with 178 additions and 41 deletions

View File

@@ -1049,20 +1049,24 @@ class TradingTransformerTrainer:
def __init__(self, model: AdvancedTradingTransformer, config: TradingTransformerConfig):
self.model = model
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Determine device from config or auto-detect
self.device = self._get_device_from_config()
# Move model to device
self.model.to(self.device)
logger.info(f"Model moved to device: {self.device}")
# Log GPU info if available
if torch.cuda.is_available():
if self.device.type == 'cuda' and torch.cuda.is_available():
logger.info(f" GPU: {torch.cuda.get_device_name(0)}")
logger.info(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
elif self.device.type == 'cpu':
logger.info(" Using CPU (GPU disabled or unavailable)")
# MEMORY OPTIMIZATION: Enable gradient checkpointing if configured
# This trades 20% compute for 30-40% memory savings
if config.use_gradient_checkpointing:
if self.config.use_gradient_checkpointing:
logger.info("Enabling gradient checkpointing for memory efficiency")
self._enable_gradient_checkpointing()
@@ -1077,11 +1081,82 @@ class TradingTransformerTrainer:
# Optimizer with warmup
self.optimizer = optim.AdamW(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
# Learning rate scheduler
self.scheduler = optim.lr_scheduler.OneCycleLR(
self.optimizer,
max_lr=self.config.learning_rate,
total_steps=10000, # Will be updated based on training data
pct_start=0.1
)
# Loss functions with class weights
# Pivot-based training: BUY at L pivots, SELL at H pivots (naturally balanced)
# Weights: [HOLD=0, BUY=1, SELL=2] - equal weighting for pivot-based trades
class_weights = torch.tensor([0.5, 1.0, 1.0], dtype=torch.float32, device=self.device)
self.action_criterion = nn.CrossEntropyLoss(weight=class_weights)
self.price_criterion = nn.MSELoss()
self.confidence_criterion = nn.BCELoss()
# Training history
self.training_history = {
'train_loss': [],
'val_loss': [],
'train_accuracy': [],
'val_accuracy': [],
'epochs': []
}
def _get_device_from_config(self) -> torch.device:
"""Get device from config.yaml or auto-detect"""
try:
# Try to load config
from core.config import get_config
config = get_config()
gpu_config = config._config.get('gpu', {})
device_setting = gpu_config.get('device', 'auto')
fallback_to_cpu = gpu_config.get('fallback_to_cpu', True)
gpu_enabled = gpu_config.get('enabled', True)
# If GPU is disabled in config, use CPU
if not gpu_enabled:
logger.info("GPU disabled in config.yaml, using CPU")
return torch.device('cpu')
# Handle device selection
if device_setting == 'cpu':
logger.info("Device set to CPU in config.yaml")
return torch.device('cpu')
elif device_setting == 'cuda' or device_setting == 'auto':
# Try GPU first
if torch.cuda.is_available():
logger.info("Using GPU (CUDA available)")
return torch.device('cuda')
else:
if fallback_to_cpu:
logger.warning("CUDA not available, falling back to CPU")
return torch.device('cpu')
else:
raise RuntimeError("CUDA not available and fallback_to_cpu is False")
else:
logger.warning(f"Unknown device setting '{device_setting}', using auto-detection")
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
except Exception as e:
logger.warning(f"Error reading device config: {e}, using auto-detection")
# Fallback to auto-detection
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def _enable_gradient_checkpointing(self):
"""Enable gradient checkpointing for memory efficiency"""
# This is handled by the model itself if use_gradient_checkpointing is True
pass
# Learning rate scheduler
self.scheduler = optim.lr_scheduler.OneCycleLR(
self.optimizer,
@@ -1229,8 +1304,8 @@ class TradingTransformerTrainer:
# Enable anomaly detection temporarily to debug inplace operation issues
# NOTE: This significantly slows down training (2-3x slower), use only for debugging
# Set to False once the issue is resolved
enable_anomaly_detection = False # Set to True to debug gradient issues
# Set to True to find exact inplace operation causing errors
enable_anomaly_detection = True # TEMPORARILY ENABLED to find inplace operations
if enable_anomaly_detection:
torch.autograd.set_detect_anomaly(True)
@@ -1276,30 +1351,54 @@ class TradingTransformerTrainer:
del batch_gpu
# else: batch is already on GPU, use it directly!
# Ensure all batch tensors are on the same device as the model
# This is critical to avoid device mismatch errors
model_device = next(self.model.parameters()).device
batch_on_device = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
# Move tensor to model's device if it's not already there
if v.device != model_device:
batch_on_device[k] = v.to(model_device, non_blocking=True)
else:
batch_on_device[k] = v
else:
batch_on_device[k] = v
# Also ensure model is on the correct device (in case it was moved elsewhere)
if model_device != self.device:
logger.warning(f"Model device ({model_device}) doesn't match trainer device ({self.device}). Moving model to {self.device}")
self.model.to(self.device)
model_device = self.device
# Re-move batch to correct device
for k, v in batch_on_device.items():
if isinstance(v, torch.Tensor):
batch_on_device[k] = v.to(self.device, non_blocking=True)
# Use automatic mixed precision (FP16) for memory efficiency
# Support both CUDA and ROCm (AMD) devices
device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'
with torch.amp.autocast(device_type, enabled=self.use_amp and device_type != 'cpu'):
# Forward pass with multi-timeframe data
outputs = self.model(
price_data_1s=batch.get('price_data_1s'),
price_data_1m=batch.get('price_data_1m'),
price_data_1h=batch.get('price_data_1h'),
price_data_1d=batch.get('price_data_1d'),
btc_data_1m=batch.get('btc_data_1m'),
cob_data=batch.get('cob_data'), # Use .get() to handle missing key
tech_data=batch.get('tech_data'),
market_data=batch.get('market_data'),
position_state=batch.get('position_state'),
price_data=batch.get('price_data') # Legacy fallback
price_data_1s=batch_on_device.get('price_data_1s'),
price_data_1m=batch_on_device.get('price_data_1m'),
price_data_1h=batch_on_device.get('price_data_1h'),
price_data_1d=batch_on_device.get('price_data_1d'),
btc_data_1m=batch_on_device.get('btc_data_1m'),
cob_data=batch_on_device.get('cob_data'), # Use .get() to handle missing key
tech_data=batch_on_device.get('tech_data'),
market_data=batch_on_device.get('market_data'),
position_state=batch_on_device.get('position_state'),
price_data=batch_on_device.get('price_data') # Legacy fallback
)
# Calculate losses
action_loss = self.action_criterion(outputs['action_logits'], batch['actions'])
# Calculate losses (use batch_on_device for consistency)
action_loss = self.action_criterion(outputs['action_logits'], batch_on_device['actions'])
# FIXED: Ensure shapes match for MSELoss
price_pred = outputs['price_prediction']
price_target = batch['future_prices']
price_target = batch_on_device['future_prices']
# Both should be [batch, 1], but ensure they match
if price_pred.shape != price_target.shape:
@@ -1310,14 +1409,14 @@ class TradingTransformerTrainer:
# NEW: Trend analysis loss (if trend_target provided)
trend_loss = torch.tensor(0.0, device=self.device)
if 'trend_target' in batch and 'trend_analysis' in outputs:
if 'trend_target' in batch_on_device and 'trend_analysis' in outputs:
trend_pred = torch.cat([
outputs['trend_analysis']['angle_radians'],
outputs['trend_analysis']['steepness'],
outputs['trend_analysis']['direction']
], dim=1) # [batch, 3]
trend_target = batch['trend_target']
trend_target = batch_on_device['trend_target']
if trend_pred.shape == trend_target.shape:
trend_loss = self.price_criterion(trend_pred, trend_target)
logger.debug(f"Trend loss: {trend_loss.item():.6f} (pred={trend_pred[0].tolist()}, target={trend_target[0].tolist()})")
@@ -1333,7 +1432,7 @@ class TradingTransformerTrainer:
# Get normalization parameters if available
# norm_params may be a dict or a list of dicts (one per sample in batch)
norm_params_raw = batch.get('norm_params', {})
norm_params_raw = batch_on_device.get('norm_params', {})
if isinstance(norm_params_raw, list) and len(norm_params_raw) > 0:
# If it's a list, use the first one (batch size is typically 1)
norm_params = norm_params_raw[0]
@@ -1344,9 +1443,9 @@ class TradingTransformerTrainer:
for tf in ['1s', '1m', '1h', '1d']:
future_key = f'future_candle_{tf}'
if tf in outputs['next_candles'] and future_key in batch:
if tf in outputs['next_candles'] and future_key in batch_on_device:
pred_candle = outputs['next_candles'][tf] # [batch, 5] - predicted OHLCV (normalized)
target_candle = batch[future_key] # [batch, 5] - actual OHLCV (normalized)
target_candle = batch_on_device[future_key] # [batch, 5] - actual OHLCV (normalized)
if target_candle is not None and pred_candle.shape == target_candle.shape:
# MSE loss on normalized values (used for backprop)
@@ -1386,10 +1485,10 @@ class TradingTransformerTrainer:
total_loss = total_loss / 5.0
# Add confidence loss if available
if 'confidence' in outputs and 'trade_success' in batch:
if 'confidence' in outputs and 'trade_success' in batch_on_device:
# Both tensors should have shape [batch_size, 1] for BCELoss
confidence_pred = outputs['confidence']
trade_target = batch['trade_success'].float()
trade_target = batch_on_device['trade_success'].float()
# FIXED: Ensure both are 2D tensors [batch_size, 1]
# Handle different input shapes robustly
@@ -1471,6 +1570,31 @@ class TradingTransformerTrainer:
# Note: We need to recompute loss and backward pass, but for now just skip this step
logger.warning("Skipping optimizer step after reset - gradients need to be recomputed")
# Don't raise - allow training to continue with next batch
except RuntimeError as gpu_error:
# Check if it's a GPU-related error and fallback to CPU if configured
if "cuda" in str(gpu_error).lower() or "gpu" in str(gpu_error).lower():
logger.error(f"GPU error during optimizer step: {gpu_error}")
# Try to fallback to CPU if configured
try:
from core.config import get_config
config = get_config()
fallback_to_cpu = config._config.get('gpu', {}).get('fallback_to_cpu', True)
if fallback_to_cpu and self.device.type == 'cuda':
logger.warning("Falling back to CPU due to GPU errors")
self.device = torch.device('cpu')
self.model.to(self.device)
# Recreate optimizer for CPU
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
logger.info("Model moved to CPU, training will continue on CPU")
# Skip this step, continue with next batch
return result
except Exception as fallback_error:
logger.error(f"Failed to fallback to CPU: {fallback_error}")
raise
self.scheduler.step()
@@ -1486,12 +1610,12 @@ class TradingTransformerTrainer:
if 'next_candles' in outputs:
# Use 1s or 1m timeframe as primary metric (try 1s first)
if '1s' in outputs['next_candles'] and 'future_candle_1s' in batch:
if '1s' in outputs['next_candles'] and 'future_candle_1s' in batch_on_device:
pred_candle = outputs['next_candles']['1s'] # [batch, 5]
actual_candle = batch['future_candle_1s'] # [batch, 5]
elif '1m' in outputs['next_candles'] and 'future_candle_1m' in batch:
actual_candle = batch_on_device['future_candle_1s'] # [batch, 5]
elif '1m' in outputs['next_candles'] and 'future_candle_1m' in batch_on_device:
pred_candle = outputs['next_candles']['1m'] # [batch, 5]
actual_candle = batch['future_candle_1m'] # [batch, 5]
actual_candle = batch_on_device['future_candle_1m'] # [batch, 5]
else:
pred_candle = None
actual_candle = None
@@ -1521,12 +1645,12 @@ class TradingTransformerTrainer:
# SECONDARY: Trend vector prediction accuracy
trend_accuracy = 0.0
if 'trend_analysis' in outputs and 'trend_target' in batch:
if 'trend_analysis' in outputs and 'trend_target' in batch_on_device:
pred_angle = outputs['trend_analysis']['angle_radians']
pred_steepness = outputs['trend_analysis']['steepness']
actual_angle = batch['trend_target'][:, 0:1]
actual_steepness = batch['trend_target'][:, 1:2]
actual_angle = batch_on_device['trend_target'][:, 0:1]
actual_steepness = batch_on_device['trend_target'][:, 1:2]
# Angle error (degrees)
angle_error_rad = torch.abs(pred_angle - actual_angle)
@@ -1541,7 +1665,7 @@ class TradingTransformerTrainer:
# LEGACY: Action accuracy (for comparison)
action_predictions = torch.argmax(outputs['action_logits'], dim=-1)
action_accuracy = (action_predictions == batch['actions']).float().mean().item()
action_accuracy = (action_predictions == batch_on_device['actions']).float().mean().item()
# Extract values and delete tensors to free memory
result = {