dash and training wip
This commit is contained in:
@@ -14,11 +14,7 @@ from collections import defaultdict
|
||||
import torch
|
||||
import random
|
||||
|
||||
try:
|
||||
import wandb
|
||||
WANDB_AVAILABLE = True
|
||||
except ImportError:
|
||||
WANDB_AVAILABLE = False
|
||||
WANDB_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -58,13 +54,13 @@ class CheckpointManager:
|
||||
base_checkpoint_dir: str = "NN/models/saved",
|
||||
max_checkpoints_per_model: int = 5,
|
||||
metadata_file: str = "checkpoint_metadata.json",
|
||||
enable_wandb: bool = True):
|
||||
enable_wandb: bool = False):
|
||||
self.base_dir = Path(base_checkpoint_dir)
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.max_checkpoints = max_checkpoints_per_model
|
||||
self.metadata_file = self.base_dir / metadata_file
|
||||
self.enable_wandb = enable_wandb and WANDB_AVAILABLE
|
||||
self.enable_wandb = False
|
||||
|
||||
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
|
||||
self._load_metadata()
|
||||
@@ -115,10 +111,7 @@ class CheckpointManager:
|
||||
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
|
||||
)
|
||||
|
||||
if self.enable_wandb and wandb.run is not None:
|
||||
artifact_name = self._upload_to_wandb(checkpoint_path, metadata)
|
||||
metadata.wandb_run_id = wandb.run.id
|
||||
metadata.wandb_artifact_name = artifact_name
|
||||
# W&B disabled
|
||||
|
||||
self.checkpoints[model_name].append(metadata)
|
||||
self._rotate_checkpoints(model_name)
|
||||
@@ -273,19 +266,7 @@ class CheckpointManager:
|
||||
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
|
||||
|
||||
def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]:
|
||||
try:
|
||||
if not self.enable_wandb or wandb.run is None:
|
||||
return None
|
||||
|
||||
artifact_name = f"{metadata.model_name}_checkpoint"
|
||||
artifact = wandb.Artifact(artifact_name, type="model")
|
||||
artifact.add_file(str(file_path))
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
return artifact_name
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading to W&B: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
def _load_metadata(self):
|
||||
try:
|
||||
@@ -404,6 +385,56 @@ class CheckpointManager:
|
||||
logger.debug(f"Found legacy model file in subdirectory: {candidate_path}")
|
||||
return candidate_path
|
||||
|
||||
# Extended search: scan common project model directories for best checkpoints
|
||||
try:
|
||||
# Attempt to infer project root from base_dir (NN/models/saved -> root)
|
||||
project_root = base_dir.resolve().parent.parent.parent
|
||||
except Exception:
|
||||
project_root = Path(".").resolve()
|
||||
additional_dirs = [
|
||||
project_root / "models",
|
||||
project_root / "models" / "archive",
|
||||
project_root / "models" / "backtest",
|
||||
]
|
||||
|
||||
def _match_legacy_name(candidate: Path, model: str) -> bool:
|
||||
name = candidate.name.lower()
|
||||
model_keys = {
|
||||
'dqn_agent': ['dqn', 'agent', 'policy'],
|
||||
'enhanced_cnn': ['cnn', 'optimized_short_term'],
|
||||
'extrema_trainer': ['supervised', 'extrema'],
|
||||
'cob_rl': ['cob', 'rl', 'policy'],
|
||||
'decision': ['decision', 'transformer']
|
||||
}.get(model, [model])
|
||||
return any(k in name for k in model_keys)
|
||||
|
||||
candidates: List[Path] = []
|
||||
for adir in additional_dirs:
|
||||
if not adir.exists():
|
||||
continue
|
||||
try:
|
||||
for pt in adir.rglob('*.pt'):
|
||||
# Prefer files that indicate "best" and match model hints
|
||||
lname = pt.name.lower()
|
||||
if 'best' in lname and _match_legacy_name(pt, model_name):
|
||||
candidates.append(pt)
|
||||
# Do not add generic fallbacks to avoid mismatched model types
|
||||
except Exception:
|
||||
# Ignore directory traversal issues
|
||||
pass
|
||||
|
||||
if candidates:
|
||||
# Pick the most recently modified candidate
|
||||
try:
|
||||
best = max(candidates, key=lambda p: p.stat().st_mtime)
|
||||
logger.debug(f"Found legacy model file in project models dir: {best}")
|
||||
return best
|
||||
except Exception:
|
||||
# If stat fails, just return the first one deterministically
|
||||
candidates.sort()
|
||||
logger.debug(f"Found legacy model file in project models dir: {candidates[0]}")
|
||||
return candidates[0]
|
||||
|
||||
return None
|
||||
|
||||
def _create_legacy_metadata(self, model_name: str, file_path: Path) -> CheckpointMetadata:
|
||||
|
Reference in New Issue
Block a user