realtime data in main

This commit is contained in:
Dobromir Popov 2025-03-19 04:18:55 +02:00
parent a8954bbf99
commit 1a1c410922

360
main.py
View File

@ -38,6 +38,7 @@ from dash.dependencies import Input, Output, State
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from threading import Thread
import socket
# Configure logging
logging.basicConfig(
@ -971,12 +972,12 @@ class TradingEnvironment:
def calculate_reward(self, action):
"""Calculate reward for the given action with improved penalties for losing trades"""
"""Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals"""
reward = 0
# Base reward for actions
if action == 0: # HOLD
reward = -0.01 # Small penalty for doing nothing
reward = -0.05 # Increased penalty for doing nothing to encourage more trading
elif action == 1: # BUY/LONG
if self.position == 'flat':
@ -990,13 +991,26 @@ class TradingEnvironment:
# Check if this is an optimal buy point (bottom)
current_idx = len(self.features['price']) - 1
if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms:
reward += 2.0 # Bonus for buying at a bottom
reward += 3.0 # Increased bonus for buying at a bottom
# Check for volume spike (indicating potential big movement)
if len(self.features['volume']) > 5:
avg_volume = np.mean(self.features['volume'][-5:-1])
current_volume = self.features['volume'][-1]
if current_volume > avg_volume * 1.5:
reward += 2.0 # Bonus for entering during high volume
# Check for price action signals
if self.features['rsi'][-1] < 30: # Oversold condition
reward += 1.5 # Bonus for buying at oversold levels
# Check if we're buying in a clear uptrend (good)
if self.is_uptrend():
reward += 1.0 # Bonus for buying in uptrend
elif self.is_downtrend():
reward -= 0.25 # Reduced penalty for buying in downtrend
else:
# Check if we're buying in a downtrend (bad)
if self.is_downtrend():
reward -= 0.5 # Penalty for buying in downtrend
else:
reward += 0.1 # Small reward for opening a position
reward += 0.2 # Small reward for opening a position
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
@ -1066,9 +1080,26 @@ class TradingEnvironment:
# Check if this is an optimal sell point (top)
current_idx = len(self.features['price']) - 1
if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops:
reward += 2.0 # Bonus for selling at a top
reward += 3.0 # Increased bonus for selling at a top
# Check for volume spike
if len(self.features['volume']) > 5:
avg_volume = np.mean(self.features['volume'][-5:-1])
current_volume = self.features['volume'][-1]
if current_volume > avg_volume * 1.5:
reward += 2.0 # Bonus for entering during high volume
# Check for price action signals
if self.features['rsi'][-1] > 70: # Overbought condition
reward += 1.5 # Bonus for selling at overbought levels
# Check if we're selling in a clear downtrend (good)
if self.is_downtrend():
reward += 1.0 # Bonus for selling in downtrend
elif self.is_uptrend():
reward -= 0.25 # Reduced penalty for selling in uptrend
else:
reward += 0.1 # Small reward for opening a position
reward += 0.2 # Small reward for opening a position
logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
@ -1714,15 +1745,32 @@ class Agent:
sample = random.random()
if training:
# Epsilon decay
# More aggressive epsilon decay for faster exploitation
self.epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \
np.exp(-1. * self.steps_done / EPSILON_DECAY)
np.exp(-1.5 * self.steps_done / EPSILON_DECAY) # Increased decay factor
self.steps_done += 1
# Lower threshold for exploration, especially in live trading
if not training:
# In live trading, be much more aggressive with exploitation
self.epsilon = max(EPSILON_END, self.epsilon * 0.95)
if sample > self.epsilon or not training:
with torch.no_grad():
state_tensor = torch.FloatTensor(state).to(self.device)
action_values = self.policy_net(state_tensor)
# Add temperature-based sampling for more aggressive actions
# when the model is confident (higher action differences)
if not training: # More aggressive in live trading
values = action_values.cpu().numpy()
max_value = np.max(values)
value_diff = max_value - np.mean(values)
# If there's a clear best action, always take it
if value_diff > 0.5:
return action_values.max(1)[1].item()
return action_values.max(1)[1].item()
else:
return random.randrange(self.action_size)
@ -2761,7 +2809,7 @@ async def process_websocket_ticks(websocket, env, agent=None, demo=True, timefra
continue
# Convert timestamp to datetime
tick_time = datetime.fromtimestamp(timestamp / 1000)
tick_time = datetime.datetime.fromtimestamp(timestamp / 1000)
# For 1-minute candles, track the minute
if timeframe == "1m":
@ -2856,6 +2904,8 @@ async def main():
help='Path to model file for evaluation or live trading')
parser.add_argument('--use-websocket', action='store_true',
help='Use Binance WebSocket for real-time data instead of CCXT (for live mode)')
parser.add_argument('--dashboard', action='store_true',
help='Enable Dash dashboard visualization for real-time trading')
args = parser.parse_args()
@ -2876,7 +2926,7 @@ async def main():
if args.mode == 'train':
# Fetch initial data for training
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
await env.fetch_initial_data(exchange, args.symbol,args.timeframe, 1000)
# Create agent with consistent parameters
# Note: Using STATE_SIZE and action_size=4 for consistency
@ -2942,7 +2992,8 @@ async def main():
symbol=args.symbol,
timeframe=args.timeframe,
demo=demo_mode,
leverage=args.leverage
leverage=args.leverage,
use_dashboard=args.dashboard
)
else:
logger.info("Using CCXT for real-time data")
@ -3048,7 +3099,7 @@ def create_candlestick_figure(data, trade_signals, window_size=100, title=""):
logger.error(f"Error creating chart: {str(e)}")
return None
async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50):
async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50, use_dashboard=False):
"""Run the trading bot in live mode using Binance WebSocket for real-time data
Args:
@ -3058,6 +3109,7 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="
timeframe: The candlestick timeframe (e.g., "1m")
demo: Whether to run in demo mode (paper trading)
leverage: The leverage to use for trading
use_dashboard: Whether to display the real-time dashboard
Returns:
None
@ -3075,9 +3127,25 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="
# Initialize TensorBoard for monitoring
if not hasattr(agent, 'writer') or agent.writer is None:
from torch.utils.tensorboard import SummaryWriter
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
agent.writer = SummaryWriter(f'runs/live_ws_{symbol.replace("/", "_")}_{current_time}')
# Initialize Dash dashboard if enabled
dashboard = None
if use_dashboard:
try:
dashboard = TradingDashboard(symbol)
dashboard_started = dashboard.start() # Start the dashboard in a separate thread
if dashboard_started:
logger.info(f"Trading dashboard enabled at http://localhost:8060")
else:
logger.warning("Failed to start trading dashboard, continuing without visualization")
dashboard = None
except Exception as e:
logger.error(f"Error initializing dashboard: {e}")
logger.error(traceback.format_exc())
dashboard = None
# Track performance metrics
trades_count = 0
winning_trades = 0
@ -3088,7 +3156,7 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="
# Create directory for trade logs
os.makedirs('trade_logs', exist_ok=True)
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
trade_log_path = f'trade_logs/trades_ws_{current_time}.csv'
with open(trade_log_path, 'w') as f:
f.write("timestamp,action,price,position_size,balance,pnl\n")
@ -3109,6 +3177,10 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="
# Reset environment with historical data
env.reset()
# Update dashboard with initial data if enabled
if dashboard:
dashboard.update_data(env=env, candles=env.data, trade_signals=env.trade_signals)
# Initialize futures trading if not in demo mode
exchange = None
if not demo:
@ -3144,7 +3216,7 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="
total_profit += env.last_trade_profit
# Log trade details
current_time = datetime.now().isoformat()
current_time = datetime.datetime.now().isoformat()
action_name = "HOLD" if getattr(env, 'last_action', 0) == 0 else "BUY" if getattr(env, 'last_action', 0) == 1 else "SELL" if getattr(env, 'last_action', 0) == 2 else "CLOSE"
with open(trade_log_path, 'a') as f:
f.write(f"{current_time},{action_name},{env.current_price},{env.position_size},{env.balance},{getattr(env, 'last_trade_profit', 0)}\n")
@ -3184,6 +3256,10 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="
"""
agent.writer.add_text('Performance', performance_text, step_counter)
# Update the dashboard with latest data if enabled
if dashboard:
dashboard.update_data(env=env, candles=env.data, trade_signals=env.trade_signals)
prev_position = env.position
# Sleep for a short time to prevent CPU hogging
@ -3231,6 +3307,252 @@ async def live_trading_with_websocket(agent, env, symbol="ETH/USDT", timeframe="
if 'exchange' in locals() and exchange:
await exchange.close()
def ensure_pytorch_compatibility():
"""Check and fix common PyTorch compatibility issues"""
try:
import torch.serialization
import pickle
# Register safe pickles to handle the numpy scalar warning
if hasattr(torch.serialization, 'add_safe_globals'):
torch.serialization.add_safe_globals([('numpy._core.multiarray.scalar', np.ndarray)])
torch.serialization.add_safe_globals([('numpy.core.multiarray.scalar', np.ndarray)])
torch.serialization.add_safe_globals(['numpy._core.multiarray.scalar'])
torch.serialization.add_safe_globals(['numpy.core.multiarray.scalar'])
logger.info("PyTorch safe globals registered for compatibility")
else:
logger.warning("PyTorch serialization module doesn't have add_safe_globals method")
except Exception as e:
logger.warning(f"PyTorch compatibility check failed: {e}")
class TradingDashboard:
"""Dashboard for visualizing trading activity with Dash"""
def __init__(self, symbol="ETH/USDT"):
self.symbol = symbol
self.env = None
self.candles = []
self.trade_signals = []
# Create Dash app
self.app = dash.Dash(__name__, suppress_callback_exceptions=True)
# Create basic layout
self.app.layout = html.Div([
# Store components for data
html.Div(id='candle-store', style={'display': 'none'}),
html.Div(id='signal-store', style={'display': 'none'}),
# Header
html.H1(f"Trading Dashboard - {symbol}", style={'textAlign': 'center'}),
# Main content
html.Div([
# Chart
html.Div([
dcc.Graph(id='candlestick-chart', style={'height': '70vh'}),
dcc.Interval(id='interval-component', interval=5*1000, n_intervals=0)
], style={'width': '70%', 'display': 'inline-block'}),
# Trading info
html.Div([
html.Div([
html.H3("Account Info"),
html.Div(id='account-info')
]),
html.Div([
html.H3("Recent Trades"),
html.Div(id='recent-trades')
])
], style={'width': '30%', 'display': 'inline-block', 'verticalAlign': 'top'})
])
])
# Setup callbacks
self._setup_callbacks()
# Thread for running the server
self.thread = None
self.is_running = False
def _setup_callbacks(self):
@self.app.callback(
Output('candlestick-chart', 'figure'),
[Input('interval-component', 'n_intervals'),
Input('candle-store', 'children'),
Input('signal-store', 'children')]
)
def update_chart(n, candles_json, signals_json):
# Parse JSON data
candles = json.loads(candles_json) if candles_json else []
signals = json.loads(signals_json) if signals_json else []
# Create figure with subplots
fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
vertical_spacing=0.1, row_heights=[0.7, 0.3])
if candles:
# Convert to dataframe
df = pd.DataFrame(candles[-100:]) # Show last 100 candles
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
# Add candlestick trace
fig.add_trace(
go.Candlestick(
x=df['timestamp'],
open=df['open'],
high=df['high'],
low=df['low'],
close=df['close'],
name='Price'
),
row=1, col=1
)
# Add volume trace
fig.add_trace(
go.Bar(
x=df['timestamp'],
y=df['volume'],
name='Volume'
),
row=2, col=1
)
# Add trade signals
for signal in signals:
if signal['timestamp'] >= df['timestamp'].iloc[0].timestamp() * 1000:
signal_time = pd.to_datetime(signal['timestamp'], unit='ms')
marker_color = 'green' if signal['type'] == 'buy' else 'red' if signal['type'] == 'sell' else 'orange'
marker_symbol = 'triangle-up' if signal['type'] == 'buy' else 'triangle-down' if signal['type'] == 'sell' else 'circle'
# Add marker for signal
fig.add_trace(
go.Scatter(
x=[signal_time],
y=[signal['price']],
mode='markers',
marker=dict(
color=marker_color,
size=12,
symbol=marker_symbol
),
name=signal['type'].capitalize(),
showlegend=False
),
row=1, col=1
)
# Update layout
fig.update_layout(
title=f'{self.symbol} Trading Chart',
xaxis_rangeslider_visible=False,
template='plotly_dark'
)
return fig
@self.app.callback(
[Output('account-info', 'children'),
Output('recent-trades', 'children')],
[Input('interval-component', 'n_intervals')]
)
def update_account_info(n):
if not self.env:
return "No data available", "No trades available"
# Account info
account_info = html.Div([
html.P(f"Balance: ${self.env.balance:.2f}"),
html.P(f"PnL: ${self.env.total_pnl:.2f}",
style={'color': 'green' if self.env.total_pnl > 0 else 'red' if self.env.total_pnl < 0 else 'white'}),
html.P(f"Position: {self.env.position.upper()}")
])
# Recent trades
if hasattr(self.env, 'trades') and self.env.trades:
# Get last 5 trades
recent_trades = []
for trade in reversed(self.env.trades[-5:]):
trade_card = html.Div([
html.P(f"{trade['action'].upper()} at ${trade['price']:.2f}"),
html.P(f"PnL: ${trade['pnl']:.2f}",
style={'color': 'green' if trade['pnl'] > 0 else 'red' if trade['pnl'] < 0 else 'white'})
], style={'border': '1px solid #ddd', 'padding': '10px', 'margin-bottom': '5px'})
recent_trades.append(trade_card)
else:
recent_trades = [html.P("No trades yet")]
return account_info, recent_trades
def update_data(self, env=None, candles=None, trade_signals=None):
"""Update dashboard data"""
if env:
self.env = env
if candles:
self.candles = candles
if trade_signals:
self.trade_signals = trade_signals
# Update store components
if hasattr(self.app, 'layout'):
self.app.layout.children[0].children = json.dumps(self.candles)
self.app.layout.children[1].children = json.dumps(self.trade_signals)
def start(self, host='localhost', port=8060):
"""Start the dashboard server in a separate thread"""
if not self.is_running:
# First check if the port is already in use
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
port_available = False
# Try the initial port and a few alternatives if needed
for attempt_port in range(port, port + 10):
try:
sock.bind((host, attempt_port))
port_available = True
port = attempt_port
break
except socket.error:
logger.warning(f"Port {attempt_port} is already in use")
sock.close()
if not port_available:
logger.error("Could not find an available port for dashboard")
return False
# Create and start the thread
self.thread = Thread(target=self._run_server, args=(host, port))
self.thread.daemon = True # This ensures the thread will exit when the main program does
self.thread.start()
self.is_running = True
logger.info(f"Trading dashboard started at http://{host}:{port}")
# Verify the thread actually started
if not self.thread.is_alive():
logger.error("Dashboard thread failed to start")
return False
# Wait a short time to let the server initialize
time.sleep(1.0)
return True
return False
def _run_server(self, host, port):
"""Run the Dash server"""
try:
logger.info(f"Starting Dash server on {host}:{port}")
self.app.run_server(debug=False, host=host, port=port, use_reloader=False, threaded=True)
except Exception as e:
logger.error(f"Error running dashboard server: {e}")
self.is_running = False
if __name__ == "__main__":
try:
asyncio.run(main())