massive clenup

This commit is contained in:
Dobromir Popov
2025-05-24 10:32:00 +03:00
parent 310f3c5bf9
commit b5ad023b16
87 changed files with 1930 additions and 784568 deletions

115
tests/test_essential.py Normal file
View File

@ -0,0 +1,115 @@
#!/usr/bin/env python3
"""
Essential Test Suite - Core functionality tests
This file contains the most important tests to verify core functionality:
- Data loading and processing
- Basic model operations
- Trading signal generation
- Critical utility functions
"""
import sys
import os
import unittest
import logging
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
logger = logging.getLogger(__name__)
class TestEssentialFunctionality(unittest.TestCase):
"""Essential tests for core trading system functionality"""
def test_imports(self):
"""Test that all critical modules can be imported"""
try:
from core.config import get_config
from core.data_provider import DataProvider
from utils.model_utils import robust_save, robust_load
logger.info("✅ All critical imports successful")
except ImportError as e:
self.fail(f"Critical import failed: {e}")
def test_config_loading(self):
"""Test configuration loading"""
try:
from core.config import get_config
config = get_config()
self.assertIsNotNone(config, "Config should be loaded")
logger.info("✅ Configuration loading successful")
except Exception as e:
self.fail(f"Config loading failed: {e}")
def test_data_provider_initialization(self):
"""Test DataProvider can be initialized"""
try:
from core.data_provider import DataProvider
data_provider = DataProvider(['ETH/USDT'], ['1m'])
self.assertIsNotNone(data_provider, "DataProvider should initialize")
logger.info("✅ DataProvider initialization successful")
except Exception as e:
self.fail(f"DataProvider initialization failed: {e}")
def test_model_utils(self):
"""Test model utility functions"""
try:
from utils.model_utils import get_model_info
import tempfile
# Test with non-existent file
info = get_model_info("non_existent_file.pt")
self.assertFalse(info['exists'], "Should report file doesn't exist")
logger.info("✅ Model utils test successful")
except Exception as e:
self.fail(f"Model utils test failed: {e}")
def test_signal_generation_logic(self):
"""Test basic signal generation logic"""
import numpy as np
# Test signal distribution calculation
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0]) # SELL, HOLD, BUY
buy_count = np.sum(predictions == 2)
sell_count = np.sum(predictions == 0)
hold_count = np.sum(predictions == 1)
total = len(predictions)
distribution = {
"BUY": buy_count / total,
"SELL": sell_count / total,
"HOLD": hold_count / total
}
# Verify calculations
self.assertAlmostEqual(distribution["BUY"], 0.3, places=1)
self.assertAlmostEqual(distribution["SELL"], 0.3, places=1)
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=1)
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=1)
logger.info("✅ Signal generation logic test successful")
def run_essential_tests():
"""Run essential tests only"""
suite = unittest.TestLoader().loadTestsFromTestCase(TestEssentialFunctionality)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
return result.wasSuccessful()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
logger.info("Running essential functionality tests...")
success = run_essential_tests()
if success:
logger.info("✅ All essential tests passed!")
sys.exit(0)
else:
logger.error("❌ Essential tests failed!")
sys.exit(1)

View File

