wip
This commit is contained in:
@ -198,16 +198,44 @@ def train(data_interface, model, args):
|
||||
val_action_probs, val_price_preds = model.predict(X_val)
|
||||
|
||||
# Calculate PnL and win rates
|
||||
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
|
||||
train_action_probs, train_prices, position_size=1.0
|
||||
)
|
||||
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
|
||||
val_action_probs, val_prices, position_size=1.0
|
||||
)
|
||||
try:
|
||||
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
|
||||
train_preds, train_prices, position_size=1.0
|
||||
)
|
||||
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
|
||||
val_preds, val_prices, position_size=1.0
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating PnL: {str(e)}")
|
||||
train_pnl, train_win_rate, val_pnl, val_win_rate = 0, 0, 0, 0
|
||||
train_trades, val_trades = [], []
|
||||
|
||||
# Calculate price prediction error
|
||||
train_price_mae = np.mean(np.abs(train_price_preds - train_future_prices))
|
||||
val_price_mae = np.mean(np.abs(val_price_preds - val_future_prices))
|
||||
if train_future_prices is not None and train_price_preds is not None:
|
||||
# Ensure arrays have the same shape and are numpy arrays
|
||||
train_future_prices_np = np.array(train_future_prices) if not isinstance(train_future_prices, np.ndarray) else train_future_prices
|
||||
train_price_preds_np = np.array(train_price_preds) if not isinstance(train_price_preds, np.ndarray) else train_price_preds
|
||||
|
||||
if len(train_price_preds_np) > 0 and len(train_future_prices_np) > 0:
|
||||
min_len = min(len(train_price_preds_np), len(train_future_prices_np))
|
||||
train_price_mae = np.mean(np.abs(train_price_preds_np[:min_len] - train_future_prices_np[:min_len]))
|
||||
else:
|
||||
train_price_mae = float('inf')
|
||||
else:
|
||||
train_price_mae = float('inf')
|
||||
|
||||
if val_future_prices is not None and val_price_preds is not None:
|
||||
# Ensure arrays have the same shape and are numpy arrays
|
||||
val_future_prices_np = np.array(val_future_prices) if not isinstance(val_future_prices, np.ndarray) else val_future_prices
|
||||
val_price_preds_np = np.array(val_price_preds) if not isinstance(val_price_preds, np.ndarray) else val_price_preds
|
||||
|
||||
if len(val_price_preds_np) > 0 and len(val_future_prices_np) > 0:
|
||||
min_len = min(len(val_price_preds_np), len(val_future_prices_np))
|
||||
val_price_mae = np.mean(np.abs(val_price_preds_np[:min_len] - val_future_prices_np[:min_len]))
|
||||
else:
|
||||
val_price_mae = float('inf')
|
||||
else:
|
||||
val_price_mae = float('inf')
|
||||
|
||||
# Monitor action distribution
|
||||
train_actions = np.bincount(np.argmax(train_action_probs, axis=1), minlength=3)
|
||||
@ -233,7 +261,7 @@ def train(data_interface, model, args):
|
||||
writer.add_scalar(f'Actions/val_{action}', val_actions[i], epoch)
|
||||
|
||||
# Save best model based on validation metrics
|
||||
if val_pnl > best_val_pnl or (val_pnl == best_val_pnl and val_acc > best_val_acc):
|
||||
if np.isscalar(val_pnl) and np.isscalar(best_val_pnl) and (val_pnl > best_val_pnl or (np.isclose(val_pnl, best_val_pnl) and val_acc > best_val_acc)):
|
||||
best_val_pnl = val_pnl
|
||||
best_val_acc = val_acc
|
||||
best_win_rate = val_win_rate
|
||||
|
Reference in New Issue
Block a user