better rewards, fixed TZ at last
This commit is contained in:
@ -21,7 +21,9 @@ class TradingEnvironment(gym.Env):
|
||||
risk_aversion: float = 0.2, # Controls how much to penalize volatility
|
||||
price_scaling: str = 'zscore', # 'zscore', 'minmax', or 'raw'
|
||||
reward_scaling: float = 10.0, # Scale factor for rewards
|
||||
episode_penalty: float = 0.1): # Penalty for active positions at end of episode
|
||||
episode_penalty: float = 0.1, # Penalty for active positions at end of episode
|
||||
min_profit_after_fees: float = 0.0005 # Deadzone: require >= 5 bps beyond fees
|
||||
):
|
||||
super(TradingEnvironment, self).__init__()
|
||||
|
||||
self.data = data
|
||||
@ -33,6 +35,7 @@ class TradingEnvironment(gym.Env):
|
||||
self.price_scaling = price_scaling
|
||||
self.reward_scaling = reward_scaling
|
||||
self.episode_penalty = episode_penalty
|
||||
self.min_profit_after_fees = max(0.0, float(min_profit_after_fees))
|
||||
|
||||
# Preprocess data if needed
|
||||
self._preprocess_data()
|
||||
@ -177,8 +180,14 @@ class TradingEnvironment(gym.Env):
|
||||
price_diff = current_price - self.entry_price
|
||||
pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees
|
||||
|
||||
# Adjust reward based on PnL and risk
|
||||
reward = pnl * self.reward_scaling
|
||||
# Deadzone to discourage micro profits
|
||||
if pnl > 0 and pnl < self.min_profit_after_fees:
|
||||
reward = -self.fee_rate
|
||||
elif pnl < 0 and abs(pnl) < self.min_profit_after_fees:
|
||||
reward = pnl * self.reward_scaling * 0.5
|
||||
else:
|
||||
effective_pnl = pnl - (self.min_profit_after_fees if pnl > 0 else 0.0)
|
||||
reward = effective_pnl * self.reward_scaling
|
||||
|
||||
# Track trade performance
|
||||
self.total_trades += 1
|
||||
@ -212,8 +221,12 @@ class TradingEnvironment(gym.Env):
|
||||
price_diff = current_price - self.entry_price
|
||||
unrealized_pnl = price_diff / self.entry_price
|
||||
|
||||
# Small reward/penalty based on unrealized P&L
|
||||
reward = unrealized_pnl * 0.05 # Scale down to encourage holding good positions
|
||||
# Encourage holding only if unrealized edge exceeds deadzone
|
||||
unrealized_edge = unrealized_pnl
|
||||
if abs(unrealized_edge) >= self.min_profit_after_fees:
|
||||
reward = unrealized_edge * (self.reward_scaling * 0.2)
|
||||
else:
|
||||
reward = -0.0002
|
||||
|
||||
elif self.position < 0: # Short position
|
||||
if action == 0: # BUY (close short)
|
||||
@ -221,8 +234,13 @@ class TradingEnvironment(gym.Env):
|
||||
price_diff = self.entry_price - current_price
|
||||
pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees
|
||||
|
||||
# Adjust reward based on PnL and risk
|
||||
reward = pnl * self.reward_scaling
|
||||
if pnl > 0 and pnl < self.min_profit_after_fees:
|
||||
reward = -self.fee_rate
|
||||
elif pnl < 0 and abs(pnl) < self.min_profit_after_fees:
|
||||
reward = pnl * self.reward_scaling * 0.5
|
||||
else:
|
||||
effective_pnl = pnl - (self.min_profit_after_fees if pnl > 0 else 0.0)
|
||||
reward = effective_pnl * self.reward_scaling
|
||||
|
||||
# Track trade performance
|
||||
self.total_trades += 1
|
||||
@ -256,8 +274,12 @@ class TradingEnvironment(gym.Env):
|
||||
price_diff = self.entry_price - current_price
|
||||
unrealized_pnl = price_diff / self.entry_price
|
||||
|
||||
# Small reward/penalty based on unrealized P&L
|
||||
reward = unrealized_pnl * 0.05 # Scale down to encourage holding good positions
|
||||
# Encourage holding only if unrealized edge exceeds deadzone
|
||||
unrealized_edge = unrealized_pnl
|
||||
if abs(unrealized_edge) >= self.min_profit_after_fees:
|
||||
reward = unrealized_edge * (self.reward_scaling * 0.2)
|
||||
else:
|
||||
reward = -0.0002
|
||||
|
||||
# Record the action
|
||||
self.actions_taken.append(action)
|
||||
|
Binary file not shown.
@ -1996,12 +1996,15 @@ class CleanTradingDashboard:
|
||||
try:
|
||||
if hasattr(df_historical.index, 'tz') and df_historical.index.tz is not None:
|
||||
df_historical_local = df_historical.tz_convert(_local_tz) if _local_tz else df_historical
|
||||
# Drop tzinfo (tz-naive) for plotting
|
||||
df_historical_local.index = df_historical_local.index.tz_localize(None)
|
||||
else:
|
||||
# Treat as UTC then convert to local
|
||||
# Treat as UTC then convert to local tz, and make tz-naive
|
||||
df_historical_local = df_historical.copy()
|
||||
df_historical_local.index = df_historical_local.index.tz_localize('UTC')
|
||||
if _local_tz:
|
||||
df_historical_local = df_historical_local.tz_convert(_local_tz)
|
||||
df_historical_local.index = df_historical_local.index.tz_localize(None)
|
||||
except Exception:
|
||||
df_historical_local = df_historical
|
||||
|
||||
@ -2010,11 +2013,13 @@ class CleanTradingDashboard:
|
||||
try:
|
||||
if hasattr(df_live.index, 'tz') and df_live.index.tz is not None:
|
||||
df_live_local = df_live.tz_convert(_local_tz) if _local_tz else df_live
|
||||
df_live_local.index = df_live_local.index.tz_localize(None)
|
||||
else:
|
||||
df_live_local = df_live.copy()
|
||||
df_live_local.index = df_live_local.index.tz_localize('UTC')
|
||||
if _local_tz:
|
||||
df_live_local = df_live_local.tz_convert(_local_tz)
|
||||
df_live_local.index = df_live_local.index.tz_localize(None)
|
||||
except Exception:
|
||||
df_live_local = df_live
|
||||
|
||||
@ -2030,8 +2035,13 @@ class CleanTradingDashboard:
|
||||
df_main = df_historical_local
|
||||
main_source = "Historical 1m"
|
||||
elif df_live is not None and not df_live.empty:
|
||||
# No historical data, use live only
|
||||
# No historical data, use live only (ensure tz-naive)
|
||||
df_main = df_live.tail(180)
|
||||
try:
|
||||
if hasattr(df_main.index, 'tz') and df_main.index.tz is not None:
|
||||
df_main.index = df_main.index.tz_convert(_local_tz).tz_localize(None) if _local_tz else df_main.index.tz_localize(None)
|
||||
except Exception:
|
||||
pass
|
||||
main_source = "Live 1m (WebSocket)"
|
||||
else:
|
||||
# No data at all
|
||||
@ -2043,11 +2053,14 @@ class CleanTradingDashboard:
|
||||
if ws_data_1s is not None and not ws_data_1s.empty:
|
||||
try:
|
||||
if hasattr(ws_data_1s.index, 'tz') and ws_data_1s.index.tz is not None:
|
||||
ws_data_1s = ws_data_1s.tz_convert(_local_tz) if _local_tz else ws_data_1s
|
||||
ws_data_1s = ws_data_1s.copy()
|
||||
ws_data_1s.index = ws_data_1s.index.tz_convert(_local_tz).tz_localize(None) if _local_tz else ws_data_1s.index.tz_localize(None)
|
||||
else:
|
||||
ws_data_1s = ws_data_1s.copy()
|
||||
ws_data_1s.index = ws_data_1s.index.tz_localize('UTC')
|
||||
if _local_tz:
|
||||
ws_data_1s = ws_data_1s.tz_convert(_local_tz)
|
||||
ws_data_1s.index = ws_data_1s.index.tz_convert(_local_tz)
|
||||
ws_data_1s.index = ws_data_1s.index.tz_localize(None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -2083,6 +2096,14 @@ class CleanTradingDashboard:
|
||||
)
|
||||
has_mini_chart = False
|
||||
|
||||
# Ensure main df_main index is local tz-naive to avoid offsets
|
||||
try:
|
||||
if hasattr(df_main.index, 'tz') and df_main.index.tz is not None:
|
||||
df_main = df_main.copy()
|
||||
df_main.index = df_main.index.tz_convert(_local_tz).tz_localize(None) if _local_tz else df_main.index.tz_localize(None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Main 1-minute candlestick chart
|
||||
fig.add_trace(
|
||||
go.Candlestick(
|
||||
@ -2128,7 +2149,7 @@ class CleanTradingDashboard:
|
||||
ptype = getattr(p, 'type', 'low')
|
||||
if ts is None or price is None:
|
||||
continue
|
||||
# Convert pivot timestamp to local tz to match chart axes
|
||||
# Convert pivot timestamp to local tz and make tz-naive to match chart axes
|
||||
try:
|
||||
if hasattr(ts, 'tzinfo') and ts.tzinfo is not None:
|
||||
pt = ts.astimezone(_local_tz) if _local_tz else ts
|
||||
@ -2136,6 +2157,11 @@ class CleanTradingDashboard:
|
||||
# Assume UTC then convert
|
||||
pt = ts.replace(tzinfo=timezone.utc)
|
||||
pt = pt.astimezone(_local_tz) if _local_tz else pt
|
||||
# Drop tzinfo for plotting
|
||||
try:
|
||||
pt = pt.replace(tzinfo=None)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pt = ts
|
||||
if start_ts <= pt <= end_ts:
|
||||
@ -2163,6 +2189,13 @@ class CleanTradingDashboard:
|
||||
|
||||
# Mini 1-second chart (if available)
|
||||
if has_mini_chart and ws_data_1s is not None:
|
||||
# Align mini chart to local tz-naive
|
||||
try:
|
||||
if hasattr(ws_data_1s.index, 'tz') and ws_data_1s.index.tz is not None:
|
||||
ws_data_1s = ws_data_1s.copy()
|
||||
ws_data_1s.index = ws_data_1s.index.tz_convert(_local_tz).tz_localize(None) if _local_tz else ws_data_1s.index.tz_localize(None)
|
||||
except Exception:
|
||||
pass
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=ws_data_1s.index,
|
||||
@ -3307,7 +3340,8 @@ class CleanTradingDashboard:
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(symbol_ticks)
|
||||
df['datetime'] = pd.to_datetime(df['datetime'])
|
||||
# Force UTC-aware timestamps for all websocket ticks
|
||||
df['datetime'] = pd.to_datetime(df['datetime'], utc=True)
|
||||
df.set_index('datetime', inplace=True)
|
||||
|
||||
# Get the price column (could be 'price', 'close', or 'c')
|
||||
|
Reference in New Issue
Block a user