From 4a5c3fc94387beaab693b5ed26fe1f005cfbb3d2 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Wed, 12 Nov 2025 18:12:47 +0200 Subject: [PATCH] try to work with ROCKm (AMD) GPUs again --- ANNOTATE/core/real_training_adapter.py | 6 +++--- ANNOTATE/web/app.py | 4 ++-- NN/models/advanced_transformer_trading.py | 4 +++- NN/models/cob_rl_model.py | 3 ++- NN/models/dqn_agent.py | 3 ++- core/realtime_rl_cob_trader.py | 6 ++++-- 6 files changed, 16 insertions(+), 10 deletions(-) diff --git a/ANNOTATE/core/real_training_adapter.py b/ANNOTATE/core/real_training_adapter.py index 6d883aa..a22962a 100644 --- a/ANNOTATE/core/real_training_adapter.py +++ b/ANNOTATE/core/real_training_adapter.py @@ -821,7 +821,7 @@ class RealTrainingAdapter: # If exit not found, estimate it if exit_index is None: # Estimate: 1 minute per candle - candles_in_trade = holding_period_seconds // 60 + candles_in_trade = int(holding_period_seconds // 60) # Ensure integer exit_index = min(entry_index + candles_in_trade, len(timestamps) - 1) logger.debug(f" Estimated exit index: {exit_index} ({candles_in_trade} candles)") @@ -1934,9 +1934,9 @@ class RealTrainingAdapter: # Save metadata to database for easy retrieval try: - from utils.database_manager import DatabaseManager + from utils.database_manager import get_database_manager - db_manager = DatabaseManager() + db_manager = get_database_manager() checkpoint_id = f"transformer_e{epoch+1}_{timestamp}" # Create metadata object diff --git a/ANNOTATE/web/app.py b/ANNOTATE/web/app.py index 41a4d65..f651de3 100644 --- a/ANNOTATE/web/app.py +++ b/ANNOTATE/web/app.py @@ -204,8 +204,8 @@ class AnnotationDashboard: try: # Try to get from database first (has full metadata) try: - from utils.database_manager import DatabaseManager - db_manager = DatabaseManager() + from utils.database_manager import get_database_manager + db_manager = get_database_manager() # Get active checkpoint for this model with db_manager._get_connection() as conn: diff --git a/NN/models/advanced_transformer_trading.py b/NN/models/advanced_transformer_trading.py index 402888b..665ea37 100644 --- a/NN/models/advanced_transformer_trading.py +++ b/NN/models/advanced_transformer_trading.py @@ -1229,7 +1229,9 @@ class TradingTransformerTrainer: for k, v in batch.items()} # Use automatic mixed precision (FP16) for memory efficiency - with torch.amp.autocast('cuda', enabled=self.use_amp): + # Support both CUDA and ROCm (AMD) devices + device_type = 'cuda' if self.device.type == 'cuda' else 'cpu' + with torch.amp.autocast(device_type, enabled=self.use_amp and device_type != 'cpu'): # Forward pass with multi-timeframe data outputs = self.model( price_data_1s=batch.get('price_data_1s'), diff --git a/NN/models/cob_rl_model.py b/NN/models/cob_rl_model.py index 6abaf11..b862520 100644 --- a/NN/models/cob_rl_model.py +++ b/NN/models/cob_rl_model.py @@ -323,7 +323,8 @@ class COBRLModelInterface(ModelInterface): self.optimizer.zero_grad() if self.scaler: - with torch.amp.autocast('cuda'): + device_type = 'cuda' if next(self.model.parameters()).device.type == 'cuda' else 'cpu' + with torch.amp.autocast(device_type, enabled=device_type != 'cpu'): outputs = self.model(features) loss = self._calculate_loss(outputs, targets) diff --git a/NN/models/dqn_agent.py b/NN/models/dqn_agent.py index 84d4678..d43e933 100644 --- a/NN/models/dqn_agent.py +++ b/NN/models/dqn_agent.py @@ -1436,7 +1436,8 @@ class DQNAgent: import warnings with warnings.catch_warnings(): warnings.simplefilter("ignore", FutureWarning) - with torch.amp.autocast('cuda'): + device_type = 'cuda' if self.device.type == 'cuda' else 'cpu' + with torch.amp.autocast(device_type, enabled=device_type != 'cpu'): # Get current Q values and predictions current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states) current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1) diff --git a/core/realtime_rl_cob_trader.py b/core/realtime_rl_cob_trader.py index 8715051..2d6446f 100644 --- a/core/realtime_rl_cob_trader.py +++ b/core/realtime_rl_cob_trader.py @@ -536,7 +536,8 @@ class RealtimeRLCOBTrader: features_tensor = torch.from_numpy(features).unsqueeze(0).to(self.device) with torch.no_grad(): - with torch.amp.autocast('cuda'): + device_type = 'cuda' if self.device.type == 'cuda' else 'cpu' + with torch.amp.autocast(device_type, enabled=device_type != 'cpu'): outputs = model(features_tensor) # Extract predictions @@ -934,7 +935,8 @@ class RealtimeRLCOBTrader: ], dtype=torch.float32).to(self.device) # Forward pass with mixed precision - with torch.amp.autocast('cuda'): + device_type = 'cuda' if self.device.type == 'cuda' else 'cpu' + with torch.amp.autocast(device_type, enabled=device_type != 'cpu'): outputs = model(features) # Calculate losses