This commit is contained in:
Dobromir Popov
2025-03-31 02:22:51 +03:00
parent 1b9f471076
commit 8981ad0691
5 changed files with 124 additions and 147 deletions

View File

@ -198,16 +198,44 @@ def train(data_interface, model, args):
val_action_probs, val_price_preds = model.predict(X_val)
# Calculate PnL and win rates
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
train_action_probs, train_prices, position_size=1.0
)
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
val_action_probs, val_prices, position_size=1.0
)
try:
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
train_preds, train_prices, position_size=1.0
)
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
val_preds, val_prices, position_size=1.0
)
except Exception as e:
logger.error(f"Error calculating PnL: {str(e)}")
train_pnl, train_win_rate, val_pnl, val_win_rate = 0, 0, 0, 0
train_trades, val_trades = [], []
# Calculate price prediction error
train_price_mae = np.mean(np.abs(train_price_preds - train_future_prices))
val_price_mae = np.mean(np.abs(val_price_preds - val_future_prices))
if train_future_prices is not None and train_price_preds is not None:
# Ensure arrays have the same shape and are numpy arrays
train_future_prices_np = np.array(train_future_prices) if not isinstance(train_future_prices, np.ndarray) else train_future_prices
train_price_preds_np = np.array(train_price_preds) if not isinstance(train_price_preds, np.ndarray) else train_price_preds
if len(train_price_preds_np) > 0 and len(train_future_prices_np) > 0:
min_len = min(len(train_price_preds_np), len(train_future_prices_np))
train_price_mae = np.mean(np.abs(train_price_preds_np[:min_len] - train_future_prices_np[:min_len]))
else:
train_price_mae = float('inf')
else:
train_price_mae = float('inf')
if val_future_prices is not None and val_price_preds is not None:
# Ensure arrays have the same shape and are numpy arrays
val_future_prices_np = np.array(val_future_prices) if not isinstance(val_future_prices, np.ndarray) else val_future_prices
val_price_preds_np = np.array(val_price_preds) if not isinstance(val_price_preds, np.ndarray) else val_price_preds
if len(val_price_preds_np) > 0 and len(val_future_prices_np) > 0:
min_len = min(len(val_price_preds_np), len(val_future_prices_np))
val_price_mae = np.mean(np.abs(val_price_preds_np[:min_len] - val_future_prices_np[:min_len]))
else:
val_price_mae = float('inf')
else:
val_price_mae = float('inf')
# Monitor action distribution
train_actions = np.bincount(np.argmax(train_action_probs, axis=1), minlength=3)
@ -233,7 +261,7 @@ def train(data_interface, model, args):
writer.add_scalar(f'Actions/val_{action}', val_actions[i], epoch)
# Save best model based on validation metrics
if val_pnl > best_val_pnl or (val_pnl == best_val_pnl and val_acc > best_val_acc):
if np.isscalar(val_pnl) and np.isscalar(best_val_pnl) and (val_pnl > best_val_pnl or (np.isclose(val_pnl, best_val_pnl) and val_acc > best_val_acc)):
best_val_pnl = val_pnl
best_val_acc = val_acc
best_win_rate = val_win_rate