better rewards, fixed TZ at last

This commit is contained in:
Dobromir Popov
2025-08-08 01:53:17 +03:00
parent ded7e7f008
commit b80e1c1eba
3 changed files with 72 additions and 16 deletions

View File

@ -21,7 +21,9 @@ class TradingEnvironment(gym.Env):
risk_aversion: float = 0.2, # Controls how much to penalize volatility risk_aversion: float = 0.2, # Controls how much to penalize volatility
price_scaling: str = 'zscore', # 'zscore', 'minmax', or 'raw' price_scaling: str = 'zscore', # 'zscore', 'minmax', or 'raw'
reward_scaling: float = 10.0, # Scale factor for rewards reward_scaling: float = 10.0, # Scale factor for rewards
episode_penalty: float = 0.1): # Penalty for active positions at end of episode episode_penalty: float = 0.1, # Penalty for active positions at end of episode
min_profit_after_fees: float = 0.0005 # Deadzone: require >= 5 bps beyond fees
):
super(TradingEnvironment, self).__init__() super(TradingEnvironment, self).__init__()
self.data = data self.data = data
@ -33,6 +35,7 @@ class TradingEnvironment(gym.Env):
self.price_scaling = price_scaling self.price_scaling = price_scaling
self.reward_scaling = reward_scaling self.reward_scaling = reward_scaling
self.episode_penalty = episode_penalty self.episode_penalty = episode_penalty
self.min_profit_after_fees = max(0.0, float(min_profit_after_fees))
# Preprocess data if needed # Preprocess data if needed
self._preprocess_data() self._preprocess_data()
@ -177,8 +180,14 @@ class TradingEnvironment(gym.Env):
price_diff = current_price - self.entry_price price_diff = current_price - self.entry_price
pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees
# Adjust reward based on PnL and risk # Deadzone to discourage micro profits
reward = pnl * self.reward_scaling if pnl > 0 and pnl < self.min_profit_after_fees:
reward = -self.fee_rate
elif pnl < 0 and abs(pnl) < self.min_profit_after_fees:
reward = pnl * self.reward_scaling * 0.5
else:
effective_pnl = pnl - (self.min_profit_after_fees if pnl > 0 else 0.0)
reward = effective_pnl * self.reward_scaling
# Track trade performance # Track trade performance
self.total_trades += 1 self.total_trades += 1
@ -212,8 +221,12 @@ class TradingEnvironment(gym.Env):
price_diff = current_price - self.entry_price price_diff = current_price - self.entry_price
unrealized_pnl = price_diff / self.entry_price unrealized_pnl = price_diff / self.entry_price
# Small reward/penalty based on unrealized P&L # Encourage holding only if unrealized edge exceeds deadzone
reward = unrealized_pnl * 0.05 # Scale down to encourage holding good positions unrealized_edge = unrealized_pnl
if abs(unrealized_edge) >= self.min_profit_after_fees:
reward = unrealized_edge * (self.reward_scaling * 0.2)
else:
reward = -0.0002
elif self.position < 0: # Short position elif self.position < 0: # Short position
if action == 0: # BUY (close short) if action == 0: # BUY (close short)
@ -221,8 +234,13 @@ class TradingEnvironment(gym.Env):
price_diff = self.entry_price - current_price price_diff = self.entry_price - current_price
pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees
# Adjust reward based on PnL and risk if pnl > 0 and pnl < self.min_profit_after_fees:
reward = pnl * self.reward_scaling reward = -self.fee_rate
elif pnl < 0 and abs(pnl) < self.min_profit_after_fees:
reward = pnl * self.reward_scaling * 0.5
else:
effective_pnl = pnl - (self.min_profit_after_fees if pnl > 0 else 0.0)
reward = effective_pnl * self.reward_scaling
# Track trade performance # Track trade performance
self.total_trades += 1 self.total_trades += 1
@ -256,8 +274,12 @@ class TradingEnvironment(gym.Env):
price_diff = self.entry_price - current_price price_diff = self.entry_price - current_price
unrealized_pnl = price_diff / self.entry_price unrealized_pnl = price_diff / self.entry_price
# Small reward/penalty based on unrealized P&L # Encourage holding only if unrealized edge exceeds deadzone
reward = unrealized_pnl * 0.05 # Scale down to encourage holding good positions unrealized_edge = unrealized_pnl
if abs(unrealized_edge) >= self.min_profit_after_fees:
reward = unrealized_edge * (self.reward_scaling * 0.2)
else:
reward = -0.0002
# Record the action # Record the action
self.actions_taken.append(action) self.actions_taken.append(action)

