From 56f1110df3c3ab75e8d50873f71ba4d5d4bb77a7 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Wed, 2 Jul 2025 00:38:29 +0300 Subject: [PATCH] feed COB to the models --- web/clean_dashboard.py | 85 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 82 insertions(+), 3 deletions(-) diff --git a/web/clean_dashboard.py b/web/clean_dashboard.py index adb5b0d..23cd211 100644 --- a/web/clean_dashboard.py +++ b/web/clean_dashboard.py @@ -3630,12 +3630,16 @@ class CleanTradingDashboard: def _feed_cob_data_to_models(self, symbol: str, cob_snapshot: dict): """Feed COB data to models for training and inference""" try: + # Calculate cumulative imbalance for model feeding + cumulative_imbalance = self._calculate_cumulative_imbalance(symbol) + # Create 15-second history for model feeding history_data = { 'symbol': symbol, 'current_snapshot': cob_snapshot, 'history': self.cob_data_history[symbol][-15:], # Last 15 seconds 'bucketed_data': self.cob_bucketed_data[symbol], + 'cumulative_imbalance': cumulative_imbalance, # Add cumulative imbalance 'timestamp': cob_snapshot['timestamp'] } @@ -3643,7 +3647,7 @@ class CleanTradingDashboard: if hasattr(self.orchestrator, '_on_cob_dashboard_data'): try: self.orchestrator._on_cob_dashboard_data(symbol, history_data) - logger.debug(f"COB data fed to orchestrator for {symbol}") + logger.debug(f"COB data fed to orchestrator for {symbol} with cumulative imbalance: {cumulative_imbalance}") except Exception as e: logger.debug(f"Error feeding COB data to orchestrator: {e}") @@ -4026,6 +4030,10 @@ class CleanTradingDashboard: current_price = self._get_current_price('ETH/USDT') if not current_price: return training_data + + # Get cumulative imbalance for training + cumulative_imbalance = self._calculate_cumulative_imbalance('ETH/USDT') + df = self.data_provider.get_historical_data('ETH/USDT', '1m', limit=50) if df is not None and not df.empty: for i in range(1, min(len(df), 20)): @@ -4035,6 +4043,7 @@ class CleanTradingDashboard: sample = { 'timestamp': df.index[i], 'price': curr_price, 'prev_price': prev_price, 'price_change': price_change, 'volume': float(df['volume'].iloc[i]), + 'cumulative_imbalance': cumulative_imbalance, # Add cumulative imbalance 'action': 'BUY' if price_change > 0.001 else 'SELL' if price_change < -0.001 else 'HOLD' } training_data.append(sample) @@ -4043,7 +4052,8 @@ class CleanTradingDashboard: for tick in recent_ticks: sample = { 'timestamp': tick.get('datetime', datetime.now()), 'price': tick.get('price', current_price), - 'volume': tick.get('volume', 0), 'tick_data': True + 'volume': tick.get('volume', 0), 'cumulative_imbalance': cumulative_imbalance, # Add cumulative imbalance + 'tick_data': True } training_data.append(sample) return training_data @@ -4058,13 +4068,34 @@ class CleanTradingDashboard: return agent = self.orchestrator.rl_agent training_samples = 0 + total_loss = 0 + loss_count = 0 + for data in market_data[-10:]: try: price = data.get('price', 0) prev_price = data.get('prev_price', price) price_change = data.get('price_change', 0) volume = data.get('volume', 0) - state = np.array([price / 10000, price_change, volume / 1000000, 1.0 if price > prev_price else 0.0, abs(price_change) * 100]) + cumulative_imbalance = data.get('cumulative_imbalance', {}) + + # Extract imbalance values for state + imbalance_1s = cumulative_imbalance.get('1s', 0.0) + imbalance_5s = cumulative_imbalance.get('5s', 0.0) + imbalance_15s = cumulative_imbalance.get('15s', 0.0) + imbalance_60s = cumulative_imbalance.get('60s', 0.0) + + state = np.array([ + price / 10000, + price_change, + volume / 1000000, + 1.0 if price > prev_price else 0.0, + abs(price_change) * 100, + imbalance_1s, + imbalance_5s, + imbalance_15s, + imbalance_60s + ]) if hasattr(agent, 'state_dim') and len(state) < agent.state_dim: padded_state = np.zeros(agent.state_dim) padded_state[:len(state)] = state @@ -4079,17 +4110,58 @@ class CleanTradingDashboard: training_samples += 1 except Exception as e: logger.debug(f"Error adding market experience to DQN memory: {e}") + if hasattr(agent, 'memory') and len(agent.memory) >= 32: for _ in range(3): try: loss = agent.replay() if loss is not None: + total_loss += loss + loss_count += 1 self.orchestrator.update_model_loss('dqn', loss) if not hasattr(agent, 'losses'): agent.losses = [] agent.losses.append(loss) if len(agent.losses) > 1000: agent.losses = agent.losses[-1000:] except Exception as e: logger.debug(f"DQN training step failed: {e}") + + # Save checkpoint after training + if loss_count > 0: + try: + from utils.checkpoint_manager import save_checkpoint + avg_loss = total_loss / loss_count + + # Prepare checkpoint data + checkpoint_data = { + 'model_state_dict': agent.model.state_dict() if hasattr(agent, 'model') else None, + 'target_model_state_dict': agent.target_model.state_dict() if hasattr(agent, 'target_model') else None, + 'optimizer_state_dict': agent.optimizer.state_dict() if hasattr(agent, 'optimizer') else None, + 'memory_size': len(agent.memory), + 'training_samples': training_samples, + 'losses': agent.losses[-100:] if hasattr(agent, 'losses') else [] + } + + performance_metrics = { + 'loss': avg_loss, + 'memory_size': len(agent.memory), + 'training_samples': training_samples, + 'model_parameters': sum(p.numel() for p in agent.model.parameters()) if hasattr(agent, 'model') else 0 + } + + metadata = save_checkpoint( + model=checkpoint_data, + model_name="dqn_agent", + model_type="dqn", + performance_metrics=performance_metrics, + training_metadata={'training_iterations': loss_count} + ) + + if metadata: + logger.info(f"DQN checkpoint saved: {metadata.checkpoint_id} (loss={avg_loss:.4f})") + + except Exception as e: + logger.error(f"Error saving DQN checkpoint: {e}") + logger.info(f"DQN TRAINING: Added {training_samples} experiences, memory size: {len(agent.memory)}") except Exception as e: logger.error(f"Error in real DQN training: {e}") @@ -4109,10 +4181,17 @@ class CleanTradingDashboard: current_price = current_data.get('price', 0) next_price = next_data.get('price', current_price) price_change = (next_price - current_price) / current_price if current_price > 0 else 0 + cumulative_imbalance = current_data.get('cumulative_imbalance', {}) + features = np.random.randn(100) features[0] = current_price / 10000 features[1] = price_change features[2] = current_data.get('volume', 0) / 1000000 + # Add cumulative imbalance features + features[3] = cumulative_imbalance.get('1s', 0.0) + features[4] = cumulative_imbalance.get('5s', 0.0) + features[5] = cumulative_imbalance.get('15s', 0.0) + features[6] = cumulative_imbalance.get('60s', 0.0) if price_change > 0.001: target = 2 elif price_change < -0.001: target = 0 else: target = 1