better rewards, fixed TZ at last

This commit is contained in:
Dobromir Popov
2025-08-08 01:53:17 +03:00
parent ded7e7f008
commit b80e1c1eba
3 changed files with 72 additions and 16 deletions

View File

@ -20,8 +20,10 @@ class TradingEnvironment(gym.Env):
window_size: int = 20,
risk_aversion: float = 0.2, # Controls how much to penalize volatility
price_scaling: str = 'zscore', # 'zscore', 'minmax', or 'raw'
reward_scaling: float = 10.0, # Scale factor for rewards
episode_penalty: float = 0.1): # Penalty for active positions at end of episode
reward_scaling: float = 10.0, # Scale factor for rewards
episode_penalty: float = 0.1, # Penalty for active positions at end of episode
min_profit_after_fees: float = 0.0005 # Deadzone: require >= 5 bps beyond fees
):
super(TradingEnvironment, self).__init__()
self.data = data
@ -33,6 +35,7 @@ class TradingEnvironment(gym.Env):
self.price_scaling = price_scaling
self.reward_scaling = reward_scaling
self.episode_penalty = episode_penalty
self.min_profit_after_fees = max(0.0, float(min_profit_after_fees))
# Preprocess data if needed
self._preprocess_data()
@ -177,8 +180,14 @@ class TradingEnvironment(gym.Env):
price_diff = current_price - self.entry_price
pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees
# Adjust reward based on PnL and risk
reward = pnl * self.reward_scaling
# Deadzone to discourage micro profits
if pnl > 0 and pnl < self.min_profit_after_fees:
reward = -self.fee_rate
elif pnl < 0 and abs(pnl) < self.min_profit_after_fees:
reward = pnl * self.reward_scaling * 0.5
else:
effective_pnl = pnl - (self.min_profit_after_fees if pnl > 0 else 0.0)
reward = effective_pnl * self.reward_scaling
# Track trade performance
self.total_trades += 1
@ -212,8 +221,12 @@ class TradingEnvironment(gym.Env):
price_diff = current_price - self.entry_price
unrealized_pnl = price_diff / self.entry_price
# Small reward/penalty based on unrealized P&L
reward = unrealized_pnl * 0.05 # Scale down to encourage holding good positions
# Encourage holding only if unrealized edge exceeds deadzone
unrealized_edge = unrealized_pnl
if abs(unrealized_edge) >= self.min_profit_after_fees:
reward = unrealized_edge * (self.reward_scaling * 0.2)
else:
reward = -0.0002
elif self.position < 0: # Short position
if action == 0: # BUY (close short)
@ -221,8 +234,13 @@ class TradingEnvironment(gym.Env):
price_diff = self.entry_price - current_price
pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees
# Adjust reward based on PnL and risk
reward = pnl * self.reward_scaling
if pnl > 0 and pnl < self.min_profit_after_fees:
reward = -self.fee_rate
elif pnl < 0 and abs(pnl) < self.min_profit_after_fees:
reward = pnl * self.reward_scaling * 0.5
else:
effective_pnl = pnl - (self.min_profit_after_fees if pnl > 0 else 0.0)
reward = effective_pnl * self.reward_scaling
# Track trade performance
self.total_trades += 1
@ -256,8 +274,12 @@ class TradingEnvironment(gym.Env):
price_diff = self.entry_price - current_price
unrealized_pnl = price_diff / self.entry_price
# Small reward/penalty based on unrealized P&L
reward = unrealized_pnl * 0.05 # Scale down to encourage holding good positions
# Encourage holding only if unrealized edge exceeds deadzone
unrealized_edge = unrealized_pnl
if abs(unrealized_edge) >= self.min_profit_after_fees:
reward = unrealized_edge * (self.reward_scaling * 0.2)
else:
reward = -0.0002
# Record the action
self.actions_taken.append(action)

Binary file not shown.

View File

