201 lines
8.5 KiB
Python
201 lines
8.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test Enhanced COB Integration with RL and CNN Models
|
|
|
|
This script tests the integration of Consolidated Order Book (COB) data
|
|
with the real-time RL and CNN training pipeline.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import sys
|
|
from pathlib import Path
|
|
import numpy as np
|
|
import time
|
|
from datetime import datetime
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from core.config import setup_logging
|
|
from core.data_provider import DataProvider
|
|
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
|
from core.cob_integration import COBIntegration
|
|
|
|
# Setup logging
|
|
setup_logging()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class COBMLIntegrationTester:
|
|
"""Test COB integration with ML models"""
|
|
|
|
def __init__(self):
|
|
self.symbols = ['BTC/USDT', 'ETH/USDT']
|
|
self.data_provider = DataProvider()
|
|
self.test_results = {}
|
|
|
|
async def test_cob_ml_integration(self):
|
|
"""Test full COB integration with ML pipeline"""
|
|
logger.info("=" * 60)
|
|
logger.info("TESTING COB INTEGRATION WITH RL AND CNN MODELS")
|
|
logger.info("=" * 60)
|
|
|
|
try:
|
|
# Initialize enhanced orchestrator with COB integration
|
|
logger.info("1. Initializing Enhanced Trading Orchestrator with COB...")
|
|
orchestrator = EnhancedTradingOrchestrator(
|
|
data_provider=self.data_provider,
|
|
symbols=self.symbols,
|
|
enhanced_rl_training=True,
|
|
model_registry={}
|
|
)
|
|
|
|
# Start COB integration
|
|
logger.info("2. Starting COB Integration...")
|
|
await orchestrator.start_cob_integration()
|
|
await asyncio.sleep(5) # Allow startup and data collection
|
|
|
|
# Test COB feature generation
|
|
logger.info("3. Testing COB feature generation...")
|
|
await self._test_cob_features(orchestrator)
|
|
|
|
# Test market state with COB data
|
|
logger.info("4. Testing market state with COB data...")
|
|
await self._test_market_state_cob(orchestrator)
|
|
|
|
# Test real-time COB callbacks
|
|
logger.info("5. Testing real-time COB callbacks...")
|
|
await self._test_realtime_callbacks(orchestrator)
|
|
|
|
# Stop COB integration
|
|
await orchestrator.stop_cob_integration()
|
|
|
|
# Print results
|
|
self._print_test_results()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in COB ML integration test: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
|
|
async def _test_cob_features(self, orchestrator):
|
|
"""Test COB feature availability"""
|
|
try:
|
|
for symbol in self.symbols:
|
|
# Check if COB features are available
|
|
cob_features = orchestrator.latest_cob_features.get(symbol)
|
|
cob_state = orchestrator.latest_cob_state.get(symbol)
|
|
|
|
if cob_features is not None:
|
|
logger.info(f"✅ {symbol}: COB CNN features available - shape: {cob_features.shape}")
|
|
self.test_results[f'{symbol}_cob_cnn_features'] = True
|
|
else:
|
|
logger.warning(f"⚠️ {symbol}: COB CNN features not available")
|
|
self.test_results[f'{symbol}_cob_cnn_features'] = False
|
|
|
|
if cob_state is not None:
|
|
logger.info(f"✅ {symbol}: COB DQN state available - shape: {cob_state.shape}")
|
|
self.test_results[f'{symbol}_cob_dqn_state'] = True
|
|
else:
|
|
logger.warning(f"⚠️ {symbol}: COB DQN state not available")
|
|
self.test_results[f'{symbol}_cob_dqn_state'] = False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error testing COB features: {e}")
|
|
|
|
async def _test_market_state_cob(self, orchestrator):
|
|
"""Test market state includes COB data"""
|
|
try:
|
|
# Generate market states with COB data
|
|
from core.universal_data_adapter import UniversalDataAdapter
|
|
adapter = UniversalDataAdapter(self.data_provider)
|
|
universal_stream = await adapter.get_universal_stream(['BTC/USDT', 'ETH/USDT'])
|
|
|
|
market_states = await orchestrator._get_all_market_states_universal(universal_stream)
|
|
|
|
for symbol in self.symbols:
|
|
if symbol in market_states:
|
|
state = market_states[symbol]
|
|
|
|
# Check COB integration in market state
|
|
tests = [
|
|
('cob_features', state.cob_features is not None),
|
|
('cob_state', state.cob_state is not None),
|
|
('order_book_imbalance', hasattr(state, 'order_book_imbalance')),
|
|
('liquidity_depth', hasattr(state, 'liquidity_depth')),
|
|
('exchange_diversity', hasattr(state, 'exchange_diversity')),
|
|
('market_impact_estimate', hasattr(state, 'market_impact_estimate'))
|
|
]
|
|
|
|
for test_name, passed in tests:
|
|
status = "✅" if passed else "❌"
|
|
logger.info(f"{status} {symbol}: {test_name} - {passed}")
|
|
self.test_results[f'{symbol}_market_state_{test_name}'] = passed
|
|
|
|
# Log COB metrics if available
|
|
if hasattr(state, 'order_book_imbalance'):
|
|
logger.info(f"📊 {symbol} COB Metrics:")
|
|
logger.info(f" Order Book Imbalance: {state.order_book_imbalance:.4f}")
|
|
logger.info(f" Liquidity Depth: ${state.liquidity_depth:,.0f}")
|
|
logger.info(f" Exchange Diversity: {state.exchange_diversity}")
|
|
logger.info(f" Market Impact (10k): {state.market_impact_estimate:.4f}%")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error testing market state COB: {e}")
|
|
|
|
async def _test_realtime_callbacks(self, orchestrator):
|
|
"""Test real-time COB callbacks"""
|
|
try:
|
|
# Monitor COB callbacks for 10 seconds
|
|
initial_features = {s: len(orchestrator.cob_feature_history[s]) for s in self.symbols}
|
|
|
|
logger.info("Monitoring COB callbacks for 10 seconds...")
|
|
await asyncio.sleep(10)
|
|
|
|
final_features = {s: len(orchestrator.cob_feature_history[s]) for s in self.symbols}
|
|
|
|
for symbol in self.symbols:
|
|
updates = final_features[symbol] - initial_features[symbol]
|
|
if updates > 0:
|
|
logger.info(f"✅ {symbol}: Received {updates} COB feature updates")
|
|
self.test_results[f'{symbol}_realtime_callbacks'] = True
|
|
else:
|
|
logger.warning(f"⚠️ {symbol}: No COB feature updates received")
|
|
self.test_results[f'{symbol}_realtime_callbacks'] = False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error testing realtime callbacks: {e}")
|
|
|
|
def _print_test_results(self):
|
|
"""Print comprehensive test results"""
|
|
logger.info("=" * 60)
|
|
logger.info("COB ML INTEGRATION TEST RESULTS")
|
|
logger.info("=" * 60)
|
|
|
|
passed = sum(1 for result in self.test_results.values() if result)
|
|
total = len(self.test_results)
|
|
|
|
logger.info(f"Overall: {passed}/{total} tests passed ({passed/total*100:.1f}%)")
|
|
logger.info("")
|
|
|
|
for test_name, result in self.test_results.items():
|
|
status = "✅ PASS" if result else "❌ FAIL"
|
|
logger.info(f"{status}: {test_name}")
|
|
|
|
logger.info("=" * 60)
|
|
|
|
if passed == total:
|
|
logger.info("🎉 ALL TESTS PASSED - COB ML INTEGRATION WORKING!")
|
|
elif passed > total * 0.8:
|
|
logger.info("⚠️ MOSTLY WORKING - Some minor issues detected")
|
|
else:
|
|
logger.warning("🚨 INTEGRATION ISSUES - Significant problems detected")
|
|
|
|
async def main():
|
|
"""Run COB ML integration tests"""
|
|
tester = COBMLIntegrationTester()
|
|
await tester.test_cob_ml_integration()
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |