765 lines
28 KiB
Python
765 lines
28 KiB
Python
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.amp import GradScaler, autocast
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from datetime import datetime
|
|
from tensorboardX import SummaryWriter
|
|
|
|
# Import our enhanced models
|
|
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN, EnhancedReplayBuffer, train_price_predictor, prepare_multi_timeframe_data
|
|
|
|
# Constants
|
|
TIMEFRAMES = ['1m', '15m', '1h']
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
LEARNING_RATE = 1e-4
|
|
BATCH_SIZE = 64
|
|
GAMMA = 0.99
|
|
REPLAY_BUFFER_SIZE = 100000
|
|
TARGET_UPDATE = 10
|
|
NUM_EPISODES = 200
|
|
MAX_STEPS_PER_EPISODE = 1000
|
|
EPSILON_START = 1.0
|
|
EPSILON_END = 0.01
|
|
EPSILON_DECAY = 0.995
|
|
SAVE_INTERVAL = 10
|
|
CONTINUOUS_MODE = True
|
|
CONTINUOUS_START_EPISODE = 0
|
|
|
|
def setup_tensorboard():
|
|
"""Set up TensorBoard for logging training metrics"""
|
|
current_time = datetime.now().strftime('%Y%m%d-%H%M%S')
|
|
log_dir = os.path.join('runs', current_time)
|
|
writer = SummaryWriter(log_dir)
|
|
return writer
|
|
|
|
def save_models(price_model, dqn_model, optimizer, episode, rewards, profits, win_rates, best_reward, best_pnl, best_winrate):
|
|
"""Save model checkpoints and clean up old ones to keep only top 5 and best PnL"""
|
|
# Create models directory if it doesn't exist
|
|
os.makedirs('models', exist_ok=True)
|
|
|
|
# Save latest models
|
|
torch.save({
|
|
'price_model_state_dict': price_model.state_dict(),
|
|
'dqn_model_state_dict': dqn_model.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'episode': episode,
|
|
'rewards': rewards,
|
|
'profits': profits,
|
|
'win_rates': win_rates
|
|
}, 'models/enhanced_trading_agent_latest.pt')
|
|
|
|
# Save continuous training checkpoint
|
|
continuous_model_path = f'models/enhanced_trading_agent_continuous_{episode}.pt'
|
|
torch.save({
|
|
'price_model_state_dict': price_model.state_dict(),
|
|
'dqn_model_state_dict': dqn_model.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'episode': episode,
|
|
'rewards': rewards,
|
|
'profits': profits,
|
|
'win_rates': win_rates
|
|
}, continuous_model_path)
|
|
|
|
# Save best models
|
|
if rewards[-1] > best_reward:
|
|
best_reward = rewards[-1]
|
|
torch.save({
|
|
'price_model_state_dict': price_model.state_dict(),
|
|
'dqn_model_state_dict': dqn_model.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'episode': episode,
|
|
'rewards': rewards,
|
|
'profits': profits,
|
|
'win_rates': win_rates
|
|
}, 'models/enhanced_trading_agent_best_reward.pt')
|
|
|
|
if profits[-1] > best_pnl:
|
|
best_pnl = profits[-1]
|
|
torch.save({
|
|
'price_model_state_dict': price_model.state_dict(),
|
|
'dqn_model_state_dict': dqn_model.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'episode': episode,
|
|
'rewards': rewards,
|
|
'profits': profits,
|
|
'win_rates': win_rates
|
|
}, 'models/enhanced_trading_agent_best_pnl.pt')
|
|
|
|
if win_rates[-1] > best_winrate:
|
|
best_winrate = win_rates[-1]
|
|
torch.save({
|
|
'price_model_state_dict': price_model.state_dict(),
|
|
'dqn_model_state_dict': dqn_model.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'episode': episode,
|
|
'rewards': rewards,
|
|
'profits': profits,
|
|
'win_rates': win_rates
|
|
}, 'models/enhanced_trading_agent_best_winrate.pt')
|
|
|
|
# Save final model at the end of training
|
|
if episode == NUM_EPISODES - 1:
|
|
torch.save({
|
|
'price_model_state_dict': price_model.state_dict(),
|
|
'dqn_model_state_dict': dqn_model.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'episode': episode,
|
|
'rewards': rewards,
|
|
'profits': profits,
|
|
'win_rates': win_rates
|
|
}, 'models/enhanced_trading_agent_final.pt')
|
|
|
|
# Clean up old models - keep only top 5 most recent and best PnL
|
|
cleanup_model_files()
|
|
|
|
return best_reward, best_pnl, best_winrate
|
|
|
|
def cleanup_model_files():
|
|
"""Keep only the top 5 most recent continuous models and the best models"""
|
|
# Files we always want to keep
|
|
essential_files = [
|
|
'enhanced_trading_agent_latest.pt',
|
|
'enhanced_trading_agent_best_reward.pt',
|
|
'enhanced_trading_agent_best_pnl.pt',
|
|
'enhanced_trading_agent_best_winrate.pt',
|
|
'enhanced_trading_agent_final.pt'
|
|
]
|
|
|
|
# Get all continuous training model files
|
|
continuous_files = []
|
|
for file in os.listdir('models'):
|
|
if file.startswith('enhanced_trading_agent_continuous_') and file.endswith('.pt'):
|
|
continuous_files.append(file)
|
|
|
|
# Sort continuous files by episode number (newest first)
|
|
if continuous_files:
|
|
try:
|
|
continuous_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]), reverse=True)
|
|
# Keep only the 5 most recent continuous files
|
|
files_to_keep = essential_files + continuous_files[:5]
|
|
except (ValueError, IndexError):
|
|
# Handle case where filename format is unexpected
|
|
print("Warning: Could not sort continuous files by episode number. Keeping all continuous files.")
|
|
files_to_keep = essential_files + continuous_files
|
|
else:
|
|
files_to_keep = essential_files
|
|
|
|
# Delete all other model files
|
|
for file in os.listdir('models'):
|
|
if file.endswith('.pt') and file not in files_to_keep:
|
|
try:
|
|
os.remove(os.path.join('models', file))
|
|
print(f"Deleted old model file: {file}")
|
|
except Exception as e:
|
|
print(f"Error deleting {file}: {e}")
|
|
|
|
def plot_training_results(rewards, profits, win_rates, episode):
|
|
"""Plot training metrics"""
|
|
plt.figure(figsize=(15, 15))
|
|
|
|
# Plot rewards
|
|
plt.subplot(3, 1, 1)
|
|
plt.plot(rewards)
|
|
plt.title('Average Reward per Episode')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Reward')
|
|
|
|
# Plot profits
|
|
plt.subplot(3, 1, 2)
|
|
plt.plot(profits)
|
|
plt.title('Profit/Loss per Episode')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('PnL ($)')
|
|
|
|
# Plot win rates
|
|
plt.subplot(3, 1, 3)
|
|
plt.plot(win_rates)
|
|
plt.title('Win Rate per Episode')
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Win Rate (%)')
|
|
plt.ylim(0, 100)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig('training_results.png')
|
|
|
|
# Also save episode-specific plots periodically
|
|
if episode % 20 == 0:
|
|
os.makedirs('visualizations', exist_ok=True)
|
|
plt.savefig(f'visualizations/training_episode_{episode}.png')
|
|
|
|
plt.close()
|
|
|
|
def load_checkpoint(price_model, dqn_model, optimizer, episode=None):
|
|
"""Load model checkpoint for continuous training"""
|
|
if episode is not None:
|
|
checkpoint_path = f'models/enhanced_trading_agent_continuous_{episode}.pt'
|
|
else:
|
|
checkpoint_path = 'models/enhanced_trading_agent_latest.pt'
|
|
|
|
if os.path.exists(checkpoint_path):
|
|
print(f"Loading checkpoint from {checkpoint_path}")
|
|
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
|
|
|
|
price_model.load_state_dict(checkpoint['price_model_state_dict'])
|
|
dqn_model.load_state_dict(checkpoint['dqn_model_state_dict'])
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
|
start_episode = checkpoint['episode'] + 1
|
|
rewards = checkpoint['rewards']
|
|
profits = checkpoint['profits']
|
|
win_rates = checkpoint['win_rates']
|
|
|
|
print(f"Resuming training from episode {start_episode}")
|
|
return start_episode, rewards, profits, win_rates
|
|
else:
|
|
print("No checkpoint found, starting training from scratch")
|
|
return 0, [], [], []
|
|
|
|
def enhanced_train_agent(exchange, num_episodes=NUM_EPISODES, continuous=CONTINUOUS_MODE, start_episode=CONTINUOUS_START_EPISODE):
|
|
"""
|
|
Train the enhanced trading agent using multi-timeframe data
|
|
|
|
Args:
|
|
exchange: Exchange object to fetch data from
|
|
num_episodes: Number of episodes to train for
|
|
continuous: Whether to continue training from a checkpoint
|
|
start_episode: Episode to start from if continuous training
|
|
"""
|
|
print(f"Training on device: {DEVICE}")
|
|
|
|
# Set up TensorBoard
|
|
writer = setup_tensorboard()
|
|
|
|
# Initialize models
|
|
state_dim = 100 # Increased state dimension for multi-timeframe features
|
|
action_dim = 3 # Buy, Sell, Hold
|
|
|
|
price_model = EnhancedPricePredictionModel(
|
|
input_dim=2, # Price and volume
|
|
hidden_dim=256,
|
|
num_layers=3,
|
|
output_dim=5, # Predict next 5 candles
|
|
num_timeframes=len(TIMEFRAMES)
|
|
).to(DEVICE)
|
|
|
|
dqn_model = EnhancedDQN(
|
|
state_dim=state_dim,
|
|
action_dim=action_dim,
|
|
hidden_dim=512
|
|
).to(DEVICE)
|
|
|
|
target_dqn = EnhancedDQN(
|
|
state_dim=state_dim,
|
|
action_dim=action_dim,
|
|
hidden_dim=512
|
|
).to(DEVICE)
|
|
|
|
# Copy initial weights to target network
|
|
target_dqn.load_state_dict(dqn_model.state_dict())
|
|
|
|
# Initialize optimizer
|
|
optimizer = optim.Adam(list(price_model.parameters()) + list(dqn_model.parameters()), lr=LEARNING_RATE)
|
|
|
|
# Initialize replay buffer
|
|
replay_buffer = EnhancedReplayBuffer(
|
|
capacity=REPLAY_BUFFER_SIZE,
|
|
alpha=0.6,
|
|
beta=0.4,
|
|
beta_increment=0.001,
|
|
n_step=3,
|
|
gamma=GAMMA
|
|
)
|
|
|
|
# Initialize gradient scaler for mixed precision training
|
|
scaler = GradScaler(enabled=(DEVICE.type == 'cuda'))
|
|
|
|
# Initialize tracking variables
|
|
rewards = []
|
|
profits = []
|
|
win_rates = []
|
|
best_reward = float('-inf')
|
|
best_pnl = float('-inf')
|
|
best_winrate = float('-inf')
|
|
|
|
# Load checkpoint if continuous training
|
|
if continuous:
|
|
start_episode, rewards, profits, win_rates = load_checkpoint(
|
|
price_model, dqn_model, optimizer, start_episode
|
|
)
|
|
|
|
# Prepare multi-timeframe data for price prediction model training
|
|
data_loaders = prepare_multi_timeframe_data(exchange, TIMEFRAMES)
|
|
|
|
# Pre-train price prediction model
|
|
print("Pre-training price prediction model...")
|
|
train_price_predictor(price_model, data_loaders, optimizer, DEVICE, epochs=5)
|
|
|
|
# Main training loop
|
|
epsilon = EPSILON_START
|
|
|
|
for episode in range(start_episode, num_episodes):
|
|
print(f"Episode {episode+1}/{num_episodes}")
|
|
|
|
# Reset environment
|
|
state = initialize_state(exchange, TIMEFRAMES)
|
|
total_reward = 0
|
|
trades = []
|
|
wins = 0
|
|
losses = 0
|
|
|
|
# Episode loop
|
|
for step in range(MAX_STEPS_PER_EPISODE):
|
|
# Epsilon-greedy action selection
|
|
if np.random.random() < epsilon:
|
|
action = np.random.randint(0, action_dim)
|
|
else:
|
|
with torch.no_grad():
|
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(DEVICE)
|
|
q_values, _, _ = dqn_model(state_tensor)
|
|
action = q_values.argmax().item()
|
|
|
|
# Execute action and get next state and reward
|
|
next_state, reward, done, trade_info = step_environment(
|
|
exchange, state, action, price_model, TIMEFRAMES, DEVICE
|
|
)
|
|
|
|
# Store transition in replay buffer
|
|
replay_buffer.push(
|
|
torch.FloatTensor(state),
|
|
action,
|
|
reward,
|
|
torch.FloatTensor(next_state),
|
|
done
|
|
)
|
|
|
|
# Update state and accumulate reward
|
|
state = next_state
|
|
total_reward += reward
|
|
|
|
# Track trade outcomes
|
|
if trade_info is not None:
|
|
trades.append(trade_info)
|
|
if trade_info['pnl'] > 0:
|
|
wins += 1
|
|
elif trade_info['pnl'] < 0:
|
|
losses += 1
|
|
|
|
# Learn from experiences if enough samples
|
|
if len(replay_buffer) > BATCH_SIZE:
|
|
learn(dqn_model, target_dqn, replay_buffer, optimizer, scaler, DEVICE)
|
|
|
|
if done:
|
|
break
|
|
|
|
# Update target network
|
|
if episode % TARGET_UPDATE == 0:
|
|
target_dqn.load_state_dict(dqn_model.state_dict())
|
|
|
|
# Calculate episode metrics
|
|
avg_reward = total_reward / (step + 1)
|
|
total_pnl = sum(trade['pnl'] for trade in trades) if trades else 0
|
|
win_rate = (wins / (wins + losses) * 100) if (wins + losses) > 0 else 0
|
|
|
|
# Decay epsilon
|
|
epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
|
|
|
|
# Track metrics
|
|
rewards.append(avg_reward)
|
|
profits.append(total_pnl)
|
|
win_rates.append(win_rate)
|
|
|
|
# Log to TensorBoard
|
|
writer.add_scalar('Training/Reward', avg_reward, episode)
|
|
writer.add_scalar('Training/Profit', total_pnl, episode)
|
|
writer.add_scalar('Training/WinRate', win_rate, episode)
|
|
writer.add_scalar('Training/Epsilon', epsilon, episode)
|
|
|
|
# Print episode summary
|
|
print(f"Episode {episode+1} - Avg Reward: {avg_reward:.2f}, PnL: ${total_pnl:.2f}, Win Rate: {win_rate:.1f}%")
|
|
|
|
# Save models and plot results
|
|
if episode % SAVE_INTERVAL == 0 or episode == num_episodes - 1:
|
|
best_reward, best_pnl, best_winrate = save_models(
|
|
price_model, dqn_model, optimizer, episode,
|
|
rewards, profits, win_rates,
|
|
best_reward, best_pnl, best_winrate
|
|
)
|
|
plot_training_results(rewards, profits, win_rates, episode)
|
|
|
|
# Close TensorBoard writer
|
|
writer.close()
|
|
|
|
# Final save and plot
|
|
best_reward, best_pnl, best_winrate = save_models(
|
|
price_model, dqn_model, optimizer, num_episodes - 1,
|
|
rewards, profits, win_rates,
|
|
best_reward, best_pnl, best_winrate
|
|
)
|
|
plot_training_results(rewards, profits, win_rates, num_episodes - 1)
|
|
|
|
print("Training complete!")
|
|
return price_model, dqn_model
|
|
|
|
def learn(dqn, target_dqn, replay_buffer, optimizer, scaler, device):
|
|
"""Update the DQN model using experiences from the replay buffer"""
|
|
# Sample from replay buffer
|
|
states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(BATCH_SIZE)
|
|
|
|
# Move to device
|
|
states = states.to(device)
|
|
actions = actions.to(device)
|
|
rewards = rewards.to(device)
|
|
next_states = next_states.to(device)
|
|
dones = dones.to(device)
|
|
weights = weights.to(device)
|
|
|
|
# Get current Q values
|
|
if device.type == 'cuda':
|
|
with autocast(device_type='cuda', enabled=True):
|
|
current_q_values, _, _ = dqn(states)
|
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
|
|
# Compute target Q values
|
|
with torch.no_grad():
|
|
next_q_values, _, _ = target_dqn(next_states)
|
|
max_next_q_values = next_q_values.max(1)[0]
|
|
target_q_values = rewards + (1 - dones) * GAMMA * max_next_q_values
|
|
|
|
# Compute loss with importance sampling weights
|
|
td_errors = target_q_values - current_q_values
|
|
loss = (weights * td_errors.pow(2)).mean()
|
|
else:
|
|
# CPU version without autocast
|
|
current_q_values, _, _ = dqn(states)
|
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
|
|
# Compute target Q values
|
|
with torch.no_grad():
|
|
next_q_values, _, _ = target_dqn(next_states)
|
|
max_next_q_values = next_q_values.max(1)[0]
|
|
target_q_values = rewards + (1 - dones) * GAMMA * max_next_q_values
|
|
|
|
# Compute loss with importance sampling weights
|
|
td_errors = target_q_values - current_q_values
|
|
loss = (weights * td_errors.pow(2)).mean()
|
|
|
|
# Update priorities in replay buffer
|
|
replay_buffer.update_priorities(indices, td_errors.abs().detach().cpu().numpy())
|
|
|
|
# Optimize the model with mixed precision
|
|
optimizer.zero_grad()
|
|
|
|
if device.type == 'cuda':
|
|
scaler.scale(loss).backward()
|
|
scaler.unscale_(optimizer)
|
|
torch.nn.utils.clip_grad_norm_(dqn.parameters(), max_norm=1.0)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
else:
|
|
# CPU version without scaler
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(dqn.parameters(), max_norm=1.0)
|
|
optimizer.step()
|
|
|
|
def initialize_state(exchange, timeframes):
|
|
"""Initialize the state with data from multiple timeframes"""
|
|
# Fetch data for each timeframe
|
|
timeframe_data = {}
|
|
for tf in timeframes:
|
|
candles = exchange.fetch_ohlcv(timeframe=tf, limit=30)
|
|
timeframe_data[tf] = candles
|
|
|
|
# Extract features from each timeframe
|
|
state = []
|
|
|
|
for tf in timeframes:
|
|
candles = timeframe_data[tf]
|
|
|
|
# Price features
|
|
prices = [candle[4] for candle in candles[-10:]] # Last 10 close prices
|
|
price_changes = [prices[i]/prices[i-1] - 1 for i in range(1, len(prices))]
|
|
|
|
# Volume features
|
|
volumes = [candle[5] for candle in candles[-10:]] # Last 10 volumes
|
|
volume_changes = [volumes[i]/volumes[i-1] - 1 for i in range(1, len(volumes))]
|
|
|
|
# Technical indicators
|
|
# Simple Moving Averages
|
|
sma_5 = sum(prices[-5:]) / 5
|
|
sma_10 = sum(prices) / 10
|
|
|
|
# Relative Strength Index (simplified)
|
|
gains = [max(0, price_changes[i]) for i in range(len(price_changes))]
|
|
losses = [max(0, -price_changes[i]) for i in range(len(price_changes))]
|
|
avg_gain = sum(gains) / len(gains)
|
|
avg_loss = sum(losses) / len(losses)
|
|
rs = avg_gain / (avg_loss + 1e-10) # Avoid division by zero
|
|
rsi = 100 - (100 / (1 + rs))
|
|
|
|
# Add features to state
|
|
state.extend(price_changes) # 9 features
|
|
state.extend(volume_changes) # 9 features
|
|
state.append(sma_5 / prices[-1] - 1) # 1 feature
|
|
state.append(sma_10 / prices[-1] - 1) # 1 feature
|
|
state.append(rsi / 100) # 1 feature
|
|
|
|
# Add market regime features
|
|
# This is a placeholder - in a real implementation, you would use the market_regime_classifier
|
|
# from the DQN model to predict the current market regime
|
|
state.extend([0, 0, 0]) # 3 features for market regime (one-hot encoded)
|
|
|
|
# Add additional features to reach the expected dimension of 100
|
|
# Calculate more technical indicators
|
|
for tf in timeframes:
|
|
candles = timeframe_data[tf]
|
|
prices = [candle[4] for candle in candles[-20:]] # Last 20 close prices
|
|
|
|
# Bollinger Bands
|
|
window = 20
|
|
if len(prices) >= window:
|
|
sma_20 = sum(prices[-window:]) / window
|
|
std_dev = (sum((price - sma_20) ** 2 for price in prices[-window:]) / window) ** 0.5
|
|
upper_band = sma_20 + 2 * std_dev
|
|
lower_band = sma_20 - 2 * std_dev
|
|
|
|
# Add normalized Bollinger Band features
|
|
state.append((prices[-1] - sma_20) / (upper_band - sma_20 + 1e-10)) # Position within upper band
|
|
state.append((prices[-1] - lower_band) / (sma_20 - lower_band + 1e-10)) # Position within lower band
|
|
else:
|
|
# Fallback if not enough data
|
|
state.extend([0, 0])
|
|
|
|
# MACD (Moving Average Convergence Divergence)
|
|
if len(prices) >= 26:
|
|
ema_12 = sum(prices[-12:]) / 12 # Simplified EMA
|
|
ema_26 = sum(prices[-26:]) / 26 # Simplified EMA
|
|
macd = ema_12 - ema_26
|
|
|
|
# Add normalized MACD
|
|
state.append(macd / prices[-1])
|
|
else:
|
|
# Fallback if not enough data
|
|
state.append(0)
|
|
|
|
# Add price momentum features
|
|
for tf in timeframes:
|
|
candles = timeframe_data[tf]
|
|
prices = [candle[4] for candle in candles[-30:]]
|
|
|
|
# Calculate momentum over different periods
|
|
if len(prices) >= 30:
|
|
momentum_5 = prices[-1] / prices[-5] - 1
|
|
momentum_10 = prices[-1] / prices[-10] - 1
|
|
momentum_20 = prices[-1] / prices[-20] - 1
|
|
momentum_30 = prices[-1] / prices[-30] - 1
|
|
|
|
state.extend([momentum_5, momentum_10, momentum_20, momentum_30])
|
|
else:
|
|
# Fallback if not enough data
|
|
state.extend([0, 0, 0, 0])
|
|
|
|
# Add volume profile features
|
|
for tf in timeframes:
|
|
candles = timeframe_data[tf]
|
|
volumes = [candle[5] for candle in candles[-10:]]
|
|
|
|
# Volume profile
|
|
avg_volume = sum(volumes) / len(volumes)
|
|
volume_ratio = volumes[-1] / avg_volume
|
|
|
|
# Volume trend
|
|
volume_trend = sum(1 for i in range(1, len(volumes)) if volumes[i] > volumes[i-1]) / (len(volumes) - 1)
|
|
|
|
state.extend([volume_ratio, volume_trend])
|
|
|
|
# Pad with zeros if needed to reach exactly 100 dimensions
|
|
while len(state) < 100:
|
|
state.append(0)
|
|
|
|
# Ensure state has exactly 100 dimensions
|
|
if len(state) > 100:
|
|
state = state[:100]
|
|
|
|
assert len(state) == 100, f"State dimension mismatch: {len(state)} != 100"
|
|
|
|
return state
|
|
|
|
def step_environment(exchange, state, action, price_model, timeframes, device):
|
|
"""
|
|
Execute action in the environment and return next state, reward, done flag, and trade info
|
|
|
|
Args:
|
|
exchange: Exchange object to interact with
|
|
state: Current state
|
|
action: Action to take (0: Hold, 1: Buy, 2: Sell)
|
|
price_model: Price prediction model
|
|
timeframes: List of timeframes to use
|
|
device: Device to run models on
|
|
|
|
Returns:
|
|
next_state: Next state after taking action
|
|
reward: Reward received
|
|
done: Whether episode is done
|
|
trade_info: Information about the trade (if any)
|
|
"""
|
|
# Fetch latest data for each timeframe
|
|
timeframe_data = {}
|
|
for tf in timeframes:
|
|
candles = exchange.fetch_ohlcv(timeframe=tf, limit=30)
|
|
timeframe_data[tf] = candles
|
|
|
|
# Prepare inputs for price prediction model
|
|
price_inputs = []
|
|
for tf in timeframes:
|
|
candles = timeframe_data[tf]
|
|
# Extract price and volume data
|
|
input_data = torch.tensor([
|
|
[candle[4], candle[5]] for candle in candles[-30:] # Last 30 candles
|
|
], dtype=torch.float32).unsqueeze(0).to(device) # Add batch dimension
|
|
price_inputs.append(input_data)
|
|
|
|
# Get price and extrema predictions
|
|
with torch.no_grad():
|
|
price_pred, extrema_logits, volume_pred = price_model(price_inputs)
|
|
|
|
# Convert predictions to numpy
|
|
price_pred = price_pred.cpu().numpy()[0] # Remove batch dimension
|
|
extrema_probs = torch.sigmoid(extrema_logits).cpu().numpy()[0]
|
|
volume_pred = volume_pred.cpu().numpy()[0]
|
|
|
|
# Execute action
|
|
current_price = timeframe_data['1m'][-1][4] # Current close price
|
|
trade_info = None
|
|
reward = 0
|
|
|
|
if action == 1: # Buy
|
|
# Check if we're at a predicted low point (good time to buy)
|
|
is_predicted_low = any(extrema_probs[i*2+1] > 0.7 for i in range(5))
|
|
|
|
# Calculate entry quality based on predictions
|
|
entry_quality = 0.5 # Default quality
|
|
if is_predicted_low:
|
|
entry_quality += 0.2 # Bonus for buying at predicted low
|
|
|
|
# Check volume confirmation
|
|
volume_increasing = volume_pred[0] > timeframe_data['1m'][-1][5]
|
|
if volume_increasing:
|
|
entry_quality += 0.1 # Bonus for increasing volume
|
|
|
|
# Execute buy order
|
|
# In a real implementation, this would interact with the exchange
|
|
# For now, we'll simulate the trade
|
|
trade_info = {
|
|
'action': 'buy',
|
|
'price': current_price,
|
|
'size': 100 * entry_quality, # Size based on entry quality
|
|
'entry_quality': entry_quality,
|
|
'pnl': 0 # Will be updated later
|
|
}
|
|
|
|
# Calculate reward
|
|
# Base reward for taking action
|
|
reward = 1
|
|
|
|
# Bonus for buying at predicted low
|
|
if is_predicted_low:
|
|
reward += 5
|
|
print("Trading at predicted low - additional reward")
|
|
|
|
# Bonus for volume confirmation
|
|
if volume_increasing:
|
|
reward += 2
|
|
print("Trading with high volume - additional reward")
|
|
|
|
elif action == 2: # Sell
|
|
# Check if we're at a predicted high point (good time to sell)
|
|
is_predicted_high = any(extrema_probs[i*2] > 0.7 for i in range(5))
|
|
|
|
# Calculate entry quality based on predictions
|
|
entry_quality = 0.5 # Default quality
|
|
if is_predicted_high:
|
|
entry_quality += 0.2 # Bonus for selling at predicted high
|
|
|
|
# Check volume confirmation
|
|
volume_increasing = volume_pred[0] > timeframe_data['1m'][-1][5]
|
|
if volume_increasing:
|
|
entry_quality += 0.1 # Bonus for increasing volume
|
|
|
|
# Execute sell order
|
|
# In a real implementation, this would interact with the exchange
|
|
# For now, we'll simulate the trade
|
|
trade_info = {
|
|
'action': 'sell',
|
|
'price': current_price,
|
|
'size': 100 * entry_quality, # Size based on entry quality
|
|
'entry_quality': entry_quality,
|
|
'pnl': 0 # Will be updated later
|
|
}
|
|
|
|
# Calculate reward
|
|
# Base reward for taking action
|
|
reward = 1
|
|
|
|
# Bonus for selling at predicted high
|
|
if is_predicted_high:
|
|
reward += 5
|
|
print("Trading at predicted high - additional reward")
|
|
|
|
# Bonus for volume confirmation
|
|
if volume_increasing:
|
|
reward += 2
|
|
print("Trading with high volume - additional reward")
|
|
|
|
else: # Hold
|
|
# Small reward for holding
|
|
reward = 0.1
|
|
|
|
# Simulate trade outcome
|
|
if trade_info is not None:
|
|
# In a real implementation, this would be based on actual market movement
|
|
# For now, we'll use the price prediction to simulate the outcome
|
|
future_price = price_pred[0] # Price in the next candle
|
|
|
|
if trade_info['action'] == 'buy':
|
|
# For buy, profit if price goes up
|
|
pnl_pct = (future_price / current_price - 1) * 100
|
|
trade_info['pnl'] = pnl_pct * trade_info['size'] / 100
|
|
else: # sell
|
|
# For sell, profit if price goes down
|
|
pnl_pct = (1 - future_price / current_price) * 100
|
|
trade_info['pnl'] = pnl_pct * trade_info['size'] / 100
|
|
|
|
# Adjust reward based on trade outcome
|
|
reward += trade_info['pnl'] * 10 # Scale PnL for reward
|
|
|
|
# Update state
|
|
next_state = initialize_state(exchange, timeframes)
|
|
|
|
# Check if episode is done
|
|
# In a real implementation, this would be based on episode length or other criteria
|
|
done = False
|
|
|
|
return next_state, reward, done, trade_info
|
|
|
|
# Main function to run training
|
|
def main():
|
|
from exchange_simulator import ExchangeSimulator
|
|
|
|
# Initialize exchange simulator
|
|
exchange = ExchangeSimulator()
|
|
|
|
# Train agent
|
|
price_model, dqn_model = enhanced_train_agent(
|
|
exchange=exchange,
|
|
num_episodes=NUM_EPISODES,
|
|
continuous=CONTINUOUS_MODE,
|
|
start_episode=CONTINUOUS_START_EPISODE
|
|
)
|
|
|
|
print("Training complete!")
|
|
|
|
if __name__ == "__main__":
|
|
main() |