fix training & demo mode
This commit is contained in:
parent
c63b1a2daf
commit
e87207c1fa
@ -3,7 +3,7 @@ import time
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import datetime
|
||||
from datetime import datetime
|
||||
import random
|
||||
import logging
|
||||
import asyncio
|
||||
@ -28,6 +28,8 @@ from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as mpf
|
||||
import matplotlib.gridspec as gridspec
|
||||
import datetime
|
||||
from datetime import datetime as dt
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@ -53,7 +55,7 @@ GAMMA = 0.99 # Discount factor
|
||||
EPSILON_START = 1.0
|
||||
EPSILON_END = 0.05
|
||||
EPSILON_DECAY = 10000
|
||||
STATE_SIZE = 40 # Size of our state representation
|
||||
STATE_SIZE = 64 # Size of our state representation
|
||||
LEARNING_RATE = 1e-4
|
||||
TARGET_UPDATE = 10 # Update target network every 10 episodes
|
||||
|
||||
@ -1308,15 +1310,19 @@ class TradingEnvironment:
|
||||
|
||||
def update_price_predictions(self):
|
||||
"""Update price predictions"""
|
||||
if len(self.features['price']) < 30:
|
||||
if len(self.features['price']) < 30 or not hasattr(self, 'price_predictor') or self.price_predictor is None:
|
||||
self.predicted_prices = np.array([])
|
||||
return
|
||||
|
||||
# Get price history
|
||||
price_history = self.features['price']
|
||||
|
||||
# Get predictions
|
||||
self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5)
|
||||
try:
|
||||
# Get predictions
|
||||
self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error updating predictions: {e}")
|
||||
self.predicted_prices = np.array([])
|
||||
|
||||
def identify_optimal_trades(self):
|
||||
"""Identify optimal entry and exit points based on local extrema"""
|
||||
@ -1609,39 +1615,40 @@ def get_device():
|
||||
|
||||
# Update Agent class to use GPU properly
|
||||
class Agent:
|
||||
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4,
|
||||
device=None):
|
||||
if device is None:
|
||||
self.device = get_device()
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4, device=None):
|
||||
"""Initialize Agent with architecture parameters stored as attributes"""
|
||||
self.state_size = state_size
|
||||
self.action_size = action_size
|
||||
self.memory = ReplayMemory(MEMORY_SIZE)
|
||||
self.steps_done = 0
|
||||
self.epsilon = EPSILON_START # Initialize epsilon
|
||||
self.hidden_size = hidden_size # Store hidden_size as an instance attribute
|
||||
self.lstm_layers = lstm_layers # Store lstm_layers as an instance attribute
|
||||
self.attention_heads = attention_heads # Store attention_heads as an instance attribute
|
||||
|
||||
# Initialize policy and target networks
|
||||
# Set device
|
||||
self.device = device if device is not None else get_device()
|
||||
|
||||
# Initialize networks
|
||||
self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
||||
self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
self.target_net.eval()
|
||||
|
||||
# Initialize optimizer with weight decay for regularization
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
|
||||
# Initialize optimizer
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
# Initialize gradient scaler for mixed precision training
|
||||
self.scaler = amp.GradScaler()
|
||||
# Initialize replay memory
|
||||
self.memory = ReplayMemory(MEMORY_SIZE)
|
||||
|
||||
# TensorBoard writer
|
||||
self.writer = SummaryWriter()
|
||||
# Initialize exploration parameters
|
||||
self.epsilon = EPSILON_START
|
||||
self.epsilon_decay = EPSILON_DECAY
|
||||
self.epsilon_min = EPSILON_END
|
||||
|
||||
# For chart visualization
|
||||
self.chart_step = 0
|
||||
# Initialize step counter
|
||||
self.steps_done = 0
|
||||
|
||||
# Create models directory if it doesn't exist
|
||||
os.makedirs("models", exist_ok=True)
|
||||
# Initialize TensorBoard writer
|
||||
self.writer = None
|
||||
|
||||
# Rest of the initialization code...
|
||||
|
||||
def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8):
|
||||
"""Expand the model to handle more features or increase capacity"""
|
||||
@ -1727,7 +1734,7 @@ class Agent:
|
||||
|
||||
# Use mixed precision for forward/backward passes
|
||||
if self.device.type == "cuda":
|
||||
with amp.autocast():
|
||||
with torch.amp.autocast('cuda'):
|
||||
# Compute Q values
|
||||
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
|
||||
|
||||
@ -1795,54 +1802,74 @@ class Agent:
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
def save(self, path="models/trading_agent_best_pnl.pt"):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
torch.save({
|
||||
'policy_net': self.policy_net.state_dict(),
|
||||
'target_net': self.target_net.state_dict(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'epsilon': self.epsilon,
|
||||
'steps_done': self.steps_done
|
||||
}, path)
|
||||
logger.info(f"Model saved to {path}")
|
||||
"""Save the model in a format compatible with PyTorch 2.6+"""
|
||||
try:
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
|
||||
# Ensure architecture parameters are set
|
||||
if not hasattr(self, 'hidden_size'):
|
||||
self.hidden_size = 256 # Default value
|
||||
logger.warning("Setting default hidden_size=256 for saving")
|
||||
|
||||
if not hasattr(self, 'lstm_layers'):
|
||||
self.lstm_layers = 2 # Default value
|
||||
logger.warning("Setting default lstm_layers=2 for saving")
|
||||
|
||||
if not hasattr(self, 'attention_heads'):
|
||||
self.attention_heads = 4 # Default value
|
||||
logger.warning("Setting default attention_heads=4 for saving")
|
||||
|
||||
# Save model state
|
||||
checkpoint = {
|
||||
'policy_net': self.policy_net.state_dict(),
|
||||
'target_net': self.target_net.state_dict(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'epsilon': self.epsilon,
|
||||
'state_size': self.state_size,
|
||||
'action_size': self.action_size,
|
||||
'hidden_size': self.hidden_size,
|
||||
'lstm_layers': self.lstm_layers,
|
||||
'attention_heads': self.attention_heads
|
||||
}
|
||||
|
||||
# Save with pickle_protocol=4 for better compatibility
|
||||
torch.save(checkpoint, path, _use_new_zipfile_serialization=True, pickle_protocol=4)
|
||||
logger.info(f"Model saved to {path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def load(self, path="models/trading_agent_best_pnl.pt"):
|
||||
"""Load a trained model"""
|
||||
"""Load a trained model with improved error handling for PyTorch 2.6 compatibility"""
|
||||
try:
|
||||
# First try with weights_only=True (safer)
|
||||
checkpoint = torch.load(path, map_location=self.device)
|
||||
|
||||
# Check if model architecture matches
|
||||
# First try to load with weights_only=False (for models saved with older PyTorch versions)
|
||||
try:
|
||||
state_dict = checkpoint['policy_net']
|
||||
input_size = state_dict['fc1.weight'].shape[1]
|
||||
output_size = state_dict['advantage_stream.bias'].shape[0]
|
||||
hidden_size = state_dict['fc1.weight'].shape[0]
|
||||
logger.info(f"Attempting to load model with weights_only=False: {path}")
|
||||
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
||||
logger.info("Model loaded successfully with weights_only=False")
|
||||
except Exception as e1:
|
||||
logger.warning(f"Failed to load with weights_only=False: {e1}")
|
||||
|
||||
# If architecture doesn't match, rebuild the model
|
||||
if (input_size != self.state_size or
|
||||
output_size != self.action_size or
|
||||
hidden_size != self.hidden_size):
|
||||
logger.warning(f"Model architecture mismatch. Rebuilding model with: "
|
||||
f"state_size={input_size}, action_size={output_size}, "
|
||||
f"hidden_size={hidden_size}")
|
||||
# Try with safe_globals context manager
|
||||
try:
|
||||
logger.info("Attempting to load with safe_globals context manager")
|
||||
import numpy as np
|
||||
from torch.serialization import safe_globals
|
||||
|
||||
# Rebuild the model with the correct architecture
|
||||
self.state_size = input_size
|
||||
self.action_size = output_size
|
||||
self.hidden_size = hidden_size
|
||||
# Add numpy scalar to safe globals
|
||||
with safe_globals(['numpy._core.multiarray.scalar']):
|
||||
checkpoint = torch.load(path, map_location=self.device)
|
||||
logger.info("Model loaded successfully with safe_globals")
|
||||
except Exception as e2:
|
||||
logger.warning(f"Failed to load with safe_globals: {e2}")
|
||||
|
||||
# Recreate networks with correct architecture
|
||||
self.policy_net = DQN(self.state_size, self.action_size,
|
||||
self.hidden_size, self.lstm_layers,
|
||||
self.attention_heads).to(self.device)
|
||||
self.target_net = DQN(self.state_size, self.action_size,
|
||||
self.hidden_size, self.lstm_layers,
|
||||
self.attention_heads).to(self.device)
|
||||
|
||||
# Recreate optimizer
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking model architecture: {e}")
|
||||
# Last resort: try with pickle_module=pickle
|
||||
logger.info("Attempting to load with pickle_module")
|
||||
import pickle
|
||||
checkpoint = torch.load(path, map_location=self.device, pickle_module=pickle, weights_only=False)
|
||||
logger.info("Model loaded successfully with pickle_module")
|
||||
|
||||
# Load state dictionaries
|
||||
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
||||
@ -1858,28 +1885,38 @@ class Agent:
|
||||
if 'epsilon' in checkpoint:
|
||||
self.epsilon = checkpoint['epsilon']
|
||||
|
||||
logger.info(f"Model loaded from {path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading model with default method: {e}")
|
||||
|
||||
# Try with weights_only=False as fallback
|
||||
try:
|
||||
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
||||
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
||||
self.target_net.load_state_dict(checkpoint['target_net'])
|
||||
|
||||
# Load architecture parameters if available
|
||||
if 'state_size' in checkpoint:
|
||||
self.state_size = checkpoint['state_size']
|
||||
if 'action_size' in checkpoint:
|
||||
self.action_size = checkpoint['action_size']
|
||||
if 'hidden_size' in checkpoint:
|
||||
self.hidden_size = checkpoint['hidden_size']
|
||||
else:
|
||||
# If hidden_size not in checkpoint, infer from model
|
||||
try:
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
self.hidden_size = self.policy_net.fc1.weight.shape[0]
|
||||
logger.info(f"Inferred hidden_size={self.hidden_size} from model")
|
||||
except:
|
||||
logger.warning("Could not load optimizer state")
|
||||
self.hidden_size = 256 # Default value
|
||||
logger.warning(f"Could not infer hidden_size, using default: {self.hidden_size}")
|
||||
|
||||
if 'epsilon' in checkpoint:
|
||||
self.epsilon = checkpoint['epsilon']
|
||||
if 'lstm_layers' in checkpoint:
|
||||
self.lstm_layers = checkpoint['lstm_layers']
|
||||
else:
|
||||
self.lstm_layers = 2 # Default value
|
||||
|
||||
logger.info(f"Model loaded from {path} with weights_only=False")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
if 'attention_heads' in checkpoint:
|
||||
self.attention_heads = checkpoint['attention_heads']
|
||||
else:
|
||||
self.attention_heads = 4 # Default value
|
||||
|
||||
logger.info(f"Model loaded successfully from {path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
def add_chart_to_tensorboard(self, env, global_step):
|
||||
"""Add trading chart to TensorBoard"""
|
||||
@ -1903,9 +1940,10 @@ class Agent:
|
||||
self.writer.add_image('Trading Chart', chart_array, global_step)
|
||||
|
||||
# Add position information as text
|
||||
entry_price = env.entry_price if env.entry_price else 0.00
|
||||
position_info = f"""
|
||||
**Current Position**: {env.position.upper()}
|
||||
**Entry Price**: ${env.entry_price:.2f if env.entry_price else 0:.2f}
|
||||
**Entry Price**: ${entry_price:.2f}
|
||||
**Current Price**: ${env.data[-1]['close']:.2f}
|
||||
**Position Size**: ${env.position_size:.2f}
|
||||
**Unrealized PnL**: ${env.total_pnl:.2f}
|
||||
@ -2345,21 +2383,30 @@ async def get_historical_data(exchange, symbol="ETH/USDT", timeframe="1m", limit
|
||||
return []
|
||||
|
||||
async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50):
|
||||
"""
|
||||
Run the trading bot in live mode with enhanced error handling and monitoring
|
||||
|
||||
Args:
|
||||
agent: Trained trading agent
|
||||
env: Trading environment
|
||||
exchange: CCXT exchange instance
|
||||
symbol: Trading pair (default: ETH/USDT)
|
||||
timeframe: Candle timeframe (default: 1m)
|
||||
demo: If True, simulate trades without executing (default: True)
|
||||
leverage: Leverage for futures trading (default: 50x)
|
||||
"""
|
||||
"""Run the trading bot in live mode with enhanced error handling"""
|
||||
logger.info(f"Starting live trading for {symbol} on {timeframe} timeframe")
|
||||
logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}")
|
||||
|
||||
# Verify agent is properly initialized
|
||||
try:
|
||||
# Ensure agent has all required attributes
|
||||
if not hasattr(agent, 'hidden_size'):
|
||||
agent.hidden_size = 256 # Default value
|
||||
logger.warning("Agent missing hidden_size attribute, using default: 256")
|
||||
|
||||
if not hasattr(agent, 'lstm_layers'):
|
||||
agent.lstm_layers = 2 # Default value
|
||||
logger.warning("Agent missing lstm_layers attribute, using default: 2")
|
||||
|
||||
if not hasattr(agent, 'attention_heads'):
|
||||
agent.attention_heads = 4 # Default value
|
||||
logger.warning("Agent missing attention_heads attribute, using default: 4")
|
||||
|
||||
logger.info(f"Agent configuration: state_size={agent.state_size}, action_size={agent.action_size}, hidden_size={agent.hidden_size}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking agent configuration: {e}")
|
||||
# Continue anyway, as these are just informational attributes
|
||||
|
||||
if not demo:
|
||||
# Confirm with user before starting live trading
|
||||
confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ")
|
||||
@ -2378,7 +2425,10 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
|
||||
|
||||
# Initialize TensorBoard for monitoring
|
||||
if not hasattr(agent, 'writer') or agent.writer is None:
|
||||
agent.writer = SummaryWriter(f'runs/live_{symbol.replace("/", "_")}_{datetime.now().strftime("%Y%m%d_%H%M%S")}')
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
# Fix the datetime usage here
|
||||
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
agent.writer = SummaryWriter(f'runs/live_{symbol.replace("/", "_")}_{current_time}')
|
||||
|
||||
# Track performance metrics
|
||||
trades_count = 0
|
||||
@ -2391,10 +2441,14 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
|
||||
|
||||
# Create directory for trade logs
|
||||
os.makedirs('trade_logs', exist_ok=True)
|
||||
trade_log_path = f'trade_logs/trades_{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv'
|
||||
# Fix the datetime usage here
|
||||
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
trade_log_path = f'trade_logs/trades_{current_time}.csv'
|
||||
with open(trade_log_path, 'w') as f:
|
||||
f.write("timestamp,action,price,position_size,balance,pnl\n")
|
||||
|
||||
logger.info("Entering live trading loop...")
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
@ -2410,19 +2464,54 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
|
||||
|
||||
# Get current state and select action
|
||||
state = env.get_state()
|
||||
|
||||
# Verify state shape matches agent's expected input
|
||||
if state.shape[0] != agent.state_size:
|
||||
logger.warning(f"State size mismatch: got {state.shape[0]}, expected {agent.state_size}")
|
||||
# Pad or truncate state to match expected size
|
||||
if state.shape[0] < agent.state_size:
|
||||
state = np.pad(state, (0, agent.state_size - state.shape[0]))
|
||||
else:
|
||||
state = state[:agent.state_size]
|
||||
|
||||
action = agent.select_action(state, training=False)
|
||||
|
||||
# Ensure action is valid
|
||||
if action >= agent.action_size:
|
||||
logger.warning(f"Invalid action {action}, clipping to {agent.action_size-1}")
|
||||
action = agent.action_size - 1
|
||||
|
||||
# Log action
|
||||
action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE"
|
||||
logger.info(f"Step {step_counter}: Action selected: {action_name}, Price: ${env.data[-1]['close']:.2f}")
|
||||
|
||||
# Execute action
|
||||
if not demo:
|
||||
# Execute real trade on exchange
|
||||
current_price = env.data[-1]['close']
|
||||
trade_result = await env.execute_real_trade(exchange, action, current_price)
|
||||
if not trade_result['success']:
|
||||
logger.error(f"Trade execution failed: {trade_result['error']}")
|
||||
if trade_result is None or not isinstance(trade_result, dict) or not trade_result.get('success', False):
|
||||
error_msg = trade_result.get('error', 'Unknown error') if isinstance(trade_result, dict) else 'Trade execution failed'
|
||||
logger.error(f"Trade execution failed: {error_msg}")
|
||||
# Continue with simulated trade for tracking purposes
|
||||
|
||||
# Update environment with action (simulated in demo mode)
|
||||
next_state, reward, done, info = env.step(action)
|
||||
try:
|
||||
next_state, reward, done, info = env.step(action)
|
||||
except ValueError as e:
|
||||
# Handle case where step returns 3 values instead of 4
|
||||
if "not enough values to unpack" in str(e):
|
||||
logger.warning("Step function returned 3 values instead of 4, creating info dict")
|
||||
next_state, reward, done = env.step(action)
|
||||
info = {
|
||||
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
|
||||
'price': env.current_price,
|
||||
'balance': env.balance,
|
||||
'position': env.position,
|
||||
'pnl': env.total_pnl
|
||||
}
|
||||
else:
|
||||
raise
|
||||
|
||||
# Log trade if position changed
|
||||
if env.position != prev_position:
|
||||
@ -2433,7 +2522,7 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
|
||||
|
||||
# Log trade details
|
||||
with open(trade_log_path, 'a') as f:
|
||||
f.write(f"{datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n")
|
||||
f.write(f"{datetime.datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n")
|
||||
|
||||
logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}")
|
||||
|
||||
@ -2472,10 +2561,12 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
|
||||
prev_position = env.position
|
||||
|
||||
# Wait for next candle
|
||||
logger.info(f"Waiting for next candle... (Step {step_counter})")
|
||||
await asyncio.sleep(10) # Check every 10 seconds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in live trading loop: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
logger.info("Continuing after error...")
|
||||
await asyncio.sleep(30) # Wait longer after an error
|
||||
@ -2549,7 +2640,26 @@ async def fetch_ohlcv_data(exchange, symbol, timeframe, limit):
|
||||
logger.error(f"Error fetching OHLCV data: {e}")
|
||||
return []
|
||||
|
||||
# Add this near the top of the file, after imports
|
||||
def ensure_pytorch_compatibility():
|
||||
"""Ensure compatibility with PyTorch 2.6+ for model loading"""
|
||||
try:
|
||||
import torch
|
||||
from torch.serialization import add_safe_globals
|
||||
import numpy as np
|
||||
|
||||
# Add numpy scalar to safe globals for PyTorch 2.6+
|
||||
add_safe_globals(['numpy._core.multiarray.scalar'])
|
||||
logger.info("Added numpy scalar to PyTorch safe globals")
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.warning(f"Could not configure PyTorch compatibility: {e}")
|
||||
logger.warning("This might cause issues with model loading in PyTorch 2.6+")
|
||||
|
||||
# Call this function at the start of the main function
|
||||
async def main():
|
||||
# Ensure PyTorch compatibility
|
||||
ensure_pytorch_compatibility()
|
||||
|
||||
parser = argparse.ArgumentParser(description='Trading Bot')
|
||||
parser.add_argument('--mode', type=str, choices=['train', 'eval', 'live'], default='train',
|
||||
help='Operation mode: train, eval, or live')
|
||||
@ -2583,60 +2693,84 @@ async def main():
|
||||
# Create environment
|
||||
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode)
|
||||
|
||||
# Fetch initial data
|
||||
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
|
||||
|
||||
# Create agent
|
||||
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
||||
|
||||
if args.mode == 'train':
|
||||
# Fetch initial data for training
|
||||
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
|
||||
|
||||
# Create agent with consistent parameters
|
||||
# Note: Using STATE_SIZE and action_size=4 for consistency
|
||||
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
||||
|
||||
# Train the agent
|
||||
logger.info(f"Starting training for {args.episodes} episodes...")
|
||||
stats = await train_agent(agent, env, num_episodes=args.episodes)
|
||||
|
||||
elif args.mode == 'evaluate':
|
||||
# Load trained model
|
||||
agent.load("models/trading_agent_best_pnl.pt")
|
||||
elif args.mode == 'eval' or args.mode == 'live':
|
||||
# Fetch initial data for the specified symbol and timeframe
|
||||
await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000)
|
||||
|
||||
# Evaluate the agent
|
||||
logger.info("Evaluating agent...")
|
||||
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env)
|
||||
|
||||
elif args.mode == 'live':
|
||||
# Initialize exchange
|
||||
exchange = await initialize_exchange()
|
||||
|
||||
# Load model
|
||||
model_path = args.model if args.model else "models/trading_agent.pt"
|
||||
# Determine model path
|
||||
model_path = args.model if args.model else "models/trading_agent_best_pnl.pt"
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"Model file not found: {model_path}")
|
||||
return
|
||||
|
||||
# Initialize environment
|
||||
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=WINDOW_SIZE, demo=demo_mode)
|
||||
await env.fetch_initial_data(exchange, symbol=args.symbol, timeframe=args.timeframe)
|
||||
# Create agent with default parameters
|
||||
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
||||
|
||||
# Initialize agent
|
||||
state_size = env.get_state().shape[0]
|
||||
agent = Agent(state_size=state_size, action_size=3)
|
||||
agent.load(model_path)
|
||||
# Try to load the model
|
||||
try:
|
||||
# Add numpy scalar to safe globals before loading
|
||||
import numpy as np
|
||||
from torch.serialization import add_safe_globals
|
||||
|
||||
# Start live trading
|
||||
await live_trading(
|
||||
agent=agent,
|
||||
env=env,
|
||||
exchange=exchange,
|
||||
symbol=args.symbol,
|
||||
timeframe=args.timeframe,
|
||||
demo=demo_mode,
|
||||
leverage=args.leverage
|
||||
)
|
||||
# Add numpy scalar to safe globals
|
||||
add_safe_globals(['numpy._core.multiarray.scalar'])
|
||||
|
||||
# Load the model
|
||||
agent.load(model_path)
|
||||
logger.info(f"Model loaded successfully from {model_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
|
||||
# Ask user if they want to continue with a new model
|
||||
if args.mode == 'live':
|
||||
confirmation = input("Failed to load model. Continue with a new model? (y/n): ")
|
||||
if confirmation.lower() != 'y':
|
||||
logger.info("Live trading canceled by user")
|
||||
return
|
||||
logger.info("Continuing with a new model")
|
||||
else:
|
||||
logger.info("Continuing evaluation with a new model")
|
||||
|
||||
if args.mode == 'eval':
|
||||
# Evaluate the agent
|
||||
logger.info("Evaluating agent...")
|
||||
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env, num_episodes=args.episodes)
|
||||
|
||||
elif args.mode == 'live':
|
||||
# Start live trading
|
||||
logger.info(f"Starting live trading for {args.symbol} on {args.timeframe} timeframe")
|
||||
logger.info(f"Demo mode: {demo_mode}, Leverage: {args.leverage}x")
|
||||
|
||||
await live_trading(
|
||||
agent=agent,
|
||||
env=env,
|
||||
exchange=exchange,
|
||||
symbol=args.symbol,
|
||||
timeframe=args.timeframe,
|
||||
demo=demo_mode,
|
||||
leverage=args.leverage
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main function: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
# Clean up exchange connection - safely close if possible
|
||||
# Clean up exchange connection
|
||||
if exchange:
|
||||
try:
|
||||
# Some CCXT exchanges have close method, others don't
|
||||
if hasattr(exchange, 'close'):
|
||||
await exchange.close()
|
||||
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
|
||||
@ -2667,7 +2801,8 @@ def create_candlestick_figure(data, trade_signals, window_size=100, title=""):
|
||||
|
||||
# Plot candlesticks - use a simpler approach if mplfinance fails
|
||||
try:
|
||||
mpf.plot(df, type='candle', style='yahoo', ax=price_ax, volume=volume_ax)
|
||||
# Use a different style or approach that doesn't use 'type' parameter
|
||||
mpf.plot(df, type='candle', ax=price_ax, volume=volume_ax, style='yahoo')
|
||||
except Exception as e:
|
||||
logger.warning(f"Error plotting with mplfinance: {e}, falling back to simple plot")
|
||||
# Fallback to simple plot
|
||||
|
Loading…
x
Reference in New Issue
Block a user