402 lines
16 KiB
Python
402 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Comprehensive Indicators and Signals Test Suite
|
|
|
|
This module consolidates testing functionality for:
|
|
- Technical indicators (from test_indicators.py)
|
|
- Signal interpretation and processing (from test_signal_interpreter.py)
|
|
- Market data analysis
|
|
- Trading signal validation
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import unittest
|
|
import logging
|
|
import numpy as np
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent.parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from core.config import setup_logging
|
|
from core.data_provider import DataProvider
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TestTechnicalIndicators(unittest.TestCase):
|
|
"""Test suite for technical indicators functionality"""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures"""
|
|
setup_logging()
|
|
self.data_provider = DataProvider(['ETH/USDT'], ['1h'])
|
|
|
|
def test_indicator_calculation(self):
|
|
"""Test that indicators are calculated correctly"""
|
|
logger.info("Testing technical indicators calculation...")
|
|
|
|
try:
|
|
# Fetch data with indicators
|
|
df = self.data_provider.get_historical_data('ETH/USDT', '1h', refresh=True, limit=100)
|
|
|
|
self.assertIsNotNone(df, "Should fetch data successfully")
|
|
self.assertGreater(len(df), 0, "Should have data rows")
|
|
|
|
# Check basic OHLCV columns
|
|
basic_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
|
for col in basic_cols:
|
|
self.assertIn(col, df.columns, f"Should have {col} column")
|
|
|
|
# Check that indicators are calculated
|
|
indicator_cols = [col for col in df.columns if col not in basic_cols]
|
|
self.assertGreater(len(indicator_cols), 0, "Should have technical indicators")
|
|
|
|
logger.info(f"✅ Successfully calculated {len(indicator_cols)} indicators")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Indicator test failed: {e}")
|
|
self.skipTest("Data or indicators not available")
|
|
|
|
def test_indicator_categorization(self):
|
|
"""Test categorization of different indicator types"""
|
|
logger.info("Testing indicator categorization...")
|
|
|
|
try:
|
|
df = self.data_provider.get_historical_data('ETH/USDT', '1h', refresh=True, limit=100)
|
|
|
|
if df is not None:
|
|
basic_cols = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
|
indicator_cols = [col for col in df.columns if col not in basic_cols]
|
|
|
|
# Categorize indicators
|
|
trend_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['sma', 'ema', 'macd', 'adx', 'psar'])]
|
|
momentum_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['rsi', 'stoch', 'williams', 'cci'])]
|
|
volatility_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['bb_', 'atr', 'keltner'])]
|
|
volume_indicators = [col for col in indicator_cols if any(x in col.lower() for x in ['volume', 'obv', 'vpt', 'mfi', 'ad_line', 'vwap'])]
|
|
|
|
# Check we have indicators in each category
|
|
total_categorized = len(trend_indicators) + len(momentum_indicators) + len(volatility_indicators) + len(volume_indicators)
|
|
|
|
logger.info(f"Indicator categories:")
|
|
logger.info(f" Trend: {len(trend_indicators)}")
|
|
logger.info(f" Momentum: {len(momentum_indicators)}")
|
|
logger.info(f" Volatility: {len(volatility_indicators)}")
|
|
logger.info(f" Volume: {len(volume_indicators)}")
|
|
logger.info(f" Total categorized: {total_categorized}/{len(indicator_cols)}")
|
|
|
|
self.assertGreater(total_categorized, 0, "Should have categorized indicators")
|
|
|
|
else:
|
|
self.skipTest("Could not fetch data for categorization test")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Categorization test failed: {e}")
|
|
self.skipTest("Indicator categorization not available")
|
|
|
|
def test_feature_matrix_creation(self):
|
|
"""Test multi-timeframe feature matrix creation"""
|
|
logger.info("Testing feature matrix creation...")
|
|
|
|
try:
|
|
# Test feature matrix with multiple timeframes
|
|
feature_matrix = self.data_provider.get_feature_matrix('ETH/USDT', ['1h'], window_size=20)
|
|
|
|
if feature_matrix is not None:
|
|
self.assertEqual(len(feature_matrix.shape), 3, "Should be 3D matrix")
|
|
self.assertGreater(feature_matrix.shape[2], 0, "Should have features")
|
|
|
|
logger.info(f"✅ Feature matrix shape: {feature_matrix.shape}")
|
|
|
|
else:
|
|
self.skipTest("Could not create feature matrix")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Feature matrix test failed: {e}")
|
|
self.skipTest("Feature matrix creation not available")
|
|
|
|
class TestSignalProcessing(unittest.TestCase):
|
|
"""Test suite for signal interpretation and processing"""
|
|
|
|
def test_signal_distribution_calculation(self):
|
|
"""Test signal distribution calculation"""
|
|
logger.info("Testing signal distribution calculation...")
|
|
|
|
# Mock predictions (SELL=0, HOLD=1, BUY=2)
|
|
predictions = np.array([0, 1, 2, 1, 0, 2, 1, 1, 2, 0])
|
|
|
|
buy_count = np.sum(predictions == 2)
|
|
sell_count = np.sum(predictions == 0)
|
|
hold_count = np.sum(predictions == 1)
|
|
total = len(predictions)
|
|
|
|
distribution = {
|
|
"BUY": buy_count / total,
|
|
"SELL": sell_count / total,
|
|
"HOLD": hold_count / total
|
|
}
|
|
|
|
# Verify calculations
|
|
self.assertAlmostEqual(distribution["BUY"], 0.3, places=2)
|
|
self.assertAlmostEqual(distribution["SELL"], 0.3, places=2)
|
|
self.assertAlmostEqual(distribution["HOLD"], 0.4, places=2)
|
|
self.assertAlmostEqual(sum(distribution.values()), 1.0, places=2)
|
|
|
|
logger.info("✅ Signal distribution calculation test passed")
|
|
|
|
def test_basic_signal_interpretation(self):
|
|
"""Test basic signal interpretation logic"""
|
|
logger.info("Testing basic signal interpretation...")
|
|
|
|
# Test cases with different probability distributions
|
|
test_cases = [
|
|
{
|
|
'probs': [0.8, 0.1, 0.1], # Strong SELL
|
|
'expected_action': 'SELL',
|
|
'expected_confidence': 'high'
|
|
},
|
|
{
|
|
'probs': [0.1, 0.1, 0.8], # Strong BUY
|
|
'expected_action': 'BUY',
|
|
'expected_confidence': 'high'
|
|
},
|
|
{
|
|
'probs': [0.1, 0.8, 0.1], # Strong HOLD
|
|
'expected_action': 'HOLD',
|
|
'expected_confidence': 'high'
|
|
},
|
|
{
|
|
'probs': [0.4, 0.3, 0.3], # Uncertain - should prefer SELL (index 0)
|
|
'expected_action': 'SELL',
|
|
'expected_confidence': 'low'
|
|
},
|
|
{
|
|
'probs': [0.33, 0.33, 0.34], # Very uncertain - slight BUY preference
|
|
'expected_action': 'BUY',
|
|
'expected_confidence': 'low'
|
|
}
|
|
]
|
|
|
|
for i, test_case in enumerate(test_cases):
|
|
probs = np.array(test_case['probs'])
|
|
expected_action = test_case['expected_action']
|
|
|
|
# Simple signal interpretation (argmax)
|
|
predicted_action_idx = np.argmax(probs)
|
|
action_map = {0: 'SELL', 1: 'HOLD', 2: 'BUY'}
|
|
predicted_action = action_map[predicted_action_idx]
|
|
|
|
# Calculate confidence (max probability)
|
|
confidence = np.max(probs)
|
|
confidence_level = 'high' if confidence > 0.7 else 'medium' if confidence > 0.5 else 'low'
|
|
|
|
# Verify predictions
|
|
self.assertEqual(predicted_action, expected_action,
|
|
f"Test case {i+1}: Expected {expected_action}, got {predicted_action}")
|
|
|
|
logger.info(f"Test case {i+1}: {probs} -> {predicted_action} ({confidence_level} confidence)")
|
|
|
|
logger.info("✅ Basic signal interpretation test passed")
|
|
|
|
def test_signal_filtering_logic(self):
|
|
"""Test signal filtering and validation logic"""
|
|
logger.info("Testing signal filtering logic...")
|
|
|
|
# Test threshold-based filtering
|
|
buy_threshold = 0.6
|
|
sell_threshold = 0.6
|
|
hold_threshold = 0.7
|
|
|
|
test_signals = [
|
|
{
|
|
'probs': [0.8, 0.1, 0.1], # Strong SELL (above threshold)
|
|
'should_pass': True,
|
|
'expected': 'SELL'
|
|
},
|
|
{
|
|
'probs': [0.5, 0.3, 0.2], # Weak SELL (below threshold)
|
|
'should_pass': False,
|
|
'expected': 'HOLD'
|
|
},
|
|
{
|
|
'probs': [0.1, 0.2, 0.7], # Strong BUY (above threshold)
|
|
'should_pass': True,
|
|
'expected': 'BUY'
|
|
},
|
|
{
|
|
'probs': [0.2, 0.8, 0.0], # Strong HOLD (above threshold)
|
|
'should_pass': True,
|
|
'expected': 'HOLD'
|
|
}
|
|
]
|
|
|
|
for i, test in enumerate(test_signals):
|
|
probs = np.array(test['probs'])
|
|
sell_prob, hold_prob, buy_prob = probs
|
|
|
|
# Apply threshold filtering
|
|
if sell_prob >= sell_threshold:
|
|
filtered_action = 'SELL'
|
|
passed_filter = True
|
|
elif buy_prob >= buy_threshold:
|
|
filtered_action = 'BUY'
|
|
passed_filter = True
|
|
elif hold_prob >= hold_threshold:
|
|
filtered_action = 'HOLD'
|
|
passed_filter = True
|
|
else:
|
|
filtered_action = 'HOLD' # Default to HOLD if no threshold met
|
|
passed_filter = False
|
|
|
|
# Verify filtering
|
|
expected_pass = test['should_pass']
|
|
expected_action = test['expected']
|
|
|
|
self.assertEqual(passed_filter, expected_pass,
|
|
f"Test {i+1}: Filter pass expectation failed")
|
|
self.assertEqual(filtered_action, expected_action,
|
|
f"Test {i+1}: Expected {expected_action}, got {filtered_action}")
|
|
|
|
logger.info(f"Test {i+1}: {probs} -> {filtered_action} (passed: {passed_filter})")
|
|
|
|
logger.info("✅ Signal filtering logic test passed")
|
|
|
|
def test_signal_sequence_validation(self):
|
|
"""Test signal sequence validation and oscillation prevention"""
|
|
logger.info("Testing signal sequence validation...")
|
|
|
|
# Simulate a sequence of signals that might oscillate
|
|
signal_sequence = ['BUY', 'SELL', 'BUY', 'SELL', 'HOLD', 'BUY']
|
|
|
|
# Simple oscillation detection
|
|
oscillation_count = 0
|
|
for i in range(1, len(signal_sequence)):
|
|
if (signal_sequence[i-1] == 'BUY' and signal_sequence[i] == 'SELL') or \
|
|
(signal_sequence[i-1] == 'SELL' and signal_sequence[i] == 'BUY'):
|
|
oscillation_count += 1
|
|
|
|
# Count consecutive non-HOLD signals
|
|
consecutive_trades = 0
|
|
max_consecutive = 0
|
|
for signal in signal_sequence:
|
|
if signal != 'HOLD':
|
|
consecutive_trades += 1
|
|
max_consecutive = max(max_consecutive, consecutive_trades)
|
|
else:
|
|
consecutive_trades = 0
|
|
|
|
# Verify oscillation detection
|
|
self.assertGreater(oscillation_count, 0, "Should detect oscillations in test sequence")
|
|
self.assertGreater(max_consecutive, 1, "Should detect consecutive trades")
|
|
|
|
logger.info(f"Detected {oscillation_count} oscillations and max {max_consecutive} consecutive trades")
|
|
logger.info("✅ Signal sequence validation test passed")
|
|
|
|
class TestMarketDataAnalysis(unittest.TestCase):
|
|
"""Test suite for market data analysis functionality"""
|
|
|
|
def test_price_movement_calculation(self):
|
|
"""Test price movement and trend calculation"""
|
|
logger.info("Testing price movement calculation...")
|
|
|
|
# Mock price data
|
|
prices = np.array([100.0, 101.0, 102.5, 101.8, 103.2, 102.9, 104.1])
|
|
|
|
# Calculate price movements
|
|
price_changes = np.diff(prices)
|
|
percentage_changes = (price_changes / prices[:-1]) * 100
|
|
|
|
# Calculate simple trend
|
|
recent_trend = np.mean(percentage_changes[-3:]) # Last 3 changes
|
|
trend_direction = 'uptrend' if recent_trend > 0.1 else 'downtrend' if recent_trend < -0.1 else 'sideways'
|
|
|
|
# Verify calculations
|
|
self.assertEqual(len(price_changes), len(prices) - 1, "Should have n-1 price changes")
|
|
self.assertEqual(len(percentage_changes), len(prices) - 1, "Should have n-1 percentage changes")
|
|
|
|
# Verify trend detection makes sense
|
|
self.assertIn(trend_direction, ['uptrend', 'downtrend', 'sideways'], "Should detect valid trend")
|
|
|
|
logger.info(f"Price sequence: {prices}")
|
|
logger.info(f"Recent trend: {trend_direction} ({recent_trend:.2f}%)")
|
|
logger.info("✅ Price movement calculation test passed")
|
|
|
|
def test_volatility_measurement(self):
|
|
"""Test volatility measurement"""
|
|
logger.info("Testing volatility measurement...")
|
|
|
|
# Mock price data with different volatility
|
|
stable_prices = np.array([100.0, 100.1, 99.9, 100.2, 99.8, 100.0])
|
|
volatile_prices = np.array([100.0, 105.0, 95.0, 110.0, 90.0, 115.0])
|
|
|
|
# Calculate volatility (standard deviation of returns)
|
|
def calculate_volatility(prices):
|
|
returns = np.diff(prices) / prices[:-1]
|
|
return np.std(returns) * 100 # As percentage
|
|
|
|
stable_vol = calculate_volatility(stable_prices)
|
|
volatile_vol = calculate_volatility(volatile_prices)
|
|
|
|
# Verify volatility measurements
|
|
self.assertLess(stable_vol, volatile_vol, "Stable prices should have lower volatility")
|
|
self.assertGreater(volatile_vol, 5.0, "Volatile prices should have significant volatility")
|
|
|
|
logger.info(f"Stable volatility: {stable_vol:.2f}%")
|
|
logger.info(f"Volatile volatility: {volatile_vol:.2f}%")
|
|
logger.info("✅ Volatility measurement test passed")
|
|
|
|
def run_indicator_tests():
|
|
"""Run indicator tests only"""
|
|
suite = unittest.TestLoader().loadTestsFromTestCase(TestTechnicalIndicators)
|
|
runner = unittest.TextTestRunner(verbosity=2)
|
|
result = runner.run(suite)
|
|
return result.wasSuccessful()
|
|
|
|
def run_signal_tests():
|
|
"""Run signal processing tests only"""
|
|
test_suites = [
|
|
unittest.TestLoader().loadTestsFromTestCase(TestSignalProcessing),
|
|
unittest.TestLoader().loadTestsFromTestCase(TestMarketDataAnalysis),
|
|
]
|
|
|
|
combined_suite = unittest.TestSuite(test_suites)
|
|
runner = unittest.TextTestRunner(verbosity=2)
|
|
result = runner.run(combined_suite)
|
|
return result.wasSuccessful()
|
|
|
|
def run_all_tests():
|
|
"""Run all indicator and signal tests"""
|
|
test_suites = [
|
|
unittest.TestLoader().loadTestsFromTestCase(TestTechnicalIndicators),
|
|
unittest.TestLoader().loadTestsFromTestCase(TestSignalProcessing),
|
|
unittest.TestLoader().loadTestsFromTestCase(TestMarketDataAnalysis),
|
|
]
|
|
|
|
combined_suite = unittest.TestSuite(test_suites)
|
|
runner = unittest.TextTestRunner(verbosity=2)
|
|
result = runner.run(combined_suite)
|
|
return result.wasSuccessful()
|
|
|
|
if __name__ == "__main__":
|
|
setup_logging()
|
|
logger.info("Running indicators and signals test suite...")
|
|
|
|
if len(sys.argv) > 1:
|
|
test_type = sys.argv[1]
|
|
if test_type == "indicators":
|
|
success = run_indicator_tests()
|
|
elif test_type == "signals":
|
|
success = run_signal_tests()
|
|
else:
|
|
success = run_all_tests()
|
|
else:
|
|
success = run_all_tests()
|
|
|
|
if success:
|
|
logger.info("✅ All indicator and signal tests passed!")
|
|
sys.exit(0)
|
|
else:
|
|
logger.error("❌ Some tests failed!")
|
|
sys.exit(1) |