@ -0,0 +1,402 @@
#!/usr/bin/env python3
"""
Comprehensive Indicators and Signals Test Suite
This module consolidates testing functionality for:
- Technical indicators (from test_indicators.py)
- Signal interpretation and processing (from test_signal_interpreter.py)
- Market data analysis
- Trading signal validation
"""
import sys
import os
import unittest
import logging
import numpy as np
import tempfile
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from core.config import setup_logging
from core.data_provider import DataProvider
logger = logging.getLogger(__name__)
class TestTechnicalIndicators(unittest.TestCase):
"""Test suite for technical indicators functionality"""
def setUp(self):
"""Set up test fixtures"""
setup_logging()
self.data_provider = DataProvider(['ETH/USDT'], ['1h'])
def test_indicator_calculation(self):
"""Test that indicators are calculated correctly"""
logger.info("Testing technical indicators calculation...")
try:
# Fetch data with indicators
df = self.data_provider.get_historical_data('ETH/USDT', '1h', refresh=True, limit=100)
self.assertIsNotNone(df, "Should fetch data successfully")
self.assertGreater(len(df), 0, "Should have data rows")
# Check basic OHLCV columns
basic_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
for col in basic_cols:
self.assertIn(col, df.columns, f"Should have {col} column")
# Check that indicators are calculated
indicator_cols = [col for col in df.columns if col not in basic_cols]
self.assertGreater(len(indicator_cols), 0, "Should have technical indicators")
logger.info(f"✅ Successfully calculated {len(indicator_cols)} indicators")
except Exception as e:
logger.warning(f"Indicator test failed: {e}")
self.skipTest("Data or indicators not available")
def test_indicator_categorization(self):
"""Test categorization of different indicator types"""
logger.info("Testing indicator categorization...")
try:
df = self.data_provider.get_historical_data('ETH/USDT', '1h', refresh=True, limit=100)
if df is not None:
basic_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
indicator_cols = [col for col in df.columns if col not in basic_cols]
# Categorize indicators
trend_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['sma', 'ema', 'macd', 'adx', 'psar'])]
momentum_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['rsi', 'stoch', 'williams', 'cci'])]
volatility_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['bb_', 'atr', 'keltner'])]
volume_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['volume', 'obv', 'vpt', 'mfi', 'ad_line', 'vwap'])]
# Check we have indicators in each category
total_categorized = len(trend_indicators) + len(momentum_indicators) + len(volatility_indicators) + len(volume_indicators)
logger.info(f"Indicator categories:")
logger.info(f" Trend: {len(trend_indicators)}")
logger.info(f" Momentum: {len(momentum_indicators)}")
logger.info(f" Volatility: {len(volatility_indicators)}")
logger.info(f" Volume: {len(volume_indicators)}")
logger.info(f" Total categorized: {total_categorized}/{len(indicator_cols)}")
self.assertGreater(total_categorized, 0, "Should have categorized indicators")
else:
self.skipTest("Could not fetch data for categorization test")
except Exception as e:
logger.warning(f"Categorization test failed: {e}")
self.skipTest("Indicator categorization not available")
def test_feature_matrix_creation(self):
"""Test multi-timeframe feature matrix creation"""
logger.info("Testing feature matrix creation...")
try:
# Test feature matrix with multiple timeframes
feature_matrix = self.data_provider.get_feature_matrix('ETH/USDT', ['1h'], window_size=20)
if feature_matrix is not None:
self.assertEqual(len(feature_matrix.shape), 3, "Should be 3D matrix")
self.assertGreater(feature_matrix.shape[2], 0, "Should have features")
logger.info(f"✅ Feature matrix shape: {feature_matrix.shape}")
else:
self.skipTest("Could not create feature matrix")
except Exception as e:
logger.warning(f"Feature matrix test failed: {e}")
self.skipTest("Feature matrix creation not available")
class TestSignalProcessing(unittest.TestCase):
"""Test suite for signal interpretation and processing"""
def test_signal_distribution_calculation(self):
"""Test signal distribution calculation"""
logger.info("Testing signal distribution calculation...")
# Mock predictions (SELL=0, HOLD=1, BUY=2)
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0])
buy_count = np.sum(predictions == 2)
sell_count = np.sum(predictions == 0)
hold_count = np.sum(predictions == 1)
total = len(predictions)
distribution = {
"BUY": buy_count / total,
"SELL": sell_count / total,
"HOLD": hold_count / total
}
# Verify calculations
self.assertAlmostEqual(distribution["BUY"], 0.3, places=2)
self.assertAlmostEqual(distribution["SELL"], 0.3, places=2)
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2)
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2)
logger.info("✅ Signal distribution calculation test passed")
def test_basic_signal_interpretation(self):
"""Test basic signal interpretation logic"""
logger.info("Testing basic signal interpretation...")
# Test cases with different probability distributions
test_cases = [
{
'probs': [0.8, 0.1, 0.1], # Strong SELL
'expected_action': 'SELL',
'expected_confidence': 'high'
},
{
'probs': [0.1, 0.1, 0.8], # Strong BUY
'expected_action': 'BUY',
'expected_confidence': 'high'
},
{
'probs': [0.1, 0.8, 0.1], # Strong HOLD
'expected_action': 'HOLD',
'expected_confidence': 'high'
},
{
'probs': [0.4, 0.3, 0.3], # Uncertain - should prefer SELL (index 0)
'expected_action': 'SELL',
'expected_confidence': 'low'
},
{
'probs': [0.33, 0.33, 0.34], # Very uncertain - slight BUY preference
'expected_action': 'BUY',
'expected_confidence': 'low'
}
]
for i, test_case in enumerate(test_cases):
probs = np.array(test_case['probs'])
expected_action = test_case['expected_action']
# Simple signal interpretation (argmax)
predicted_action_idx = np.argmax(probs)
action_map = {0: 'SELL', 1: 'HOLD', 2: 'BUY'}
predicted_action = action_map[predicted_action_idx]
# Calculate confidence (max probability)
confidence = np.max(probs)
confidence_level = 'high' if confidence > 0.7 else 'medium' if confidence > 0.5 else 'low'
# Verify predictions
self.assertEqual(predicted_action, expected_action,
f"Test case {i+1}: Expected {expected_action}, got {predicted_action}")
logger.info(f"Test case {i+1}: {probs} -> {predicted_action} ({confidence_level} confidence)")
logger.info("✅ Basic signal interpretation test passed")
def test_signal_filtering_logic(self):
"""Test signal filtering and validation logic"""
logger.info("Testing signal filtering logic...")
# Test threshold-based filtering
buy_threshold = 0.6
sell_threshold = 0.6
hold_threshold = 0.7
test_signals = [
{
'probs': [0.8, 0.1, 0.1], # Strong SELL (above threshold)
'should_pass': True,
'expected': 'SELL'
},
{
'probs': [0.5, 0.3, 0.2], # Weak SELL (below threshold)
'should_pass': False,
'expected': 'HOLD'
},
{
'probs': [0.1, 0.2, 0.7], # Strong BUY (above threshold)
'should_pass': True,
'expected': 'BUY'
},
{
'probs': [0.2, 0.8, 0.0], # Strong HOLD (above threshold)
'should_pass': True,
'expected': 'HOLD'
}
]
for i, test in enumerate(test_signals):
probs = np.array(test['probs'])
sell_prob, hold_prob, buy_prob = probs
# Apply threshold filtering
if sell_prob >= sell_threshold:
filtered_action = 'SELL'
passed_filter = True
elif buy_prob >= buy_threshold:
filtered_action = 'BUY'
passed_filter = True
elif hold_prob >= hold_threshold:
filtered_action = 'HOLD'
passed_filter = True
else:
filtered_action = 'HOLD' # Default to HOLD if no threshold met
passed_filter = False
# Verify filtering
expected_pass = test['should_pass']
expected_action = test['expected']
self.assertEqual(passed_filter, expected_pass,
f"Test {i+1}: Filter pass expectation failed")
self.assertEqual(filtered_action, expected_action,
f"Test {i+1}: Expected {expected_action}, got {filtered_action}")
logger.info(f"Test {i+1}: {probs} -> {filtered_action} (passed: {passed_filter})")
logger.info("✅ Signal filtering logic test passed")
def test_signal_sequence_validation(self):
"""Test signal sequence validation and oscillation prevention"""
logger.info("Testing signal sequence validation...")
# Simulate a sequence of signals that might oscillate
signal_sequence = ['BUY', 'SELL', 'BUY', 'SELL', 'HOLD', 'BUY']
# Simple oscillation detection
oscillation_count = 0
for i in range(1, len(signal_sequence)):
if (signal_sequence[i-1] == 'BUY' and signal_sequence[i] == 'SELL') or \
(signal_sequence[i-1] == 'SELL' and signal_sequence[i] == 'BUY'):
oscillation_count += 1
# Count consecutive non-HOLD signals
consecutive_trades = 0
max_consecutive = 0
for signal in signal_sequence:
if signal != 'HOLD':
consecutive_trades += 1
max_consecutive = max(max_consecutive, consecutive_trades)
else:
consecutive_trades = 0
# Verify oscillation detection
self.assertGreater(oscillation_count, 0, "Should detect oscillations in test sequence")
self.assertGreater(max_consecutive, 1, "Should detect consecutive trades")
logger.info(f"Detected {oscillation_count} oscillations and max {max_consecutive} consecutive trades")
logger.info("✅ Signal sequence validation test passed")
class TestMarketDataAnalysis(unittest.TestCase):
"""Test suite for market data analysis functionality"""
def test_price_movement_calculation(self):
"""Test price movement and trend calculation"""
logger.info("Testing price movement calculation...")
# Mock price data
prices = np.array([100.0, 101.0, 102.5, 101.8, 103.2, 102.9, 104.1])
# Calculate price movements
price_changes = np.diff(prices)
percentage_changes = (price_changes / prices[:-1]) * 100
# Calculate simple trend
recent_trend = np.mean(percentage_changes[-3:]) # Last 3 changes
trend_direction = 'uptrend' if recent_trend > 0.1 else 'downtrend' if recent_trend < -0.1 else 'sideways'
# Verify calculations
self.assertEqual(len(price_changes), len(prices) - 1, "Should have n-1 price changes")
self.assertEqual(len(percentage_changes), len(prices) - 1, "Should have n-1 percentage changes")
# Verify trend detection makes sense
self.assertIn(trend_direction, ['uptrend', 'downtrend', 'sideways'], "Should detect valid trend")
logger.info(f"Price sequence: {prices}")
logger.info(f"Recent trend: {trend_direction} ({recent_trend:.2f}%)")
logger.info("✅ Price movement calculation test passed")
def test_volatility_measurement(self):
"""Test volatility measurement"""
logger.info("Testing volatility measurement...")
# Mock price data with different volatility
stable_prices = np.array([100.0, 100.1, 99.9, 100.2, 99.8, 100.0])
volatile_prices = np.array([100.0, 105.0, 95.0, 110.0, 90.0, 115.0])
# Calculate volatility (standard deviation of returns)
def calculate_volatility(prices):
returns = np.diff(prices) / prices[:-1]
return np.std(returns) * 100 # As percentage
stable_vol = calculate_volatility(stable_prices)
volatile_vol = calculate_volatility(volatile_prices)
# Verify volatility measurements
self.assertLess(stable_vol, volatile_vol, "Stable prices should have lower volatility")
self.assertGreater(volatile_vol, 5.0, "Volatile prices should have significant volatility")
logger.info(f"Stable volatility: {stable_vol:.2f}%")
logger.info(f"Volatile volatility: {volatile_vol:.2f}%")
logger.info("✅ Volatility measurement test passed")
def run_indicator_tests():
"""Run indicator tests only"""
suite = unittest.TestLoader().loadTestsFromTestCase(TestTechnicalIndicators)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
return result.wasSuccessful()
def run_signal_tests():
"""Run signal processing tests only"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestSignalProcessing),
unittest.TestLoader().loadTestsFromTestCase(TestMarketDataAnalysis),
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
def run_all_tests():
"""Run all indicator and signal tests"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestTechnicalIndicators),
unittest.TestLoader().loadTestsFromTestCase(TestSignalProcessing),
unittest.TestLoader().loadTestsFromTestCase(TestMarketDataAnalysis),
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
if __name__ == "__main__":
setup_logging()
logger.info("Running indicators and signals test suite...")
if len(sys.argv) > 1:
test_type = sys.argv[1]
if test_type == "indicators":
success = run_indicator_tests()
elif test_type == "signals":
success = run_signal_tests()
else:
success = run_all_tests()
else:
success = run_all_tests()
if success:
logger.info("✅ All indicator and signal tests passed!")
sys.exit(0)
else:
logger.error("❌ Some tests failed!")
sys.exit(1)

