trying to fix training

This commit is contained in:
Dobromir Popov
2025-03-29 03:53:38 +02:00
parent 2255a8363a
commit ebbc0ed2d7
7 changed files with 533 additions and 304 deletions

View File

@ -151,35 +151,59 @@ def main():
logger.info("Neural Network Trading System finished successfully")
def train(data_interface, model, args):
"""Train the model using the data interface"""
"""Enhanced training with performance tracking"""
from torch.utils.tensorboard import SummaryWriter
logger.info("Starting training mode...")
writer = SummaryWriter(log_dir=f"runs/{args.model_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
try:
# Prepare training data
logger.info("Preparing training data...")
X_train, y_train, X_val, y_val = data_interface.prepare_training_data()
best_val_acc = 0
# Train the model
logger.info("Training model...")
model.train(
X_train, y_train,
X_val, y_val,
batch_size=args.batch_size,
epochs=args.epochs
)
# Save the model
for epoch in range(args.epochs):
# Refresh data every few epochs
if epoch % 3 == 0:
X_train, y_train, X_val, y_val = data_interface.prepare_training_data(refresh=True)
else:
X_train, y_train, X_val, y_val = data_interface.prepare_training_data()
# Train for one epoch
train_loss, train_acc = model.train_epoch(
X_train, y_train,
batch_size=args.batch_size
)
# Validate
val_loss, val_acc = model.evaluate(X_val, y_val)
# Log metrics
writer.add_scalar('Loss/Train', train_loss, epoch)
writer.add_scalar('Accuracy/Train', train_acc, epoch)
writer.add_scalar('Loss/Validation', val_loss, epoch)
writer.add_scalar('Accuracy/Validation', val_acc, epoch)
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
model_path = os.path.join(
'models',
f"{args.model_type}_best_{args.symbol.replace('/', '_')}.pt"
)
model.save(model_path)
logger.info(f"New best model saved with val_acc: {val_acc:.2f}")
logger.info(f"Epoch {epoch+1}/{args.epochs} - "
f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f} - "
f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}")
# Save final model
model_path = os.path.join(
'models',
f"{args.model_type}_{args.symbol.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
'models',
f"{args.model_type}_final_{args.symbol.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
)
logger.info(f"Saving model to {model_path}...")
model.save(model_path)
# Evaluate the model
logger.info("Evaluating model...")
metrics = model.evaluate(X_val, y_val)
logger.info(f"Evaluation metrics: {metrics}")
logger.info(f"Training Complete - Best Val Accuracy: {best_val_acc:.2f}")
except Exception as e:
logger.error(f"Error in training mode: {str(e)}")