refactoring
This commit is contained in:
@@ -4711,7 +4711,7 @@ class CleanTradingDashboard:
|
||||
stored_models = []
|
||||
|
||||
# Use unified model registry for saving
|
||||
from utils.model_registry import save_model
|
||||
from NN.training.model_manager import save_model
|
||||
|
||||
# 1. Store DQN model
|
||||
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
@@ -6129,7 +6129,7 @@ class CleanTradingDashboard:
|
||||
# Save checkpoint after training
|
||||
if loss_count > 0:
|
||||
try:
|
||||
from utils.checkpoint_manager import save_checkpoint
|
||||
from NN.training.model_manager import save_checkpoint
|
||||
avg_loss = total_loss / loss_count
|
||||
|
||||
# Prepare checkpoint data
|
||||
@@ -6452,7 +6452,7 @@ class CleanTradingDashboard:
|
||||
# Try to load existing transformer checkpoint first
|
||||
if transformer_model is None or transformer_trainer is None:
|
||||
try:
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
from NN.training.model_manager import load_best_checkpoint
|
||||
|
||||
# Try to load the best transformer checkpoint
|
||||
checkpoint_metadata = load_best_checkpoint("transformer", "transformer")
|
||||
@@ -6687,7 +6687,7 @@ class CleanTradingDashboard:
|
||||
# Save checkpoint periodically with proper checkpoint management
|
||||
if transformer_trainer.training_history['train_loss']:
|
||||
try:
|
||||
from utils.checkpoint_manager import save_checkpoint
|
||||
from NN.training.model_manager import save_checkpoint
|
||||
|
||||
# Prepare checkpoint data
|
||||
checkpoint_data = {
|
||||
@@ -6740,7 +6740,7 @@ class CleanTradingDashboard:
|
||||
logger.error(f"Error saving transformer checkpoint: {e}")
|
||||
# Use unified registry for checkpoint
|
||||
try:
|
||||
from utils.model_registry import save_checkpoint as registry_save_checkpoint
|
||||
from NN.training.model_manager import save_checkpoint as registry_save_checkpoint
|
||||
|
||||
checkpoint_data = torch.load(checkpoint_path, map_location='cpu') if 'checkpoint_path' in locals() else checkpoint_data
|
||||
|
||||
|
Reference in New Issue
Block a user