training
This commit is contained in:
@ -148,19 +148,29 @@ def train(data_interface, model, args):
|
||||
best_val_acc = 0
|
||||
best_val_pnl = float('-inf')
|
||||
best_win_rate = 0
|
||||
best_price_mae = float('inf')
|
||||
|
||||
logger.info("Verifying data interface...")
|
||||
X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True)
|
||||
logger.info(f"Data validation - X shape: {X_sample.shape}, y shape: {y_sample.shape}")
|
||||
|
||||
# Calculate refresh intervals based on timeframes
|
||||
min_timeframe = min(args.timeframes)
|
||||
refresh_interval = {
|
||||
'1s': 1,
|
||||
'1m': 60,
|
||||
'5m': 300,
|
||||
'15m': 900,
|
||||
'1h': 3600,
|
||||
'4h': 14400,
|
||||
'1d': 86400
|
||||
}.get(min_timeframe, 60)
|
||||
|
||||
logger.info(f"Using refresh interval of {refresh_interval} seconds based on {min_timeframe} timeframe")
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
# More frequent refresh for shorter timeframes
|
||||
if '1s' in args.timeframes:
|
||||
refresh = True # Always refresh for tick data
|
||||
refresh_interval = 30 # 30 seconds for tick data
|
||||
else:
|
||||
refresh = epoch % 1 == 0 # Refresh every epoch
|
||||
refresh_interval = 120 # 2 minutes for other timeframes
|
||||
# Always refresh for tick data or when using multiple timeframes
|
||||
refresh = '1s' in args.timeframes or len(args.timeframes) > 1
|
||||
|
||||
logger.info(f"\nStarting epoch {epoch+1}/{args.epochs}")
|
||||
X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data(
|
||||
@ -170,86 +180,106 @@ def train(data_interface, model, args):
|
||||
logger.info(f"Training data - X shape: {X_train.shape}, y shape: {y_train.shape}")
|
||||
logger.info(f"Validation data - X shape: {X_val.shape}, y shape: {y_val.shape}")
|
||||
|
||||
# Get future prices for retrospective training
|
||||
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=3)
|
||||
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=3)
|
||||
|
||||
# Train and validate
|
||||
try:
|
||||
train_loss, train_acc = model.train_epoch(X_train, y_train, args.batch_size)
|
||||
val_loss, val_acc = model.evaluate(X_val, y_val)
|
||||
train_action_loss, train_price_loss, train_acc = model.train_epoch(
|
||||
X_train, y_train, train_future_prices, args.batch_size
|
||||
)
|
||||
val_action_loss, val_price_loss, val_acc = model.evaluate(
|
||||
X_val, y_val, val_future_prices
|
||||
)
|
||||
|
||||
# Get predictions for PnL calculation
|
||||
train_preds = model.predict(X_train)
|
||||
val_preds = model.predict(X_val)
|
||||
train_action_probs, train_price_preds = model.predict(X_train)
|
||||
val_action_probs, val_price_preds = model.predict(X_val)
|
||||
|
||||
# Calculate PnL and win rates
|
||||
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
|
||||
train_preds, train_prices, position_size=1.0
|
||||
train_action_probs, train_prices, position_size=1.0
|
||||
)
|
||||
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
|
||||
val_preds, val_prices, position_size=1.0
|
||||
val_action_probs, val_prices, position_size=1.0
|
||||
)
|
||||
|
||||
# Calculate price prediction error
|
||||
train_price_mae = np.mean(np.abs(train_price_preds - train_future_prices))
|
||||
val_price_mae = np.mean(np.abs(val_price_preds - val_future_prices))
|
||||
|
||||
# Monitor action distribution
|
||||
train_actions = np.bincount(train_preds, minlength=3)
|
||||
val_actions = np.bincount(val_preds, minlength=3)
|
||||
train_actions = np.bincount(np.argmax(train_action_probs, axis=1), minlength=3)
|
||||
val_actions = np.bincount(np.argmax(val_action_probs, axis=1), minlength=3)
|
||||
|
||||
# Log metrics
|
||||
writer.add_scalar('Loss/train', train_loss, epoch)
|
||||
writer.add_scalar('Loss/action_train', train_action_loss, epoch)
|
||||
writer.add_scalar('Loss/price_train', train_price_loss, epoch)
|
||||
writer.add_scalar('Loss/action_val', val_action_loss, epoch)
|
||||
writer.add_scalar('Loss/price_val', val_price_loss, epoch)
|
||||
writer.add_scalar('Accuracy/train', train_acc, epoch)
|
||||
writer.add_scalar('Loss/val', val_loss, epoch)
|
||||
writer.add_scalar('Accuracy/val', val_acc, epoch)
|
||||
writer.add_scalar('PnL/train', train_pnl, epoch)
|
||||
writer.add_scalar('PnL/val', val_pnl, epoch)
|
||||
writer.add_scalar('WinRate/train', train_win_rate, epoch)
|
||||
writer.add_scalar('WinRate/val', val_win_rate, epoch)
|
||||
writer.add_scalar('PriceMAE/train', train_price_mae, epoch)
|
||||
writer.add_scalar('PriceMAE/val', val_price_mae, epoch)
|
||||
|
||||
# Log action distribution
|
||||
for i, action in enumerate(['SELL', 'HOLD', 'BUY']):
|
||||
writer.add_scalar(f'Actions/train_{action}', train_actions[i], epoch)
|
||||
writer.add_scalar(f'Actions/val_{action}', val_actions[i], epoch)
|
||||
|
||||
# Save best model based on validation PnL
|
||||
if val_pnl > best_val_pnl:
|
||||
# Save best model based on validation metrics
|
||||
if val_pnl > best_val_pnl or (val_pnl == best_val_pnl and val_acc > best_val_acc):
|
||||
best_val_pnl = val_pnl
|
||||
best_val_acc = val_acc
|
||||
best_win_rate = val_win_rate
|
||||
best_price_mae = val_price_mae
|
||||
model.save(f"models/{args.model_type}_best.pt")
|
||||
logger.info("Saved new best model based on validation metrics")
|
||||
|
||||
# Log detailed metrics
|
||||
logger.info(f"Epoch {epoch+1}/{args.epochs} - "
|
||||
f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}, "
|
||||
f"PnL: {train_pnl:.2%}, Win Rate: {train_win_rate:.2%} - "
|
||||
f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}, "
|
||||
f"PnL: {val_pnl:.2%}, Win Rate: {val_win_rate:.2%}")
|
||||
logger.info(f"Epoch {epoch+1}/{args.epochs}")
|
||||
logger.info("Training Metrics:")
|
||||
logger.info(f" Action Loss: {train_action_loss:.4f}")
|
||||
logger.info(f" Price Loss: {train_price_loss:.4f}")
|
||||
logger.info(f" Accuracy: {train_acc:.2f}")
|
||||
logger.info(f" PnL: {train_pnl:.2%}")
|
||||
logger.info(f" Win Rate: {train_win_rate:.2%}")
|
||||
logger.info(f" Price MAE: {train_price_mae:.2f}")
|
||||
|
||||
logger.info("Validation Metrics:")
|
||||
logger.info(f" Action Loss: {val_action_loss:.4f}")
|
||||
logger.info(f" Price Loss: {val_price_loss:.4f}")
|
||||
logger.info(f" Accuracy: {val_acc:.2f}")
|
||||
logger.info(f" PnL: {val_pnl:.2%}")
|
||||
logger.info(f" Win Rate: {val_win_rate:.2%}")
|
||||
logger.info(f" Price MAE: {val_price_mae:.2f}")
|
||||
|
||||
# Log action distribution
|
||||
logger.info("Action Distribution:")
|
||||
for i, action in enumerate(['SELL', 'HOLD', 'BUY']):
|
||||
logger.info(f"{action}: Train={train_actions[i]}, Val={val_actions[i]}")
|
||||
logger.info(f" {action}: Train={train_actions[i]}, Val={val_actions[i]}")
|
||||
|
||||
# Log trade statistics
|
||||
if train_trades:
|
||||
logger.info(f"Training trades: {len(train_trades)}")
|
||||
logger.info(f"Validation trades: {len(val_trades)}")
|
||||
logger.info("Trade Statistics:")
|
||||
logger.info(f" Training trades: {len(train_trades)}")
|
||||
logger.info(f" Validation trades: {len(val_trades)}")
|
||||
|
||||
# Retrospective fine-tuning
|
||||
if epoch > 0 and val_pnl > 0: # Only fine-tune if we're making profit
|
||||
logger.info("Performing retrospective fine-tuning...")
|
||||
|
||||
# Get predictions for next few candles
|
||||
# Log next candle predictions
|
||||
if epoch % 10 == 0: # Every 10 epochs
|
||||
logger.info("\nNext Candle Predictions:")
|
||||
next_candles = model.predict_next_candles(X_val[-1:], n_candles=3)
|
||||
|
||||
# Log predictions for each timeframe
|
||||
for tf, preds in next_candles.items():
|
||||
logger.info(f"Next 3 candles for {tf}:")
|
||||
for i, pred in enumerate(preds):
|
||||
action = ['SELL', 'HOLD', 'BUY'][np.argmax(pred)]
|
||||
confidence = np.max(pred)
|
||||
logger.info(f"Candle {i+1}: {action} (confidence: {confidence:.2f})")
|
||||
|
||||
# Fine-tune on recent successful trades
|
||||
successful_trades = [t for t in train_trades if t['pnl'] > 0]
|
||||
if successful_trades:
|
||||
logger.info(f"Fine-tuning on {len(successful_trades)} successful trades")
|
||||
# TODO: Implement fine-tuning logic here
|
||||
for tf in args.timeframes:
|
||||
if tf in next_candles:
|
||||
logger.info(f"\n{tf} timeframe predictions:")
|
||||
for i, pred in enumerate(next_candles[tf]):
|
||||
action = ['SELL', 'HOLD', 'BUY'][np.argmax(pred)]
|
||||
confidence = np.max(pred)
|
||||
logger.info(f" Candle {i+1}: {action} (confidence: {confidence:.2f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during epoch {epoch+1}: {str(e)}")
|
||||
@ -257,10 +287,11 @@ def train(data_interface, model, args):
|
||||
|
||||
# Save final model
|
||||
model.save(f"models/{args.model_type}_final_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt")
|
||||
logger.info(f"Training complete. Best validation metrics:")
|
||||
logger.info(f"\nTraining complete. Best validation metrics:")
|
||||
logger.info(f"Accuracy: {best_val_acc:.2f}")
|
||||
logger.info(f"PnL: {best_val_pnl:.2%}")
|
||||
logger.info(f"Win Rate: {best_win_rate:.2%}")
|
||||
logger.info(f"Price MAE: {best_price_mae:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training: {str(e)}")
|
||||
|
Reference in New Issue
Block a user