device tensor fix

This commit is contained in:
Dobromir Popov
2025-07-25 13:59:33 +03:00
parent 78b4bb0f06
commit 1f60c80d67
6 changed files with 495 additions and 45 deletions

153
test_device_training_fix.py Normal file
View File

@ -0,0 +1,153 @@
#!/usr/bin/env python3
"""
Test script to verify device handling and training sample population fixes
"""
import logging
import asyncio
import torch
from datetime import datetime
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_device_handling():
"""Test that device handling is working correctly"""
try:
logger.info("Testing device handling...")
# Test 1: Check CUDA availability
cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")
logger.info(f"CUDA available: {cuda_available}")
logger.info(f"Using device: {device}")
# Test 2: Initialize CNN adapter
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
logger.info("Initializing CNN adapter...")
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
logger.info(f"CNN adapter device: {cnn_adapter.device}")
logger.info(f"CNN model device: {cnn_adapter.model.device}")
# Test 3: Create test data
from core.data_models import BaseDataInput
logger.info("Creating test BaseDataInput...")
base_data = BaseDataInput(
symbol="ETH/USDT",
timestamp=datetime.now(),
ohlcv_1s=[],
ohlcv_1m=[],
ohlcv_1h=[],
ohlcv_1d=[],
btc_ohlcv_1s=[],
cob_data=None,
technical_indicators={},
last_predictions={}
)
# Test 4: Make prediction (this should not cause device mismatch)
logger.info("Making prediction...")
prediction = cnn_adapter.predict(base_data)
logger.info(f"Prediction successful: {prediction.predictions['action']}")
logger.info(f"Confidence: {prediction.confidence:.4f}")
# Test 5: Add training samples
logger.info("Adding training samples...")
cnn_adapter.add_training_sample(base_data, "BUY", 0.1)
cnn_adapter.add_training_sample(base_data, "SELL", -0.05)
cnn_adapter.add_training_sample(base_data, "HOLD", 0.02)
logger.info(f"Training samples added: {len(cnn_adapter.training_data)}")
# Test 6: Try training if we have enough samples
if len(cnn_adapter.training_data) >= 2:
logger.info("Attempting training...")
training_results = cnn_adapter.train(epochs=1)
logger.info(f"Training results: {training_results}")
else:
logger.info("Not enough samples for training")
logger.info("✅ Device handling test passed!")
return True
except Exception as e:
logger.error(f"❌ Device handling test failed: {e}")
import traceback
traceback.print_exc()
return False
async def test_orchestrator_training():
"""Test that orchestrator properly adds training samples"""
try:
logger.info("Testing orchestrator training integration...")
# Test 1: Initialize orchestrator
from core.orchestrator import TradingOrchestrator
from core.standardized_data_provider import StandardizedDataProvider
logger.info("Initializing data provider...")
data_provider = StandardizedDataProvider()
logger.info("Initializing orchestrator...")
orchestrator = TradingOrchestrator(data_provider=data_provider)
# Test 2: Check if CNN adapter is available
if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter:
logger.info(f"✅ CNN adapter available in orchestrator")
logger.info(f"Initial training samples: {len(orchestrator.cnn_adapter.training_data)}")
else:
logger.warning("⚠️ CNN adapter not available in orchestrator")
return False
# Test 3: Make a trading decision (this should add training samples)
logger.info("Making trading decision...")
decision = await orchestrator.make_trading_decision("ETH/USDT")
if decision:
logger.info(f"Decision: {decision.action} (confidence: {decision.confidence:.4f})")
logger.info(f"Training samples after decision: {len(orchestrator.cnn_adapter.training_data)}")
else:
logger.warning("No decision made")
# Test 4: Check inference history
logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
for model_name, history in orchestrator.inference_history.items():
logger.info(f" {model_name}: {len(history)} records")
logger.info("✅ Orchestrator training test passed!")
return True
except Exception as e:
logger.error(f"❌ Orchestrator training test failed: {e}")
import traceback
traceback.print_exc()
return False
async def main():
"""Run all tests"""
logger.info("Starting device and training fix tests...")
# Test 1: Device handling
test1_passed = test_device_handling()
# Test 2: Orchestrator training
test2_passed = await test_orchestrator_training()
# Summary
logger.info("\n" + "="*50)
logger.info("TEST SUMMARY:")
logger.info(f"Device handling: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
logger.info(f"Orchestrator training: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
if test1_passed and test2_passed:
logger.info("🎉 All tests passed! Device and training issues should be fixed.")
else:
logger.error("❌ Some tests failed. Please check the logs above.")
if __name__ == "__main__":
asyncio.run(main())