Compare commits
10 Commits
715261a3f9
...
fdde2ff587
Author | SHA1 | Date | |
---|---|---|---|
|
fdde2ff587 | ||
|
690c61f230 | ||
|
0148964409 | ||
|
3e924b32ac | ||
|
8dafb6d310 | ||
|
621a2505bd | ||
|
08b8da7c8f | ||
|
e884f0c9e6 | ||
|
cfddc996d7 | ||
|
6f78703ba1 |
23
.gitignore
vendored
23
.gitignore
vendored
@ -32,6 +32,27 @@ crypto/sol/.vs/*
|
||||
crypto/brian/models/best/*
|
||||
crypto/brian/models/last/*
|
||||
crypto/brian/live_chart.html
|
||||
crypto/gogo2/models/*
|
||||
crypto/gogo2/trading_bot.log
|
||||
*.log
|
||||
|
||||
crypto/gogo2/checkpoints/trading_agent_episode_*.pt
|
||||
*trading_agent_continuous_*.pt
|
||||
*trading_agent_episode_*.pt
|
||||
crypto/gogo2/models/trading_agent_continuous_150.pt
|
||||
crypto/gogo2/checkpoints/trading_agent_episode_0.pt
|
||||
crypto/gogo2/checkpoints/trading_agent_episode_10.pt
|
||||
crypto/gogo2/checkpoints/trading_agent_episode_20.pt
|
||||
crypto/gogo2/checkpoints/trading_agent_episode_40.pt
|
||||
crypto/gogo2/models/trading_agent_best_pnl.pt
|
||||
crypto/gogo2/models/trading_agent_best_reward.pt
|
||||
crypto/gogo2/models/trading_agent_best_winrate.pt
|
||||
crypto/gogo2/models/trading_agent_continuous_0.pt
|
||||
crypto/gogo2/models/trading_agent_continuous_50.pt
|
||||
crypto/gogo2/models/trading_agent_continuous_100.pt
|
||||
crypto/gogo2/models/trading_agent_continuous_150.pt
|
||||
crypto/gogo2/models/trading_agent_emergency.pt
|
||||
crypto/gogo2/models/trading_agent_episode_0.pt
|
||||
crypto/gogo2/models/trading_agent_episode_10.pt
|
||||
crypto/gogo2/models/trading_agent_episode_20.pt
|
||||
crypto/gogo2/models/trading_agent_episode_30.pt
|
||||
crypto/gogo2/models/trading_agent_final.pt
|
||||
|
1
crypto/gogo2/.gitattributes
vendored
Normal file
1
crypto/gogo2/.gitattributes
vendored
Normal file
@ -0,0 +1 @@
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
1
crypto/gogo2/.gitignore
vendored
Normal file
1
crypto/gogo2/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
*.pt
|
11
crypto/gogo2/.vscode/launch.json
vendored
11
crypto/gogo2/.vscode/launch.json
vendored
@ -6,7 +6,7 @@
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"args": ["--mode", "train", "--episodes", "1000"],
|
||||
"args": ["--mode", "train", "--episodes", "100"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
},
|
||||
@ -36,6 +36,15 @@
|
||||
"args": ["--mode", "live"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
"name": "Continuous Training",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"args": ["--mode", "continuous", "--refresh-data"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
}
|
||||
]
|
||||
}
|
@ -1 +1 @@
|
||||
{"best_reward": 202.7441047517104, "best_pnl": -1.285678343969877, "best_win_rate": 38.70967741935484, "last_episode": 20, "timestamp": "2025-03-10T13:31:02.938465"}
|
||||
{"best_reward": 202.7441047517104, "best_pnl": 9.268344827764809, "best_win_rate": 73.33333333333333, "last_episode": 30, "timestamp": "2025-03-10T17:57:19.913481"}
|
@ -23,6 +23,9 @@ import copy
|
||||
import argparse
|
||||
import traceback
|
||||
import math
|
||||
import matplotlib.dates as mdates
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@ -426,7 +429,12 @@ class TradingEnvironment:
|
||||
def _update_features(self):
|
||||
"""Update technical indicators with new data"""
|
||||
self._initialize_features() # Recalculate all features
|
||||
|
||||
|
||||
async def fetch_new_data(env, exchange, symbol="ETH/USDT", timeframe="1m", limit=1000):
|
||||
"""Fetch new data for the environment"""
|
||||
# Call the environment's fetch_initial_data method
|
||||
return await env.fetch_initial_data(exchange, symbol, timeframe, limit)
|
||||
|
||||
async def fetch_initial_data(self, exchange, symbol="ETH/USDT", timeframe="1m", limit=1000):
|
||||
"""Fetch initial historical data for the environment"""
|
||||
try:
|
||||
@ -1676,7 +1684,7 @@ async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
||||
await asyncio.sleep(5)
|
||||
break
|
||||
|
||||
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, exchange=None):
|
||||
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, exchange=None, args=None):
|
||||
"""Train the agent using historical and live data with GPU acceleration"""
|
||||
logger.info(f"Starting training on device: {agent.device}")
|
||||
|
||||
@ -1700,6 +1708,7 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
|
||||
]
|
||||
current_stage = 0
|
||||
|
||||
# Initialize stats dictionary with the correct keys
|
||||
stats = {
|
||||
'episode_rewards': [],
|
||||
'episode_profits': [],
|
||||
@ -1758,18 +1767,11 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
|
||||
# Set risk factor for this episode
|
||||
env.risk_factor = risk_factor
|
||||
|
||||
# Refresh data with latest candles if exchange is provided
|
||||
if exchange is not None:
|
||||
try:
|
||||
logger.info(f"Fetching latest data for episode {episode}")
|
||||
latest_data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 100)
|
||||
if latest_data:
|
||||
# Add new data to environment
|
||||
for candle in latest_data:
|
||||
env.add_data(candle)
|
||||
logger.info(f"Added {len(latest_data)} new candles for episode {episode}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing data: {e}")
|
||||
# Update training data if exchange is available
|
||||
if exchange and args.refresh_data:
|
||||
# Fetch new data at the start of each episode
|
||||
logger.info(f"Refreshing data for episode {episode}")
|
||||
await env.fetch_new_data(exchange, "ETH/USDT", "1m", 100)
|
||||
|
||||
# Reset environment
|
||||
state = env.reset()
|
||||
@ -1796,6 +1798,59 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
|
||||
# Update price predictions
|
||||
env.update_price_predictions()
|
||||
|
||||
# Log OHLCV data to TensorBoard at the start of the episode
|
||||
if episode % 5 == 0: # Log every 5 episodes to avoid too much data
|
||||
# Create a DataFrame from the environment's data
|
||||
df_ohlcv = pd.DataFrame([{
|
||||
'timestamp': candle['timestamp'],
|
||||
'open': candle['open'],
|
||||
'high': candle['high'],
|
||||
'low': candle['low'],
|
||||
'close': candle['close'],
|
||||
'volume': candle['volume']
|
||||
} for candle in env.data[-100:]]) # Use last 100 candles
|
||||
|
||||
# Convert timestamp to datetime
|
||||
df_ohlcv['timestamp'] = pd.to_datetime(df_ohlcv['timestamp'], unit='ms')
|
||||
df_ohlcv.set_index('timestamp', inplace=True)
|
||||
|
||||
# Extract buy/sell signals from trades
|
||||
buy_signals = []
|
||||
sell_signals = []
|
||||
|
||||
if hasattr(env, 'trades') and env.trades:
|
||||
for trade in env.trades:
|
||||
if 'entry_time' in trade and 'entry' in trade:
|
||||
if trade['type'] == 'long':
|
||||
# Buy signal
|
||||
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
|
||||
buy_signals.append((entry_time, trade['entry']))
|
||||
|
||||
# Sell signal if closed
|
||||
if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0:
|
||||
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
|
||||
sell_signals.append((exit_time, trade['exit']))
|
||||
|
||||
elif trade['type'] == 'short':
|
||||
# Sell short signal
|
||||
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
|
||||
sell_signals.append((entry_time, trade['entry']))
|
||||
|
||||
# Buy to cover signal if closed
|
||||
if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0:
|
||||
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
|
||||
buy_signals.append((exit_time, trade['exit']))
|
||||
|
||||
# Log to TensorBoard
|
||||
log_ohlcv_to_tensorboard(
|
||||
agent.writer,
|
||||
df_ohlcv,
|
||||
buy_signals,
|
||||
sell_signals,
|
||||
episode,
|
||||
tag_prefix=f"episode_{episode}"
|
||||
)
|
||||
|
||||
while not done:
|
||||
# Select action
|
||||
action = agent.select_action(state)
|
||||
@ -1920,6 +1975,72 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
|
||||
if episode % 10 == 0 or episode == num_episodes - 1:
|
||||
visualize_training_results(env, agent, episode)
|
||||
|
||||
# After episode is complete, log final state with all trades
|
||||
if episode % 10 == 0 or episode == num_episodes - 1:
|
||||
# Create a DataFrame from the environment's data
|
||||
df_ohlcv = pd.DataFrame([{
|
||||
'timestamp': candle['timestamp'],
|
||||
'open': candle['open'],
|
||||
'high': candle['high'],
|
||||
'low': candle['low'],
|
||||
'close': candle['close'],
|
||||
'volume': candle['volume']
|
||||
} for candle in env.data[-100:]]) # Use last 100 candles
|
||||
|
||||
# Convert timestamp to datetime
|
||||
df_ohlcv['timestamp'] = pd.to_datetime(df_ohlcv['timestamp'], unit='ms')
|
||||
df_ohlcv.set_index('timestamp', inplace=True)
|
||||
|
||||
# Extract buy/sell signals from trades
|
||||
buy_signals = []
|
||||
sell_signals = []
|
||||
|
||||
if hasattr(env, 'trades') and env.trades:
|
||||
for trade in env.trades:
|
||||
if 'entry_time' in trade and 'entry' in trade:
|
||||
if trade['type'] == 'long':
|
||||
# Buy signal
|
||||
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
|
||||
buy_signals.append((entry_time, trade['entry']))
|
||||
|
||||
# Sell signal if closed
|
||||
if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0:
|
||||
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
|
||||
sell_signals.append((exit_time, trade['exit']))
|
||||
|
||||
elif trade['type'] == 'short':
|
||||
# Sell short signal
|
||||
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
|
||||
sell_signals.append((entry_time, trade['entry']))
|
||||
|
||||
# Buy to cover signal if closed
|
||||
if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0:
|
||||
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
|
||||
buy_signals.append((exit_time, trade['exit']))
|
||||
|
||||
# Log to TensorBoard - use a fixed tag to overwrite previous charts
|
||||
log_ohlcv_to_tensorboard(
|
||||
agent.writer,
|
||||
df_ohlcv,
|
||||
buy_signals,
|
||||
sell_signals,
|
||||
episode,
|
||||
tag_prefix="latest_trading_data" # Use a fixed tag to overwrite previous charts
|
||||
)
|
||||
|
||||
# Create visualization - only keep the latest one
|
||||
os.makedirs("visualizations", exist_ok=True)
|
||||
# Remove previous visualizations to save disk space
|
||||
for file in os.listdir("visualizations"):
|
||||
if file.startswith("training_episode_") and file.endswith(".png"):
|
||||
try:
|
||||
os.remove(os.path.join("visualizations", file))
|
||||
except:
|
||||
pass
|
||||
|
||||
# Create new visualization
|
||||
visualize_training_results(env, agent, episode)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in episode {episode}: {e}")
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
@ -1945,8 +2066,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
|
||||
return stats
|
||||
|
||||
def plot_training_results(stats):
|
||||
"""Plot detailed training results"""
|
||||
plt.figure(figsize=(20, 15))
|
||||
"""Plot training results"""
|
||||
plt.figure(figsize=(15, 15))
|
||||
|
||||
# Plot rewards
|
||||
plt.subplot(3, 2, 1)
|
||||
@ -1955,12 +2076,12 @@ def plot_training_results(stats):
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Reward')
|
||||
|
||||
# Plot balance
|
||||
# Plot balance/profits
|
||||
plt.subplot(3, 2, 2)
|
||||
plt.plot(stats['episode_profits'])
|
||||
plt.title('Account Balance')
|
||||
plt.title('Episode Profits')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Balance ($)')
|
||||
plt.ylabel('Profit ($)')
|
||||
|
||||
# Plot win rate
|
||||
plt.subplot(3, 2, 3)
|
||||
@ -1969,35 +2090,26 @@ def plot_training_results(stats):
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Win Rate (%)')
|
||||
|
||||
# Plot episode PnL
|
||||
# Plot trade count
|
||||
plt.subplot(3, 2, 4)
|
||||
plt.plot(stats['episode_pnls'])
|
||||
plt.title('Episode PnL')
|
||||
plt.plot(stats['trade_counts'])
|
||||
plt.title('Number of Trades')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('PnL ($)')
|
||||
plt.ylabel('Trades')
|
||||
|
||||
# Plot cumulative PnL
|
||||
# Plot prediction accuracy
|
||||
plt.subplot(3, 2, 5)
|
||||
plt.plot(stats['cumulative_pnl'])
|
||||
plt.title('Cumulative PnL')
|
||||
plt.plot(stats['prediction_accuracies'])
|
||||
plt.title('Prediction Accuracy')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Cumulative PnL ($)')
|
||||
|
||||
# Plot drawdown
|
||||
plt.subplot(3, 2, 6)
|
||||
plt.plot(stats['drawdowns'])
|
||||
plt.title('Maximum Drawdown')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Drawdown (%)')
|
||||
plt.ylabel('Accuracy (%)')
|
||||
|
||||
# Save the figure
|
||||
plt.tight_layout()
|
||||
plt.savefig('training_results.png')
|
||||
plt.close()
|
||||
|
||||
# Save statistics to CSV
|
||||
df = pd.DataFrame(stats)
|
||||
df.to_csv('training_stats.csv', index=False)
|
||||
|
||||
logger.info("Training statistics saved to training_stats.csv and training_results.png")
|
||||
logger.info("Training results saved to training_results.png")
|
||||
|
||||
def evaluate_agent(agent, env, num_episodes=10):
|
||||
"""Evaluate the agent on test data"""
|
||||
@ -2237,6 +2349,7 @@ async def get_latest_candle(exchange, symbol):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch latest candle: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def fetch_ohlcv_data(exchange, symbol, timeframe, limit):
|
||||
"""Fetch OHLCV data with proper handling for both async and standard CCXT"""
|
||||
@ -2281,8 +2394,9 @@ async def fetch_ohlcv_data(exchange, symbol, timeframe, limit):
|
||||
async def main():
|
||||
"""Main function to run the trading bot"""
|
||||
parser = argparse.ArgumentParser(description='Crypto Trading Bot')
|
||||
parser.add_argument('--mode', type=str, default='train', choices=['train', 'evaluate', 'live'],
|
||||
help='Mode to run the bot in')
|
||||
parser.add_argument('--mode', type=str, default='train',
|
||||
choices=['train', 'evaluate', 'live', 'continuous'],
|
||||
help='Mode to run the bot in (train, evaluate, live, or continuous)')
|
||||
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train')
|
||||
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
|
||||
parser.add_argument('--refresh-data', action='store_true', help='Refresh data during training')
|
||||
@ -2294,49 +2408,85 @@ async def main():
|
||||
device = get_device(args.device)
|
||||
|
||||
exchange = None
|
||||
|
||||
try:
|
||||
# Initialize exchange
|
||||
exchange = await initialize_exchange()
|
||||
|
||||
# Create environment
|
||||
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=args.demo)
|
||||
# Create environment with the correct parameters
|
||||
env = TradingEnvironment(
|
||||
initial_balance=INITIAL_BALANCE,
|
||||
window_size=30,
|
||||
demo=args.demo or args.mode != 'live'
|
||||
)
|
||||
|
||||
# Fetch initial data
|
||||
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
|
||||
logger.info("Fetching initial data for ETH/USDT")
|
||||
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 500)
|
||||
|
||||
# Create agent
|
||||
# Initialize agent
|
||||
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
||||
|
||||
if args.mode == 'train':
|
||||
# Train the agent
|
||||
logger.info(f"Starting training for {args.episodes} episodes...")
|
||||
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange, args=args)
|
||||
|
||||
# Pass exchange to training function if refresh-data is enabled
|
||||
if args.refresh_data:
|
||||
logger.info("Data refresh enabled during training")
|
||||
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange)
|
||||
else:
|
||||
stats = await train_agent(agent, env, num_episodes=args.episodes)
|
||||
elif args.mode == 'continuous':
|
||||
# Run in continuous mode - train indefinitely
|
||||
logger.info("Starting continuous training mode. Press Ctrl+C to stop.")
|
||||
episode_counter = 0
|
||||
try:
|
||||
while True: # Run indefinitely until manually stopped
|
||||
# Train for a batch of episodes
|
||||
batch_size = 50 # Train in batches of 50 episodes
|
||||
logger.info(f"Starting training batch {episode_counter // batch_size + 1}")
|
||||
|
||||
# Refresh data at the start of each batch
|
||||
if exchange and args.refresh_data:
|
||||
logger.info("Refreshing data for new training batch")
|
||||
await env.fetch_new_data(exchange, "ETH/USDT", "1m", 500)
|
||||
logger.info(f"Updated environment with fresh candles")
|
||||
|
||||
# Train for a batch of episodes
|
||||
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange, args=args)
|
||||
|
||||
# Save model after each batch
|
||||
agent.save(f"models/trading_agent_continuous_{episode_counter}.pt")
|
||||
|
||||
# Increment counter
|
||||
episode_counter += batch_size
|
||||
|
||||
# Sleep briefly to prevent excessive API calls
|
||||
await asyncio.sleep(5)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Continuous training stopped by user")
|
||||
# Save final model
|
||||
agent.save("models/trading_agent_continuous_final.pt")
|
||||
logger.info("Final model saved")
|
||||
|
||||
elif args.mode == 'evaluate':
|
||||
# Load trained model
|
||||
# Load the best model
|
||||
agent.load("models/trading_agent_best_pnl.pt")
|
||||
|
||||
# Evaluate the agent
|
||||
logger.info("Evaluating agent...")
|
||||
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env)
|
||||
results = evaluate_agent(agent, env, num_episodes=10)
|
||||
logger.info(f"Evaluation results: {results}")
|
||||
|
||||
elif args.mode == 'live':
|
||||
# Load trained model
|
||||
# Load the best model
|
||||
agent.load("models/trading_agent_best_pnl.pt")
|
||||
|
||||
# Run live trading
|
||||
logger.info("Starting live trading...")
|
||||
await live_trading(agent, env, exchange, demo=args.demo)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
finally:
|
||||
# Clean up exchange connection - safely close if possible
|
||||
# Close exchange connection
|
||||
if exchange:
|
||||
try:
|
||||
# Some CCXT exchanges have close method, others don't
|
||||
@ -2515,6 +2665,128 @@ def visualize_training_results(env, agent, episode_num):
|
||||
logger.error(f"Error creating visualization: {e}")
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
def log_ohlcv_to_tensorboard(writer, df_ohlcv, buy_signals, sell_signals, step, tag_prefix="trading"):
|
||||
"""
|
||||
Log OHLCV chart with buy/sell signals to TensorBoard
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
writer : torch.utils.tensorboard.SummaryWriter
|
||||
TensorBoard writer instance
|
||||
df_ohlcv : pandas.DataFrame
|
||||
DataFrame with OHLCV data
|
||||
buy_signals : list of tuples
|
||||
List of (datetime, price) tuples for buy signals
|
||||
sell_signals : list of tuples
|
||||
List of (datetime, price) tuples for sell signals
|
||||
step : int
|
||||
Global step value to record
|
||||
tag_prefix : str
|
||||
Prefix for the tag in TensorBoard
|
||||
"""
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
||||
import numpy as np
|
||||
|
||||
# Check if DataFrame is empty
|
||||
if df_ohlcv.empty:
|
||||
logger.warning("Empty OHLCV DataFrame, skipping visualization")
|
||||
return
|
||||
|
||||
# Create figure
|
||||
fig = Figure(figsize=(12, 8))
|
||||
canvas = FigureCanvas(fig)
|
||||
|
||||
# Create subplots for price and volume
|
||||
ax1 = fig.add_subplot(2, 1, 1) # Price chart
|
||||
ax2 = fig.add_subplot(2, 1, 2, sharex=ax1) # Volume chart
|
||||
|
||||
# Plot OHLC
|
||||
dates = mdates.date2num(df_ohlcv.index.to_pydatetime())
|
||||
ohlc = np.column_stack((dates, df_ohlcv['open'], df_ohlcv['high'], df_ohlcv['low'], df_ohlcv['close']))
|
||||
|
||||
# Plot candlestick chart
|
||||
from matplotlib.lines import Line2D
|
||||
from matplotlib.patches import Rectangle
|
||||
|
||||
width = 0.6 / (len(df_ohlcv) + 1) # Adjust width based on number of candles
|
||||
|
||||
for i, (date, open_price, high, low, close) in enumerate(ohlc):
|
||||
# Determine candle color
|
||||
if close >= open_price:
|
||||
color = 'green'
|
||||
body_bottom = open_price
|
||||
body_height = close - open_price
|
||||
else:
|
||||
color = 'red'
|
||||
body_bottom = close
|
||||
body_height = open_price - close
|
||||
|
||||
# Plot candle body
|
||||
rect = Rectangle(
|
||||
xy=(date - width/2, body_bottom),
|
||||
width=width,
|
||||
height=body_height,
|
||||
facecolor=color,
|
||||
edgecolor='black',
|
||||
alpha=0.8
|
||||
)
|
||||
ax1.add_patch(rect)
|
||||
|
||||
# Plot wick
|
||||
ax1.plot([date, date], [low, high], color='black', linewidth=1)
|
||||
|
||||
# Plot buy signals
|
||||
if buy_signals:
|
||||
buy_dates = mdates.date2num([x[0] for x in buy_signals])
|
||||
buy_prices = [x[1] for x in buy_signals]
|
||||
ax1.scatter(buy_dates, buy_prices, marker='^', color='green', s=100, label='Buy')
|
||||
|
||||
# Plot sell signals
|
||||
if sell_signals:
|
||||
sell_dates = mdates.date2num([x[0] for x in sell_signals])
|
||||
sell_prices = [x[1] for x in sell_signals]
|
||||
ax1.scatter(sell_dates, sell_prices, marker='v', color='red', s=100, label='Sell')
|
||||
|
||||
# Plot volume
|
||||
ax2.bar(dates, df_ohlcv['volume'], width=width, color='blue', alpha=0.5)
|
||||
|
||||
# Format axes
|
||||
ax1.set_title(f'OHLC with Buy/Sell Signals - {tag_prefix}')
|
||||
ax1.set_ylabel('Price')
|
||||
ax1.legend()
|
||||
ax1.grid(True)
|
||||
|
||||
ax2.set_xlabel('Date')
|
||||
ax2.set_ylabel('Volume')
|
||||
ax2.grid(True)
|
||||
|
||||
# Format date
|
||||
date_format = mdates.DateFormatter('%Y-%m-%d %H:%M')
|
||||
ax2.xaxis.set_major_formatter(date_format)
|
||||
fig.autofmt_xdate()
|
||||
|
||||
# Adjust layout
|
||||
fig.tight_layout()
|
||||
|
||||
# Log to TensorBoard
|
||||
if tag_prefix == "latest_trading_data":
|
||||
# For the latest data, use a fixed tag without step to overwrite previous charts
|
||||
writer.add_figure(f"{tag_prefix}/ohlcv_chart", fig)
|
||||
else:
|
||||
# For other charts, include the step
|
||||
writer.add_figure(f"{tag_prefix}/ohlcv_chart", fig, global_step=step)
|
||||
|
||||
# Clean up
|
||||
plt.close(fig)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in log_ohlcv_to_tensorboard: {e}")
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
|
@ -6,4 +6,5 @@ python-dotenv>=0.19.0
|
||||
ccxt>=2.0.0
|
||||
websockets>=10.0
|
||||
tensorboard>=2.6.0
|
||||
scikit-learn
|
||||
scikit-learn
|
||||
mplfinance
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
192666
crypto/gogo2/trading_bot.log
192666
crypto/gogo2/trading_bot.log
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Before Width: | Height: | Size: 60 KiB After Width: | Height: | Size: 170 KiB |
BIN
crypto/gogo2/visualizations/training_episode_30.png
Normal file
BIN
crypto/gogo2/visualizations/training_episode_30.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 86 KiB |
Loading…
x
Reference in New Issue
Block a user