View File

@ -0,0 +1,274 @@
#!/usr/bin/env python
"""
Comprehensive test suite for model persistence and training functionality
"""
import os
import sys
import unittest
import tempfile
import logging
import torch
import numpy as np
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from utils.model_utils import robust_save, robust_load, get_model_info, verify_save_load_cycle
# Configure logging for tests
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class MockAgent:
"""Mock agent class for testing model persistence"""
def __init__(self, state_size=64, action_size=4, hidden_size=256):
self.state_size = state_size
self.action_size = action_size
self.hidden_size = hidden_size
self.epsilon = 0.1
# Create simple mock networks
self.policy_net = torch.nn.Sequential(
torch.nn.Linear(state_size, hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(hidden_size, action_size)
)
self.target_net = torch.nn.Sequential(
torch.nn.Linear(state_size, hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(hidden_size, action_size)
)
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=0.001)
class TestModelPersistence(unittest.TestCase):
"""Test suite for model saving and loading functionality"""
def setUp(self):
"""Set up test fixtures"""
self.temp_dir = tempfile.mkdtemp()
self.test_agent = MockAgent()
def tearDown(self):
"""Clean up test fixtures"""
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_robust_save_basic(self):
"""Test basic robust save functionality"""
save_path = os.path.join(self.temp_dir, "test_model.pt")
success = robust_save(self.test_agent, save_path)
self.assertTrue(success, "Robust save should succeed")
self.assertTrue(os.path.exists(save_path), "Model file should exist")
self.assertGreater(os.path.getsize(save_path), 0, "Model file should not be empty")
def test_robust_save_without_optimizer(self):
"""Test robust save without optimizer state"""
save_path = os.path.join(self.temp_dir, "test_model_no_opt.pt")
success = robust_save(self.test_agent, save_path, include_optimizer=False)
self.assertTrue(success, "Robust save without optimizer should succeed")
# Verify that optimizer state is not included
checkpoint = torch.load(save_path, map_location='cpu')
self.assertNotIn('optimizer', checkpoint, "Optimizer state should not be saved")
self.assertIn('policy_net', checkpoint, "Policy network should be saved")
def test_robust_load_basic(self):
"""Test basic robust load functionality"""
save_path = os.path.join(self.temp_dir, "test_model.pt")
# Save first
success = robust_save(self.test_agent, save_path)
self.assertTrue(success, "Save should succeed")
# Create new agent and load
new_agent = MockAgent()
success = robust_load(new_agent, save_path)
self.assertTrue(success, "Load should succeed")
# Verify epsilon was loaded
self.assertEqual(new_agent.epsilon, self.test_agent.epsilon, "Epsilon should match")
def test_get_model_info(self):
"""Test model info extraction"""
save_path = os.path.join(self.temp_dir, "test_model.pt")
# Test non-existent file
info = get_model_info(save_path)
self.assertFalse(info['exists'], "Non-existent file should return exists=False")
# Save model and test info
robust_save(self.test_agent, save_path)
info = get_model_info(save_path)
self.assertTrue(info['exists'], "Existing file should return exists=True")
self.assertGreater(info['size_bytes'], 0, "File size should be greater than 0")
self.assertTrue(info['has_optimizer'], "Should detect optimizer in checkpoint")
self.assertEqual(info['parameters']['state_size'], self.test_agent.state_size)
self.assertEqual(info['parameters']['action_size'], self.test_agent.action_size)
def test_save_load_cycle_verification(self):
"""Test save/load cycle verification"""
test_path = os.path.join(self.temp_dir, "cycle_test.pt")
success = verify_save_load_cycle(self.test_agent, test_path)
self.assertTrue(success, "Save/load cycle should succeed")
# File should be cleaned up after verification
self.assertFalse(os.path.exists(test_path), "Test file should be cleaned up")
def test_multiple_save_methods(self):
"""Test that different save methods all work"""
methods = ['regular', 'no_optimizer', 'pickle2']
for method in methods:
with self.subTest(method=method):
save_path = os.path.join(self.temp_dir, f"test_{method}.pt")
if method == 'regular':
success = robust_save(self.test_agent, save_path)
elif method == 'no_optimizer':
success = robust_save(self.test_agent, save_path, include_optimizer=False)
elif method == 'pickle2':
# This would be tested by the robust_save fallback mechanism
success = robust_save(self.test_agent, save_path)
self.assertTrue(success, f"{method} save should succeed")
self.assertTrue(os.path.exists(save_path), f"{method} save should create file")
class TestTrainingMetrics(unittest.TestCase):
"""Test suite for training metrics and monitoring functionality"""
def test_signal_distribution_calculation(self):
"""Test signal distribution calculation"""
# Mock predictions
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0]) # SELL, HOLD, BUY
buy_count = np.sum(predictions == 2)
sell_count = np.sum(predictions == 0)
hold_count = np.sum(predictions == 1)
total = len(predictions)
distribution = {
"BUY": buy_count / total,
"SELL": sell_count / total,
"HOLD": hold_count / total
}
self.assertAlmostEqual(distribution["BUY"], 0.3, places=2)
self.assertAlmostEqual(distribution["SELL"], 0.3, places=2)
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2)
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2)
def test_metrics_tracking_structure(self):
"""Test metrics history structure for training monitoring"""
metrics_history = {
"epoch": [],
"train_loss": [],
"val_loss": [],
"train_acc": [],
"val_acc": [],
"train_pnl": [],
"val_pnl": [],
"train_win_rate": [],
"val_win_rate": [],
"signal_distribution": []
}
# Simulate adding metrics for one epoch
metrics_history["epoch"].append(1)
metrics_history["train_loss"].append(0.5)
metrics_history["val_loss"].append(0.6)
metrics_history["train_acc"].append(0.7)
metrics_history["val_acc"].append(0.65)
metrics_history["train_pnl"].append(0.1)
metrics_history["val_pnl"].append(0.08)
metrics_history["train_win_rate"].append(0.6)
metrics_history["val_win_rate"].append(0.55)
metrics_history["signal_distribution"].append({"BUY": 0.3, "SELL": 0.3, "HOLD": 0.4})
# Verify structure
self.assertEqual(len(metrics_history["epoch"]), 1)
self.assertEqual(metrics_history["epoch"][0], 1)
self.assertIsInstance(metrics_history["signal_distribution"][0], dict)
self.assertIn("BUY", metrics_history["signal_distribution"][0])
class TestModelArchitecture(unittest.TestCase):
"""Test suite for model architecture verification"""
def test_model_parameter_consistency(self):
"""Test that model parameters are consistent after save/load"""
agent = MockAgent(state_size=32, action_size=3, hidden_size=128)
with tempfile.TemporaryDirectory() as temp_dir:
save_path = os.path.join(temp_dir, "consistency_test.pt")
# Save model
robust_save(agent, save_path)
# Load into new model with same architecture
new_agent = MockAgent(state_size=32, action_size=3, hidden_size=128)
robust_load(new_agent, save_path)
# Verify parameters match
self.assertEqual(new_agent.state_size, agent.state_size)
self.assertEqual(new_agent.action_size, agent.action_size)
self.assertEqual(new_agent.hidden_size, agent.hidden_size)
self.assertEqual(new_agent.epsilon, agent.epsilon)
def test_model_forward_pass(self):
"""Test that model can perform forward pass after load"""
agent = MockAgent()
with tempfile.TemporaryDirectory() as temp_dir:
save_path = os.path.join(temp_dir, "forward_test.pt")
# Create test input
test_input = torch.randn(1, agent.state_size)
# Get original output
original_output = agent.policy_net(test_input)
# Save and load
robust_save(agent, save_path)
new_agent = MockAgent()
robust_load(new_agent, save_path)
# Test forward pass works
new_output = new_agent.policy_net(test_input)
self.assertEqual(new_output.shape, original_output.shape)
# Outputs should be identical since we loaded the same weights
torch.testing.assert_close(new_output, original_output)
def run_all_tests():
"""Run all test suites"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestModelPersistence),
unittest.TestLoader().loadTestsFromTestCase(TestTrainingMetrics),
unittest.TestLoader().loadTestsFromTestCase(TestModelArchitecture)
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
if __name__ == "__main__":
logger.info("Running comprehensive model persistence and training tests...")
success = run_all_tests()
if success:
logger.info("All tests passed!")
sys.exit(0)
else:
logger.error("Some tests failed!")
sys.exit(1)

View File

@ -0,0 +1,395 @@
#!/usr/bin/env python3
"""
Comprehensive Training Integration Tests
This module consolidates and improves test functionality from multiple test files:
- CNN training tests (from test_cnn_only.py, test_training.py)
- Model testing (from test_model.py)
- Chart data testing (from test_chart_data.py)
- Integration testing between components
"""
import sys
import os
import logging
import time
import unittest
import tempfile
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from core.config import setup_logging, get_config
from core.data_provider import DataProvider
from training.cnn_trainer import CNNTrainer
from training.rl_trainer import RLTrainer
from dataprovider_realtime import RealTimeChart, TickStorage, BinanceHistoricalData
logger = logging.getLogger(__name__)
class TestDataProviders(unittest.TestCase):
"""Test suite for data provider functionality"""
def test_binance_historical_data(self):
"""Test Binance historical data fetching"""
logger.info("Testing Binance historical data fetch...")
try:
binance_data = BinanceHistoricalData()
df = binance_data.get_historical_candles("ETH/USDT", 60, 100)
self.assertIsNotNone(df, "Should fetch data successfully")
self.assertFalse(df.empty, "Data should not be empty")
self.assertGreater(len(df), 0, "Should have candles")
# Verify data structure
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
for col in required_columns:
self.assertIn(col, df.columns, f"Should have {col} column")
logger.info(f"✅ Successfully fetched {len(df)} candles")
return True
except Exception as e:
logger.warning(f"Binance API test failed: {e}")
self.skipTest("Binance API not available")
def test_tick_storage(self):
"""Test TickStorage functionality"""
logger.info("Testing TickStorage data loading...")
try:
tick_storage = TickStorage("ETH/USDT", ["1m", "5m", "1h"])
success = tick_storage.load_historical_data("ETH/USDT", limit=100)
if success:
# Check timeframes
for tf in ["1m", "5m", "1h"]:
candles = tick_storage.get_candles(tf)
logger.info(f" {tf}: {len(candles)} candles")
logger.info("✅ TickStorage working correctly")
return True
else:
self.skipTest("Could not load tick storage data")
except Exception as e:
logger.warning(f"TickStorage test failed: {e}")
self.skipTest("TickStorage not available")
def test_chart_initialization(self):
"""Test RealTimeChart initialization"""
logger.info("Testing RealTimeChart initialization...")
try:
chart = RealTimeChart(app=None, symbol="ETH/USDT", standalone=False)
# Test getting candles
candles_1m = chart.get_candles(60)
self.assertIsInstance(candles_1m, list, "Should return list of candles")
logger.info(f"✅ Chart initialized with {len(candles_1m)} 1m candles")
except Exception as e:
logger.warning(f"Chart initialization failed: {e}")
self.skipTest("Chart initialization not available")
class TestCNNTraining(unittest.TestCase):
"""Test suite for CNN training functionality"""
def setUp(self):
"""Set up test fixtures"""
self.temp_dir = tempfile.mkdtemp()
setup_logging()
def tearDown(self):
"""Clean up test fixtures"""
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_cnn_quick_training(self):
"""Test quick CNN training with small dataset"""
logger.info("Testing CNN quick training...")
try:
config = get_config()
# Test configuration
symbols = ['ETH/USDT']
timeframes = ['1m', '5m']
num_samples = 100 # Very small for testing
epochs = 1
batch_size = 16
# Override config for quick test
config._config['timeframes'] = timeframes
trainer = CNNTrainer(config)
trainer.batch_size = batch_size
trainer.epochs = epochs
# Train model
save_path = os.path.join(self.temp_dir, 'test_cnn.pt')
results = trainer.train(symbols, save_path=save_path, num_samples=num_samples)
# Verify results
self.assertIsInstance(results, dict, "Should return results dict")
self.assertIn('best_val_accuracy', results, "Should have accuracy metric")
self.assertIn('total_epochs', results, "Should have epoch count")
self.assertIn('training_time', results, "Should have training time")
# Verify model was saved
self.assertTrue(os.path.exists(save_path), "Model should be saved")
logger.info(f"✅ CNN training completed successfully")
logger.info(f" Best accuracy: {results['best_val_accuracy']:.4f}")
logger.info(f" Training time: {results['training_time']:.2f}s")
except Exception as e:
logger.error(f"CNN training test failed: {e}")
raise
finally:
if hasattr(trainer, 'close_tensorboard'):
trainer.close_tensorboard()
class TestRLTraining(unittest.TestCase):
"""Test suite for RL training functionality"""
def setUp(self):
"""Set up test fixtures"""
self.temp_dir = tempfile.mkdtemp()
setup_logging()
def tearDown(self):
"""Clean up test fixtures"""
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_rl_quick_training(self):
"""Test quick RL training with small dataset"""
logger.info("Testing RL quick training...")
try:
# Setup minimal configuration
data_provider = DataProvider(['ETH/USDT'], ['1m', '5m'])
trainer = RLTrainer(data_provider)
# Configure for very quick test
trainer.num_episodes = 5
trainer.max_steps_per_episode = 50
trainer.evaluation_frequency = 3
trainer.save_frequency = 10 # Don't save during test
# Train
save_path = os.path.join(self.temp_dir, 'test_rl.pt')
results = trainer.train(save_path=save_path)
# Verify results
self.assertIsInstance(results, dict, "Should return results dict")
self.assertIn('total_episodes', results, "Should have episode count")
self.assertIn('best_reward', results, "Should have best reward")
self.assertIn('final_evaluation', results, "Should have final evaluation")
logger.info(f"✅ RL training completed successfully")
logger.info(f" Total episodes: {results['total_episodes']}")
logger.info(f" Best reward: {results['best_reward']:.4f}")
except Exception as e:
logger.error(f"RL training test failed: {e}")
raise
class TestExtendedTraining(unittest.TestCase):
"""Test suite for extended training functionality (from test_model.py)"""
def test_metrics_tracking(self):
"""Test comprehensive metrics tracking functionality"""
logger.info("Testing extended metrics tracking...")
# Test metrics history structure
metrics_history = {
"epoch": [],
"train_loss": [],
"val_loss": [],
"train_acc": [],
"val_acc": [],
"train_pnl": [],
"val_pnl": [],
"train_win_rate": [],
"val_win_rate": [],
"signal_distribution": []
}
# Simulate adding metrics
for epoch in range(3):
metrics_history["epoch"].append(epoch + 1)
metrics_history["train_loss"].append(0.5 - epoch * 0.1)
metrics_history["val_loss"].append(0.6 - epoch * 0.1)
metrics_history["train_acc"].append(0.6 + epoch * 0.05)
metrics_history["val_acc"].append(0.55 + epoch * 0.05)
metrics_history["train_pnl"].append(epoch * 0.1)
metrics_history["val_pnl"].append(epoch * 0.08)
metrics_history["train_win_rate"].append(0.5 + epoch * 0.1)
metrics_history["val_win_rate"].append(0.45 + epoch * 0.1)
metrics_history["signal_distribution"].append({
"BUY": 0.3, "SELL": 0.3, "HOLD": 0.4
})
# Verify structure
self.assertEqual(len(metrics_history["epoch"]), 3)
self.assertEqual(len(metrics_history["train_loss"]), 3)
self.assertEqual(len(metrics_history["signal_distribution"]), 3)
# Verify improvement
self.assertLess(metrics_history["train_loss"][-1], metrics_history["train_loss"][0])
self.assertGreater(metrics_history["train_acc"][-1], metrics_history["train_acc"][0])
logger.info("✅ Metrics tracking test passed")
def test_signal_distribution_calculation(self):
"""Test signal distribution calculation"""
import numpy as np
# Mock predictions (SELL=0, HOLD=1, BUY=2)
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0])
buy_count = np.sum(predictions == 2)
sell_count = np.sum(predictions == 0)
hold_count = np.sum(predictions == 1)
total = len(predictions)
distribution = {
"BUY": buy_count / total,
"SELL": sell_count / total,
"HOLD": hold_count / total
}
# Verify calculations
self.assertAlmostEqual(distribution["BUY"], 0.3, places=2)
self.assertAlmostEqual(distribution["SELL"], 0.3, places=2)
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2)
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2)
logger.info("✅ Signal distribution calculation test passed")
class TestIntegration(unittest.TestCase):
"""Integration tests between components"""
def test_training_pipeline_integration(self):
"""Test that CNN and RL training can work together"""
logger.info("Testing training pipeline integration...")
with tempfile.TemporaryDirectory() as temp_dir:
try:
# Quick CNN training
config = get_config()
config._config['timeframes'] = ['1m']
cnn_trainer = CNNTrainer(config)
cnn_trainer.epochs = 1
cnn_trainer.batch_size = 8
cnn_path = os.path.join(temp_dir, 'test_cnn.pt')
cnn_results = cnn_trainer.train(['ETH/USDT'], save_path=cnn_path, num_samples=50)
# Quick RL training
data_provider = DataProvider(['ETH/USDT'], ['1m'])
rl_trainer = RLTrainer(data_provider)
rl_trainer.num_episodes = 3
rl_trainer.max_steps_per_episode = 25
rl_path = os.path.join(temp_dir, 'test_rl.pt')
rl_results = rl_trainer.train(save_path=rl_path)
# Verify both trained successfully
self.assertIsInstance(cnn_results, dict)
self.assertIsInstance(rl_results, dict)
self.assertTrue(os.path.exists(cnn_path))
self.assertTrue(os.path.exists(rl_path))
logger.info("✅ Training pipeline integration test passed")
except Exception as e:
logger.error(f"Integration test failed: {e}")
raise
finally:
if 'cnn_trainer' in locals():
cnn_trainer.close_tensorboard()
def run_quick_tests():
"""Run only the quickest tests for fast validation"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestExtendedTraining),
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
def run_data_tests():
"""Run data provider tests"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestDataProviders),
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
def run_training_tests():
"""Run training tests (slower)"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestCNNTraining),
unittest.TestLoader().loadTestsFromTestCase(TestRLTraining),
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
def run_all_tests():
"""Run all test suites"""
test_suites = [
unittest.TestLoader().loadTestsFromTestCase(TestDataProviders),
unittest.TestLoader().loadTestsFromTestCase(TestCNNTraining),
unittest.TestLoader().loadTestsFromTestCase(TestRLTraining),
unittest.TestLoader().loadTestsFromTestCase(TestExtendedTraining),
unittest.TestLoader().loadTestsFromTestCase(TestIntegration),
]
combined_suite = unittest.TestSuite(test_suites)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(combined_suite)
return result.wasSuccessful()
if __name__ == "__main__":
setup_logging()
logger.info("Running comprehensive training integration tests...")
if len(sys.argv) > 1:
test_type = sys.argv[1]
if test_type == "quick":
success = run_quick_tests()
elif test_type == "data":
success = run_data_tests()
elif test_type == "training":
success = run_training_tests()
else:
success = run_all_tests()
else:
success = run_all_tests()
if success:
logger.info("✅ All tests passed!")
sys.exit(0)
else:
logger.error("❌ Some tests failed!")
sys.exit(1)