@ -1996,12 +1996,15 @@ class CleanTradingDashboard:
try:
if hasattr(df_historical.index, 'tz') and df_historical.index.tz is not None:
df_historical_local = df_historical.tz_convert(_local_tz) if _local_tz else df_historical
# Drop tzinfo (tz-naive) for plotting
df_historical_local.index = df_historical_local.index.tz_localize(None)
else:
# Treat as UTC then convert to local
# Treat as UTC then convert to local tz, and make tz-naive
df_historical_local = df_historical.copy()
df_historical_local.index = df_historical_local.index.tz_localize('UTC')
if _local_tz:
df_historical_local = df_historical_local.tz_convert(_local_tz)
df_historical_local.index = df_historical_local.index.tz_localize(None)
except Exception:
df_historical_local = df_historical
@ -2010,11 +2013,13 @@ class CleanTradingDashboard:
try:
if hasattr(df_live.index, 'tz') and df_live.index.tz is not None:
df_live_local = df_live.tz_convert(_local_tz) if _local_tz else df_live
df_live_local.index = df_live_local.index.tz_localize(None)
else:
df_live_local = df_live.copy()
df_live_local.index = df_live_local.index.tz_localize('UTC')
if _local_tz:
df_live_local = df_live_local.tz_convert(_local_tz)
df_live_local.index = df_live_local.index.tz_localize(None)
except Exception:
df_live_local = df_live
@ -2030,8 +2035,13 @@ class CleanTradingDashboard:
df_main = df_historical_local
main_source = "Historical 1m"
elif df_live is not None and not df_live.empty:
# No historical data, use live only
# No historical data, use live only (ensure tz-naive)
df_main = df_live.tail(180)
try:
if hasattr(df_main.index, 'tz') and df_main.index.tz is not None:
df_main.index = df_main.index.tz_convert(_local_tz).tz_localize(None) if _local_tz else df_main.index.tz_localize(None)
except Exception:
pass
main_source = "Live 1m (WebSocket)"
else:
# No data at all
@ -2043,11 +2053,14 @@ class CleanTradingDashboard:
if ws_data_1s is not None and not ws_data_1s.empty:
try:
if hasattr(ws_data_1s.index, 'tz') and ws_data_1s.index.tz is not None:
ws_data_1s = ws_data_1s.tz_convert(_local_tz) if _local_tz else ws_data_1s
ws_data_1s = ws_data_1s.copy()
ws_data_1s.index = ws_data_1s.index.tz_convert(_local_tz).tz_localize(None) if _local_tz else ws_data_1s.index.tz_localize(None)
else:
ws_data_1s = ws_data_1s.copy()
ws_data_1s.index = ws_data_1s.index.tz_localize('UTC')
if _local_tz:
ws_data_1s = ws_data_1s.tz_convert(_local_tz)
ws_data_1s.index = ws_data_1s.index.tz_convert(_local_tz)
ws_data_1s.index = ws_data_1s.index.tz_localize(None)
except Exception:
pass
@ -2083,6 +2096,14 @@ class CleanTradingDashboard:
)
has_mini_chart = False
# Ensure main df_main index is local tz-naive to avoid offsets
try:
if hasattr(df_main.index, 'tz') and df_main.index.tz is not None:
df_main = df_main.copy()
df_main.index = df_main.index.tz_convert(_local_tz).tz_localize(None) if _local_tz else df_main.index.tz_localize(None)
except Exception:
pass
# Main 1-minute candlestick chart
fig.add_trace(
go.Candlestick(
@ -2128,7 +2149,7 @@ class CleanTradingDashboard:
ptype = getattr(p, 'type', 'low')
if ts is None or price is None:
continue
# Convert pivot timestamp to local tz to match chart axes
# Convert pivot timestamp to local tz and make tz-naive to match chart axes
try:
if hasattr(ts, 'tzinfo') and ts.tzinfo is not None:
pt = ts.astimezone(_local_tz) if _local_tz else ts
@ -2136,6 +2157,11 @@ class CleanTradingDashboard:
# Assume UTC then convert
pt = ts.replace(tzinfo=timezone.utc)
pt = pt.astimezone(_local_tz) if _local_tz else pt
# Drop tzinfo for plotting
try:
pt = pt.replace(tzinfo=None)
except Exception:
pass
except Exception:
pt = ts
if start_ts <= pt <= end_ts:
@ -2163,6 +2189,13 @@ class CleanTradingDashboard:
# Mini 1-second chart (if available)
if has_mini_chart and ws_data_1s is not None:
# Align mini chart to local tz-naive
try:
if hasattr(ws_data_1s.index, 'tz') and ws_data_1s.index.tz is not None:
ws_data_1s = ws_data_1s.copy()
ws_data_1s.index = ws_data_1s.index.tz_convert(_local_tz).tz_localize(None) if _local_tz else ws_data_1s.index.tz_localize(None)
except Exception:
pass
fig.add_trace(
go.Scatter(
x=ws_data_1s.index,
@ -3307,7 +3340,8 @@ class CleanTradingDashboard:
# Convert to DataFrame
df = pd.DataFrame(symbol_ticks)
df['datetime'] = pd.to_datetime(df['datetime'])
# Force UTC-aware timestamps for all websocket ticks
df['datetime'] = pd.to_datetime(df['datetime'], utc=True)
df.set_index('datetime', inplace=True)
# Get the price column (could be 'price', 'close', or 'c')