reduce cob model to 400m
This commit is contained in:
@ -112,11 +112,11 @@ class RealtimeRLTester:
|
||||
raise
|
||||
|
||||
async def test_model_parameter_count(self):
|
||||
"""Test that model has approximately 1B parameters"""
|
||||
"""Test that model has approximately 400M parameters"""
|
||||
logger.info("🔢 Testing Model Parameter Count...")
|
||||
|
||||
try:
|
||||
model = MassiveRLNetwork(input_size=2000, hidden_size=4096, num_layers=12)
|
||||
model = MassiveRLNetwork(input_size=2000, hidden_size=2048, num_layers=8)
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
@ -124,15 +124,23 @@ class RealtimeRLTester:
|
||||
logger.info(f"Total parameters: {total_params:,}")
|
||||
logger.info(f"Trainable parameters: {trainable_params:,}")
|
||||
|
||||
# Check if parameters are approximately 400M (350M - 450M range)
|
||||
target_400m = total_params >= 350_000_000 and total_params <= 450_000_000
|
||||
|
||||
self.test_results['test_model_parameter_count'] = {
|
||||
'status': 'PASSED',
|
||||
'status': 'PASSED' if target_400m else 'WARNING',
|
||||
'total_parameters': total_params,
|
||||
'trainable_parameters': trainable_params,
|
||||
'parameter_size_gb': (total_params * 4) / (1024**3), # 4 bytes per float32
|
||||
'is_massive': total_params > 100_000_000 # At least 100M parameters
|
||||
'is_optimized': target_400m, # Around 400M parameters for faster startup
|
||||
'target_range': '350M - 450M parameters'
|
||||
}
|
||||
|
||||
logger.info(f"✅ Model has {total_params:,} parameters ({total_params/1e9:.2f}B)")
|
||||
logger.info(f"✅ Model has {total_params:,} parameters ({total_params/1e6:.0f}M)")
|
||||
if target_400m:
|
||||
logger.info("✅ Parameter count within 400M target range for fast startup")
|
||||
else:
|
||||
logger.warning(f"⚠️ Parameter count outside 400M target range: {total_params/1e6:.0f}M")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results['test_model_parameter_count'] = {'status': 'FAILED', 'error': str(e)}
|
||||
|
Reference in New Issue
Block a user