Binary file not shown.

View File

@ -1996,12 +1996,15 @@ class CleanTradingDashboard:
try: try:
if hasattr(df_historical.index, 'tz') and df_historical.index.tz is not None: if hasattr(df_historical.index, 'tz') and df_historical.index.tz is not None:
df_historical_local = df_historical.tz_convert(_local_tz) if _local_tz else df_historical df_historical_local = df_historical.tz_convert(_local_tz) if _local_tz else df_historical
# Drop tzinfo (tz-naive) for plotting
df_historical_local.index = df_historical_local.index.tz_localize(None)
else: else:
# Treat as UTC then convert to local # Treat as UTC then convert to local tz, and make tz-naive
df_historical_local = df_historical.copy() df_historical_local = df_historical.copy()
df_historical_local.index = df_historical_local.index.tz_localize('UTC') df_historical_local.index = df_historical_local.index.tz_localize('UTC')
if _local_tz: if _local_tz:
df_historical_local = df_historical_local.tz_convert(_local_tz) df_historical_local = df_historical_local.tz_convert(_local_tz)
df_historical_local.index = df_historical_local.index.tz_localize(None)
except Exception: except Exception:
df_historical_local = df_historical df_historical_local = df_historical
@ -2010,11 +2013,13 @@ class CleanTradingDashboard:
try: try:
if hasattr(df_live.index, 'tz') and df_live.index.tz is not None: if hasattr(df_live.index, 'tz') and df_live.index.tz is not None:
df_live_local = df_live.tz_convert(_local_tz) if _local_tz else df_live df_live_local = df_live.tz_convert(_local_tz) if _local_tz else df_live
df_live_local.index = df_live_local.index.tz_localize(None)
else: else:
df_live_local = df_live.copy() df_live_local = df_live.copy()
df_live_local.index = df_live_local.index.tz_localize('UTC') df_live_local.index = df_live_local.index.tz_localize('UTC')
if _local_tz: if _local_tz:
df_live_local = df_live_local.tz_convert(_local_tz) df_live_local = df_live_local.tz_convert(_local_tz)
df_live_local.index = df_live_local.index.tz_localize(None)
except Exception: except Exception:
df_live_local = df_live df_live_local = df_live
@ -2030,8 +2035,13 @@ class CleanTradingDashboard:
df_main = df_historical_local df_main = df_historical_local
main_source = "Historical 1m" main_source = "Historical 1m"
elif df_live is not None and not df_live.empty: elif df_live is not None and not df_live.empty:
# No historical data, use live only # No historical data, use live only (ensure tz-naive)
df_main = df_live.tail(180) df_main = df_live.tail(180)
try:
if hasattr(df_main.index, 'tz') and df_main.index.tz is not None:
df_main.index = df_main.index.tz_convert(_local_tz).tz_localize(None) if _local_tz else df_main.index.tz_localize(None)
except Exception:
pass
main_source = "Live 1m (WebSocket)" main_source = "Live 1m (WebSocket)"
else: else:
# No data at all # No data at all
@ -2043,11 +2053,14 @@ class CleanTradingDashboard:
if ws_data_1s is not None and not ws_data_1s.empty: if ws_data_1s is not None and not ws_data_1s.empty:
try: try:
if hasattr(ws_data_1s.index, 'tz') and ws_data_1s.index.tz is not None: if hasattr(ws_data_1s.index, 'tz') and ws_data_1s.index.tz is not None:
ws_data_1s = ws_data_1s.tz_convert(_local_tz) if _local_tz else ws_data_1s ws_data_1s = ws_data_1s.copy()
ws_data_1s.index = ws_data_1s.index.tz_convert(_local_tz).tz_localize(None) if _local_tz else ws_data_1s.index.tz_localize(None)
else: else:
ws_data_1s = ws_data_1s.copy()
ws_data_1s.index = ws_data_1s.index.tz_localize('UTC') ws_data_1s.index = ws_data_1s.index.tz_localize('UTC')
if _local_tz: if _local_tz:
ws_data_1s = ws_data_1s.tz_convert(_local_tz) ws_data_1s.index = ws_data_1s.index.tz_convert(_local_tz)
ws_data_1s.index = ws_data_1s.index.tz_localize(None)
except Exception: except Exception:
pass pass
@ -2083,6 +2096,14 @@ class CleanTradingDashboard:
) )
has_mini_chart = False has_mini_chart = False
# Ensure main df_main index is local tz-naive to avoid offsets
try:
if hasattr(df_main.index, 'tz') and df_main.index.tz is not None:
df_main = df_main.copy()
df_main.index = df_main.index.tz_convert(_local_tz).tz_localize(None) if _local_tz else df_main.index.tz_localize(None)
except Exception:
pass
# Main 1-minute candlestick chart # Main 1-minute candlestick chart
fig.add_trace( fig.add_trace(
go.Candlestick( go.Candlestick(
@ -2128,7 +2149,7 @@ class CleanTradingDashboard:
ptype = getattr(p, 'type', 'low') ptype = getattr(p, 'type', 'low')
if ts is None or price is None: if ts is None or price is None:
continue continue
# Convert pivot timestamp to local tz to match chart axes # Convert pivot timestamp to local tz and make tz-naive to match chart axes
try: try:
if hasattr(ts, 'tzinfo') and ts.tzinfo is not None: if hasattr(ts, 'tzinfo') and ts.tzinfo is not None:
pt = ts.astimezone(_local_tz) if _local_tz else ts pt = ts.astimezone(_local_tz) if _local_tz else ts
@ -2136,6 +2157,11 @@ class CleanTradingDashboard:
# Assume UTC then convert # Assume UTC then convert
pt = ts.replace(tzinfo=timezone.utc) pt = ts.replace(tzinfo=timezone.utc)
pt = pt.astimezone(_local_tz) if _local_tz else pt pt = pt.astimezone(_local_tz) if _local_tz else pt
# Drop tzinfo for plotting
try:
pt = pt.replace(tzinfo=None)
except Exception:
pass
except Exception: except Exception:
pt = ts pt = ts
if start_ts <= pt <= end_ts: if start_ts <= pt <= end_ts:
@ -2163,6 +2189,13 @@ class CleanTradingDashboard:
# Mini 1-second chart (if available) # Mini 1-second chart (if available)
if has_mini_chart and ws_data_1s is not None: if has_mini_chart and ws_data_1s is not None:
# Align mini chart to local tz-naive
try:
if hasattr(ws_data_1s.index, 'tz') and ws_data_1s.index.tz is not None:
ws_data_1s = ws_data_1s.copy()
ws_data_1s.index = ws_data_1s.index.tz_convert(_local_tz).tz_localize(None) if _local_tz else ws_data_1s.index.tz_localize(None)
except Exception:
pass
fig.add_trace( fig.add_trace(
go.Scatter( go.Scatter(
x=ws_data_1s.index, x=ws_data_1s.index,
@ -3307,7 +3340,8 @@ class CleanTradingDashboard:
# Convert to DataFrame # Convert to DataFrame
df = pd.DataFrame(symbol_ticks) df = pd.DataFrame(symbol_ticks)
df['datetime'] = pd.to_datetime(df['datetime']) # Force UTC-aware timestamps for all websocket ticks
df['datetime'] = pd.to_datetime(df['datetime'], utc=True)
df.set_index('datetime', inplace=True) df.set_index('datetime', inplace=True)
# Get the price column (could be 'price', 'close', or 'c') # Get the price column (could be 'price', 'close', or 'c')