massive clenup
This commit is contained in:
115
tests/test_essential.py
Normal file
115
tests/test_essential.py
Normal file
@ -0,0 +1,115 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Essential Test Suite - Core functionality tests
|
||||
|
||||
This file contains the most important tests to verify core functionality:
|
||||
- Data loading and processing
|
||||
- Basic model operations
|
||||
- Trading signal generation
|
||||
- Critical utility functions
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import unittest
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TestEssentialFunctionality(unittest.TestCase):
|
||||
"""Essential tests for core trading system functionality"""
|
||||
|
||||
def test_imports(self):
|
||||
"""Test that all critical modules can be imported"""
|
||||
try:
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from utils.model_utils import robust_save, robust_load
|
||||
logger.info("✅ All critical imports successful")
|
||||
except ImportError as e:
|
||||
self.fail(f"Critical import failed: {e}")
|
||||
|
||||
def test_config_loading(self):
|
||||
"""Test configuration loading"""
|
||||
try:
|
||||
from core.config import get_config
|
||||
config = get_config()
|
||||
self.assertIsNotNone(config, "Config should be loaded")
|
||||
logger.info("✅ Configuration loading successful")
|
||||
except Exception as e:
|
||||
self.fail(f"Config loading failed: {e}")
|
||||
|
||||
def test_data_provider_initialization(self):
|
||||
"""Test DataProvider can be initialized"""
|
||||
try:
|
||||
from core.data_provider import DataProvider
|
||||
data_provider = DataProvider(['ETH/USDT'], ['1m'])
|
||||
self.assertIsNotNone(data_provider, "DataProvider should initialize")
|
||||
logger.info("✅ DataProvider initialization successful")
|
||||
except Exception as e:
|
||||
self.fail(f"DataProvider initialization failed: {e}")
|
||||
|
||||
def test_model_utils(self):
|
||||
"""Test model utility functions"""
|
||||
try:
|
||||
from utils.model_utils import get_model_info
|
||||
import tempfile
|
||||
|
||||
# Test with non-existent file
|
||||
info = get_model_info("non_existent_file.pt")
|
||||
self.assertFalse(info['exists'], "Should report file doesn't exist")
|
||||
|
||||
logger.info("✅ Model utils test successful")
|
||||
except Exception as e:
|
||||
self.fail(f"Model utils test failed: {e}")
|
||||
|
||||
def test_signal_generation_logic(self):
|
||||
"""Test basic signal generation logic"""
|
||||
import numpy as np
|
||||
|
||||
# Test signal distribution calculation
|
||||
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0]) # SELL, HOLD, BUY
|
||||
|
||||
buy_count = np.sum(predictions == 2)
|
||||
sell_count = np.sum(predictions == 0)
|
||||
hold_count = np.sum(predictions == 1)
|
||||
total = len(predictions)
|
||||
|
||||
distribution = {
|
||||
"BUY": buy_count / total,
|
||||
"SELL": sell_count / total,
|
||||
"HOLD": hold_count / total
|
||||
}
|
||||
|
||||
# Verify calculations
|
||||
self.assertAlmostEqual(distribution["BUY"], 0.3, places=1)
|
||||
self.assertAlmostEqual(distribution["SELL"], 0.3, places=1)
|
||||
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=1)
|
||||
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=1)
|
||||
|
||||
logger.info("✅ Signal generation logic test successful")
|
||||
|
||||
def run_essential_tests():
|
||||
"""Run essential tests only"""
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(TestEssentialFunctionality)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(suite)
|
||||
return result.wasSuccessful()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger.info("Running essential functionality tests...")
|
||||
|
||||
success = run_essential_tests()
|
||||
|
||||
if success:
|
||||
logger.info("✅ All essential tests passed!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("❌ Essential tests failed!")
|
||||
sys.exit(1)
|
402
tests/test_indicators_and_signals.py
Normal file
402
tests/test_indicators_and_signals.py
Normal file
@ -0,0 +1,402 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive Indicators and Signals Test Suite
|
||||
|
||||
This module consolidates testing functionality for:
|
||||
- Technical indicators (from test_indicators.py)
|
||||
- Signal interpretation and processing (from test_signal_interpreter.py)
|
||||
- Market data analysis
|
||||
- Trading signal validation
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import unittest
|
||||
import logging
|
||||
import numpy as np
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import setup_logging
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TestTechnicalIndicators(unittest.TestCase):
|
||||
"""Test suite for technical indicators functionality"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
setup_logging()
|
||||
self.data_provider = DataProvider(['ETH/USDT'], ['1h'])
|
||||
|
||||
def test_indicator_calculation(self):
|
||||
"""Test that indicators are calculated correctly"""
|
||||
logger.info("Testing technical indicators calculation...")
|
||||
|
||||
try:
|
||||
# Fetch data with indicators
|
||||
df = self.data_provider.get_historical_data('ETH/USDT', '1h', refresh=True, limit=100)
|
||||
|
||||
self.assertIsNotNone(df, "Should fetch data successfully")
|
||||
self.assertGreater(len(df), 0, "Should have data rows")
|
||||
|
||||
# Check basic OHLCV columns
|
||||
basic_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
for col in basic_cols:
|
||||
self.assertIn(col, df.columns, f"Should have {col} column")
|
||||
|
||||
# Check that indicators are calculated
|
||||
indicator_cols = [col for col in df.columns if col not in basic_cols]
|
||||
self.assertGreater(len(indicator_cols), 0, "Should have technical indicators")
|
||||
|
||||
logger.info(f"✅ Successfully calculated {len(indicator_cols)} indicators")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Indicator test failed: {e}")
|
||||
self.skipTest("Data or indicators not available")
|
||||
|
||||
def test_indicator_categorization(self):
|
||||
"""Test categorization of different indicator types"""
|
||||
logger.info("Testing indicator categorization...")
|
||||
|
||||
try:
|
||||
df = self.data_provider.get_historical_data('ETH/USDT', '1h', refresh=True, limit=100)
|
||||
|
||||
if df is not None:
|
||||
basic_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
indicator_cols = [col for col in df.columns if col not in basic_cols]
|
||||
|
||||
# Categorize indicators
|
||||
trend_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['sma', 'ema', 'macd', 'adx', 'psar'])]
|
||||
momentum_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['rsi', 'stoch', 'williams', 'cci'])]
|
||||
volatility_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['bb_', 'atr', 'keltner'])]
|
||||
volume_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['volume', 'obv', 'vpt', 'mfi', 'ad_line', 'vwap'])]
|
||||
|
||||
# Check we have indicators in each category
|
||||
total_categorized = len(trend_indicators) + len(momentum_indicators) + len(volatility_indicators) + len(volume_indicators)
|
||||
|
||||
logger.info(f"Indicator categories:")
|
||||
logger.info(f" Trend: {len(trend_indicators)}")
|
||||
logger.info(f" Momentum: {len(momentum_indicators)}")
|
||||
logger.info(f" Volatility: {len(volatility_indicators)}")
|
||||
logger.info(f" Volume: {len(volume_indicators)}")
|
||||
logger.info(f" Total categorized: {total_categorized}/{len(indicator_cols)}")
|
||||
|
||||
self.assertGreater(total_categorized, 0, "Should have categorized indicators")
|
||||
|
||||
else:
|
||||
self.skipTest("Could not fetch data for categorization test")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Categorization test failed: {e}")
|
||||
self.skipTest("Indicator categorization not available")
|
||||
|
||||
def test_feature_matrix_creation(self):
|
||||
"""Test multi-timeframe feature matrix creation"""
|
||||
logger.info("Testing feature matrix creation...")
|
||||
|
||||
try:
|
||||
# Test feature matrix with multiple timeframes
|
||||
feature_matrix = self.data_provider.get_feature_matrix('ETH/USDT', ['1h'], window_size=20)
|
||||
|
||||
if feature_matrix is not None:
|
||||
self.assertEqual(len(feature_matrix.shape), 3, "Should be 3D matrix")
|
||||
self.assertGreater(feature_matrix.shape[2], 0, "Should have features")
|
||||
|
||||
logger.info(f"✅ Feature matrix shape: {feature_matrix.shape}")
|
||||
|
||||
else:
|
||||
self.skipTest("Could not create feature matrix")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Feature matrix test failed: {e}")
|
||||
self.skipTest("Feature matrix creation not available")
|
||||
|
||||
class TestSignalProcessing(unittest.TestCase):
|
||||
"""Test suite for signal interpretation and processing"""
|
||||
|
||||
def test_signal_distribution_calculation(self):
|
||||
"""Test signal distribution calculation"""
|
||||
logger.info("Testing signal distribution calculation...")
|
||||
|
||||
# Mock predictions (SELL=0, HOLD=1, BUY=2)
|
||||
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0])
|
||||
|
||||
buy_count = np.sum(predictions == 2)
|
||||
sell_count = np.sum(predictions == 0)
|
||||
hold_count = np.sum(predictions == 1)
|
||||
total = len(predictions)
|
||||
|
||||
distribution = {
|
||||
"BUY": buy_count / total,
|
||||
"SELL": sell_count / total,
|
||||
"HOLD": hold_count / total
|
||||
}
|
||||
|
||||
# Verify calculations
|
||||
self.assertAlmostEqual(distribution["BUY"], 0.3, places=2)
|
||||
self.assertAlmostEqual(distribution["SELL"], 0.3, places=2)
|
||||
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2)
|
||||
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2)
|
||||
|
||||
logger.info("✅ Signal distribution calculation test passed")
|
||||
|
||||
def test_basic_signal_interpretation(self):
|
||||
"""Test basic signal interpretation logic"""
|
||||
logger.info("Testing basic signal interpretation...")
|
||||
|
||||
# Test cases with different probability distributions
|
||||
test_cases = [
|
||||
{
|
||||
'probs': [0.8, 0.1, 0.1], # Strong SELL
|
||||
'expected_action': 'SELL',
|
||||
'expected_confidence': 'high'
|
||||
},
|
||||
{
|
||||
'probs': [0.1, 0.1, 0.8], # Strong BUY
|
||||
'expected_action': 'BUY',
|
||||
'expected_confidence': 'high'
|
||||
},
|
||||
{
|
||||
'probs': [0.1, 0.8, 0.1], # Strong HOLD
|
||||
'expected_action': 'HOLD',
|
||||
'expected_confidence': 'high'
|
||||
},
|
||||
{
|
||||
'probs': [0.4, 0.3, 0.3], # Uncertain - should prefer SELL (index 0)
|
||||
'expected_action': 'SELL',
|
||||
'expected_confidence': 'low'
|
||||
},
|
||||
{
|
||||
'probs': [0.33, 0.33, 0.34], # Very uncertain - slight BUY preference
|
||||
'expected_action': 'BUY',
|
||||
'expected_confidence': 'low'
|
||||
}
|
||||
]
|
||||
|
||||
for i, test_case in enumerate(test_cases):
|
||||
probs = np.array(test_case['probs'])
|
||||
expected_action = test_case['expected_action']
|
||||
|
||||
# Simple signal interpretation (argmax)
|
||||
predicted_action_idx = np.argmax(probs)
|
||||
action_map = {0: 'SELL', 1: 'HOLD', 2: 'BUY'}
|
||||
predicted_action = action_map[predicted_action_idx]
|
||||
|
||||
# Calculate confidence (max probability)
|
||||
confidence = np.max(probs)
|
||||
confidence_level = 'high' if confidence > 0.7 else 'medium' if confidence > 0.5 else 'low'
|
||||
|
||||
# Verify predictions
|
||||
self.assertEqual(predicted_action, expected_action,
|
||||
f"Test case {i+1}: Expected {expected_action}, got {predicted_action}")
|
||||
|
||||
logger.info(f"Test case {i+1}: {probs} -> {predicted_action} ({confidence_level} confidence)")
|
||||
|
||||
logger.info("✅ Basic signal interpretation test passed")
|
||||
|
||||
def test_signal_filtering_logic(self):
|
||||
"""Test signal filtering and validation logic"""
|
||||
logger.info("Testing signal filtering logic...")
|
||||
|
||||
# Test threshold-based filtering
|
||||
buy_threshold = 0.6
|
||||
sell_threshold = 0.6
|
||||
hold_threshold = 0.7
|
||||
|
||||
test_signals = [
|
||||
{
|
||||
'probs': [0.8, 0.1, 0.1], # Strong SELL (above threshold)
|
||||
'should_pass': True,
|
||||
'expected': 'SELL'
|
||||
},
|
||||
{
|
||||
'probs': [0.5, 0.3, 0.2], # Weak SELL (below threshold)
|
||||
'should_pass': False,
|
||||
'expected': 'HOLD'
|
||||
},
|
||||
{
|
||||
'probs': [0.1, 0.2, 0.7], # Strong BUY (above threshold)
|
||||
'should_pass': True,
|
||||
'expected': 'BUY'
|
||||
},
|
||||
{
|
||||
'probs': [0.2, 0.8, 0.0], # Strong HOLD (above threshold)
|
||||
'should_pass': True,
|
||||
'expected': 'HOLD'
|
||||
}
|
||||
]
|
||||
|
||||
for i, test in enumerate(test_signals):
|
||||
probs = np.array(test['probs'])
|
||||
sell_prob, hold_prob, buy_prob = probs
|
||||
|
||||
# Apply threshold filtering
|
||||
if sell_prob >= sell_threshold:
|
||||
filtered_action = 'SELL'
|
||||
passed_filter = True
|
||||
elif buy_prob >= buy_threshold:
|
||||
filtered_action = 'BUY'
|
||||
passed_filter = True
|
||||
elif hold_prob >= hold_threshold:
|
||||
filtered_action = 'HOLD'
|
||||
passed_filter = True
|
||||
else:
|
||||
filtered_action = 'HOLD' # Default to HOLD if no threshold met
|
||||
passed_filter = False
|
||||
|
||||
# Verify filtering
|
||||
expected_pass = test['should_pass']
|
||||
expected_action = test['expected']
|
||||
|
||||
self.assertEqual(passed_filter, expected_pass,
|
||||
f"Test {i+1}: Filter pass expectation failed")
|
||||
self.assertEqual(filtered_action, expected_action,
|
||||
f"Test {i+1}: Expected {expected_action}, got {filtered_action}")
|
||||
|
||||
logger.info(f"Test {i+1}: {probs} -> {filtered_action} (passed: {passed_filter})")
|
||||
|
||||
logger.info("✅ Signal filtering logic test passed")
|
||||
|
||||
def test_signal_sequence_validation(self):
|
||||
"""Test signal sequence validation and oscillation prevention"""
|
||||
logger.info("Testing signal sequence validation...")
|
||||
|
||||
# Simulate a sequence of signals that might oscillate
|
||||
signal_sequence = ['BUY', 'SELL', 'BUY', 'SELL', 'HOLD', 'BUY']
|
||||
|
||||
# Simple oscillation detection
|
||||
oscillation_count = 0
|
||||
for i in range(1, len(signal_sequence)):
|
||||
if (signal_sequence[i-1] == 'BUY' and signal_sequence[i] == 'SELL') or \
|
||||
(signal_sequence[i-1] == 'SELL' and signal_sequence[i] == 'BUY'):
|
||||
oscillation_count += 1
|
||||
|
||||
# Count consecutive non-HOLD signals
|
||||
consecutive_trades = 0
|
||||
max_consecutive = 0
|
||||
for signal in signal_sequence:
|
||||
if signal != 'HOLD':
|
||||
consecutive_trades += 1
|
||||
max_consecutive = max(max_consecutive, consecutive_trades)
|
||||
else:
|
||||
consecutive_trades = 0
|
||||
|
||||
# Verify oscillation detection
|
||||
self.assertGreater(oscillation_count, 0, "Should detect oscillations in test sequence")
|
||||
self.assertGreater(max_consecutive, 1, "Should detect consecutive trades")
|
||||
|
||||
logger.info(f"Detected {oscillation_count} oscillations and max {max_consecutive} consecutive trades")
|
||||
logger.info("✅ Signal sequence validation test passed")
|
||||
|
||||
class TestMarketDataAnalysis(unittest.TestCase):
|
||||
"""Test suite for market data analysis functionality"""
|
||||
|
||||
def test_price_movement_calculation(self):
|
||||
"""Test price movement and trend calculation"""
|
||||
logger.info("Testing price movement calculation...")
|
||||
|
||||
# Mock price data
|
||||
prices = np.array([100.0, 101.0, 102.5, 101.8, 103.2, 102.9, 104.1])
|
||||
|
||||
# Calculate price movements
|
||||
price_changes = np.diff(prices)
|
||||
percentage_changes = (price_changes / prices[:-1]) * 100
|
||||
|
||||
# Calculate simple trend
|
||||
recent_trend = np.mean(percentage_changes[-3:]) # Last 3 changes
|
||||
trend_direction = 'uptrend' if recent_trend > 0.1 else 'downtrend' if recent_trend < -0.1 else 'sideways'
|
||||
|
||||
# Verify calculations
|
||||
self.assertEqual(len(price_changes), len(prices) - 1, "Should have n-1 price changes")
|
||||
self.assertEqual(len(percentage_changes), len(prices) - 1, "Should have n-1 percentage changes")
|
||||
|
||||
# Verify trend detection makes sense
|
||||
self.assertIn(trend_direction, ['uptrend', 'downtrend', 'sideways'], "Should detect valid trend")
|
||||
|
||||
logger.info(f"Price sequence: {prices}")
|
||||
logger.info(f"Recent trend: {trend_direction} ({recent_trend:.2f}%)")
|
||||
logger.info("✅ Price movement calculation test passed")
|
||||
|
||||
def test_volatility_measurement(self):
|
||||
"""Test volatility measurement"""
|
||||
logger.info("Testing volatility measurement...")
|
||||
|
||||
# Mock price data with different volatility
|
||||
stable_prices = np.array([100.0, 100.1, 99.9, 100.2, 99.8, 100.0])
|
||||
volatile_prices = np.array([100.0, 105.0, 95.0, 110.0, 90.0, 115.0])
|
||||
|
||||
# Calculate volatility (standard deviation of returns)
|
||||
def calculate_volatility(prices):
|
||||
returns = np.diff(prices) / prices[:-1]
|
||||
return np.std(returns) * 100 # As percentage
|
||||
|
||||
stable_vol = calculate_volatility(stable_prices)
|
||||
volatile_vol = calculate_volatility(volatile_prices)
|
||||
|
||||
# Verify volatility measurements
|
||||
self.assertLess(stable_vol, volatile_vol, "Stable prices should have lower volatility")
|
||||
self.assertGreater(volatile_vol, 5.0, "Volatile prices should have significant volatility")
|
||||
|
||||
logger.info(f"Stable volatility: {stable_vol:.2f}%")
|
||||
logger.info(f"Volatile volatility: {volatile_vol:.2f}%")
|
||||
logger.info("✅ Volatility measurement test passed")
|
||||
|
||||
def run_indicator_tests():
|
||||
"""Run indicator tests only"""
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(TestTechnicalIndicators)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(suite)
|
||||
return result.wasSuccessful()
|
||||
|
||||
def run_signal_tests():
|
||||
"""Run signal processing tests only"""
|
||||
test_suites = [
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestSignalProcessing),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestMarketDataAnalysis),
|
||||
]
|
||||
|
||||
combined_suite = unittest.TestSuite(test_suites)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(combined_suite)
|
||||
return result.wasSuccessful()
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all indicator and signal tests"""
|
||||
test_suites = [
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestTechnicalIndicators),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestSignalProcessing),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestMarketDataAnalysis),
|
||||
]
|
||||
|
||||
combined_suite = unittest.TestSuite(test_suites)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(combined_suite)
|
||||
return result.wasSuccessful()
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logging()
|
||||
logger.info("Running indicators and signals test suite...")
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
test_type = sys.argv[1]
|
||||
if test_type == "indicators":
|
||||
success = run_indicator_tests()
|
||||
elif test_type == "signals":
|
||||
success = run_signal_tests()
|
||||
else:
|
||||
success = run_all_tests()
|
||||
else:
|
||||
success = run_all_tests()
|
||||
|
||||
if success:
|
||||
logger.info("✅ All indicator and signal tests passed!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("❌ Some tests failed!")
|
||||
sys.exit(1)
|
274
tests/test_model_persistence.py
Normal file
274
tests/test_model_persistence.py
Normal file
@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Comprehensive test suite for model persistence and training functionality
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
import tempfile
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from utils.model_utils import robust_save, robust_load, get_model_info, verify_save_load_cycle
|
||||
|
||||
# Configure logging for tests
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MockAgent:
|
||||
"""Mock agent class for testing model persistence"""
|
||||
|
||||
def __init__(self, state_size=64, action_size=4, hidden_size=256):
|
||||
self.state_size = state_size
|
||||
self.action_size = action_size
|
||||
self.hidden_size = hidden_size
|
||||
self.epsilon = 0.1
|
||||
|
||||
# Create simple mock networks
|
||||
self.policy_net = torch.nn.Sequential(
|
||||
torch.nn.Linear(state_size, hidden_size),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(hidden_size, action_size)
|
||||
)
|
||||
|
||||
self.target_net = torch.nn.Sequential(
|
||||
torch.nn.Linear(state_size, hidden_size),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(hidden_size, action_size)
|
||||
)
|
||||
|
||||
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=0.001)
|
||||
|
||||
class TestModelPersistence(unittest.TestCase):
|
||||
"""Test suite for model saving and loading functionality"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.test_agent = MockAgent()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures"""
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_robust_save_basic(self):
|
||||
"""Test basic robust save functionality"""
|
||||
save_path = os.path.join(self.temp_dir, "test_model.pt")
|
||||
|
||||
success = robust_save(self.test_agent, save_path)
|
||||
self.assertTrue(success, "Robust save should succeed")
|
||||
self.assertTrue(os.path.exists(save_path), "Model file should exist")
|
||||
self.assertGreater(os.path.getsize(save_path), 0, "Model file should not be empty")
|
||||
|
||||
def test_robust_save_without_optimizer(self):
|
||||
"""Test robust save without optimizer state"""
|
||||
save_path = os.path.join(self.temp_dir, "test_model_no_opt.pt")
|
||||
|
||||
success = robust_save(self.test_agent, save_path, include_optimizer=False)
|
||||
self.assertTrue(success, "Robust save without optimizer should succeed")
|
||||
|
||||
# Verify that optimizer state is not included
|
||||
checkpoint = torch.load(save_path, map_location='cpu')
|
||||
self.assertNotIn('optimizer', checkpoint, "Optimizer state should not be saved")
|
||||
self.assertIn('policy_net', checkpoint, "Policy network should be saved")
|
||||
|
||||
def test_robust_load_basic(self):
|
||||
"""Test basic robust load functionality"""
|
||||
save_path = os.path.join(self.temp_dir, "test_model.pt")
|
||||
|
||||
# Save first
|
||||
success = robust_save(self.test_agent, save_path)
|
||||
self.assertTrue(success, "Save should succeed")
|
||||
|
||||
# Create new agent and load
|
||||
new_agent = MockAgent()
|
||||
success = robust_load(new_agent, save_path)
|
||||
self.assertTrue(success, "Load should succeed")
|
||||
|
||||
# Verify epsilon was loaded
|
||||
self.assertEqual(new_agent.epsilon, self.test_agent.epsilon, "Epsilon should match")
|
||||
|
||||
def test_get_model_info(self):
|
||||
"""Test model info extraction"""
|
||||
save_path = os.path.join(self.temp_dir, "test_model.pt")
|
||||
|
||||
# Test non-existent file
|
||||
info = get_model_info(save_path)
|
||||
self.assertFalse(info['exists'], "Non-existent file should return exists=False")
|
||||
|
||||
# Save model and test info
|
||||
robust_save(self.test_agent, save_path)
|
||||
info = get_model_info(save_path)
|
||||
|
||||
self.assertTrue(info['exists'], "Existing file should return exists=True")
|
||||
self.assertGreater(info['size_bytes'], 0, "File size should be greater than 0")
|
||||
self.assertTrue(info['has_optimizer'], "Should detect optimizer in checkpoint")
|
||||
self.assertEqual(info['parameters']['state_size'], self.test_agent.state_size)
|
||||
self.assertEqual(info['parameters']['action_size'], self.test_agent.action_size)
|
||||
|
||||
def test_save_load_cycle_verification(self):
|
||||
"""Test save/load cycle verification"""
|
||||
test_path = os.path.join(self.temp_dir, "cycle_test.pt")
|
||||
|
||||
success = verify_save_load_cycle(self.test_agent, test_path)
|
||||
self.assertTrue(success, "Save/load cycle should succeed")
|
||||
|
||||
# File should be cleaned up after verification
|
||||
self.assertFalse(os.path.exists(test_path), "Test file should be cleaned up")
|
||||
|
||||
def test_multiple_save_methods(self):
|
||||
"""Test that different save methods all work"""
|
||||
methods = ['regular', 'no_optimizer', 'pickle2']
|
||||
|
||||
for method in methods:
|
||||
with self.subTest(method=method):
|
||||
save_path = os.path.join(self.temp_dir, f"test_{method}.pt")
|
||||
|
||||
if method == 'regular':
|
||||
success = robust_save(self.test_agent, save_path)
|
||||
elif method == 'no_optimizer':
|
||||
success = robust_save(self.test_agent, save_path, include_optimizer=False)
|
||||
elif method == 'pickle2':
|
||||
# This would be tested by the robust_save fallback mechanism
|
||||
success = robust_save(self.test_agent, save_path)
|
||||
|
||||
self.assertTrue(success, f"{method} save should succeed")
|
||||
self.assertTrue(os.path.exists(save_path), f"{method} save should create file")
|
||||
|
||||
class TestTrainingMetrics(unittest.TestCase):
|
||||
"""Test suite for training metrics and monitoring functionality"""
|
||||
|
||||
def test_signal_distribution_calculation(self):
|
||||
"""Test signal distribution calculation"""
|
||||
# Mock predictions
|
||||
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0]) # SELL, HOLD, BUY
|
||||
|
||||
buy_count = np.sum(predictions == 2)
|
||||
sell_count = np.sum(predictions == 0)
|
||||
hold_count = np.sum(predictions == 1)
|
||||
total = len(predictions)
|
||||
|
||||
distribution = {
|
||||
"BUY": buy_count / total,
|
||||
"SELL": sell_count / total,
|
||||
"HOLD": hold_count / total
|
||||
}
|
||||
|
||||
self.assertAlmostEqual(distribution["BUY"], 0.3, places=2)
|
||||
self.assertAlmostEqual(distribution["SELL"], 0.3, places=2)
|
||||
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2)
|
||||
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2)
|
||||
|
||||
def test_metrics_tracking_structure(self):
|
||||
"""Test metrics history structure for training monitoring"""
|
||||
metrics_history = {
|
||||
"epoch": [],
|
||||
"train_loss": [],
|
||||
"val_loss": [],
|
||||
"train_acc": [],
|
||||
"val_acc": [],
|
||||
"train_pnl": [],
|
||||
"val_pnl": [],
|
||||
"train_win_rate": [],
|
||||
"val_win_rate": [],
|
||||
"signal_distribution": []
|
||||
}
|
||||
|
||||
# Simulate adding metrics for one epoch
|
||||
metrics_history["epoch"].append(1)
|
||||
metrics_history["train_loss"].append(0.5)
|
||||
metrics_history["val_loss"].append(0.6)
|
||||
metrics_history["train_acc"].append(0.7)
|
||||
metrics_history["val_acc"].append(0.65)
|
||||
metrics_history["train_pnl"].append(0.1)
|
||||
metrics_history["val_pnl"].append(0.08)
|
||||
metrics_history["train_win_rate"].append(0.6)
|
||||
metrics_history["val_win_rate"].append(0.55)
|
||||
metrics_history["signal_distribution"].append({"BUY": 0.3, "SELL": 0.3, "HOLD": 0.4})
|
||||
|
||||
# Verify structure
|
||||
self.assertEqual(len(metrics_history["epoch"]), 1)
|
||||
self.assertEqual(metrics_history["epoch"][0], 1)
|
||||
self.assertIsInstance(metrics_history["signal_distribution"][0], dict)
|
||||
self.assertIn("BUY", metrics_history["signal_distribution"][0])
|
||||
|
||||
class TestModelArchitecture(unittest.TestCase):
|
||||
"""Test suite for model architecture verification"""
|
||||
|
||||
def test_model_parameter_consistency(self):
|
||||
"""Test that model parameters are consistent after save/load"""
|
||||
agent = MockAgent(state_size=32, action_size=3, hidden_size=128)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
save_path = os.path.join(temp_dir, "consistency_test.pt")
|
||||
|
||||
# Save model
|
||||
robust_save(agent, save_path)
|
||||
|
||||
# Load into new model with same architecture
|
||||
new_agent = MockAgent(state_size=32, action_size=3, hidden_size=128)
|
||||
robust_load(new_agent, save_path)
|
||||
|
||||
# Verify parameters match
|
||||
self.assertEqual(new_agent.state_size, agent.state_size)
|
||||
self.assertEqual(new_agent.action_size, agent.action_size)
|
||||
self.assertEqual(new_agent.hidden_size, agent.hidden_size)
|
||||
self.assertEqual(new_agent.epsilon, agent.epsilon)
|
||||
|
||||
def test_model_forward_pass(self):
|
||||
"""Test that model can perform forward pass after load"""
|
||||
agent = MockAgent()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
save_path = os.path.join(temp_dir, "forward_test.pt")
|
||||
|
||||
# Create test input
|
||||
test_input = torch.randn(1, agent.state_size)
|
||||
|
||||
# Get original output
|
||||
original_output = agent.policy_net(test_input)
|
||||
|
||||
# Save and load
|
||||
robust_save(agent, save_path)
|
||||
new_agent = MockAgent()
|
||||
robust_load(new_agent, save_path)
|
||||
|
||||
# Test forward pass works
|
||||
new_output = new_agent.policy_net(test_input)
|
||||
|
||||
self.assertEqual(new_output.shape, original_output.shape)
|
||||
# Outputs should be identical since we loaded the same weights
|
||||
torch.testing.assert_close(new_output, original_output)
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all test suites"""
|
||||
test_suites = [
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestModelPersistence),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestTrainingMetrics),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestModelArchitecture)
|
||||
]
|
||||
|
||||
combined_suite = unittest.TestSuite(test_suites)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(combined_suite)
|
||||
|
||||
return result.wasSuccessful()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Running comprehensive model persistence and training tests...")
|
||||
success = run_all_tests()
|
||||
|
||||
if success:
|
||||
logger.info("All tests passed!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("Some tests failed!")
|
||||
sys.exit(1)
|
395
tests/test_training_integration.py
Normal file
395
tests/test_training_integration.py
Normal file
@ -0,0 +1,395 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive Training Integration Tests
|
||||
|
||||
This module consolidates and improves test functionality from multiple test files:
|
||||
- CNN training tests (from test_cnn_only.py, test_training.py)
|
||||
- Model testing (from test_model.py)
|
||||
- Chart data testing (from test_chart_data.py)
|
||||
- Integration testing between components
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
import unittest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from core.config import setup_logging, get_config
|
||||
from core.data_provider import DataProvider
|
||||
from training.cnn_trainer import CNNTrainer
|
||||
from training.rl_trainer import RLTrainer
|
||||
from dataprovider_realtime import RealTimeChart, TickStorage, BinanceHistoricalData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TestDataProviders(unittest.TestCase):
|
||||
"""Test suite for data provider functionality"""
|
||||
|
||||
def test_binance_historical_data(self):
|
||||
"""Test Binance historical data fetching"""
|
||||
logger.info("Testing Binance historical data fetch...")
|
||||
|
||||
try:
|
||||
binance_data = BinanceHistoricalData()
|
||||
df = binance_data.get_historical_candles("ETH/USDT", 60, 100)
|
||||
|
||||
self.assertIsNotNone(df, "Should fetch data successfully")
|
||||
self.assertFalse(df.empty, "Data should not be empty")
|
||||
self.assertGreater(len(df), 0, "Should have candles")
|
||||
|
||||
# Verify data structure
|
||||
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
for col in required_columns:
|
||||
self.assertIn(col, df.columns, f"Should have {col} column")
|
||||
|
||||
logger.info(f"✅ Successfully fetched {len(df)} candles")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Binance API test failed: {e}")
|
||||
self.skipTest("Binance API not available")
|
||||
|
||||
def test_tick_storage(self):
|
||||
"""Test TickStorage functionality"""
|
||||
logger.info("Testing TickStorage data loading...")
|
||||
|
||||
try:
|
||||
tick_storage = TickStorage("ETH/USDT", ["1m", "5m", "1h"])
|
||||
success = tick_storage.load_historical_data("ETH/USDT", limit=100)
|
||||
|
||||
if success:
|
||||
# Check timeframes
|
||||
for tf in ["1m", "5m", "1h"]:
|
||||
candles = tick_storage.get_candles(tf)
|
||||
logger.info(f" {tf}: {len(candles)} candles")
|
||||
|
||||
logger.info("✅ TickStorage working correctly")
|
||||
return True
|
||||
else:
|
||||
self.skipTest("Could not load tick storage data")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"TickStorage test failed: {e}")
|
||||
self.skipTest("TickStorage not available")
|
||||
|
||||
def test_chart_initialization(self):
|
||||
"""Test RealTimeChart initialization"""
|
||||
logger.info("Testing RealTimeChart initialization...")
|
||||
|
||||
try:
|
||||
chart = RealTimeChart(app=None, symbol="ETH/USDT", standalone=False)
|
||||
|
||||
# Test getting candles
|
||||
candles_1m = chart.get_candles(60)
|
||||
|
||||
self.assertIsInstance(candles_1m, list, "Should return list of candles")
|
||||
logger.info(f"✅ Chart initialized with {len(candles_1m)} 1m candles")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Chart initialization failed: {e}")
|
||||
self.skipTest("Chart initialization not available")
|
||||
|
||||
class TestCNNTraining(unittest.TestCase):
|
||||
"""Test suite for CNN training functionality"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
setup_logging()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures"""
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_cnn_quick_training(self):
|
||||
"""Test quick CNN training with small dataset"""
|
||||
logger.info("Testing CNN quick training...")
|
||||
|
||||
try:
|
||||
config = get_config()
|
||||
|
||||
# Test configuration
|
||||
symbols = ['ETH/USDT']
|
||||
timeframes = ['1m', '5m']
|
||||
num_samples = 100 # Very small for testing
|
||||
epochs = 1
|
||||
batch_size = 16
|
||||
|
||||
# Override config for quick test
|
||||
config._config['timeframes'] = timeframes
|
||||
|
||||
trainer = CNNTrainer(config)
|
||||
trainer.batch_size = batch_size
|
||||
trainer.epochs = epochs
|
||||
|
||||
# Train model
|
||||
save_path = os.path.join(self.temp_dir, 'test_cnn.pt')
|
||||
results = trainer.train(symbols, save_path=save_path, num_samples=num_samples)
|
||||
|
||||
# Verify results
|
||||
self.assertIsInstance(results, dict, "Should return results dict")
|
||||
self.assertIn('best_val_accuracy', results, "Should have accuracy metric")
|
||||
self.assertIn('total_epochs', results, "Should have epoch count")
|
||||
self.assertIn('training_time', results, "Should have training time")
|
||||
|
||||
# Verify model was saved
|
||||
self.assertTrue(os.path.exists(save_path), "Model should be saved")
|
||||
|
||||
logger.info(f"✅ CNN training completed successfully")
|
||||
logger.info(f" Best accuracy: {results['best_val_accuracy']:.4f}")
|
||||
logger.info(f" Training time: {results['training_time']:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CNN training test failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
if hasattr(trainer, 'close_tensorboard'):
|
||||
trainer.close_tensorboard()
|
||||
|
||||
class TestRLTraining(unittest.TestCase):
|
||||
"""Test suite for RL training functionality"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
setup_logging()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures"""
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_rl_quick_training(self):
|
||||
"""Test quick RL training with small dataset"""
|
||||
logger.info("Testing RL quick training...")
|
||||
|
||||
try:
|
||||
# Setup minimal configuration
|
||||
data_provider = DataProvider(['ETH/USDT'], ['1m', '5m'])
|
||||
trainer = RLTrainer(data_provider)
|
||||
|
||||
# Configure for very quick test
|
||||
trainer.num_episodes = 5
|
||||
trainer.max_steps_per_episode = 50
|
||||
trainer.evaluation_frequency = 3
|
||||
trainer.save_frequency = 10 # Don't save during test
|
||||
|
||||
# Train
|
||||
save_path = os.path.join(self.temp_dir, 'test_rl.pt')
|
||||
results = trainer.train(save_path=save_path)
|
||||
|
||||
# Verify results
|
||||
self.assertIsInstance(results, dict, "Should return results dict")
|
||||
self.assertIn('total_episodes', results, "Should have episode count")
|
||||
self.assertIn('best_reward', results, "Should have best reward")
|
||||
self.assertIn('final_evaluation', results, "Should have final evaluation")
|
||||
|
||||
logger.info(f"✅ RL training completed successfully")
|
||||
logger.info(f" Total episodes: {results['total_episodes']}")
|
||||
logger.info(f" Best reward: {results['best_reward']:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RL training test failed: {e}")
|
||||
raise
|
||||
|
||||
class TestExtendedTraining(unittest.TestCase):
|
||||
"""Test suite for extended training functionality (from test_model.py)"""
|
||||
|
||||
def test_metrics_tracking(self):
|
||||
"""Test comprehensive metrics tracking functionality"""
|
||||
logger.info("Testing extended metrics tracking...")
|
||||
|
||||
# Test metrics history structure
|
||||
metrics_history = {
|
||||
"epoch": [],
|
||||
"train_loss": [],
|
||||
"val_loss": [],
|
||||
"train_acc": [],
|
||||
"val_acc": [],
|
||||
"train_pnl": [],
|
||||
"val_pnl": [],
|
||||
"train_win_rate": [],
|
||||
"val_win_rate": [],
|
||||
"signal_distribution": []
|
||||
}
|
||||
|
||||
# Simulate adding metrics
|
||||
for epoch in range(3):
|
||||
metrics_history["epoch"].append(epoch + 1)
|
||||
metrics_history["train_loss"].append(0.5 - epoch * 0.1)
|
||||
metrics_history["val_loss"].append(0.6 - epoch * 0.1)
|
||||
metrics_history["train_acc"].append(0.6 + epoch * 0.05)
|
||||
metrics_history["val_acc"].append(0.55 + epoch * 0.05)
|
||||
metrics_history["train_pnl"].append(epoch * 0.1)
|
||||
metrics_history["val_pnl"].append(epoch * 0.08)
|
||||
metrics_history["train_win_rate"].append(0.5 + epoch * 0.1)
|
||||
metrics_history["val_win_rate"].append(0.45 + epoch * 0.1)
|
||||
metrics_history["signal_distribution"].append({
|
||||
"BUY": 0.3, "SELL": 0.3, "HOLD": 0.4
|
||||
})
|
||||
|
||||
# Verify structure
|
||||
self.assertEqual(len(metrics_history["epoch"]), 3)
|
||||
self.assertEqual(len(metrics_history["train_loss"]), 3)
|
||||
self.assertEqual(len(metrics_history["signal_distribution"]), 3)
|
||||
|
||||
# Verify improvement
|
||||
self.assertLess(metrics_history["train_loss"][-1], metrics_history["train_loss"][0])
|
||||
self.assertGreater(metrics_history["train_acc"][-1], metrics_history["train_acc"][0])
|
||||
|
||||
logger.info("✅ Metrics tracking test passed")
|
||||
|
||||
def test_signal_distribution_calculation(self):
|
||||
"""Test signal distribution calculation"""
|
||||
import numpy as np
|
||||
|
||||
# Mock predictions (SELL=0, HOLD=1, BUY=2)
|
||||
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0])
|
||||
|
||||
buy_count = np.sum(predictions == 2)
|
||||
sell_count = np.sum(predictions == 0)
|
||||
hold_count = np.sum(predictions == 1)
|
||||
total = len(predictions)
|
||||
|
||||
distribution = {
|
||||
"BUY": buy_count / total,
|
||||
"SELL": sell_count / total,
|
||||
"HOLD": hold_count / total
|
||||
}
|
||||
|
||||
# Verify calculations
|
||||
self.assertAlmostEqual(distribution["BUY"], 0.3, places=2)
|
||||
self.assertAlmostEqual(distribution["SELL"], 0.3, places=2)
|
||||
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2)
|
||||
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2)
|
||||
|
||||
logger.info("✅ Signal distribution calculation test passed")
|
||||
|
||||
class TestIntegration(unittest.TestCase):
|
||||
"""Integration tests between components"""
|
||||
|
||||
def test_training_pipeline_integration(self):
|
||||
"""Test that CNN and RL training can work together"""
|
||||
logger.info("Testing training pipeline integration...")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
try:
|
||||
# Quick CNN training
|
||||
config = get_config()
|
||||
config._config['timeframes'] = ['1m']
|
||||
|
||||
cnn_trainer = CNNTrainer(config)
|
||||
cnn_trainer.epochs = 1
|
||||
cnn_trainer.batch_size = 8
|
||||
|
||||
cnn_path = os.path.join(temp_dir, 'test_cnn.pt')
|
||||
cnn_results = cnn_trainer.train(['ETH/USDT'], save_path=cnn_path, num_samples=50)
|
||||
|
||||
# Quick RL training
|
||||
data_provider = DataProvider(['ETH/USDT'], ['1m'])
|
||||
rl_trainer = RLTrainer(data_provider)
|
||||
rl_trainer.num_episodes = 3
|
||||
rl_trainer.max_steps_per_episode = 25
|
||||
|
||||
rl_path = os.path.join(temp_dir, 'test_rl.pt')
|
||||
rl_results = rl_trainer.train(save_path=rl_path)
|
||||
|
||||
# Verify both trained successfully
|
||||
self.assertIsInstance(cnn_results, dict)
|
||||
self.assertIsInstance(rl_results, dict)
|
||||
self.assertTrue(os.path.exists(cnn_path))
|
||||
self.assertTrue(os.path.exists(rl_path))
|
||||
|
||||
logger.info("✅ Training pipeline integration test passed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Integration test failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
if 'cnn_trainer' in locals():
|
||||
cnn_trainer.close_tensorboard()
|
||||
|
||||
def run_quick_tests():
|
||||
"""Run only the quickest tests for fast validation"""
|
||||
test_suites = [
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestExtendedTraining),
|
||||
]
|
||||
|
||||
combined_suite = unittest.TestSuite(test_suites)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(combined_suite)
|
||||
|
||||
return result.wasSuccessful()
|
||||
|
||||
def run_data_tests():
|
||||
"""Run data provider tests"""
|
||||
test_suites = [
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestDataProviders),
|
||||
]
|
||||
|
||||
combined_suite = unittest.TestSuite(test_suites)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(combined_suite)
|
||||
|
||||
return result.wasSuccessful()
|
||||
|
||||
def run_training_tests():
|
||||
"""Run training tests (slower)"""
|
||||
test_suites = [
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestCNNTraining),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestRLTraining),
|
||||
]
|
||||
|
||||
combined_suite = unittest.TestSuite(test_suites)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(combined_suite)
|
||||
|
||||
return result.wasSuccessful()
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all test suites"""
|
||||
test_suites = [
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestDataProviders),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestCNNTraining),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestRLTraining),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestExtendedTraining),
|
||||
unittest.TestLoader().loadTestsFromTestCase(TestIntegration),
|
||||
]
|
||||
|
||||
combined_suite = unittest.TestSuite(test_suites)
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(combined_suite)
|
||||
|
||||
return result.wasSuccessful()
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logging()
|
||||
logger.info("Running comprehensive training integration tests...")
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
test_type = sys.argv[1]
|
||||
if test_type == "quick":
|
||||
success = run_quick_tests()
|
||||
elif test_type == "data":
|
||||
success = run_data_tests()
|
||||
elif test_type == "training":
|
||||
success = run_training_tests()
|
||||
else:
|
||||
success = run_all_tests()
|
||||
else:
|
||||
success = run_all_tests()
|
||||
|
||||
if success:
|
||||
logger.info("✅ All tests passed!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("❌ Some tests failed!")
|
||||
sys.exit(1)
|
@ -1,128 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import unittest
|
||||
from typing import Optional, Dict
|
||||
import websockets
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TestMEXCWebSocket(unittest.TestCase):
|
||||
async def test_websocket_connection(self):
|
||||
"""Test basic WebSocket connection and subscription"""
|
||||
uri = "wss://stream.mexc.com/ws"
|
||||
symbol = "ethusdt"
|
||||
|
||||
async with websockets.connect(uri) as ws:
|
||||
# Test subscription to deals
|
||||
sub_msg = {
|
||||
"op": "sub",
|
||||
"id": "test1",
|
||||
"topic": f"spot.deals.{symbol}"
|
||||
}
|
||||
|
||||
# Send subscription
|
||||
await ws.send(json.dumps(sub_msg))
|
||||
|
||||
# Wait for subscription confirmation and first message
|
||||
messages_received = 0
|
||||
trades_received = 0
|
||||
|
||||
while messages_received < 5: # Get at least 5 messages
|
||||
try:
|
||||
response = await asyncio.wait_for(ws.recv(), timeout=10)
|
||||
messages_received += 1
|
||||
|
||||
logger.info(f"Received message: {response[:200]}...")
|
||||
data = json.loads(response)
|
||||
|
||||
# Check message structure
|
||||
if isinstance(data, dict):
|
||||
if 'channel' in data:
|
||||
if data['channel'] == 'spot.deals':
|
||||
trades = data.get('data', [])
|
||||
if trades:
|
||||
trades_received += 1
|
||||
logger.info(f"Received trade data: {trades[0]}")
|
||||
|
||||
# Verify trade data structure
|
||||
trade = trades[0]
|
||||
self.assertIn('t', trade) # timestamp
|
||||
self.assertIn('p', trade) # price
|
||||
self.assertIn('v', trade) # volume
|
||||
self.assertIn('S', trade) # side
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.fail("Timeout waiting for WebSocket messages")
|
||||
|
||||
# Verify we received some trades
|
||||
self.assertGreater(trades_received, 0, "No trades received")
|
||||
|
||||
# Test unsubscribe
|
||||
unsub_msg = {
|
||||
"op": "unsub",
|
||||
"id": "test1",
|
||||
"topic": f"spot.deals.{symbol}"
|
||||
}
|
||||
await ws.send(json.dumps(unsub_msg))
|
||||
|
||||
async def test_kline_subscription(self):
|
||||
"""Test subscription to kline (candlestick) data"""
|
||||
uri = "wss://stream.mexc.com/ws"
|
||||
symbol = "ethusdt"
|
||||
|
||||
async with websockets.connect(uri) as ws:
|
||||
# Subscribe to 1m klines
|
||||
sub_msg = {
|
||||
"op": "sub",
|
||||
"id": "test2",
|
||||
"topic": f"spot.klines.{symbol}_1m"
|
||||
}
|
||||
|
||||
await ws.send(json.dumps(sub_msg))
|
||||
|
||||
messages_received = 0
|
||||
klines_received = 0
|
||||
|
||||
while messages_received < 5:
|
||||
try:
|
||||
response = await asyncio.wait_for(ws.recv(), timeout=10)
|
||||
messages_received += 1
|
||||
|
||||
logger.info(f"Received kline message: {response[:200]}...")
|
||||
data = json.loads(response)
|
||||
|
||||
if isinstance(data, dict):
|
||||
if data.get('channel') == 'spot.klines':
|
||||
kline_data = data.get('data', [])
|
||||
if kline_data:
|
||||
klines_received += 1
|
||||
logger.info(f"Received kline data: {kline_data[0]}")
|
||||
|
||||
# Verify kline data structure (should be an array)
|
||||
kline = kline_data[0]
|
||||
self.assertEqual(len(kline), 6) # Should have 6 elements
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.fail("Timeout waiting for kline data")
|
||||
|
||||
self.assertGreater(klines_received, 0, "No kline data received")
|
||||
|
||||
def run_tests():
|
||||
"""Run the WebSocket tests"""
|
||||
async def run_async_tests():
|
||||
# Create test suite
|
||||
suite = unittest.TestSuite()
|
||||
suite.addTest(TestMEXCWebSocket('test_websocket_connection'))
|
||||
suite.addTest(TestMEXCWebSocket('test_kline_subscription'))
|
||||
|
||||
# Run tests
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
runner.run(suite)
|
||||
|
||||
# Run tests in asyncio loop
|
||||
asyncio.run(run_async_tests())
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
Reference in New Issue
Block a user