Files
gogo2/test_device_fix.py
2025-07-27 22:13:28 +03:00

141 lines
4.7 KiB
Python

#!/usr/bin/env python3
"""
Test script to verify device mismatch fixes for GPU training
"""
import torch
import logging
import sys
import os
# Add the project root to the path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from NN.models.enhanced_cnn import EnhancedCNN
from core.data_models import BaseDataInput, OHLCVBar
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_device_consistency():
"""Test that all tensors are on the same device"""
logger.info("Testing device consistency for EnhancedCNN...")
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
try:
# Initialize the adapter
adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
# Verify adapter device
logger.info(f"Adapter device: {adapter.device}")
logger.info(f"Model device: {next(adapter.model.parameters()).device}")
# Create sample data
sample_ohlcv = [
OHLCVBar(
symbol="ETH/USDT",
timeframe="1s",
timestamp=1640995200.0, # 2022-01-01
open=50000.0,
high=51000.0,
low=49000.0,
close=50500.0,
volume=1000.0
)
] * 300 # 300 frames
base_data = BaseDataInput(
symbol="ETH/USDT",
timestamp=1640995200.0,
ohlcv_1s=sample_ohlcv,
ohlcv_1m=sample_ohlcv,
ohlcv_5m=sample_ohlcv,
ohlcv_15m=sample_ohlcv,
btc_ohlcv=sample_ohlcv,
cob_data={},
ma_data={},
technical_indicators={},
last_predictions={}
)
# Test prediction
logger.info("Testing prediction...")
prediction = adapter.predict(base_data)
logger.info(f"Prediction successful: {prediction.predictions['action']} (confidence: {prediction.confidence:.3f})")
# Test training sample addition
logger.info("Testing training sample addition...")
adapter.add_training_sample(base_data, "BUY", 0.1)
adapter.add_training_sample(base_data, "SELL", -0.05)
adapter.add_training_sample(base_data, "HOLD", 0.02)
# Test training
logger.info("Testing training...")
training_results = adapter.train(epochs=1)
logger.info(f"Training results: {training_results}")
logger.info("✅ All device consistency tests passed!")
return True
except Exception as e:
logger.error(f"❌ Device consistency test failed: {e}")
import traceback
traceback.print_exc()
return False
def test_orchestrator_inference_history():
"""Test that orchestrator properly initializes inference history"""
logger.info("Testing orchestrator inference history initialization...")
try:
from core.orchestrator import TradingOrchestrator
from core.data_provider import DataProvider
# Initialize orchestrator
data_provider = DataProvider()
orchestrator = TradingOrchestrator(data_provider=data_provider)
# Check if inference history is initialized
logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
# Check if models are registered
logger.info(f"Registered models: {list(orchestrator.model_registry.models.keys())}")
# Verify each registered model has inference history
for model_name in orchestrator.model_registry.models.keys():
if model_name in orchestrator.inference_history:
logger.info(f"{model_name} has inference history initialized")
else:
logger.warning(f"{model_name} missing inference history")
logger.info("✅ Orchestrator inference history test completed!")
return True
except Exception as e:
logger.error(f"❌ Orchestrator test failed: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
logger.info("Starting device fix verification tests...")
# Test 1: Device consistency
test1_passed = test_device_consistency()
# Test 2: Orchestrator inference history
test2_passed = test_orchestrator_inference_history()
# Summary
if test1_passed and test2_passed:
logger.info("🎉 All tests passed! Device issues should be fixed.")
sys.exit(0)
else:
logger.error("❌ Some tests failed. Please check the logs above.")
sys.exit(1)