4 Commits

Author SHA1 Message Date
26eeb9b35b ACTUAL TRAINING WORKING (WIP) 2025-07-25 14:08:25 +03:00
1f60c80d67 device tensor fix 2025-07-25 13:59:33 +03:00
78b4bb0f06 wip, training still disabled 2025-07-24 16:20:37 +03:00
045780758a wip symbols tidy up 2025-07-24 16:08:58 +03:00
9 changed files with 704 additions and 111 deletions

1
.gitignore vendored
View File

@ -48,3 +48,4 @@ chrome_user_data/*
.env .env
.env .env
training_data/*

View File

@ -74,6 +74,16 @@ Based on the existing implementation in `core/data_provider.py`, we'll enhance i
- 1,5,15 and 60s MA of the COB imbalance counting +- 5 COB buckets - 1,5,15 and 60s MA of the COB imbalance counting +- 5 COB buckets
- ***OUTPUTS***: suggested trade action (BUY/SELL) - ***OUTPUTS***: suggested trade action (BUY/SELL)
# Standardized input for all models:
{
'primary_symbol': 'ETH/USDT',
'reference_symbol': 'BTC/USDT',
'eth_data': {'ETH_1s': df, 'ETH_1m': df, 'ETH_1h': df, 'ETH_1d': df},
'btc_data': {'BTC_1s': df},
'current_prices': {'ETH': price, 'BTC': price},
'data_completeness': {...}
}
### 2. CNN Model ### 2. CNN Model
The CNN Model is responsible for analyzing patterns in market data and predicting pivot points across multiple timeframes. The CNN Model is responsible for analyzing patterns in market data and predicting pivot points across multiple timeframes.

View File

@ -197,7 +197,9 @@
- Ensure validation occurs before any model inference - Ensure validation occurs before any model inference
- _Requirements: 9.1, 9.4_ - _Requirements: 9.1, 9.4_
- [ ] 5.2. Implement persistent inference history storage - [x] 5.2. Implement persistent inference history storage
- Create InferenceHistoryStore class for persistent storage - Create InferenceHistoryStore class for persistent storage
- Store complete input data packages with each prediction - Store complete input data packages with each prediction
- Include timestamp, symbol, input features, prediction outputs, confidence scores - Include timestamp, symbol, input features, prediction outputs, confidence scores

View File

@ -70,6 +70,9 @@ class EnhancedCNNAdapter:
else: else:
self._load_best_checkpoint() self._load_best_checkpoint()
# Final device check and move
self._ensure_model_on_device()
logger.info(f"EnhancedCNNAdapter initialized on {self.device}") logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
def _initialize_model(self): def _initialize_model(self):
@ -88,9 +91,10 @@ class EnhancedCNNAdapter:
# Create model # Create model
self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions) self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
# Ensure model is moved to the correct device
self.model.to(self.device) self.model.to(self.device)
logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions}") logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions} on device {self.device}")
except Exception as e: except Exception as e:
logger.error(f"Error initializing EnhancedCNN model: {e}") logger.error(f"Error initializing EnhancedCNN model: {e}")
@ -102,7 +106,9 @@ class EnhancedCNNAdapter:
if self.model and os.path.exists(checkpoint_path): if self.model and os.path.exists(checkpoint_path):
success = self.model.load(checkpoint_path) success = self.model.load(checkpoint_path)
if success: if success:
logger.info(f"Loaded model from {checkpoint_path}") # Ensure model is moved to the correct device after loading
self.model.to(self.device)
logger.info(f"Loaded model from {checkpoint_path} and moved to {self.device}")
return True return True
else: else:
logger.warning(f"Failed to load model from {checkpoint_path}") logger.warning(f"Failed to load model from {checkpoint_path}")
@ -146,7 +152,9 @@ class EnhancedCNNAdapter:
success = self.model.load(best_checkpoint_path) success = self.model.load(best_checkpoint_path)
if success: if success:
logger.info(f"Loaded best checkpoint from {best_checkpoint_path}") # Ensure model is moved to the correct device after loading
self.model.to(self.device)
logger.info(f"Loaded best checkpoint from {best_checkpoint_path} and moved to {self.device}")
# Log metrics # Log metrics
metrics = best_checkpoint_metadata.get('metrics', {}) metrics = best_checkpoint_metadata.get('metrics', {})
@ -161,7 +169,17 @@ class EnhancedCNNAdapter:
logger.error(f"Error loading best checkpoint: {e}") logger.error(f"Error loading best checkpoint: {e}")
return False return False
def _ensure_model_on_device(self):
"""Ensure model and all its components are on the correct device"""
try:
if self.model:
self.model.to(self.device)
# Also ensure the model's internal device is set correctly
if hasattr(self.model, 'device'):
self.model.device = self.device
logger.debug(f"Model ensured on device {self.device}")
except Exception as e:
logger.error(f"Error ensuring model on device: {e}")
def _create_default_output(self, symbol: str) -> ModelOutput: def _create_default_output(self, symbol: str) -> ModelOutput:
"""Create default output when prediction fails""" """Create default output when prediction fails"""
@ -235,6 +253,9 @@ class EnhancedCNNAdapter:
if features.dim() == 1: if features.dim() == 1:
features = features.unsqueeze(0) features = features.unsqueeze(0)
# Ensure model is on correct device before prediction
self._ensure_model_on_device()
# Set model to evaluation mode # Set model to evaluation mode
self.model.eval() self.model.eval()
@ -399,6 +420,9 @@ class EnhancedCNNAdapter:
logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}") logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)} return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
# Ensure model is on correct device before training
self._ensure_model_on_device()
# Set model to training mode # Set model to training mode
self.model.train() self.model.train()
@ -423,8 +447,8 @@ class EnhancedCNNAdapter:
if len(batch) < 2: if len(batch) < 2:
continue continue
# Prepare batch # Prepare batch - ensure all tensors are on the correct device
features = torch.stack([sample[0] for sample in batch]) features = torch.stack([sample[0].to(self.device) for sample in batch])
actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device) actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device) rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)

View File

@ -110,7 +110,9 @@ class TradingOrchestrator:
self.confidence_threshold_close = self.config.orchestrator.get('confidence_threshold_close', 0.08) # Lowered from 0.10 self.confidence_threshold_close = self.config.orchestrator.get('confidence_threshold_close', 0.08) # Lowered from 0.10
# Decision frequency limit to prevent excessive trading # Decision frequency limit to prevent excessive trading
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30) self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30)
self.symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) # Enhanced to support multiple symbols
self.symbol = self.config.get('symbol', "ETH/USDT") # main symbol we wre trading and making predictions on. only one!
self.ref_symbols = self.config.get('ref_symbols', [ 'BTC/USDT']) # Enhanced to support multiple reference symbols. ToDo: we can add 'SOL/USDT' later
# NEW: Aggressiveness parameters # NEW: Aggressiveness parameters
self.entry_aggressiveness = self.config.orchestrator.get('entry_aggressiveness', 0.5) # 0.0 = conservative, 1.0 = very aggressive self.entry_aggressiveness = self.config.orchestrator.get('entry_aggressiveness', 0.5) # 0.0 = conservative, 1.0 = very aggressive
@ -153,12 +155,11 @@ class TradingOrchestrator:
self.recent_cnn_predictions: Dict[str, deque] = {} # {symbol: List[Dict]} - Recent CNN predictions self.recent_cnn_predictions: Dict[str, deque] = {} # {symbol: List[Dict]} - Recent CNN predictions
self.prediction_accuracy_history: Dict[str, deque] = {} # {symbol: List[Dict]} - Prediction accuracy tracking self.prediction_accuracy_history: Dict[str, deque] = {} # {symbol: List[Dict]} - Prediction accuracy tracking
# Initialize prediction tracking for each symbol # Initialize prediction tracking for the primary trading symbol only
for symbol in self.symbols: self.recent_dqn_predictions[self.symbol] = deque(maxlen=100)
self.recent_dqn_predictions[symbol] = deque(maxlen=100) self.recent_cnn_predictions[self.symbol] = deque(maxlen=50)
self.recent_cnn_predictions[symbol] = deque(maxlen=50) self.prediction_accuracy_history[self.symbol] = deque(maxlen=200)
self.prediction_accuracy_history[symbol] = deque(maxlen=200) self.signal_accumulator[self.symbol] = []
self.signal_accumulator[symbol] = []
# Decision callbacks # Decision callbacks
self.decision_callbacks: List[Any] = [] self.decision_callbacks: List[Any] = []
@ -177,7 +178,7 @@ class TradingOrchestrator:
self.latest_cob_data: Dict[str, Any] = {} # {symbol: COBSnapshot} self.latest_cob_data: Dict[str, Any] = {} # {symbol: COBSnapshot}
self.latest_cob_features: Dict[str, Any] = {} # {symbol: np.ndarray} - CNN features self.latest_cob_features: Dict[str, Any] = {} # {symbol: np.ndarray} - CNN features
self.latest_cob_state: Dict[str, Any] = {} # {symbol: np.ndarray} - DQN state features self.latest_cob_state: Dict[str, Any] = {} # {symbol: np.ndarray} - DQN state features
self.cob_feature_history: Dict[str, List[Any]] = {symbol: [] for symbol in self.symbols} # Rolling history for models self.cob_feature_history: Dict[str, List[Any]] = {self.symbol: []} # Rolling history for primary trading symbol
# Enhanced ML Models # Enhanced ML Models
self.rl_agent: Any = None # DQN Agent self.rl_agent: Any = None # DQN Agent
@ -204,13 +205,13 @@ class TradingOrchestrator:
# Training tracking # Training tracking
self.last_trained_symbols: Dict[str, datetime] = {} self.last_trained_symbols: Dict[str, datetime] = {}
# INFERENCE DATA STORAGE - Store model inputs and outputs for training # INFERENCE DATA STORAGE - Per-model storage with memory optimization
self.inference_history: Dict[str, deque] = {} # {symbol: deque of inference records} self.inference_history: Dict[str, deque] = {} # {model_name: deque of last 5 inference records}
self.max_inference_history = 1000 # Keep last 1000 inference records per symbol self.max_memory_inferences = 5 # Keep only last 5 inferences in memory per model
self.max_disk_files_per_model = 200 # Cap disk files per model
# Initialize inference history for each symbol # Initialize inference history for each model (will be populated as models make predictions)
for symbol in self.symbols: # We'll create entries dynamically as models are used
self.inference_history[symbol] = deque(maxlen=self.max_inference_history)
# ENHANCED: Real-time Training System Integration # ENHANCED: Real-time Training System Integration
self.enhanced_training_system = None # Will be set to EnhancedRealtimeTrainingSystem if available self.enhanced_training_system = None # Will be set to EnhancedRealtimeTrainingSystem if available
@ -223,7 +224,7 @@ class TradingOrchestrator:
logger.info(f"Training enabled: {self.training_enabled}") logger.info(f"Training enabled: {self.training_enabled}")
logger.info(f"Confidence threshold: {self.confidence_threshold}") logger.info(f"Confidence threshold: {self.confidence_threshold}")
# logger.info(f"Decision frequency: {self.decision_frequency}s") # logger.info(f"Decision frequency: {self.decision_frequency}s")
logger.info(f"Symbols: {self.symbols}") logger.info(f"Primary symbol: {self.symbol}, Reference symbols: {self.ref_symbols}")
logger.info("Universal Data Adapter integrated for centralized data flow") logger.info("Universal Data Adapter integrated for centralized data flow")
# Start centralized data collection for all models and dashboard # Start centralized data collection for all models and dashboard
@ -298,12 +299,12 @@ class TradingOrchestrator:
logger.warning("DQN Agent not available") logger.warning("DQN Agent not available")
self.rl_agent = None self.rl_agent = None
# Initialize CNN Model # Initialize CNN Model with Adapter
try: try:
from NN.models.standardized_cnn import StandardizedCNN from core.enhanced_cnn_adapter import EnhancedCNNAdapter
self.cnn_model = StandardizedCNN() self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
self.cnn_model.to(self.device) # Move CNN model to the determined device self.cnn_model = self.cnn_adapter.model # Keep reference for compatibility
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN
# Load best checkpoint and capture initial state # Load best checkpoint and capture initial state
@ -331,11 +332,12 @@ class TradingOrchestrator:
self.model_states['cnn']['best_loss'] = None self.model_states['cnn']['best_loss'] = None
logger.info("CNN starting fresh - no checkpoint found") logger.info("CNN starting fresh - no checkpoint found")
logger.info("Enhanced CNN model initialized") logger.info("Enhanced CNN adapter initialized")
except ImportError: except ImportError:
try: try:
from NN.models.standardized_cnn import StandardizedCNN from NN.models.standardized_cnn import StandardizedCNN
self.cnn_model = StandardizedCNN() self.cnn_model = StandardizedCNN()
self.cnn_adapter = None # No adapter available
self.cnn_model.to(self.device) # Move basic CNN model to the determined device self.cnn_model.to(self.device) # Move basic CNN model to the determined device
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for basic CNN self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for basic CNN
@ -358,6 +360,7 @@ class TradingOrchestrator:
except ImportError: except ImportError:
logger.warning("CNN model not available") logger.warning("CNN model not available")
self.cnn_model = None self.cnn_model = None
self.cnn_adapter = None
self.cnn_optimizer = None # Ensure optimizer is also None if model is not available self.cnn_optimizer = None # Ensure optimizer is also None if model is not available
# Initialize Extrema Trainer # Initialize Extrema Trainer
@ -365,7 +368,7 @@ class TradingOrchestrator:
from core.extrema_trainer import ExtremaTrainer from core.extrema_trainer import ExtremaTrainer
self.extrema_trainer = ExtremaTrainer( self.extrema_trainer = ExtremaTrainer(
data_provider=self.data_provider, data_provider=self.data_provider,
symbols=self.symbols symbols=[self.symbol] # Only primary trading symbol
) )
# Load checkpoint and capture initial state # Load checkpoint and capture initial state
@ -617,7 +620,7 @@ class TradingOrchestrator:
async def start_continuous_trading(self, symbols: Optional[List[str]] = None): async def start_continuous_trading(self, symbols: Optional[List[str]] = None):
"""Start the continuous trading loop, using a decision model and trading executor""" """Start the continuous trading loop, using a decision model and trading executor"""
if symbols is None: if symbols is None:
symbols = self.symbols symbols = [self.symbol] # Only trade the primary symbol
if not self.realtime_processing_task: if not self.realtime_processing_task:
self.realtime_processing_task = asyncio.create_task(self._trading_decision_loop()) self.realtime_processing_task = asyncio.create_task(self._trading_decision_loop())
@ -638,9 +641,9 @@ class TradingOrchestrator:
logger.info("Trading decision loop started") logger.info("Trading decision loop started")
while self.running: while self.running:
try: try:
for symbol in self.symbols: # Only make decisions for the primary trading symbol
await self.make_trading_decision(symbol) await self.make_trading_decision(self.symbol)
await asyncio.sleep(1) # Small delay between symbols await asyncio.sleep(1)
await asyncio.sleep(self.decision_frequency) await asyncio.sleep(self.decision_frequency)
except Exception as e: except Exception as e:
@ -767,7 +770,7 @@ class TradingOrchestrator:
if COB_INTEGRATION_AVAILABLE and COBIntegration is not None: if COB_INTEGRATION_AVAILABLE and COBIntegration is not None:
try: try:
self.cob_integration = COBIntegration( self.cob_integration = COBIntegration(
symbols=self.symbols, symbols=[self.symbol] + self.ref_symbols, # Primary + reference symbols
data_provider=self.data_provider data_provider=self.data_provider
) )
logger.info("COB Integration initialized") logger.info("COB Integration initialized")
@ -929,6 +932,11 @@ class TradingOrchestrator:
if model.name not in self.model_performance: if model.name not in self.model_performance:
self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0} self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
# Initialize inference history for this model
if model.name not in self.inference_history:
self.inference_history[model.name] = deque(maxlen=self.max_memory_inferences)
logger.debug(f"Initialized inference history for {model.name}")
logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}") logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}")
self._normalize_weights() self._normalize_weights()
return True return True
@ -1023,6 +1031,9 @@ class TradingOrchestrator:
except Exception as e: except Exception as e:
logger.error(f"Error in decision callback: {e}") logger.error(f"Error in decision callback: {e}")
# Add training samples based on current market conditions
await self._add_training_samples_from_predictions(symbol, predictions, current_price)
# Clean up memory periodically # Clean up memory periodically
if len(self.recent_decisions[symbol]) % 200 == 0: # Reduced from 50 to 200 if len(self.recent_decisions[symbol]) % 200 == 0: # Reduced from 50 to 200
self.model_registry.cleanup_all_models() self.model_registry.cleanup_all_models()
@ -1033,6 +1044,47 @@ class TradingOrchestrator:
logger.error(f"Error making trading decision for {symbol}: {e}") logger.error(f"Error making trading decision for {symbol}: {e}")
return None return None
async def _add_training_samples_from_predictions(self, symbol: str, predictions: List[Prediction], current_price: float):
"""Add training samples to models based on current predictions and market conditions"""
try:
if not hasattr(self, 'cnn_adapter') or not self.cnn_adapter:
return
# Get recent price data to evaluate if predictions would be correct
recent_prices = self.data_provider.get_recent_prices(symbol, limit=10)
if not recent_prices or len(recent_prices) < 2:
return
# Calculate recent price change
price_change_pct = (current_price - recent_prices[-2]) / recent_prices[-2] * 100
# Add training samples for CNN predictions
for prediction in predictions:
if 'cnn' in prediction.model_name.lower():
# Determine reward based on prediction accuracy
reward = 0.0
if prediction.action == 'BUY' and price_change_pct > 0.1:
reward = min(price_change_pct * 0.1, 1.0) # Positive reward for correct BUY
elif prediction.action == 'SELL' and price_change_pct < -0.1:
reward = min(abs(price_change_pct) * 0.1, 1.0) # Positive reward for correct SELL
elif prediction.action == 'HOLD' and abs(price_change_pct) < 0.1:
reward = 0.1 # Small positive reward for correct HOLD
else:
reward = -0.05 # Small negative reward for incorrect prediction
# Add training sample
self.cnn_adapter.add_training_sample(symbol, prediction.action, reward)
logger.debug(f"Added CNN training sample: {prediction.action}, reward={reward:.3f}, price_change={price_change_pct:.2f}%")
# Trigger training if we have enough samples
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
training_results = self.cnn_adapter.train(epochs=1)
logger.info(f"CNN training completed: loss={training_results.get('loss', 0):.4f}, accuracy={training_results.get('accuracy', 0):.4f}")
except Exception as e:
logger.error(f"Error adding training samples from predictions: {e}")
async def _get_all_predictions(self, symbol: str) -> List[Prediction]: async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
"""Get predictions from all registered models with input data storage""" """Get predictions from all registered models with input data storage"""
predictions = [] predictions = []
@ -1050,8 +1102,12 @@ class TradingOrchestrator:
# Get CNN predictions for each timeframe # Get CNN predictions for each timeframe
cnn_predictions = await self._get_cnn_predictions(model, symbol) cnn_predictions = await self._get_cnn_predictions(model, symbol)
predictions.extend(cnn_predictions) predictions.extend(cnn_predictions)
# Store input data for CNN # Store input data for CNN - store for each prediction
model_input = input_data.get('cnn_input') model_input = input_data.get('cnn_input')
if model_input is not None and cnn_predictions:
# Store inference data for each CNN prediction
for cnn_pred in cnn_predictions:
await self._store_inference_data_async(model_name, model_input, cnn_pred, current_time, symbol)
elif isinstance(model, RLAgentInterface): elif isinstance(model, RLAgentInterface):
# Get RL prediction # Get RL prediction
@ -1061,6 +1117,8 @@ class TradingOrchestrator:
prediction = rl_prediction prediction = rl_prediction
# Store input data for RL # Store input data for RL
model_input = input_data.get('rl_input') model_input = input_data.get('rl_input')
if model_input is not None:
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
else: else:
# Generic model interface # Generic model interface
@ -1070,78 +1128,173 @@ class TradingOrchestrator:
prediction = generic_prediction prediction = generic_prediction
# Store input data for generic model # Store input data for generic model
model_input = input_data.get('generic_input') model_input = input_data.get('generic_input')
if model_input is not None:
# Store inference data for training await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
if prediction and model_input is not None:
self._store_inference_data(symbol, model_name, model_input, prediction, current_time)
except Exception as e: except Exception as e:
logger.error(f"Error getting prediction from {model_name}: {e}") logger.error(f"Error getting prediction from {model_name}: {e}")
continue continue
# Debug: Log inference history status (only if low record count)
total_records = sum(len(history) for history in self.inference_history.values())
if total_records < 10: # Only log when we have few records
logger.debug(f"Total inference records across all models: {total_records}")
for model_name, history in self.inference_history.items():
logger.debug(f" {model_name}: {len(history)} records")
# Trigger training based on previous inference data # Trigger training based on previous inference data
await self._trigger_model_training(symbol) await self._trigger_model_training(symbol)
return predictions return predictions
async def _collect_model_input_data(self, symbol: str) -> Dict[str, Any]: async def _collect_model_input_data(self, symbol: str) -> Dict[str, Any]:
"""Collect comprehensive input data for all models""" """Collect standardized input data for all models - ETH primary + BTC reference"""
try: try:
input_data = {} # Only collect data for ETH (primary symbol) - we inference only for ETH
if symbol != 'ETH/USDT':
return {}
# Get current market data from data provider # Standardized input: 4 ETH timeframes + 1s BTC reference
current_price = self.data_provider.get_current_price(symbol) eth_data = {}
eth_timeframes = ['1s', '1m', '1h', '1d']
# Collect OHLCV data for multiple timeframes # Collect ETH data for all timeframes
ohlcv_data = {} for tf in eth_timeframes:
timeframes = ['1s', '1m', '1h', '1d'] df = self.data_provider.get_historical_data('ETH/USDT', tf, limit=300)
for tf in timeframes:
df = self.data_provider.get_historical_data(symbol, tf, limit=300)
if df is not None and not df.empty: if df is not None and not df.empty:
ohlcv_data[tf] = df eth_data[f'ETH_{tf}'] = df
# Collect COB data if available # Collect BTC 1s reference data
cob_data = self.get_cob_snapshot(symbol) btc_1s = self.data_provider.get_historical_data('BTC/USDT', '1s', limit=300)
btc_data = {}
if btc_1s is not None and not btc_1s.empty:
btc_data['BTC_1s'] = btc_1s
# Collect technical indicators # Get current prices
technical_indicators = {} eth_price = self.data_provider.get_current_price('ETH/USDT')
if '1h' in ohlcv_data: btc_price = self.data_provider.get_current_price('BTC/USDT')
df = ohlcv_data['1h']
if len(df) > 20:
technical_indicators['sma_20'] = df['close'].rolling(20).mean().iloc[-1]
technical_indicators['rsi'] = self._calculate_rsi(df['close'])
# Prepare CNN input # Create standardized input package
cnn_input = self._prepare_cnn_input_data(ohlcv_data, cob_data, technical_indicators) standardized_input = {
# Prepare RL input
rl_input = self._prepare_rl_input_data(ohlcv_data, cob_data, technical_indicators)
# Prepare generic input
generic_input = {
'symbol': symbol,
'current_price': current_price,
'ohlcv_data': ohlcv_data,
'cob_data': cob_data,
'technical_indicators': technical_indicators
}
input_data = {
'cnn_input': cnn_input,
'rl_input': rl_input,
'generic_input': generic_input,
'timestamp': datetime.now(), 'timestamp': datetime.now(),
'symbol': symbol 'primary_symbol': 'ETH/USDT',
'reference_symbol': 'BTC/USDT',
'eth_data': eth_data,
'btc_data': btc_data,
'current_prices': {
'ETH': eth_price,
'BTC': btc_price
},
'data_completeness': {
'eth_timeframes': len(eth_data),
'btc_reference': len(btc_data),
'total_expected': 5 # 4 ETH + 1 BTC
}
} }
return input_data # Create model-specific input data
model_inputs = {
'cnn_input': standardized_input,
'rl_input': standardized_input,
'generic_input': standardized_input,
'standardized_input': standardized_input
}
return model_inputs
except Exception as e: except Exception as e:
logger.error(f"Error collecting model input data for {symbol}: {e}") logger.error(f"Error collecting standardized model input data: {e}")
return {} return {}
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray: async def _store_inference_data_async(self, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime, symbol: str = None):
"""Prepare standardized input data for CNN models""" """Store inference data per-model with async file operations and memory optimization"""
try:
# Only log first few inference records to avoid spam
if len(self.inference_history.get(model_name, [])) < 3:
logger.debug(f"Storing inference data for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
# Extract symbol from prediction if not provided
if symbol is None:
symbol = getattr(prediction, 'symbol', 'ETH/USDT') # Default to ETH/USDT if not available
# Create comprehensive inference record
inference_record = {
'timestamp': timestamp.isoformat(),
'symbol': symbol,
'model_name': model_name,
'model_input': model_input,
'prediction': {
'action': prediction.action,
'confidence': prediction.confidence,
'probabilities': prediction.probabilities,
'timeframe': prediction.timeframe
},
'metadata': prediction.metadata or {}
}
# Store in memory (only last 5 per model)
if model_name not in self.inference_history:
self.inference_history[model_name] = deque(maxlen=self.max_memory_inferences)
self.inference_history[model_name].append(inference_record)
# Async file storage (don't wait for completion)
asyncio.create_task(self._save_inference_to_disk_async(model_name, inference_record))
logger.debug(f"Stored inference data for {model_name} (memory: {len(self.inference_history[model_name])}/5)")
except Exception as e:
logger.error(f"Error storing inference data for {model_name}: {e}")
async def _save_inference_to_disk_async(self, model_name: str, inference_record: Dict):
"""Async save inference record to disk with file capping"""
try:
# Create model-specific directory
model_dir = Path(f"training_data/inference_history/{model_name}")
model_dir.mkdir(parents=True, exist_ok=True)
# Create filename with timestamp
timestamp_str = datetime.fromisoformat(inference_record['timestamp']).strftime('%Y%m%d_%H%M%S_%f')[:-3]
filename = f"inference_{timestamp_str}.json"
filepath = model_dir / filename
# Convert to JSON-serializable format
serializable_record = self._make_json_serializable(inference_record)
# Save to file
with open(filepath, 'w') as f:
json.dump(serializable_record, f, indent=2)
# Cap files per model (keep only latest 200)
await self._cap_model_files(model_dir)
logger.debug(f"Saved inference record to disk: {filepath}")
except Exception as e:
logger.error(f"Error saving inference to disk for {model_name}: {e}")
async def _cap_model_files(self, model_dir: Path):
"""Cap the number of files per model to max_disk_files_per_model"""
try:
# Get all inference files
files = list(model_dir.glob("inference_*.json"))
if len(files) > self.max_disk_files_per_model:
# Sort by modification time (oldest first)
files.sort(key=lambda x: x.stat().st_mtime)
# Remove oldest files
files_to_remove = files[:-self.max_disk_files_per_model]
for file_path in files_to_remove:
file_path.unlink()
logger.debug(f"Removed {len(files_to_remove)} old inference files from {model_dir.name}")
except Exception as e:
logger.error(f"Error capping model files in {model_dir}: {e}")
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor:
"""Prepare standardized input data for CNN models with proper GPU device placement"""
try: try:
# Create feature matrix from OHLCV data # Create feature matrix from OHLCV data
features = [] features = []
@ -1168,16 +1321,18 @@ class TradingOrchestrator:
feature_array = np.pad(feature_array, (0, 300 - len(feature_array)), 'constant') feature_array = np.pad(feature_array, (0, 300 - len(feature_array)), 'constant')
else: else:
feature_array = feature_array[:300] feature_array = feature_array[:300]
return feature_array.reshape(1, -1) # Convert to tensor and move to GPU
return torch.tensor(feature_array.reshape(1, -1), dtype=torch.float32, device=self.device)
else: else:
return np.zeros((1, 300)) # Return zero tensor on GPU
return torch.zeros((1, 300), dtype=torch.float32, device=self.device)
except Exception as e: except Exception as e:
logger.error(f"Error preparing CNN input data: {e}") logger.error(f"Error preparing CNN input data: {e}")
return np.zeros((1, 300)) return torch.zeros((1, 300), dtype=torch.float32, device=self.device)
def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray: def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor:
"""Prepare standardized input data for RL models""" """Prepare standardized input data for RL models with proper GPU device placement"""
try: try:
# Create state representation # Create state representation
state_features = [] state_features = []
@ -1205,13 +1360,15 @@ class TradingOrchestrator:
state_array = np.pad(state_array, (0, expected_size - len(state_array)), 'constant') state_array = np.pad(state_array, (0, expected_size - len(state_array)), 'constant')
else: else:
state_array = state_array[:expected_size] state_array = state_array[:expected_size]
return state_array # Convert to tensor and move to GPU
return torch.tensor(state_array, dtype=torch.float32, device=self.device)
else: else:
return np.zeros(100) # Return zero tensor on GPU
return torch.zeros(100, dtype=torch.float32, device=self.device)
except Exception as e: except Exception as e:
logger.error(f"Error preparing RL input data: {e}") logger.error(f"Error preparing RL input data: {e}")
return np.zeros(100) return torch.zeros(100, dtype=torch.float32, device=self.device)
def _store_inference_data(self, symbol: str, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime): def _store_inference_data(self, symbol: str, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime):
"""Store comprehensive inference data for future training with persistent storage""" """Store comprehensive inference data for future training with persistent storage"""
@ -1262,10 +1419,12 @@ class TradingOrchestrator:
'outcome_evaluated': False 'outcome_evaluated': False
} }
# Store in memory (inference history) # Store in memory (inference history) - keyed by model_name
if symbol in self.inference_history: if model_name not in self.inference_history:
self.inference_history[symbol].append(inference_record) self.inference_history[model_name] = deque(maxlen=self.max_memory_inferences)
logger.debug(f"Stored inference data for {model_name} on {symbol}")
self.inference_history[model_name].append(inference_record)
logger.debug(f"Stored inference data for {model_name} on {symbol}")
# Persistent storage to disk (for long-term training data) # Persistent storage to disk (for long-term training data)
self._save_inference_to_disk(inference_record) self._save_inference_to_disk(inference_record)
@ -1350,6 +1509,35 @@ class TradingOrchestrator:
logger.error(f"Error loading inference history from disk: {e}") logger.error(f"Error loading inference history from disk: {e}")
return [] return []
async def load_model_inference_history(self, model_name: str, limit: int = 50) -> List[Dict]:
"""Load inference history for a specific model from disk"""
try:
model_dir = Path(f"training_data/inference_history/{model_name}")
if not model_dir.exists():
return []
# Get all inference files
files = list(model_dir.glob("inference_*.json"))
files.sort(key=lambda x: x.stat().st_mtime, reverse=True) # Newest first
# Load up to 'limit' files
inference_records = []
for filepath in files[:limit]:
try:
with open(filepath, 'r') as f:
record = json.load(f)
inference_records.append(record)
except Exception as e:
logger.warning(f"Error loading inference file {filepath}: {e}")
continue
logger.info(f"Loaded {len(inference_records)} inference records for {model_name}")
return inference_records
except Exception as e:
logger.error(f"Error loading model inference history for {model_name}: {e}")
return []
def get_model_training_data(self, model_name: str, symbol: str = None) -> List[Dict]: def get_model_training_data(self, model_name: str, symbol: str = None) -> List[Dict]:
"""Get training data for a specific model""" """Get training data for a specific model"""
try: try:
@ -1395,12 +1583,28 @@ class TradingOrchestrator:
async def _trigger_model_training(self, symbol: str): async def _trigger_model_training(self, symbol: str):
"""Trigger training for models based on previous inference data""" """Trigger training for models based on previous inference data"""
try: try:
if not self.training_enabled or symbol not in self.inference_history: if not self.training_enabled:
logger.debug("Training disabled, skipping model training")
return return
# Get recent inference records # Check if we have any inference history for any model
recent_records = list(self.inference_history[symbol]) if not self.inference_history:
if len(recent_records) < 2: logger.debug("No inference history available for training")
return
# Get recent inference records from all models (not symbol-based)
all_recent_records = []
for model_name, model_records in self.inference_history.items():
all_recent_records.extend(list(model_records))
# Only log if we have few records (for debugging)
if len(all_recent_records) < 5:
logger.debug(f"Total inference records for training: {len(all_recent_records)}")
for model_name, model_records in self.inference_history.items():
logger.debug(f" Model {model_name} has {len(model_records)} inference records")
if len(all_recent_records) < 2:
logger.debug("Not enough inference records for training")
return # Need at least 2 records to compare return # Need at least 2 records to compare
# Get current price for outcome evaluation # Get current price for outcome evaluation
@ -1408,12 +1612,11 @@ class TradingOrchestrator:
if current_price is None: if current_price is None:
return return
# Process records that are old enough to evaluate outcomes # Train on the most recent inference record (last prediction made)
cutoff_time = datetime.now() - timedelta(minutes=5) # 5 minutes ago if all_recent_records:
# Get the most recent record for training
for record in recent_records: most_recent_record = max(all_recent_records, key=lambda x: datetime.fromisoformat(x['timestamp']) if isinstance(x['timestamp'], str) else x['timestamp'])
if record['timestamp'] < cutoff_time: await self._evaluate_and_train_on_record(most_recent_record, current_price)
await self._evaluate_and_train_on_record(record, current_price)
except Exception as e: except Exception as e:
logger.error(f"Error triggering model training for {symbol}: {e}") logger.error(f"Error triggering model training for {symbol}: {e}")
@ -1425,6 +1628,10 @@ class TradingOrchestrator:
prediction = record['prediction'] prediction = record['prediction']
timestamp = record['timestamp'] timestamp = record['timestamp']
# Convert timestamp string back to datetime if needed
if isinstance(timestamp, str):
timestamp = datetime.fromisoformat(timestamp)
# Calculate price change since prediction # Calculate price change since prediction
# This is a simplified outcome evaluation - you might want to make it more sophisticated # This is a simplified outcome evaluation - you might want to make it more sophisticated
time_diff = (datetime.now() - timestamp).total_seconds() / 60 # minutes time_diff = (datetime.now() - timestamp).total_seconds() / 60 # minutes
@ -1495,12 +1702,23 @@ class TradingOrchestrator:
) )
logger.debug(f"Added RL training experience: reward={reward}") logger.debug(f"Added RL training experience: reward={reward}")
# Train CNN models # Train CNN models using adapter
elif 'cnn' in model_name.lower() and self.cnn_model: elif 'cnn' in model_name.lower() and hasattr(self, 'cnn_adapter') and self.cnn_adapter:
if hasattr(self.cnn_model, 'train_on_outcome'): # Use the adapter's add_training_sample method
target = 1 if was_correct else 0 actual_action = prediction['action']
self.cnn_model.train_on_outcome(model_input, target) self.cnn_adapter.add_training_sample(record['symbol'], actual_action, reward)
logger.debug(f"Trained CNN on outcome: target={target}") logger.debug(f"Added CNN training sample: action={actual_action}, reward={reward}")
# Trigger training if we have enough samples
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
training_results = self.cnn_adapter.train(epochs=1)
logger.debug(f"CNN training results: {training_results}")
# Fallback for raw CNN model
elif 'cnn' in model_name.lower() and self.cnn_model and hasattr(self.cnn_model, 'train_on_outcome'):
target = 1 if was_correct else 0
self.cnn_model.train_on_outcome(model_input, target)
logger.debug(f"Trained CNN on outcome: target={target}")
except Exception as e: except Exception as e:
logger.error(f"Error training model on outcome: {e}") logger.error(f"Error training model on outcome: {e}")
@ -2147,8 +2365,8 @@ class TradingOrchestrator:
return return
if not ENHANCED_TRAINING_AVAILABLE: if not ENHANCED_TRAINING_AVAILABLE:
logger.warning("EnhancedRealtimeTrainingSystem not available - training disabled") logger.info("EnhancedRealtimeTrainingSystem not available - using built-in training")
self.training_enabled = False # Keep training enabled - we have built-in training capabilities
return return
# Initialize the enhanced training system # Initialize the enhanced training system

View File

@ -449,5 +449,37 @@ class StandardizedDataProvider(DataProvider):
logger.info("Stopped real-time processing for standardized data") logger.info("Stopped real-time processing for standardized data")
except Exception as e:
logger.error(f"Error stopping real-time processing: {e}")
def get_recent_prices(self, symbol: str, limit: int = 10) -> List[float]:
"""
Get recent prices for a symbol
Args:
symbol: Trading symbol
limit: Number of recent prices to return
Returns:
List[float]: List of recent prices
"""
try:
# Get recent OHLCV data using parent class method
df = self.get_historical_data(symbol, '1m', limit)
if df is None or df.empty:
return []
# Extract close prices from DataFrame
if 'close' in df.columns:
prices = df['close'].tolist()
return prices[-limit:] # Return most recent prices
else:
logger.warning(f"No 'close' column found in OHLCV data for {symbol}")
return []
except Exception as e:
logger.error(f"Error getting recent prices for {symbol}: {e}")
return []
except Exception as e: except Exception as e:
logger.error(f"Error stopping real-time processing: {e}") logger.error(f"Error stopping real-time processing: {e}")

View File

@ -16,6 +16,18 @@ ETH:
300s of 1s OHLCV data (5 min) 300s of 1s OHLCV data (5 min)
300 OHLCV + indicatros bars of each 1m 1h 1d and 1s BTC 300 OHLCV + indicatros bars of each 1m 1h 1d and 1s BTC
so:
# Standardized input for all models:
{
'primary_symbol': 'ETH/USDT',
'reference_symbol': 'BTC/USDT',
'eth_data': {'ETH_1s': df, 'ETH_1m': df, 'ETH_1h': df, 'ETH_1d': df},
'btc_data': {'BTC_1s': df},
'current_prices': {'ETH': price, 'BTC': price},
'data_completeness': {...}
}
RL model should have also access of the last hidden layers of the CNN model where patterns are learned. it can be empty if CNN model is not active or missing. as well as the output (predictions) of the CNN model for each timeframe (1s 1m 1h 1d) and next expected pivot point RL model should have also access of the last hidden layers of the CNN model where patterns are learned. it can be empty if CNN model is not active or missing. as well as the output (predictions) of the CNN model for each timeframe (1s 1m 1h 1d) and next expected pivot point
## CNN model ## CNN model

141
test_device_fix.py Normal file
View File

@ -0,0 +1,141 @@
#!/usr/bin/env python3
"""
Test script to verify device mismatch fixes for GPU training
"""
import torch
import logging
import sys
import os
# Add the project root to the path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
from core.data_models import BaseDataInput, OHLCVBar
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_device_consistency():
"""Test that all tensors are on the same device"""
logger.info("Testing device consistency for EnhancedCNN...")
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
try:
# Initialize the adapter
adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
# Verify adapter device
logger.info(f"Adapter device: {adapter.device}")
logger.info(f"Model device: {next(adapter.model.parameters()).device}")
# Create sample data
sample_ohlcv = [
OHLCVBar(
symbol="ETH/USDT",
timeframe="1s",
timestamp=1640995200.0, # 2022-01-01
open=50000.0,
high=51000.0,
low=49000.0,
close=50500.0,
volume=1000.0
)
] * 300 # 300 frames
base_data = BaseDataInput(
symbol="ETH/USDT",
timestamp=1640995200.0,
ohlcv_1s=sample_ohlcv,
ohlcv_1m=sample_ohlcv,
ohlcv_5m=sample_ohlcv,
ohlcv_15m=sample_ohlcv,
btc_ohlcv=sample_ohlcv,
cob_data={},
ma_data={},
technical_indicators={},
last_predictions={}
)
# Test prediction
logger.info("Testing prediction...")
prediction = adapter.predict(base_data)
logger.info(f"Prediction successful: {prediction.predictions['action']} (confidence: {prediction.confidence:.3f})")
# Test training sample addition
logger.info("Testing training sample addition...")
adapter.add_training_sample(base_data, "BUY", 0.1)
adapter.add_training_sample(base_data, "SELL", -0.05)
adapter.add_training_sample(base_data, "HOLD", 0.02)
# Test training
logger.info("Testing training...")
training_results = adapter.train(epochs=1)
logger.info(f"Training results: {training_results}")
logger.info("✅ All device consistency tests passed!")
return True
except Exception as e:
logger.error(f"❌ Device consistency test failed: {e}")
import traceback
traceback.print_exc()
return False
def test_orchestrator_inference_history():
"""Test that orchestrator properly initializes inference history"""
logger.info("Testing orchestrator inference history initialization...")
try:
from core.orchestrator import TradingOrchestrator
from core.data_provider import DataProvider
# Initialize orchestrator
data_provider = DataProvider()
orchestrator = TradingOrchestrator(data_provider=data_provider)
# Check if inference history is initialized
logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
# Check if models are registered
logger.info(f"Registered models: {list(orchestrator.model_registry.models.keys())}")
# Verify each registered model has inference history
for model_name in orchestrator.model_registry.models.keys():
if model_name in orchestrator.inference_history:
logger.info(f"{model_name} has inference history initialized")
else:
logger.warning(f"{model_name} missing inference history")
logger.info("✅ Orchestrator inference history test completed!")
return True
except Exception as e:
logger.error(f"❌ Orchestrator test failed: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
logger.info("Starting device fix verification tests...")
# Test 1: Device consistency
test1_passed = test_device_consistency()
# Test 2: Orchestrator inference history
test2_passed = test_orchestrator_inference_history()
# Summary
if test1_passed and test2_passed:
logger.info("🎉 All tests passed! Device issues should be fixed.")
sys.exit(0)
else:
logger.error("❌ Some tests failed. Please check the logs above.")
sys.exit(1)

153
test_device_training_fix.py Normal file
View File

@ -0,0 +1,153 @@
#!/usr/bin/env python3
"""
Test script to verify device handling and training sample population fixes
"""
import logging
import asyncio
import torch
from datetime import datetime
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_device_handling():
"""Test that device handling is working correctly"""
try:
logger.info("Testing device handling...")
# Test 1: Check CUDA availability
cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")
logger.info(f"CUDA available: {cuda_available}")
logger.info(f"Using device: {device}")
# Test 2: Initialize CNN adapter
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
logger.info("Initializing CNN adapter...")
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
logger.info(f"CNN adapter device: {cnn_adapter.device}")
logger.info(f"CNN model device: {cnn_adapter.model.device}")
# Test 3: Create test data
from core.data_models import BaseDataInput
logger.info("Creating test BaseDataInput...")
base_data = BaseDataInput(
symbol="ETH/USDT",
timestamp=datetime.now(),
ohlcv_1s=[],
ohlcv_1m=[],
ohlcv_1h=[],
ohlcv_1d=[],
btc_ohlcv_1s=[],
cob_data=None,
technical_indicators={},
last_predictions={}
)
# Test 4: Make prediction (this should not cause device mismatch)
logger.info("Making prediction...")
prediction = cnn_adapter.predict(base_data)
logger.info(f"Prediction successful: {prediction.predictions['action']}")
logger.info(f"Confidence: {prediction.confidence:.4f}")
# Test 5: Add training samples
logger.info("Adding training samples...")
cnn_adapter.add_training_sample(base_data, "BUY", 0.1)
cnn_adapter.add_training_sample(base_data, "SELL", -0.05)
cnn_adapter.add_training_sample(base_data, "HOLD", 0.02)
logger.info(f"Training samples added: {len(cnn_adapter.training_data)}")
# Test 6: Try training if we have enough samples
if len(cnn_adapter.training_data) >= 2:
logger.info("Attempting training...")
training_results = cnn_adapter.train(epochs=1)
logger.info(f"Training results: {training_results}")
else:
logger.info("Not enough samples for training")
logger.info("✅ Device handling test passed!")
return True
except Exception as e:
logger.error(f"❌ Device handling test failed: {e}")
import traceback
traceback.print_exc()
return False
async def test_orchestrator_training():
"""Test that orchestrator properly adds training samples"""
try:
logger.info("Testing orchestrator training integration...")
# Test 1: Initialize orchestrator
from core.orchestrator import TradingOrchestrator
from core.standardized_data_provider import StandardizedDataProvider
logger.info("Initializing data provider...")
data_provider = StandardizedDataProvider()
logger.info("Initializing orchestrator...")
orchestrator = TradingOrchestrator(data_provider=data_provider)
# Test 2: Check if CNN adapter is available
if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter:
logger.info(f"✅ CNN adapter available in orchestrator")
logger.info(f"Initial training samples: {len(orchestrator.cnn_adapter.training_data)}")
else:
logger.warning("⚠️ CNN adapter not available in orchestrator")
return False
# Test 3: Make a trading decision (this should add training samples)
logger.info("Making trading decision...")
decision = await orchestrator.make_trading_decision("ETH/USDT")
if decision:
logger.info(f"Decision: {decision.action} (confidence: {decision.confidence:.4f})")
logger.info(f"Training samples after decision: {len(orchestrator.cnn_adapter.training_data)}")
else:
logger.warning("No decision made")
# Test 4: Check inference history
logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
for model_name, history in orchestrator.inference_history.items():
logger.info(f" {model_name}: {len(history)} records")
logger.info("✅ Orchestrator training test passed!")
return True
except Exception as e:
logger.error(f"❌ Orchestrator training test failed: {e}")
import traceback
traceback.print_exc()
return False
async def main():
"""Run all tests"""
logger.info("Starting device and training fix tests...")
# Test 1: Device handling
test1_passed = test_device_handling()
# Test 2: Orchestrator training
test2_passed = await test_orchestrator_training()
# Summary
logger.info("\n" + "="*50)
logger.info("TEST SUMMARY:")
logger.info(f"Device handling: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
logger.info(f"Orchestrator training: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
if test1_passed and test2_passed:
logger.info("🎉 All tests passed! Device and training issues should be fixed.")
else:
logger.error("❌ Some tests failed. Please check the logs above.")
if __name__ == "__main__":
asyncio.run(main())