fixed model training
This commit is contained in:
@@ -815,8 +815,8 @@ class RealTrainingAdapter:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if entry_index is None:
|
if entry_index is None:
|
||||||
logger.warning(f"Could not find entry timestamp in market data")
|
logger.debug(f"Could not find entry timestamp in market data - using first candle as entry")
|
||||||
return negative_samples
|
entry_index = 0 # Use first candle if exact match not found
|
||||||
|
|
||||||
# If exit not found, estimate it
|
# If exit not found, estimate it
|
||||||
if exit_index is None:
|
if exit_index is None:
|
||||||
@@ -1958,7 +1958,7 @@ class RealTrainingAdapter:
|
|||||||
'training_id': session.training_id
|
'training_id': session.training_id
|
||||||
},
|
},
|
||||||
file_path=checkpoint_path,
|
file_path=checkpoint_path,
|
||||||
performance_score=float(avg_accuracy), # Use accuracy as score
|
file_size_mb=os.path.getsize(checkpoint_path) / (1024 * 1024) if os.path.exists(checkpoint_path) else 0.0,
|
||||||
is_active=True
|
is_active=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -213,7 +213,7 @@ class AnnotationDashboard:
|
|||||||
SELECT checkpoint_id, performance_metrics, timestamp, file_path
|
SELECT checkpoint_id, performance_metrics, timestamp, file_path
|
||||||
FROM checkpoint_metadata
|
FROM checkpoint_metadata
|
||||||
WHERE model_name = ? AND is_active = TRUE
|
WHERE model_name = ? AND is_active = TRUE
|
||||||
ORDER BY performance_score DESC
|
ORDER BY timestamp DESC
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
""", (model_name.lower(),))
|
""", (model_name.lower(),))
|
||||||
|
|
||||||
|
|||||||
@@ -1229,7 +1229,7 @@ class TradingTransformerTrainer:
|
|||||||
for k, v in batch.items()}
|
for k, v in batch.items()}
|
||||||
|
|
||||||
# Use automatic mixed precision (FP16) for memory efficiency
|
# Use automatic mixed precision (FP16) for memory efficiency
|
||||||
with torch.cuda.amp.autocast(enabled=self.use_amp):
|
with torch.amp.autocast('cuda', enabled=self.use_amp):
|
||||||
# Forward pass with multi-timeframe data
|
# Forward pass with multi-timeframe data
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
price_data_1s=batch.get('price_data_1s'),
|
price_data_1s=batch.get('price_data_1s'),
|
||||||
|
|||||||
@@ -323,7 +323,7 @@ class COBRLModelInterface(ModelInterface):
|
|||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
if self.scaler:
|
if self.scaler:
|
||||||
with torch.cuda.amp.autocast():
|
with torch.amp.autocast('cuda'):
|
||||||
outputs = self.model(features)
|
outputs = self.model(features)
|
||||||
loss = self._calculate_loss(outputs, targets)
|
loss = self._calculate_loss(outputs, targets)
|
||||||
|
|
||||||
|
|||||||
@@ -1436,7 +1436,7 @@ class DQNAgent:
|
|||||||
import warnings
|
import warnings
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore", FutureWarning)
|
warnings.simplefilter("ignore", FutureWarning)
|
||||||
with torch.cuda.amp.autocast():
|
with torch.amp.autocast('cuda'):
|
||||||
# Get current Q values and predictions
|
# Get current Q values and predictions
|
||||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
|
current_q_values, current_extrema_pred, current_price_pred, hidden_features, current_advanced_pred = self._safe_cnn_forward(self.policy_net, states)
|
||||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||||
|
|||||||
@@ -536,7 +536,7 @@ class RealtimeRLCOBTrader:
|
|||||||
features_tensor = torch.from_numpy(features).unsqueeze(0).to(self.device)
|
features_tensor = torch.from_numpy(features).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.cuda.amp.autocast():
|
with torch.amp.autocast('cuda'):
|
||||||
outputs = model(features_tensor)
|
outputs = model(features_tensor)
|
||||||
|
|
||||||
# Extract predictions
|
# Extract predictions
|
||||||
@@ -934,7 +934,7 @@ class RealtimeRLCOBTrader:
|
|||||||
], dtype=torch.float32).to(self.device)
|
], dtype=torch.float32).to(self.device)
|
||||||
|
|
||||||
# Forward pass with mixed precision
|
# Forward pass with mixed precision
|
||||||
with torch.cuda.amp.autocast():
|
with torch.amp.autocast('cuda'):
|
||||||
outputs = model(features)
|
outputs = model(features)
|
||||||
|
|
||||||
# Calculate losses
|
# Calculate losses
|
||||||
|
|||||||
Reference in New Issue
Block a user