gogo2/tests/test_model_persistence.py
2025-05-24 10:32:00 +03:00

274 lines
11 KiB
Python

#!/usr/bin/env python
"""
Comprehensive test suite for model persistence and training functionality
"""
import os
import sys
import unittest
import tempfile
import logging
import torch
import numpy as np
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from utils.model_utils import robust_save, robust_load, get_model_info, verify_save_load_cycle
# Configure logging for tests
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class MockAgent:
"""Mock agent class for testing model persistence"""
def __init__(self, state_size=64, action_size=4, hidden_size=256):
self.state_size = state_size
self.action_size = action_size
self.hidden_size = hidden_size
self.epsilon = 0.1
# Create simple mock networks
self.policy_net = torch.nn.Sequential(
torch.nn.Linear(state_size, hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(hidden_size, action_size)
)
self.target_net = torch.nn.Sequential(
torch.nn.Linear(state_size, hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(hidden_size, action_size)
)
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=0.001)
class TestModelPersistence(unittest.TestCase):
"""Test suite for model saving and loading functionality"""
def setUp(self):
"""Set up test fixtures"""
self.temp_dir = tempfile.mkdtemp()
self.test_agent = MockAgent()
def tearDown(self):
"""Clean up test fixtures"""
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_robust_save_basic(self):
"""Test basic robust save functionality"""
save_path = os.path.join(self.temp_dir, "test_model.pt")
success = robust_save(self.test_agent, save_path)
self.assertTrue(success, "Robust save should succeed")
self.assertTrue(os.path.exists(save_path), "Model file should exist")
self.assertGreater(os.path.getsize(save_path), 0, "Model file should not be empty")
def test_robust_save_without_optimizer(self):
"""Test robust save without optimizer state"""
save_path = os.path.join(self.temp_dir, "test_model_no_opt.pt")
success = robust_save(self.test_agent, save_path, include_optimizer=False)
self.assertTrue(success, "Robust save without optimizer should succeed")
# Verify that optimizer state is not included
checkpoint = torch.load(save_path, map_location='cpu')
self.assertNotIn('optimizer', checkpoint, "Optimizer state should not be saved")
self.assertIn('policy_net', checkpoint, "Policy network should be saved")
def test_robust_load_basic(self):
"""Test basic robust load functionality"""
save_path = os.path.join(self.temp_dir, "test_model.pt")
# Save first
success = robust_save(self.test_agent, save_path)
self.assertTrue(success, "Save should succeed")
# Create new agent and load
new_agent = MockAgent()
success = robust_load(new_agent, save_path)
self.assertTrue(success, "Load should succeed")
# Verify epsilon was loaded
self.assertEqual(new_agent.epsilon, self.test_agent.epsilon, "Epsilon should match")
def test_get_model_info(self):
"""Test model info extraction"""
save_path = os.path.join(self.temp_dir, "test_model.pt")
# Test non-existent file
info = get_model_info(save_path)
self.assertFalse(info['exists'], "Non-existent file should return exists=False")
# Save model and test info
robust_save(self.test_agent, save_path)
info = get_model_info(save_path)
self.assertTrue(info['exists'], "Existing file should return exists=True")
self.assertGreater(info['size_bytes'], 0, "File size should be greater than 0")
self.assertTrue(info['has_optimizer'], "Should detect optimizer in checkpoint")
self.assertEqual(info['parameters']['state_size'], self.test_agent.state_size)
self.assertEqual(info['parameters']['action_size'], self.test_agent.action_size)
def test_save_load_cycle_verification(self):
"""Test save/load cycle verification"""
test_path = os.path.join(self.temp_dir, "cycle_test.pt")
success = verify_save_load_cycle(self.test_agent, test_path)
self.assertTrue(success, "Save/load cycle should succeed")
# File should be cleaned up after verification
self.assertFalse(os.path.exists(test_path), "Test file should be cleaned up")
def test_multiple_save_methods(self):
"""Test that different save methods all work"""
methods = ['regular', 'no_optimizer', 'pickle2']
for method in methods:
with self.subTest(method=method):
save_path = os.path.join(self.temp_dir, f"test_{method}.pt")
if method == 'regular':
success = robust_save(self.test_agent, save_path)
elif method == 'no_optimizer':
success = robust_save(self.test_agent, save_path, include_optimizer=False)
elif method == 'pickle2':
# This would be tested by the robust_save fallback mechanism
success = robust_save(self.test_agent, save_path)
self.assertTrue(success, f"{method} save should succeed")
self.assertTrue(os.path.exists(save_path), f"{method} save should create file")
class TestTrainingMetrics(unittest.TestCase):
"""Test suite for training metrics and monitoring functionality"""
def test_signal_distribution_calculation(self):
"""Test signal distribution calculation"""
# Mock predictions
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0]) # SELL, HOLD, BUY
buy_count = np.sum(predictions == 2)
sell_count = np.sum(predictions == 0)
hold_count = np.sum(predictions == 1)
total = len(predictions)
distribution = {
"BUY": buy_count / total,
"SELL": sell_count / total,
"HOLD": hold_count / total
}
self.assertAlmostEqual(distribution["BUY"], 0.3, places=2)
self.assertAlmostEqual(distribution["SELL"], 0.3, places=2)
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2)
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2)
def test_metrics_tracking_structure(self):
"""Test metrics history structure for training monitoring"""
metrics_history = {
"epoch": [],
"train_loss": [],
"val_loss": [],
"train_acc": [],
"val_acc": [],
"train_pnl": [],
"val_pnl": [],
"train_win_rate": [],
"val_win_rate": [],
"signal_distribution": []
}
# Simulate adding metrics for one epoch
metrics_history["epoch"].append(1)
metrics_history["train_loss"].append(0.5)
metrics_history["val_loss"].append(0.6)
metrics_history["train_acc"].append(0.7)
metrics_history["val_acc"].append(0.65)
metrics_history["train_pnl"].append(0.1)
metrics_history["val_pnl"].append(0.08)
metrics_history["train_win_rate"].append(0.6)
metrics_history["val_win_rate"].append(0.55)
metrics_history["signal_distribution"].append({"BUY": 0.3, "SELL": 0.3, "HOLD": 0.4})
# Verify structure
self.assertEqual(len(metrics_history["epoch"]), 1)
self.assertEqual(metrics_history["epoch"][0], 1)
self.assertIsInstance(metrics_history["signal_distribution"][0], dict)
self.assertIn("BUY", metrics_history["signal_distribution"][0])
class TestModelArchitecture(unittest.TestCase):
"""Test suite for model architecture verification"""
def test_model_parameter_consistency(self):
"""Test that model parameters are consistent after save/load"""
agent = MockAgent(state_size=32, action_size=3, hidden_size=128)
with tempfile.TemporaryDirectory() as temp_dir:
save_path = os.path.join(temp_dir, "consistency_test.pt")
# Save model
robust_save(agent, save_path)
# Load into new model with same architecture
new_agent = MockAgent(state_size=32, action_size=3, hidden_size=128)
robust_load(new_agent, save_path)
# Verify parameters match
self.assertEqual(new_agent.state_size, agent.state_size)
self.assertEqual(new_agent.action_size, agent.action_size)
self.assertEqual(new_agent.hidden_size, agent.hidden_size)
self.assertEqual(new_agent.epsilon, agent.epsilon)
def test_model_forward_pass(self):
"""Test that model can perform forward pass after load"""
agent = MockAgent()
with tempfile.TemporaryDirectory() as temp_dir:
save_path = os.path.join(temp_dir, "forward_test.pt")
# Create test input
test_input = torch.randn(1, agent.state_size)
# Get original output
original_output = agent.policy_net(test_input)
# Save and load
robust_save(agent, save_path)
new_agent = MockAgent()
robust_load(new_agent, save_path)
# Test forward pass works
new_output = new_agent.policy_net(test_input)
self.assertEqual(new_output.shape, original_output.shape)
# Outputs should be identical since we loaded the same weights
torch.testing.assert_close(new_output, original_output)
def run_all_tests():
"""Run all test suites"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestModelPersistence),
unittest.TestLoader().loadTestsFromTestCase(TestTrainingMetrics),
unittest.TestLoader().loadTestsFromTestCase(TestModelArchitecture)
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
if __name__ == "__main__":
logger.info("Running comprehensive model persistence and training tests...")
success = run_all_tests()
if success:
logger.info("All tests passed!")
sys.exit(0)
else:
logger.error("Some tests failed!")
sys.exit(1)