device tensor fix

This commit is contained in:
Dobromir Popov
2025-07-25 13:59:33 +03:00
parent 78b4bb0f06
commit 1f60c80d67
6 changed files with 495 additions and 45 deletions

1
.gitignore vendored
View File

@ -48,3 +48,4 @@ chrome_user_data/*
.env
.env
training_data/*

View File

@ -70,6 +70,9 @@ class EnhancedCNNAdapter:
else:
self._load_best_checkpoint()
# Final device check and move
self._ensure_model_on_device()
logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
def _initialize_model(self):
@ -88,9 +91,10 @@ class EnhancedCNNAdapter:
# Create model
self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
# Ensure model is moved to the correct device
self.model.to(self.device)
logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions}")
logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions} on device {self.device}")
except Exception as e:
logger.error(f"Error initializing EnhancedCNN model: {e}")
@ -102,7 +106,9 @@ class EnhancedCNNAdapter:
if self.model and os.path.exists(checkpoint_path):
success = self.model.load(checkpoint_path)
if success:
logger.info(f"Loaded model from {checkpoint_path}")
# Ensure model is moved to the correct device after loading
self.model.to(self.device)
logger.info(f"Loaded model from {checkpoint_path} and moved to {self.device}")
return True
else:
logger.warning(f"Failed to load model from {checkpoint_path}")
@ -146,7 +152,9 @@ class EnhancedCNNAdapter:
success = self.model.load(best_checkpoint_path)
if success:
logger.info(f"Loaded best checkpoint from {best_checkpoint_path}")
# Ensure model is moved to the correct device after loading
self.model.to(self.device)
logger.info(f"Loaded best checkpoint from {best_checkpoint_path} and moved to {self.device}")
# Log metrics
metrics = best_checkpoint_metadata.get('metrics', {})
@ -161,7 +169,17 @@ class EnhancedCNNAdapter:
logger.error(f"Error loading best checkpoint: {e}")
return False
def _ensure_model_on_device(self):
"""Ensure model and all its components are on the correct device"""
try:
if self.model:
self.model.to(self.device)
# Also ensure the model's internal device is set correctly
if hasattr(self.model, 'device'):
self.model.device = self.device
logger.debug(f"Model ensured on device {self.device}")
except Exception as e:
logger.error(f"Error ensuring model on device: {e}")
def _create_default_output(self, symbol: str) -> ModelOutput:
"""Create default output when prediction fails"""
@ -235,6 +253,9 @@ class EnhancedCNNAdapter:
if features.dim() == 1:
features = features.unsqueeze(0)
# Ensure model is on correct device before prediction
self._ensure_model_on_device()
# Set model to evaluation mode
self.model.eval()
@ -399,6 +420,9 @@ class EnhancedCNNAdapter:
logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
# Ensure model is on correct device before training
self._ensure_model_on_device()
# Set model to training mode
self.model.train()
@ -423,8 +447,8 @@ class EnhancedCNNAdapter:
if len(batch) < 2:
continue
# Prepare batch
features = torch.stack([sample[0] for sample in batch])
# Prepare batch - ensure all tensors are on the correct device
features = torch.stack([sample[0].to(self.device) for sample in batch])
actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)

View File

@ -299,12 +299,12 @@ class TradingOrchestrator:
logger.warning("DQN Agent not available")
self.rl_agent = None
# Initialize CNN Model
# Initialize CNN Model with Adapter
try:
from NN.models.standardized_cnn import StandardizedCNN
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
self.cnn_model = StandardizedCNN()
self.cnn_model.to(self.device) # Move CNN model to the determined device
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
self.cnn_model = self.cnn_adapter.model # Keep reference for compatibility
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN
# Load best checkpoint and capture initial state
@ -332,11 +332,12 @@ class TradingOrchestrator:
self.model_states['cnn']['best_loss'] = None
logger.info("CNN starting fresh - no checkpoint found")
logger.info("Enhanced CNN model initialized")
logger.info("Enhanced CNN adapter initialized")
except ImportError:
try:
from NN.models.standardized_cnn import StandardizedCNN
self.cnn_model = StandardizedCNN()
self.cnn_adapter = None # No adapter available
self.cnn_model.to(self.device) # Move basic CNN model to the determined device
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for basic CNN
@ -359,6 +360,7 @@ class TradingOrchestrator:
except ImportError:
logger.warning("CNN model not available")
self.cnn_model = None
self.cnn_adapter = None
self.cnn_optimizer = None # Ensure optimizer is also None if model is not available
# Initialize Extrema Trainer
@ -930,6 +932,11 @@ class TradingOrchestrator:
if model.name not in self.model_performance:
self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
# Initialize inference history for this model
if model.name not in self.inference_history:
self.inference_history[model.name] = deque(maxlen=self.max_memory_inferences)
logger.debug(f"Initialized inference history for {model.name}")
logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}")
self._normalize_weights()
return True
@ -1024,6 +1031,9 @@ class TradingOrchestrator:
except Exception as e:
logger.error(f"Error in decision callback: {e}")
# Add training samples based on current market conditions
await self._add_training_samples_from_predictions(symbol, predictions, current_price)
# Clean up memory periodically
if len(self.recent_decisions[symbol]) % 200 == 0: # Reduced from 50 to 200
self.model_registry.cleanup_all_models()
@ -1034,6 +1044,47 @@ class TradingOrchestrator:
logger.error(f"Error making trading decision for {symbol}: {e}")
return None
async def _add_training_samples_from_predictions(self, symbol: str, predictions: List[Prediction], current_price: float):
"""Add training samples to models based on current predictions and market conditions"""
try:
if not hasattr(self, 'cnn_adapter') or not self.cnn_adapter:
return
# Get recent price data to evaluate if predictions would be correct
recent_prices = self.data_provider.get_recent_prices(symbol, limit=10)
if not recent_prices or len(recent_prices) < 2:
return
# Calculate recent price change
price_change_pct = (current_price - recent_prices[-2]) / recent_prices[-2] * 100
# Add training samples for CNN predictions
for prediction in predictions:
if 'cnn' in prediction.model_name.lower():
# Determine reward based on prediction accuracy
reward = 0.0
if prediction.action == 'BUY' and price_change_pct > 0.1:
reward = min(price_change_pct * 0.1, 1.0) # Positive reward for correct BUY
elif prediction.action == 'SELL' and price_change_pct < -0.1:
reward = min(abs(price_change_pct) * 0.1, 1.0) # Positive reward for correct SELL
elif prediction.action == 'HOLD' and abs(price_change_pct) < 0.1:
reward = 0.1 # Small positive reward for correct HOLD
else:
reward = -0.05 # Small negative reward for incorrect prediction
# Add training sample
self.cnn_adapter.add_training_sample(symbol, prediction.action, reward)
logger.debug(f"Added CNN training sample: {prediction.action}, reward={reward:.3f}, price_change={price_change_pct:.2f}%")
# Trigger training if we have enough samples
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
training_results = self.cnn_adapter.train(epochs=1)
logger.info(f"CNN training completed: loss={training_results.get('loss', 0):.4f}, accuracy={training_results.get('accuracy', 0):.4f}")
except Exception as e:
logger.error(f"Error adding training samples from predictions: {e}")
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
"""Get predictions from all registered models with input data storage"""
predictions = []
@ -1051,8 +1102,12 @@ class TradingOrchestrator:
# Get CNN predictions for each timeframe
cnn_predictions = await self._get_cnn_predictions(model, symbol)
predictions.extend(cnn_predictions)
# Store input data for CNN
# Store input data for CNN - store for each prediction
model_input = input_data.get('cnn_input')
if model_input is not None and cnn_predictions:
# Store inference data for each CNN prediction
for cnn_pred in cnn_predictions:
await self._store_inference_data_async(model_name, model_input, cnn_pred, current_time)
elif isinstance(model, RLAgentInterface):
# Get RL prediction
@ -1062,6 +1117,8 @@ class TradingOrchestrator:
prediction = rl_prediction
# Store input data for RL
model_input = input_data.get('rl_input')
if model_input is not None:
await self._store_inference_data_async(model_name, model_input, prediction, current_time)
else:
# Generic model interface
@ -1071,15 +1128,20 @@ class TradingOrchestrator:
prediction = generic_prediction
# Store input data for generic model
model_input = input_data.get('generic_input')
# Store inference data for training (per-model, async)
if prediction and model_input is not None:
await self._store_inference_data_async(model_name, model_input, prediction, current_time)
if model_input is not None:
await self._store_inference_data_async(model_name, model_input, prediction, current_time)
except Exception as e:
logger.error(f"Error getting prediction from {model_name}: {e}")
continue
# Debug: Log inference history status (only if low record count)
total_records = sum(len(history) for history in self.inference_history.values())
if total_records < 10: # Only log when we have few records
logger.debug(f"Total inference records across all models: {total_records}")
for model_name, history in self.inference_history.items():
logger.debug(f" {model_name}: {len(history)} records")
# Trigger training based on previous inference data
await self._trigger_model_training(symbol)
@ -1130,7 +1192,15 @@ class TradingOrchestrator:
}
}
return standardized_input
# Create model-specific input data
model_inputs = {
'cnn_input': standardized_input,
'rl_input': standardized_input,
'generic_input': standardized_input,
'standardized_input': standardized_input
}
return model_inputs
except Exception as e:
logger.error(f"Error collecting standardized model input data: {e}")
@ -1139,6 +1209,9 @@ class TradingOrchestrator:
async def _store_inference_data_async(self, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime):
"""Store inference data per-model with async file operations and memory optimization"""
try:
# Only log first few inference records to avoid spam
if len(self.inference_history.get(model_name, [])) < 3:
logger.debug(f"Storing inference data for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
# Create comprehensive inference record
inference_record = {
'timestamp': timestamp.isoformat(),
@ -1214,8 +1287,8 @@ class TradingOrchestrator:
except Exception as e:
logger.error(f"Error capping model files in {model_dir}: {e}")
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray:
"""Prepare standardized input data for CNN models"""
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor:
"""Prepare standardized input data for CNN models with proper GPU device placement"""
try:
# Create feature matrix from OHLCV data
features = []
@ -1242,16 +1315,18 @@ class TradingOrchestrator:
feature_array = np.pad(feature_array, (0, 300 - len(feature_array)), 'constant')
else:
feature_array = feature_array[:300]
return feature_array.reshape(1, -1)
# Convert to tensor and move to GPU
return torch.tensor(feature_array.reshape(1, -1), dtype=torch.float32, device=self.device)
else:
return np.zeros((1, 300))
# Return zero tensor on GPU
return torch.zeros((1, 300), dtype=torch.float32, device=self.device)
except Exception as e:
logger.error(f"Error preparing CNN input data: {e}")
return np.zeros((1, 300))
return torch.zeros((1, 300), dtype=torch.float32, device=self.device)
def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray:
"""Prepare standardized input data for RL models"""
def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor:
"""Prepare standardized input data for RL models with proper GPU device placement"""
try:
# Create state representation
state_features = []
@ -1279,13 +1354,15 @@ class TradingOrchestrator:
state_array = np.pad(state_array, (0, expected_size - len(state_array)), 'constant')
else:
state_array = state_array[:expected_size]
return state_array
# Convert to tensor and move to GPU
return torch.tensor(state_array, dtype=torch.float32, device=self.device)
else:
return np.zeros(100)
# Return zero tensor on GPU
return torch.zeros(100, dtype=torch.float32, device=self.device)
except Exception as e:
logger.error(f"Error preparing RL input data: {e}")
return np.zeros(100)
return torch.zeros(100, dtype=torch.float32, device=self.device)
def _store_inference_data(self, symbol: str, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime):
"""Store comprehensive inference data for future training with persistent storage"""
@ -1336,10 +1413,12 @@ class TradingOrchestrator:
'outcome_evaluated': False
}
# Store in memory (inference history)
if symbol in self.inference_history:
self.inference_history[symbol].append(inference_record)
logger.debug(f"Stored inference data for {model_name} on {symbol}")
# Store in memory (inference history) - keyed by model_name
if model_name not in self.inference_history:
self.inference_history[model_name] = deque(maxlen=self.max_memory_inferences)
self.inference_history[model_name].append(inference_record)
logger.debug(f"Stored inference data for {model_name} on {symbol}")
# Persistent storage to disk (for long-term training data)
self._save_inference_to_disk(inference_record)
@ -1512,6 +1591,12 @@ class TradingOrchestrator:
for model_name, model_records in self.inference_history.items():
all_recent_records.extend(list(model_records))
# Only log if we have few records (for debugging)
if len(all_recent_records) < 5:
logger.debug(f"Total inference records for training: {len(all_recent_records)}")
for model_name, model_records in self.inference_history.items():
logger.debug(f" Model {model_name} has {len(model_records)} inference records")
if len(all_recent_records) < 2:
logger.debug("Not enough inference records for training")
return # Need at least 2 records to compare
@ -1521,12 +1606,11 @@ class TradingOrchestrator:
if current_price is None:
return
# Process records that are old enough to evaluate outcomes
cutoff_time = datetime.now() - timedelta(minutes=5) # 5 minutes ago
for record in recent_records:
if record['timestamp'] < cutoff_time:
await self._evaluate_and_train_on_record(record, current_price)
# Train on the most recent inference record (last prediction made)
if all_recent_records:
# Get the most recent record for training
most_recent_record = max(all_recent_records, key=lambda x: datetime.fromisoformat(x['timestamp']) if isinstance(x['timestamp'], str) else x['timestamp'])
await self._evaluate_and_train_on_record(most_recent_record, current_price)
except Exception as e:
logger.error(f"Error triggering model training for {symbol}: {e}")
@ -1538,6 +1622,10 @@ class TradingOrchestrator:
prediction = record['prediction']
timestamp = record['timestamp']
# Convert timestamp string back to datetime if needed
if isinstance(timestamp, str):
timestamp = datetime.fromisoformat(timestamp)
# Calculate price change since prediction
# This is a simplified outcome evaluation - you might want to make it more sophisticated
time_diff = (datetime.now() - timestamp).total_seconds() / 60 # minutes
@ -1608,12 +1696,23 @@ class TradingOrchestrator:
)
logger.debug(f"Added RL training experience: reward={reward}")
# Train CNN models
elif 'cnn' in model_name.lower() and self.cnn_model:
if hasattr(self.cnn_model, 'train_on_outcome'):
target = 1 if was_correct else 0
self.cnn_model.train_on_outcome(model_input, target)
logger.debug(f"Trained CNN on outcome: target={target}")
# Train CNN models using adapter
elif 'cnn' in model_name.lower() and hasattr(self, 'cnn_adapter') and self.cnn_adapter:
# Use the adapter's add_training_sample method
actual_action = prediction['action']
self.cnn_adapter.add_training_sample(record['symbol'], actual_action, reward)
logger.debug(f"Added CNN training sample: action={actual_action}, reward={reward}")
# Trigger training if we have enough samples
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
training_results = self.cnn_adapter.train(epochs=1)
logger.debug(f"CNN training results: {training_results}")
# Fallback for raw CNN model
elif 'cnn' in model_name.lower() and self.cnn_model and hasattr(self.cnn_model, 'train_on_outcome'):
target = 1 if was_correct else 0
self.cnn_model.train_on_outcome(model_input, target)
logger.debug(f"Trained CNN on outcome: target={target}")
except Exception as e:
logger.error(f"Error training model on outcome: {e}")
@ -2260,8 +2359,8 @@ class TradingOrchestrator:
return
if not ENHANCED_TRAINING_AVAILABLE:
logger.warning("EnhancedRealtimeTrainingSystem not available - training disabled")
self.training_enabled = False
logger.info("EnhancedRealtimeTrainingSystem not available - using built-in training")
# Keep training enabled - we have built-in training capabilities
return
# Initialize the enhanced training system

View File

@ -449,5 +449,37 @@ class StandardizedDataProvider(DataProvider):
logger.info("Stopped real-time processing for standardized data")
except Exception as e:
logger.error(f"Error stopping real-time processing: {e}")
def get_recent_prices(self, symbol: str, limit: int = 10) -> List[float]:
"""
Get recent prices for a symbol
Args:
symbol: Trading symbol
limit: Number of recent prices to return
Returns:
List[float]: List of recent prices
"""
try:
# Get recent OHLCV data using parent class method
df = self.get_historical_data(symbol, '1m', limit)
if df is None or df.empty:
return []
# Extract close prices from DataFrame
if 'close' in df.columns:
prices = df['close'].tolist()
return prices[-limit:] # Return most recent prices
else:
logger.warning(f"No 'close' column found in OHLCV data for {symbol}")
return []
except Exception as e:
logger.error(f"Error getting recent prices for {symbol}: {e}")
return []
except Exception as e:
logger.error(f"Error stopping real-time processing: {e}")

141
test_device_fix.py Normal file
View File

@ -0,0 +1,141 @@
#!/usr/bin/env python3
"""
Test script to verify device mismatch fixes for GPU training
"""
import torch
import logging
import sys
import os
# Add the project root to the path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
from core.data_models import BaseDataInput, OHLCVBar
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_device_consistency():
"""Test that all tensors are on the same device"""
logger.info("Testing device consistency for EnhancedCNN...")
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
try:
# Initialize the adapter
adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
# Verify adapter device
logger.info(f"Adapter device: {adapter.device}")
logger.info(f"Model device: {next(adapter.model.parameters()).device}")
# Create sample data
sample_ohlcv = [
OHLCVBar(
symbol="ETH/USDT",
timeframe="1s",
timestamp=1640995200.0, # 2022-01-01
open=50000.0,
high=51000.0,
low=49000.0,
close=50500.0,
volume=1000.0
)
] * 300 # 300 frames
base_data = BaseDataInput(
symbol="ETH/USDT",
timestamp=1640995200.0,
ohlcv_1s=sample_ohlcv,
ohlcv_1m=sample_ohlcv,
ohlcv_5m=sample_ohlcv,
ohlcv_15m=sample_ohlcv,
btc_ohlcv=sample_ohlcv,
cob_data={},
ma_data={},
technical_indicators={},
last_predictions={}
)
# Test prediction
logger.info("Testing prediction...")
prediction = adapter.predict(base_data)
logger.info(f"Prediction successful: {prediction.predictions['action']} (confidence: {prediction.confidence:.3f})")
# Test training sample addition
logger.info("Testing training sample addition...")
adapter.add_training_sample(base_data, "BUY", 0.1)
adapter.add_training_sample(base_data, "SELL", -0.05)
adapter.add_training_sample(base_data, "HOLD", 0.02)
# Test training
logger.info("Testing training...")
training_results = adapter.train(epochs=1)
logger.info(f"Training results: {training_results}")
logger.info("✅ All device consistency tests passed!")
return True
except Exception as e:
logger.error(f"❌ Device consistency test failed: {e}")
import traceback
traceback.print_exc()
return False
def test_orchestrator_inference_history():
"""Test that orchestrator properly initializes inference history"""
logger.info("Testing orchestrator inference history initialization...")
try:
from core.orchestrator import TradingOrchestrator
from core.data_provider import DataProvider
# Initialize orchestrator
data_provider = DataProvider()
orchestrator = TradingOrchestrator(data_provider=data_provider)
# Check if inference history is initialized
logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
# Check if models are registered
logger.info(f"Registered models: {list(orchestrator.model_registry.models.keys())}")
# Verify each registered model has inference history
for model_name in orchestrator.model_registry.models.keys():
if model_name in orchestrator.inference_history:
logger.info(f"{model_name} has inference history initialized")
else:
logger.warning(f"{model_name} missing inference history")
logger.info("✅ Orchestrator inference history test completed!")
return True
except Exception as e:
logger.error(f"❌ Orchestrator test failed: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
logger.info("Starting device fix verification tests...")
# Test 1: Device consistency
test1_passed = test_device_consistency()
# Test 2: Orchestrator inference history
test2_passed = test_orchestrator_inference_history()
# Summary
if test1_passed and test2_passed:
logger.info("🎉 All tests passed! Device issues should be fixed.")
sys.exit(0)
else:
logger.error("❌ Some tests failed. Please check the logs above.")
sys.exit(1)

153
test_device_training_fix.py Normal file
View File

@ -0,0 +1,153 @@
#!/usr/bin/env python3
"""
Test script to verify device handling and training sample population fixes
"""
import logging
import asyncio
import torch
from datetime import datetime
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_device_handling():
"""Test that device handling is working correctly"""
try:
logger.info("Testing device handling...")
# Test 1: Check CUDA availability
cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")
logger.info(f"CUDA available: {cuda_available}")
logger.info(f"Using device: {device}")
# Test 2: Initialize CNN adapter
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
logger.info("Initializing CNN adapter...")
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
logger.info(f"CNN adapter device: {cnn_adapter.device}")
logger.info(f"CNN model device: {cnn_adapter.model.device}")
# Test 3: Create test data
from core.data_models import BaseDataInput
logger.info("Creating test BaseDataInput...")
base_data = BaseDataInput(
symbol="ETH/USDT",
timestamp=datetime.now(),
ohlcv_1s=[],
ohlcv_1m=[],
ohlcv_1h=[],
ohlcv_1d=[],
btc_ohlcv_1s=[],
cob_data=None,
technical_indicators={},
last_predictions={}
)
# Test 4: Make prediction (this should not cause device mismatch)
logger.info("Making prediction...")
prediction = cnn_adapter.predict(base_data)
logger.info(f"Prediction successful: {prediction.predictions['action']}")
logger.info(f"Confidence: {prediction.confidence:.4f}")
# Test 5: Add training samples
logger.info("Adding training samples...")
cnn_adapter.add_training_sample(base_data, "BUY", 0.1)
cnn_adapter.add_training_sample(base_data, "SELL", -0.05)
cnn_adapter.add_training_sample(base_data, "HOLD", 0.02)
logger.info(f"Training samples added: {len(cnn_adapter.training_data)}")
# Test 6: Try training if we have enough samples
if len(cnn_adapter.training_data) >= 2:
logger.info("Attempting training...")
training_results = cnn_adapter.train(epochs=1)
logger.info(f"Training results: {training_results}")
else:
logger.info("Not enough samples for training")
logger.info("✅ Device handling test passed!")
return True
except Exception as e:
logger.error(f"❌ Device handling test failed: {e}")
import traceback
traceback.print_exc()
return False
async def test_orchestrator_training():
"""Test that orchestrator properly adds training samples"""
try:
logger.info("Testing orchestrator training integration...")
# Test 1: Initialize orchestrator
from core.orchestrator import TradingOrchestrator
from core.standardized_data_provider import StandardizedDataProvider
logger.info("Initializing data provider...")
data_provider = StandardizedDataProvider()
logger.info("Initializing orchestrator...")
orchestrator = TradingOrchestrator(data_provider=data_provider)
# Test 2: Check if CNN adapter is available
if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter:
logger.info(f"✅ CNN adapter available in orchestrator")
logger.info(f"Initial training samples: {len(orchestrator.cnn_adapter.training_data)}")
else:
logger.warning("⚠️ CNN adapter not available in orchestrator")
return False
# Test 3: Make a trading decision (this should add training samples)
logger.info("Making trading decision...")
decision = await orchestrator.make_trading_decision("ETH/USDT")
if decision:
logger.info(f"Decision: {decision.action} (confidence: {decision.confidence:.4f})")
logger.info(f"Training samples after decision: {len(orchestrator.cnn_adapter.training_data)}")
else:
logger.warning("No decision made")
# Test 4: Check inference history
logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
for model_name, history in orchestrator.inference_history.items():
logger.info(f" {model_name}: {len(history)} records")
logger.info("✅ Orchestrator training test passed!")
return True
except Exception as e:
logger.error(f"❌ Orchestrator training test failed: {e}")
import traceback
traceback.print_exc()
return False
async def main():
"""Run all tests"""
logger.info("Starting device and training fix tests...")
# Test 1: Device handling
test1_passed = test_device_handling()
# Test 2: Orchestrator training
test2_passed = await test_orchestrator_training()
# Summary
logger.info("\n" + "="*50)
logger.info("TEST SUMMARY:")
logger.info(f"Device handling: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
logger.info(f"Orchestrator training: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
if test1_passed and test2_passed:
logger.info("🎉 All tests passed! Device and training issues should be fixed.")
else:
logger.error("❌ Some tests failed. Please check the logs above.")
if __name__ == "__main__":
asyncio.run(main())