charts more or less OK

This commit is contained in:
Dobromir Popov 2025-04-01 21:37:08 +03:00
parent b0a57c5330
commit 938eef8bc9
3 changed files with 465 additions and 226 deletions

View File

@ -34,7 +34,7 @@ class RLTradingEnvironment(gym.Env):
Reinforcement Learning environment for trading with technical indicators
from multiple timeframes
"""
def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.001):
def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.0025, min_trade_interval=15):
super().__init__()
# Initialize attributes before parent class
@ -50,7 +50,8 @@ class RLTradingEnvironment(gym.Env):
# Trading parameters
self.initial_balance = 1.0
self.trading_fee = trading_fee
self.trading_fee = trading_fee # Increased from 0.001 to 0.0025 (0.25%)
self.min_trade_interval = min_trade_interval # Minimum steps between trades
# Define action and observation spaces
self.action_space = gym.spaces.Discrete(3) # 0: Buy, 1: Sell, 2: Hold
@ -76,6 +77,7 @@ class RLTradingEnvironment(gym.Env):
self.wins = 0
self.losses = 0
self.trade_history = []
self.last_trade_step = -self.min_trade_interval # Initialize to allow immediate first trade
# Get initial observation
observation = self._get_observation()
@ -150,24 +152,40 @@ class RLTradingEnvironment(gym.Env):
done = False
profit_pct = None # Initialize profit_pct variable
# Check if enough time has passed since last trade
trade_interval = self.current_step - self.last_trade_step
trade_interval_penalty = 0
# Execute action
if action == 0: # BUY
if self.position == 0: # Only buy if not already in position
# Apply extra penalty for trading too frequently
if trade_interval < self.min_trade_interval:
trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval)
# Still allow the trade but with penalty
self.position = self.balance * (1 - self.trading_fee)
self.balance = 0
self.trades += 1
reward = 0 # Neutral reward for entering position
reward = -0.001 + trade_interval_penalty # Small cost for transaction + potential penalty
self.trade_entry_price = current_price
self.last_trade_step = self.current_step
elif action == 1: # SELL
if self.position > 0: # Only sell if in position
# Apply extra penalty for trading too frequently
if trade_interval < self.min_trade_interval:
trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval)
# Still allow the trade but with penalty
# Calculate position value at current price
position_value = self.position * (1 + price_change)
self.balance = position_value * (1 - self.trading_fee)
# Calculate profit/loss from trade
profit_pct = (next_price - self.trade_entry_price) / self.trade_entry_price
reward = profit_pct * 10 # Scale reward by profit percentage
# Scale reward by profit percentage and apply trade interval penalty
reward = (profit_pct * 10) + trade_interval_penalty
# Update win/loss count
if profit_pct > 0:
@ -179,11 +197,13 @@ class RLTradingEnvironment(gym.Env):
self.trade_history.append({
'entry_price': self.trade_entry_price,
'exit_price': next_price,
'profit_pct': profit_pct
'profit_pct': profit_pct,
'trade_interval': trade_interval
})
# Reset position
# Reset position and update last trade step
self.position = 0
self.last_trade_step = self.current_step
# else: (action == 2 - HOLD) - no position change

View File

@ -33,6 +33,8 @@ class BinanceHistoricalData:
self.cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cache')
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)
# Timestamp of last data update
self.last_update = None
def get_historical_candles(self, symbol, interval_seconds=3600, limit=1000):
"""
@ -61,15 +63,18 @@ class BinanceHistoricalData:
interval = interval_map.get(interval_seconds, "1h")
# Format symbol for Binance API (remove slash)
formatted_symbol = symbol.replace("/", "")
formatted_symbol = symbol.replace("/", "").lower()
# Check if we have cached data first
cache_file = self._get_cache_filename(formatted_symbol, interval)
cached_data = self._load_from_cache(formatted_symbol, interval)
# If we have cached data that's recent enough, use it
if cached_data is not None and len(cached_data) >= limit:
logger.info(f"Using cached historical data for {symbol} ({interval})")
return cached_data
cache_age_minutes = (datetime.now() - self.last_update).total_seconds() / 60 if self.last_update else 60
if cache_age_minutes < 15: # Only use cache if it's less than 15 minutes old
logger.info(f"Using cached historical data for {symbol} ({interval})")
return cached_data
try:
# Build URL for klines endpoint
@ -106,6 +111,7 @@ class BinanceHistoricalData:
# Save to cache for future use
self._save_to_cache(df, formatted_symbol, interval)
self.last_update = datetime.now()
logger.info(f"Fetched {len(df)} candles for {symbol} ({interval})")
return df
@ -594,11 +600,13 @@ class TickStorage:
logger.error(f"Error loading ticks from file: {str(e)}")
def load_historical_data(self, historical_data, symbol):
"""Load historical data"""
"""Load historical data for all timeframes"""
try:
# Load data for different timeframes
timeframes = [
(1, '1s'), # 1 second
(1, '1s'), # 1 second - limit to 20 minutes (1200 seconds)
(5, '5s'), # 5 seconds
(15, '15s'), # 15 seconds
(60, '1m'), # 1 minute
(300, '5m'), # 5 minutes
(900, '15m'), # 15 minutes
@ -611,9 +619,9 @@ class TickStorage:
# Set appropriate limits based on timeframe
limit = 1000 # Default
if interval_seconds == 1:
limit = 500 # 1s is too much data, limit to 500
limit = 1200 # 1s data - limit to 20 minutes as requested
elif interval_seconds < 60:
limit = 750 # For seconds-level data
limit = 500 # For seconds-level data
elif interval_seconds < 300:
limit = 1000 # 1m
elif interval_seconds < 900:
@ -623,34 +631,86 @@ class TickStorage:
else:
limit = 200 # hourly/daily data
df = historical_data.get_historical_candles(symbol, interval_seconds, limit)
if df is not None and not df.empty:
logger.info(f"Loaded {len(df)} historical candles for {symbol} ({interval_key})")
try:
# For 1s data, we might need to generate it from 1m data
if interval_seconds == 1:
# Get 1m data first
df_1m = historical_data.get_historical_candles(symbol, 60, 60) # Get 60 minutes of 1m data
if df_1m is not None and not df_1m.empty:
# Create simulated 1s data from 1m data
simulated_1s = []
for _, row in df_1m.iterrows():
# For each 1m candle, create 60 1s candles
start_time = row['timestamp']
for i in range(60):
# Calculate second-level timestamp
second_time = start_time + timedelta(seconds=i)
# Convert to our candle format and store
for _, row in df.iterrows():
candle = {
'timestamp': row['timestamp'],
'open': row['open'],
'high': row['high'],
'low': row['low'],
'close': row['close'],
'volume': row['volume']
}
self.candles[interval_key].append(candle)
# Create candle with random price movement around close price
close_price = row['close']
price_range = (row['high'] - row['low']) / 60 # Reduced range
# For 1m and above, also use the close price to simulate ticks
# but don't do this for seconds-level data as it creates too many ticks
if interval_seconds >= 60 and interval_key == '1m':
self.add_tick(
price=row['close'],
volume=row['volume'],
timestamp=row['timestamp']
)
# Interpolate price - gradual movement from open to close
progress = i / 60
interp_price = row['open'] + (row['close'] - row['open']) * progress
# Update latest price from most recent candle
if len(df) > 0:
self.latest_price = df.iloc[-1]['close']
# Add some small random movement
random_factor = np.random.normal(0, price_range * 0.5)
s_price = max(0, interp_price + random_factor)
# Create 1s candle
s_candle = {
'timestamp': second_time,
'open': s_price,
'high': s_price * 1.0001, # Tiny movement
'low': s_price * 0.9999, # Tiny movement
'close': s_price,
'volume': row['volume'] / 60 # Distribute volume
}
simulated_1s.append(s_candle)
# Add the simulated 1s candles to our candles storage
self.candles['1s'] = simulated_1s
logger.info(f"Generated {len(simulated_1s)} simulated 1s candles for {symbol}")
else:
# Load normal historical data
df = historical_data.get_historical_candles(symbol, interval_seconds, limit)
if df is not None and not df.empty:
logger.info(f"Loaded {len(df)} historical candles for {symbol} ({interval_key})")
# Convert to our candle format and store
candles = []
for _, row in df.iterrows():
candle = {
'timestamp': row['timestamp'],
'open': row['open'],
'high': row['high'],
'low': row['low'],
'close': row['close'],
'volume': row['volume']
}
candles.append(candle)
# Set the candles for this timeframe
self.candles[interval_key] = candles
# For 1m data, also use it to generate tick data
if interval_key == '1m':
for candle in candles[-20:]: # Use only the last 20 candles for tick data
self.add_tick(
price=candle['close'],
volume=candle['volume'] / 10, # Distribute volume
timestamp=candle['timestamp']
)
# Set latest price from most recent candle
if candles:
self.latest_price = candles[-1]['close']
logger.info(f"Set latest price to ${self.latest_price:.2f} from historical data")
except Exception as e:
logger.error(f"Error loading {interval_key} data: {e}")
continue
logger.info(f"Completed loading historical data for {symbol}")
@ -926,125 +986,137 @@ class RealTimeChart:
return {}, {}, [], "Error", "$0.00", "$0.00"
def _update_main_chart(self, interval=1):
"""Update the main chart with the selected timeframe"""
"""Update the main chart with OHLC data"""
try:
# Get candle data for the selected interval
candles = self.get_candles(interval_seconds=interval)
# Get candles for the interval
interval_key = self._get_interval_key(interval)
if not candles or len(candles) == 0:
# Return empty chart if no data
# Make sure we have data for this interval
if interval_key not in self.tick_storage.candles or not self.tick_storage.candles[interval_key]:
logger.warning(f"No candle data available for {interval_key}")
# Return empty figure with a message
fig = go.Figure()
fig.add_annotation(
text=f"No data available for {interval_key}",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False
)
fig.update_layout(title=f"{self.symbol} - {interval_key}")
return fig
# For rendering, limit to the last 500 candles for performance
candles = self.tick_storage.candles[interval_key][-500:]
# Ensure we have at least 1 candle
if not candles:
logger.warning(f"No historical candles available for {interval_key}")
return go.Figure()
# Create the candlestick chart
fig = go.Figure()
# Extract OHLC values
timestamps = [candle['timestamp'] for candle in candles]
opens = [candle['open'] for candle in candles]
highs = [candle['high'] for candle in candles]
lows = [candle['low'] for candle in candles]
closes = [candle['close'] for candle in candles]
volumes = [candle['volume'] for candle in candles]
# Create figure
fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
vertical_spacing=0.02,
row_heights=[0.8, 0.2],
specs=[[{"type": "candlestick"}],
[{"type": "bar"}]])
# Add candlestick trace
fig.add_trace(go.Candlestick(
x=[c['timestamp'] for c in candles],
open=[c['open'] for c in candles],
high=[c['high'] for c in candles],
low=[c['low'] for c in candles],
close=[c['close'] for c in candles],
name='Price'
))
x=timestamps,
open=opens,
high=highs,
low=lows,
close=closes,
name='OHLC',
increasing_line_color='rgba(0, 180, 0, 0.7)',
decreasing_line_color='rgba(255, 0, 0, 0.7)',
), row=1, col=1)
# Add volume as a bar chart below
# Add volume bars
fig.add_trace(go.Bar(
x=[c['timestamp'] for c in candles],
y=[c['volume'] for c in candles],
x=timestamps,
y=volumes,
name='Volume',
marker=dict(
color='rgba(0, 0, 255, 0.5)',
),
opacity=0.5,
yaxis='y2'
))
marker_color='rgba(100, 100, 255, 0.5)'
), row=2, col=1)
# Add buy/sell markers for trades
# Add trading markers if available
if hasattr(self, 'positions') and self.positions:
buy_times = []
# Get last 100 positions for display (to avoid too many markers)
positions = self.positions[-100:]
buy_timestamps = []
buy_prices = []
sell_times = []
sell_timestamps = []
sell_prices = []
# Use all positions for chart display
# Filter recent ones based on visible time range
now = datetime.now()
time_limit = now - timedelta(hours=24) # Show at most 24h of trades
for pos in positions:
if pos.action == 'BUY':
buy_timestamps.append(pos.entry_timestamp)
buy_prices.append(pos.entry_price)
elif pos.action == 'SELL':
sell_timestamps.append(pos.entry_timestamp) # Using entry_time for consistency
sell_prices.append(pos.entry_price) # Using entry_price for consistency
for position in self.positions:
if position.entry_timestamp > time_limit:
if position.action == "BUY":
buy_times.append(position.entry_timestamp)
buy_prices.append(position.entry_price)
elif position.action == "SELL" and position.exit_timestamp:
sell_times.append(position.exit_timestamp)
sell_prices.append(position.exit_price)
# Add buy markers (green triangles pointing up)
if buy_times:
# Add buy markers
if buy_timestamps:
fig.add_trace(go.Scatter(
x=buy_times,
x=buy_timestamps,
y=buy_prices,
mode='markers',
name='Buy',
marker=dict(
symbol='triangle-up',
size=10,
color='green',
line=dict(width=1, color='black')
size=15,
color='rgba(0, 180, 0, 0.8)',
line=dict(width=1, color='rgba(0, 180, 0, 1)')
)
))
), row=1, col=1)
# Add sell markers (red triangles pointing down)
if sell_times:
# Add sell markers
if sell_timestamps:
fig.add_trace(go.Scatter(
x=sell_times,
x=sell_timestamps,
y=sell_prices,
mode='markers',
name='Sell',
marker=dict(
symbol='triangle-down',
size=10,
color='red',
line=dict(width=1, color='black')
size=15,
color='rgba(255, 0, 0, 0.8)',
line=dict(width=1, color='rgba(255, 0, 0, 1)')
)
))
), row=1, col=1)
# Update layout
timeframe_label = f"{interval}s" if interval < 60 else f"{interval//60}m" if interval < 3600 else f"{interval//3600}h"
fig.update_layout(
title=f'{self.symbol} Price ({timeframe_label})',
xaxis_title='Time',
yaxis_title='Price',
template='plotly_dark',
xaxis_rangeslider_visible=False,
title=f"{self.symbol} - {interval_key}",
xaxis_title="Time",
yaxis_title="Price",
height=600,
hovermode='x unified',
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
),
yaxis=dict(
domain=[0.25, 1]
),
yaxis2=dict(
domain=[0, 0.2],
title='Volume'
),
template="plotly_dark",
showlegend=True,
margin=dict(l=0, r=0, t=50, b=20),
legend=dict(orientation="h", y=1.02, x=0.5, xanchor="center"),
uirevision='true' # To maintain zoom level on updates
)
# Add timestamp to show when chart was last updated
fig.add_annotation(
text=f"Last updated: {datetime.now().strftime('%H:%M:%S')}",
xref="paper", yref="paper",
x=0.98, y=0.01,
showarrow=False,
font=dict(size=10, color="gray")
# Format Y-axis with enough decimal places for cryptocurrency
fig.update_yaxes(tickformat=".2f")
# Format X-axis with date/time
fig.update_xaxes(
rangeslider_visible=False,
rangebreaks=[
dict(bounds=["sat", "mon"]) # hide weekends
]
)
return fig
@ -1053,71 +1125,90 @@ class RealTimeChart:
logger.error(f"Error updating main chart: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return go.Figure() # Return empty figure on error
# Return empty figure on error
return go.Figure()
def _update_secondary_charts(self):
"""Create secondary charts with multiple timeframes (1m, 1h, 1d)"""
"""Update the secondary charts for other timeframes"""
try:
# Create subplot with 3 rows
# For each timeframe, create a small chart
secondary_timeframes = ['1m', '5m', '15m', '1h']
if not all(tf in self.tick_storage.candles for tf in secondary_timeframes):
logger.warning("Not all secondary timeframes available")
# Return empty figure with a message
fig = make_subplots(rows=1, cols=4)
for i, tf in enumerate(secondary_timeframes, 1):
fig.add_annotation(
text=f"No data for {tf}",
xref=f"x{i}", yref=f"y{i}",
x=0.5, y=0.5, showarrow=False
)
return fig
# Create subplots for each timeframe
fig = make_subplots(
rows=3, cols=1,
shared_xaxes=False,
vertical_spacing=0.05,
subplot_titles=('1 Minute', '1 Hour', '1 Day')
rows=1, cols=4,
subplot_titles=secondary_timeframes,
shared_yaxes=True
)
# Get data for each timeframe
candles_1m = self.get_candles(interval_seconds=60)
candles_1h = self.get_candles(interval_seconds=3600)
candles_1d = self.get_candles(interval_seconds=86400)
# Loop through each timeframe
for i, timeframe in enumerate(secondary_timeframes, 1):
interval_key = timeframe
# 1-minute chart (row 1)
if candles_1m and len(candles_1m) > 0:
fig.add_trace(go.Candlestick(
x=[c['timestamp'] for c in candles_1m],
open=[c['open'] for c in candles_1m],
high=[c['high'] for c in candles_1m],
low=[c['low'] for c in candles_1m],
close=[c['close'] for c in candles_1m],
name='1m Price',
showlegend=False
), row=1, col=1)
# Get candles for this timeframe
if interval_key in self.tick_storage.candles and self.tick_storage.candles[interval_key]:
# For rendering, limit to the last 100 candles for performance
candles = self.tick_storage.candles[interval_key][-100:]
# 1-hour chart (row 2)
if candles_1h and len(candles_1h) > 0:
fig.add_trace(go.Candlestick(
x=[c['timestamp'] for c in candles_1h],
open=[c['open'] for c in candles_1h],
high=[c['high'] for c in candles_1h],
low=[c['low'] for c in candles_1h],
close=[c['close'] for c in candles_1h],
name='1h Price',
showlegend=False
), row=2, col=1)
if candles:
# Extract OHLC values
timestamps = [candle['timestamp'] for candle in candles]
opens = [candle['open'] for candle in candles]
highs = [candle['high'] for candle in candles]
lows = [candle['low'] for candle in candles]
closes = [candle['close'] for candle in candles]
# 1-day chart (row 3)
if candles_1d and len(candles_1d) > 0:
fig.add_trace(go.Candlestick(
x=[c['timestamp'] for c in candles_1d],
open=[c['open'] for c in candles_1d],
high=[c['high'] for c in candles_1d],
low=[c['low'] for c in candles_1d],
close=[c['close'] for c in candles_1d],
name='1d Price',
showlegend=False
), row=3, col=1)
# Add candlestick trace
fig.add_trace(go.Candlestick(
x=timestamps,
open=opens,
high=highs,
low=lows,
close=closes,
name=interval_key,
increasing_line_color='rgba(0, 180, 0, 0.7)',
decreasing_line_color='rgba(255, 0, 0, 0.7)',
showlegend=False
), row=1, col=i)
else:
# Add empty annotation if no data
fig.add_annotation(
text=f"No data for {interval_key}",
xref=f"x{i}", yref=f"y{i}",
x=0.5, y=0.5, showarrow=False
)
# Update layout
fig.update_layout(
height=500,
template='plotly_dark',
margin=dict(l=50, r=50, t=30, b=30),
height=250,
template="plotly_dark",
showlegend=False,
hovermode='x unified'
margin=dict(l=0, r=0, t=30, b=0),
)
# Disable rangesliders for cleaner look
fig.update_xaxes(rangeslider_visible=False)
# Format Y-axis with 2 decimal places
fig.update_yaxes(tickformat=".2f")
# Format X-axis to show only the date (no time)
for i in range(1, 5):
fig.update_xaxes(
row=1, col=i,
rangeslider_visible=False,
rangebreaks=[dict(bounds=["sat", "mon"])], # hide weekends
tickformat="%m-%d" # Show month-day only
)
return fig
@ -1125,7 +1216,8 @@ class RealTimeChart:
logger.error(f"Error updating secondary charts: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return go.Figure() # Return empty figure on error
# Return empty figure on error
return make_subplots(rows=1, cols=4)
def _get_position_list_rows(self):
"""Generate HTML for the positions list (last 10 positions only)"""
@ -1259,6 +1351,19 @@ class RealTimeChart:
async def start_websocket(self):
"""Start the websocket connection for real-time data"""
try:
# Load historical data first to ensure we have candles for all timeframes
logger.info(f"Loading historical data for {self.symbol}")
# Initialize a BinanceHistoricalData instance
historical_data = BinanceHistoricalData()
# Load historical data for display
self.tick_storage.load_historical_data(historical_data, self.symbol)
# Make sure we update the charts once with historical data before websocket starts
# Update all the charts with the initial historical data
self._update_chart_and_positions()
# Initialize websocket
self.websocket = ExchangeWebSocket(self.symbol)
await self.websocket.connect()

View File

@ -299,14 +299,30 @@ class RLTrainingIntegrator:
# Create a custom environment class that includes our reward function modification
class EnhancedRLTradingEnvironment(RLTradingEnvironment):
def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.001):
"""Initialize with normalization parameters"""
super().__init__(features_1m, features_5m, features_15m, window_size, trading_fee)
# Initialize integrator and chart references
self.integrator = None # Will be set after initialization
self.chart = None # Will be set after initialization
# Make writer accessible to integrator callbacks
self.writer = None # Will be set by train_rl
def __init__(self, features_1m, features_5m, features_15m, window_size=20, trading_fee=0.0025, min_trade_interval=15):
super().__init__(features_1m, features_5m, features_15m, window_size, trading_fee, min_trade_interval)
# Reference to integrator for tracking
self.integrator = None
# Store the original data for extrema analysis
self.original_data = None
# RNN signal integration
self.signal_interpreter = None
self.last_rnn_signals = []
self.rnn_signal_weight = 0.3 # Weight for RNN signals in decision making
# TensorBoard writer
self.writer = None
def set_integrator(self, integrator):
"""Set reference to integrator for callbacks"""
self.integrator = integrator
def set_signal_interpreter(self, signal_interpreter):
"""Set reference to signal interpreter for RNN signal integration"""
self.signal_interpreter = signal_interpreter
def set_tensorboard_writer(self, writer):
"""Set the TensorBoard writer"""
@ -314,50 +330,133 @@ class RLTrainingIntegrator:
def _calculate_reward(self, action):
"""Override the reward calculation with our enhanced version"""
# Get the original reward calculation result
reward, pnl = super()._calculate_reward(action)
# Get current price (normalized from training data)
# Get current and next price
current_price = self.features_1m[self.current_step, -1]
next_price = self.features_1m[self.current_step + 1, -1]
# Get real market price if available
# Default values
pnl = 0.0
reward = -0.0001 # Small negative reward to discourage excessive actions
# Get real market price if available (from integrator)
real_market_price = None
if hasattr(self, 'chart') and self.chart and hasattr(self.chart, 'latest_price'):
real_market_price = self.chart.latest_price
if self.integrator and hasattr(self.integrator, 'chart') and self.integrator.chart:
if hasattr(self.integrator.chart, 'tick_storage'):
real_market_price = self.integrator.chart.tick_storage.get_latest_price()
# Pass through the integrator's reward modifier
if hasattr(self, 'integrator') and self.integrator is not None:
# Add price to history - use real market price if available
if real_market_price is not None:
# For extrema detection, use a normalized version of the real price
# to keep scale consistent with the model's price history
self.integrator.price_history.append(current_price)
# Calculate base reward based on position and price change
if action == 0: # BUY
# Apply fee directly as negative reward to discourage excessive trading
reward -= self.trading_fee
# Check if we already have a position
if self.integrator and self.integrator.current_position_size > 0:
reward -= 0.002 # Additional penalty for trying to buy when already in position
# If RNN signal available, incorporate it
if self.signal_interpreter and len(self.last_rnn_signals) > 0:
last_signal = self.last_rnn_signals[-1]
if last_signal['action'] == 'BUY':
# RNN also suggests BUY - boost reward
reward += 0.003 * self.rnn_signal_weight * last_signal.get('confidence', 1.0)
elif last_signal['action'] == 'SELL':
# RNN suggests opposite - reduce reward
reward -= 0.003 * self.rnn_signal_weight * last_signal.get('confidence', 1.0)
elif action == 1: # SELL
if self.integrator and self.integrator.current_position_size > 0:
# Calculate potential profit/loss
if self.integrator.entry_price:
price_to_use = real_market_price if real_market_price else current_price
pnl = (price_to_use - self.integrator.entry_price) / self.integrator.entry_price
# Base reward on actual PnL
reward = pnl * 10
# Apply fee as negative component
reward -= self.trading_fee
# If RNN signal available, incorporate it
if self.signal_interpreter and len(self.last_rnn_signals) > 0:
last_signal = self.last_rnn_signals[-1]
if last_signal['action'] == 'SELL':
# RNN also suggests SELL - boost reward
reward += 0.003 * self.rnn_signal_weight * last_signal.get('confidence', 1.0)
elif last_signal['action'] == 'BUY':
# RNN suggests opposite - reduce reward
reward -= 0.003 * self.rnn_signal_weight * last_signal.get('confidence', 1.0)
else:
self.integrator.price_history.append(current_price)
# No position to sell - penalize
reward = -0.005
# Apply extrema-based reward modifications
if len(self.integrator.price_history) > 20:
# Detect local extrema
tops_indices, bottoms_indices = self.integrator.extrema_detector.find_extrema(
self.integrator.price_history
)
elif action == 2: # HOLD
# Check if we're holding a profitable position
if self.integrator and self.integrator.current_position_size > 0 and self.integrator.entry_price:
price_to_use = real_market_price if real_market_price else current_price
pnl = (price_to_use - self.integrator.entry_price) / self.integrator.entry_price
# Calculate additional rewards based on extrema
if action == 0 and bottoms_indices and bottoms_indices[-1] > len(self.integrator.price_history) - 5:
# Bonus for buying near bottoms
reward += 0.01
if self.integrator.session_step % 50 == 0: # Log less frequently
# Display the real market price if available
display_price = real_market_price if real_market_price is not None else current_price
logger.info(f"BUY signal near bottom detected at price {display_price:.2f}! Adding bonus reward.")
# Encourage holding profitable positions
if pnl > 0:
reward = 0.0001 * pnl * 5 # Small positive reward for holding winner
elif action == 1 and tops_indices and tops_indices[-1] > len(self.integrator.price_history) - 5:
# Bonus for selling near tops
reward += 0.01
if self.integrator.session_step % 50 == 0: # Log less frequently
# Display the real market price if available
display_price = real_market_price if real_market_price is not None else current_price
logger.info(f"SELL signal near top detected at price {display_price:.2f}! Adding bonus reward.")
# If position is very profitable, increase hold reward
if pnl > 0.01: # Over 1% profit
reward *= 2
else:
# Small negative reward for holding losing position
reward = -0.0001 * abs(pnl) * 2
# If RNN signal suggests HOLD, add small reward
if self.signal_interpreter and len(self.last_rnn_signals) > 0:
last_signal = self.last_rnn_signals[-1]
if last_signal['action'] == 'HOLD':
reward += 0.0001 * self.rnn_signal_weight
# Add price to history - use real market price if available
if real_market_price is not None:
# For extrema detection, use a normalized version of the real price
# to keep scale consistent with the model's price history
self.integrator.price_history.append(current_price)
else:
self.integrator.price_history.append(current_price)
# Apply extrema-based reward modifications
if len(self.integrator.price_history) > 20:
# Detect local extrema
tops_indices, bottoms_indices = self.integrator.extrema_detector.find_extrema(
self.integrator.price_history
)
# Get current price and market context
current_price = self.integrator.price_history[-1]
# Check if we're near a local extrema (top or bottom)
is_near_bottom = any(i > len(self.integrator.price_history) - 5 for i in bottoms_indices)
is_near_top = any(i > len(self.integrator.price_history) - 5 for i in tops_indices)
# Modify reward based on action and extrema
if action == 0 and is_near_bottom: # BUY near bottom
logger.info("Buying near local bottom - adding bonus reward")
reward += 0.015 # Significant bonus
elif action == 0 and is_near_top: # BUY near top
logger.info("Buying near local top - applying penalty")
reward -= 0.01 # Penalty
elif action == 1 and is_near_top: # SELL near top
logger.info("Selling near local top - adding bonus reward")
reward += 0.015 # Significant bonus
elif action == 1 and is_near_bottom: # SELL near bottom
logger.info("Selling near local bottom - applying penalty")
reward -= 0.01 # Penalty
elif action == 2: # HOLD
if is_near_bottom and self.integrator.current_position_size > 0:
# Good to hold if we have positions at bottom
reward += 0.002 # Small bonus
elif is_near_top and self.integrator.current_position_size == 0:
# Good to hold if we have no positions at top
reward += 0.002 # Small bonus
# Limit extreme rewards
reward = max(min(reward, 0.5), -0.5)
return reward, pnl
@ -758,7 +857,22 @@ def _add_trade_compat(chart, price, timestamp, amount, pnl=0.0, action="BUY"):
# For SELL actions, close the position with given PnL
if action == "SELL":
# Find the most recent BUY position that hasn't been closed
entry_position = None
entry_price = price # Default if no open position found
for pos in reversed(chart.positions):
if pos.action == "BUY" and pos.is_open:
entry_position = pos
entry_price = pos.entry_price
# Mark this position as closed
pos.close(price, timestamp)
break
# Close this sell position with the right prices
position.entry_price = entry_price # Use the found entry price
position.close(price, timestamp)
# Use realistic PnL values rather than the enormous ones from the model
# Cap PnL to reasonable values based on position size and price
max_reasonable_pnl = price * amount * 0.05 # Max 5% profit per trade