fix training & demo mode

This commit is contained in:
Dobromir Popov 2025-03-17 03:59:50 +02:00
parent c63b1a2daf
commit e87207c1fa

View File

@ -3,7 +3,7 @@ import time
import json
import numpy as np
import pandas as pd
import datetime
from datetime import datetime
import random
import logging
import asyncio
@ -28,6 +28,8 @@ from matplotlib.figure import Figure
from PIL import Image
import matplotlib.pyplot as mpf
import matplotlib.gridspec as gridspec
import datetime
from datetime import datetime as dt
# Configure logging
logging.basicConfig(
@ -53,7 +55,7 @@ GAMMA = 0.99 # Discount factor
EPSILON_START = 1.0
EPSILON_END = 0.05
EPSILON_DECAY = 10000
STATE_SIZE = 40 # Size of our state representation
STATE_SIZE = 64 # Size of our state representation
LEARNING_RATE = 1e-4
TARGET_UPDATE = 10 # Update target network every 10 episodes
@ -1308,15 +1310,19 @@ class TradingEnvironment:
def update_price_predictions(self):
"""Update price predictions"""
if len(self.features['price']) < 30:
if len(self.features['price']) < 30 or not hasattr(self, 'price_predictor') or self.price_predictor is None:
self.predicted_prices = np.array([])
return
# Get price history
price_history = self.features['price']
# Get predictions
self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5)
try:
# Get predictions
self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5)
except Exception as e:
logger.warning(f"Error updating predictions: {e}")
self.predicted_prices = np.array([])
def identify_optimal_trades(self):
"""Identify optimal entry and exit points based on local extrema"""
@ -1609,39 +1615,40 @@ def get_device():
# Update Agent class to use GPU properly
class Agent:
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4,
device=None):
if device is None:
self.device = get_device()
else:
self.device = device
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4, device=None):
"""Initialize Agent with architecture parameters stored as attributes"""
self.state_size = state_size
self.action_size = action_size
self.memory = ReplayMemory(MEMORY_SIZE)
self.steps_done = 0
self.epsilon = EPSILON_START # Initialize epsilon
self.hidden_size = hidden_size # Store hidden_size as an instance attribute
self.lstm_layers = lstm_layers # Store lstm_layers as an instance attribute
self.attention_heads = attention_heads # Store attention_heads as an instance attribute
# Initialize policy and target networks
# Set device
self.device = device if device is not None else get_device()
# Initialize networks
self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
# Initialize optimizer with weight decay for regularization
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
# Initialize optimizer
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
# Initialize gradient scaler for mixed precision training
self.scaler = amp.GradScaler()
# Initialize replay memory
self.memory = ReplayMemory(MEMORY_SIZE)
# TensorBoard writer
self.writer = SummaryWriter()
# Initialize exploration parameters
self.epsilon = EPSILON_START
self.epsilon_decay = EPSILON_DECAY
self.epsilon_min = EPSILON_END
# For chart visualization
self.chart_step = 0
# Initialize step counter
self.steps_done = 0
# Create models directory if it doesn't exist
os.makedirs("models", exist_ok=True)
# Initialize TensorBoard writer
self.writer = None
# Rest of the initialization code...
def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8):
"""Expand the model to handle more features or increase capacity"""
@ -1727,7 +1734,7 @@ class Agent:
# Use mixed precision for forward/backward passes
if self.device.type == "cuda":
with amp.autocast():
with torch.amp.autocast('cuda'):
# Compute Q values
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
@ -1795,54 +1802,74 @@ class Agent:
self.target_net.load_state_dict(self.policy_net.state_dict())
def save(self, path="models/trading_agent_best_pnl.pt"):
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({
'policy_net': self.policy_net.state_dict(),
'target_net': self.target_net.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epsilon': self.epsilon,
'steps_done': self.steps_done
}, path)
logger.info(f"Model saved to {path}")
"""Save the model in a format compatible with PyTorch 2.6+"""
try:
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(path), exist_ok=True)
# Ensure architecture parameters are set
if not hasattr(self, 'hidden_size'):
self.hidden_size = 256 # Default value
logger.warning("Setting default hidden_size=256 for saving")
if not hasattr(self, 'lstm_layers'):
self.lstm_layers = 2 # Default value
logger.warning("Setting default lstm_layers=2 for saving")
if not hasattr(self, 'attention_heads'):
self.attention_heads = 4 # Default value
logger.warning("Setting default attention_heads=4 for saving")
# Save model state
checkpoint = {
'policy_net': self.policy_net.state_dict(),
'target_net': self.target_net.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epsilon': self.epsilon,
'state_size': self.state_size,
'action_size': self.action_size,
'hidden_size': self.hidden_size,
'lstm_layers': self.lstm_layers,
'attention_heads': self.attention_heads
}
# Save with pickle_protocol=4 for better compatibility
torch.save(checkpoint, path, _use_new_zipfile_serialization=True, pickle_protocol=4)
logger.info(f"Model saved to {path}")
except Exception as e:
logger.error(f"Error saving model: {e}")
import traceback
logger.error(traceback.format_exc())
def load(self, path="models/trading_agent_best_pnl.pt"):
"""Load a trained model"""
"""Load a trained model with improved error handling for PyTorch 2.6 compatibility"""
try:
# First try with weights_only=True (safer)
checkpoint = torch.load(path, map_location=self.device)
# Check if model architecture matches
# First try to load with weights_only=False (for models saved with older PyTorch versions)
try:
state_dict = checkpoint['policy_net']
input_size = state_dict['fc1.weight'].shape[1]
output_size = state_dict['advantage_stream.bias'].shape[0]
hidden_size = state_dict['fc1.weight'].shape[0]
logger.info(f"Attempting to load model with weights_only=False: {path}")
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
logger.info("Model loaded successfully with weights_only=False")
except Exception as e1:
logger.warning(f"Failed to load with weights_only=False: {e1}")
# If architecture doesn't match, rebuild the model
if (input_size != self.state_size or
output_size != self.action_size or
hidden_size != self.hidden_size):
logger.warning(f"Model architecture mismatch. Rebuilding model with: "
f"state_size={input_size}, action_size={output_size}, "
f"hidden_size={hidden_size}")
# Try with safe_globals context manager
try:
logger.info("Attempting to load with safe_globals context manager")
import numpy as np
from torch.serialization import safe_globals
# Rebuild the model with the correct architecture
self.state_size = input_size
self.action_size = output_size
self.hidden_size = hidden_size
# Add numpy scalar to safe globals
with safe_globals(['numpy._core.multiarray.scalar']):
checkpoint = torch.load(path, map_location=self.device)
logger.info("Model loaded successfully with safe_globals")
except Exception as e2:
logger.warning(f"Failed to load with safe_globals: {e2}")
# Recreate networks with correct architecture
self.policy_net = DQN(self.state_size, self.action_size,
self.hidden_size, self.lstm_layers,
self.attention_heads).to(self.device)
self.target_net = DQN(self.state_size, self.action_size,
self.hidden_size, self.lstm_layers,
self.attention_heads).to(self.device)
# Recreate optimizer
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
except Exception as e:
logger.warning(f"Error checking model architecture: {e}")
# Last resort: try with pickle_module=pickle
logger.info("Attempting to load with pickle_module")
import pickle
checkpoint = torch.load(path, map_location=self.device, pickle_module=pickle, weights_only=False)
logger.info("Model loaded successfully with pickle_module")
# Load state dictionaries
self.policy_net.load_state_dict(checkpoint['policy_net'])
@ -1858,28 +1885,38 @@ class Agent:
if 'epsilon' in checkpoint:
self.epsilon = checkpoint['epsilon']
logger.info(f"Model loaded from {path}")
except Exception as e:
logger.warning(f"Error loading model with default method: {e}")
# Try with weights_only=False as fallback
try:
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
self.policy_net.load_state_dict(checkpoint['policy_net'])
self.target_net.load_state_dict(checkpoint['target_net'])
# Load architecture parameters if available
if 'state_size' in checkpoint:
self.state_size = checkpoint['state_size']
if 'action_size' in checkpoint:
self.action_size = checkpoint['action_size']
if 'hidden_size' in checkpoint:
self.hidden_size = checkpoint['hidden_size']
else:
# If hidden_size not in checkpoint, infer from model
try:
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.hidden_size = self.policy_net.fc1.weight.shape[0]
logger.info(f"Inferred hidden_size={self.hidden_size} from model")
except:
logger.warning("Could not load optimizer state")
self.hidden_size = 256 # Default value
logger.warning(f"Could not infer hidden_size, using default: {self.hidden_size}")
if 'epsilon' in checkpoint:
self.epsilon = checkpoint['epsilon']
if 'lstm_layers' in checkpoint:
self.lstm_layers = checkpoint['lstm_layers']
else:
self.lstm_layers = 2 # Default value
logger.info(f"Model loaded from {path} with weights_only=False")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
if 'attention_heads' in checkpoint:
self.attention_heads = checkpoint['attention_heads']
else:
self.attention_heads = 4 # Default value
logger.info(f"Model loaded successfully from {path}")
except Exception as e:
logger.error(f"Error loading model: {e}")
import traceback
logger.error(traceback.format_exc())
raise
def add_chart_to_tensorboard(self, env, global_step):
"""Add trading chart to TensorBoard"""
@ -1903,9 +1940,10 @@ class Agent:
self.writer.add_image('Trading Chart', chart_array, global_step)
# Add position information as text
entry_price = env.entry_price if env.entry_price else 0.00
position_info = f"""
**Current Position**: {env.position.upper()}
**Entry Price**: ${env.entry_price:.2f if env.entry_price else 0:.2f}
**Entry Price**: ${entry_price:.2f}
**Current Price**: ${env.data[-1]['close']:.2f}
**Position Size**: ${env.position_size:.2f}
**Unrealized PnL**: ${env.total_pnl:.2f}
@ -2345,21 +2383,30 @@ async def get_historical_data(exchange, symbol="ETH/USDT", timeframe="1m", limit
return []
async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50):
"""
Run the trading bot in live mode with enhanced error handling and monitoring
Args:
agent: Trained trading agent
env: Trading environment
exchange: CCXT exchange instance
symbol: Trading pair (default: ETH/USDT)
timeframe: Candle timeframe (default: 1m)
demo: If True, simulate trades without executing (default: True)
leverage: Leverage for futures trading (default: 50x)
"""
"""Run the trading bot in live mode with enhanced error handling"""
logger.info(f"Starting live trading for {symbol} on {timeframe} timeframe")
logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}")
# Verify agent is properly initialized
try:
# Ensure agent has all required attributes
if not hasattr(agent, 'hidden_size'):
agent.hidden_size = 256 # Default value
logger.warning("Agent missing hidden_size attribute, using default: 256")
if not hasattr(agent, 'lstm_layers'):
agent.lstm_layers = 2 # Default value
logger.warning("Agent missing lstm_layers attribute, using default: 2")
if not hasattr(agent, 'attention_heads'):
agent.attention_heads = 4 # Default value
logger.warning("Agent missing attention_heads attribute, using default: 4")
logger.info(f"Agent configuration: state_size={agent.state_size}, action_size={agent.action_size}, hidden_size={agent.hidden_size}")
except Exception as e:
logger.error(f"Error checking agent configuration: {e}")
# Continue anyway, as these are just informational attributes
if not demo:
# Confirm with user before starting live trading
confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ")
@ -2378,7 +2425,10 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
# Initialize TensorBoard for monitoring
if not hasattr(agent, 'writer') or agent.writer is None:
agent.writer = SummaryWriter(f'runs/live_{symbol.replace("/", "_")}_{datetime.now().strftime("%Y%m%d_%H%M%S")}')
from torch.utils.tensorboard import SummaryWriter
# Fix the datetime usage here
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
agent.writer = SummaryWriter(f'runs/live_{symbol.replace("/", "_")}_{current_time}')
# Track performance metrics
trades_count = 0
@ -2391,10 +2441,14 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
# Create directory for trade logs
os.makedirs('trade_logs', exist_ok=True)
trade_log_path = f'trade_logs/trades_{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv'
# Fix the datetime usage here
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
trade_log_path = f'trade_logs/trades_{current_time}.csv'
with open(trade_log_path, 'w') as f:
f.write("timestamp,action,price,position_size,balance,pnl\n")
logger.info("Entering live trading loop...")
try:
while True:
try:
@ -2410,19 +2464,54 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
# Get current state and select action
state = env.get_state()
# Verify state shape matches agent's expected input
if state.shape[0] != agent.state_size:
logger.warning(f"State size mismatch: got {state.shape[0]}, expected {agent.state_size}")
# Pad or truncate state to match expected size
if state.shape[0] < agent.state_size:
state = np.pad(state, (0, agent.state_size - state.shape[0]))
else:
state = state[:agent.state_size]
action = agent.select_action(state, training=False)
# Ensure action is valid
if action >= agent.action_size:
logger.warning(f"Invalid action {action}, clipping to {agent.action_size-1}")
action = agent.action_size - 1
# Log action
action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE"
logger.info(f"Step {step_counter}: Action selected: {action_name}, Price: ${env.data[-1]['close']:.2f}")
# Execute action
if not demo:
# Execute real trade on exchange
current_price = env.data[-1]['close']
trade_result = await env.execute_real_trade(exchange, action, current_price)
if not trade_result['success']:
logger.error(f"Trade execution failed: {trade_result['error']}")
if trade_result is None or not isinstance(trade_result, dict) or not trade_result.get('success', False):
error_msg = trade_result.get('error', 'Unknown error') if isinstance(trade_result, dict) else 'Trade execution failed'
logger.error(f"Trade execution failed: {error_msg}")
# Continue with simulated trade for tracking purposes
# Update environment with action (simulated in demo mode)
next_state, reward, done, info = env.step(action)
try:
next_state, reward, done, info = env.step(action)
except ValueError as e:
# Handle case where step returns 3 values instead of 4
if "not enough values to unpack" in str(e):
logger.warning("Step function returned 3 values instead of 4, creating info dict")
next_state, reward, done = env.step(action)
info = {
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
'price': env.current_price,
'balance': env.balance,
'position': env.position,
'pnl': env.total_pnl
}
else:
raise
# Log trade if position changed
if env.position != prev_position:
@ -2433,7 +2522,7 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
# Log trade details
with open(trade_log_path, 'a') as f:
f.write(f"{datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n")
f.write(f"{datetime.datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n")
logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}")
@ -2472,10 +2561,12 @@ async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m",
prev_position = env.position
# Wait for next candle
logger.info(f"Waiting for next candle... (Step {step_counter})")
await asyncio.sleep(10) # Check every 10 seconds
except Exception as e:
logger.error(f"Error in live trading loop: {str(e)}")
import traceback
logger.error(traceback.format_exc())
logger.info("Continuing after error...")
await asyncio.sleep(30) # Wait longer after an error
@ -2549,7 +2640,26 @@ async def fetch_ohlcv_data(exchange, symbol, timeframe, limit):
logger.error(f"Error fetching OHLCV data: {e}")
return []
# Add this near the top of the file, after imports
def ensure_pytorch_compatibility():
"""Ensure compatibility with PyTorch 2.6+ for model loading"""
try:
import torch
from torch.serialization import add_safe_globals
import numpy as np
# Add numpy scalar to safe globals for PyTorch 2.6+
add_safe_globals(['numpy._core.multiarray.scalar'])
logger.info("Added numpy scalar to PyTorch safe globals")
except (ImportError, AttributeError) as e:
logger.warning(f"Could not configure PyTorch compatibility: {e}")
logger.warning("This might cause issues with model loading in PyTorch 2.6+")
# Call this function at the start of the main function
async def main():
# Ensure PyTorch compatibility
ensure_pytorch_compatibility()
parser = argparse.ArgumentParser(description='Trading Bot')
parser.add_argument('--mode', type=str, choices=['train', 'eval', 'live'], default='train',
help='Operation mode: train, eval, or live')
@ -2583,60 +2693,84 @@ async def main():
# Create environment
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode)
# Fetch initial data
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
# Create agent
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
if args.mode == 'train':
# Fetch initial data for training
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
# Create agent with consistent parameters
# Note: Using STATE_SIZE and action_size=4 for consistency
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
# Train the agent
logger.info(f"Starting training for {args.episodes} episodes...")
stats = await train_agent(agent, env, num_episodes=args.episodes)
elif args.mode == 'evaluate':
# Load trained model
agent.load("models/trading_agent_best_pnl.pt")
elif args.mode == 'eval' or args.mode == 'live':
# Fetch initial data for the specified symbol and timeframe
await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000)
# Evaluate the agent
logger.info("Evaluating agent...")
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env)
elif args.mode == 'live':
# Initialize exchange
exchange = await initialize_exchange()
# Load model
model_path = args.model if args.model else "models/trading_agent.pt"
# Determine model path
model_path = args.model if args.model else "models/trading_agent_best_pnl.pt"
if not os.path.exists(model_path):
logger.error(f"Model file not found: {model_path}")
return
# Initialize environment
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=WINDOW_SIZE, demo=demo_mode)
await env.fetch_initial_data(exchange, symbol=args.symbol, timeframe=args.timeframe)
# Create agent with default parameters
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
# Initialize agent
state_size = env.get_state().shape[0]
agent = Agent(state_size=state_size, action_size=3)
agent.load(model_path)
# Try to load the model
try:
# Add numpy scalar to safe globals before loading
import numpy as np
from torch.serialization import add_safe_globals
# Start live trading
await live_trading(
agent=agent,
env=env,
exchange=exchange,
symbol=args.symbol,
timeframe=args.timeframe,
demo=demo_mode,
leverage=args.leverage
)
# Add numpy scalar to safe globals
add_safe_globals(['numpy._core.multiarray.scalar'])
# Load the model
agent.load(model_path)
logger.info(f"Model loaded successfully from {model_path}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
# Ask user if they want to continue with a new model
if args.mode == 'live':
confirmation = input("Failed to load model. Continue with a new model? (y/n): ")
if confirmation.lower() != 'y':
logger.info("Live trading canceled by user")
return
logger.info("Continuing with a new model")
else:
logger.info("Continuing evaluation with a new model")
if args.mode == 'eval':
# Evaluate the agent
logger.info("Evaluating agent...")
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env, num_episodes=args.episodes)
elif args.mode == 'live':
# Start live trading
logger.info(f"Starting live trading for {args.symbol} on {args.timeframe} timeframe")
logger.info(f"Demo mode: {demo_mode}, Leverage: {args.leverage}x")
await live_trading(
agent=agent,
env=env,
exchange=exchange,
symbol=args.symbol,
timeframe=args.timeframe,
demo=demo_mode,
leverage=args.leverage
)
except Exception as e:
logger.error(f"Error in main function: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
# Clean up exchange connection - safely close if possible
# Clean up exchange connection
if exchange:
try:
# Some CCXT exchanges have close method, others don't
if hasattr(exchange, 'close'):
await exchange.close()
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
@ -2667,7 +2801,8 @@ def create_candlestick_figure(data, trade_signals, window_size=100, title=""):
# Plot candlesticks - use a simpler approach if mplfinance fails
try:
mpf.plot(df, type='candle', style='yahoo', ax=price_ax, volume=volume_ax)
# Use a different style or approach that doesn't use 'type' parameter
mpf.plot(df, type='candle', ax=price_ax, volume=volume_ax, style='yahoo')
except Exception as e:
logger.warning(f"Error plotting with mplfinance: {e}, falling back to simple plot")
# Fallback to simple plot