View File

@ -1,128 +0,0 @@
import asyncio
import json
import logging
import unittest
from typing import Optional, Dict
import websockets
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class TestMEXCWebSocket(unittest.TestCase):
async def test_websocket_connection(self):
"""Test basic WebSocket connection and subscription"""
uri = "wss://stream.mexc.com/ws"
symbol = "ethusdt"
async with websockets.connect(uri) as ws:
# Test subscription to deals
sub_msg = {
"op": "sub",
"id": "test1",
"topic": f"spot.deals.{symbol}"
}
# Send subscription
await ws.send(json.dumps(sub_msg))
# Wait for subscription confirmation and first message
messages_received = 0
trades_received = 0
while messages_received < 5: # Get at least 5 messages
try:
response = await asyncio.wait_for(ws.recv(), timeout=10)
messages_received += 1
logger.info(f"Received message: {response[:200]}...")
data = json.loads(response)
# Check message structure
if isinstance(data, dict):
if 'channel' in data:
if data['channel'] == 'spot.deals':
trades = data.get('data', [])
if trades:
trades_received += 1
logger.info(f"Received trade data: {trades[0]}")
# Verify trade data structure
trade = trades[0]
self.assertIn('t', trade) # timestamp
self.assertIn('p', trade) # price
self.assertIn('v', trade) # volume
self.assertIn('S', trade) # side
except asyncio.TimeoutError:
self.fail("Timeout waiting for WebSocket messages")
# Verify we received some trades
self.assertGreater(trades_received, 0, "No trades received")
# Test unsubscribe
unsub_msg = {
"op": "unsub",
"id": "test1",
"topic": f"spot.deals.{symbol}"
}
await ws.send(json.dumps(unsub_msg))
async def test_kline_subscription(self):
"""Test subscription to kline (candlestick) data"""
uri = "wss://stream.mexc.com/ws"
symbol = "ethusdt"
async with websockets.connect(uri) as ws:
# Subscribe to 1m klines
sub_msg = {
"op": "sub",
"id": "test2",
"topic": f"spot.klines.{symbol}_1m"
}
await ws.send(json.dumps(sub_msg))
messages_received = 0
klines_received = 0
while messages_received < 5:
try:
response = await asyncio.wait_for(ws.recv(), timeout=10)
messages_received += 1
logger.info(f"Received kline message: {response[:200]}...")
data = json.loads(response)
if isinstance(data, dict):
if data.get('channel') == 'spot.klines':
kline_data = data.get('data', [])
if kline_data:
klines_received += 1
logger.info(f"Received kline data: {kline_data[0]}")
# Verify kline data structure (should be an array)
kline = kline_data[0]
self.assertEqual(len(kline), 6) # Should have 6 elements
except asyncio.TimeoutError:
self.fail("Timeout waiting for kline data")
self.assertGreater(klines_received, 0, "No kline data received")
def run_tests():
"""Run the WebSocket tests"""
async def run_async_tests():
# Create test suite
suite = unittest.TestSuite()
suite.addTest(TestMEXCWebSocket('test_websocket_connection'))
suite.addTest(TestMEXCWebSocket('test_kline_subscription'))
# Run tests
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite)
# Run tests in asyncio loop
asyncio.run(run_async_tests())
if __name__ == "__main__":
run_tests()