added plots; fixes
This commit is contained in:
parent
d9d0ba9da8
commit
469d681c4b
43
crypto/gogo2/.vscode/launch.json
vendored
43
crypto/gogo2/.vscode/launch.json
vendored
@ -24,18 +24,53 @@
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"args": ["--mode", "live", "--demo"],
|
||||
"args": [
|
||||
"--mode", "live",
|
||||
"--demo", "true",
|
||||
"--symbol", "ETH/USDT",
|
||||
"--timeframe", "1m"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
"justMyCode": true,
|
||||
"env": {
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Live Trading (Real)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"args": ["--mode", "live"],
|
||||
"args": [
|
||||
"--mode", "live",
|
||||
"--demo", "false",
|
||||
"--symbol", "ETH/USDT",
|
||||
"--timeframe", "1m",
|
||||
"--leverage", "50"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
"justMyCode": true,
|
||||
"env": {
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Live Trading (BTC Futures)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"args": [
|
||||
"--mode", "live",
|
||||
"--demo", "false",
|
||||
"--symbol", "BTC/USDT",
|
||||
"--timeframe", "5m",
|
||||
"--leverage", "20"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"env": {
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
@ -26,6 +26,8 @@ import io
|
||||
import matplotlib.dates as mdates
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as mpf
|
||||
import matplotlib.gridspec as gridspec
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@ -1782,35 +1784,34 @@ class Agent:
|
||||
return False
|
||||
|
||||
def add_chart_to_tensorboard(self, env, global_step):
|
||||
"""Add enhanced trading chart to TensorBoard"""
|
||||
if len(env.data) < 10: # Minimum data to show
|
||||
return
|
||||
|
||||
"""Add trading chart to TensorBoard"""
|
||||
try:
|
||||
# Create chart with annotations
|
||||
# Create chart image
|
||||
chart_img = create_candlestick_figure(
|
||||
env.data,
|
||||
env.trade_signals,
|
||||
window_size=100,
|
||||
title=f"Trading Chart (Step {global_step})"
|
||||
title=f"Trading Chart - Step {global_step}"
|
||||
)
|
||||
|
||||
# Add to TensorBoard
|
||||
self.writer.add_image('Trading Chart', np.array(chart_img).transpose(2, 0, 1), global_step)
|
||||
self.chart_step = global_step
|
||||
|
||||
# Also log position information
|
||||
if env.position != 'flat':
|
||||
position_info = {
|
||||
'position_type': env.position,
|
||||
'entry_price': env.entry_price,
|
||||
'position_size': env.position_size,
|
||||
'unrealized_pnl': env.total_pnl
|
||||
}
|
||||
self.writer.add_text('Position', str(position_info), global_step)
|
||||
if chart_img is not None:
|
||||
# Convert PIL image to numpy array for TensorBoard
|
||||
chart_array = np.array(chart_img)
|
||||
# TensorBoard expects [C, H, W] format
|
||||
chart_array = np.transpose(chart_array, (2, 0, 1))
|
||||
self.writer.add_image('Trading Chart', chart_array, global_step)
|
||||
|
||||
# Add position information as text
|
||||
position_info = f"""
|
||||
**Current Position**: {env.position.upper()}
|
||||
**Entry Price**: ${env.entry_price:.2f if env.entry_price else 0:.2f}
|
||||
**Current Price**: ${env.data[-1]['close']:.2f}
|
||||
**Position Size**: ${env.position_size:.2f}
|
||||
**Unrealized PnL**: ${env.unrealized_pnl:.2f}
|
||||
"""
|
||||
self.writer.add_text('Position', position_info, global_step)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating chart: {e}")
|
||||
logger.error(f"Error adding chart to TensorBoard: {str(e)}")
|
||||
|
||||
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
||||
"""Get live price data using websockets"""
|
||||
@ -2225,122 +2226,155 @@ async def get_historical_data(exchange, symbol="ETH/USDT", timeframe="1m", limit
|
||||
logger.error(f"Failed to fetch historical data: {e}")
|
||||
return []
|
||||
|
||||
async def live_trading(agent, env, exchange, demo=True):
|
||||
"""Run live trading with the trained agent"""
|
||||
logger.info(f"Starting live trading (demo mode: {demo})")
|
||||
async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50):
|
||||
"""
|
||||
Run the trading bot in live mode with enhanced error handling and monitoring
|
||||
|
||||
try:
|
||||
# Subscribe to websocket for real-time data
|
||||
symbol = "ETH/USDT"
|
||||
timeframe = "1m"
|
||||
Args:
|
||||
agent: Trained trading agent
|
||||
env: Trading environment
|
||||
exchange: CCXT exchange instance
|
||||
symbol: Trading pair (default: ETH/USDT)
|
||||
timeframe: Candle timeframe (default: 1m)
|
||||
demo: If True, simulate trades without executing (default: True)
|
||||
leverage: Leverage for futures trading (default: 50x)
|
||||
"""
|
||||
logger.info(f"Starting live trading for {symbol} on {timeframe} timeframe")
|
||||
logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}")
|
||||
|
||||
# Initialize with historical data
|
||||
success = await env.fetch_initial_data(exchange, symbol, timeframe, 100)
|
||||
if not success:
|
||||
logger.error("Failed to initialize with historical data")
|
||||
if not demo:
|
||||
# Confirm with user before starting live trading
|
||||
confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ")
|
||||
if confirmation != "CONFIRM":
|
||||
logger.info("Live trading canceled by user")
|
||||
return
|
||||
|
||||
# Main trading loop
|
||||
step_counter = 0
|
||||
# Initialize futures trading if not in demo mode
|
||||
try:
|
||||
await env.initialize_futures(exchange)
|
||||
logger.info(f"Futures trading initialized with {leverage}x leverage")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize futures trading: {str(e)}")
|
||||
logger.info("Falling back to demo mode for safety")
|
||||
demo = True
|
||||
|
||||
# For online learning
|
||||
states = []
|
||||
actions = []
|
||||
rewards = []
|
||||
next_states = []
|
||||
dones = []
|
||||
# Initialize TensorBoard for monitoring
|
||||
if not hasattr(agent, 'writer') or agent.writer is None:
|
||||
agent.writer = SummaryWriter(f'runs/live_{symbol.replace("/", "_")}_{datetime.now().strftime("%Y%m%d_%H%M%S")}')
|
||||
|
||||
# Track performance metrics
|
||||
trades_count = 0
|
||||
winning_trades = 0
|
||||
total_profit = 0
|
||||
max_drawdown = 0
|
||||
peak_balance = env.balance
|
||||
step_counter = 0
|
||||
prev_position = 'flat'
|
||||
|
||||
# Create directory for trade logs
|
||||
os.makedirs('trade_logs', exist_ok=True)
|
||||
trade_log_path = f'trade_logs/trades_{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv'
|
||||
with open(trade_log_path, 'w') as f:
|
||||
f.write("timestamp,action,price,position_size,balance,pnl\n")
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Wait for the next candle (1 minute)
|
||||
await asyncio.sleep(5) # Check every 5 seconds
|
||||
try:
|
||||
# Fetch latest candle data
|
||||
candle = await get_latest_candle(exchange, symbol)
|
||||
if candle is None:
|
||||
logger.warning("Failed to fetch latest candle, retrying in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
|
||||
# Fetch latest candle
|
||||
latest_candle = await get_latest_candle(exchange, symbol)
|
||||
# Add new data to environment
|
||||
env.add_data(candle)
|
||||
|
||||
if not latest_candle:
|
||||
logger.warning("No latest candle received, skipping update")
|
||||
continue
|
||||
# Get current state and select action
|
||||
state = env.get_state()
|
||||
action = agent.select_action(state, training=False)
|
||||
|
||||
# Update environment with new data
|
||||
env.add_data(latest_candle)
|
||||
# Execute action
|
||||
if not demo:
|
||||
# Execute real trade on exchange
|
||||
current_price = env.data[-1]['close']
|
||||
trade_result = await env.execute_real_trade(exchange, action, current_price)
|
||||
if not trade_result['success']:
|
||||
logger.error(f"Trade execution failed: {trade_result['error']}")
|
||||
# Continue with simulated trade for tracking purposes
|
||||
|
||||
# Get current state
|
||||
state = env.get_state()
|
||||
# Update environment with action (simulated in demo mode)
|
||||
next_state, reward, done, info = env.step(action)
|
||||
|
||||
# Select action (no exploration in live trading)
|
||||
action = agent.select_action(state, training=False)
|
||||
# Log trade if position changed
|
||||
if env.position != prev_position:
|
||||
trades_count += 1
|
||||
if env.last_trade_profit > 0:
|
||||
winning_trades += 1
|
||||
total_profit += env.last_trade_profit
|
||||
|
||||
# Take action
|
||||
next_state, reward, done = env.step(action)
|
||||
# Log trade details
|
||||
with open(trade_log_path, 'a') as f:
|
||||
f.write(f"{datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n")
|
||||
|
||||
# Store experience for online learning
|
||||
states.append(state)
|
||||
actions.append(action)
|
||||
rewards.append(reward)
|
||||
next_states.append(next_state)
|
||||
dones.append(done)
|
||||
logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}")
|
||||
|
||||
# Online learning - update the model with new experiences
|
||||
if len(states) >= 10: # Batch size for online learning
|
||||
# Store experiences in replay memory
|
||||
for i in range(len(states)):
|
||||
agent.memory.push(states[i], actions[i], rewards[i], next_states[i], dones[i])
|
||||
# Update performance metrics
|
||||
if env.balance > peak_balance:
|
||||
peak_balance = env.balance
|
||||
current_drawdown = (peak_balance - env.balance) / peak_balance if peak_balance > 0 else 0
|
||||
if current_drawdown > max_drawdown:
|
||||
max_drawdown = current_drawdown
|
||||
|
||||
# Learn from experiences if we have enough samples
|
||||
if len(agent.memory) > 32:
|
||||
loss = agent.learn()
|
||||
if loss is not None:
|
||||
agent.writer.add_scalar('Live/Loss', loss, step_counter)
|
||||
|
||||
# Clear the temporary storage
|
||||
states = []
|
||||
actions = []
|
||||
rewards = []
|
||||
next_states = []
|
||||
dones = []
|
||||
|
||||
# Save the updated model periodically
|
||||
if step_counter % 100 == 0:
|
||||
agent.save("models/trading_agent_live_updated.pt")
|
||||
logger.info("Updated model saved during live trading")
|
||||
|
||||
# Log trading activity
|
||||
action_names = ["HOLD", "BUY", "SELL", "CLOSE"]
|
||||
logger.info(f"Price: ${latest_candle['close']:.2f} | Action: {action_names[action]}")
|
||||
|
||||
# Log performance metrics
|
||||
if env.trades:
|
||||
wins = sum(1 for t in env.trades if t.get('pnl_percent', 0) > 0)
|
||||
win_rate = wins / len(env.trades) * 100
|
||||
total_pnl = sum(t.get('pnl_dollar', 0) for t in env.trades)
|
||||
|
||||
logger.info(f"Balance: ${env.balance:.2f} | Trades: {len(env.trades)} | "
|
||||
f"Win Rate: {win_rate:.1f}% | Total PnL: ${total_pnl:.2f}")
|
||||
|
||||
# Analyze recent trades
|
||||
trade_analysis = env.analyze_trades()
|
||||
if trade_analysis:
|
||||
logger.info(f"Recent Performance: Win Rate={trade_analysis.get('uptrend_win_rate', 0):.1f}% in uptrends, "
|
||||
f"{trade_analysis.get('downtrend_win_rate', 0):.1f}% in downtrends")
|
||||
|
||||
# Add chart to TensorBoard periodically
|
||||
step_counter += 1
|
||||
if step_counter % 10 == 0: # Update chart every 10 steps
|
||||
agent.add_chart_to_tensorboard(env, step_counter)
|
||||
|
||||
# Also log current PnL and balance
|
||||
# Update TensorBoard metrics
|
||||
step_counter += 1
|
||||
agent.writer.add_scalar('Live/Balance', env.balance, step_counter)
|
||||
agent.writer.add_scalar('Live/TotalPnL', env.total_pnl, step_counter)
|
||||
agent.writer.add_scalar('Live/WinRate',
|
||||
(env.win_count / (env.win_count + env.loss_count) * 100)
|
||||
if (env.win_count + env.loss_count) > 0 else 0,
|
||||
step_counter)
|
||||
agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter)
|
||||
agent.writer.add_scalar('Live/Drawdown', current_drawdown * 100, step_counter)
|
||||
|
||||
# Update chart visualization
|
||||
if step_counter % 5 == 0 or env.position != prev_position:
|
||||
agent.add_chart_to_tensorboard(env, step_counter)
|
||||
|
||||
# Log performance summary
|
||||
if trades_count > 0:
|
||||
win_rate = (winning_trades / trades_count) * 100
|
||||
agent.writer.add_scalar('Live/WinRate', win_rate, step_counter)
|
||||
|
||||
performance_text = f"""
|
||||
**Live Trading Performance**
|
||||
Balance: ${env.balance:.2f}
|
||||
Total PnL: ${env.total_pnl:.2f}
|
||||
Trades: {trades_count}
|
||||
Win Rate: {win_rate:.1f}%
|
||||
Max Drawdown: {max_drawdown*100:.1f}%
|
||||
"""
|
||||
agent.writer.add_text('Performance', performance_text, step_counter)
|
||||
|
||||
prev_position = env.position
|
||||
|
||||
# Wait for next candle
|
||||
await asyncio.sleep(10) # Check every 10 seconds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in live trading loop: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
logger.info("Continuing after error...")
|
||||
await asyncio.sleep(30) # Wait longer after an error
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Live trading stopped by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in live trading: {e}")
|
||||
raise
|
||||
|
||||
# Final performance report
|
||||
if trades_count > 0:
|
||||
win_rate = (winning_trades / trades_count) * 100
|
||||
logger.info(f"Trading session summary:")
|
||||
logger.info(f"Total trades: {trades_count}")
|
||||
logger.info(f"Win rate: {win_rate:.1f}%")
|
||||
logger.info(f"Final balance: ${env.balance:.2f}")
|
||||
logger.info(f"Total profit: ${total_profit:.2f}")
|
||||
logger.info(f"Maximum drawdown: {max_drawdown*100:.1f}%")
|
||||
logger.info(f"Trade log saved to: {trade_log_path}")
|
||||
|
||||
async def get_latest_candle(exchange, symbol):
|
||||
"""Get the latest candle data"""
|
||||
@ -2404,6 +2438,12 @@ async def main():
|
||||
help='Mode to run the bot in')
|
||||
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train')
|
||||
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
|
||||
parser.add_argument('--live', action='store_true', help='Run in live trading mode')
|
||||
parser.add_argument('--real', action='store_true', help='Execute real trades (default is demo/paper trading)')
|
||||
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
|
||||
parser.add_argument('--timeframe', type=str, default='1m', help='Candle timeframe')
|
||||
parser.add_argument('--leverage', type=int, default=50, help='Leverage for futures trading')
|
||||
parser.add_argument('--model', type=str, help='Path to model file')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get device (GPU or CPU)
|
||||
@ -2438,12 +2478,46 @@ async def main():
|
||||
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env)
|
||||
|
||||
elif args.mode == 'live':
|
||||
# Load trained model
|
||||
agent.load("models/trading_agent_best_pnl.pt")
|
||||
# Add these arguments to the parser
|
||||
parser.add_argument('--live', action='store_true', help='Run in live trading mode')
|
||||
parser.add_argument('--real', action='store_true', help='Execute real trades (default is demo/paper trading)')
|
||||
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
|
||||
parser.add_argument('--timeframe', type=str, default='1m', help='Candle timeframe')
|
||||
parser.add_argument('--leverage', type=int, default=50, help='Leverage for futures trading')
|
||||
parser.add_argument('--model', type=str, help='Path to model file')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run live trading
|
||||
logger.info("Starting live trading...")
|
||||
await live_trading(agent, env, exchange, demo=args.demo)
|
||||
# In the main function, add this section to handle live trading
|
||||
if args.live:
|
||||
# Initialize exchange
|
||||
exchange = await initialize_exchange()
|
||||
|
||||
# Load the trained agent
|
||||
model_path = args.model if args.model else "models/trading_agent.pt"
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"Model file not found: {model_path}")
|
||||
return
|
||||
|
||||
# Initialize environment with historical data
|
||||
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=WINDOW_SIZE, demo=not args.real)
|
||||
await env.fetch_initial_data(exchange, symbol=args.symbol, timeframe=args.timeframe)
|
||||
|
||||
# Initialize agent
|
||||
state_size = env.get_state().shape[0]
|
||||
agent = Agent(state_size=state_size, action_size=3)
|
||||
agent.load(model_path)
|
||||
logger.info(f"Loaded model from {model_path}")
|
||||
|
||||
# Start live trading
|
||||
await live_trading(
|
||||
agent=agent,
|
||||
env=env,
|
||||
exchange=exchange,
|
||||
symbol=args.symbol,
|
||||
timeframe=args.timeframe,
|
||||
demo=not args.real,
|
||||
leverage=args.leverage
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up exchange connection - safely close if possible
|
||||
@ -2458,6 +2532,73 @@ async def main():
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not properly close exchange connection: {e}")
|
||||
|
||||
# Add this function near the top with other utility functions
|
||||
def create_candlestick_figure(data, trade_signals, window_size=100, title=""):
|
||||
"""Create a candlestick chart with trade signals for TensorBoard visualization"""
|
||||
if len(data) < 10:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Create figure
|
||||
fig = plt.figure(figsize=(12, 8))
|
||||
|
||||
# Prepare data for plotting
|
||||
df = pd.DataFrame(data[-window_size:])
|
||||
df['date'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||||
df.set_index('date', inplace=True)
|
||||
|
||||
# Create subplot grid
|
||||
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1])
|
||||
price_ax = plt.subplot(gs[0])
|
||||
volume_ax = plt.subplot(gs[1], sharex=price_ax)
|
||||
|
||||
# Plot candlesticks
|
||||
mpf.plot(df, type='candle', style='yahoo', ax=price_ax, volume=volume_ax)
|
||||
|
||||
# Add trade signals
|
||||
for signal in trade_signals:
|
||||
if signal['timestamp'] not in df.index:
|
||||
continue
|
||||
|
||||
timestamp = pd.to_datetime(signal['timestamp'], unit='ms')
|
||||
price = signal['price']
|
||||
|
||||
if signal['type'] == 'buy':
|
||||
price_ax.plot(timestamp, price, '^', color='green', markersize=10, label='Buy')
|
||||
elif signal['type'] == 'sell':
|
||||
price_ax.plot(timestamp, price, 'v', color='red', markersize=10, label='Sell')
|
||||
elif signal['type'] == 'close_long':
|
||||
price_ax.plot(timestamp, price, 'x', color='gold', markersize=10, label='Close Long')
|
||||
elif signal['type'] == 'close_short':
|
||||
price_ax.plot(timestamp, price, 'x', color='black', markersize=10, label='Close Short')
|
||||
elif 'stop_loss' in signal['type']:
|
||||
price_ax.plot(timestamp, price, 'X', color='purple', markersize=10, label='Stop Loss')
|
||||
elif 'take_profit' in signal['type']:
|
||||
price_ax.plot(timestamp, price, '*', color='cyan', markersize=10, label='Take Profit')
|
||||
|
||||
# Add balance and PnL annotation
|
||||
if trade_signals and 'balance' in trade_signals[-1] and 'pnl' in trade_signals[-1]:
|
||||
balance = trade_signals[-1]['balance']
|
||||
pnl = trade_signals[-1]['pnl']
|
||||
price_ax.annotate(f"Balance: ${balance:.2f}\nPnL: ${pnl:.2f}",
|
||||
xy=(0.02, 0.95), xycoords='axes fraction',
|
||||
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8))
|
||||
|
||||
# Set title and format
|
||||
price_ax.set_title(title)
|
||||
fig.tight_layout()
|
||||
|
||||
# Convert to image
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format='png')
|
||||
buf.seek(0)
|
||||
img = Image.open(buf)
|
||||
return img
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating chart: {str(e)}")
|
||||
return None
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
|
Loading…
x
Reference in New Issue
Block a user