262 lines
10 KiB
Python
262 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test Universal Data Format Compliance
|
|
|
|
This script verifies that our enhanced trading system properly feeds
|
|
the 5 required timeseries streams to all models:
|
|
- ETH/USDT: ticks (1s), 1m, 1h, 1d
|
|
- BTC/USDT: ticks (1s) as reference
|
|
|
|
This is our universal trading system input format.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import sys
|
|
from pathlib import Path
|
|
import numpy as np
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from core.config import get_config
|
|
from core.data_provider import DataProvider
|
|
from core.universal_data_adapter import UniversalDataAdapter, UniversalDataStream
|
|
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
|
from training.enhanced_cnn_trainer import EnhancedCNNTrainer
|
|
from training.enhanced_rl_trainer import EnhancedRLTrainer
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
async def test_universal_data_format():
|
|
"""Test that all components properly use the universal 5-timeseries format"""
|
|
logger.info("="*80)
|
|
logger.info("🧪 TESTING UNIVERSAL DATA FORMAT COMPLIANCE")
|
|
logger.info("="*80)
|
|
|
|
try:
|
|
# Initialize components
|
|
config = get_config()
|
|
data_provider = DataProvider(config)
|
|
|
|
# Test 1: Universal Data Adapter
|
|
logger.info("\n📊 TEST 1: Universal Data Adapter")
|
|
logger.info("-" * 40)
|
|
|
|
adapter = UniversalDataAdapter(data_provider)
|
|
universal_stream = adapter.get_universal_data_stream()
|
|
|
|
if universal_stream is None:
|
|
logger.error("❌ Failed to get universal data stream")
|
|
return False
|
|
|
|
# Validate format
|
|
is_valid, issues = adapter.validate_universal_format(universal_stream)
|
|
if not is_valid:
|
|
logger.error(f"❌ Universal format validation failed: {issues}")
|
|
return False
|
|
|
|
logger.info("✅ Universal Data Adapter: PASSED")
|
|
logger.info(f" ETH ticks: {len(universal_stream.eth_ticks)} samples")
|
|
logger.info(f" ETH 1m: {len(universal_stream.eth_1m)} candles")
|
|
logger.info(f" ETH 1h: {len(universal_stream.eth_1h)} candles")
|
|
logger.info(f" ETH 1d: {len(universal_stream.eth_1d)} candles")
|
|
logger.info(f" BTC reference: {len(universal_stream.btc_ticks)} samples")
|
|
logger.info(f" Data quality: {universal_stream.metadata['data_quality']['overall_score']:.2f}")
|
|
|
|
# Test 2: Enhanced Orchestrator
|
|
logger.info("\n🎯 TEST 2: Enhanced Orchestrator")
|
|
logger.info("-" * 40)
|
|
|
|
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
|
|
|
# Test that orchestrator uses universal adapter
|
|
if not hasattr(orchestrator, 'universal_adapter'):
|
|
logger.error("❌ Orchestrator missing universal_adapter")
|
|
return False
|
|
|
|
# Test coordinated decisions
|
|
decisions = await orchestrator.make_coordinated_decisions()
|
|
|
|
logger.info("✅ Enhanced Orchestrator: PASSED")
|
|
logger.info(f" Generated {len(decisions)} decisions")
|
|
logger.info(f" Universal adapter: {type(orchestrator.universal_adapter).__name__}")
|
|
|
|
for symbol, decision in decisions.items():
|
|
if decision:
|
|
logger.info(f" {symbol}: {decision.action} (confidence: {decision.confidence:.2f})")
|
|
|
|
# Test 3: CNN Model Data Format
|
|
logger.info("\n🧠 TEST 3: CNN Model Data Format")
|
|
logger.info("-" * 40)
|
|
|
|
# Format data for CNN
|
|
cnn_data = adapter.format_for_model(universal_stream, 'cnn')
|
|
|
|
required_cnn_keys = ['eth_ticks', 'eth_1m', 'eth_1h', 'eth_1d', 'btc_ticks']
|
|
missing_keys = [key for key in required_cnn_keys if key not in cnn_data]
|
|
|
|
if missing_keys:
|
|
logger.error(f"❌ CNN data missing keys: {missing_keys}")
|
|
return False
|
|
|
|
logger.info("✅ CNN Model Data Format: PASSED")
|
|
for key, data in cnn_data.items():
|
|
if isinstance(data, np.ndarray):
|
|
logger.info(f" {key}: shape {data.shape}")
|
|
else:
|
|
logger.info(f" {key}: {type(data)}")
|
|
|
|
# Test 4: RL Model Data Format
|
|
logger.info("\n🤖 TEST 4: RL Model Data Format")
|
|
logger.info("-" * 40)
|
|
|
|
# Format data for RL
|
|
rl_data = adapter.format_for_model(universal_stream, 'rl')
|
|
|
|
if 'state_vector' not in rl_data:
|
|
logger.error("❌ RL data missing state_vector")
|
|
return False
|
|
|
|
state_vector = rl_data['state_vector']
|
|
if not isinstance(state_vector, np.ndarray):
|
|
logger.error("❌ RL state_vector is not numpy array")
|
|
return False
|
|
|
|
logger.info("✅ RL Model Data Format: PASSED")
|
|
logger.info(f" State vector shape: {state_vector.shape}")
|
|
logger.info(f" State vector size: {len(state_vector)} features")
|
|
|
|
# Test 5: CNN Trainer Integration
|
|
logger.info("\n🎓 TEST 5: CNN Trainer Integration")
|
|
logger.info("-" * 40)
|
|
|
|
try:
|
|
cnn_trainer = EnhancedCNNTrainer(config, orchestrator)
|
|
logger.info("✅ CNN Trainer Integration: PASSED")
|
|
logger.info(f" Model timeframes: {cnn_trainer.model.timeframes}")
|
|
logger.info(f" Model device: {cnn_trainer.model.device}")
|
|
except Exception as e:
|
|
logger.error(f"❌ CNN Trainer Integration failed: {e}")
|
|
return False
|
|
|
|
# Test 6: RL Trainer Integration
|
|
logger.info("\n🎮 TEST 6: RL Trainer Integration")
|
|
logger.info("-" * 40)
|
|
|
|
try:
|
|
rl_trainer = EnhancedRLTrainer(config, orchestrator)
|
|
logger.info("✅ RL Trainer Integration: PASSED")
|
|
logger.info(f" RL agents: {len(rl_trainer.agents)}")
|
|
for symbol, agent in rl_trainer.agents.items():
|
|
logger.info(f" {symbol} agent: {type(agent).__name__}")
|
|
except Exception as e:
|
|
logger.error(f"❌ RL Trainer Integration failed: {e}")
|
|
return False
|
|
|
|
# Test 7: Data Flow Verification
|
|
logger.info("\n🔄 TEST 7: Data Flow Verification")
|
|
logger.info("-" * 40)
|
|
|
|
# Verify that models receive the correct data format
|
|
test_predictions = await orchestrator._get_enhanced_predictions_universal(
|
|
'ETH/USDT',
|
|
list(orchestrator.market_states['ETH/USDT'])[-1] if orchestrator.market_states['ETH/USDT'] else None,
|
|
universal_stream
|
|
)
|
|
|
|
if test_predictions:
|
|
logger.info("✅ Data Flow Verification: PASSED")
|
|
for pred in test_predictions:
|
|
logger.info(f" Model: {pred.model_name}")
|
|
logger.info(f" Action: {pred.overall_action}")
|
|
logger.info(f" Confidence: {pred.overall_confidence:.2f}")
|
|
logger.info(f" Timeframes: {len(pred.timeframe_predictions)}")
|
|
else:
|
|
logger.warning("⚠️ No predictions generated (may be normal if no models loaded)")
|
|
|
|
# Test 8: Configuration Compliance
|
|
logger.info("\n⚙️ TEST 8: Configuration Compliance")
|
|
logger.info("-" * 40)
|
|
|
|
# Check that config matches universal format
|
|
expected_symbols = ['ETH/USDT', 'BTC/USDT']
|
|
expected_timeframes = ['1s', '1m', '1h', '1d']
|
|
|
|
config_symbols = config.symbols
|
|
config_timeframes = config.timeframes
|
|
|
|
symbols_match = all(symbol in config_symbols for symbol in expected_symbols)
|
|
timeframes_match = all(tf in config_timeframes for tf in expected_timeframes)
|
|
|
|
if not symbols_match:
|
|
logger.warning(f"⚠️ Config symbols may not match universal format")
|
|
logger.warning(f" Expected: {expected_symbols}")
|
|
logger.warning(f" Config: {config_symbols}")
|
|
|
|
if not timeframes_match:
|
|
logger.warning(f"⚠️ Config timeframes may not match universal format")
|
|
logger.warning(f" Expected: {expected_timeframes}")
|
|
logger.warning(f" Config: {config_timeframes}")
|
|
|
|
if symbols_match and timeframes_match:
|
|
logger.info("✅ Configuration Compliance: PASSED")
|
|
else:
|
|
logger.info("⚠️ Configuration Compliance: PARTIAL")
|
|
|
|
logger.info(f" Symbols: {config_symbols}")
|
|
logger.info(f" Timeframes: {config_timeframes}")
|
|
|
|
# Final Summary
|
|
logger.info("\n" + "="*80)
|
|
logger.info("🎉 UNIVERSAL DATA FORMAT TEST SUMMARY")
|
|
logger.info("="*80)
|
|
logger.info("✅ All core tests PASSED!")
|
|
logger.info("")
|
|
logger.info("📋 VERIFIED COMPLIANCE:")
|
|
logger.info(" ✓ Universal Data Adapter working")
|
|
logger.info(" ✓ Enhanced Orchestrator using universal format")
|
|
logger.info(" ✓ CNN models receive 5 timeseries streams")
|
|
logger.info(" ✓ RL models receive combined state vector")
|
|
logger.info(" ✓ Trainers properly integrated")
|
|
logger.info(" ✓ Data flow verified")
|
|
logger.info("")
|
|
logger.info("🎯 UNIVERSAL FORMAT ACTIVE:")
|
|
logger.info(" 1. ETH/USDT ticks (1s) ✓")
|
|
logger.info(" 2. ETH/USDT 1m ✓")
|
|
logger.info(" 3. ETH/USDT 1h ✓")
|
|
logger.info(" 4. ETH/USDT 1d ✓")
|
|
logger.info(" 5. BTC/USDT reference ticks ✓")
|
|
logger.info("")
|
|
logger.info("🚀 Your enhanced trading system is ready with universal data format!")
|
|
logger.info("="*80)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Universal data format test failed: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return False
|
|
|
|
async def main():
|
|
"""Main test function"""
|
|
logger.info("🚀 Starting Universal Data Format Compliance Test...")
|
|
|
|
success = await test_universal_data_format()
|
|
|
|
if success:
|
|
logger.info("\n🎉 All tests passed! Universal data format is properly implemented.")
|
|
logger.info("Your enhanced trading system respects the 5-timeseries input format.")
|
|
else:
|
|
logger.error("\n💥 Tests failed! Please check the universal data format implementation.")
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |