device tensor fix

This commit is contained in:
Dobromir Popov
2025-07-25 13:59:33 +03:00
parent 78b4bb0f06
commit 1f60c80d67
6 changed files with 495 additions and 45 deletions

1
.gitignore vendored
View File

@ -48,3 +48,4 @@ chrome_user_data/*
.env .env
.env .env
training_data/*

View File

@ -70,6 +70,9 @@ class EnhancedCNNAdapter:
else: else:
self._load_best_checkpoint() self._load_best_checkpoint()
# Final device check and move
self._ensure_model_on_device()
logger.info(f"EnhancedCNNAdapter initialized on {self.device}") logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
def _initialize_model(self): def _initialize_model(self):
@ -88,9 +91,10 @@ class EnhancedCNNAdapter:
# Create model # Create model
self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions) self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
# Ensure model is moved to the correct device
self.model.to(self.device) self.model.to(self.device)
logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions}") logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions} on device {self.device}")
except Exception as e: except Exception as e:
logger.error(f"Error initializing EnhancedCNN model: {e}") logger.error(f"Error initializing EnhancedCNN model: {e}")
@ -102,7 +106,9 @@ class EnhancedCNNAdapter:
if self.model and os.path.exists(checkpoint_path): if self.model and os.path.exists(checkpoint_path):
success = self.model.load(checkpoint_path) success = self.model.load(checkpoint_path)
if success: if success:
logger.info(f"Loaded model from {checkpoint_path}") # Ensure model is moved to the correct device after loading
self.model.to(self.device)
logger.info(f"Loaded model from {checkpoint_path} and moved to {self.device}")
return True return True
else: else:
logger.warning(f"Failed to load model from {checkpoint_path}") logger.warning(f"Failed to load model from {checkpoint_path}")
@ -146,7 +152,9 @@ class EnhancedCNNAdapter:
success = self.model.load(best_checkpoint_path) success = self.model.load(best_checkpoint_path)
if success: if success:
logger.info(f"Loaded best checkpoint from {best_checkpoint_path}") # Ensure model is moved to the correct device after loading
self.model.to(self.device)
logger.info(f"Loaded best checkpoint from {best_checkpoint_path} and moved to {self.device}")
# Log metrics # Log metrics
metrics = best_checkpoint_metadata.get('metrics', {}) metrics = best_checkpoint_metadata.get('metrics', {})
@ -161,7 +169,17 @@ class EnhancedCNNAdapter:
logger.error(f"Error loading best checkpoint: {e}") logger.error(f"Error loading best checkpoint: {e}")
return False return False
def _ensure_model_on_device(self):
"""Ensure model and all its components are on the correct device"""
try:
if self.model:
self.model.to(self.device)
# Also ensure the model's internal device is set correctly
if hasattr(self.model, 'device'):
self.model.device = self.device
logger.debug(f"Model ensured on device {self.device}")
except Exception as e:
logger.error(f"Error ensuring model on device: {e}")
def _create_default_output(self, symbol: str) -> ModelOutput: def _create_default_output(self, symbol: str) -> ModelOutput:
"""Create default output when prediction fails""" """Create default output when prediction fails"""
@ -235,6 +253,9 @@ class EnhancedCNNAdapter:
if features.dim() == 1: if features.dim() == 1:
features = features.unsqueeze(0) features = features.unsqueeze(0)
# Ensure model is on correct device before prediction
self._ensure_model_on_device()
# Set model to evaluation mode # Set model to evaluation mode
self.model.eval() self.model.eval()
@ -399,6 +420,9 @@ class EnhancedCNNAdapter:
logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}") logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)} return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
# Ensure model is on correct device before training
self._ensure_model_on_device()
# Set model to training mode # Set model to training mode
self.model.train() self.model.train()
@ -423,8 +447,8 @@ class EnhancedCNNAdapter:
if len(batch) < 2: if len(batch) < 2:
continue continue
# Prepare batch # Prepare batch - ensure all tensors are on the correct device
features = torch.stack([sample[0] for sample in batch]) features = torch.stack([sample[0].to(self.device) for sample in batch])
actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device) actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device) rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)

View File

@ -299,12 +299,12 @@ class TradingOrchestrator:
logger.warning("DQN Agent not available") logger.warning("DQN Agent not available")
self.rl_agent = None self.rl_agent = None
# Initialize CNN Model # Initialize CNN Model with Adapter
try: try:
from NN.models.standardized_cnn import StandardizedCNN from core.enhanced_cnn_adapter import EnhancedCNNAdapter
self.cnn_model = StandardizedCNN() self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
self.cnn_model.to(self.device) # Move CNN model to the determined device self.cnn_model = self.cnn_adapter.model # Keep reference for compatibility
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN
# Load best checkpoint and capture initial state # Load best checkpoint and capture initial state
@ -332,11 +332,12 @@ class TradingOrchestrator:
self.model_states['cnn']['best_loss'] = None self.model_states['cnn']['best_loss'] = None
logger.info("CNN starting fresh - no checkpoint found") logger.info("CNN starting fresh - no checkpoint found")
logger.info("Enhanced CNN model initialized") logger.info("Enhanced CNN adapter initialized")
except ImportError: except ImportError:
try: try:
from NN.models.standardized_cnn import StandardizedCNN from NN.models.standardized_cnn import StandardizedCNN
self.cnn_model = StandardizedCNN() self.cnn_model = StandardizedCNN()
self.cnn_adapter = None # No adapter available
self.cnn_model.to(self.device) # Move basic CNN model to the determined device self.cnn_model.to(self.device) # Move basic CNN model to the determined device
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for basic CNN self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for basic CNN
@ -359,6 +360,7 @@ class TradingOrchestrator:
except ImportError: except ImportError:
logger.warning("CNN model not available") logger.warning("CNN model not available")
self.cnn_model = None self.cnn_model = None
self.cnn_adapter = None
self.cnn_optimizer = None # Ensure optimizer is also None if model is not available self.cnn_optimizer = None # Ensure optimizer is also None if model is not available
# Initialize Extrema Trainer # Initialize Extrema Trainer
@ -930,6 +932,11 @@ class TradingOrchestrator:
if model.name not in self.model_performance: if model.name not in self.model_performance:
self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0} self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
# Initialize inference history for this model
if model.name not in self.inference_history:
self.inference_history[model.name] = deque(maxlen=self.max_memory_inferences)
logger.debug(f"Initialized inference history for {model.name}")
logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}") logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}")
self._normalize_weights() self._normalize_weights()
return True return True
@ -1024,6 +1031,9 @@ class TradingOrchestrator:
except Exception as e: except Exception as e:
logger.error(f"Error in decision callback: {e}") logger.error(f"Error in decision callback: {e}")
# Add training samples based on current market conditions
await self._add_training_samples_from_predictions(symbol, predictions, current_price)
# Clean up memory periodically # Clean up memory periodically
if len(self.recent_decisions[symbol]) % 200 == 0: # Reduced from 50 to 200 if len(self.recent_decisions[symbol]) % 200 == 0: # Reduced from 50 to 200
self.model_registry.cleanup_all_models() self.model_registry.cleanup_all_models()
@ -1034,6 +1044,47 @@ class TradingOrchestrator:
logger.error(f"Error making trading decision for {symbol}: {e}") logger.error(f"Error making trading decision for {symbol}: {e}")
return None return None
async def _add_training_samples_from_predictions(self, symbol: str, predictions: List[Prediction], current_price: float):
"""Add training samples to models based on current predictions and market conditions"""
try:
if not hasattr(self, 'cnn_adapter') or not self.cnn_adapter:
return
# Get recent price data to evaluate if predictions would be correct
recent_prices = self.data_provider.get_recent_prices(symbol, limit=10)
if not recent_prices or len(recent_prices) < 2:
return
# Calculate recent price change
price_change_pct = (current_price - recent_prices[-2]) / recent_prices[-2] * 100
# Add training samples for CNN predictions
for prediction in predictions:
if 'cnn' in prediction.model_name.lower():
# Determine reward based on prediction accuracy
reward = 0.0
if prediction.action == 'BUY' and price_change_pct > 0.1:
reward = min(price_change_pct * 0.1, 1.0) # Positive reward for correct BUY
elif prediction.action == 'SELL' and price_change_pct < -0.1:
reward = min(abs(price_change_pct) * 0.1, 1.0) # Positive reward for correct SELL
elif prediction.action == 'HOLD' and abs(price_change_pct) < 0.1:
reward = 0.1 # Small positive reward for correct HOLD
else:
reward = -0.05 # Small negative reward for incorrect prediction
# Add training sample
self.cnn_adapter.add_training_sample(symbol, prediction.action, reward)
logger.debug(f"Added CNN training sample: {prediction.action}, reward={reward:.3f}, price_change={price_change_pct:.2f}%")
# Trigger training if we have enough samples
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
training_results = self.cnn_adapter.train(epochs=1)
logger.info(f"CNN training completed: loss={training_results.get('loss', 0):.4f}, accuracy={training_results.get('accuracy', 0):.4f}")
except Exception as e:
logger.error(f"Error adding training samples from predictions: {e}")
async def _get_all_predictions(self, symbol: str) -> List[Prediction]: async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
"""Get predictions from all registered models with input data storage""" """Get predictions from all registered models with input data storage"""
predictions = [] predictions = []
@ -1051,8 +1102,12 @@ class TradingOrchestrator:
# Get CNN predictions for each timeframe # Get CNN predictions for each timeframe
cnn_predictions = await self._get_cnn_predictions(model, symbol) cnn_predictions = await self._get_cnn_predictions(model, symbol)
predictions.extend(cnn_predictions) predictions.extend(cnn_predictions)
# Store input data for CNN # Store input data for CNN - store for each prediction
model_input = input_data.get('cnn_input') model_input = input_data.get('cnn_input')
if model_input is not None and cnn_predictions:
# Store inference data for each CNN prediction
for cnn_pred in cnn_predictions:
await self._store_inference_data_async(model_name, model_input, cnn_pred, current_time)
elif isinstance(model, RLAgentInterface): elif isinstance(model, RLAgentInterface):
# Get RL prediction # Get RL prediction
@ -1062,6 +1117,8 @@ class TradingOrchestrator:
prediction = rl_prediction prediction = rl_prediction
# Store input data for RL # Store input data for RL
model_input = input_data.get('rl_input') model_input = input_data.get('rl_input')
if model_input is not None:
await self._store_inference_data_async(model_name, model_input, prediction, current_time)
else: else:
# Generic model interface # Generic model interface
@ -1071,15 +1128,20 @@ class TradingOrchestrator:
prediction = generic_prediction prediction = generic_prediction
# Store input data for generic model # Store input data for generic model
model_input = input_data.get('generic_input') model_input = input_data.get('generic_input')
if model_input is not None:
# Store inference data for training (per-model, async) await self._store_inference_data_async(model_name, model_input, prediction, current_time)
if prediction and model_input is not None:
await self._store_inference_data_async(model_name, model_input, prediction, current_time)
except Exception as e: except Exception as e:
logger.error(f"Error getting prediction from {model_name}: {e}") logger.error(f"Error getting prediction from {model_name}: {e}")
continue continue
# Debug: Log inference history status (only if low record count)
total_records = sum(len(history) for history in self.inference_history.values())
if total_records < 10: # Only log when we have few records
logger.debug(f"Total inference records across all models: {total_records}")
for model_name, history in self.inference_history.items():
logger.debug(f" {model_name}: {len(history)} records")
# Trigger training based on previous inference data # Trigger training based on previous inference data
await self._trigger_model_training(symbol) await self._trigger_model_training(symbol)
@ -1130,7 +1192,15 @@ class TradingOrchestrator:
} }
} }
return standardized_input # Create model-specific input data
model_inputs = {
'cnn_input': standardized_input,
'rl_input': standardized_input,
'generic_input': standardized_input,
'standardized_input': standardized_input
}
return model_inputs
except Exception as e: except Exception as e:
logger.error(f"Error collecting standardized model input data: {e}") logger.error(f"Error collecting standardized model input data: {e}")
@ -1139,6 +1209,9 @@ class TradingOrchestrator:
async def _store_inference_data_async(self, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime): async def _store_inference_data_async(self, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime):
"""Store inference data per-model with async file operations and memory optimization""" """Store inference data per-model with async file operations and memory optimization"""
try: try:
# Only log first few inference records to avoid spam
if len(self.inference_history.get(model_name, [])) < 3:
logger.debug(f"Storing inference data for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
# Create comprehensive inference record # Create comprehensive inference record
inference_record = { inference_record = {
'timestamp': timestamp.isoformat(), 'timestamp': timestamp.isoformat(),
@ -1214,8 +1287,8 @@ class TradingOrchestrator:
except Exception as e: except Exception as e:
logger.error(f"Error capping model files in {model_dir}: {e}") logger.error(f"Error capping model files in {model_dir}: {e}")
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray: def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor:
"""Prepare standardized input data for CNN models""" """Prepare standardized input data for CNN models with proper GPU device placement"""
try: try:
# Create feature matrix from OHLCV data # Create feature matrix from OHLCV data
features = [] features = []
@ -1242,16 +1315,18 @@ class TradingOrchestrator:
feature_array = np.pad(feature_array, (0, 300 - len(feature_array)), 'constant') feature_array = np.pad(feature_array, (0, 300 - len(feature_array)), 'constant')
else: else:
feature_array = feature_array[:300] feature_array = feature_array[:300]
return feature_array.reshape(1, -1) # Convert to tensor and move to GPU
return torch.tensor(feature_array.reshape(1, -1), dtype=torch.float32, device=self.device)
else: else:
return np.zeros((1, 300)) # Return zero tensor on GPU
return torch.zeros((1, 300), dtype=torch.float32, device=self.device)
except Exception as e: except Exception as e:
logger.error(f"Error preparing CNN input data: {e}") logger.error(f"Error preparing CNN input data: {e}")
return np.zeros((1, 300)) return torch.zeros((1, 300), dtype=torch.float32, device=self.device)
def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray: def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor:
"""Prepare standardized input data for RL models""" """Prepare standardized input data for RL models with proper GPU device placement"""
try: try:
# Create state representation # Create state representation
state_features = [] state_features = []
@ -1279,13 +1354,15 @@ class TradingOrchestrator:
state_array = np.pad(state_array, (0, expected_size - len(state_array)), 'constant') state_array = np.pad(state_array, (0, expected_size - len(state_array)), 'constant')
else: else:
state_array = state_array[:expected_size] state_array = state_array[:expected_size]
return state_array # Convert to tensor and move to GPU
return torch.tensor(state_array, dtype=torch.float32, device=self.device)
else: else:
return np.zeros(100) # Return zero tensor on GPU
return torch.zeros(100, dtype=torch.float32, device=self.device)
except Exception as e: except Exception as e:
logger.error(f"Error preparing RL input data: {e}") logger.error(f"Error preparing RL input data: {e}")
return np.zeros(100) return torch.zeros(100, dtype=torch.float32, device=self.device)
def _store_inference_data(self, symbol: str, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime): def _store_inference_data(self, symbol: str, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime):
"""Store comprehensive inference data for future training with persistent storage""" """Store comprehensive inference data for future training with persistent storage"""
@ -1336,10 +1413,12 @@ class TradingOrchestrator:
'outcome_evaluated': False 'outcome_evaluated': False
} }
# Store in memory (inference history) # Store in memory (inference history) - keyed by model_name
if symbol in self.inference_history: if model_name not in self.inference_history:
self.inference_history[symbol].append(inference_record) self.inference_history[model_name] = deque(maxlen=self.max_memory_inferences)
logger.debug(f"Stored inference data for {model_name} on {symbol}")
self.inference_history[model_name].append(inference_record)
logger.debug(f"Stored inference data for {model_name} on {symbol}")
# Persistent storage to disk (for long-term training data) # Persistent storage to disk (for long-term training data)
self._save_inference_to_disk(inference_record) self._save_inference_to_disk(inference_record)
@ -1512,6 +1591,12 @@ class TradingOrchestrator:
for model_name, model_records in self.inference_history.items(): for model_name, model_records in self.inference_history.items():
all_recent_records.extend(list(model_records)) all_recent_records.extend(list(model_records))
# Only log if we have few records (for debugging)
if len(all_recent_records) < 5:
logger.debug(f"Total inference records for training: {len(all_recent_records)}")
for model_name, model_records in self.inference_history.items():
logger.debug(f" Model {model_name} has {len(model_records)} inference records")
if len(all_recent_records) < 2: if len(all_recent_records) < 2:
logger.debug("Not enough inference records for training") logger.debug("Not enough inference records for training")
return # Need at least 2 records to compare return # Need at least 2 records to compare
@ -1521,12 +1606,11 @@ class TradingOrchestrator:
if current_price is None: if current_price is None:
return return
# Process records that are old enough to evaluate outcomes # Train on the most recent inference record (last prediction made)
cutoff_time = datetime.now() - timedelta(minutes=5) # 5 minutes ago if all_recent_records:
# Get the most recent record for training
for record in recent_records: most_recent_record = max(all_recent_records, key=lambda x: datetime.fromisoformat(x['timestamp']) if isinstance(x['timestamp'], str) else x['timestamp'])
if record['timestamp'] < cutoff_time: await self._evaluate_and_train_on_record(most_recent_record, current_price)
await self._evaluate_and_train_on_record(record, current_price)
except Exception as e: except Exception as e:
logger.error(f"Error triggering model training for {symbol}: {e}") logger.error(f"Error triggering model training for {symbol}: {e}")
@ -1538,6 +1622,10 @@ class TradingOrchestrator:
prediction = record['prediction'] prediction = record['prediction']
timestamp = record['timestamp'] timestamp = record['timestamp']
# Convert timestamp string back to datetime if needed
if isinstance(timestamp, str):
timestamp = datetime.fromisoformat(timestamp)
# Calculate price change since prediction # Calculate price change since prediction
# This is a simplified outcome evaluation - you might want to make it more sophisticated # This is a simplified outcome evaluation - you might want to make it more sophisticated
time_diff = (datetime.now() - timestamp).total_seconds() / 60 # minutes time_diff = (datetime.now() - timestamp).total_seconds() / 60 # minutes
@ -1608,12 +1696,23 @@ class TradingOrchestrator:
) )
logger.debug(f"Added RL training experience: reward={reward}") logger.debug(f"Added RL training experience: reward={reward}")
# Train CNN models # Train CNN models using adapter
elif 'cnn' in model_name.lower() and self.cnn_model: elif 'cnn' in model_name.lower() and hasattr(self, 'cnn_adapter') and self.cnn_adapter:
if hasattr(self.cnn_model, 'train_on_outcome'): # Use the adapter's add_training_sample method
target = 1 if was_correct else 0 actual_action = prediction['action']
self.cnn_model.train_on_outcome(model_input, target) self.cnn_adapter.add_training_sample(record['symbol'], actual_action, reward)
logger.debug(f"Trained CNN on outcome: target={target}") logger.debug(f"Added CNN training sample: action={actual_action}, reward={reward}")
# Trigger training if we have enough samples
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
training_results = self.cnn_adapter.train(epochs=1)
logger.debug(f"CNN training results: {training_results}")
# Fallback for raw CNN model
elif 'cnn' in model_name.lower() and self.cnn_model and hasattr(self.cnn_model, 'train_on_outcome'):
target = 1 if was_correct else 0
self.cnn_model.train_on_outcome(model_input, target)
logger.debug(f"Trained CNN on outcome: target={target}")
except Exception as e: except Exception as e:
logger.error(f"Error training model on outcome: {e}") logger.error(f"Error training model on outcome: {e}")
@ -2260,8 +2359,8 @@ class TradingOrchestrator:
return return
if not ENHANCED_TRAINING_AVAILABLE: if not ENHANCED_TRAINING_AVAILABLE:
logger.warning("EnhancedRealtimeTrainingSystem not available - training disabled") logger.info("EnhancedRealtimeTrainingSystem not available - using built-in training")
self.training_enabled = False # Keep training enabled - we have built-in training capabilities
return return
# Initialize the enhanced training system # Initialize the enhanced training system

View File

@ -449,5 +449,37 @@ class StandardizedDataProvider(DataProvider):
logger.info("Stopped real-time processing for standardized data") logger.info("Stopped real-time processing for standardized data")
except Exception as e:
logger.error(f"Error stopping real-time processing: {e}")
def get_recent_prices(self, symbol: str, limit: int = 10) -> List[float]:
"""
Get recent prices for a symbol
Args:
symbol: Trading symbol
limit: Number of recent prices to return
Returns:
List[float]: List of recent prices
"""
try:
# Get recent OHLCV data using parent class method
df = self.get_historical_data(symbol, '1m', limit)
if df is None or df.empty:
return []
# Extract close prices from DataFrame
if 'close' in df.columns:
prices = df['close'].tolist()
return prices[-limit:] # Return most recent prices
else:
logger.warning(f"No 'close' column found in OHLCV data for {symbol}")
return []
except Exception as e:
logger.error(f"Error getting recent prices for {symbol}: {e}")
return []
except Exception as e: except Exception as e:
logger.error(f"Error stopping real-time processing: {e}") logger.error(f"Error stopping real-time processing: {e}")

141
test_device_fix.py Normal file
View File

@ -0,0 +1,141 @@
#!/usr/bin/env python3
"""
Test script to verify device mismatch fixes for GPU training
"""
import torch
import logging
import sys
import os
# Add the project root to the path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
from core.data_models import BaseDataInput, OHLCVBar
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_device_consistency():
"""Test that all tensors are on the same device"""
logger.info("Testing device consistency for EnhancedCNN...")
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
try:
# Initialize the adapter
adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
# Verify adapter device
logger.info(f"Adapter device: {adapter.device}")
logger.info(f"Model device: {next(adapter.model.parameters()).device}")
# Create sample data
sample_ohlcv = [
OHLCVBar(
symbol="ETH/USDT",
timeframe="1s",
timestamp=1640995200.0, # 2022-01-01
open=50000.0,
high=51000.0,
low=49000.0,
close=50500.0,
volume=1000.0
)
] * 300 # 300 frames
base_data = BaseDataInput(
symbol="ETH/USDT",
timestamp=1640995200.0,
ohlcv_1s=sample_ohlcv,
ohlcv_1m=sample_ohlcv,
ohlcv_5m=sample_ohlcv,
ohlcv_15m=sample_ohlcv,
btc_ohlcv=sample_ohlcv,
cob_data={},
ma_data={},
technical_indicators={},
last_predictions={}
)
# Test prediction
logger.info("Testing prediction...")
prediction = adapter.predict(base_data)
logger.info(f"Prediction successful: {prediction.predictions['action']} (confidence: {prediction.confidence:.3f})")
# Test training sample addition
logger.info("Testing training sample addition...")
adapter.add_training_sample(base_data, "BUY", 0.1)
adapter.add_training_sample(base_data, "SELL", -0.05)
adapter.add_training_sample(base_data, "HOLD", 0.02)
# Test training
logger.info("Testing training...")
training_results = adapter.train(epochs=1)
logger.info(f"Training results: {training_results}")
logger.info("✅ All device consistency tests passed!")
return True
except Exception as e:
logger.error(f"❌ Device consistency test failed: {e}")
import traceback
traceback.print_exc()
return False
def test_orchestrator_inference_history():
"""Test that orchestrator properly initializes inference history"""
logger.info("Testing orchestrator inference history initialization...")
try:
from core.orchestrator import TradingOrchestrator
from core.data_provider import DataProvider
# Initialize orchestrator
data_provider = DataProvider()
orchestrator = TradingOrchestrator(data_provider=data_provider)
# Check if inference history is initialized
logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
# Check if models are registered
logger.info(f"Registered models: {list(orchestrator.model_registry.models.keys())}")
# Verify each registered model has inference history
for model_name in orchestrator.model_registry.models.keys():
if model_name in orchestrator.inference_history:
logger.info(f"{model_name} has inference history initialized")
else:
logger.warning(f"{model_name} missing inference history")
logger.info("✅ Orchestrator inference history test completed!")
return True
except Exception as e:
logger.error(f"❌ Orchestrator test failed: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
logger.info("Starting device fix verification tests...")
# Test 1: Device consistency
test1_passed = test_device_consistency()
# Test 2: Orchestrator inference history
test2_passed = test_orchestrator_inference_history()
# Summary
if test1_passed and test2_passed:
logger.info("🎉 All tests passed! Device issues should be fixed.")
sys.exit(0)
else:
logger.error("❌ Some tests failed. Please check the logs above.")
sys.exit(1)

153
test_device_training_fix.py Normal file
View File

@ -0,0 +1,153 @@
#!/usr/bin/env python3
"""
Test script to verify device handling and training sample population fixes
"""
import logging
import asyncio
import torch
from datetime import datetime
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_device_handling():
"""Test that device handling is working correctly"""
try:
logger.info("Testing device handling...")
# Test 1: Check CUDA availability
cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")
logger.info(f"CUDA available: {cuda_available}")
logger.info(f"Using device: {device}")
# Test 2: Initialize CNN adapter
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
logger.info("Initializing CNN adapter...")
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
logger.info(f"CNN adapter device: {cnn_adapter.device}")
logger.info(f"CNN model device: {cnn_adapter.model.device}")
# Test 3: Create test data
from core.data_models import BaseDataInput
logger.info("Creating test BaseDataInput...")
base_data = BaseDataInput(
symbol="ETH/USDT",
timestamp=datetime.now(),
ohlcv_1s=[],
ohlcv_1m=[],
ohlcv_1h=[],
ohlcv_1d=[],
btc_ohlcv_1s=[],
cob_data=None,
technical_indicators={},
last_predictions={}
)
# Test 4: Make prediction (this should not cause device mismatch)
logger.info("Making prediction...")
prediction = cnn_adapter.predict(base_data)
logger.info(f"Prediction successful: {prediction.predictions['action']}")
logger.info(f"Confidence: {prediction.confidence:.4f}")
# Test 5: Add training samples
logger.info("Adding training samples...")
cnn_adapter.add_training_sample(base_data, "BUY", 0.1)
cnn_adapter.add_training_sample(base_data, "SELL", -0.05)
cnn_adapter.add_training_sample(base_data, "HOLD", 0.02)
logger.info(f"Training samples added: {len(cnn_adapter.training_data)}")
# Test 6: Try training if we have enough samples
if len(cnn_adapter.training_data) >= 2:
logger.info("Attempting training...")
training_results = cnn_adapter.train(epochs=1)
logger.info(f"Training results: {training_results}")
else:
logger.info("Not enough samples for training")
logger.info("✅ Device handling test passed!")
return True
except Exception as e:
logger.error(f"❌ Device handling test failed: {e}")
import traceback
traceback.print_exc()
return False
async def test_orchestrator_training():
"""Test that orchestrator properly adds training samples"""
try:
logger.info("Testing orchestrator training integration...")
# Test 1: Initialize orchestrator
from core.orchestrator import TradingOrchestrator
from core.standardized_data_provider import StandardizedDataProvider
logger.info("Initializing data provider...")
data_provider = StandardizedDataProvider()
logger.info("Initializing orchestrator...")
orchestrator = TradingOrchestrator(data_provider=data_provider)
# Test 2: Check if CNN adapter is available
if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter:
logger.info(f"✅ CNN adapter available in orchestrator")
logger.info(f"Initial training samples: {len(orchestrator.cnn_adapter.training_data)}")
else:
logger.warning("⚠️ CNN adapter not available in orchestrator")
return False
# Test 3: Make a trading decision (this should add training samples)
logger.info("Making trading decision...")
decision = await orchestrator.make_trading_decision("ETH/USDT")
if decision:
logger.info(f"Decision: {decision.action} (confidence: {decision.confidence:.4f})")
logger.info(f"Training samples after decision: {len(orchestrator.cnn_adapter.training_data)}")
else:
logger.warning("No decision made")
# Test 4: Check inference history
logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
for model_name, history in orchestrator.inference_history.items():
logger.info(f" {model_name}: {len(history)} records")
logger.info("✅ Orchestrator training test passed!")
return True
except Exception as e:
logger.error(f"❌ Orchestrator training test failed: {e}")
import traceback
traceback.print_exc()
return False
async def main():
"""Run all tests"""
logger.info("Starting device and training fix tests...")
# Test 1: Device handling
test1_passed = test_device_handling()
# Test 2: Orchestrator training
test2_passed = await test_orchestrator_training()
# Summary
logger.info("\n" + "="*50)
logger.info("TEST SUMMARY:")
logger.info(f"Device handling: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
logger.info(f"Orchestrator training: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
if test1_passed and test2_passed:
logger.info("🎉 All tests passed! Device and training issues should be fixed.")
else:
logger.error("❌ Some tests failed. Please check the logs above.")
if __name__ == "__main__":
asyncio.run(main())