train works
This commit is contained in:
@@ -197,14 +197,25 @@ def train(data_interface, model, args):
|
||||
train_action_probs, train_price_preds = model.predict(X_train)
|
||||
val_action_probs, val_price_preds = model.predict(X_val)
|
||||
|
||||
# Convert probabilities to actions for PnL calculation
|
||||
train_preds = np.argmax(train_action_probs, axis=1)
|
||||
val_preds = np.argmax(val_action_probs, axis=1)
|
||||
|
||||
# Calculate PnL and win rates
|
||||
try:
|
||||
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
|
||||
train_preds, train_prices, position_size=1.0
|
||||
)
|
||||
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
|
||||
val_preds, val_prices, position_size=1.0
|
||||
)
|
||||
if train_preds is not None and train_prices is not None:
|
||||
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
|
||||
train_preds, train_prices, position_size=1.0
|
||||
)
|
||||
else:
|
||||
train_pnl, train_win_rate, train_trades = 0, 0, []
|
||||
|
||||
if val_preds is not None and val_prices is not None:
|
||||
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
|
||||
val_preds, val_prices, position_size=1.0
|
||||
)
|
||||
else:
|
||||
val_pnl, val_win_rate, val_trades = 0, 0, []
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating PnL: {str(e)}")
|
||||
train_pnl, train_win_rate, val_pnl, val_win_rate = 0, 0, 0, 0
|
||||
|
||||
Reference in New Issue
Block a user