folder stricture reorganize
This commit is contained in:
201
tests/test_nn_driven_trading.py
Normal file
201
tests/test_nn_driven_trading.py
Normal file
@ -0,0 +1,201 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test NN-Driven Trading System
|
||||
Demonstrates how the system now makes decisions using Neural Networks instead of algorithms
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def test_nn_driven_system():
|
||||
"""Test the NN-driven trading system"""
|
||||
logger.info("=== TESTING NN-DRIVEN TRADING SYSTEM ===")
|
||||
|
||||
try:
|
||||
# Import core components
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.nn_decision_fusion import ModelPrediction, MarketContext
|
||||
|
||||
# Initialize components
|
||||
config = get_config()
|
||||
data_provider = DataProvider()
|
||||
|
||||
# Initialize NN-driven orchestrator
|
||||
orchestrator = EnhancedTradingOrchestrator(
|
||||
data_provider=data_provider,
|
||||
symbols=['ETH/USDT', 'BTC/USDT'],
|
||||
enhanced_rl_training=True
|
||||
)
|
||||
|
||||
logger.info("✅ NN-driven orchestrator initialized")
|
||||
|
||||
# Test 1: Add mock CNN prediction
|
||||
cnn_prediction = ModelPrediction(
|
||||
model_name="williams_cnn",
|
||||
prediction_type="direction",
|
||||
value=0.6, # Bullish signal
|
||||
confidence=0.8,
|
||||
timestamp=datetime.now(),
|
||||
metadata={'timeframe': '1h', 'feature_importance': [0.2, 0.3, 0.5]}
|
||||
)
|
||||
|
||||
orchestrator.neural_fusion.add_prediction(cnn_prediction)
|
||||
logger.info("🔮 Added CNN prediction: BULLISH (0.6) with 80% confidence")
|
||||
|
||||
# Test 2: Add mock RL prediction
|
||||
rl_prediction = ModelPrediction(
|
||||
model_name="dqn_agent",
|
||||
prediction_type="action",
|
||||
value=0.4, # Moderate buy signal
|
||||
confidence=0.7,
|
||||
timestamp=datetime.now(),
|
||||
metadata={'action_probs': [0.4, 0.2, 0.4]} # [BUY, SELL, HOLD]
|
||||
)
|
||||
|
||||
orchestrator.neural_fusion.add_prediction(rl_prediction)
|
||||
logger.info("🔮 Added RL prediction: MODERATE_BUY (0.4) with 70% confidence")
|
||||
|
||||
# Test 3: Add mock COB RL prediction
|
||||
cob_prediction = ModelPrediction(
|
||||
model_name="cob_rl",
|
||||
prediction_type="direction",
|
||||
value=0.3, # Slightly bullish
|
||||
confidence=0.85,
|
||||
timestamp=datetime.now(),
|
||||
metadata={'cob_imbalance': 0.1, 'liquidity_depth': 150000}
|
||||
)
|
||||
|
||||
orchestrator.neural_fusion.add_prediction(cob_prediction)
|
||||
logger.info("🔮 Added COB RL prediction: SLIGHT_BULLISH (0.3) with 85% confidence")
|
||||
|
||||
# Test 4: Create market context
|
||||
market_context = MarketContext(
|
||||
symbol='ETH/USDT',
|
||||
current_price=2441.50,
|
||||
price_change_1m=0.002, # 0.2% up in 1m
|
||||
price_change_5m=0.008, # 0.8% up in 5m
|
||||
volume_ratio=1.2, # 20% above average volume
|
||||
volatility=0.015, # 1.5% volatility
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
logger.info(f"📊 Market Context: ETH/USDT at ${market_context.current_price}")
|
||||
logger.info(f" 📈 Price changes: 1m: {market_context.price_change_1m:.3f}, 5m: {market_context.price_change_5m:.3f}")
|
||||
logger.info(f" 📊 Volume ratio: {market_context.volume_ratio:.2f}, Volatility: {market_context.volatility:.3f}")
|
||||
|
||||
# Test 5: Make NN decision
|
||||
fusion_decision = orchestrator.neural_fusion.make_decision(
|
||||
symbol='ETH/USDT',
|
||||
market_context=market_context,
|
||||
min_confidence=0.25
|
||||
)
|
||||
|
||||
if fusion_decision:
|
||||
logger.info("🧠 === NN DECISION RESULT ===")
|
||||
logger.info(f" Action: {fusion_decision.action}")
|
||||
logger.info(f" Confidence: {fusion_decision.confidence:.3f}")
|
||||
logger.info(f" Expected Return: {fusion_decision.expected_return:.3f}")
|
||||
logger.info(f" Risk Score: {fusion_decision.risk_score:.3f}")
|
||||
logger.info(f" Position Size: {fusion_decision.position_size:.4f} ETH")
|
||||
logger.info(f" Reasoning: {fusion_decision.reasoning}")
|
||||
logger.info(" Model Contributions:")
|
||||
for model, contribution in fusion_decision.model_contributions.items():
|
||||
logger.info(f" - {model}: {contribution:.1%}")
|
||||
else:
|
||||
logger.warning("❌ No NN decision generated")
|
||||
|
||||
# Test 6: Test coordinated decisions
|
||||
logger.info("\n🎯 Testing coordinated NN decisions...")
|
||||
decisions = await orchestrator.make_coordinated_decisions()
|
||||
|
||||
if decisions:
|
||||
logger.info(f"✅ Generated {len(decisions)} NN-driven trading decisions:")
|
||||
for i, decision in enumerate(decisions):
|
||||
logger.info(f" Decision {i+1}: {decision.symbol} {decision.action} "
|
||||
f"({decision.confidence:.3f} confidence, "
|
||||
f"{decision.quantity:.4f} size)")
|
||||
if hasattr(decision, 'metadata') and decision.metadata:
|
||||
if decision.metadata.get('nn_driven'):
|
||||
logger.info(f" 🧠 NN-DRIVEN: {decision.metadata.get('reasoning', 'No reasoning')}")
|
||||
else:
|
||||
logger.info("ℹ️ No trading decisions generated (insufficient confidence)")
|
||||
|
||||
# Test 7: Check NN system status
|
||||
nn_status = orchestrator.neural_fusion.get_status()
|
||||
logger.info("\n📊 NN System Status:")
|
||||
logger.info(f" Device: {nn_status['device']}")
|
||||
logger.info(f" Training Mode: {nn_status['training_mode']}")
|
||||
logger.info(f" Registered Models: {nn_status['registered_models']}")
|
||||
logger.info(f" Recent Predictions: {nn_status['recent_predictions']}")
|
||||
logger.info(f" Model Parameters: {nn_status['model_parameters']:,}")
|
||||
|
||||
# Test 8: Demonstrate different confidence scenarios
|
||||
logger.info("\n🔬 Testing different confidence scenarios...")
|
||||
|
||||
# Low confidence scenario
|
||||
low_conf_prediction = ModelPrediction(
|
||||
model_name="williams_cnn",
|
||||
prediction_type="direction",
|
||||
value=0.1, # Weak signal
|
||||
confidence=0.2, # Low confidence
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
orchestrator.neural_fusion.add_prediction(low_conf_prediction)
|
||||
low_conf_decision = orchestrator.neural_fusion.make_decision(
|
||||
symbol='ETH/USDT',
|
||||
market_context=market_context,
|
||||
min_confidence=0.25
|
||||
)
|
||||
|
||||
if low_conf_decision:
|
||||
logger.info(f" Low confidence result: {low_conf_decision.action} (should be HOLD)")
|
||||
else:
|
||||
logger.info(" ✅ Low confidence correctly resulted in no decision")
|
||||
|
||||
# High confidence scenario
|
||||
high_conf_prediction = ModelPrediction(
|
||||
model_name="williams_cnn",
|
||||
prediction_type="direction",
|
||||
value=0.8, # Strong signal
|
||||
confidence=0.95, # Very high confidence
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
orchestrator.neural_fusion.add_prediction(high_conf_prediction)
|
||||
high_conf_decision = orchestrator.neural_fusion.make_decision(
|
||||
symbol='ETH/USDT',
|
||||
market_context=market_context,
|
||||
min_confidence=0.25
|
||||
)
|
||||
|
||||
if high_conf_decision:
|
||||
logger.info(f" High confidence result: {high_conf_decision.action} "
|
||||
f"(conf: {high_conf_decision.confidence:.3f}, "
|
||||
f"size: {high_conf_decision.position_size:.4f})")
|
||||
|
||||
logger.info("\n✅ NN-DRIVEN TRADING SYSTEM TEST COMPLETE")
|
||||
logger.info("🎯 Key Benefits Demonstrated:")
|
||||
logger.info(" 1. Multiple NN models provide predictions")
|
||||
logger.info(" 2. Central NN fusion makes final decisions")
|
||||
logger.info(" 3. Market context influences decisions")
|
||||
logger.info(" 4. Confidence thresholds prevent bad trades")
|
||||
logger.info(" 5. Position sizing based on NN outputs")
|
||||
logger.info(" 6. Clear reasoning for every decision")
|
||||
logger.info(" 7. Model contribution tracking")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in NN-driven system test: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_nn_driven_system())
|
Reference in New Issue
Block a user