gogo2/test_universal_data_format.py
2025-05-26 16:02:40 +03:00

262 lines
10 KiB
Python

#!/usr/bin/env python3
"""
Test Universal Data Format Compliance
This script verifies that our enhanced trading system properly feeds
the 5 required timeseries streams to all models:
- ETH/USDT: ticks (1s), 1m, 1h, 1d
- BTC/USDT: ticks (1s) as reference
This is our universal trading system input format.
"""
import asyncio
import logging
import sys
from pathlib import Path
import numpy as np
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import get_config
from core.data_provider import DataProvider
from core.universal_data_adapter import UniversalDataAdapter, UniversalDataStream
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
from training.enhanced_cnn_trainer import EnhancedCNNTrainer
from training.enhanced_rl_trainer import EnhancedRLTrainer
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
async def test_universal_data_format():
"""Test that all components properly use the universal 5-timeseries format"""
logger.info("="*80)
logger.info("🧪 TESTING UNIVERSAL DATA FORMAT COMPLIANCE")
logger.info("="*80)
try:
# Initialize components
config = get_config()
data_provider = DataProvider(config)
# Test 1: Universal Data Adapter
logger.info("\n📊 TEST 1: Universal Data Adapter")
logger.info("-" * 40)
adapter = UniversalDataAdapter(data_provider)
universal_stream = adapter.get_universal_data_stream()
if universal_stream is None:
logger.error("❌ Failed to get universal data stream")
return False
# Validate format
is_valid, issues = adapter.validate_universal_format(universal_stream)
if not is_valid:
logger.error(f"❌ Universal format validation failed: {issues}")
return False
logger.info("✅ Universal Data Adapter: PASSED")
logger.info(f" ETH ticks: {len(universal_stream.eth_ticks)} samples")
logger.info(f" ETH 1m: {len(universal_stream.eth_1m)} candles")
logger.info(f" ETH 1h: {len(universal_stream.eth_1h)} candles")
logger.info(f" ETH 1d: {len(universal_stream.eth_1d)} candles")
logger.info(f" BTC reference: {len(universal_stream.btc_ticks)} samples")
logger.info(f" Data quality: {universal_stream.metadata['data_quality']['overall_score']:.2f}")
# Test 2: Enhanced Orchestrator
logger.info("\n🎯 TEST 2: Enhanced Orchestrator")
logger.info("-" * 40)
orchestrator = EnhancedTradingOrchestrator(data_provider)
# Test that orchestrator uses universal adapter
if not hasattr(orchestrator, 'universal_adapter'):
logger.error("❌ Orchestrator missing universal_adapter")
return False
# Test coordinated decisions
decisions = await orchestrator.make_coordinated_decisions()
logger.info("✅ Enhanced Orchestrator: PASSED")
logger.info(f" Generated {len(decisions)} decisions")
logger.info(f" Universal adapter: {type(orchestrator.universal_adapter).__name__}")
for symbol, decision in decisions.items():
if decision:
logger.info(f" {symbol}: {decision.action} (confidence: {decision.confidence:.2f})")
# Test 3: CNN Model Data Format
logger.info("\n🧠 TEST 3: CNN Model Data Format")
logger.info("-" * 40)
# Format data for CNN
cnn_data = adapter.format_for_model(universal_stream, 'cnn')
required_cnn_keys = ['eth_ticks', 'eth_1m', 'eth_1h', 'eth_1d', 'btc_ticks']
missing_keys = [key for key in required_cnn_keys if key not in cnn_data]
if missing_keys:
logger.error(f"❌ CNN data missing keys: {missing_keys}")
return False
logger.info("✅ CNN Model Data Format: PASSED")
for key, data in cnn_data.items():
if isinstance(data, np.ndarray):
logger.info(f" {key}: shape {data.shape}")
else:
logger.info(f" {key}: {type(data)}")
# Test 4: RL Model Data Format
logger.info("\n🤖 TEST 4: RL Model Data Format")
logger.info("-" * 40)
# Format data for RL
rl_data = adapter.format_for_model(universal_stream, 'rl')
if 'state_vector' not in rl_data:
logger.error("❌ RL data missing state_vector")
return False
state_vector = rl_data['state_vector']
if not isinstance(state_vector, np.ndarray):
logger.error("❌ RL state_vector is not numpy array")
return False
logger.info("✅ RL Model Data Format: PASSED")
logger.info(f" State vector shape: {state_vector.shape}")
logger.info(f" State vector size: {len(state_vector)} features")
# Test 5: CNN Trainer Integration
logger.info("\n🎓 TEST 5: CNN Trainer Integration")
logger.info("-" * 40)
try:
cnn_trainer = EnhancedCNNTrainer(config, orchestrator)
logger.info("✅ CNN Trainer Integration: PASSED")
logger.info(f" Model timeframes: {cnn_trainer.model.timeframes}")
logger.info(f" Model device: {cnn_trainer.model.device}")
except Exception as e:
logger.error(f"❌ CNN Trainer Integration failed: {e}")
return False
# Test 6: RL Trainer Integration
logger.info("\n🎮 TEST 6: RL Trainer Integration")
logger.info("-" * 40)
try:
rl_trainer = EnhancedRLTrainer(config, orchestrator)
logger.info("✅ RL Trainer Integration: PASSED")
logger.info(f" RL agents: {len(rl_trainer.agents)}")
for symbol, agent in rl_trainer.agents.items():
logger.info(f" {symbol} agent: {type(agent).__name__}")
except Exception as e:
logger.error(f"❌ RL Trainer Integration failed: {e}")
return False
# Test 7: Data Flow Verification
logger.info("\n🔄 TEST 7: Data Flow Verification")
logger.info("-" * 40)
# Verify that models receive the correct data format
test_predictions = await orchestrator._get_enhanced_predictions_universal(
'ETH/USDT',
list(orchestrator.market_states['ETH/USDT'])[-1] if orchestrator.market_states['ETH/USDT'] else None,
universal_stream
)
if test_predictions:
logger.info("✅ Data Flow Verification: PASSED")
for pred in test_predictions:
logger.info(f" Model: {pred.model_name}")
logger.info(f" Action: {pred.overall_action}")
logger.info(f" Confidence: {pred.overall_confidence:.2f}")
logger.info(f" Timeframes: {len(pred.timeframe_predictions)}")
else:
logger.warning("⚠️ No predictions generated (may be normal if no models loaded)")
# Test 8: Configuration Compliance
logger.info("\n⚙️ TEST 8: Configuration Compliance")
logger.info("-" * 40)
# Check that config matches universal format
expected_symbols = ['ETH/USDT', 'BTC/USDT']
expected_timeframes = ['1s', '1m', '1h', '1d']
config_symbols = config.symbols
config_timeframes = config.timeframes
symbols_match = all(symbol in config_symbols for symbol in expected_symbols)
timeframes_match = all(tf in config_timeframes for tf in expected_timeframes)
if not symbols_match:
logger.warning(f"⚠️ Config symbols may not match universal format")
logger.warning(f" Expected: {expected_symbols}")
logger.warning(f" Config: {config_symbols}")
if not timeframes_match:
logger.warning(f"⚠️ Config timeframes may not match universal format")
logger.warning(f" Expected: {expected_timeframes}")
logger.warning(f" Config: {config_timeframes}")
if symbols_match and timeframes_match:
logger.info("✅ Configuration Compliance: PASSED")
else:
logger.info("⚠️ Configuration Compliance: PARTIAL")
logger.info(f" Symbols: {config_symbols}")
logger.info(f" Timeframes: {config_timeframes}")
# Final Summary
logger.info("\n" + "="*80)
logger.info("🎉 UNIVERSAL DATA FORMAT TEST SUMMARY")
logger.info("="*80)
logger.info("✅ All core tests PASSED!")
logger.info("")
logger.info("📋 VERIFIED COMPLIANCE:")
logger.info(" ✓ Universal Data Adapter working")
logger.info(" ✓ Enhanced Orchestrator using universal format")
logger.info(" ✓ CNN models receive 5 timeseries streams")
logger.info(" ✓ RL models receive combined state vector")
logger.info(" ✓ Trainers properly integrated")
logger.info(" ✓ Data flow verified")
logger.info("")
logger.info("🎯 UNIVERSAL FORMAT ACTIVE:")
logger.info(" 1. ETH/USDT ticks (1s) ✓")
logger.info(" 2. ETH/USDT 1m ✓")
logger.info(" 3. ETH/USDT 1h ✓")
logger.info(" 4. ETH/USDT 1d ✓")
logger.info(" 5. BTC/USDT reference ticks ✓")
logger.info("")
logger.info("🚀 Your enhanced trading system is ready with universal data format!")
logger.info("="*80)
return True
except Exception as e:
logger.error(f"❌ Universal data format test failed: {e}")
import traceback
logger.error(traceback.format_exc())
return False
async def main():
"""Main test function"""
logger.info("🚀 Starting Universal Data Format Compliance Test...")
success = await test_universal_data_format()
if success:
logger.info("\n🎉 All tests passed! Universal data format is properly implemented.")
logger.info("Your enhanced trading system respects the 5-timeseries input format.")
else:
logger.error("\n💥 Tests failed! Please check the universal data format implementation.")
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())