improve training and model data
This commit is contained in:
@ -18,6 +18,9 @@ import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from models import ModelInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -221,12 +224,13 @@ class MassiveRLNetwork(nn.Module):
|
||||
}
|
||||
|
||||
|
||||
class COBRLModelInterface:
|
||||
class COBRLModelInterface(ModelInterface):
|
||||
"""
|
||||
Interface for the COB RL model that handles model management, training, and inference
|
||||
"""
|
||||
|
||||
def __init__(self, model_checkpoint_dir: str = "models/realtime_rl_cob", device: str = None):
|
||||
super().__init__(name="cob_rl_model") # Initialize ModelInterface with a name
|
||||
self.model_checkpoint_dir = model_checkpoint_dir
|
||||
self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu'))
|
||||
|
||||
@ -368,4 +372,23 @@ class COBRLModelInterface:
|
||||
|
||||
def get_model_stats(self) -> Dict[str, Any]:
|
||||
"""Get model statistics"""
|
||||
return self.model.get_model_info()
|
||||
return self.model.get_model_info()
|
||||
|
||||
def get_memory_usage(self) -> float:
|
||||
"""Estimate COBRLModel memory usage in MB"""
|
||||
# This is an estimation. For a more precise value, you'd inspect tensors.
|
||||
# A massive network might take hundreds of MBs or even GBs.
|
||||
# Let's use a more realistic estimate for a 1B parameter model.
|
||||
# Assuming float32 (4 bytes per parameter), 1B params = 4GB.
|
||||
# For a 400M parameter network (as mentioned in comments), it's 1.6GB.
|
||||
# Let's use a placeholder if it's too complex to calculate dynamically.
|
||||
try:
|
||||
# Calculate total parameters and convert to MB
|
||||
total_params = sum(p.numel() for p in self.model.parameters())
|
||||
# Assuming float32 (4 bytes per parameter) and converting to MB
|
||||
memory_bytes = total_params * 4
|
||||
memory_mb = memory_bytes / (1024 * 1024)
|
||||
return memory_mb
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not estimate COBRLModel memory usage: {e}")
|
||||
return 1600.0 # Default to 1.6 GB as an estimate if calculation fails
|
Reference in New Issue
Block a user