try to work with ROCKm (AMD) GPUs again
This commit is contained in:
@@ -821,7 +821,7 @@ class RealTrainingAdapter:
|
|||||||
# If exit not found, estimate it
|
# If exit not found, estimate it
|
||||||
if exit_index is None:
|
if exit_index is None:
|
||||||
# Estimate: 1 minute per candle
|
# Estimate: 1 minute per candle
|
||||||
candles_in_trade = holding_period_seconds // 60
|
candles_in_trade = int(holding_period_seconds // 60) # Ensure integer
|
||||||
exit_index = min(entry_index + candles_in_trade, len(timestamps) - 1)
|
exit_index = min(entry_index + candles_in_trade, len(timestamps) - 1)
|
||||||
logger.debug(f" Estimated exit index: {exit_index} ({candles_in_trade} candles)")
|
logger.debug(f" Estimated exit index: {exit_index} ({candles_in_trade} candles)")
|
||||||
|
|
||||||
@@ -1934,9 +1934,9 @@ class RealTrainingAdapter:
|
|||||||
|
|
||||||
# Save metadata to database for easy retrieval
|
# Save metadata to database for easy retrieval
|
||||||
try:
|
try:
|
||||||
from utils.database_manager import DatabaseManager
|
from utils.database_manager import get_database_manager
|
||||||
|
|
||||||
db_manager = DatabaseManager()
|
db_manager = get_database_manager()
|
||||||
checkpoint_id = f"transformer_e{epoch+1}_{timestamp}"
|
checkpoint_id = f"transformer_e{epoch+1}_{timestamp}"
|
||||||
|
|
||||||
# Create metadata object
|
# Create metadata object
|
||||||
|
|||||||
@@ -204,8 +204,8 @@ class AnnotationDashboard:
|
|||||||
try:
|
try:
|
||||||
# Try to get from database first (has full metadata)
|
# Try to get from database first (has full metadata)
|
||||||
try:
|
try:
|
||||||
from utils.database_manager import DatabaseManager
|
from utils.database_manager import get_database_manager
|
||||||
db_manager = DatabaseManager()
|
db_manager = get_database_manager()
|
||||||
|
|
||||||
# Get active checkpoint for this model
|
# Get active checkpoint for this model
|
||||||
with db_manager._get_connection() as conn:
|
with db_manager._get_connection() as conn:
|
||||||
|
|||||||
@@ -1229,7 +1229,9 @@ class TradingTransformerTrainer:
|
|||||||
for k, v in batch.items()}
|
for k, v in batch.items()}
|
||||||
|
|
||||||
# Use automatic mixed precision (FP16) for memory efficiency
|
# Use automatic mixed precision (FP16) for memory efficiency
|
||||||
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
# Support both CUDA and ROCm (AMD) devices
|
||||||
|
device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'
|
||||||
|
with torch.amp.autocast(device_type, enabled=self.use_amp and device_type != 'cpu'):
|
||||||
# Forward pass with multi-timeframe data
|
# Forward pass with multi-timeframe data
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
price_data_1s=batch.get('price_data_1s'),
|
price_data_1s=batch.get('price_data_1s'),
|
||||||
|
|||||||
@@ -323,7 +323,8 @@ class COBRLModelInterface(ModelInterface):
|
|||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
if self.scaler:
|
if self.scaler:
|
||||||
with torch.amp.autocast('cuda'):
|
device_type = 'cuda' if next(self.model.parameters()).device.type == 'cuda' else 'cpu'
|
||||||
|
with torch.amp.autocast(device_type, enabled=device_type != 'cpu'):
|
||||||
outputs = self.model(features)
|
outputs = self.model(features)
|
||||||
loss = self._calculate_loss(outputs, targets)
|
loss = self._calculate_loss(outputs, targets)
|
||||||
|
|
||||||
|
|||||||
@@ -1436,7 +1436,8 @@ class DQNAgent:
|
|||||||
import warnings
|
import warnings
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore", FutureWarning)
|
warnings.simplefilter("ignore", FutureWarning)
|
||||||
with torch.amp.autocast('cuda'):
|
device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'
|
||||||
|
with torch.amp.autocast(device_type, enabled=device_type != 'cpu'):
|
||||||
# Get current Q values and predictions
|
# Get current Q values and predictions
|
||||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
|
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
|
||||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||||
|
|||||||
@@ -536,7 +536,8 @@ class RealtimeRLCOBTrader:
|
|||||||
features_tensor = torch.from_numpy(features).unsqueeze(0).to(self.device)
|
features_tensor = torch.from_numpy(features).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.amp.autocast('cuda'):
|
device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'
|
||||||
|
with torch.amp.autocast(device_type, enabled=device_type != 'cpu'):
|
||||||
outputs = model(features_tensor)
|
outputs = model(features_tensor)
|
||||||
|
|
||||||
# Extract predictions
|
# Extract predictions
|
||||||
@@ -934,7 +935,8 @@ class RealtimeRLCOBTrader:
|
|||||||
], dtype=torch.float32).to(self.device)
|
], dtype=torch.float32).to(self.device)
|
||||||
|
|
||||||
# Forward pass with mixed precision
|
# Forward pass with mixed precision
|
||||||
with torch.amp.autocast('cuda'):
|
device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'
|
||||||
|
with torch.amp.autocast(device_type, enabled=device_type != 'cpu'):
|
||||||
outputs = model(features)
|
outputs = model(features)
|
||||||
|
|
||||||
# Calculate losses
|
# Calculate losses
|
||||||
|
|||||||
Reference in New Issue
Block a user