MISC
This commit is contained in:
@@ -1071,8 +1071,9 @@ class DQNAgent:
|
|||||||
|
|
||||||
# If no experiences provided, sample from memory
|
# If no experiences provided, sample from memory
|
||||||
if experiences is None:
|
if experiences is None:
|
||||||
# Skip if memory is too small
|
# Skip if memory is too small (allow early training for GPU warmup)
|
||||||
if len(self.memory) < self.batch_size:
|
min_required = min(getattr(self, 'batch_size', 32), 16)
|
||||||
|
if len(self.memory) < min_required:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
# Sample random mini-batch from memory
|
# Sample random mini-batch from memory
|
||||||
|
|||||||
@@ -1312,13 +1312,19 @@ class DataProvider:
|
|||||||
|
|
||||||
# For 1s timeframe, try to generate from WebSocket ticks first
|
# For 1s timeframe, try to generate from WebSocket ticks first
|
||||||
if timeframe == '1s':
|
if timeframe == '1s':
|
||||||
# logger.info(f"Attempting to generate 1s candles from WebSocket ticks for {symbol}")
|
# Attempt to generate from WebSocket ticks, but throttle attempts to avoid spam
|
||||||
|
if not hasattr(self, '_last_1s_generation_attempt'):
|
||||||
|
self._last_1s_generation_attempt = {}
|
||||||
|
now_ts = time.time()
|
||||||
|
last_attempt = self._last_1s_generation_attempt.get(symbol, 0)
|
||||||
|
generated_df = None
|
||||||
|
if now_ts - last_attempt >= 1.5:
|
||||||
|
self._last_1s_generation_attempt[symbol] = now_ts
|
||||||
generated_df = self._generate_1s_candles_from_ticks(symbol, limit)
|
generated_df = self._generate_1s_candles_from_ticks(symbol, limit)
|
||||||
if generated_df is not None and not generated_df.empty:
|
if generated_df is not None and not generated_df.empty:
|
||||||
# logger.info(f"Successfully generated 1s candles from WebSocket ticks for {symbol}")
|
|
||||||
return generated_df
|
return generated_df
|
||||||
else:
|
else:
|
||||||
logger.info(f"Could not generate 1s candles from ticks for {symbol}; trying Binance API")
|
logger.debug(f"Could not generate 1s candles from ticks for {symbol}; trying Binance API")
|
||||||
|
|
||||||
# Convert symbol format
|
# Convert symbol format
|
||||||
binance_symbol = symbol.replace('/', '').upper()
|
binance_symbol = symbol.replace('/', '').upper()
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ class EnhancedRLTrainingAdapter:
|
|||||||
'direction': direction,
|
'direction': direction,
|
||||||
'confidence': float(confidence),
|
'confidence': float(confidence),
|
||||||
'action': action_names[action_idx],
|
'action': action_names[action_idx],
|
||||||
'model_state': state,
|
'model_state': (state.tolist() if hasattr(state, 'tolist') else state),
|
||||||
'context': context
|
'context': context
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -238,6 +238,8 @@ class EnhancedRLTrainingAdapter:
|
|||||||
|
|
||||||
predicted_price = current_price * (1 + predicted_return) if (predicted_return is not None and current_price) else current_price
|
predicted_price = current_price * (1 + predicted_return) if (predicted_return is not None and current_price) else current_price
|
||||||
|
|
||||||
|
# Also attach DQN-formatted state if available for training consumption
|
||||||
|
dqn_state = self._convert_to_dqn_state(base_data, context)
|
||||||
return {
|
return {
|
||||||
'predicted_price': predicted_price,
|
'predicted_price': predicted_price,
|
||||||
'current_price': current_price,
|
'current_price': current_price,
|
||||||
@@ -246,6 +248,7 @@ class EnhancedRLTrainingAdapter:
|
|||||||
'predicted_return': predicted_return,
|
'predicted_return': predicted_return,
|
||||||
'action': action,
|
'action': action,
|
||||||
'model_output': model_output,
|
'model_output': model_output,
|
||||||
|
'model_state': (dqn_state.tolist() if hasattr(dqn_state, 'tolist') else dqn_state),
|
||||||
'context': context
|
'context': context
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1518,9 +1518,89 @@ class TradingOrchestrator:
|
|||||||
with open(self.ui_state_file, "w") as f:
|
with open(self.ui_state_file, "w") as f:
|
||||||
json.dump(ui_state, f, indent=4)
|
json.dump(ui_state, f, indent=4)
|
||||||
logger.debug(f"UI state saved to {self.ui_state_file}")
|
logger.debug(f"UI state saved to {self.ui_state_file}")
|
||||||
|
# Also append a session snapshot for persistence across restarts
|
||||||
|
self._append_session_snapshot()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error saving UI state: {e}")
|
logger.error(f"Error saving UI state: {e}")
|
||||||
|
|
||||||
|
def _append_session_snapshot(self):
|
||||||
|
"""Append current session metrics to persistent JSON until cleared manually."""
|
||||||
|
try:
|
||||||
|
session_file = os.path.join("data", "session_state.json")
|
||||||
|
os.makedirs(os.path.dirname(session_file), exist_ok=True)
|
||||||
|
|
||||||
|
# Load existing
|
||||||
|
existing = {}
|
||||||
|
if os.path.exists(session_file):
|
||||||
|
try:
|
||||||
|
with open(session_file, "r", encoding="utf-8") as f:
|
||||||
|
existing = json.load(f) or {}
|
||||||
|
except Exception:
|
||||||
|
existing = {}
|
||||||
|
|
||||||
|
# Collect metrics
|
||||||
|
balance = 0.0
|
||||||
|
pnl_total = 0.0
|
||||||
|
closed_trades = []
|
||||||
|
try:
|
||||||
|
if hasattr(self, "trading_executor") and self.trading_executor:
|
||||||
|
balance = float(getattr(self.trading_executor, "account_balance", 0.0) or 0.0)
|
||||||
|
if hasattr(self.trading_executor, "trade_history"):
|
||||||
|
for t in self.trading_executor.trade_history:
|
||||||
|
try:
|
||||||
|
closed_trades.append({
|
||||||
|
"symbol": t.symbol,
|
||||||
|
"side": t.side,
|
||||||
|
"qty": t.quantity,
|
||||||
|
"entry": t.entry_price,
|
||||||
|
"exit": t.exit_price,
|
||||||
|
"pnl": t.pnl,
|
||||||
|
"timestamp": getattr(t, "timestamp", None)
|
||||||
|
})
|
||||||
|
pnl_total += float(t.pnl or 0.0)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Models and performance (best-effort)
|
||||||
|
models = {}
|
||||||
|
try:
|
||||||
|
models = {
|
||||||
|
"dqn": {
|
||||||
|
"available": bool(getattr(self, "rl_agent", None)),
|
||||||
|
"last_losses": getattr(getattr(self, "rl_agent", None), "losses", [])[-10:] if getattr(getattr(self, "rl_agent", None), "losses", None) else []
|
||||||
|
},
|
||||||
|
"cnn": {
|
||||||
|
"available": bool(getattr(self, "cnn_model", None))
|
||||||
|
},
|
||||||
|
"cob_rl": {
|
||||||
|
"available": bool(getattr(self, "cob_rl_agent", None))
|
||||||
|
},
|
||||||
|
"decision_fusion": {
|
||||||
|
"available": bool(getattr(self, "decision_model", None))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
snapshot = {
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"balance": balance,
|
||||||
|
"session_pnl": pnl_total,
|
||||||
|
"closed_trades": closed_trades,
|
||||||
|
"models": models
|
||||||
|
}
|
||||||
|
|
||||||
|
if "snapshots" not in existing:
|
||||||
|
existing["snapshots"] = []
|
||||||
|
existing["snapshots"].append(snapshot)
|
||||||
|
|
||||||
|
with open(session_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(existing, f, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error appending session snapshot: {e}")
|
||||||
|
|
||||||
def get_model_toggle_state(self, model_name: str) -> Dict[str, bool]:
|
def get_model_toggle_state(self, model_name: str) -> Dict[str, bool]:
|
||||||
"""Get toggle state for a model"""
|
"""Get toggle state for a model"""
|
||||||
key = self._normalize_model_name(model_name)
|
key = self._normalize_model_name(model_name)
|
||||||
|
|||||||
Reference in New Issue
Block a user