trying to fix training
This commit is contained in:
@ -151,35 +151,59 @@ def main():
|
||||
logger.info("Neural Network Trading System finished successfully")
|
||||
|
||||
def train(data_interface, model, args):
|
||||
"""Train the model using the data interface"""
|
||||
"""Enhanced training with performance tracking"""
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
logger.info("Starting training mode...")
|
||||
writer = SummaryWriter(log_dir=f"runs/{args.model_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
||||
|
||||
try:
|
||||
# Prepare training data
|
||||
logger.info("Preparing training data...")
|
||||
X_train, y_train, X_val, y_val = data_interface.prepare_training_data()
|
||||
best_val_acc = 0
|
||||
|
||||
# Train the model
|
||||
logger.info("Training model...")
|
||||
model.train(
|
||||
X_train, y_train,
|
||||
X_val, y_val,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs
|
||||
)
|
||||
|
||||
# Save the model
|
||||
for epoch in range(args.epochs):
|
||||
# Refresh data every few epochs
|
||||
if epoch % 3 == 0:
|
||||
X_train, y_train, X_val, y_val = data_interface.prepare_training_data(refresh=True)
|
||||
else:
|
||||
X_train, y_train, X_val, y_val = data_interface.prepare_training_data()
|
||||
|
||||
# Train for one epoch
|
||||
train_loss, train_acc = model.train_epoch(
|
||||
X_train, y_train,
|
||||
batch_size=args.batch_size
|
||||
)
|
||||
|
||||
# Validate
|
||||
val_loss, val_acc = model.evaluate(X_val, y_val)
|
||||
|
||||
# Log metrics
|
||||
writer.add_scalar('Loss/Train', train_loss, epoch)
|
||||
writer.add_scalar('Accuracy/Train', train_acc, epoch)
|
||||
writer.add_scalar('Loss/Validation', val_loss, epoch)
|
||||
writer.add_scalar('Accuracy/Validation', val_acc, epoch)
|
||||
|
||||
# Save best model
|
||||
if val_acc > best_val_acc:
|
||||
best_val_acc = val_acc
|
||||
model_path = os.path.join(
|
||||
'models',
|
||||
f"{args.model_type}_best_{args.symbol.replace('/', '_')}.pt"
|
||||
)
|
||||
model.save(model_path)
|
||||
logger.info(f"New best model saved with val_acc: {val_acc:.2f}")
|
||||
|
||||
logger.info(f"Epoch {epoch+1}/{args.epochs} - "
|
||||
f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f} - "
|
||||
f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}")
|
||||
|
||||
# Save final model
|
||||
model_path = os.path.join(
|
||||
'models',
|
||||
f"{args.model_type}_{args.symbol.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
'models',
|
||||
f"{args.model_type}_final_{args.symbol.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
|
||||
)
|
||||
logger.info(f"Saving model to {model_path}...")
|
||||
model.save(model_path)
|
||||
|
||||
# Evaluate the model
|
||||
logger.info("Evaluating model...")
|
||||
metrics = model.evaluate(X_val, y_val)
|
||||
logger.info(f"Evaluation metrics: {metrics}")
|
||||
logger.info(f"Training Complete - Best Val Accuracy: {best_val_acc:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training mode: {str(e)}")
|
||||
|
Reference in New Issue
Block a user