fixes
This commit is contained in:
88
debug_checkpoint_loading.py
Normal file
88
debug_checkpoint_loading.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Debug script to test checkpoint loading
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append('.')
|
||||
|
||||
def test_checkpoint_loading():
|
||||
"""Test checkpoint loading for all models"""
|
||||
print("=== Testing Checkpoint Loading ===")
|
||||
|
||||
try:
|
||||
from utils.database_manager import get_database_manager
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
|
||||
db = get_database_manager()
|
||||
|
||||
# Test models that should have checkpoints
|
||||
models = ['dqn_agent', 'enhanced_cnn', 'cob_rl_model', 'extrema_trainer']
|
||||
|
||||
for model in models:
|
||||
print(f"\n--- Testing {model} ---")
|
||||
|
||||
# Check database
|
||||
checkpoints = db.list_checkpoints(model)
|
||||
print(f"DB checkpoints: {len(checkpoints)}")
|
||||
|
||||
if checkpoints:
|
||||
best = db.get_best_checkpoint_metadata(model)
|
||||
if best:
|
||||
print(f"Best checkpoint: {best.checkpoint_id}")
|
||||
print(f"File path: {best.file_path}")
|
||||
print(f"File exists: {os.path.exists(best.file_path)}")
|
||||
print(f"Active: {best.is_active}")
|
||||
print(f"Performance metrics: {best.performance_metrics}")
|
||||
else:
|
||||
print("No best checkpoint found in DB")
|
||||
|
||||
# Test filesystem fallback
|
||||
result = load_best_checkpoint(model)
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
print(f"Filesystem fallback: {file_path}")
|
||||
print(f"File exists: {os.path.exists(file_path)}")
|
||||
print(f"Metadata: {getattr(metadata, 'checkpoint_id', 'N/A')}")
|
||||
else:
|
||||
print("No filesystem fallback found")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def test_model_initialization():
|
||||
"""Test model initialization with checkpoint loading"""
|
||||
print("\n=== Testing Model Initialization ===")
|
||||
|
||||
try:
|
||||
# Test DQN Agent initialization
|
||||
print("\n--- Testing DQN Agent ---")
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
|
||||
# Create a minimal DQN agent
|
||||
agent = DQNAgent(
|
||||
state_shape=(10,), # Simple state shape
|
||||
n_actions=3,
|
||||
model_name="dqn_agent",
|
||||
enable_checkpoints=True
|
||||
)
|
||||
|
||||
print(f"Agent created with checkpoints enabled: {agent.enable_checkpoints}")
|
||||
print(f"Model name: {agent.model_name}")
|
||||
|
||||
# Try to load checkpoint manually
|
||||
print("Attempting manual checkpoint load...")
|
||||
agent.load_best_checkpoint()
|
||||
print("Manual checkpoint load completed")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in model initialization: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_checkpoint_loading()
|
||||
test_model_initialization()
|
||||
Reference in New Issue
Block a user