added modes scripts
This commit is contained in:
parent
469d681c4b
commit
c63b1a2daf
166
crypto/gogo2/check_live_trading.py
Normal file
166
crypto/gogo2/check_live_trading.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import importlib
|
||||||
|
import asyncio
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[logging.StreamHandler()]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("check_live_trading")
|
||||||
|
|
||||||
|
def check_dependencies():
|
||||||
|
"""Check if all required dependencies are installed"""
|
||||||
|
required_packages = [
|
||||||
|
"numpy", "pandas", "matplotlib", "mplfinance", "torch",
|
||||||
|
"dotenv", "ccxt", "websockets", "tensorboard",
|
||||||
|
"sklearn", "PIL", "asyncio"
|
||||||
|
]
|
||||||
|
|
||||||
|
missing_packages = []
|
||||||
|
|
||||||
|
for package in required_packages:
|
||||||
|
try:
|
||||||
|
if package == "dotenv":
|
||||||
|
importlib.import_module("dotenv")
|
||||||
|
elif package == "PIL":
|
||||||
|
importlib.import_module("PIL")
|
||||||
|
else:
|
||||||
|
importlib.import_module(package)
|
||||||
|
logger.info(f"✅ {package} is installed")
|
||||||
|
except ImportError:
|
||||||
|
missing_packages.append(package)
|
||||||
|
logger.error(f"❌ {package} is NOT installed")
|
||||||
|
|
||||||
|
if missing_packages:
|
||||||
|
logger.error(f"Missing packages: {', '.join(missing_packages)}")
|
||||||
|
logger.info("Install missing packages with: pip install -r requirements.txt")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def check_api_keys():
|
||||||
|
"""Check if API keys are configured"""
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
api_key = os.getenv('MEXC_API_KEY')
|
||||||
|
secret_key = os.getenv('MEXC_SECRET_KEY')
|
||||||
|
|
||||||
|
if not api_key or api_key == "your_api_key_here" or not secret_key or secret_key == "your_secret_key_here":
|
||||||
|
logger.error("❌ API keys are not properly configured in .env file")
|
||||||
|
logger.info("Please update your .env file with valid MEXC API keys")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info("✅ API keys are configured")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def check_model_files():
|
||||||
|
"""Check if trained model files exist"""
|
||||||
|
model_files = [
|
||||||
|
"models/trading_agent_best_pnl.pt",
|
||||||
|
"models/trading_agent_best_reward.pt",
|
||||||
|
"models/trading_agent_final.pt"
|
||||||
|
]
|
||||||
|
|
||||||
|
missing_models = []
|
||||||
|
|
||||||
|
for model_file in model_files:
|
||||||
|
if os.path.exists(model_file):
|
||||||
|
logger.info(f"✅ Model file exists: {model_file}")
|
||||||
|
else:
|
||||||
|
missing_models.append(model_file)
|
||||||
|
logger.error(f"❌ Model file missing: {model_file}")
|
||||||
|
|
||||||
|
if missing_models:
|
||||||
|
logger.warning("Some model files are missing. You need to train the model first.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def check_exchange_connection():
|
||||||
|
"""Test connection to MEXC exchange"""
|
||||||
|
try:
|
||||||
|
import ccxt
|
||||||
|
|
||||||
|
# Load API keys
|
||||||
|
load_dotenv()
|
||||||
|
api_key = os.getenv('MEXC_API_KEY')
|
||||||
|
secret_key = os.getenv('MEXC_SECRET_KEY')
|
||||||
|
|
||||||
|
if api_key == "your_api_key_here" or secret_key == "your_secret_key_here":
|
||||||
|
logger.warning("⚠️ Using placeholder API keys, skipping exchange connection test")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Initialize exchange
|
||||||
|
exchange = ccxt.mexc({
|
||||||
|
'apiKey': api_key,
|
||||||
|
'secret': secret_key,
|
||||||
|
'enableRateLimit': True
|
||||||
|
})
|
||||||
|
|
||||||
|
# Test connection by fetching markets
|
||||||
|
markets = exchange.fetch_markets()
|
||||||
|
logger.info(f"✅ Successfully connected to MEXC exchange")
|
||||||
|
logger.info(f"✅ Found {len(markets)} markets")
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Failed to connect to MEXC exchange: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_directories():
|
||||||
|
"""Check if required directories exist"""
|
||||||
|
required_dirs = ["models", "runs", "trade_logs"]
|
||||||
|
|
||||||
|
for directory in required_dirs:
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
logger.info(f"Creating directory: {directory}")
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info("✅ All required directories exist")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run all checks"""
|
||||||
|
logger.info("Running pre-flight checks for live trading...")
|
||||||
|
|
||||||
|
checks = [
|
||||||
|
("Dependencies", check_dependencies()),
|
||||||
|
("API Keys", check_api_keys()),
|
||||||
|
("Model Files", check_model_files()),
|
||||||
|
("Directories", check_directories()),
|
||||||
|
("Exchange Connection", await check_exchange_connection())
|
||||||
|
]
|
||||||
|
|
||||||
|
# Count failed checks
|
||||||
|
failed_checks = sum(1 for _, result in checks if not result)
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
logger.info("\n" + "="*50)
|
||||||
|
logger.info("LIVE TRADING PRE-FLIGHT CHECK SUMMARY")
|
||||||
|
logger.info("="*50)
|
||||||
|
|
||||||
|
for check_name, result in checks:
|
||||||
|
status = "✅ PASS" if result else "❌ FAIL"
|
||||||
|
logger.info(f"{check_name}: {status}")
|
||||||
|
|
||||||
|
logger.info("="*50)
|
||||||
|
|
||||||
|
if failed_checks == 0:
|
||||||
|
logger.info("🚀 All checks passed! You're ready for live trading.")
|
||||||
|
logger.info("\nRun live trading with:")
|
||||||
|
logger.info("python main.py --mode live --demo true --symbol ETH/USDT --timeframe 1m")
|
||||||
|
logger.info("\nFor real trading (after updating API keys):")
|
||||||
|
logger.info("python main.py --mode live --demo false --symbol ETH/USDT --timeframe 1m --leverage 50")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
logger.error(f"❌ {failed_checks} check(s) failed. Please fix the issues before running live trading.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit_code = asyncio.run(main())
|
||||||
|
sys.exit(exit_code)
|
@ -448,6 +448,20 @@ class TradingEnvironment:
|
|||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""Take an action in the environment and return the next state, reward, and done flag"""
|
"""Take an action in the environment and return the next state, reward, and done flag"""
|
||||||
|
# Check if we have enough data
|
||||||
|
if self.current_step >= len(self.data) - 1:
|
||||||
|
# We've reached the end of data
|
||||||
|
done = True
|
||||||
|
next_state = self.get_state()
|
||||||
|
info = {
|
||||||
|
'action': 'none',
|
||||||
|
'price': self.current_price,
|
||||||
|
'balance': self.balance,
|
||||||
|
'position': self.position,
|
||||||
|
'pnl': self.total_pnl
|
||||||
|
}
|
||||||
|
return next_state, 0, done, info
|
||||||
|
|
||||||
# Store current price before taking action
|
# Store current price before taking action
|
||||||
self.current_price = self.data[self.current_step]['close']
|
self.current_price = self.data[self.current_step]['close']
|
||||||
|
|
||||||
@ -486,7 +500,16 @@ class TradingEnvironment:
|
|||||||
# Get new state
|
# Get new state
|
||||||
next_state = self.get_state()
|
next_state = self.get_state()
|
||||||
|
|
||||||
return next_state, reward, done
|
# Create info dictionary
|
||||||
|
info = {
|
||||||
|
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
|
||||||
|
'price': self.current_price,
|
||||||
|
'balance': self.balance,
|
||||||
|
'position': self.position,
|
||||||
|
'pnl': self.total_pnl
|
||||||
|
}
|
||||||
|
|
||||||
|
return next_state, reward, done, info
|
||||||
|
|
||||||
def check_sl_tp(self):
|
def check_sl_tp(self):
|
||||||
"""Check if stop loss or take profit has been hit"""
|
"""Check if stop loss or take profit has been hit"""
|
||||||
@ -709,22 +732,39 @@ class TradingEnvironment:
|
|||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
"""Create state representation for the agent with enhanced features"""
|
"""Create state representation for the agent with enhanced features"""
|
||||||
if len(self.data) < 30 or len(self.features['price']) == 0:
|
# Ensure we have enough data
|
||||||
|
if len(self.data) < 30 or self.current_step >= len(self.data) or len(self.features['price']) == 0:
|
||||||
# Return zeros if not enough data
|
# Return zeros if not enough data
|
||||||
return np.zeros(STATE_SIZE)
|
return np.zeros(STATE_SIZE)
|
||||||
|
|
||||||
# Create a normalized state vector with recent price action and indicators
|
# Create a normalized state vector with recent price action and indicators
|
||||||
state_components = []
|
state_components = []
|
||||||
|
|
||||||
# Price features (normalize recent prices by the latest price)
|
# Safely get the latest price
|
||||||
latest_price = self.features['price'][-1]
|
try:
|
||||||
price_features = np.array(self.features['price'][-10:]) / latest_price - 1.0
|
latest_price = self.features['price'][-1]
|
||||||
state_components.append(price_features)
|
except IndexError:
|
||||||
|
# If we can't get the latest price, return zeros
|
||||||
|
return np.zeros(STATE_SIZE)
|
||||||
|
|
||||||
# Volume features (normalize by max volume)
|
# Safely get price features
|
||||||
max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1
|
try:
|
||||||
vol_features = np.array(self.features['volume'][-5:]) / max_vol
|
# Price features (normalize recent prices by the latest price)
|
||||||
state_components.append(vol_features)
|
price_features = np.array(self.features['price'][-10:]) / latest_price - 1.0
|
||||||
|
state_components.append(price_features)
|
||||||
|
except (IndexError, ZeroDivisionError):
|
||||||
|
# If we can't get price features, use zeros
|
||||||
|
state_components.append(np.zeros(10))
|
||||||
|
|
||||||
|
# Safely get volume features
|
||||||
|
try:
|
||||||
|
# Volume features (normalize by max volume)
|
||||||
|
max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1
|
||||||
|
vol_features = np.array(self.features['volume'][-5:]) / max_vol
|
||||||
|
state_components.append(vol_features)
|
||||||
|
except (IndexError, ZeroDivisionError):
|
||||||
|
# If we can't get volume features, use zeros
|
||||||
|
state_components.append(np.zeros(5))
|
||||||
|
|
||||||
# Technical indicators
|
# Technical indicators
|
||||||
rsi = np.array(self.features['rsi'][-3:]) / 100.0 # Scale to 0-1
|
rsi = np.array(self.features['rsi'][-3:]) / 100.0 # Scale to 0-1
|
||||||
@ -872,10 +912,18 @@ class TradingEnvironment:
|
|||||||
# Combine all features
|
# Combine all features
|
||||||
state = np.concatenate([comp.flatten() for comp in state_components])
|
state = np.concatenate([comp.flatten() for comp in state_components])
|
||||||
|
|
||||||
# Replace any NaN values
|
# Replace any NaN or infinite values
|
||||||
state = np.nan_to_num(state, nan=0.0)
|
state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0)
|
||||||
|
|
||||||
|
# Ensure the state has the correct size
|
||||||
|
if len(state) != STATE_SIZE:
|
||||||
|
logger.warning(f"State size mismatch: expected {STATE_SIZE}, got {len(state)}")
|
||||||
|
# Pad or truncate to match expected size
|
||||||
|
if len(state) < STATE_SIZE:
|
||||||
|
state = np.pad(state, (0, STATE_SIZE - len(state)))
|
||||||
|
else:
|
||||||
|
state = state[:STATE_SIZE]
|
||||||
|
|
||||||
# Return the state (the caller will handle sizing)
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def get_expanded_state_size(self):
|
def get_expanded_state_size(self):
|
||||||
@ -1461,7 +1509,7 @@ class TradingEnvironment:
|
|||||||
Returns:
|
Returns:
|
||||||
Potential profit percentage
|
Potential profit percentage
|
||||||
"""
|
"""
|
||||||
if len(self.data) <= 1:
|
if len(self.data) <= 1 or self.current_step >= len(self.data):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# Get current price
|
# Get current price
|
||||||
@ -1471,6 +1519,7 @@ class TradingEnvironment:
|
|||||||
future_prices = []
|
future_prices = []
|
||||||
current_idx = self.current_step
|
current_idx = self.current_step
|
||||||
|
|
||||||
|
# Safely get future prices
|
||||||
for i in range(1, min(lookahead + 1, len(self.data) - current_idx)):
|
for i in range(1, min(lookahead + 1, len(self.data) - current_idx)):
|
||||||
if current_idx + i < len(self.data):
|
if current_idx + i < len(self.data):
|
||||||
future_prices.append(self.data[current_idx + i]['close'])
|
future_prices.append(self.data[current_idx + i]['close'])
|
||||||
@ -1500,8 +1549,9 @@ class TradingEnvironment:
|
|||||||
await exchange.set_leverage(self.leverage, symbol=self.futures_symbol)
|
await exchange.set_leverage(self.leverage, symbol=self.futures_symbol)
|
||||||
logger.info(f"Futures initialized with {self.leverage}x leverage")
|
logger.info(f"Futures initialized with {self.leverage}x leverage")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize futures: {e}")
|
logger.error(f"Failed to initialize futures trading: {str(e)}")
|
||||||
raise
|
logger.info("Falling back to demo mode for safety")
|
||||||
|
demo = True
|
||||||
|
|
||||||
async def execute_real_trade(self, exchange, action, current_price):
|
async def execute_real_trade(self, exchange, action, current_price):
|
||||||
"""Execute real futures trade on MEXC"""
|
"""Execute real futures trade on MEXC"""
|
||||||
@ -1744,7 +1794,7 @@ class Agent:
|
|||||||
def update_target_network(self):
|
def update_target_network(self):
|
||||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||||
|
|
||||||
def save(self, path="models/trading_agent.pt"):
|
def save(self, path="models/trading_agent_best_pnl.pt"):
|
||||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
torch.save({
|
torch.save({
|
||||||
'policy_net': self.policy_net.state_dict(),
|
'policy_net': self.policy_net.state_dict(),
|
||||||
@ -1755,37 +1805,88 @@ class Agent:
|
|||||||
}, path)
|
}, path)
|
||||||
logger.info(f"Model saved to {path}")
|
logger.info(f"Model saved to {path}")
|
||||||
|
|
||||||
def load(self, path="models/trading_agent.pt"):
|
def load(self, path="models/trading_agent_best_pnl.pt"):
|
||||||
if os.path.isfile(path):
|
"""Load a trained model"""
|
||||||
try:
|
try:
|
||||||
# First try with weights_only=True (safer)
|
# First try with weights_only=True (safer)
|
||||||
checkpoint = torch.load(path, weights_only=True)
|
checkpoint = torch.load(path, map_location=self.device)
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to load with weights_only=True: {e}")
|
|
||||||
try:
|
|
||||||
# Try with safe_globals for numpy.scalar
|
|
||||||
import numpy as np
|
|
||||||
from torch.serialization import safe_globals
|
|
||||||
with safe_globals([np.core.multiarray.scalar]):
|
|
||||||
checkpoint = torch.load(path, weights_only=True)
|
|
||||||
except Exception as e2:
|
|
||||||
logger.warning(f"Failed with safe_globals: {e2}")
|
|
||||||
# Fall back to weights_only=False if needed
|
|
||||||
checkpoint = torch.load(path, weights_only=False)
|
|
||||||
|
|
||||||
|
# Check if model architecture matches
|
||||||
|
try:
|
||||||
|
state_dict = checkpoint['policy_net']
|
||||||
|
input_size = state_dict['fc1.weight'].shape[1]
|
||||||
|
output_size = state_dict['advantage_stream.bias'].shape[0]
|
||||||
|
hidden_size = state_dict['fc1.weight'].shape[0]
|
||||||
|
|
||||||
|
# If architecture doesn't match, rebuild the model
|
||||||
|
if (input_size != self.state_size or
|
||||||
|
output_size != self.action_size or
|
||||||
|
hidden_size != self.hidden_size):
|
||||||
|
logger.warning(f"Model architecture mismatch. Rebuilding model with: "
|
||||||
|
f"state_size={input_size}, action_size={output_size}, "
|
||||||
|
f"hidden_size={hidden_size}")
|
||||||
|
|
||||||
|
# Rebuild the model with the correct architecture
|
||||||
|
self.state_size = input_size
|
||||||
|
self.action_size = output_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
# Recreate networks with correct architecture
|
||||||
|
self.policy_net = DQN(self.state_size, self.action_size,
|
||||||
|
self.hidden_size, self.lstm_layers,
|
||||||
|
self.attention_heads).to(self.device)
|
||||||
|
self.target_net = DQN(self.state_size, self.action_size,
|
||||||
|
self.hidden_size, self.lstm_layers,
|
||||||
|
self.attention_heads).to(self.device)
|
||||||
|
|
||||||
|
# Recreate optimizer
|
||||||
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error checking model architecture: {e}")
|
||||||
|
|
||||||
|
# Load state dictionaries
|
||||||
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
||||||
self.target_net.load_state_dict(checkpoint['target_net'])
|
self.target_net.load_state_dict(checkpoint['target_net'])
|
||||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
|
||||||
self.epsilon = checkpoint['epsilon']
|
# Try to load optimizer state
|
||||||
self.steps_done = checkpoint['steps_done']
|
try:
|
||||||
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not load optimizer state: {e}")
|
||||||
|
|
||||||
|
# Load epsilon if available
|
||||||
|
if 'epsilon' in checkpoint:
|
||||||
|
self.epsilon = checkpoint['epsilon']
|
||||||
|
|
||||||
logger.info(f"Model loaded from {path}")
|
logger.info(f"Model loaded from {path}")
|
||||||
return True
|
except Exception as e:
|
||||||
logger.warning(f"No model found at {path}")
|
logger.warning(f"Error loading model with default method: {e}")
|
||||||
return False
|
|
||||||
|
# Try with weights_only=False as fallback
|
||||||
|
try:
|
||||||
|
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
||||||
|
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
||||||
|
self.target_net.load_state_dict(checkpoint['target_net'])
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
except:
|
||||||
|
logger.warning("Could not load optimizer state")
|
||||||
|
|
||||||
|
if 'epsilon' in checkpoint:
|
||||||
|
self.epsilon = checkpoint['epsilon']
|
||||||
|
|
||||||
|
logger.info(f"Model loaded from {path} with weights_only=False")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
def add_chart_to_tensorboard(self, env, global_step):
|
def add_chart_to_tensorboard(self, env, global_step):
|
||||||
"""Add trading chart to TensorBoard"""
|
"""Add trading chart to TensorBoard"""
|
||||||
try:
|
try:
|
||||||
|
if len(env.data) < 10:
|
||||||
|
return
|
||||||
|
|
||||||
# Create chart image
|
# Create chart image
|
||||||
chart_img = create_candlestick_figure(
|
chart_img = create_candlestick_figure(
|
||||||
env.data,
|
env.data,
|
||||||
@ -1807,11 +1908,12 @@ class Agent:
|
|||||||
**Entry Price**: ${env.entry_price:.2f if env.entry_price else 0:.2f}
|
**Entry Price**: ${env.entry_price:.2f if env.entry_price else 0:.2f}
|
||||||
**Current Price**: ${env.data[-1]['close']:.2f}
|
**Current Price**: ${env.data[-1]['close']:.2f}
|
||||||
**Position Size**: ${env.position_size:.2f}
|
**Position Size**: ${env.position_size:.2f}
|
||||||
**Unrealized PnL**: ${env.unrealized_pnl:.2f}
|
**Unrealized PnL**: ${env.total_pnl:.2f}
|
||||||
"""
|
"""
|
||||||
self.writer.add_text('Position', position_info, global_step)
|
self.writer.add_text('Position', position_info, global_step)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error adding chart to TensorBoard: {str(e)}")
|
logger.error(f"Error adding chart to TensorBoard: {str(e)}")
|
||||||
|
# Continue without visualization rather than crashing
|
||||||
|
|
||||||
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
||||||
"""Get live price data using websockets"""
|
"""Get live price data using websockets"""
|
||||||
@ -1853,8 +1955,7 @@ async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
|||||||
|
|
||||||
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000):
|
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000):
|
||||||
"""Train the agent using historical and live data with GPU acceleration"""
|
"""Train the agent using historical and live data with GPU acceleration"""
|
||||||
logger.info(f"Starting training on device: {agent.device}")
|
# Initialize statistics tracking
|
||||||
|
|
||||||
stats = {
|
stats = {
|
||||||
'episode_rewards': [],
|
'episode_rewards': [],
|
||||||
'episode_lengths': [],
|
'episode_lengths': [],
|
||||||
@ -1867,147 +1968,164 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
|||||||
'trade_analysis': []
|
'trade_analysis': []
|
||||||
}
|
}
|
||||||
|
|
||||||
best_reward = -float('inf')
|
# Track best models
|
||||||
best_pnl = -float('inf')
|
best_reward = float('-inf')
|
||||||
|
best_pnl = float('-inf')
|
||||||
|
|
||||||
try:
|
# Initialize TensorBoard writer if not already initialized
|
||||||
# Initialize price predictor
|
if not hasattr(agent, 'writer') or agent.writer is None:
|
||||||
env.initialize_price_predictor(agent.device)
|
agent.writer = SummaryWriter('runs/training')
|
||||||
|
|
||||||
for episode in range(num_episodes):
|
# Training loop
|
||||||
try:
|
for episode in range(num_episodes):
|
||||||
# Reset environment
|
try:
|
||||||
state = env.reset()
|
# Reset environment
|
||||||
episode_reward = 0
|
state = env.reset()
|
||||||
env.episode_pnl = 0.0 # Reset episode PnL
|
episode_reward = 0
|
||||||
|
prediction_loss = 0
|
||||||
|
|
||||||
# Identify optimal trade points for this episode
|
# Episode loop
|
||||||
env.identify_optimal_trades()
|
for step in range(max_steps_per_episode):
|
||||||
|
# Select action
|
||||||
|
action = agent.select_action(state)
|
||||||
|
|
||||||
# Train price predictor
|
# Take action
|
||||||
prediction_loss = env.train_price_predictor()
|
try:
|
||||||
|
next_state, reward, done, info = env.step(action)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in step function: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
# Update price predictions
|
# Store transition in replay memory
|
||||||
env.update_price_predictions()
|
agent.memory.push(state, action, reward, next_state, done)
|
||||||
|
|
||||||
for step in range(max_steps_per_episode):
|
# Move to the next state
|
||||||
# Select action
|
state = next_state
|
||||||
action = agent.select_action(state)
|
|
||||||
|
|
||||||
# Take action
|
# Update episode reward
|
||||||
next_state, reward, done = env.step(action)
|
episode_reward += reward
|
||||||
|
|
||||||
# Store experience
|
# Learn from experience
|
||||||
agent.memory.push(state, action, reward, next_state, done)
|
if len(agent.memory) > BATCH_SIZE:
|
||||||
|
agent.learn()
|
||||||
|
|
||||||
state = next_state
|
# Update price predictions periodically
|
||||||
episode_reward += reward
|
if step % 50 == 0:
|
||||||
|
|
||||||
# Learn from experience with mixed precision
|
|
||||||
try:
|
try:
|
||||||
loss = agent.learn()
|
|
||||||
if loss is not None:
|
|
||||||
agent.writer.add_scalar('Loss/train', loss, agent.steps_done)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Learning error in episode {episode}, step {step}: {e}")
|
|
||||||
|
|
||||||
# Update price predictions periodically
|
|
||||||
if step % 10 == 0:
|
|
||||||
env.update_price_predictions()
|
env.update_price_predictions()
|
||||||
|
env.identify_optimal_trades()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error updating predictions: {e}")
|
||||||
|
|
||||||
# Add chart to TensorBoard periodically
|
# Add chart to TensorBoard periodically
|
||||||
if step % 50 == 0 or (step == max_steps_per_episode - 1) or done:
|
if step % 50 == 0 or (step == max_steps_per_episode - 1) or done:
|
||||||
|
try:
|
||||||
global_step = episode * max_steps_per_episode + step
|
global_step = episode * max_steps_per_episode + step
|
||||||
agent.add_chart_to_tensorboard(env, global_step)
|
agent.add_chart_to_tensorboard(env, global_step)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error adding chart to TensorBoard: {e}")
|
||||||
|
|
||||||
if done:
|
# End episode if done
|
||||||
break
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
# Update target network
|
# Update target network periodically
|
||||||
if episode % TARGET_UPDATE == 0:
|
if episode % TARGET_UPDATE == 0:
|
||||||
agent.target_net.load_state_dict(agent.policy_net.state_dict())
|
agent.update_target_network()
|
||||||
|
|
||||||
# Calculate win rate
|
# Calculate win rate
|
||||||
if len(env.trades) > 0:
|
total_trades = env.win_count + env.loss_count
|
||||||
wins = sum(1 for trade in env.trades if trade.get('pnl_percent', 0) > 0)
|
win_rate = (env.win_count / total_trades * 100) if total_trades > 0 else 0
|
||||||
win_rate = wins / len(env.trades) * 100
|
|
||||||
else:
|
|
||||||
win_rate = 0
|
|
||||||
|
|
||||||
# Analyze trades
|
# Train price predictor
|
||||||
|
try:
|
||||||
|
if episode % 5 == 0 and len(env.data) > 50:
|
||||||
|
prediction_loss = env.train_price_predictor()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error training price predictor: {e}")
|
||||||
|
prediction_loss = 0
|
||||||
|
|
||||||
|
# Analyze trades
|
||||||
|
try:
|
||||||
trade_analysis = env.analyze_trades()
|
trade_analysis = env.analyze_trades()
|
||||||
stats['trade_analysis'].append(trade_analysis)
|
stats['trade_analysis'].append(trade_analysis)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error analyzing trades: {e}")
|
||||||
|
trade_analysis = {}
|
||||||
|
stats['trade_analysis'].append({})
|
||||||
|
|
||||||
# Calculate prediction accuracy
|
# Calculate prediction accuracy
|
||||||
prediction_accuracy = 0.0
|
prediction_accuracy = 0.0
|
||||||
|
try:
|
||||||
if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0:
|
if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0:
|
||||||
if len(env.data) > 5:
|
if len(env.data) > 5:
|
||||||
actual_prices = [candle['close'] for candle in env.data[-5:]]
|
actual_prices = [candle['close'] for candle in env.data[-5:]]
|
||||||
predicted = env.predicted_prices[:min(5, len(actual_prices))]
|
predicted = env.predicted_prices[:min(5, len(actual_prices))]
|
||||||
errors = [abs(p - a) / a for p, a in zip(predicted, actual_prices[:len(predicted)])]
|
errors = [abs(p - a) / a for p, a in zip(predicted, actual_prices[:len(predicted)])]
|
||||||
prediction_accuracy = 100 * (1 - sum(errors) / len(errors))
|
prediction_accuracy = 100 * (1 - sum(errors) / len(errors))
|
||||||
|
|
||||||
# Log statistics
|
|
||||||
stats['episode_rewards'].append(episode_reward)
|
|
||||||
stats['episode_lengths'].append(step + 1)
|
|
||||||
stats['balances'].append(env.balance)
|
|
||||||
stats['win_rates'].append(win_rate)
|
|
||||||
stats['episode_pnls'].append(env.episode_pnl)
|
|
||||||
stats['cumulative_pnl'].append(env.total_pnl)
|
|
||||||
stats['drawdowns'].append(env.max_drawdown * 100)
|
|
||||||
stats['prediction_accuracy'].append(prediction_accuracy)
|
|
||||||
|
|
||||||
# Log detailed trade analysis
|
|
||||||
if trade_analysis:
|
|
||||||
logger.info(f"Trade Analysis: Win Rate={trade_analysis.get('uptrend_win_rate', 0):.1f}% in uptrends, "
|
|
||||||
f"{trade_analysis.get('downtrend_win_rate', 0):.1f}% in downtrends | "
|
|
||||||
f"Avg Win=${trade_analysis.get('avg_win', 0):.2f}, Avg Loss=${trade_analysis.get('avg_loss', 0):.2f}")
|
|
||||||
|
|
||||||
# Log to TensorBoard
|
|
||||||
agent.writer.add_scalar('Reward/train', episode_reward, episode)
|
|
||||||
agent.writer.add_scalar('Balance/train', env.balance, episode)
|
|
||||||
agent.writer.add_scalar('WinRate/train', win_rate, episode)
|
|
||||||
agent.writer.add_scalar('PnL/episode', env.episode_pnl, episode)
|
|
||||||
agent.writer.add_scalar('PnL/cumulative', env.total_pnl, episode)
|
|
||||||
agent.writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode)
|
|
||||||
agent.writer.add_scalar('PredictionLoss', prediction_loss, episode)
|
|
||||||
agent.writer.add_scalar('PredictionAccuracy', prediction_accuracy, episode)
|
|
||||||
|
|
||||||
# Add final chart for this episode
|
|
||||||
agent.add_chart_to_tensorboard(env, (episode + 1) * max_steps_per_episode)
|
|
||||||
|
|
||||||
logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, "
|
|
||||||
f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}, "
|
|
||||||
f"Episode PnL=${env.episode_pnl:.2f}, Total PnL=${env.total_pnl:.2f}, "
|
|
||||||
f"Max Drawdown={env.max_drawdown*100:.1f}%, Pred Accuracy={prediction_accuracy:.1f}%")
|
|
||||||
|
|
||||||
# Save best model by reward
|
|
||||||
if episode_reward > best_reward:
|
|
||||||
best_reward = episode_reward
|
|
||||||
agent.save("models/trading_agent_best_reward.pt")
|
|
||||||
|
|
||||||
# Save best model by PnL
|
|
||||||
if env.episode_pnl > best_pnl:
|
|
||||||
best_pnl = env.episode_pnl
|
|
||||||
agent.save("models/trading_agent_best_pnl.pt")
|
|
||||||
|
|
||||||
# Save checkpoint
|
|
||||||
if episode % 10 == 0:
|
|
||||||
agent.save(f"models/trading_agent_episode_{episode}.pt")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in episode {episode}: {e}")
|
logger.warning(f"Error calculating prediction accuracy: {e}")
|
||||||
continue
|
|
||||||
|
|
||||||
# Save final model
|
# Log statistics
|
||||||
agent.save("models/trading_agent_final.pt")
|
stats['episode_rewards'].append(episode_reward)
|
||||||
|
stats['episode_lengths'].append(step + 1)
|
||||||
|
stats['balances'].append(env.balance)
|
||||||
|
stats['win_rates'].append(win_rate)
|
||||||
|
stats['episode_pnls'].append(env.episode_pnl)
|
||||||
|
stats['cumulative_pnl'].append(env.total_pnl)
|
||||||
|
stats['drawdowns'].append(env.max_drawdown * 100)
|
||||||
|
stats['prediction_accuracy'].append(prediction_accuracy)
|
||||||
|
|
||||||
# Plot training results
|
# Log detailed trade analysis
|
||||||
plot_training_results(stats)
|
if trade_analysis:
|
||||||
|
logger.info(f"Trade Analysis: Win Rate={trade_analysis.get('uptrend_win_rate', 0):.1f}% in uptrends, "
|
||||||
|
f"{trade_analysis.get('downtrend_win_rate', 0):.1f}% in downtrends | "
|
||||||
|
f"Avg Win=${trade_analysis.get('avg_win', 0):.2f}, Avg Loss=${trade_analysis.get('avg_loss', 0):.2f}")
|
||||||
|
|
||||||
except Exception as e:
|
# Log to TensorBoard
|
||||||
logger.error(f"Training failed: {e}")
|
agent.writer.add_scalar('Reward/train', episode_reward, episode)
|
||||||
raise
|
agent.writer.add_scalar('Balance/train', env.balance, episode)
|
||||||
|
agent.writer.add_scalar('WinRate/train', win_rate, episode)
|
||||||
|
agent.writer.add_scalar('PnL/episode', env.episode_pnl, episode)
|
||||||
|
agent.writer.add_scalar('PnL/cumulative', env.total_pnl, episode)
|
||||||
|
agent.writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode)
|
||||||
|
agent.writer.add_scalar('PredictionLoss', prediction_loss, episode)
|
||||||
|
agent.writer.add_scalar('PredictionAccuracy', prediction_accuracy, episode)
|
||||||
|
|
||||||
|
# Add final chart for this episode
|
||||||
|
try:
|
||||||
|
agent.add_chart_to_tensorboard(env, (episode + 1) * max_steps_per_episode)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error adding final chart: {e}")
|
||||||
|
|
||||||
|
logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, "
|
||||||
|
f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}, "
|
||||||
|
f"Episode PnL=${env.episode_pnl:.2f}, Total PnL=${env.total_pnl:.2f}, "
|
||||||
|
f"Max Drawdown={env.max_drawdown*100:.1f}%, Pred Accuracy={prediction_accuracy:.1f}%")
|
||||||
|
|
||||||
|
# Save best model by reward
|
||||||
|
if episode_reward > best_reward:
|
||||||
|
best_reward = episode_reward
|
||||||
|
agent.save("models/trading_agent_best_reward.pt")
|
||||||
|
|
||||||
|
# Save best model by PnL
|
||||||
|
if env.episode_pnl > best_pnl:
|
||||||
|
best_pnl = env.episode_pnl
|
||||||
|
agent.save("models/trading_agent_best_pnl.pt")
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
if episode % 10 == 0:
|
||||||
|
agent.save(f"models/trading_agent_episode_{episode}.pt")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in episode {episode}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Save final model
|
||||||
|
agent.save("models/trading_agent_final.pt")
|
||||||
|
|
||||||
|
# Plot training results
|
||||||
|
plot_training_results(stats)
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
@ -2082,7 +2200,7 @@ def evaluate_agent(agent, env, num_episodes=10):
|
|||||||
while not done:
|
while not done:
|
||||||
# Select action (no exploration)
|
# Select action (no exploration)
|
||||||
action = agent.select_action(state, training=False)
|
action = agent.select_action(state, training=False)
|
||||||
next_state, reward, done = env.step(action)
|
next_state, reward, done, info = env.step(action)
|
||||||
|
|
||||||
state = next_state
|
state = next_state
|
||||||
episode_reward += reward
|
episode_reward += reward
|
||||||
@ -2150,7 +2268,7 @@ async def test_training():
|
|||||||
action = agent.select_action(state)
|
action = agent.select_action(state)
|
||||||
|
|
||||||
# Take action
|
# Take action
|
||||||
next_state, reward, done = env.step(action)
|
next_state, reward, done, info = env.step(action)
|
||||||
|
|
||||||
# Store experience
|
# Store experience
|
||||||
agent.memory.push(state, action, reward, next_state, done)
|
agent.memory.push(state, action, reward, next_state, done)
|
||||||
@ -2432,20 +2550,27 @@ async def fetch_ohlcv_data(exchange, symbol, timeframe, limit):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
"""Main function to run the trading bot"""
|
parser = argparse.ArgumentParser(description='Trading Bot')
|
||||||
parser = argparse.ArgumentParser(description='Crypto Trading Bot')
|
parser.add_argument('--mode', type=str, choices=['train', 'eval', 'live'], default='train',
|
||||||
parser.add_argument('--mode', type=str, default='train', choices=['train', 'evaluate', 'live'],
|
help='Operation mode: train, eval, or live')
|
||||||
help='Mode to run the bot in')
|
parser.add_argument('--episodes', type=int, default=1000,
|
||||||
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train')
|
help='Number of episodes for training or evaluation')
|
||||||
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
|
parser.add_argument('--demo', type=str, choices=['true', 'false'], default='true',
|
||||||
parser.add_argument('--live', action='store_true', help='Run in live trading mode')
|
help='Run in demo mode (paper trading) if true')
|
||||||
parser.add_argument('--real', action='store_true', help='Execute real trades (default is demo/paper trading)')
|
parser.add_argument('--symbol', type=str, default='ETH/USDT',
|
||||||
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
|
help='Trading pair symbol')
|
||||||
parser.add_argument('--timeframe', type=str, default='1m', help='Candle timeframe')
|
parser.add_argument('--timeframe', type=str, default='1m',
|
||||||
parser.add_argument('--leverage', type=int, default=50, help='Leverage for futures trading')
|
help='Candle timeframe (1m, 5m, 15m, 1h, etc.)')
|
||||||
parser.add_argument('--model', type=str, help='Path to model file')
|
parser.add_argument('--leverage', type=int, default=50,
|
||||||
|
help='Leverage for futures trading')
|
||||||
|
parser.add_argument('--model', type=str, default=None,
|
||||||
|
help='Path to model file for evaluation or live trading')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Convert string boolean to actual boolean
|
||||||
|
demo_mode = args.demo.lower() == 'true'
|
||||||
|
|
||||||
# Get device (GPU or CPU)
|
# Get device (GPU or CPU)
|
||||||
device = get_device()
|
device = get_device()
|
||||||
|
|
||||||
@ -2456,7 +2581,7 @@ async def main():
|
|||||||
exchange = await initialize_exchange()
|
exchange = await initialize_exchange()
|
||||||
|
|
||||||
# Create environment
|
# Create environment
|
||||||
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=args.demo)
|
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode)
|
||||||
|
|
||||||
# Fetch initial data
|
# Fetch initial data
|
||||||
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
|
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
|
||||||
@ -2478,46 +2603,34 @@ async def main():
|
|||||||
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env)
|
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env)
|
||||||
|
|
||||||
elif args.mode == 'live':
|
elif args.mode == 'live':
|
||||||
# Add these arguments to the parser
|
# Initialize exchange
|
||||||
parser.add_argument('--live', action='store_true', help='Run in live trading mode')
|
exchange = await initialize_exchange()
|
||||||
parser.add_argument('--real', action='store_true', help='Execute real trades (default is demo/paper trading)')
|
|
||||||
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
|
|
||||||
parser.add_argument('--timeframe', type=str, default='1m', help='Candle timeframe')
|
|
||||||
parser.add_argument('--leverage', type=int, default=50, help='Leverage for futures trading')
|
|
||||||
parser.add_argument('--model', type=str, help='Path to model file')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# In the main function, add this section to handle live trading
|
# Load model
|
||||||
if args.live:
|
model_path = args.model if args.model else "models/trading_agent.pt"
|
||||||
# Initialize exchange
|
if not os.path.exists(model_path):
|
||||||
exchange = await initialize_exchange()
|
logger.error(f"Model file not found: {model_path}")
|
||||||
|
return
|
||||||
|
|
||||||
# Load the trained agent
|
# Initialize environment
|
||||||
model_path = args.model if args.model else "models/trading_agent.pt"
|
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=WINDOW_SIZE, demo=demo_mode)
|
||||||
if not os.path.exists(model_path):
|
await env.fetch_initial_data(exchange, symbol=args.symbol, timeframe=args.timeframe)
|
||||||
logger.error(f"Model file not found: {model_path}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Initialize environment with historical data
|
# Initialize agent
|
||||||
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=WINDOW_SIZE, demo=not args.real)
|
state_size = env.get_state().shape[0]
|
||||||
await env.fetch_initial_data(exchange, symbol=args.symbol, timeframe=args.timeframe)
|
agent = Agent(state_size=state_size, action_size=3)
|
||||||
|
agent.load(model_path)
|
||||||
|
|
||||||
# Initialize agent
|
# Start live trading
|
||||||
state_size = env.get_state().shape[0]
|
await live_trading(
|
||||||
agent = Agent(state_size=state_size, action_size=3)
|
agent=agent,
|
||||||
agent.load(model_path)
|
env=env,
|
||||||
logger.info(f"Loaded model from {model_path}")
|
exchange=exchange,
|
||||||
|
symbol=args.symbol,
|
||||||
# Start live trading
|
timeframe=args.timeframe,
|
||||||
await live_trading(
|
demo=demo_mode,
|
||||||
agent=agent,
|
leverage=args.leverage
|
||||||
env=env,
|
)
|
||||||
exchange=exchange,
|
|
||||||
symbol=args.symbol,
|
|
||||||
timeframe=args.timeframe,
|
|
||||||
demo=not args.real,
|
|
||||||
leverage=args.leverage
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up exchange connection - safely close if possible
|
# Clean up exchange connection - safely close if possible
|
||||||
@ -2552,30 +2665,37 @@ def create_candlestick_figure(data, trade_signals, window_size=100, title=""):
|
|||||||
price_ax = plt.subplot(gs[0])
|
price_ax = plt.subplot(gs[0])
|
||||||
volume_ax = plt.subplot(gs[1], sharex=price_ax)
|
volume_ax = plt.subplot(gs[1], sharex=price_ax)
|
||||||
|
|
||||||
# Plot candlesticks
|
# Plot candlesticks - use a simpler approach if mplfinance fails
|
||||||
mpf.plot(df, type='candle', style='yahoo', ax=price_ax, volume=volume_ax)
|
try:
|
||||||
|
mpf.plot(df, type='candle', style='yahoo', ax=price_ax, volume=volume_ax)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error plotting with mplfinance: {e}, falling back to simple plot")
|
||||||
|
# Fallback to simple plot
|
||||||
|
price_ax.plot(df.index, df['close'], label='Price')
|
||||||
|
volume_ax.bar(df.index, df['volume'], color='blue', alpha=0.5)
|
||||||
|
|
||||||
# Add trade signals
|
# Add trade signals
|
||||||
for signal in trade_signals:
|
for signal in trade_signals:
|
||||||
if signal['timestamp'] not in df.index:
|
try:
|
||||||
|
timestamp = pd.to_datetime(signal['timestamp'], unit='ms')
|
||||||
|
price = signal['price']
|
||||||
|
|
||||||
|
if signal['type'] == 'buy':
|
||||||
|
price_ax.plot(timestamp, price, '^', color='green', markersize=10)
|
||||||
|
elif signal['type'] == 'sell':
|
||||||
|
price_ax.plot(timestamp, price, 'v', color='red', markersize=10)
|
||||||
|
elif signal['type'] == 'close_long':
|
||||||
|
price_ax.plot(timestamp, price, 'x', color='gold', markersize=10)
|
||||||
|
elif signal['type'] == 'close_short':
|
||||||
|
price_ax.plot(timestamp, price, 'x', color='black', markersize=10)
|
||||||
|
elif 'stop_loss' in signal['type']:
|
||||||
|
price_ax.plot(timestamp, price, 'X', color='purple', markersize=10)
|
||||||
|
elif 'take_profit' in signal['type']:
|
||||||
|
price_ax.plot(timestamp, price, '*', color='cyan', markersize=10)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error plotting signal: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
timestamp = pd.to_datetime(signal['timestamp'], unit='ms')
|
|
||||||
price = signal['price']
|
|
||||||
|
|
||||||
if signal['type'] == 'buy':
|
|
||||||
price_ax.plot(timestamp, price, '^', color='green', markersize=10, label='Buy')
|
|
||||||
elif signal['type'] == 'sell':
|
|
||||||
price_ax.plot(timestamp, price, 'v', color='red', markersize=10, label='Sell')
|
|
||||||
elif signal['type'] == 'close_long':
|
|
||||||
price_ax.plot(timestamp, price, 'x', color='gold', markersize=10, label='Close Long')
|
|
||||||
elif signal['type'] == 'close_short':
|
|
||||||
price_ax.plot(timestamp, price, 'x', color='black', markersize=10, label='Close Short')
|
|
||||||
elif 'stop_loss' in signal['type']:
|
|
||||||
price_ax.plot(timestamp, price, 'X', color='purple', markersize=10, label='Stop Loss')
|
|
||||||
elif 'take_profit' in signal['type']:
|
|
||||||
price_ax.plot(timestamp, price, '*', color='cyan', markersize=10, label='Take Profit')
|
|
||||||
|
|
||||||
# Add balance and PnL annotation
|
# Add balance and PnL annotation
|
||||||
if trade_signals and 'balance' in trade_signals[-1] and 'pnl' in trade_signals[-1]:
|
if trade_signals and 'balance' in trade_signals[-1] and 'pnl' in trade_signals[-1]:
|
||||||
balance = trade_signals[-1]['balance']
|
balance = trade_signals[-1]['balance']
|
||||||
@ -2592,6 +2712,7 @@ def create_candlestick_figure(data, trade_signals, window_size=100, title=""):
|
|||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
fig.savefig(buf, format='png')
|
fig.savefig(buf, format='png')
|
||||||
buf.seek(0)
|
buf.seek(0)
|
||||||
|
plt.close(fig)
|
||||||
img = Image.open(buf)
|
img = Image.open(buf)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
numpy>=1.21.0
|
numpy>=1.21.0
|
||||||
pandas>=1.3.0
|
pandas>=1.3.0
|
||||||
matplotlib>=3.4.0
|
matplotlib>=3.4.0
|
||||||
|
mplfinance>=0.12.7
|
||||||
torch>=1.9.0
|
torch>=1.9.0
|
||||||
python-dotenv>=0.19.0
|
python-dotenv>=0.19.0
|
||||||
ccxt>=2.0.0
|
ccxt>=2.0.0
|
||||||
websockets>=10.0
|
websockets>=10.0
|
||||||
tensorboard>=2.6.0
|
tensorboard>=2.6.0
|
||||||
scikit-learn
|
scikit-learn>=1.0.0
|
||||||
|
Pillow>=9.0.0
|
||||||
|
asyncio>=3.4.3
|
477
crypto/gogo2/run_live_demo.py
Normal file
477
crypto/gogo2/run_live_demo.py
Normal file
@ -0,0 +1,477 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import random
|
||||||
|
import datetime
|
||||||
|
import torch
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import io
|
||||||
|
from PIL import Image
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler("live_trading.log"),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("live_trading")
|
||||||
|
|
||||||
|
def generate_mock_data(symbol, timeframe, limit=1000):
|
||||||
|
"""Generate mock OHLCV data for demo mode"""
|
||||||
|
logger.info(f"Generating mock data for {symbol} ({timeframe})")
|
||||||
|
|
||||||
|
# Set seed for reproducibility
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
# Generate timestamps
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
start_time = end_time - datetime.timedelta(minutes=limit)
|
||||||
|
timestamps = [start_time + datetime.timedelta(minutes=i) for i in range(limit)]
|
||||||
|
|
||||||
|
# Convert to milliseconds
|
||||||
|
timestamps_ms = [int(ts.timestamp() * 1000) for ts in timestamps]
|
||||||
|
|
||||||
|
# Generate price data with realistic patterns
|
||||||
|
base_price = 3000.0 # Starting price
|
||||||
|
price_data = []
|
||||||
|
current_price = base_price
|
||||||
|
|
||||||
|
for i in range(limit):
|
||||||
|
# Random walk with momentum and volatility clusters
|
||||||
|
momentum = np.random.normal(0, 1)
|
||||||
|
volatility = 0.5 + 0.5 * np.sin(i / 100) # Cyclical volatility
|
||||||
|
|
||||||
|
# Price change with momentum and volatility
|
||||||
|
price_change = momentum * volatility * current_price * 0.005
|
||||||
|
|
||||||
|
# Add some trends and patterns
|
||||||
|
if i % 200 < 100: # Uptrend for 100 candles, then downtrend
|
||||||
|
price_change += current_price * 0.001
|
||||||
|
else:
|
||||||
|
price_change -= current_price * 0.0008
|
||||||
|
|
||||||
|
# Update current price
|
||||||
|
current_price += price_change
|
||||||
|
|
||||||
|
# Generate OHLCV data
|
||||||
|
open_price = current_price
|
||||||
|
close_price = current_price + np.random.normal(0, 1) * current_price * 0.002
|
||||||
|
high_price = max(open_price, close_price) + abs(np.random.normal(0, 1)) * current_price * 0.003
|
||||||
|
low_price = min(open_price, close_price) - abs(np.random.normal(0, 1)) * current_price * 0.003
|
||||||
|
volume = np.random.gamma(2, 100) * (1 + 0.5 * np.sin(i / 50)) # Cyclical volume
|
||||||
|
|
||||||
|
# Store data
|
||||||
|
price_data.append({
|
||||||
|
'timestamp': timestamps_ms[i],
|
||||||
|
'open': open_price,
|
||||||
|
'high': high_price,
|
||||||
|
'low': low_price,
|
||||||
|
'close': close_price,
|
||||||
|
'volume': volume
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Generated {len(price_data)} mock candles")
|
||||||
|
return price_data
|
||||||
|
|
||||||
|
async def generate_mock_live_candles(initial_data, symbol, timeframe):
|
||||||
|
"""Generate mock live candles based on initial data"""
|
||||||
|
last_candle = initial_data[-1].copy()
|
||||||
|
last_timestamp = last_candle['timestamp']
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Wait for next candle
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
# Update timestamp
|
||||||
|
if timeframe == '1m':
|
||||||
|
last_timestamp += 60 * 1000 # 1 minute in milliseconds
|
||||||
|
elif timeframe == '5m':
|
||||||
|
last_timestamp += 5 * 60 * 1000
|
||||||
|
elif timeframe == '15m':
|
||||||
|
last_timestamp += 15 * 60 * 1000
|
||||||
|
elif timeframe == '1h':
|
||||||
|
last_timestamp += 60 * 60 * 1000
|
||||||
|
else:
|
||||||
|
last_timestamp += 60 * 1000 # Default to 1 minute
|
||||||
|
|
||||||
|
# Generate new candle
|
||||||
|
last_price = last_candle['close']
|
||||||
|
price_change = np.random.normal(0, 1) * last_price * 0.002
|
||||||
|
|
||||||
|
# Add some persistence
|
||||||
|
if last_candle['close'] > last_candle['open']:
|
||||||
|
# Previous candle was green, more likely to continue up
|
||||||
|
price_change += last_price * 0.0005
|
||||||
|
else:
|
||||||
|
# Previous candle was red, more likely to continue down
|
||||||
|
price_change -= last_price * 0.0005
|
||||||
|
|
||||||
|
# Generate OHLCV data
|
||||||
|
open_price = last_price
|
||||||
|
close_price = last_price + price_change
|
||||||
|
high_price = max(open_price, close_price) + abs(np.random.normal(0, 1)) * last_price * 0.001
|
||||||
|
low_price = min(open_price, close_price) - abs(np.random.normal(0, 1)) * last_price * 0.001
|
||||||
|
volume = np.random.gamma(2, 100)
|
||||||
|
|
||||||
|
# Create new candle
|
||||||
|
new_candle = {
|
||||||
|
'timestamp': last_timestamp,
|
||||||
|
'open': open_price,
|
||||||
|
'high': high_price,
|
||||||
|
'low': low_price,
|
||||||
|
'close': close_price,
|
||||||
|
'volume': volume
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update last candle
|
||||||
|
last_candle = new_candle.copy()
|
||||||
|
|
||||||
|
yield new_candle
|
||||||
|
|
||||||
|
class MockExchange:
|
||||||
|
"""Mock exchange for demo mode"""
|
||||||
|
def __init__(self):
|
||||||
|
self.name = "MockExchange"
|
||||||
|
self.id = "mock"
|
||||||
|
|
||||||
|
async def fetch_ohlcv(self, symbol, timeframe, limit=1000):
|
||||||
|
"""Mock method to fetch OHLCV data"""
|
||||||
|
# Generate mock data
|
||||||
|
mock_data = generate_mock_data(symbol, timeframe, limit)
|
||||||
|
|
||||||
|
# Convert to CCXT format
|
||||||
|
ohlcv = []
|
||||||
|
for candle in mock_data:
|
||||||
|
ohlcv.append([
|
||||||
|
candle['timestamp'],
|
||||||
|
candle['open'],
|
||||||
|
candle['high'],
|
||||||
|
candle['low'],
|
||||||
|
candle['close'],
|
||||||
|
candle['volume']
|
||||||
|
])
|
||||||
|
|
||||||
|
return ohlcv
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Mock method to close exchange connection"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_model_info(model_path):
|
||||||
|
"""Extract model architecture information from saved model file"""
|
||||||
|
try:
|
||||||
|
# Load checkpoint with weights_only=False to get all information
|
||||||
|
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
||||||
|
|
||||||
|
# Extract model parameters
|
||||||
|
state_size = checkpoint['policy_net']['fc1.weight'].shape[1]
|
||||||
|
action_size = checkpoint['policy_net']['advantage_stream.bias'].shape[0]
|
||||||
|
hidden_size = checkpoint['policy_net']['fc1.weight'].shape[0]
|
||||||
|
|
||||||
|
# Try to extract LSTM layers and attention heads
|
||||||
|
lstm_layers = 2 # Default
|
||||||
|
attention_heads = 4 # Default
|
||||||
|
|
||||||
|
# Check if these parameters are stored in the checkpoint
|
||||||
|
if 'lstm_layers' in checkpoint:
|
||||||
|
lstm_layers = checkpoint['lstm_layers']
|
||||||
|
if 'attention_heads' in checkpoint:
|
||||||
|
attention_heads = checkpoint['attention_heads']
|
||||||
|
|
||||||
|
logger.info(f"Extracted model architecture: state_size={state_size}, action_size={action_size}, "
|
||||||
|
f"hidden_size={hidden_size}, lstm_layers={lstm_layers}, attention_heads={attention_heads}")
|
||||||
|
|
||||||
|
return state_size, action_size, hidden_size, lstm_layers, attention_heads
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to extract model info: {str(e)}")
|
||||||
|
logger.warning("Using default model architecture")
|
||||||
|
return 40, 3, 384, 2, 4 # Default values
|
||||||
|
|
||||||
|
async def run_live_demo():
|
||||||
|
"""Run the trading bot in live demo mode with enhanced error handling"""
|
||||||
|
parser = argparse.ArgumentParser(description='Live Trading Demo')
|
||||||
|
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
|
||||||
|
parser.add_argument('--timeframe', type=str, default='1m', help='Candle timeframe')
|
||||||
|
parser.add_argument('--model', type=str, default='models/trading_agent_best_pnl.pt', help='Path to model file')
|
||||||
|
parser.add_argument('--mock', action='store_true', help='Use mock data instead of real exchange data')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Import main module
|
||||||
|
import main
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Create directories if they don't exist
|
||||||
|
os.makedirs("trade_logs", exist_ok=True)
|
||||||
|
os.makedirs("runs", exist_ok=True)
|
||||||
|
|
||||||
|
# Check if model file exists
|
||||||
|
if not os.path.exists(args.model):
|
||||||
|
logger.error(f"Model file not found: {args.model}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
logger.info(f"Starting live trading demo for {args.symbol} on {args.timeframe} timeframe")
|
||||||
|
logger.info(f"Using model: {args.model}")
|
||||||
|
|
||||||
|
# Check API keys
|
||||||
|
api_key = os.getenv('MEXC_API_KEY')
|
||||||
|
secret_key = os.getenv('MEXC_SECRET_KEY')
|
||||||
|
use_mock = args.mock or not api_key or api_key == "your_api_key_here"
|
||||||
|
|
||||||
|
if use_mock:
|
||||||
|
logger.info("Using mock data for demo mode (no API keys required)")
|
||||||
|
exchange = MockExchange()
|
||||||
|
else:
|
||||||
|
# Initialize real exchange
|
||||||
|
exchange = await main.initialize_exchange()
|
||||||
|
|
||||||
|
# Initialize environment
|
||||||
|
env = main.TradingEnvironment(
|
||||||
|
initial_balance=float(os.getenv('INITIAL_BALANCE', 1000)),
|
||||||
|
window_size=30,
|
||||||
|
demo=True # Always use demo mode in this script
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch initial data
|
||||||
|
if use_mock:
|
||||||
|
# Use mock data
|
||||||
|
mock_data = generate_mock_data(args.symbol, args.timeframe, 1000)
|
||||||
|
env.data = mock_data
|
||||||
|
env._initialize_features()
|
||||||
|
success = True
|
||||||
|
else:
|
||||||
|
# Fetch real data
|
||||||
|
success = await env.fetch_initial_data(
|
||||||
|
exchange,
|
||||||
|
symbol=args.symbol,
|
||||||
|
timeframe=args.timeframe,
|
||||||
|
limit=1000
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
logger.error("Failed to fetch initial data. Exiting.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Get model architecture from saved model
|
||||||
|
state_size, action_size, hidden_size, lstm_layers, attention_heads = get_model_info(args.model)
|
||||||
|
|
||||||
|
# Initialize agent with the correct architecture
|
||||||
|
agent = main.Agent(
|
||||||
|
state_size=state_size,
|
||||||
|
action_size=action_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
lstm_layers=lstm_layers,
|
||||||
|
attention_heads=attention_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load model with weights_only=False to handle numpy types
|
||||||
|
try:
|
||||||
|
# First try with weights_only=True (safer)
|
||||||
|
agent.load(args.model)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load model with weights_only=True: {str(e)}")
|
||||||
|
|
||||||
|
# Try with safe_globals
|
||||||
|
try:
|
||||||
|
import torch.serialization
|
||||||
|
with torch.serialization.safe_globals(['numpy._core.multiarray.scalar', 'numpy.dtype']):
|
||||||
|
agent.load(args.model)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed with safe_globals: {str(e)}")
|
||||||
|
|
||||||
|
# Last resort: try with weights_only=False
|
||||||
|
try:
|
||||||
|
# Monkey patch the load method temporarily
|
||||||
|
original_load = main.Agent.load
|
||||||
|
|
||||||
|
def patched_load(self, path):
|
||||||
|
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
||||||
|
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
||||||
|
self.target_net.load_state_dict(checkpoint['target_net'])
|
||||||
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
self.epsilon = checkpoint.get('epsilon', 0.05)
|
||||||
|
logger.info(f"Model loaded from {path}")
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
main.Agent.load = patched_load
|
||||||
|
|
||||||
|
# Try loading
|
||||||
|
agent.load(args.model)
|
||||||
|
|
||||||
|
# Restore original method
|
||||||
|
main.Agent.load = original_load
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"All loading attempts failed: {str(e)}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
logger.info(f"Model loaded successfully")
|
||||||
|
|
||||||
|
# Initialize TensorBoard writer
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
agent.writer = SummaryWriter(f'runs/live_demo_{args.symbol.replace("/", "_")}_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}')
|
||||||
|
|
||||||
|
# Track performance metrics
|
||||||
|
trades_count = 0
|
||||||
|
winning_trades = 0
|
||||||
|
total_profit = 0
|
||||||
|
max_drawdown = 0
|
||||||
|
peak_balance = env.balance
|
||||||
|
step_counter = 0
|
||||||
|
prev_position = 'flat'
|
||||||
|
|
||||||
|
# Create trade log file
|
||||||
|
os.makedirs('trade_logs', exist_ok=True)
|
||||||
|
trade_log_path = f'trade_logs/trades_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}.csv'
|
||||||
|
with open(trade_log_path, 'w') as f:
|
||||||
|
f.write("timestamp,action,price,position_size,balance,pnl\n")
|
||||||
|
|
||||||
|
logger.info(f"Starting live trading simulation...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set up mock live data generator if using mock data
|
||||||
|
if use_mock:
|
||||||
|
live_candle_generator = generate_mock_live_candles(env.data, args.symbol, args.timeframe)
|
||||||
|
|
||||||
|
# Main trading loop
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Get latest candle
|
||||||
|
if use_mock:
|
||||||
|
# Get mock candle
|
||||||
|
candle = await anext(live_candle_generator)
|
||||||
|
else:
|
||||||
|
# Get real candle
|
||||||
|
candle = await main.get_latest_candle(exchange, args.symbol)
|
||||||
|
|
||||||
|
if candle is None:
|
||||||
|
logger.warning("Failed to fetch latest candle, retrying in 5 seconds...")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add new data to environment
|
||||||
|
env.add_data(candle)
|
||||||
|
|
||||||
|
# Get current state and select action
|
||||||
|
state = env.get_state()
|
||||||
|
action = agent.select_action(state, training=False)
|
||||||
|
|
||||||
|
# Update environment with action (simulated)
|
||||||
|
next_state, reward, done = env.step(action)
|
||||||
|
|
||||||
|
# Create info dictionary (missing in the step function)
|
||||||
|
info = {
|
||||||
|
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
|
||||||
|
'price': env.current_price,
|
||||||
|
'balance': env.balance,
|
||||||
|
'position': env.position,
|
||||||
|
'pnl': env.total_pnl
|
||||||
|
}
|
||||||
|
|
||||||
|
# Log trade if position changed
|
||||||
|
if env.position != prev_position:
|
||||||
|
trades_count += 1
|
||||||
|
if env.last_trade_profit > 0:
|
||||||
|
winning_trades += 1
|
||||||
|
total_profit += env.last_trade_profit
|
||||||
|
|
||||||
|
# Log trade details
|
||||||
|
with open(trade_log_path, 'a') as f:
|
||||||
|
f.write(f"{datetime.datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n")
|
||||||
|
|
||||||
|
logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}")
|
||||||
|
|
||||||
|
# Update performance metrics
|
||||||
|
if env.balance > peak_balance:
|
||||||
|
peak_balance = env.balance
|
||||||
|
current_drawdown = (peak_balance - env.balance) / peak_balance if peak_balance > 0 else 0
|
||||||
|
if current_drawdown > max_drawdown:
|
||||||
|
max_drawdown = current_drawdown
|
||||||
|
|
||||||
|
# Update TensorBoard metrics
|
||||||
|
step_counter += 1
|
||||||
|
agent.writer.add_scalar('Live/Balance', env.balance, step_counter)
|
||||||
|
agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter)
|
||||||
|
agent.writer.add_scalar('Live/Drawdown', current_drawdown * 100, step_counter)
|
||||||
|
|
||||||
|
# Update chart visualization
|
||||||
|
if step_counter % 5 == 0 or env.position != prev_position:
|
||||||
|
agent.add_chart_to_tensorboard(env, step_counter)
|
||||||
|
|
||||||
|
# Log performance summary
|
||||||
|
if trades_count > 0:
|
||||||
|
win_rate = (winning_trades / trades_count) * 100
|
||||||
|
agent.writer.add_scalar('Live/WinRate', win_rate, step_counter)
|
||||||
|
|
||||||
|
performance_text = f"""
|
||||||
|
**Live Trading Performance**
|
||||||
|
Balance: ${env.balance:.2f}
|
||||||
|
Total PnL: ${env.total_pnl:.2f}
|
||||||
|
Trades: {trades_count}
|
||||||
|
Win Rate: {win_rate:.1f}%
|
||||||
|
Max Drawdown: {max_drawdown*100:.1f}%
|
||||||
|
"""
|
||||||
|
agent.writer.add_text('Performance', performance_text, step_counter)
|
||||||
|
|
||||||
|
prev_position = env.position
|
||||||
|
|
||||||
|
# Wait for next candle
|
||||||
|
await asyncio.sleep(1) # Faster updates in demo mode
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in live trading loop: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.info("Continuing after error...")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Live trading stopped by user")
|
||||||
|
|
||||||
|
# Final performance report
|
||||||
|
if trades_count > 0:
|
||||||
|
win_rate = (winning_trades / trades_count) * 100
|
||||||
|
logger.info(f"Trading session summary:")
|
||||||
|
logger.info(f"Total trades: {trades_count}")
|
||||||
|
logger.info(f"Win rate: {win_rate:.1f}%")
|
||||||
|
logger.info(f"Final balance: ${env.balance:.2f}")
|
||||||
|
logger.info(f"Total profit: ${total_profit:.2f}")
|
||||||
|
logger.info(f"Maximum drawdown: {max_drawdown*100:.1f}%")
|
||||||
|
logger.info(f"Trade log saved to: {trade_log_path}")
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Live trading stopped by user")
|
||||||
|
return 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in live trading: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Set environment variable to indicate we're in demo mode
|
||||||
|
os.environ['DEMO_MODE'] = 'true'
|
||||||
|
|
||||||
|
# Print banner
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("🤖 TRADING BOT - LIVE DEMO MODE 🤖")
|
||||||
|
print("="*60)
|
||||||
|
print("This is a DEMO mode with simulated trading (no real trades)")
|
||||||
|
print("Press Ctrl+C to stop the bot at any time")
|
||||||
|
print("="*60 + "\n")
|
||||||
|
|
||||||
|
# Run the async main function
|
||||||
|
exit_code = asyncio.run(run_live_demo())
|
||||||
|
sys.exit(exit_code)
|
69
crypto/gogo2/run_tensorboard.py
Normal file
69
crypto/gogo2/run_tensorboard.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
import webbrowser
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def run_tensorboard():
|
||||||
|
"""Run TensorBoard server and open browser"""
|
||||||
|
parser = argparse.ArgumentParser(description='TensorBoard Launcher')
|
||||||
|
parser.add_argument('--port', type=int, default=6006, help='Port for TensorBoard server')
|
||||||
|
parser.add_argument('--logdir', type=str, default='runs', help='Log directory for TensorBoard')
|
||||||
|
parser.add_argument('--no-browser', action='store_true', help='Do not open browser automatically')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create log directory if it doesn't exist
|
||||||
|
os.makedirs(args.logdir, exist_ok=True)
|
||||||
|
|
||||||
|
# Print banner
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("📊 TRADING BOT - TENSORBOARD MONITORING 📊")
|
||||||
|
print("="*60)
|
||||||
|
print(f"Starting TensorBoard server on port {args.port}")
|
||||||
|
print(f"Log directory: {args.logdir}")
|
||||||
|
print("Press Ctrl+C to stop the server")
|
||||||
|
print("="*60 + "\n")
|
||||||
|
|
||||||
|
# Start TensorBoard server
|
||||||
|
cmd = ["tensorboard", "--logdir", args.logdir, "--port", str(args.port)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Start TensorBoard process
|
||||||
|
process = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
universal_newlines=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for TensorBoard to start
|
||||||
|
time.sleep(3)
|
||||||
|
|
||||||
|
# Open browser
|
||||||
|
if not args.no_browser:
|
||||||
|
url = f"http://localhost:{args.port}"
|
||||||
|
print(f"Opening browser to {url}")
|
||||||
|
webbrowser.open(url)
|
||||||
|
|
||||||
|
# Print TensorBoard output
|
||||||
|
while True:
|
||||||
|
output = process.stdout.readline()
|
||||||
|
if output == '' and process.poll() is not None:
|
||||||
|
break
|
||||||
|
if output:
|
||||||
|
print(output.strip())
|
||||||
|
|
||||||
|
return process.poll()
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nStopping TensorBoard server...")
|
||||||
|
process.terminate()
|
||||||
|
return 0
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error running TensorBoard: {str(e)}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit_code = run_tensorboard()
|
||||||
|
sys.exit(exit_code)
|
14
crypto/gogo2/start_live_trading.ps1
Normal file
14
crypto/gogo2/start_live_trading.ps1
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# PowerShell script to start live trading demo and TensorBoard
|
||||||
|
|
||||||
|
Write-Host "Starting Trading Bot Live Demo..." -ForegroundColor Green
|
||||||
|
|
||||||
|
# Create a new PowerShell window for TensorBoard
|
||||||
|
Start-Process powershell -ArgumentList "-Command python run_tensorboard.py" -WindowStyle Normal
|
||||||
|
|
||||||
|
# Wait a moment for TensorBoard to start
|
||||||
|
Write-Host "Starting TensorBoard... Please wait" -ForegroundColor Yellow
|
||||||
|
Start-Sleep -Seconds 5
|
||||||
|
|
||||||
|
# Start the live trading demo in the current window
|
||||||
|
Write-Host "Starting Live Trading Demo with mock data..." -ForegroundColor Green
|
||||||
|
python run_live_demo.py --symbol ETH/USDT --timeframe 1m --model models/trading_agent_best_pnl.pt --mock
|
File diff suppressed because it is too large
Load Diff
@ -1 +1 @@
|
|||||||
episode_rewards,episode_lengths,balances,win_rates,episode_pnls,cumulative_pnl,drawdowns,prediction_accuracy
|
episode_rewards,episode_lengths,balances,win_rates,episode_pnls,cumulative_pnl,drawdowns,prediction_accuracy,trade_analysis
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user