diff --git a/.gitignore b/.gitignore index f1b39b3..9efaa78 100644 --- a/.gitignore +++ b/.gitignore @@ -48,3 +48,4 @@ chrome_user_data/* .env .env +training_data/* diff --git a/core/enhanced_cnn_adapter.py b/core/enhanced_cnn_adapter.py index 5428efb..c394649 100644 --- a/core/enhanced_cnn_adapter.py +++ b/core/enhanced_cnn_adapter.py @@ -70,6 +70,9 @@ class EnhancedCNNAdapter: else: self._load_best_checkpoint() + # Final device check and move + self._ensure_model_on_device() + logger.info(f"EnhancedCNNAdapter initialized on {self.device}") def _initialize_model(self): @@ -88,9 +91,10 @@ class EnhancedCNNAdapter: # Create model self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions) + # Ensure model is moved to the correct device self.model.to(self.device) - logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions}") + logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions} on device {self.device}") except Exception as e: logger.error(f"Error initializing EnhancedCNN model: {e}") @@ -102,7 +106,9 @@ class EnhancedCNNAdapter: if self.model and os.path.exists(checkpoint_path): success = self.model.load(checkpoint_path) if success: - logger.info(f"Loaded model from {checkpoint_path}") + # Ensure model is moved to the correct device after loading + self.model.to(self.device) + logger.info(f"Loaded model from {checkpoint_path} and moved to {self.device}") return True else: logger.warning(f"Failed to load model from {checkpoint_path}") @@ -146,7 +152,9 @@ class EnhancedCNNAdapter: success = self.model.load(best_checkpoint_path) if success: - logger.info(f"Loaded best checkpoint from {best_checkpoint_path}") + # Ensure model is moved to the correct device after loading + self.model.to(self.device) + logger.info(f"Loaded best checkpoint from {best_checkpoint_path} and moved to {self.device}") # Log metrics metrics = best_checkpoint_metadata.get('metrics', {}) @@ -161,7 +169,17 @@ class EnhancedCNNAdapter: logger.error(f"Error loading best checkpoint: {e}") return False - + def _ensure_model_on_device(self): + """Ensure model and all its components are on the correct device""" + try: + if self.model: + self.model.to(self.device) + # Also ensure the model's internal device is set correctly + if hasattr(self.model, 'device'): + self.model.device = self.device + logger.debug(f"Model ensured on device {self.device}") + except Exception as e: + logger.error(f"Error ensuring model on device: {e}") def _create_default_output(self, symbol: str) -> ModelOutput: """Create default output when prediction fails""" @@ -235,6 +253,9 @@ class EnhancedCNNAdapter: if features.dim() == 1: features = features.unsqueeze(0) + # Ensure model is on correct device before prediction + self._ensure_model_on_device() + # Set model to evaluation mode self.model.eval() @@ -399,6 +420,9 @@ class EnhancedCNNAdapter: logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}") return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)} + # Ensure model is on correct device before training + self._ensure_model_on_device() + # Set model to training mode self.model.train() @@ -423,8 +447,8 @@ class EnhancedCNNAdapter: if len(batch) < 2: continue - # Prepare batch - features = torch.stack([sample[0] for sample in batch]) + # Prepare batch - ensure all tensors are on the correct device + features = torch.stack([sample[0].to(self.device) for sample in batch]) actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device) rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device) diff --git a/core/orchestrator.py b/core/orchestrator.py index d664319..fa3bc50 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -299,12 +299,12 @@ class TradingOrchestrator: logger.warning("DQN Agent not available") self.rl_agent = None - # Initialize CNN Model + # Initialize CNN Model with Adapter try: - from NN.models.standardized_cnn import StandardizedCNN + from core.enhanced_cnn_adapter import EnhancedCNNAdapter - self.cnn_model = StandardizedCNN() - self.cnn_model.to(self.device) # Move CNN model to the determined device + self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn") + self.cnn_model = self.cnn_adapter.model # Keep reference for compatibility self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN # Load best checkpoint and capture initial state @@ -332,11 +332,12 @@ class TradingOrchestrator: self.model_states['cnn']['best_loss'] = None logger.info("CNN starting fresh - no checkpoint found") - logger.info("Enhanced CNN model initialized") + logger.info("Enhanced CNN adapter initialized") except ImportError: try: from NN.models.standardized_cnn import StandardizedCNN self.cnn_model = StandardizedCNN() + self.cnn_adapter = None # No adapter available self.cnn_model.to(self.device) # Move basic CNN model to the determined device self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for basic CNN @@ -359,6 +360,7 @@ class TradingOrchestrator: except ImportError: logger.warning("CNN model not available") self.cnn_model = None + self.cnn_adapter = None self.cnn_optimizer = None # Ensure optimizer is also None if model is not available # Initialize Extrema Trainer @@ -930,6 +932,11 @@ class TradingOrchestrator: if model.name not in self.model_performance: self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0} + # Initialize inference history for this model + if model.name not in self.inference_history: + self.inference_history[model.name] = deque(maxlen=self.max_memory_inferences) + logger.debug(f"Initialized inference history for {model.name}") + logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}") self._normalize_weights() return True @@ -1024,6 +1031,9 @@ class TradingOrchestrator: except Exception as e: logger.error(f"Error in decision callback: {e}") + # Add training samples based on current market conditions + await self._add_training_samples_from_predictions(symbol, predictions, current_price) + # Clean up memory periodically if len(self.recent_decisions[symbol]) % 200 == 0: # Reduced from 50 to 200 self.model_registry.cleanup_all_models() @@ -1034,6 +1044,47 @@ class TradingOrchestrator: logger.error(f"Error making trading decision for {symbol}: {e}") return None + async def _add_training_samples_from_predictions(self, symbol: str, predictions: List[Prediction], current_price: float): + """Add training samples to models based on current predictions and market conditions""" + try: + if not hasattr(self, 'cnn_adapter') or not self.cnn_adapter: + return + + # Get recent price data to evaluate if predictions would be correct + recent_prices = self.data_provider.get_recent_prices(symbol, limit=10) + if not recent_prices or len(recent_prices) < 2: + return + + # Calculate recent price change + price_change_pct = (current_price - recent_prices[-2]) / recent_prices[-2] * 100 + + # Add training samples for CNN predictions + for prediction in predictions: + if 'cnn' in prediction.model_name.lower(): + # Determine reward based on prediction accuracy + reward = 0.0 + + if prediction.action == 'BUY' and price_change_pct > 0.1: + reward = min(price_change_pct * 0.1, 1.0) # Positive reward for correct BUY + elif prediction.action == 'SELL' and price_change_pct < -0.1: + reward = min(abs(price_change_pct) * 0.1, 1.0) # Positive reward for correct SELL + elif prediction.action == 'HOLD' and abs(price_change_pct) < 0.1: + reward = 0.1 # Small positive reward for correct HOLD + else: + reward = -0.05 # Small negative reward for incorrect prediction + + # Add training sample + self.cnn_adapter.add_training_sample(symbol, prediction.action, reward) + logger.debug(f"Added CNN training sample: {prediction.action}, reward={reward:.3f}, price_change={price_change_pct:.2f}%") + + # Trigger training if we have enough samples + if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size: + training_results = self.cnn_adapter.train(epochs=1) + logger.info(f"CNN training completed: loss={training_results.get('loss', 0):.4f}, accuracy={training_results.get('accuracy', 0):.4f}") + + except Exception as e: + logger.error(f"Error adding training samples from predictions: {e}") + async def _get_all_predictions(self, symbol: str) -> List[Prediction]: """Get predictions from all registered models with input data storage""" predictions = [] @@ -1051,8 +1102,12 @@ class TradingOrchestrator: # Get CNN predictions for each timeframe cnn_predictions = await self._get_cnn_predictions(model, symbol) predictions.extend(cnn_predictions) - # Store input data for CNN + # Store input data for CNN - store for each prediction model_input = input_data.get('cnn_input') + if model_input is not None and cnn_predictions: + # Store inference data for each CNN prediction + for cnn_pred in cnn_predictions: + await self._store_inference_data_async(model_name, model_input, cnn_pred, current_time) elif isinstance(model, RLAgentInterface): # Get RL prediction @@ -1062,6 +1117,8 @@ class TradingOrchestrator: prediction = rl_prediction # Store input data for RL model_input = input_data.get('rl_input') + if model_input is not None: + await self._store_inference_data_async(model_name, model_input, prediction, current_time) else: # Generic model interface @@ -1071,15 +1128,20 @@ class TradingOrchestrator: prediction = generic_prediction # Store input data for generic model model_input = input_data.get('generic_input') - - # Store inference data for training (per-model, async) - if prediction and model_input is not None: - await self._store_inference_data_async(model_name, model_input, prediction, current_time) + if model_input is not None: + await self._store_inference_data_async(model_name, model_input, prediction, current_time) except Exception as e: logger.error(f"Error getting prediction from {model_name}: {e}") continue + # Debug: Log inference history status (only if low record count) + total_records = sum(len(history) for history in self.inference_history.values()) + if total_records < 10: # Only log when we have few records + logger.debug(f"Total inference records across all models: {total_records}") + for model_name, history in self.inference_history.items(): + logger.debug(f" {model_name}: {len(history)} records") + # Trigger training based on previous inference data await self._trigger_model_training(symbol) @@ -1130,7 +1192,15 @@ class TradingOrchestrator: } } - return standardized_input + # Create model-specific input data + model_inputs = { + 'cnn_input': standardized_input, + 'rl_input': standardized_input, + 'generic_input': standardized_input, + 'standardized_input': standardized_input + } + + return model_inputs except Exception as e: logger.error(f"Error collecting standardized model input data: {e}") @@ -1139,6 +1209,9 @@ class TradingOrchestrator: async def _store_inference_data_async(self, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime): """Store inference data per-model with async file operations and memory optimization""" try: + # Only log first few inference records to avoid spam + if len(self.inference_history.get(model_name, [])) < 3: + logger.debug(f"Storing inference data for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})") # Create comprehensive inference record inference_record = { 'timestamp': timestamp.isoformat(), @@ -1214,8 +1287,8 @@ class TradingOrchestrator: except Exception as e: logger.error(f"Error capping model files in {model_dir}: {e}") - def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray: - """Prepare standardized input data for CNN models""" + def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor: + """Prepare standardized input data for CNN models with proper GPU device placement""" try: # Create feature matrix from OHLCV data features = [] @@ -1242,16 +1315,18 @@ class TradingOrchestrator: feature_array = np.pad(feature_array, (0, 300 - len(feature_array)), 'constant') else: feature_array = feature_array[:300] - return feature_array.reshape(1, -1) + # Convert to tensor and move to GPU + return torch.tensor(feature_array.reshape(1, -1), dtype=torch.float32, device=self.device) else: - return np.zeros((1, 300)) + # Return zero tensor on GPU + return torch.zeros((1, 300), dtype=torch.float32, device=self.device) except Exception as e: logger.error(f"Error preparing CNN input data: {e}") - return np.zeros((1, 300)) + return torch.zeros((1, 300), dtype=torch.float32, device=self.device) - def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray: - """Prepare standardized input data for RL models""" + def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor: + """Prepare standardized input data for RL models with proper GPU device placement""" try: # Create state representation state_features = [] @@ -1279,13 +1354,15 @@ class TradingOrchestrator: state_array = np.pad(state_array, (0, expected_size - len(state_array)), 'constant') else: state_array = state_array[:expected_size] - return state_array + # Convert to tensor and move to GPU + return torch.tensor(state_array, dtype=torch.float32, device=self.device) else: - return np.zeros(100) + # Return zero tensor on GPU + return torch.zeros(100, dtype=torch.float32, device=self.device) except Exception as e: logger.error(f"Error preparing RL input data: {e}") - return np.zeros(100) + return torch.zeros(100, dtype=torch.float32, device=self.device) def _store_inference_data(self, symbol: str, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime): """Store comprehensive inference data for future training with persistent storage""" @@ -1336,10 +1413,12 @@ class TradingOrchestrator: 'outcome_evaluated': False } - # Store in memory (inference history) - if symbol in self.inference_history: - self.inference_history[symbol].append(inference_record) - logger.debug(f"Stored inference data for {model_name} on {symbol}") + # Store in memory (inference history) - keyed by model_name + if model_name not in self.inference_history: + self.inference_history[model_name] = deque(maxlen=self.max_memory_inferences) + + self.inference_history[model_name].append(inference_record) + logger.debug(f"Stored inference data for {model_name} on {symbol}") # Persistent storage to disk (for long-term training data) self._save_inference_to_disk(inference_record) @@ -1512,6 +1591,12 @@ class TradingOrchestrator: for model_name, model_records in self.inference_history.items(): all_recent_records.extend(list(model_records)) + # Only log if we have few records (for debugging) + if len(all_recent_records) < 5: + logger.debug(f"Total inference records for training: {len(all_recent_records)}") + for model_name, model_records in self.inference_history.items(): + logger.debug(f" Model {model_name} has {len(model_records)} inference records") + if len(all_recent_records) < 2: logger.debug("Not enough inference records for training") return # Need at least 2 records to compare @@ -1521,12 +1606,11 @@ class TradingOrchestrator: if current_price is None: return - # Process records that are old enough to evaluate outcomes - cutoff_time = datetime.now() - timedelta(minutes=5) # 5 minutes ago - - for record in recent_records: - if record['timestamp'] < cutoff_time: - await self._evaluate_and_train_on_record(record, current_price) + # Train on the most recent inference record (last prediction made) + if all_recent_records: + # Get the most recent record for training + most_recent_record = max(all_recent_records, key=lambda x: datetime.fromisoformat(x['timestamp']) if isinstance(x['timestamp'], str) else x['timestamp']) + await self._evaluate_and_train_on_record(most_recent_record, current_price) except Exception as e: logger.error(f"Error triggering model training for {symbol}: {e}") @@ -1538,6 +1622,10 @@ class TradingOrchestrator: prediction = record['prediction'] timestamp = record['timestamp'] + # Convert timestamp string back to datetime if needed + if isinstance(timestamp, str): + timestamp = datetime.fromisoformat(timestamp) + # Calculate price change since prediction # This is a simplified outcome evaluation - you might want to make it more sophisticated time_diff = (datetime.now() - timestamp).total_seconds() / 60 # minutes @@ -1608,12 +1696,23 @@ class TradingOrchestrator: ) logger.debug(f"Added RL training experience: reward={reward}") - # Train CNN models - elif 'cnn' in model_name.lower() and self.cnn_model: - if hasattr(self.cnn_model, 'train_on_outcome'): - target = 1 if was_correct else 0 - self.cnn_model.train_on_outcome(model_input, target) - logger.debug(f"Trained CNN on outcome: target={target}") + # Train CNN models using adapter + elif 'cnn' in model_name.lower() and hasattr(self, 'cnn_adapter') and self.cnn_adapter: + # Use the adapter's add_training_sample method + actual_action = prediction['action'] + self.cnn_adapter.add_training_sample(record['symbol'], actual_action, reward) + logger.debug(f"Added CNN training sample: action={actual_action}, reward={reward}") + + # Trigger training if we have enough samples + if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size: + training_results = self.cnn_adapter.train(epochs=1) + logger.debug(f"CNN training results: {training_results}") + + # Fallback for raw CNN model + elif 'cnn' in model_name.lower() and self.cnn_model and hasattr(self.cnn_model, 'train_on_outcome'): + target = 1 if was_correct else 0 + self.cnn_model.train_on_outcome(model_input, target) + logger.debug(f"Trained CNN on outcome: target={target}") except Exception as e: logger.error(f"Error training model on outcome: {e}") @@ -2260,8 +2359,8 @@ class TradingOrchestrator: return if not ENHANCED_TRAINING_AVAILABLE: - logger.warning("EnhancedRealtimeTrainingSystem not available - training disabled") - self.training_enabled = False + logger.info("EnhancedRealtimeTrainingSystem not available - using built-in training") + # Keep training enabled - we have built-in training capabilities return # Initialize the enhanced training system diff --git a/core/standardized_data_provider.py b/core/standardized_data_provider.py index 510acbf..0ae71ca 100644 --- a/core/standardized_data_provider.py +++ b/core/standardized_data_provider.py @@ -449,5 +449,37 @@ class StandardizedDataProvider(DataProvider): logger.info("Stopped real-time processing for standardized data") + except Exception as e: + logger.error(f"Error stopping real-time processing: {e}") + + def get_recent_prices(self, symbol: str, limit: int = 10) -> List[float]: + """ + Get recent prices for a symbol + + Args: + symbol: Trading symbol + limit: Number of recent prices to return + + Returns: + List[float]: List of recent prices + """ + try: + # Get recent OHLCV data using parent class method + df = self.get_historical_data(symbol, '1m', limit) + if df is None or df.empty: + return [] + + # Extract close prices from DataFrame + if 'close' in df.columns: + prices = df['close'].tolist() + return prices[-limit:] # Return most recent prices + else: + logger.warning(f"No 'close' column found in OHLCV data for {symbol}") + return [] + + except Exception as e: + logger.error(f"Error getting recent prices for {symbol}: {e}") + return [] + except Exception as e: logger.error(f"Error stopping real-time processing: {e}") \ No newline at end of file diff --git a/test_device_fix.py b/test_device_fix.py new file mode 100644 index 0000000..a43a029 --- /dev/null +++ b/test_device_fix.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +""" +Test script to verify device mismatch fixes for GPU training +""" + +import torch +import logging +import sys +import os + +# Add the project root to the path +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from core.enhanced_cnn_adapter import EnhancedCNNAdapter +from core.data_models import BaseDataInput, OHLCVBar + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def test_device_consistency(): + """Test that all tensors are on the same device""" + + logger.info("Testing device consistency for EnhancedCNN...") + + # Check if CUDA is available + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + logger.info(f"Using device: {device}") + + try: + # Initialize the adapter + adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn") + + # Verify adapter device + logger.info(f"Adapter device: {adapter.device}") + logger.info(f"Model device: {next(adapter.model.parameters()).device}") + + # Create sample data + sample_ohlcv = [ + OHLCVBar( + symbol="ETH/USDT", + timeframe="1s", + timestamp=1640995200.0, # 2022-01-01 + open=50000.0, + high=51000.0, + low=49000.0, + close=50500.0, + volume=1000.0 + ) + ] * 300 # 300 frames + + base_data = BaseDataInput( + symbol="ETH/USDT", + timestamp=1640995200.0, + ohlcv_1s=sample_ohlcv, + ohlcv_1m=sample_ohlcv, + ohlcv_5m=sample_ohlcv, + ohlcv_15m=sample_ohlcv, + btc_ohlcv=sample_ohlcv, + cob_data={}, + ma_data={}, + technical_indicators={}, + last_predictions={} + ) + + # Test prediction + logger.info("Testing prediction...") + prediction = adapter.predict(base_data) + logger.info(f"Prediction successful: {prediction.predictions['action']} (confidence: {prediction.confidence:.3f})") + + # Test training sample addition + logger.info("Testing training sample addition...") + adapter.add_training_sample(base_data, "BUY", 0.1) + adapter.add_training_sample(base_data, "SELL", -0.05) + adapter.add_training_sample(base_data, "HOLD", 0.02) + + # Test training + logger.info("Testing training...") + training_results = adapter.train(epochs=1) + logger.info(f"Training results: {training_results}") + + logger.info("✅ All device consistency tests passed!") + return True + + except Exception as e: + logger.error(f"❌ Device consistency test failed: {e}") + import traceback + traceback.print_exc() + return False + +def test_orchestrator_inference_history(): + """Test that orchestrator properly initializes inference history""" + + logger.info("Testing orchestrator inference history initialization...") + + try: + from core.orchestrator import TradingOrchestrator + from core.data_provider import DataProvider + + # Initialize orchestrator + data_provider = DataProvider() + orchestrator = TradingOrchestrator(data_provider=data_provider) + + # Check if inference history is initialized + logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}") + + # Check if models are registered + logger.info(f"Registered models: {list(orchestrator.model_registry.models.keys())}") + + # Verify each registered model has inference history + for model_name in orchestrator.model_registry.models.keys(): + if model_name in orchestrator.inference_history: + logger.info(f"✅ {model_name} has inference history initialized") + else: + logger.warning(f"❌ {model_name} missing inference history") + + logger.info("✅ Orchestrator inference history test completed!") + return True + + except Exception as e: + logger.error(f"❌ Orchestrator test failed: {e}") + import traceback + traceback.print_exc() + return False + +if __name__ == "__main__": + logger.info("Starting device fix verification tests...") + + # Test 1: Device consistency + test1_passed = test_device_consistency() + + # Test 2: Orchestrator inference history + test2_passed = test_orchestrator_inference_history() + + # Summary + if test1_passed and test2_passed: + logger.info("🎉 All tests passed! Device issues should be fixed.") + sys.exit(0) + else: + logger.error("❌ Some tests failed. Please check the logs above.") + sys.exit(1) \ No newline at end of file diff --git a/test_device_training_fix.py b/test_device_training_fix.py new file mode 100644 index 0000000..b883043 --- /dev/null +++ b/test_device_training_fix.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +Test script to verify device handling and training sample population fixes +""" + +import logging +import asyncio +import torch +from datetime import datetime + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +def test_device_handling(): + """Test that device handling is working correctly""" + try: + logger.info("Testing device handling...") + + # Test 1: Check CUDA availability + cuda_available = torch.cuda.is_available() + device = torch.device("cuda" if cuda_available else "cpu") + logger.info(f"CUDA available: {cuda_available}") + logger.info(f"Using device: {device}") + + # Test 2: Initialize CNN adapter + from core.enhanced_cnn_adapter import EnhancedCNNAdapter + + logger.info("Initializing CNN adapter...") + cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn") + + logger.info(f"CNN adapter device: {cnn_adapter.device}") + logger.info(f"CNN model device: {cnn_adapter.model.device}") + + # Test 3: Create test data + from core.data_models import BaseDataInput + + logger.info("Creating test BaseDataInput...") + base_data = BaseDataInput( + symbol="ETH/USDT", + timestamp=datetime.now(), + ohlcv_1s=[], + ohlcv_1m=[], + ohlcv_1h=[], + ohlcv_1d=[], + btc_ohlcv_1s=[], + cob_data=None, + technical_indicators={}, + last_predictions={} + ) + + # Test 4: Make prediction (this should not cause device mismatch) + logger.info("Making prediction...") + prediction = cnn_adapter.predict(base_data) + + logger.info(f"Prediction successful: {prediction.predictions['action']}") + logger.info(f"Confidence: {prediction.confidence:.4f}") + + # Test 5: Add training samples + logger.info("Adding training samples...") + cnn_adapter.add_training_sample(base_data, "BUY", 0.1) + cnn_adapter.add_training_sample(base_data, "SELL", -0.05) + cnn_adapter.add_training_sample(base_data, "HOLD", 0.02) + + logger.info(f"Training samples added: {len(cnn_adapter.training_data)}") + + # Test 6: Try training if we have enough samples + if len(cnn_adapter.training_data) >= 2: + logger.info("Attempting training...") + training_results = cnn_adapter.train(epochs=1) + logger.info(f"Training results: {training_results}") + else: + logger.info("Not enough samples for training") + + logger.info("✅ Device handling test passed!") + return True + + except Exception as e: + logger.error(f"❌ Device handling test failed: {e}") + import traceback + traceback.print_exc() + return False + +async def test_orchestrator_training(): + """Test that orchestrator properly adds training samples""" + try: + logger.info("Testing orchestrator training integration...") + + # Test 1: Initialize orchestrator + from core.orchestrator import TradingOrchestrator + from core.standardized_data_provider import StandardizedDataProvider + + logger.info("Initializing data provider...") + data_provider = StandardizedDataProvider() + + logger.info("Initializing orchestrator...") + orchestrator = TradingOrchestrator(data_provider=data_provider) + + # Test 2: Check if CNN adapter is available + if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter: + logger.info(f"✅ CNN adapter available in orchestrator") + logger.info(f"Initial training samples: {len(orchestrator.cnn_adapter.training_data)}") + else: + logger.warning("⚠️ CNN adapter not available in orchestrator") + return False + + # Test 3: Make a trading decision (this should add training samples) + logger.info("Making trading decision...") + decision = await orchestrator.make_trading_decision("ETH/USDT") + + if decision: + logger.info(f"Decision: {decision.action} (confidence: {decision.confidence:.4f})") + logger.info(f"Training samples after decision: {len(orchestrator.cnn_adapter.training_data)}") + else: + logger.warning("No decision made") + + # Test 4: Check inference history + logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}") + for model_name, history in orchestrator.inference_history.items(): + logger.info(f" {model_name}: {len(history)} records") + + logger.info("✅ Orchestrator training test passed!") + return True + + except Exception as e: + logger.error(f"❌ Orchestrator training test failed: {e}") + import traceback + traceback.print_exc() + return False + +async def main(): + """Run all tests""" + logger.info("Starting device and training fix tests...") + + # Test 1: Device handling + test1_passed = test_device_handling() + + # Test 2: Orchestrator training + test2_passed = await test_orchestrator_training() + + # Summary + logger.info("\n" + "="*50) + logger.info("TEST SUMMARY:") + logger.info(f"Device handling: {'✅ PASSED' if test1_passed else '❌ FAILED'}") + logger.info(f"Orchestrator training: {'✅ PASSED' if test2_passed else '❌ FAILED'}") + + if test1_passed and test2_passed: + logger.info("🎉 All tests passed! Device and training issues should be fixed.") + else: + logger.error("❌ Some tests failed. Please check the logs above.") + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file