153 lines
5.8 KiB
Python
153 lines
5.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script to verify device handling and training sample population fixes
|
|
"""
|
|
|
|
import logging
|
|
import asyncio
|
|
import torch
|
|
from datetime import datetime
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def test_device_handling():
|
|
"""Test that device handling is working correctly"""
|
|
try:
|
|
logger.info("Testing device handling...")
|
|
|
|
# Test 1: Check CUDA availability
|
|
cuda_available = torch.cuda.is_available()
|
|
device = torch.device("cuda" if cuda_available else "cpu")
|
|
logger.info(f"CUDA available: {cuda_available}")
|
|
logger.info(f"Using device: {device}")
|
|
|
|
# Test 2: Initialize CNN adapter
|
|
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
|
|
|
logger.info("Initializing CNN adapter...")
|
|
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
|
|
|
logger.info(f"CNN adapter device: {cnn_adapter.device}")
|
|
logger.info(f"CNN model device: {cnn_adapter.model.device}")
|
|
|
|
# Test 3: Create test data
|
|
from core.data_models import BaseDataInput
|
|
|
|
logger.info("Creating test BaseDataInput...")
|
|
base_data = BaseDataInput(
|
|
symbol="ETH/USDT",
|
|
timestamp=datetime.now(),
|
|
ohlcv_1s=[],
|
|
ohlcv_1m=[],
|
|
ohlcv_1h=[],
|
|
ohlcv_1d=[],
|
|
btc_ohlcv_1s=[],
|
|
cob_data=None,
|
|
technical_indicators={},
|
|
last_predictions={}
|
|
)
|
|
|
|
# Test 4: Make prediction (this should not cause device mismatch)
|
|
logger.info("Making prediction...")
|
|
prediction = cnn_adapter.predict(base_data)
|
|
|
|
logger.info(f"Prediction successful: {prediction.predictions['action']}")
|
|
logger.info(f"Confidence: {prediction.confidence:.4f}")
|
|
|
|
# Test 5: Add training samples
|
|
logger.info("Adding training samples...")
|
|
cnn_adapter.add_training_sample(base_data, "BUY", 0.1)
|
|
cnn_adapter.add_training_sample(base_data, "SELL", -0.05)
|
|
cnn_adapter.add_training_sample(base_data, "HOLD", 0.02)
|
|
|
|
logger.info(f"Training samples added: {len(cnn_adapter.training_data)}")
|
|
|
|
# Test 6: Try training if we have enough samples
|
|
if len(cnn_adapter.training_data) >= 2:
|
|
logger.info("Attempting training...")
|
|
training_results = cnn_adapter.train(epochs=1)
|
|
logger.info(f"Training results: {training_results}")
|
|
else:
|
|
logger.info("Not enough samples for training")
|
|
|
|
logger.info("✅ Device handling test passed!")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Device handling test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
async def test_orchestrator_training():
|
|
"""Test that orchestrator properly adds training samples"""
|
|
try:
|
|
logger.info("Testing orchestrator training integration...")
|
|
|
|
# Test 1: Initialize orchestrator
|
|
from core.orchestrator import TradingOrchestrator
|
|
from core.standardized_data_provider import StandardizedDataProvider
|
|
|
|
logger.info("Initializing data provider...")
|
|
data_provider = StandardizedDataProvider()
|
|
|
|
logger.info("Initializing orchestrator...")
|
|
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
|
|
|
# Test 2: Check if CNN adapter is available
|
|
if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter:
|
|
logger.info(f"✅ CNN adapter available in orchestrator")
|
|
logger.info(f"Initial training samples: {len(orchestrator.cnn_adapter.training_data)}")
|
|
else:
|
|
logger.warning("⚠️ CNN adapter not available in orchestrator")
|
|
return False
|
|
|
|
# Test 3: Make a trading decision (this should add training samples)
|
|
logger.info("Making trading decision...")
|
|
decision = await orchestrator.make_trading_decision("ETH/USDT")
|
|
|
|
if decision:
|
|
logger.info(f"Decision: {decision.action} (confidence: {decision.confidence:.4f})")
|
|
logger.info(f"Training samples after decision: {len(orchestrator.cnn_adapter.training_data)}")
|
|
else:
|
|
logger.warning("No decision made")
|
|
|
|
# Test 4: Check inference history
|
|
logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
|
|
for model_name, history in orchestrator.inference_history.items():
|
|
logger.info(f" {model_name}: {len(history)} records")
|
|
|
|
logger.info("✅ Orchestrator training test passed!")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Orchestrator training test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
async def main():
|
|
"""Run all tests"""
|
|
logger.info("Starting device and training fix tests...")
|
|
|
|
# Test 1: Device handling
|
|
test1_passed = test_device_handling()
|
|
|
|
# Test 2: Orchestrator training
|
|
test2_passed = await test_orchestrator_training()
|
|
|
|
# Summary
|
|
logger.info("\n" + "="*50)
|
|
logger.info("TEST SUMMARY:")
|
|
logger.info(f"Device handling: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
|
|
logger.info(f"Orchestrator training: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
|
|
|
|
if test1_passed and test2_passed:
|
|
logger.info("🎉 All tests passed! Device and training issues should be fixed.")
|
|
else:
|
|
logger.error("❌ Some tests failed. Please check the logs above.")
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |