This commit is contained in:
Dobromir Popov
2025-03-29 04:09:03 +02:00
parent 43803caaf1
commit 8b3db10a85
3 changed files with 307 additions and 267 deletions

View File

@ -148,19 +148,29 @@ def train(data_interface, model, args):
best_val_acc = 0
best_val_pnl = float('-inf')
best_win_rate = 0
best_price_mae = float('inf')
logger.info("Verifying data interface...")
X_sample, y_sample, _, _, _, _ = data_interface.prepare_training_data(refresh=True)
logger.info(f"Data validation - X shape: {X_sample.shape}, y shape: {y_sample.shape}")
# Calculate refresh intervals based on timeframes
min_timeframe = min(args.timeframes)
refresh_interval = {
'1s': 1,
'1m': 60,
'5m': 300,
'15m': 900,
'1h': 3600,
'4h': 14400,
'1d': 86400
}.get(min_timeframe, 60)
logger.info(f"Using refresh interval of {refresh_interval} seconds based on {min_timeframe} timeframe")
for epoch in range(args.epochs):
# More frequent refresh for shorter timeframes
if '1s' in args.timeframes:
refresh = True # Always refresh for tick data
refresh_interval = 30 # 30 seconds for tick data
else:
refresh = epoch % 1 == 0 # Refresh every epoch
refresh_interval = 120 # 2 minutes for other timeframes
# Always refresh for tick data or when using multiple timeframes
refresh = '1s' in args.timeframes or len(args.timeframes) > 1
logger.info(f"\nStarting epoch {epoch+1}/{args.epochs}")
X_train, y_train, X_val, y_val, train_prices, val_prices = data_interface.prepare_training_data(
@ -170,86 +180,106 @@ def train(data_interface, model, args):
logger.info(f"Training data - X shape: {X_train.shape}, y shape: {y_train.shape}")
logger.info(f"Validation data - X shape: {X_val.shape}, y shape: {y_val.shape}")
# Get future prices for retrospective training
train_future_prices = data_interface.get_future_prices(train_prices, n_candles=3)
val_future_prices = data_interface.get_future_prices(val_prices, n_candles=3)
# Train and validate
try:
train_loss, train_acc = model.train_epoch(X_train, y_train, args.batch_size)
val_loss, val_acc = model.evaluate(X_val, y_val)
train_action_loss, train_price_loss, train_acc = model.train_epoch(
X_train, y_train, train_future_prices, args.batch_size
)
val_action_loss, val_price_loss, val_acc = model.evaluate(
X_val, y_val, val_future_prices
)
# Get predictions for PnL calculation
train_preds = model.predict(X_train)
val_preds = model.predict(X_val)
train_action_probs, train_price_preds = model.predict(X_train)
val_action_probs, val_price_preds = model.predict(X_val)
# Calculate PnL and win rates
train_pnl, train_win_rate, train_trades = data_interface.calculate_pnl(
train_preds, train_prices, position_size=1.0
train_action_probs, train_prices, position_size=1.0
)
val_pnl, val_win_rate, val_trades = data_interface.calculate_pnl(
val_preds, val_prices, position_size=1.0
val_action_probs, val_prices, position_size=1.0
)
# Calculate price prediction error
train_price_mae = np.mean(np.abs(train_price_preds - train_future_prices))
val_price_mae = np.mean(np.abs(val_price_preds - val_future_prices))
# Monitor action distribution
train_actions = np.bincount(train_preds, minlength=3)
val_actions = np.bincount(val_preds, minlength=3)
train_actions = np.bincount(np.argmax(train_action_probs, axis=1), minlength=3)
val_actions = np.bincount(np.argmax(val_action_probs, axis=1), minlength=3)
# Log metrics
writer.add_scalar('Loss/train', train_loss, epoch)
writer.add_scalar('Loss/action_train', train_action_loss, epoch)
writer.add_scalar('Loss/price_train', train_price_loss, epoch)
writer.add_scalar('Loss/action_val', val_action_loss, epoch)
writer.add_scalar('Loss/price_val', val_price_loss, epoch)
writer.add_scalar('Accuracy/train', train_acc, epoch)
writer.add_scalar('Loss/val', val_loss, epoch)
writer.add_scalar('Accuracy/val', val_acc, epoch)
writer.add_scalar('PnL/train', train_pnl, epoch)
writer.add_scalar('PnL/val', val_pnl, epoch)
writer.add_scalar('WinRate/train', train_win_rate, epoch)
writer.add_scalar('WinRate/val', val_win_rate, epoch)
writer.add_scalar('PriceMAE/train', train_price_mae, epoch)
writer.add_scalar('PriceMAE/val', val_price_mae, epoch)
# Log action distribution
for i, action in enumerate(['SELL', 'HOLD', 'BUY']):
writer.add_scalar(f'Actions/train_{action}', train_actions[i], epoch)
writer.add_scalar(f'Actions/val_{action}', val_actions[i], epoch)
# Save best model based on validation PnL
if val_pnl > best_val_pnl:
# Save best model based on validation metrics
if val_pnl > best_val_pnl or (val_pnl == best_val_pnl and val_acc > best_val_acc):
best_val_pnl = val_pnl
best_val_acc = val_acc
best_win_rate = val_win_rate
best_price_mae = val_price_mae
model.save(f"models/{args.model_type}_best.pt")
logger.info("Saved new best model based on validation metrics")
# Log detailed metrics
logger.info(f"Epoch {epoch+1}/{args.epochs} - "
f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}, "
f"PnL: {train_pnl:.2%}, Win Rate: {train_win_rate:.2%} - "
f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}, "
f"PnL: {val_pnl:.2%}, Win Rate: {val_win_rate:.2%}")
logger.info(f"Epoch {epoch+1}/{args.epochs}")
logger.info("Training Metrics:")
logger.info(f" Action Loss: {train_action_loss:.4f}")
logger.info(f" Price Loss: {train_price_loss:.4f}")
logger.info(f" Accuracy: {train_acc:.2f}")
logger.info(f" PnL: {train_pnl:.2%}")
logger.info(f" Win Rate: {train_win_rate:.2%}")
logger.info(f" Price MAE: {train_price_mae:.2f}")
logger.info("Validation Metrics:")
logger.info(f" Action Loss: {val_action_loss:.4f}")
logger.info(f" Price Loss: {val_price_loss:.4f}")
logger.info(f" Accuracy: {val_acc:.2f}")
logger.info(f" PnL: {val_pnl:.2%}")
logger.info(f" Win Rate: {val_win_rate:.2%}")
logger.info(f" Price MAE: {val_price_mae:.2f}")
# Log action distribution
logger.info("Action Distribution:")
for i, action in enumerate(['SELL', 'HOLD', 'BUY']):
logger.info(f"{action}: Train={train_actions[i]}, Val={val_actions[i]}")
logger.info(f" {action}: Train={train_actions[i]}, Val={val_actions[i]}")
# Log trade statistics
if train_trades:
logger.info(f"Training trades: {len(train_trades)}")
logger.info(f"Validation trades: {len(val_trades)}")
logger.info("Trade Statistics:")
logger.info(f" Training trades: {len(train_trades)}")
logger.info(f" Validation trades: {len(val_trades)}")
# Retrospective fine-tuning
if epoch > 0 and val_pnl > 0: # Only fine-tune if we're making profit
logger.info("Performing retrospective fine-tuning...")
# Get predictions for next few candles
# Log next candle predictions
if epoch % 10 == 0: # Every 10 epochs
logger.info("\nNext Candle Predictions:")
next_candles = model.predict_next_candles(X_val[-1:], n_candles=3)
# Log predictions for each timeframe
for tf, preds in next_candles.items():
logger.info(f"Next 3 candles for {tf}:")
for i, pred in enumerate(preds):
action = ['SELL', 'HOLD', 'BUY'][np.argmax(pred)]
confidence = np.max(pred)
logger.info(f"Candle {i+1}: {action} (confidence: {confidence:.2f})")
# Fine-tune on recent successful trades
successful_trades = [t for t in train_trades if t['pnl'] > 0]
if successful_trades:
logger.info(f"Fine-tuning on {len(successful_trades)} successful trades")
# TODO: Implement fine-tuning logic here
for tf in args.timeframes:
if tf in next_candles:
logger.info(f"\n{tf} timeframe predictions:")
for i, pred in enumerate(next_candles[tf]):
action = ['SELL', 'HOLD', 'BUY'][np.argmax(pred)]
confidence = np.max(pred)
logger.info(f" Candle {i+1}: {action} (confidence: {confidence:.2f})")
except Exception as e:
logger.error(f"Error during epoch {epoch+1}: {str(e)}")
@ -257,10 +287,11 @@ def train(data_interface, model, args):
# Save final model
model.save(f"models/{args.model_type}_final_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt")
logger.info(f"Training complete. Best validation metrics:")
logger.info(f"\nTraining complete. Best validation metrics:")
logger.info(f"Accuracy: {best_val_acc:.2f}")
logger.info(f"PnL: {best_val_pnl:.2%}")
logger.info(f"Win Rate: {best_win_rate:.2%}")
logger.info(f"Price MAE: {best_price_mae:.2f}")
except Exception as e:
logger.error(f"Error in training: {str(e)}")