refactoring
This commit is contained in:
25
NN/models/checkpoints/registry_metadata.json
Normal file
25
NN/models/checkpoints/registry_metadata.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"models": {
|
||||
"test_model": {
|
||||
"type": "cnn",
|
||||
"latest_path": "models/cnn/saved/test_model_latest.pt",
|
||||
"last_saved": "20250908_132919",
|
||||
"save_count": 1
|
||||
},
|
||||
"audit_test_model": {
|
||||
"type": "cnn",
|
||||
"latest_path": "models/cnn/saved/audit_test_model_latest.pt",
|
||||
"last_saved": "20250908_142204",
|
||||
"save_count": 2,
|
||||
"checkpoints": [
|
||||
{
|
||||
"id": "audit_test_model_20250908_142204_0.8500",
|
||||
"path": "models/cnn/checkpoints/audit_test_model_20250908_142204_0.8500.pt",
|
||||
"performance_score": 0.85,
|
||||
"timestamp": "20250908_142204"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"last_updated": "2025-09-08T14:22:04.917612"
|
||||
}
|
17
NN/models/checkpoints/saved/session_metadata.json
Normal file
17
NN/models/checkpoints/saved/session_metadata.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"timestamp": "2025-08-30T01:03:28.549034",
|
||||
"session_pnl": 0.9740795673949083,
|
||||
"trade_count": 44,
|
||||
"stored_models": [
|
||||
[
|
||||
"DQN",
|
||||
null
|
||||
],
|
||||
[
|
||||
"CNN",
|
||||
null
|
||||
]
|
||||
],
|
||||
"training_iterations": 0,
|
||||
"model_performance": {}
|
||||
}
|
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"model_name": "test_simple_model",
|
||||
"model_type": "test",
|
||||
"saved_at": "2025-09-02T15:30:36.295046",
|
||||
"save_method": "improved_model_saver",
|
||||
"test": true,
|
||||
"accuracy": 0.95
|
||||
}
|
@@ -20,8 +20,8 @@ import torch.nn.functional as F
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.model_registry import get_model_registry
|
||||
from NN.training.model_manager import save_checkpoint, load_best_checkpoint
|
||||
from NN.training.model_manager import create_model_manager
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -778,7 +778,7 @@ class CNNModelTrainer:
|
||||
def save_model(self, filepath: str = None, metadata: Optional[Dict] = None):
|
||||
"""Save model with metadata using unified registry"""
|
||||
try:
|
||||
from utils.model_registry import save_model
|
||||
from NN.training.model_manager import save_model
|
||||
|
||||
# Prepare model data
|
||||
model_data = {
|
||||
@@ -826,7 +826,7 @@ class CNNModelTrainer:
|
||||
def load_model(self, filepath: str = None) -> Dict:
|
||||
"""Load model from unified registry or file"""
|
||||
try:
|
||||
from utils.model_registry import load_model
|
||||
from NN.training.model_manager import load_model
|
||||
|
||||
# Use unified registry if no filepath or if it's a models/ path
|
||||
if filepath is None or filepath.startswith('models/'):
|
||||
|
@@ -15,8 +15,8 @@ import time
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.model_registry import get_model_registry
|
||||
from NN.training.model_manager import save_checkpoint, load_best_checkpoint
|
||||
from NN.training.model_manager import create_model_manager
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1333,7 +1333,7 @@ class DQNAgent:
|
||||
def save(self, path: str = None):
|
||||
"""Save model and agent state using unified registry"""
|
||||
try:
|
||||
from utils.model_registry import save_model
|
||||
from NN.training.model_manager import save_model
|
||||
|
||||
# Use unified registry if no path or if it's a models/ path
|
||||
if path is None or path.startswith('models/'):
|
||||
@@ -1393,7 +1393,7 @@ class DQNAgent:
|
||||
def load(self, path: str = None):
|
||||
"""Load model and agent state from unified registry or file"""
|
||||
try:
|
||||
from utils.model_registry import load_model
|
||||
from NN.training.model_manager import load_model
|
||||
|
||||
# Use unified registry if no path or if it's a models/ path
|
||||
if path is None or path.startswith('models/'):
|
||||
|
Reference in New Issue
Block a user