try to work with ROCKm (AMD) GPUs again

This commit is contained in:
Dobromir Popov
2025-11-12 18:12:47 +02:00
parent 8354aec830
commit 4a5c3fc943
6 changed files with 16 additions and 10 deletions

View File

@@ -821,7 +821,7 @@ class RealTrainingAdapter:
# If exit not found, estimate it # If exit not found, estimate it
if exit_index is None: if exit_index is None:
# Estimate: 1 minute per candle # Estimate: 1 minute per candle
candles_in_trade = holding_period_seconds // 60 candles_in_trade = int(holding_period_seconds // 60) # Ensure integer
exit_index = min(entry_index + candles_in_trade, len(timestamps) - 1) exit_index = min(entry_index + candles_in_trade, len(timestamps) - 1)
logger.debug(f" Estimated exit index: {exit_index} ({candles_in_trade} candles)") logger.debug(f" Estimated exit index: {exit_index} ({candles_in_trade} candles)")
@@ -1934,9 +1934,9 @@ class RealTrainingAdapter:
# Save metadata to database for easy retrieval # Save metadata to database for easy retrieval
try: try:
from utils.database_manager import DatabaseManager from utils.database_manager import get_database_manager
db_manager = DatabaseManager() db_manager = get_database_manager()
checkpoint_id = f"transformer_e{epoch+1}_{timestamp}" checkpoint_id = f"transformer_e{epoch+1}_{timestamp}"
# Create metadata object # Create metadata object

View File

@@ -204,8 +204,8 @@ class AnnotationDashboard:
try: try:
# Try to get from database first (has full metadata) # Try to get from database first (has full metadata)
try: try:
from utils.database_manager import DatabaseManager from utils.database_manager import get_database_manager
db_manager = DatabaseManager() db_manager = get_database_manager()
# Get active checkpoint for this model # Get active checkpoint for this model
with db_manager._get_connection() as conn: with db_manager._get_connection() as conn:

View File

@@ -1229,7 +1229,9 @@ class TradingTransformerTrainer:
for k, v in batch.items()} for k, v in batch.items()}
# Use automatic mixed precision (FP16) for memory efficiency # Use automatic mixed precision (FP16) for memory efficiency
with torch.amp.autocast('cuda', enabled=self.use_amp): # Support both CUDA and ROCm (AMD) devices
device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'
with torch.amp.autocast(device_type, enabled=self.use_amp and device_type != 'cpu'):
# Forward pass with multi-timeframe data # Forward pass with multi-timeframe data
outputs = self.model( outputs = self.model(
price_data_1s=batch.get('price_data_1s'), price_data_1s=batch.get('price_data_1s'),

View File

@@ -323,7 +323,8 @@ class COBRLModelInterface(ModelInterface):
self.optimizer.zero_grad() self.optimizer.zero_grad()
if self.scaler: if self.scaler:
with torch.amp.autocast('cuda'): device_type = 'cuda' if next(self.model.parameters()).device.type == 'cuda' else 'cpu'
with torch.amp.autocast(device_type, enabled=device_type != 'cpu'):
outputs = self.model(features) outputs = self.model(features)
loss = self._calculate_loss(outputs, targets) loss = self._calculate_loss(outputs, targets)

View File

@@ -1436,7 +1436,8 @@ class DQNAgent:
import warnings import warnings
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning) warnings.simplefilter("ignore", FutureWarning)
with torch.amp.autocast('cuda'): device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'
with torch.amp.autocast(device_type, enabled=device_type != 'cpu'):
# Get current Q values and predictions # Get current Q values and predictions
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states) current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1) current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

View File

@@ -536,7 +536,8 @@ class RealtimeRLCOBTrader:
features_tensor = torch.from_numpy(features).unsqueeze(0).to(self.device) features_tensor = torch.from_numpy(features).unsqueeze(0).to(self.device)
with torch.no_grad(): with torch.no_grad():
with torch.amp.autocast('cuda'): device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'
with torch.amp.autocast(device_type, enabled=device_type != 'cpu'):
outputs = model(features_tensor) outputs = model(features_tensor)
# Extract predictions # Extract predictions
@@ -934,7 +935,8 @@ class RealtimeRLCOBTrader:
], dtype=torch.float32).to(self.device) ], dtype=torch.float32).to(self.device)
# Forward pass with mixed precision # Forward pass with mixed precision
with torch.amp.autocast('cuda'): device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'
with torch.amp.autocast(device_type, enabled=device_type != 'cpu'):
outputs = model(features) outputs = model(features)
# Calculate losses # Calculate losses