88 lines
3.0 KiB
Python
88 lines
3.0 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Debug script to test checkpoint loading
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
sys.path.append('.')
|
|
|
|
def test_checkpoint_loading():
|
|
"""Test checkpoint loading for all models"""
|
|
print("=== Testing Checkpoint Loading ===")
|
|
|
|
try:
|
|
from utils.database_manager import get_database_manager
|
|
from utils.checkpoint_manager import load_best_checkpoint
|
|
|
|
db = get_database_manager()
|
|
|
|
# Test models that should have checkpoints
|
|
models = ['dqn_agent', 'enhanced_cnn', 'cob_rl_model', 'extrema_trainer']
|
|
|
|
for model in models:
|
|
print(f"\n--- Testing {model} ---")
|
|
|
|
# Check database
|
|
checkpoints = db.list_checkpoints(model)
|
|
print(f"DB checkpoints: {len(checkpoints)}")
|
|
|
|
if checkpoints:
|
|
best = db.get_best_checkpoint_metadata(model)
|
|
if best:
|
|
print(f"Best checkpoint: {best.checkpoint_id}")
|
|
print(f"File path: {best.file_path}")
|
|
print(f"File exists: {os.path.exists(best.file_path)}")
|
|
print(f"Active: {best.is_active}")
|
|
print(f"Performance metrics: {best.performance_metrics}")
|
|
else:
|
|
print("No best checkpoint found in DB")
|
|
|
|
# Test filesystem fallback
|
|
result = load_best_checkpoint(model)
|
|
if result:
|
|
file_path, metadata = result
|
|
print(f"Filesystem fallback: {file_path}")
|
|
print(f"File exists: {os.path.exists(file_path)}")
|
|
print(f"Metadata: {getattr(metadata, 'checkpoint_id', 'N/A')}")
|
|
else:
|
|
print("No filesystem fallback found")
|
|
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
def test_model_initialization():
|
|
"""Test model initialization with checkpoint loading"""
|
|
print("\n=== Testing Model Initialization ===")
|
|
|
|
try:
|
|
# Test DQN Agent initialization
|
|
print("\n--- Testing DQN Agent ---")
|
|
from NN.models.dqn_agent import DQNAgent
|
|
|
|
# Create a minimal DQN agent
|
|
agent = DQNAgent(
|
|
state_shape=(10,), # Simple state shape
|
|
n_actions=3,
|
|
model_name="dqn_agent",
|
|
enable_checkpoints=True
|
|
)
|
|
|
|
print(f"Agent created with checkpoints enabled: {agent.enable_checkpoints}")
|
|
print(f"Model name: {agent.model_name}")
|
|
|
|
# Try to load checkpoint manually
|
|
print("Attempting manual checkpoint load...")
|
|
agent.load_best_checkpoint()
|
|
print("Manual checkpoint load completed")
|
|
|
|
except Exception as e:
|
|
print(f"Error in model initialization: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
if __name__ == "__main__":
|
|
test_checkpoint_loading()
|
|
test_model_initialization() |