fix training & demo mode

This commit is contained in:
Dobromir Popov 2025-03-17 03:59:50 +02:00
parent c63b1a2daf
commit e87207c1fa

View File

@ -3,7 +3,7 @@ import time
import json import json
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import datetime from datetime import datetime
import random import random
import logging import logging
import asyncio import asyncio
@ -28,6 +28,8 @@ from matplotlib.figure import Figure
from PIL import Image from PIL import Image
import matplotlib.pyplot as mpf import matplotlib.pyplot as mpf
import matplotlib.gridspec as gridspec import matplotlib.gridspec as gridspec
import datetime
from datetime import datetime as dt
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
@ -53,7 +55,7 @@ GAMMA = 0.99 # Discount factor
EPSILON_START = 1.0 EPSILON_START = 1.0
EPSILON_END = 0.05 EPSILON_END = 0.05
EPSILON_DECAY = 10000 EPSILON_DECAY = 10000
STATE_SIZE = 40 # Size of our state representation STATE_SIZE = 64 # Size of our state representation
LEARNING_RATE = 1e-4 LEARNING_RATE = 1e-4
TARGET_UPDATE = 10 # Update target network every 10 episodes TARGET_UPDATE = 10 # Update target network every 10 episodes
@ -1308,15 +1310,19 @@ class TradingEnvironment:
def update_price_predictions(self): def update_price_predictions(self):
"""Update price predictions""" """Update price predictions"""
if len(self.features['price']) < 30: if len(self.features['price']) < 30 or not hasattr(self, 'price_predictor') or self.price_predictor is None:
self.predicted_prices = np.array([]) self.predicted_prices = np.array([])
return return
# Get price history # Get price history
price_history = self.features['price'] price_history = self.features['price']
# Get predictions try:
self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5) # Get predictions
self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5)
except Exception as e:
logger.warning(f"Error updating predictions: {e}")
self.predicted_prices = np.array([])
def identify_optimal_trades(self): def identify_optimal_trades(self):
"""Identify optimal entry and exit points based on local extrema""" """Identify optimal entry and exit points based on local extrema"""
@ -1609,39 +1615,40 @@ def get_device():
# Update Agent class to use GPU properly # Update Agent class to use GPU properly
class Agent: class Agent:
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4, def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4, device=None):
device=None): """Initialize Agent with architecture parameters stored as attributes"""
if device is None:
self.device = get_device()
else:
self.device = device
self.state_size = state_size self.state_size = state_size
self.action_size = action_size self.action_size = action_size
self.memory = ReplayMemory(MEMORY_SIZE) self.hidden_size = hidden_size # Store hidden_size as an instance attribute
self.steps_done = 0 self.lstm_layers = lstm_layers # Store lstm_layers as an instance attribute
self.epsilon = EPSILON_START # Initialize epsilon self.attention_heads = attention_heads # Store attention_heads as an instance attribute
# Initialize policy and target networks # Set device
self.device = device if device is not None else get_device()
# Initialize networks
self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device) self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device) self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict()) self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
# Initialize optimizer with weight decay for regularization # Initialize optimizer
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE, weight_decay=1e-5) self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
# Initialize gradient scaler for mixed precision training # Initialize replay memory
self.scaler = amp.GradScaler() self.memory = ReplayMemory(MEMORY_SIZE)
# TensorBoard writer # Initialize exploration parameters
self.writer = SummaryWriter() self.epsilon = EPSILON_START
self.epsilon_decay = EPSILON_DECAY
self.epsilon_min = EPSILON_END
# For chart visualization # Initialize step counter
self.chart_step = 0 self.steps_done = 0
# Create models directory if it doesn't exist # Initialize TensorBoard writer
os.makedirs("models", exist_ok=True) self.writer = None
# Rest of the initialization code...
def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8): def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8):
"""Expand the model to handle more features or increase capacity""" """Expand the model to handle more features or increase capacity"""
@ -1727,7 +1734,7 @@ class Agent:
# Use mixed precision for forward/backward passes # Use mixed precision for forward/backward passes
if self.device.type == "cuda": if self.device.type == "cuda":
with amp.autocast(): with torch.amp.autocast('cuda'):
# Compute Q values # Compute Q values
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)) current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
@ -1795,54 +1802,74 @@ class Agent:
self.target_net.load_state_dict(self.policy_net.state_dict()) self.target_net.load_state_dict(self.policy_net.state_dict())
def save(self, path="models/trading_agent_best_pnl.pt"): def save(self, path="models/trading_agent_best_pnl.pt"):
os.makedirs(os.path.dirname(path), exist_ok=True) """Save the model in a format compatible with PyTorch 2.6+"""
torch.save({
'policy_net': self.policy_net.state_dict(),
'target_net': self.target_net.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epsilon': self.epsilon,
'steps_done': self.steps_done
}, path)
logger.info(f"Model saved to {path}")
def load(self, path="models/trading_agent_best_pnl.pt"):
"""Load a trained model"""
try: try:
# First try with weights_only=True (safer) # Create directory if it doesn't exist
checkpoint = torch.load(path, map_location=self.device) os.makedirs(os.path.dirname(path), exist_ok=True)
# Check if model architecture matches # Ensure architecture parameters are set
try: if not hasattr(self, 'hidden_size'):
state_dict = checkpoint['policy_net'] self.hidden_size = 256 # Default value
input_size = state_dict['fc1.weight'].shape[1] logger.warning("Setting default hidden_size=256 for saving")
output_size = state_dict['advantage_stream.bias'].shape[0]
hidden_size = state_dict['fc1.weight'].shape[0]
# If architecture doesn't match, rebuild the model if not hasattr(self, 'lstm_layers'):
if (input_size != self.state_size or self.lstm_layers = 2 # Default value
output_size != self.action_size or logger.warning("Setting default lstm_layers=2 for saving")
hidden_size != self.hidden_size):
logger.warning(f"Model architecture mismatch. Rebuilding model with: " if not hasattr(self, 'attention_heads'):
f"state_size={input_size}, action_size={output_size}, " self.attention_heads = 4 # Default value
f"hidden_size={hidden_size}") logger.warning("Setting default attention_heads=4 for saving")
# Save model state
checkpoint = {
'policy_net': self.policy_net.state_dict(),
'target_net': self.target_net.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epsilon': self.epsilon,
'state_size': self.state_size,
'action_size': self.action_size,
'hidden_size': self.hidden_size,
'lstm_layers': self.lstm_layers,
'attention_heads': self.attention_heads
}
# Save with pickle_protocol=4 for better compatibility
torch.save(checkpoint, path, _use_new_zipfile_serialization=True, pickle_protocol=4)
logger.info(f"Model saved to {path}")
except Exception as e:
logger.error(f"Error saving model: {e}")
import traceback
logger.error(traceback.format_exc())
def load(self, path="models/trading_agent_best_pnl.pt"):
"""Load a trained model with improved error handling for PyTorch 2.6 compatibility"""
try:
# First try to load with weights_only=False (for models saved with older PyTorch versions)
try:
logger.info(f"Attempting to load model with weights_only=False: {path}")
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
logger.info("Model loaded successfully with weights_only=False")
except Exception as e1:
logger.warning(f"Failed to load with weights_only=False: {e1}")
# Try with safe_globals context manager
try:
logger.info("Attempting to load with safe_globals context manager")
import numpy as np
from torch.serialization import safe_globals
# Rebuild the model with the correct architecture # Add numpy scalar to safe globals
self.state_size = input_size with safe_globals(['numpy._core.multiarray.scalar']):
self.action_size = output_size checkpoint = torch.load(path, map_location=self.device)
self.hidden_size = hidden_size logger.info("Model loaded successfully with safe_globals")
except Exception as e2:
logger.warning(f"Failed to load with safe_globals: {e2}")
# Recreate networks with correct architecture # Last resort: try with pickle_module=pickle
self.policy_net = DQN(self.state_size, self.action_size, logger.info("Attempting to load with pickle_module")
self.hidden_size, self.lstm_layers, import pickle
self.attention_heads).to(self.device) checkpoint = torch.load(path, map_location=self.device, pickle_module=pickle, weights_only=False)
self.target_net = DQN(self.state_size, self.action_size, logger.info("Model loaded successfully with pickle_module")
self.hidden_size, self.lstm_layers,
self.attention_heads).to(self.device)
# Recreate optimizer
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
except Exception as e:
logger.warning(f"Error checking model architecture: {e}")
# Load state dictionaries # Load state dictionaries
self.policy_net.load_state_dict(checkpoint['policy_net']) self.policy_net.load_state_dict(checkpoint['policy_net'])
@ -1858,28 +1885,38 @@ class Agent:
if 'epsilon' in checkpoint: if 'epsilon' in checkpoint:
self.epsilon = checkpoint['epsilon'] self.epsilon = checkpoint['epsilon']
logger.info(f"Model loaded from {path}") # Load architecture parameters if available
except Exception as e: if 'state_size' in checkpoint:
logger.warning(f"Error loading model with default method: {e}") self.state_size = checkpoint['state_size']
if 'action_size' in checkpoint:
# Try with weights_only=False as fallback self.action_size = checkpoint['action_size']
try: if 'hidden_size' in checkpoint:
checkpoint = torch.load(path, map_location=self.device, weights_only=False) self.hidden_size = checkpoint['hidden_size']
self.policy_net.load_state_dict(checkpoint['policy_net']) else:
self.target_net.load_state_dict(checkpoint['target_net']) # If hidden_size not in checkpoint, infer from model
try: try:
self.optimizer.load_state_dict(checkpoint['optimizer']) self.hidden_size = self.policy_net.fc1.weight.shape[0]
logger.info(f"Inferred hidden_size={self.hidden_size} from model")
except: except:
logger.warning("Could not load optimizer state") self.hidden_size = 256 # Default value
logger.warning(f"Could not infer hidden_size, using default: {self.hidden_size}")
if 'epsilon' in checkpoint:
self.epsilon = checkpoint['epsilon'] if 'lstm_layers' in checkpoint:
self.lstm_layers = checkpoint['lstm_layers']
logger.info(f"Model loaded from {path} with weights_only=False") else:
except Exception as e: self.lstm_layers = 2 # Default value
logger.error(f"Failed to load model: {e}")
raise if 'attention_heads' in checkpoint:
self.attention_heads = checkpoint['attention_heads']
else:
self.attention_heads = 4 # Default value
logger.info(f"Model loaded successfully from {path}")
except Exception as e:
logger.error(f"Error loading model: {e}")
import traceback
logger.error(traceback.format_exc())
raise
def add_chart_to_tensorboard(self, env, global_step): def add_chart_to_tensorboard(self, env, global_step):
"""Add trading chart to TensorBoard""" """Add trading chart to TensorBoard"""
@ -1903,9 +1940,10 @@ class Agent:
self.writer.add_image('Trading Chart', chart_array, global_step) self.writer.add_image('Trading Chart', chart_array, global_step)
# Add position information as text # Add position information as text
entry_price = env.entry_price if env.entry_price else 0.00
position_info = f""" position_info = f"""
**Current Position**: {env.position.upper()} **Current Position**: {env.position.upper()}
**Entry Price**: ${env.entry_price:.2f if env.entry_price else 0:.2f} **Entry Price**: ${entry_price:.2f}
**Current Price**: ${env.data[-1]['close']:.2f} **Current Price**: ${env.data[-1]['close']:.2f}
**Position Size**: ${env.position_size:.2f} **Position Size**: ${env.position_size:.2f}
**Unrealized PnL**: ${env.total_pnl:.2f} **Unrealized PnL**: ${env.total_pnl:.2f}
@ -2345,21 +2383,30 @@ async def get_historical_data(exchange, symbol="ETH/USDT", timeframe="1m", limit
return [] return []
async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50): async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50):
""" """Run the trading bot in live mode with enhanced error handling"""
Run the trading bot in live mode with enhanced error handling and monitoring
Args:
agent: Trained trading agent
env: Trading environment
exchange: CCXT exchange instance
symbol: Trading pair (default: ETH/USDT)
timeframe: Candle timeframe (default: 1m)
demo: If True, simulate trades without executing (default: True)
leverage: Leverage for futures trading (default: 50x)
"""
logger.info(f"Starting live trading for {symbol} on {timeframe} timeframe") logger.info(f"Starting live trading for {symbol} on {timeframe} timeframe")
logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}") logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}")
# Verify agent is properly initialized
try:
# Ensure agent has all required attributes
if not hasattr(agent, 'hidden_size'):
agent.hidden_size = 256 # Default value
logger.warning("Agent missing hidden_size attribute, using default: 256")
if not hasattr(agent, 'lstm_layers'):
agent.lstm_layers = 2 # Default value
logger.warning("Agent missing lstm_layers attribute, using default: 2")
if not hasattr(agent, 'attention_heads'):
agent.attention_heads = 4 # Default value
logger.warning("Agent missing attention_heads attribute, using default: 4")
logger.info(f"Agent configuration: state_size={agent.state_size}, action_size={agent.action_size}, hidden_size={agent.hidden_size}")
except Exception as e:
logger.error(f"Error checking agent configuration: {e}")
# Continue anyway, as these are just informational attributes
if not demo: if not demo:
# Confirm with user before starting live trading # Confirm with user before starting live trading
confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ") confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ")
@ -2378,7 +2425,10 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
# Initialize TensorBoard for monitoring # Initialize TensorBoard for monitoring
if not hasattr(agent, 'writer') or agent.writer is None: if not hasattr(agent, 'writer') or agent.writer is None:
agent.writer = SummaryWriter(f'runs/live_{symbol.replace("/", "_")}_{datetime.now().strftime("%Y%m%d_%H%M%S")}') from torch.utils.tensorboard import SummaryWriter
# Fix the datetime usage here
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
agent.writer = SummaryWriter(f'runs/live_{symbol.replace("/", "_")}_{current_time}')
# Track performance metrics # Track performance metrics
trades_count = 0 trades_count = 0
@ -2391,10 +2441,14 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
# Create directory for trade logs # Create directory for trade logs
os.makedirs('trade_logs', exist_ok=True) os.makedirs('trade_logs', exist_ok=True)
trade_log_path = f'trade_logs/trades_{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv' # Fix the datetime usage here
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
trade_log_path = f'trade_logs/trades_{current_time}.csv'
with open(trade_log_path, 'w') as f: with open(trade_log_path, 'w') as f:
f.write("timestamp,action,price,position_size,balance,pnl\n") f.write("timestamp,action,price,position_size,balance,pnl\n")
logger.info("Entering live trading loop...")
try: try:
while True: while True:
try: try:
@ -2410,19 +2464,54 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
# Get current state and select action # Get current state and select action
state = env.get_state() state = env.get_state()
# Verify state shape matches agent's expected input
if state.shape[0] != agent.state_size:
logger.warning(f"State size mismatch: got {state.shape[0]}, expected {agent.state_size}")
# Pad or truncate state to match expected size
if state.shape[0] < agent.state_size:
state = np.pad(state, (0, agent.state_size - state.shape[0]))
else:
state = state[:agent.state_size]
action = agent.select_action(state, training=False) action = agent.select_action(state, training=False)
# Ensure action is valid
if action >= agent.action_size:
logger.warning(f"Invalid action {action}, clipping to {agent.action_size-1}")
action = agent.action_size - 1
# Log action
action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE"
logger.info(f"Step {step_counter}: Action selected: {action_name}, Price: ${env.data[-1]['close']:.2f}")
# Execute action # Execute action
if not demo: if not demo:
# Execute real trade on exchange # Execute real trade on exchange
current_price = env.data[-1]['close'] current_price = env.data[-1]['close']
trade_result = await env.execute_real_trade(exchange, action, current_price) trade_result = await env.execute_real_trade(exchange, action, current_price)
if not trade_result['success']: if trade_result is None or not isinstance(trade_result, dict) or not trade_result.get('success', False):
logger.error(f"Trade execution failed: {trade_result['error']}") error_msg = trade_result.get('error', 'Unknown error') if isinstance(trade_result, dict) else 'Trade execution failed'
logger.error(f"Trade execution failed: {error_msg}")
# Continue with simulated trade for tracking purposes # Continue with simulated trade for tracking purposes
# Update environment with action (simulated in demo mode) # Update environment with action (simulated in demo mode)
next_state, reward, done, info = env.step(action) try:
next_state, reward, done, info = env.step(action)
except ValueError as e:
# Handle case where step returns 3 values instead of 4
if "not enough values to unpack" in str(e):
logger.warning("Step function returned 3 values instead of 4, creating info dict")
next_state, reward, done = env.step(action)
info = {
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
'price': env.current_price,
'balance': env.balance,
'position': env.position,
'pnl': env.total_pnl
}
else:
raise
# Log trade if position changed # Log trade if position changed
if env.position != prev_position: if env.position != prev_position:
@ -2433,7 +2522,7 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
# Log trade details # Log trade details
with open(trade_log_path, 'a') as f: with open(trade_log_path, 'a') as f:
f.write(f"{datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n") f.write(f"{datetime.datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n")
logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}") logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}")
@ -2472,10 +2561,12 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
prev_position = env.position prev_position = env.position
# Wait for next candle # Wait for next candle
logger.info(f"Waiting for next candle... (Step {step_counter})")
await asyncio.sleep(10) # Check every 10 seconds await asyncio.sleep(10) # Check every 10 seconds
except Exception as e: except Exception as e:
logger.error(f"Error in live trading loop: {str(e)}") logger.error(f"Error in live trading loop: {str(e)}")
import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.info("Continuing after error...") logger.info("Continuing after error...")
await asyncio.sleep(30) # Wait longer after an error await asyncio.sleep(30) # Wait longer after an error
@ -2549,7 +2640,26 @@ async def fetch_ohlcv_data(exchange, symbol, timeframe, limit):
logger.error(f"Error fetching OHLCV data: {e}") logger.error(f"Error fetching OHLCV data: {e}")
return [] return []
# Add this near the top of the file, after imports
def ensure_pytorch_compatibility():
"""Ensure compatibility with PyTorch 2.6+ for model loading"""
try:
import torch
from torch.serialization import add_safe_globals
import numpy as np
# Add numpy scalar to safe globals for PyTorch 2.6+
add_safe_globals(['numpy._core.multiarray.scalar'])
logger.info("Added numpy scalar to PyTorch safe globals")
except (ImportError, AttributeError) as e:
logger.warning(f"Could not configure PyTorch compatibility: {e}")
logger.warning("This might cause issues with model loading in PyTorch 2.6+")
# Call this function at the start of the main function
async def main(): async def main():
# Ensure PyTorch compatibility
ensure_pytorch_compatibility()
parser = argparse.ArgumentParser(description='Trading Bot') parser = argparse.ArgumentParser(description='Trading Bot')
parser.add_argument('--mode', type=str, choices=['train', 'eval', 'live'], default='train', parser.add_argument('--mode', type=str, choices=['train', 'eval', 'live'], default='train',
help='Operation mode: train, eval, or live') help='Operation mode: train, eval, or live')
@ -2583,60 +2693,84 @@ async def main():
# Create environment # Create environment
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode) env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode)
# Fetch initial data
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
# Create agent
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
if args.mode == 'train': if args.mode == 'train':
# Fetch initial data for training
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
# Create agent with consistent parameters
# Note: Using STATE_SIZE and action_size=4 for consistency
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
# Train the agent # Train the agent
logger.info(f"Starting training for {args.episodes} episodes...") logger.info(f"Starting training for {args.episodes} episodes...")
stats = await train_agent(agent, env, num_episodes=args.episodes) stats = await train_agent(agent, env, num_episodes=args.episodes)
elif args.mode == 'evaluate': elif args.mode == 'eval' or args.mode == 'live':
# Load trained model # Fetch initial data for the specified symbol and timeframe
agent.load("models/trading_agent_best_pnl.pt") await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000)
# Evaluate the agent # Determine model path
logger.info("Evaluating agent...") model_path = args.model if args.model else "models/trading_agent_best_pnl.pt"
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env)
elif args.mode == 'live':
# Initialize exchange
exchange = await initialize_exchange()
# Load model
model_path = args.model if args.model else "models/trading_agent.pt"
if not os.path.exists(model_path): if not os.path.exists(model_path):
logger.error(f"Model file not found: {model_path}") logger.error(f"Model file not found: {model_path}")
return return
# Create agent with default parameters
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
# Try to load the model
try:
# Add numpy scalar to safe globals before loading
import numpy as np
from torch.serialization import add_safe_globals
# Initialize environment # Add numpy scalar to safe globals
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=WINDOW_SIZE, demo=demo_mode) add_safe_globals(['numpy._core.multiarray.scalar'])
await env.fetch_initial_data(exchange, symbol=args.symbol, timeframe=args.timeframe)
# Load the model
agent.load(model_path)
logger.info(f"Model loaded successfully from {model_path}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
# Ask user if they want to continue with a new model
if args.mode == 'live':
confirmation = input("Failed to load model. Continue with a new model? (y/n): ")
if confirmation.lower() != 'y':
logger.info("Live trading canceled by user")
return
logger.info("Continuing with a new model")
else:
logger.info("Continuing evaluation with a new model")
# Initialize agent if args.mode == 'eval':
state_size = env.get_state().shape[0] # Evaluate the agent
agent = Agent(state_size=state_size, action_size=3) logger.info("Evaluating agent...")
agent.load(model_path) avg_reward, avg_profit, win_rate = evaluate_agent(agent, env, num_episodes=args.episodes)
# Start live trading elif args.mode == 'live':
await live_trading( # Start live trading
agent=agent, logger.info(f"Starting live trading for {args.symbol} on {args.timeframe} timeframe")
env=env, logger.info(f"Demo mode: {demo_mode}, Leverage: {args.leverage}x")
exchange=exchange,
symbol=args.symbol, await live_trading(
timeframe=args.timeframe, agent=agent,
demo=demo_mode, env=env,
leverage=args.leverage exchange=exchange,
) symbol=args.symbol,
timeframe=args.timeframe,
demo=demo_mode,
leverage=args.leverage
)
except Exception as e:
logger.error(f"Error in main function: {e}")
import traceback
logger.error(traceback.format_exc())
finally: finally:
# Clean up exchange connection - safely close if possible # Clean up exchange connection
if exchange: if exchange:
try: try:
# Some CCXT exchanges have close method, others don't
if hasattr(exchange, 'close'): if hasattr(exchange, 'close'):
await exchange.close() await exchange.close()
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'): elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
@ -2667,7 +2801,8 @@ def create_candlestick_figure(data, trade_signals, window_size=100, title=""):
# Plot candlesticks - use a simpler approach if mplfinance fails # Plot candlesticks - use a simpler approach if mplfinance fails
try: try:
mpf.plot(df, type='candle', style='yahoo', ax=price_ax, volume=volume_ax) # Use a different style or approach that doesn't use 'type' parameter
mpf.plot(df, type='candle', ax=price_ax, volume=volume_ax, style='yahoo')
except Exception as e: except Exception as e:
logger.warning(f"Error plotting with mplfinance: {e}, falling back to simple plot") logger.warning(f"Error plotting with mplfinance: {e}, falling back to simple plot")
# Fallback to simple plot # Fallback to simple plot