fix training & demo mode
This commit is contained in:
parent
c63b1a2daf
commit
e87207c1fa
@ -3,7 +3,7 @@ import time
|
|||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import datetime
|
from datetime import datetime
|
||||||
import random
|
import random
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -28,6 +28,8 @@ from matplotlib.figure import Figure
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import matplotlib.pyplot as mpf
|
import matplotlib.pyplot as mpf
|
||||||
import matplotlib.gridspec as gridspec
|
import matplotlib.gridspec as gridspec
|
||||||
|
import datetime
|
||||||
|
from datetime import datetime as dt
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -53,7 +55,7 @@ GAMMA = 0.99 # Discount factor
|
|||||||
EPSILON_START = 1.0
|
EPSILON_START = 1.0
|
||||||
EPSILON_END = 0.05
|
EPSILON_END = 0.05
|
||||||
EPSILON_DECAY = 10000
|
EPSILON_DECAY = 10000
|
||||||
STATE_SIZE = 40 # Size of our state representation
|
STATE_SIZE = 64 # Size of our state representation
|
||||||
LEARNING_RATE = 1e-4
|
LEARNING_RATE = 1e-4
|
||||||
TARGET_UPDATE = 10 # Update target network every 10 episodes
|
TARGET_UPDATE = 10 # Update target network every 10 episodes
|
||||||
|
|
||||||
@ -1308,15 +1310,19 @@ class TradingEnvironment:
|
|||||||
|
|
||||||
def update_price_predictions(self):
|
def update_price_predictions(self):
|
||||||
"""Update price predictions"""
|
"""Update price predictions"""
|
||||||
if len(self.features['price']) < 30:
|
if len(self.features['price']) < 30 or not hasattr(self, 'price_predictor') or self.price_predictor is None:
|
||||||
self.predicted_prices = np.array([])
|
self.predicted_prices = np.array([])
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get price history
|
# Get price history
|
||||||
price_history = self.features['price']
|
price_history = self.features['price']
|
||||||
|
|
||||||
# Get predictions
|
try:
|
||||||
self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5)
|
# Get predictions
|
||||||
|
self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error updating predictions: {e}")
|
||||||
|
self.predicted_prices = np.array([])
|
||||||
|
|
||||||
def identify_optimal_trades(self):
|
def identify_optimal_trades(self):
|
||||||
"""Identify optimal entry and exit points based on local extrema"""
|
"""Identify optimal entry and exit points based on local extrema"""
|
||||||
@ -1609,39 +1615,40 @@ def get_device():
|
|||||||
|
|
||||||
# Update Agent class to use GPU properly
|
# Update Agent class to use GPU properly
|
||||||
class Agent:
|
class Agent:
|
||||||
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4,
|
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4, device=None):
|
||||||
device=None):
|
"""Initialize Agent with architecture parameters stored as attributes"""
|
||||||
if device is None:
|
|
||||||
self.device = get_device()
|
|
||||||
else:
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
self.state_size = state_size
|
self.state_size = state_size
|
||||||
self.action_size = action_size
|
self.action_size = action_size
|
||||||
self.memory = ReplayMemory(MEMORY_SIZE)
|
self.hidden_size = hidden_size # Store hidden_size as an instance attribute
|
||||||
self.steps_done = 0
|
self.lstm_layers = lstm_layers # Store lstm_layers as an instance attribute
|
||||||
self.epsilon = EPSILON_START # Initialize epsilon
|
self.attention_heads = attention_heads # Store attention_heads as an instance attribute
|
||||||
|
|
||||||
# Initialize policy and target networks
|
# Set device
|
||||||
|
self.device = device if device is not None else get_device()
|
||||||
|
|
||||||
|
# Initialize networks
|
||||||
self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
||||||
self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
||||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||||
self.target_net.eval()
|
|
||||||
|
|
||||||
# Initialize optimizer with weight decay for regularization
|
# Initialize optimizer
|
||||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
|
||||||
|
|
||||||
# Initialize gradient scaler for mixed precision training
|
# Initialize replay memory
|
||||||
self.scaler = amp.GradScaler()
|
self.memory = ReplayMemory(MEMORY_SIZE)
|
||||||
|
|
||||||
# TensorBoard writer
|
# Initialize exploration parameters
|
||||||
self.writer = SummaryWriter()
|
self.epsilon = EPSILON_START
|
||||||
|
self.epsilon_decay = EPSILON_DECAY
|
||||||
|
self.epsilon_min = EPSILON_END
|
||||||
|
|
||||||
# For chart visualization
|
# Initialize step counter
|
||||||
self.chart_step = 0
|
self.steps_done = 0
|
||||||
|
|
||||||
# Create models directory if it doesn't exist
|
# Initialize TensorBoard writer
|
||||||
os.makedirs("models", exist_ok=True)
|
self.writer = None
|
||||||
|
|
||||||
|
# Rest of the initialization code...
|
||||||
|
|
||||||
def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8):
|
def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8):
|
||||||
"""Expand the model to handle more features or increase capacity"""
|
"""Expand the model to handle more features or increase capacity"""
|
||||||
@ -1727,7 +1734,7 @@ class Agent:
|
|||||||
|
|
||||||
# Use mixed precision for forward/backward passes
|
# Use mixed precision for forward/backward passes
|
||||||
if self.device.type == "cuda":
|
if self.device.type == "cuda":
|
||||||
with amp.autocast():
|
with torch.amp.autocast('cuda'):
|
||||||
# Compute Q values
|
# Compute Q values
|
||||||
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
|
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
|
||||||
|
|
||||||
@ -1795,54 +1802,74 @@ class Agent:
|
|||||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||||
|
|
||||||
def save(self, path="models/trading_agent_best_pnl.pt"):
|
def save(self, path="models/trading_agent_best_pnl.pt"):
|
||||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
"""Save the model in a format compatible with PyTorch 2.6+"""
|
||||||
torch.save({
|
try:
|
||||||
'policy_net': self.policy_net.state_dict(),
|
# Create directory if it doesn't exist
|
||||||
'target_net': self.target_net.state_dict(),
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
'optimizer': self.optimizer.state_dict(),
|
|
||||||
'epsilon': self.epsilon,
|
# Ensure architecture parameters are set
|
||||||
'steps_done': self.steps_done
|
if not hasattr(self, 'hidden_size'):
|
||||||
}, path)
|
self.hidden_size = 256 # Default value
|
||||||
logger.info(f"Model saved to {path}")
|
logger.warning("Setting default hidden_size=256 for saving")
|
||||||
|
|
||||||
|
if not hasattr(self, 'lstm_layers'):
|
||||||
|
self.lstm_layers = 2 # Default value
|
||||||
|
logger.warning("Setting default lstm_layers=2 for saving")
|
||||||
|
|
||||||
|
if not hasattr(self, 'attention_heads'):
|
||||||
|
self.attention_heads = 4 # Default value
|
||||||
|
logger.warning("Setting default attention_heads=4 for saving")
|
||||||
|
|
||||||
|
# Save model state
|
||||||
|
checkpoint = {
|
||||||
|
'policy_net': self.policy_net.state_dict(),
|
||||||
|
'target_net': self.target_net.state_dict(),
|
||||||
|
'optimizer': self.optimizer.state_dict(),
|
||||||
|
'epsilon': self.epsilon,
|
||||||
|
'state_size': self.state_size,
|
||||||
|
'action_size': self.action_size,
|
||||||
|
'hidden_size': self.hidden_size,
|
||||||
|
'lstm_layers': self.lstm_layers,
|
||||||
|
'attention_heads': self.attention_heads
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save with pickle_protocol=4 for better compatibility
|
||||||
|
torch.save(checkpoint, path, _use_new_zipfile_serialization=True, pickle_protocol=4)
|
||||||
|
logger.info(f"Model saved to {path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving model: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
def load(self, path="models/trading_agent_best_pnl.pt"):
|
def load(self, path="models/trading_agent_best_pnl.pt"):
|
||||||
"""Load a trained model"""
|
"""Load a trained model with improved error handling for PyTorch 2.6 compatibility"""
|
||||||
try:
|
try:
|
||||||
# First try with weights_only=True (safer)
|
# First try to load with weights_only=False (for models saved with older PyTorch versions)
|
||||||
checkpoint = torch.load(path, map_location=self.device)
|
|
||||||
|
|
||||||
# Check if model architecture matches
|
|
||||||
try:
|
try:
|
||||||
state_dict = checkpoint['policy_net']
|
logger.info(f"Attempting to load model with weights_only=False: {path}")
|
||||||
input_size = state_dict['fc1.weight'].shape[1]
|
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
||||||
output_size = state_dict['advantage_stream.bias'].shape[0]
|
logger.info("Model loaded successfully with weights_only=False")
|
||||||
hidden_size = state_dict['fc1.weight'].shape[0]
|
except Exception as e1:
|
||||||
|
logger.warning(f"Failed to load with weights_only=False: {e1}")
|
||||||
|
|
||||||
# If architecture doesn't match, rebuild the model
|
# Try with safe_globals context manager
|
||||||
if (input_size != self.state_size or
|
try:
|
||||||
output_size != self.action_size or
|
logger.info("Attempting to load with safe_globals context manager")
|
||||||
hidden_size != self.hidden_size):
|
import numpy as np
|
||||||
logger.warning(f"Model architecture mismatch. Rebuilding model with: "
|
from torch.serialization import safe_globals
|
||||||
f"state_size={input_size}, action_size={output_size}, "
|
|
||||||
f"hidden_size={hidden_size}")
|
|
||||||
|
|
||||||
# Rebuild the model with the correct architecture
|
# Add numpy scalar to safe globals
|
||||||
self.state_size = input_size
|
with safe_globals(['numpy._core.multiarray.scalar']):
|
||||||
self.action_size = output_size
|
checkpoint = torch.load(path, map_location=self.device)
|
||||||
self.hidden_size = hidden_size
|
logger.info("Model loaded successfully with safe_globals")
|
||||||
|
except Exception as e2:
|
||||||
|
logger.warning(f"Failed to load with safe_globals: {e2}")
|
||||||
|
|
||||||
# Recreate networks with correct architecture
|
# Last resort: try with pickle_module=pickle
|
||||||
self.policy_net = DQN(self.state_size, self.action_size,
|
logger.info("Attempting to load with pickle_module")
|
||||||
self.hidden_size, self.lstm_layers,
|
import pickle
|
||||||
self.attention_heads).to(self.device)
|
checkpoint = torch.load(path, map_location=self.device, pickle_module=pickle, weights_only=False)
|
||||||
self.target_net = DQN(self.state_size, self.action_size,
|
logger.info("Model loaded successfully with pickle_module")
|
||||||
self.hidden_size, self.lstm_layers,
|
|
||||||
self.attention_heads).to(self.device)
|
|
||||||
|
|
||||||
# Recreate optimizer
|
|
||||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error checking model architecture: {e}")
|
|
||||||
|
|
||||||
# Load state dictionaries
|
# Load state dictionaries
|
||||||
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
||||||
@ -1858,28 +1885,38 @@ class Agent:
|
|||||||
if 'epsilon' in checkpoint:
|
if 'epsilon' in checkpoint:
|
||||||
self.epsilon = checkpoint['epsilon']
|
self.epsilon = checkpoint['epsilon']
|
||||||
|
|
||||||
logger.info(f"Model loaded from {path}")
|
# Load architecture parameters if available
|
||||||
except Exception as e:
|
if 'state_size' in checkpoint:
|
||||||
logger.warning(f"Error loading model with default method: {e}")
|
self.state_size = checkpoint['state_size']
|
||||||
|
if 'action_size' in checkpoint:
|
||||||
# Try with weights_only=False as fallback
|
self.action_size = checkpoint['action_size']
|
||||||
try:
|
if 'hidden_size' in checkpoint:
|
||||||
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
self.hidden_size = checkpoint['hidden_size']
|
||||||
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
else:
|
||||||
self.target_net.load_state_dict(checkpoint['target_net'])
|
# If hidden_size not in checkpoint, infer from model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
self.hidden_size = self.policy_net.fc1.weight.shape[0]
|
||||||
|
logger.info(f"Inferred hidden_size={self.hidden_size} from model")
|
||||||
except:
|
except:
|
||||||
logger.warning("Could not load optimizer state")
|
self.hidden_size = 256 # Default value
|
||||||
|
logger.warning(f"Could not infer hidden_size, using default: {self.hidden_size}")
|
||||||
|
|
||||||
if 'epsilon' in checkpoint:
|
if 'lstm_layers' in checkpoint:
|
||||||
self.epsilon = checkpoint['epsilon']
|
self.lstm_layers = checkpoint['lstm_layers']
|
||||||
|
else:
|
||||||
|
self.lstm_layers = 2 # Default value
|
||||||
|
|
||||||
logger.info(f"Model loaded from {path} with weights_only=False")
|
if 'attention_heads' in checkpoint:
|
||||||
except Exception as e:
|
self.attention_heads = checkpoint['attention_heads']
|
||||||
logger.error(f"Failed to load model: {e}")
|
else:
|
||||||
raise
|
self.attention_heads = 4 # Default value
|
||||||
|
|
||||||
|
logger.info(f"Model loaded successfully from {path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading model: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
raise
|
||||||
|
|
||||||
def add_chart_to_tensorboard(self, env, global_step):
|
def add_chart_to_tensorboard(self, env, global_step):
|
||||||
"""Add trading chart to TensorBoard"""
|
"""Add trading chart to TensorBoard"""
|
||||||
@ -1903,9 +1940,10 @@ class Agent:
|
|||||||
self.writer.add_image('Trading Chart', chart_array, global_step)
|
self.writer.add_image('Trading Chart', chart_array, global_step)
|
||||||
|
|
||||||
# Add position information as text
|
# Add position information as text
|
||||||
|
entry_price = env.entry_price if env.entry_price else 0.00
|
||||||
position_info = f"""
|
position_info = f"""
|
||||||
**Current Position**: {env.position.upper()}
|
**Current Position**: {env.position.upper()}
|
||||||
**Entry Price**: ${env.entry_price:.2f if env.entry_price else 0:.2f}
|
**Entry Price**: ${entry_price:.2f}
|
||||||
**Current Price**: ${env.data[-1]['close']:.2f}
|
**Current Price**: ${env.data[-1]['close']:.2f}
|
||||||
**Position Size**: ${env.position_size:.2f}
|
**Position Size**: ${env.position_size:.2f}
|
||||||
**Unrealized PnL**: ${env.total_pnl:.2f}
|
**Unrealized PnL**: ${env.total_pnl:.2f}
|
||||||
@ -2345,21 +2383,30 @@ async def get_historical_data(exchange, symbol="ETH/USDT", timeframe="1m", limit
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50):
|
async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50):
|
||||||
"""
|
"""Run the trading bot in live mode with enhanced error handling"""
|
||||||
Run the trading bot in live mode with enhanced error handling and monitoring
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent: Trained trading agent
|
|
||||||
env: Trading environment
|
|
||||||
exchange: CCXT exchange instance
|
|
||||||
symbol: Trading pair (default: ETH/USDT)
|
|
||||||
timeframe: Candle timeframe (default: 1m)
|
|
||||||
demo: If True, simulate trades without executing (default: True)
|
|
||||||
leverage: Leverage for futures trading (default: 50x)
|
|
||||||
"""
|
|
||||||
logger.info(f"Starting live trading for {symbol} on {timeframe} timeframe")
|
logger.info(f"Starting live trading for {symbol} on {timeframe} timeframe")
|
||||||
logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}")
|
logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}")
|
||||||
|
|
||||||
|
# Verify agent is properly initialized
|
||||||
|
try:
|
||||||
|
# Ensure agent has all required attributes
|
||||||
|
if not hasattr(agent, 'hidden_size'):
|
||||||
|
agent.hidden_size = 256 # Default value
|
||||||
|
logger.warning("Agent missing hidden_size attribute, using default: 256")
|
||||||
|
|
||||||
|
if not hasattr(agent, 'lstm_layers'):
|
||||||
|
agent.lstm_layers = 2 # Default value
|
||||||
|
logger.warning("Agent missing lstm_layers attribute, using default: 2")
|
||||||
|
|
||||||
|
if not hasattr(agent, 'attention_heads'):
|
||||||
|
agent.attention_heads = 4 # Default value
|
||||||
|
logger.warning("Agent missing attention_heads attribute, using default: 4")
|
||||||
|
|
||||||
|
logger.info(f"Agent configuration: state_size={agent.state_size}, action_size={agent.action_size}, hidden_size={agent.hidden_size}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking agent configuration: {e}")
|
||||||
|
# Continue anyway, as these are just informational attributes
|
||||||
|
|
||||||
if not demo:
|
if not demo:
|
||||||
# Confirm with user before starting live trading
|
# Confirm with user before starting live trading
|
||||||
confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ")
|
confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ")
|
||||||
@ -2378,7 +2425,10 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
|
|||||||
|
|
||||||
# Initialize TensorBoard for monitoring
|
# Initialize TensorBoard for monitoring
|
||||||
if not hasattr(agent, 'writer') or agent.writer is None:
|
if not hasattr(agent, 'writer') or agent.writer is None:
|
||||||
agent.writer = SummaryWriter(f'runs/live_{symbol.replace("/", "_")}_{datetime.now().strftime("%Y%m%d_%H%M%S")}')
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
# Fix the datetime usage here
|
||||||
|
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
agent.writer = SummaryWriter(f'runs/live_{symbol.replace("/", "_")}_{current_time}')
|
||||||
|
|
||||||
# Track performance metrics
|
# Track performance metrics
|
||||||
trades_count = 0
|
trades_count = 0
|
||||||
@ -2391,10 +2441,14 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
|
|||||||
|
|
||||||
# Create directory for trade logs
|
# Create directory for trade logs
|
||||||
os.makedirs('trade_logs', exist_ok=True)
|
os.makedirs('trade_logs', exist_ok=True)
|
||||||
trade_log_path = f'trade_logs/trades_{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv'
|
# Fix the datetime usage here
|
||||||
|
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
trade_log_path = f'trade_logs/trades_{current_time}.csv'
|
||||||
with open(trade_log_path, 'w') as f:
|
with open(trade_log_path, 'w') as f:
|
||||||
f.write("timestamp,action,price,position_size,balance,pnl\n")
|
f.write("timestamp,action,price,position_size,balance,pnl\n")
|
||||||
|
|
||||||
|
logger.info("Entering live trading loop...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -2410,19 +2464,54 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
|
|||||||
|
|
||||||
# Get current state and select action
|
# Get current state and select action
|
||||||
state = env.get_state()
|
state = env.get_state()
|
||||||
|
|
||||||
|
# Verify state shape matches agent's expected input
|
||||||
|
if state.shape[0] != agent.state_size:
|
||||||
|
logger.warning(f"State size mismatch: got {state.shape[0]}, expected {agent.state_size}")
|
||||||
|
# Pad or truncate state to match expected size
|
||||||
|
if state.shape[0] < agent.state_size:
|
||||||
|
state = np.pad(state, (0, agent.state_size - state.shape[0]))
|
||||||
|
else:
|
||||||
|
state = state[:agent.state_size]
|
||||||
|
|
||||||
action = agent.select_action(state, training=False)
|
action = agent.select_action(state, training=False)
|
||||||
|
|
||||||
|
# Ensure action is valid
|
||||||
|
if action >= agent.action_size:
|
||||||
|
logger.warning(f"Invalid action {action}, clipping to {agent.action_size-1}")
|
||||||
|
action = agent.action_size - 1
|
||||||
|
|
||||||
|
# Log action
|
||||||
|
action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE"
|
||||||
|
logger.info(f"Step {step_counter}: Action selected: {action_name}, Price: ${env.data[-1]['close']:.2f}")
|
||||||
|
|
||||||
# Execute action
|
# Execute action
|
||||||
if not demo:
|
if not demo:
|
||||||
# Execute real trade on exchange
|
# Execute real trade on exchange
|
||||||
current_price = env.data[-1]['close']
|
current_price = env.data[-1]['close']
|
||||||
trade_result = await env.execute_real_trade(exchange, action, current_price)
|
trade_result = await env.execute_real_trade(exchange, action, current_price)
|
||||||
if not trade_result['success']:
|
if trade_result is None or not isinstance(trade_result, dict) or not trade_result.get('success', False):
|
||||||
logger.error(f"Trade execution failed: {trade_result['error']}")
|
error_msg = trade_result.get('error', 'Unknown error') if isinstance(trade_result, dict) else 'Trade execution failed'
|
||||||
|
logger.error(f"Trade execution failed: {error_msg}")
|
||||||
# Continue with simulated trade for tracking purposes
|
# Continue with simulated trade for tracking purposes
|
||||||
|
|
||||||
# Update environment with action (simulated in demo mode)
|
# Update environment with action (simulated in demo mode)
|
||||||
next_state, reward, done, info = env.step(action)
|
try:
|
||||||
|
next_state, reward, done, info = env.step(action)
|
||||||
|
except ValueError as e:
|
||||||
|
# Handle case where step returns 3 values instead of 4
|
||||||
|
if "not enough values to unpack" in str(e):
|
||||||
|
logger.warning("Step function returned 3 values instead of 4, creating info dict")
|
||||||
|
next_state, reward, done = env.step(action)
|
||||||
|
info = {
|
||||||
|
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
|
||||||
|
'price': env.current_price,
|
||||||
|
'balance': env.balance,
|
||||||
|
'position': env.position,
|
||||||
|
'pnl': env.total_pnl
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
# Log trade if position changed
|
# Log trade if position changed
|
||||||
if env.position != prev_position:
|
if env.position != prev_position:
|
||||||
@ -2433,7 +2522,7 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
|
|||||||
|
|
||||||
# Log trade details
|
# Log trade details
|
||||||
with open(trade_log_path, 'a') as f:
|
with open(trade_log_path, 'a') as f:
|
||||||
f.write(f"{datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n")
|
f.write(f"{datetime.datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n")
|
||||||
|
|
||||||
logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}")
|
logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}")
|
||||||
|
|
||||||
@ -2472,10 +2561,12 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
|
|||||||
prev_position = env.position
|
prev_position = env.position
|
||||||
|
|
||||||
# Wait for next candle
|
# Wait for next candle
|
||||||
|
logger.info(f"Waiting for next candle... (Step {step_counter})")
|
||||||
await asyncio.sleep(10) # Check every 10 seconds
|
await asyncio.sleep(10) # Check every 10 seconds
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in live trading loop: {str(e)}")
|
logger.error(f"Error in live trading loop: {str(e)}")
|
||||||
|
import traceback
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
logger.info("Continuing after error...")
|
logger.info("Continuing after error...")
|
||||||
await asyncio.sleep(30) # Wait longer after an error
|
await asyncio.sleep(30) # Wait longer after an error
|
||||||
@ -2549,7 +2640,26 @@ async def fetch_ohlcv_data(exchange, symbol, timeframe, limit):
|
|||||||
logger.error(f"Error fetching OHLCV data: {e}")
|
logger.error(f"Error fetching OHLCV data: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
# Add this near the top of the file, after imports
|
||||||
|
def ensure_pytorch_compatibility():
|
||||||
|
"""Ensure compatibility with PyTorch 2.6+ for model loading"""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from torch.serialization import add_safe_globals
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Add numpy scalar to safe globals for PyTorch 2.6+
|
||||||
|
add_safe_globals(['numpy._core.multiarray.scalar'])
|
||||||
|
logger.info("Added numpy scalar to PyTorch safe globals")
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
logger.warning(f"Could not configure PyTorch compatibility: {e}")
|
||||||
|
logger.warning("This might cause issues with model loading in PyTorch 2.6+")
|
||||||
|
|
||||||
|
# Call this function at the start of the main function
|
||||||
async def main():
|
async def main():
|
||||||
|
# Ensure PyTorch compatibility
|
||||||
|
ensure_pytorch_compatibility()
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Trading Bot')
|
parser = argparse.ArgumentParser(description='Trading Bot')
|
||||||
parser.add_argument('--mode', type=str, choices=['train', 'eval', 'live'], default='train',
|
parser.add_argument('--mode', type=str, choices=['train', 'eval', 'live'], default='train',
|
||||||
help='Operation mode: train, eval, or live')
|
help='Operation mode: train, eval, or live')
|
||||||
@ -2583,60 +2693,84 @@ async def main():
|
|||||||
# Create environment
|
# Create environment
|
||||||
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode)
|
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode)
|
||||||
|
|
||||||
# Fetch initial data
|
|
||||||
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
|
|
||||||
|
|
||||||
# Create agent
|
|
||||||
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
|
||||||
|
|
||||||
if args.mode == 'train':
|
if args.mode == 'train':
|
||||||
|
# Fetch initial data for training
|
||||||
|
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
|
||||||
|
|
||||||
|
# Create agent with consistent parameters
|
||||||
|
# Note: Using STATE_SIZE and action_size=4 for consistency
|
||||||
|
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
||||||
|
|
||||||
# Train the agent
|
# Train the agent
|
||||||
logger.info(f"Starting training for {args.episodes} episodes...")
|
logger.info(f"Starting training for {args.episodes} episodes...")
|
||||||
stats = await train_agent(agent, env, num_episodes=args.episodes)
|
stats = await train_agent(agent, env, num_episodes=args.episodes)
|
||||||
|
|
||||||
elif args.mode == 'evaluate':
|
elif args.mode == 'eval' or args.mode == 'live':
|
||||||
# Load trained model
|
# Fetch initial data for the specified symbol and timeframe
|
||||||
agent.load("models/trading_agent_best_pnl.pt")
|
await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000)
|
||||||
|
|
||||||
# Evaluate the agent
|
# Determine model path
|
||||||
logger.info("Evaluating agent...")
|
model_path = args.model if args.model else "models/trading_agent_best_pnl.pt"
|
||||||
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env)
|
|
||||||
|
|
||||||
elif args.mode == 'live':
|
|
||||||
# Initialize exchange
|
|
||||||
exchange = await initialize_exchange()
|
|
||||||
|
|
||||||
# Load model
|
|
||||||
model_path = args.model if args.model else "models/trading_agent.pt"
|
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
logger.error(f"Model file not found: {model_path}")
|
logger.error(f"Model file not found: {model_path}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Initialize environment
|
# Create agent with default parameters
|
||||||
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=WINDOW_SIZE, demo=demo_mode)
|
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
||||||
await env.fetch_initial_data(exchange, symbol=args.symbol, timeframe=args.timeframe)
|
|
||||||
|
|
||||||
# Initialize agent
|
# Try to load the model
|
||||||
state_size = env.get_state().shape[0]
|
try:
|
||||||
agent = Agent(state_size=state_size, action_size=3)
|
# Add numpy scalar to safe globals before loading
|
||||||
agent.load(model_path)
|
import numpy as np
|
||||||
|
from torch.serialization import add_safe_globals
|
||||||
|
|
||||||
# Start live trading
|
# Add numpy scalar to safe globals
|
||||||
await live_trading(
|
add_safe_globals(['numpy._core.multiarray.scalar'])
|
||||||
agent=agent,
|
|
||||||
env=env,
|
|
||||||
exchange=exchange,
|
|
||||||
symbol=args.symbol,
|
|
||||||
timeframe=args.timeframe,
|
|
||||||
demo=demo_mode,
|
|
||||||
leverage=args.leverage
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
agent.load(model_path)
|
||||||
|
logger.info(f"Model loaded successfully from {model_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model: {e}")
|
||||||
|
|
||||||
|
# Ask user if they want to continue with a new model
|
||||||
|
if args.mode == 'live':
|
||||||
|
confirmation = input("Failed to load model. Continue with a new model? (y/n): ")
|
||||||
|
if confirmation.lower() != 'y':
|
||||||
|
logger.info("Live trading canceled by user")
|
||||||
|
return
|
||||||
|
logger.info("Continuing with a new model")
|
||||||
|
else:
|
||||||
|
logger.info("Continuing evaluation with a new model")
|
||||||
|
|
||||||
|
if args.mode == 'eval':
|
||||||
|
# Evaluate the agent
|
||||||
|
logger.info("Evaluating agent...")
|
||||||
|
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env, num_episodes=args.episodes)
|
||||||
|
|
||||||
|
elif args.mode == 'live':
|
||||||
|
# Start live trading
|
||||||
|
logger.info(f"Starting live trading for {args.symbol} on {args.timeframe} timeframe")
|
||||||
|
logger.info(f"Demo mode: {demo_mode}, Leverage: {args.leverage}x")
|
||||||
|
|
||||||
|
await live_trading(
|
||||||
|
agent=agent,
|
||||||
|
env=env,
|
||||||
|
exchange=exchange,
|
||||||
|
symbol=args.symbol,
|
||||||
|
timeframe=args.timeframe,
|
||||||
|
demo=demo_mode,
|
||||||
|
leverage=args.leverage
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in main function: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
finally:
|
finally:
|
||||||
# Clean up exchange connection - safely close if possible
|
# Clean up exchange connection
|
||||||
if exchange:
|
if exchange:
|
||||||
try:
|
try:
|
||||||
# Some CCXT exchanges have close method, others don't
|
|
||||||
if hasattr(exchange, 'close'):
|
if hasattr(exchange, 'close'):
|
||||||
await exchange.close()
|
await exchange.close()
|
||||||
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
|
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
|
||||||
@ -2667,7 +2801,8 @@ def create_candlestick_figure(data, trade_signals, window_size=100, title=""):
|
|||||||
|
|
||||||
# Plot candlesticks - use a simpler approach if mplfinance fails
|
# Plot candlesticks - use a simpler approach if mplfinance fails
|
||||||
try:
|
try:
|
||||||
mpf.plot(df, type='candle', style='yahoo', ax=price_ax, volume=volume_ax)
|
# Use a different style or approach that doesn't use 'type' parameter
|
||||||
|
mpf.plot(df, type='candle', ax=price_ax, volume=volume_ax, style='yahoo')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error plotting with mplfinance: {e}, falling back to simple plot")
|
logger.warning(f"Error plotting with mplfinance: {e}, falling back to simple plot")
|
||||||
# Fallback to simple plot
|
# Fallback to simple plot
|
||||||
|
Loading…
x
Reference in New Issue
Block a user