101 lines
3.9 KiB
Python
101 lines
3.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script to debug orchestrator prediction issues
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
|
|
|
from core.orchestrator import TradingOrchestrator
|
|
from core.data_provider import DataProvider
|
|
import asyncio
|
|
import logging
|
|
|
|
# Set up logging
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
async def test_orchestrator_predictions():
|
|
"""Test orchestrator prediction generation"""
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("TESTING ORCHESTRATOR PREDICTIONS")
|
|
logger.info("=" * 60)
|
|
|
|
# Initialize components
|
|
logger.info("1. Initializing orchestrator...")
|
|
data_provider = DataProvider()
|
|
orchestrator = TradingOrchestrator(data_provider)
|
|
|
|
# Check configuration
|
|
logger.info("2. Checking orchestrator configuration...")
|
|
logger.info(f" Confidence threshold: {orchestrator.confidence_threshold}")
|
|
logger.info(f" Confidence threshold close: {orchestrator.confidence_threshold_close}")
|
|
logger.info(f" Model weights: {orchestrator.model_weights}")
|
|
|
|
# Check registered models
|
|
logger.info("3. Checking registered models...")
|
|
if hasattr(orchestrator, 'model_registry'):
|
|
# Check what methods are available
|
|
registry_methods = [method for method in dir(orchestrator.model_registry) if not method.startswith('_')]
|
|
logger.info(f" Model registry methods: {registry_methods}")
|
|
|
|
# Check if we have models
|
|
if hasattr(orchestrator, 'rl_agent'):
|
|
logger.info(f" RL Agent available: {orchestrator.rl_agent is not None}")
|
|
if hasattr(orchestrator, 'cnn_model'):
|
|
logger.info(f" CNN Model available: {orchestrator.cnn_model is not None}")
|
|
if hasattr(orchestrator, 'extrema_trainer'):
|
|
logger.info(f" Extrema Trainer available: {orchestrator.extrema_trainer is not None}")
|
|
else:
|
|
logger.info(" No model registry found")
|
|
|
|
# Test prediction generation
|
|
logger.info("4. Testing prediction generation...")
|
|
symbol = 'ETH/USDT'
|
|
|
|
try:
|
|
# Get all predictions
|
|
predictions = await orchestrator._get_all_predictions(symbol)
|
|
logger.info(f" Total predictions: {len(predictions)}")
|
|
|
|
for i, pred in enumerate(predictions):
|
|
logger.info(f" Prediction {i+1}: {pred.action} (confidence: {pred.confidence:.3f}, model: {pred.model_name})")
|
|
|
|
# Test decision making
|
|
logger.info("5. Testing decision making...")
|
|
for i in range(3): # Test multiple decisions
|
|
decision = await orchestrator.make_trading_decision(symbol)
|
|
if decision:
|
|
logger.info(f" Decision {i+1}: {decision.action} (confidence: {decision.confidence:.3f})")
|
|
logger.info(f" Reasoning: {decision.reasoning}")
|
|
else:
|
|
logger.info(f" Decision {i+1}: None")
|
|
|
|
# Wait a bit between decisions
|
|
await asyncio.sleep(1)
|
|
|
|
# Test fallback prediction
|
|
logger.info("6. Testing fallback prediction...")
|
|
current_price = data_provider.get_current_price(symbol)
|
|
if current_price:
|
|
fallback = await orchestrator._generate_fallback_prediction(symbol, current_price)
|
|
if fallback:
|
|
logger.info(f" Fallback prediction: {fallback.action} (confidence: {fallback.confidence:.3f})")
|
|
else:
|
|
logger.info(" No fallback prediction generated")
|
|
else:
|
|
logger.info(" No current price available for fallback test")
|
|
|
|
except Exception as e:
|
|
logger.error(f" Error testing predictions: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("TEST COMPLETE")
|
|
logger.info("=" * 60)
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(test_orchestrator_predictions()) |