device tensor fix
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -48,3 +48,4 @@ chrome_user_data/*
|
||||
|
||||
.env
|
||||
.env
|
||||
training_data/*
|
||||
|
@ -70,6 +70,9 @@ class EnhancedCNNAdapter:
|
||||
else:
|
||||
self._load_best_checkpoint()
|
||||
|
||||
# Final device check and move
|
||||
self._ensure_model_on_device()
|
||||
|
||||
logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
|
||||
|
||||
def _initialize_model(self):
|
||||
@ -88,9 +91,10 @@ class EnhancedCNNAdapter:
|
||||
|
||||
# Create model
|
||||
self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
|
||||
# Ensure model is moved to the correct device
|
||||
self.model.to(self.device)
|
||||
|
||||
logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions}")
|
||||
logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions} on device {self.device}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing EnhancedCNN model: {e}")
|
||||
@ -102,7 +106,9 @@ class EnhancedCNNAdapter:
|
||||
if self.model and os.path.exists(checkpoint_path):
|
||||
success = self.model.load(checkpoint_path)
|
||||
if success:
|
||||
logger.info(f"Loaded model from {checkpoint_path}")
|
||||
# Ensure model is moved to the correct device after loading
|
||||
self.model.to(self.device)
|
||||
logger.info(f"Loaded model from {checkpoint_path} and moved to {self.device}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Failed to load model from {checkpoint_path}")
|
||||
@ -146,7 +152,9 @@ class EnhancedCNNAdapter:
|
||||
success = self.model.load(best_checkpoint_path)
|
||||
|
||||
if success:
|
||||
logger.info(f"Loaded best checkpoint from {best_checkpoint_path}")
|
||||
# Ensure model is moved to the correct device after loading
|
||||
self.model.to(self.device)
|
||||
logger.info(f"Loaded best checkpoint from {best_checkpoint_path} and moved to {self.device}")
|
||||
|
||||
# Log metrics
|
||||
metrics = best_checkpoint_metadata.get('metrics', {})
|
||||
@ -161,7 +169,17 @@ class EnhancedCNNAdapter:
|
||||
logger.error(f"Error loading best checkpoint: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _ensure_model_on_device(self):
|
||||
"""Ensure model and all its components are on the correct device"""
|
||||
try:
|
||||
if self.model:
|
||||
self.model.to(self.device)
|
||||
# Also ensure the model's internal device is set correctly
|
||||
if hasattr(self.model, 'device'):
|
||||
self.model.device = self.device
|
||||
logger.debug(f"Model ensured on device {self.device}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring model on device: {e}")
|
||||
|
||||
def _create_default_output(self, symbol: str) -> ModelOutput:
|
||||
"""Create default output when prediction fails"""
|
||||
@ -235,6 +253,9 @@ class EnhancedCNNAdapter:
|
||||
if features.dim() == 1:
|
||||
features = features.unsqueeze(0)
|
||||
|
||||
# Ensure model is on correct device before prediction
|
||||
self._ensure_model_on_device()
|
||||
|
||||
# Set model to evaluation mode
|
||||
self.model.eval()
|
||||
|
||||
@ -399,6 +420,9 @@ class EnhancedCNNAdapter:
|
||||
logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
|
||||
return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
|
||||
|
||||
# Ensure model is on correct device before training
|
||||
self._ensure_model_on_device()
|
||||
|
||||
# Set model to training mode
|
||||
self.model.train()
|
||||
|
||||
@ -423,8 +447,8 @@ class EnhancedCNNAdapter:
|
||||
if len(batch) < 2:
|
||||
continue
|
||||
|
||||
# Prepare batch
|
||||
features = torch.stack([sample[0] for sample in batch])
|
||||
# Prepare batch - ensure all tensors are on the correct device
|
||||
features = torch.stack([sample[0].to(self.device) for sample in batch])
|
||||
actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
|
||||
rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)
|
||||
|
||||
|
@ -299,12 +299,12 @@ class TradingOrchestrator:
|
||||
logger.warning("DQN Agent not available")
|
||||
self.rl_agent = None
|
||||
|
||||
# Initialize CNN Model
|
||||
# Initialize CNN Model with Adapter
|
||||
try:
|
||||
from NN.models.standardized_cnn import StandardizedCNN
|
||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
|
||||
self.cnn_model = StandardizedCNN()
|
||||
self.cnn_model.to(self.device) # Move CNN model to the determined device
|
||||
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||
self.cnn_model = self.cnn_adapter.model # Keep reference for compatibility
|
||||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN
|
||||
|
||||
# Load best checkpoint and capture initial state
|
||||
@ -332,11 +332,12 @@ class TradingOrchestrator:
|
||||
self.model_states['cnn']['best_loss'] = None
|
||||
logger.info("CNN starting fresh - no checkpoint found")
|
||||
|
||||
logger.info("Enhanced CNN model initialized")
|
||||
logger.info("Enhanced CNN adapter initialized")
|
||||
except ImportError:
|
||||
try:
|
||||
from NN.models.standardized_cnn import StandardizedCNN
|
||||
self.cnn_model = StandardizedCNN()
|
||||
self.cnn_adapter = None # No adapter available
|
||||
self.cnn_model.to(self.device) # Move basic CNN model to the determined device
|
||||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for basic CNN
|
||||
|
||||
@ -359,6 +360,7 @@ class TradingOrchestrator:
|
||||
except ImportError:
|
||||
logger.warning("CNN model not available")
|
||||
self.cnn_model = None
|
||||
self.cnn_adapter = None
|
||||
self.cnn_optimizer = None # Ensure optimizer is also None if model is not available
|
||||
|
||||
# Initialize Extrema Trainer
|
||||
@ -930,6 +932,11 @@ class TradingOrchestrator:
|
||||
if model.name not in self.model_performance:
|
||||
self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
|
||||
|
||||
# Initialize inference history for this model
|
||||
if model.name not in self.inference_history:
|
||||
self.inference_history[model.name] = deque(maxlen=self.max_memory_inferences)
|
||||
logger.debug(f"Initialized inference history for {model.name}")
|
||||
|
||||
logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}")
|
||||
self._normalize_weights()
|
||||
return True
|
||||
@ -1024,6 +1031,9 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in decision callback: {e}")
|
||||
|
||||
# Add training samples based on current market conditions
|
||||
await self._add_training_samples_from_predictions(symbol, predictions, current_price)
|
||||
|
||||
# Clean up memory periodically
|
||||
if len(self.recent_decisions[symbol]) % 200 == 0: # Reduced from 50 to 200
|
||||
self.model_registry.cleanup_all_models()
|
||||
@ -1034,6 +1044,47 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error making trading decision for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
async def _add_training_samples_from_predictions(self, symbol: str, predictions: List[Prediction], current_price: float):
|
||||
"""Add training samples to models based on current predictions and market conditions"""
|
||||
try:
|
||||
if not hasattr(self, 'cnn_adapter') or not self.cnn_adapter:
|
||||
return
|
||||
|
||||
# Get recent price data to evaluate if predictions would be correct
|
||||
recent_prices = self.data_provider.get_recent_prices(symbol, limit=10)
|
||||
if not recent_prices or len(recent_prices) < 2:
|
||||
return
|
||||
|
||||
# Calculate recent price change
|
||||
price_change_pct = (current_price - recent_prices[-2]) / recent_prices[-2] * 100
|
||||
|
||||
# Add training samples for CNN predictions
|
||||
for prediction in predictions:
|
||||
if 'cnn' in prediction.model_name.lower():
|
||||
# Determine reward based on prediction accuracy
|
||||
reward = 0.0
|
||||
|
||||
if prediction.action == 'BUY' and price_change_pct > 0.1:
|
||||
reward = min(price_change_pct * 0.1, 1.0) # Positive reward for correct BUY
|
||||
elif prediction.action == 'SELL' and price_change_pct < -0.1:
|
||||
reward = min(abs(price_change_pct) * 0.1, 1.0) # Positive reward for correct SELL
|
||||
elif prediction.action == 'HOLD' and abs(price_change_pct) < 0.1:
|
||||
reward = 0.1 # Small positive reward for correct HOLD
|
||||
else:
|
||||
reward = -0.05 # Small negative reward for incorrect prediction
|
||||
|
||||
# Add training sample
|
||||
self.cnn_adapter.add_training_sample(symbol, prediction.action, reward)
|
||||
logger.debug(f"Added CNN training sample: {prediction.action}, reward={reward:.3f}, price_change={price_change_pct:.2f}%")
|
||||
|
||||
# Trigger training if we have enough samples
|
||||
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
|
||||
training_results = self.cnn_adapter.train(epochs=1)
|
||||
logger.info(f"CNN training completed: loss={training_results.get('loss', 0):.4f}, accuracy={training_results.get('accuracy', 0):.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding training samples from predictions: {e}")
|
||||
|
||||
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
|
||||
"""Get predictions from all registered models with input data storage"""
|
||||
predictions = []
|
||||
@ -1051,8 +1102,12 @@ class TradingOrchestrator:
|
||||
# Get CNN predictions for each timeframe
|
||||
cnn_predictions = await self._get_cnn_predictions(model, symbol)
|
||||
predictions.extend(cnn_predictions)
|
||||
# Store input data for CNN
|
||||
# Store input data for CNN - store for each prediction
|
||||
model_input = input_data.get('cnn_input')
|
||||
if model_input is not None and cnn_predictions:
|
||||
# Store inference data for each CNN prediction
|
||||
for cnn_pred in cnn_predictions:
|
||||
await self._store_inference_data_async(model_name, model_input, cnn_pred, current_time)
|
||||
|
||||
elif isinstance(model, RLAgentInterface):
|
||||
# Get RL prediction
|
||||
@ -1062,6 +1117,8 @@ class TradingOrchestrator:
|
||||
prediction = rl_prediction
|
||||
# Store input data for RL
|
||||
model_input = input_data.get('rl_input')
|
||||
if model_input is not None:
|
||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time)
|
||||
|
||||
else:
|
||||
# Generic model interface
|
||||
@ -1071,15 +1128,20 @@ class TradingOrchestrator:
|
||||
prediction = generic_prediction
|
||||
# Store input data for generic model
|
||||
model_input = input_data.get('generic_input')
|
||||
|
||||
# Store inference data for training (per-model, async)
|
||||
if prediction and model_input is not None:
|
||||
if model_input is not None:
|
||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting prediction from {model_name}: {e}")
|
||||
continue
|
||||
|
||||
# Debug: Log inference history status (only if low record count)
|
||||
total_records = sum(len(history) for history in self.inference_history.values())
|
||||
if total_records < 10: # Only log when we have few records
|
||||
logger.debug(f"Total inference records across all models: {total_records}")
|
||||
for model_name, history in self.inference_history.items():
|
||||
logger.debug(f" {model_name}: {len(history)} records")
|
||||
|
||||
# Trigger training based on previous inference data
|
||||
await self._trigger_model_training(symbol)
|
||||
|
||||
@ -1130,7 +1192,15 @@ class TradingOrchestrator:
|
||||
}
|
||||
}
|
||||
|
||||
return standardized_input
|
||||
# Create model-specific input data
|
||||
model_inputs = {
|
||||
'cnn_input': standardized_input,
|
||||
'rl_input': standardized_input,
|
||||
'generic_input': standardized_input,
|
||||
'standardized_input': standardized_input
|
||||
}
|
||||
|
||||
return model_inputs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting standardized model input data: {e}")
|
||||
@ -1139,6 +1209,9 @@ class TradingOrchestrator:
|
||||
async def _store_inference_data_async(self, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime):
|
||||
"""Store inference data per-model with async file operations and memory optimization"""
|
||||
try:
|
||||
# Only log first few inference records to avoid spam
|
||||
if len(self.inference_history.get(model_name, [])) < 3:
|
||||
logger.debug(f"Storing inference data for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
|
||||
# Create comprehensive inference record
|
||||
inference_record = {
|
||||
'timestamp': timestamp.isoformat(),
|
||||
@ -1214,8 +1287,8 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error capping model files in {model_dir}: {e}")
|
||||
|
||||
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray:
|
||||
"""Prepare standardized input data for CNN models"""
|
||||
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor:
|
||||
"""Prepare standardized input data for CNN models with proper GPU device placement"""
|
||||
try:
|
||||
# Create feature matrix from OHLCV data
|
||||
features = []
|
||||
@ -1242,16 +1315,18 @@ class TradingOrchestrator:
|
||||
feature_array = np.pad(feature_array, (0, 300 - len(feature_array)), 'constant')
|
||||
else:
|
||||
feature_array = feature_array[:300]
|
||||
return feature_array.reshape(1, -1)
|
||||
# Convert to tensor and move to GPU
|
||||
return torch.tensor(feature_array.reshape(1, -1), dtype=torch.float32, device=self.device)
|
||||
else:
|
||||
return np.zeros((1, 300))
|
||||
# Return zero tensor on GPU
|
||||
return torch.zeros((1, 300), dtype=torch.float32, device=self.device)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing CNN input data: {e}")
|
||||
return np.zeros((1, 300))
|
||||
return torch.zeros((1, 300), dtype=torch.float32, device=self.device)
|
||||
|
||||
def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray:
|
||||
"""Prepare standardized input data for RL models"""
|
||||
def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor:
|
||||
"""Prepare standardized input data for RL models with proper GPU device placement"""
|
||||
try:
|
||||
# Create state representation
|
||||
state_features = []
|
||||
@ -1279,13 +1354,15 @@ class TradingOrchestrator:
|
||||
state_array = np.pad(state_array, (0, expected_size - len(state_array)), 'constant')
|
||||
else:
|
||||
state_array = state_array[:expected_size]
|
||||
return state_array
|
||||
# Convert to tensor and move to GPU
|
||||
return torch.tensor(state_array, dtype=torch.float32, device=self.device)
|
||||
else:
|
||||
return np.zeros(100)
|
||||
# Return zero tensor on GPU
|
||||
return torch.zeros(100, dtype=torch.float32, device=self.device)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing RL input data: {e}")
|
||||
return np.zeros(100)
|
||||
return torch.zeros(100, dtype=torch.float32, device=self.device)
|
||||
|
||||
def _store_inference_data(self, symbol: str, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime):
|
||||
"""Store comprehensive inference data for future training with persistent storage"""
|
||||
@ -1336,9 +1413,11 @@ class TradingOrchestrator:
|
||||
'outcome_evaluated': False
|
||||
}
|
||||
|
||||
# Store in memory (inference history)
|
||||
if symbol in self.inference_history:
|
||||
self.inference_history[symbol].append(inference_record)
|
||||
# Store in memory (inference history) - keyed by model_name
|
||||
if model_name not in self.inference_history:
|
||||
self.inference_history[model_name] = deque(maxlen=self.max_memory_inferences)
|
||||
|
||||
self.inference_history[model_name].append(inference_record)
|
||||
logger.debug(f"Stored inference data for {model_name} on {symbol}")
|
||||
|
||||
# Persistent storage to disk (for long-term training data)
|
||||
@ -1512,6 +1591,12 @@ class TradingOrchestrator:
|
||||
for model_name, model_records in self.inference_history.items():
|
||||
all_recent_records.extend(list(model_records))
|
||||
|
||||
# Only log if we have few records (for debugging)
|
||||
if len(all_recent_records) < 5:
|
||||
logger.debug(f"Total inference records for training: {len(all_recent_records)}")
|
||||
for model_name, model_records in self.inference_history.items():
|
||||
logger.debug(f" Model {model_name} has {len(model_records)} inference records")
|
||||
|
||||
if len(all_recent_records) < 2:
|
||||
logger.debug("Not enough inference records for training")
|
||||
return # Need at least 2 records to compare
|
||||
@ -1521,12 +1606,11 @@ class TradingOrchestrator:
|
||||
if current_price is None:
|
||||
return
|
||||
|
||||
# Process records that are old enough to evaluate outcomes
|
||||
cutoff_time = datetime.now() - timedelta(minutes=5) # 5 minutes ago
|
||||
|
||||
for record in recent_records:
|
||||
if record['timestamp'] < cutoff_time:
|
||||
await self._evaluate_and_train_on_record(record, current_price)
|
||||
# Train on the most recent inference record (last prediction made)
|
||||
if all_recent_records:
|
||||
# Get the most recent record for training
|
||||
most_recent_record = max(all_recent_records, key=lambda x: datetime.fromisoformat(x['timestamp']) if isinstance(x['timestamp'], str) else x['timestamp'])
|
||||
await self._evaluate_and_train_on_record(most_recent_record, current_price)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering model training for {symbol}: {e}")
|
||||
@ -1538,6 +1622,10 @@ class TradingOrchestrator:
|
||||
prediction = record['prediction']
|
||||
timestamp = record['timestamp']
|
||||
|
||||
# Convert timestamp string back to datetime if needed
|
||||
if isinstance(timestamp, str):
|
||||
timestamp = datetime.fromisoformat(timestamp)
|
||||
|
||||
# Calculate price change since prediction
|
||||
# This is a simplified outcome evaluation - you might want to make it more sophisticated
|
||||
time_diff = (datetime.now() - timestamp).total_seconds() / 60 # minutes
|
||||
@ -1608,9 +1696,20 @@ class TradingOrchestrator:
|
||||
)
|
||||
logger.debug(f"Added RL training experience: reward={reward}")
|
||||
|
||||
# Train CNN models
|
||||
elif 'cnn' in model_name.lower() and self.cnn_model:
|
||||
if hasattr(self.cnn_model, 'train_on_outcome'):
|
||||
# Train CNN models using adapter
|
||||
elif 'cnn' in model_name.lower() and hasattr(self, 'cnn_adapter') and self.cnn_adapter:
|
||||
# Use the adapter's add_training_sample method
|
||||
actual_action = prediction['action']
|
||||
self.cnn_adapter.add_training_sample(record['symbol'], actual_action, reward)
|
||||
logger.debug(f"Added CNN training sample: action={actual_action}, reward={reward}")
|
||||
|
||||
# Trigger training if we have enough samples
|
||||
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
|
||||
training_results = self.cnn_adapter.train(epochs=1)
|
||||
logger.debug(f"CNN training results: {training_results}")
|
||||
|
||||
# Fallback for raw CNN model
|
||||
elif 'cnn' in model_name.lower() and self.cnn_model and hasattr(self.cnn_model, 'train_on_outcome'):
|
||||
target = 1 if was_correct else 0
|
||||
self.cnn_model.train_on_outcome(model_input, target)
|
||||
logger.debug(f"Trained CNN on outcome: target={target}")
|
||||
@ -2260,8 +2359,8 @@ class TradingOrchestrator:
|
||||
return
|
||||
|
||||
if not ENHANCED_TRAINING_AVAILABLE:
|
||||
logger.warning("EnhancedRealtimeTrainingSystem not available - training disabled")
|
||||
self.training_enabled = False
|
||||
logger.info("EnhancedRealtimeTrainingSystem not available - using built-in training")
|
||||
# Keep training enabled - we have built-in training capabilities
|
||||
return
|
||||
|
||||
# Initialize the enhanced training system
|
||||
|
@ -451,3 +451,35 @@ class StandardizedDataProvider(DataProvider):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping real-time processing: {e}")
|
||||
|
||||
def get_recent_prices(self, symbol: str, limit: int = 10) -> List[float]:
|
||||
"""
|
||||
Get recent prices for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
limit: Number of recent prices to return
|
||||
|
||||
Returns:
|
||||
List[float]: List of recent prices
|
||||
"""
|
||||
try:
|
||||
# Get recent OHLCV data using parent class method
|
||||
df = self.get_historical_data(symbol, '1m', limit)
|
||||
if df is None or df.empty:
|
||||
return []
|
||||
|
||||
# Extract close prices from DataFrame
|
||||
if 'close' in df.columns:
|
||||
prices = df['close'].tolist()
|
||||
return prices[-limit:] # Return most recent prices
|
||||
else:
|
||||
logger.warning(f"No 'close' column found in OHLCV data for {symbol}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting recent prices for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping real-time processing: {e}")
|
141
test_device_fix.py
Normal file
141
test_device_fix.py
Normal file
@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify device mismatch fixes for GPU training
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the project root to the path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
from core.data_models import BaseDataInput, OHLCVBar
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_device_consistency():
|
||||
"""Test that all tensors are on the same device"""
|
||||
|
||||
logger.info("Testing device consistency for EnhancedCNN...")
|
||||
|
||||
# Check if CUDA is available
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
try:
|
||||
# Initialize the adapter
|
||||
adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||
|
||||
# Verify adapter device
|
||||
logger.info(f"Adapter device: {adapter.device}")
|
||||
logger.info(f"Model device: {next(adapter.model.parameters()).device}")
|
||||
|
||||
# Create sample data
|
||||
sample_ohlcv = [
|
||||
OHLCVBar(
|
||||
symbol="ETH/USDT",
|
||||
timeframe="1s",
|
||||
timestamp=1640995200.0, # 2022-01-01
|
||||
open=50000.0,
|
||||
high=51000.0,
|
||||
low=49000.0,
|
||||
close=50500.0,
|
||||
volume=1000.0
|
||||
)
|
||||
] * 300 # 300 frames
|
||||
|
||||
base_data = BaseDataInput(
|
||||
symbol="ETH/USDT",
|
||||
timestamp=1640995200.0,
|
||||
ohlcv_1s=sample_ohlcv,
|
||||
ohlcv_1m=sample_ohlcv,
|
||||
ohlcv_5m=sample_ohlcv,
|
||||
ohlcv_15m=sample_ohlcv,
|
||||
btc_ohlcv=sample_ohlcv,
|
||||
cob_data={},
|
||||
ma_data={},
|
||||
technical_indicators={},
|
||||
last_predictions={}
|
||||
)
|
||||
|
||||
# Test prediction
|
||||
logger.info("Testing prediction...")
|
||||
prediction = adapter.predict(base_data)
|
||||
logger.info(f"Prediction successful: {prediction.predictions['action']} (confidence: {prediction.confidence:.3f})")
|
||||
|
||||
# Test training sample addition
|
||||
logger.info("Testing training sample addition...")
|
||||
adapter.add_training_sample(base_data, "BUY", 0.1)
|
||||
adapter.add_training_sample(base_data, "SELL", -0.05)
|
||||
adapter.add_training_sample(base_data, "HOLD", 0.02)
|
||||
|
||||
# Test training
|
||||
logger.info("Testing training...")
|
||||
training_results = adapter.train(epochs=1)
|
||||
logger.info(f"Training results: {training_results}")
|
||||
|
||||
logger.info("✅ All device consistency tests passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Device consistency test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_orchestrator_inference_history():
|
||||
"""Test that orchestrator properly initializes inference history"""
|
||||
|
||||
logger.info("Testing orchestrator inference history initialization...")
|
||||
|
||||
try:
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Initialize orchestrator
|
||||
data_provider = DataProvider()
|
||||
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
||||
|
||||
# Check if inference history is initialized
|
||||
logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
|
||||
|
||||
# Check if models are registered
|
||||
logger.info(f"Registered models: {list(orchestrator.model_registry.models.keys())}")
|
||||
|
||||
# Verify each registered model has inference history
|
||||
for model_name in orchestrator.model_registry.models.keys():
|
||||
if model_name in orchestrator.inference_history:
|
||||
logger.info(f"✅ {model_name} has inference history initialized")
|
||||
else:
|
||||
logger.warning(f"❌ {model_name} missing inference history")
|
||||
|
||||
logger.info("✅ Orchestrator inference history test completed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Orchestrator test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting device fix verification tests...")
|
||||
|
||||
# Test 1: Device consistency
|
||||
test1_passed = test_device_consistency()
|
||||
|
||||
# Test 2: Orchestrator inference history
|
||||
test2_passed = test_orchestrator_inference_history()
|
||||
|
||||
# Summary
|
||||
if test1_passed and test2_passed:
|
||||
logger.info("🎉 All tests passed! Device issues should be fixed.")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("❌ Some tests failed. Please check the logs above.")
|
||||
sys.exit(1)
|
153
test_device_training_fix.py
Normal file
153
test_device_training_fix.py
Normal file
@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify device handling and training sample population fixes
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import torch
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_device_handling():
|
||||
"""Test that device handling is working correctly"""
|
||||
try:
|
||||
logger.info("Testing device handling...")
|
||||
|
||||
# Test 1: Check CUDA availability
|
||||
cuda_available = torch.cuda.is_available()
|
||||
device = torch.device("cuda" if cuda_available else "cpu")
|
||||
logger.info(f"CUDA available: {cuda_available}")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Test 2: Initialize CNN adapter
|
||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
|
||||
logger.info("Initializing CNN adapter...")
|
||||
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||
|
||||
logger.info(f"CNN adapter device: {cnn_adapter.device}")
|
||||
logger.info(f"CNN model device: {cnn_adapter.model.device}")
|
||||
|
||||
# Test 3: Create test data
|
||||
from core.data_models import BaseDataInput
|
||||
|
||||
logger.info("Creating test BaseDataInput...")
|
||||
base_data = BaseDataInput(
|
||||
symbol="ETH/USDT",
|
||||
timestamp=datetime.now(),
|
||||
ohlcv_1s=[],
|
||||
ohlcv_1m=[],
|
||||
ohlcv_1h=[],
|
||||
ohlcv_1d=[],
|
||||
btc_ohlcv_1s=[],
|
||||
cob_data=None,
|
||||
technical_indicators={},
|
||||
last_predictions={}
|
||||
)
|
||||
|
||||
# Test 4: Make prediction (this should not cause device mismatch)
|
||||
logger.info("Making prediction...")
|
||||
prediction = cnn_adapter.predict(base_data)
|
||||
|
||||
logger.info(f"Prediction successful: {prediction.predictions['action']}")
|
||||
logger.info(f"Confidence: {prediction.confidence:.4f}")
|
||||
|
||||
# Test 5: Add training samples
|
||||
logger.info("Adding training samples...")
|
||||
cnn_adapter.add_training_sample(base_data, "BUY", 0.1)
|
||||
cnn_adapter.add_training_sample(base_data, "SELL", -0.05)
|
||||
cnn_adapter.add_training_sample(base_data, "HOLD", 0.02)
|
||||
|
||||
logger.info(f"Training samples added: {len(cnn_adapter.training_data)}")
|
||||
|
||||
# Test 6: Try training if we have enough samples
|
||||
if len(cnn_adapter.training_data) >= 2:
|
||||
logger.info("Attempting training...")
|
||||
training_results = cnn_adapter.train(epochs=1)
|
||||
logger.info(f"Training results: {training_results}")
|
||||
else:
|
||||
logger.info("Not enough samples for training")
|
||||
|
||||
logger.info("✅ Device handling test passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Device handling test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def test_orchestrator_training():
|
||||
"""Test that orchestrator properly adds training samples"""
|
||||
try:
|
||||
logger.info("Testing orchestrator training integration...")
|
||||
|
||||
# Test 1: Initialize orchestrator
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.standardized_data_provider import StandardizedDataProvider
|
||||
|
||||
logger.info("Initializing data provider...")
|
||||
data_provider = StandardizedDataProvider()
|
||||
|
||||
logger.info("Initializing orchestrator...")
|
||||
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
||||
|
||||
# Test 2: Check if CNN adapter is available
|
||||
if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter:
|
||||
logger.info(f"✅ CNN adapter available in orchestrator")
|
||||
logger.info(f"Initial training samples: {len(orchestrator.cnn_adapter.training_data)}")
|
||||
else:
|
||||
logger.warning("⚠️ CNN adapter not available in orchestrator")
|
||||
return False
|
||||
|
||||
# Test 3: Make a trading decision (this should add training samples)
|
||||
logger.info("Making trading decision...")
|
||||
decision = await orchestrator.make_trading_decision("ETH/USDT")
|
||||
|
||||
if decision:
|
||||
logger.info(f"Decision: {decision.action} (confidence: {decision.confidence:.4f})")
|
||||
logger.info(f"Training samples after decision: {len(orchestrator.cnn_adapter.training_data)}")
|
||||
else:
|
||||
logger.warning("No decision made")
|
||||
|
||||
# Test 4: Check inference history
|
||||
logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
|
||||
for model_name, history in orchestrator.inference_history.items():
|
||||
logger.info(f" {model_name}: {len(history)} records")
|
||||
|
||||
logger.info("✅ Orchestrator training test passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Orchestrator training test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def main():
|
||||
"""Run all tests"""
|
||||
logger.info("Starting device and training fix tests...")
|
||||
|
||||
# Test 1: Device handling
|
||||
test1_passed = test_device_handling()
|
||||
|
||||
# Test 2: Orchestrator training
|
||||
test2_passed = await test_orchestrator_training()
|
||||
|
||||
# Summary
|
||||
logger.info("\n" + "="*50)
|
||||
logger.info("TEST SUMMARY:")
|
||||
logger.info(f"Device handling: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
|
||||
logger.info(f"Orchestrator training: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
|
||||
|
||||
if test1_passed and test2_passed:
|
||||
logger.info("🎉 All tests passed! Device and training issues should be fixed.")
|
||||
else:
|
||||
logger.error("❌ Some tests failed. Please check the logs above.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
Reference in New Issue
Block a user