device tensor fix
This commit is contained in:
153
test_device_training_fix.py
Normal file
153
test_device_training_fix.py
Normal file
@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify device handling and training sample population fixes
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import torch
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_device_handling():
|
||||
"""Test that device handling is working correctly"""
|
||||
try:
|
||||
logger.info("Testing device handling...")
|
||||
|
||||
# Test 1: Check CUDA availability
|
||||
cuda_available = torch.cuda.is_available()
|
||||
device = torch.device("cuda" if cuda_available else "cpu")
|
||||
logger.info(f"CUDA available: {cuda_available}")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Test 2: Initialize CNN adapter
|
||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
|
||||
logger.info("Initializing CNN adapter...")
|
||||
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||
|
||||
logger.info(f"CNN adapter device: {cnn_adapter.device}")
|
||||
logger.info(f"CNN model device: {cnn_adapter.model.device}")
|
||||
|
||||
# Test 3: Create test data
|
||||
from core.data_models import BaseDataInput
|
||||
|
||||
logger.info("Creating test BaseDataInput...")
|
||||
base_data = BaseDataInput(
|
||||
symbol="ETH/USDT",
|
||||
timestamp=datetime.now(),
|
||||
ohlcv_1s=[],
|
||||
ohlcv_1m=[],
|
||||
ohlcv_1h=[],
|
||||
ohlcv_1d=[],
|
||||
btc_ohlcv_1s=[],
|
||||
cob_data=None,
|
||||
technical_indicators={},
|
||||
last_predictions={}
|
||||
)
|
||||
|
||||
# Test 4: Make prediction (this should not cause device mismatch)
|
||||
logger.info("Making prediction...")
|
||||
prediction = cnn_adapter.predict(base_data)
|
||||
|
||||
logger.info(f"Prediction successful: {prediction.predictions['action']}")
|
||||
logger.info(f"Confidence: {prediction.confidence:.4f}")
|
||||
|
||||
# Test 5: Add training samples
|
||||
logger.info("Adding training samples...")
|
||||
cnn_adapter.add_training_sample(base_data, "BUY", 0.1)
|
||||
cnn_adapter.add_training_sample(base_data, "SELL", -0.05)
|
||||
cnn_adapter.add_training_sample(base_data, "HOLD", 0.02)
|
||||
|
||||
logger.info(f"Training samples added: {len(cnn_adapter.training_data)}")
|
||||
|
||||
# Test 6: Try training if we have enough samples
|
||||
if len(cnn_adapter.training_data) >= 2:
|
||||
logger.info("Attempting training...")
|
||||
training_results = cnn_adapter.train(epochs=1)
|
||||
logger.info(f"Training results: {training_results}")
|
||||
else:
|
||||
logger.info("Not enough samples for training")
|
||||
|
||||
logger.info("✅ Device handling test passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Device handling test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def test_orchestrator_training():
|
||||
"""Test that orchestrator properly adds training samples"""
|
||||
try:
|
||||
logger.info("Testing orchestrator training integration...")
|
||||
|
||||
# Test 1: Initialize orchestrator
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.standardized_data_provider import StandardizedDataProvider
|
||||
|
||||
logger.info("Initializing data provider...")
|
||||
data_provider = StandardizedDataProvider()
|
||||
|
||||
logger.info("Initializing orchestrator...")
|
||||
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
||||
|
||||
# Test 2: Check if CNN adapter is available
|
||||
if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter:
|
||||
logger.info(f"✅ CNN adapter available in orchestrator")
|
||||
logger.info(f"Initial training samples: {len(orchestrator.cnn_adapter.training_data)}")
|
||||
else:
|
||||
logger.warning("⚠️ CNN adapter not available in orchestrator")
|
||||
return False
|
||||
|
||||
# Test 3: Make a trading decision (this should add training samples)
|
||||
logger.info("Making trading decision...")
|
||||
decision = await orchestrator.make_trading_decision("ETH/USDT")
|
||||
|
||||
if decision:
|
||||
logger.info(f"Decision: {decision.action} (confidence: {decision.confidence:.4f})")
|
||||
logger.info(f"Training samples after decision: {len(orchestrator.cnn_adapter.training_data)}")
|
||||
else:
|
||||
logger.warning("No decision made")
|
||||
|
||||
# Test 4: Check inference history
|
||||
logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
|
||||
for model_name, history in orchestrator.inference_history.items():
|
||||
logger.info(f" {model_name}: {len(history)} records")
|
||||
|
||||
logger.info("✅ Orchestrator training test passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Orchestrator training test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def main():
|
||||
"""Run all tests"""
|
||||
logger.info("Starting device and training fix tests...")
|
||||
|
||||
# Test 1: Device handling
|
||||
test1_passed = test_device_handling()
|
||||
|
||||
# Test 2: Orchestrator training
|
||||
test2_passed = await test_orchestrator_training()
|
||||
|
||||
# Summary
|
||||
logger.info("\n" + "="*50)
|
||||
logger.info("TEST SUMMARY:")
|
||||
logger.info(f"Device handling: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
|
||||
logger.info(f"Orchestrator training: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
|
||||
|
||||
if test1_passed and test2_passed:
|
||||
logger.info("🎉 All tests passed! Device and training issues should be fixed.")
|
||||
else:
|
||||
logger.error("❌ Some tests failed. Please check the logs above.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
Reference in New Issue
Block a user