From c39b70f6fa78f451b7f5342dd2e7d1585a217791 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Tue, 26 Aug 2025 18:11:34 +0300 Subject: [PATCH] MISC --- NN/models/dqn_agent.py | 5 +- core/data_provider.py | 14 +++-- core/enhanced_rl_training_adapter.py | 5 +- core/orchestrator.py | 80 ++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+), 7 deletions(-) diff --git a/NN/models/dqn_agent.py b/NN/models/dqn_agent.py index 815858e..61d8dfe 100644 --- a/NN/models/dqn_agent.py +++ b/NN/models/dqn_agent.py @@ -1071,8 +1071,9 @@ class DQNAgent: # If no experiences provided, sample from memory if experiences is None: - # Skip if memory is too small - if len(self.memory) < self.batch_size: + # Skip if memory is too small (allow early training for GPU warmup) + min_required = min(getattr(self, 'batch_size', 32), 16) + if len(self.memory) < min_required: return 0.0 # Sample random mini-batch from memory diff --git a/core/data_provider.py b/core/data_provider.py index 2ba15aa..c83fddf 100644 --- a/core/data_provider.py +++ b/core/data_provider.py @@ -1312,13 +1312,19 @@ class DataProvider: # For 1s timeframe, try to generate from WebSocket ticks first if timeframe == '1s': - # logger.info(f"Attempting to generate 1s candles from WebSocket ticks for {symbol}") - generated_df = self._generate_1s_candles_from_ticks(symbol, limit) + # Attempt to generate from WebSocket ticks, but throttle attempts to avoid spam + if not hasattr(self, '_last_1s_generation_attempt'): + self._last_1s_generation_attempt = {} + now_ts = time.time() + last_attempt = self._last_1s_generation_attempt.get(symbol, 0) + generated_df = None + if now_ts - last_attempt >= 1.5: + self._last_1s_generation_attempt[symbol] = now_ts + generated_df = self._generate_1s_candles_from_ticks(symbol, limit) if generated_df is not None and not generated_df.empty: - # logger.info(f"Successfully generated 1s candles from WebSocket ticks for {symbol}") return generated_df else: - logger.info(f"Could not generate 1s candles from ticks for {symbol}; trying Binance API") + logger.debug(f"Could not generate 1s candles from ticks for {symbol}; trying Binance API") # Convert symbol format binance_symbol = symbol.replace('/', '').upper() diff --git a/core/enhanced_rl_training_adapter.py b/core/enhanced_rl_training_adapter.py index 2cd27c3..42ede62 100644 --- a/core/enhanced_rl_training_adapter.py +++ b/core/enhanced_rl_training_adapter.py @@ -154,7 +154,7 @@ class EnhancedRLTrainingAdapter: 'direction': direction, 'confidence': float(confidence), 'action': action_names[action_idx], - 'model_state': state, + 'model_state': (state.tolist() if hasattr(state, 'tolist') else state), 'context': context } except Exception as e: @@ -238,6 +238,8 @@ class EnhancedRLTrainingAdapter: predicted_price = current_price * (1 + predicted_return) if (predicted_return is not None and current_price) else current_price + # Also attach DQN-formatted state if available for training consumption + dqn_state = self._convert_to_dqn_state(base_data, context) return { 'predicted_price': predicted_price, 'current_price': current_price, @@ -246,6 +248,7 @@ class EnhancedRLTrainingAdapter: 'predicted_return': predicted_return, 'action': action, 'model_output': model_output, + 'model_state': (dqn_state.tolist() if hasattr(dqn_state, 'tolist') else dqn_state), 'context': context } except Exception as e: diff --git a/core/orchestrator.py b/core/orchestrator.py index 13b23ae..320937d 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -1518,9 +1518,89 @@ class TradingOrchestrator: with open(self.ui_state_file, "w") as f: json.dump(ui_state, f, indent=4) logger.debug(f"UI state saved to {self.ui_state_file}") + # Also append a session snapshot for persistence across restarts + self._append_session_snapshot() except Exception as e: logger.error(f"Error saving UI state: {e}") + def _append_session_snapshot(self): + """Append current session metrics to persistent JSON until cleared manually.""" + try: + session_file = os.path.join("data", "session_state.json") + os.makedirs(os.path.dirname(session_file), exist_ok=True) + + # Load existing + existing = {} + if os.path.exists(session_file): + try: + with open(session_file, "r", encoding="utf-8") as f: + existing = json.load(f) or {} + except Exception: + existing = {} + + # Collect metrics + balance = 0.0 + pnl_total = 0.0 + closed_trades = [] + try: + if hasattr(self, "trading_executor") and self.trading_executor: + balance = float(getattr(self.trading_executor, "account_balance", 0.0) or 0.0) + if hasattr(self.trading_executor, "trade_history"): + for t in self.trading_executor.trade_history: + try: + closed_trades.append({ + "symbol": t.symbol, + "side": t.side, + "qty": t.quantity, + "entry": t.entry_price, + "exit": t.exit_price, + "pnl": t.pnl, + "timestamp": getattr(t, "timestamp", None) + }) + pnl_total += float(t.pnl or 0.0) + except Exception: + continue + except Exception: + pass + + # Models and performance (best-effort) + models = {} + try: + models = { + "dqn": { + "available": bool(getattr(self, "rl_agent", None)), + "last_losses": getattr(getattr(self, "rl_agent", None), "losses", [])[-10:] if getattr(getattr(self, "rl_agent", None), "losses", None) else [] + }, + "cnn": { + "available": bool(getattr(self, "cnn_model", None)) + }, + "cob_rl": { + "available": bool(getattr(self, "cob_rl_agent", None)) + }, + "decision_fusion": { + "available": bool(getattr(self, "decision_model", None)) + } + } + except Exception: + pass + + snapshot = { + "timestamp": datetime.now().isoformat(), + "balance": balance, + "session_pnl": pnl_total, + "closed_trades": closed_trades, + "models": models + } + + if "snapshots" not in existing: + existing["snapshots"] = [] + existing["snapshots"].append(snapshot) + + with open(session_file, "w", encoding="utf-8") as f: + json.dump(existing, f, indent=2) + except Exception as e: + logger.error(f"Error appending session snapshot: {e}") + def get_model_toggle_state(self, model_name: str) -> Dict[str, bool]: """Get toggle state for a model""" key = self._normalize_model_name(model_name)