more features, but dead-end
This commit is contained in:
parent
2e901e18f2
commit
d5c291d15c
12
crypto/gogo2/.vscode/launch.json
vendored
12
crypto/gogo2/.vscode/launch.json
vendored
@ -5,8 +5,8 @@
|
|||||||
"name": "Train Bot",
|
"name": "Train Bot",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "main.py",
|
"program": "main_multiu_broken.py",
|
||||||
"args": ["--mode", "train", "--episodes", "100"],
|
"args": ["--mode", "train", "--episodes", "10000"],
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true
|
"justMyCode": true
|
||||||
},
|
},
|
||||||
@ -14,7 +14,7 @@
|
|||||||
"name": "Evaluate Bot",
|
"name": "Evaluate Bot",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "main.py",
|
"program": "main_multiu_broken.py",
|
||||||
"args": ["--mode", "eval", "--episodes", "10"],
|
"args": ["--mode", "eval", "--episodes", "10"],
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true
|
"justMyCode": true
|
||||||
@ -23,7 +23,7 @@
|
|||||||
"name": "Live Trading (Demo)",
|
"name": "Live Trading (Demo)",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "main.py",
|
"program": "main_multiu_broken.py",
|
||||||
"args": ["--mode", "live", "--demo"],
|
"args": ["--mode", "live", "--demo"],
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true
|
"justMyCode": true
|
||||||
@ -32,7 +32,7 @@
|
|||||||
"name": "Live Trading (Real)",
|
"name": "Live Trading (Real)",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "main.py",
|
"program": "main_multiu_broken.py",
|
||||||
"args": ["--mode", "live"],
|
"args": ["--mode", "live"],
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true
|
"justMyCode": true
|
||||||
@ -41,7 +41,7 @@
|
|||||||
"name": "Continuous Training",
|
"name": "Continuous Training",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "main.py",
|
"program": "main_multiu_broken.py",
|
||||||
"args": ["--mode", "continuous", "--refresh-data"],
|
"args": ["--mode", "continuous", "--refresh-data"],
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true
|
"justMyCode": true
|
||||||
|
@ -202,6 +202,28 @@ class EnhancedReplayBuffer:
|
|||||||
self.n_step_buffer = []
|
self.n_step_buffer = []
|
||||||
self.max_priority = 1.0
|
self.max_priority = 1.0
|
||||||
|
|
||||||
|
def add(self, state, action, reward, next_state, done):
|
||||||
|
"""
|
||||||
|
Add a new experience to the buffer (simplified version of push for compatibility)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current state
|
||||||
|
action: Action taken
|
||||||
|
reward: Reward received
|
||||||
|
next_state: Next state
|
||||||
|
done: Whether the episode is done
|
||||||
|
"""
|
||||||
|
# Store in replay buffer with max priority
|
||||||
|
if len(self.buffer) < self.capacity:
|
||||||
|
self.buffer.append(None)
|
||||||
|
self.buffer[self.position] = (state, action, reward, next_state, done)
|
||||||
|
|
||||||
|
# Set priority to max priority to ensure it gets sampled
|
||||||
|
self.priorities[self.position] = self.max_priority
|
||||||
|
|
||||||
|
# Move position pointer
|
||||||
|
self.position = (self.position + 1) % self.capacity
|
||||||
|
|
||||||
def push(self, state, action, reward, next_state, done):
|
def push(self, state, action, reward, next_state, done):
|
||||||
# Store experience in n-step buffer
|
# Store experience in n-step buffer
|
||||||
self.n_step_buffer.append((state, action, reward, next_state, done))
|
self.n_step_buffer.append((state, action, reward, next_state, done))
|
||||||
@ -278,15 +300,8 @@ class EnhancedReplayBuffer:
|
|||||||
next_states.append(next_state)
|
next_states.append(next_state)
|
||||||
dones.append(done)
|
dones.append(done)
|
||||||
|
|
||||||
return (
|
# Return only the states, actions, rewards, next_states, dones for compatibility with learn function
|
||||||
torch.stack(states),
|
return states, actions, rewards, next_states, dones
|
||||||
torch.tensor(actions),
|
|
||||||
torch.tensor(rewards, dtype=torch.float32),
|
|
||||||
torch.stack(next_states),
|
|
||||||
torch.tensor(dones, dtype=torch.float32),
|
|
||||||
indices,
|
|
||||||
weights
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_priorities(self, indices, td_errors):
|
def update_priorities(self, indices, td_errors):
|
||||||
for idx, td_error in zip(indices, td_errors):
|
for idx, td_error in zip(indices, td_errors):
|
||||||
|
@ -2,6 +2,7 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch.amp import GradScaler, autocast
|
from torch.amp import GradScaler, autocast
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -9,7 +10,7 @@ from datetime import datetime
|
|||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
# Import our enhanced models
|
# Import our enhanced models
|
||||||
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN, EnhancedReplayBuffer, train_price_predictor, prepare_multi_timeframe_data
|
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN, EnhancedReplayBuffer
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
TIMEFRAMES = ['1m', '15m', '1h']
|
TIMEFRAMES = ['1m', '15m', '1h']
|
||||||
@ -218,285 +219,453 @@ def load_checkpoint(price_model, dqn_model, optimizer, episode=None):
|
|||||||
print("No checkpoint found, starting training from scratch")
|
print("No checkpoint found, starting training from scratch")
|
||||||
return 0, [], [], []
|
return 0, [], [], []
|
||||||
|
|
||||||
def enhanced_train_agent(exchange, num_episodes=NUM_EPISODES, continuous=CONTINUOUS_MODE, start_episode=CONTINUOUS_START_EPISODE):
|
def enhanced_train_agent(exchange, num_episodes=NUM_EPISODES, continuous=CONTINUOUS_MODE, start_episode=CONTINUOUS_START_EPISODE, verbose=False):
|
||||||
"""
|
"""
|
||||||
Train the enhanced trading agent using multi-timeframe data
|
Train an enhanced trading agent with multi-timeframe models
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
exchange: Exchange object to fetch data from
|
exchange: Exchange simulator or real exchange
|
||||||
num_episodes: Number of episodes to train for
|
num_episodes: Number of episodes to train for
|
||||||
continuous: Whether to continue training from a checkpoint
|
continuous: Whether to continue training from a checkpoint
|
||||||
start_episode: Episode to start from if continuous training
|
start_episode: Episode to start from for continuous training
|
||||||
|
verbose: Whether to enable verbose logging
|
||||||
"""
|
"""
|
||||||
print(f"Training on device: {DEVICE}")
|
|
||||||
|
|
||||||
# Set up TensorBoard
|
# Set up TensorBoard
|
||||||
writer = setup_tensorboard()
|
writer = setup_tensorboard()
|
||||||
|
|
||||||
# Initialize models
|
# Initialize models
|
||||||
state_dim = 100 # Increased state dimension for multi-timeframe features
|
state_dim = 100 # Increased state dimension for enhanced features
|
||||||
action_dim = 3 # Buy, Sell, Hold
|
action_dim = 3 # 0: HOLD, 1: BUY, 2: SELL
|
||||||
|
|
||||||
|
# Initialize price prediction model with multi-timeframe support
|
||||||
price_model = EnhancedPricePredictionModel(
|
price_model = EnhancedPricePredictionModel(
|
||||||
input_dim=2, # Price and volume
|
input_dim=2, # Price and volume
|
||||||
hidden_dim=256,
|
hidden_dim=256,
|
||||||
num_layers=3,
|
num_layers=3,
|
||||||
output_dim=5, # Predict next 5 candles
|
output_dim=5, # OHLCV prediction
|
||||||
num_timeframes=len(TIMEFRAMES)
|
num_timeframes=len(TIMEFRAMES)
|
||||||
).to(DEVICE)
|
).to(DEVICE)
|
||||||
|
|
||||||
|
# Initialize DQN model with enhanced architecture
|
||||||
dqn_model = EnhancedDQN(
|
dqn_model = EnhancedDQN(
|
||||||
state_dim=state_dim,
|
state_dim=state_dim,
|
||||||
action_dim=action_dim,
|
action_dim=action_dim,
|
||||||
hidden_dim=512
|
hidden_dim=512
|
||||||
).to(DEVICE)
|
).to(DEVICE)
|
||||||
|
|
||||||
target_dqn = EnhancedDQN(
|
# Initialize target network
|
||||||
|
target_model = EnhancedDQN(
|
||||||
state_dim=state_dim,
|
state_dim=state_dim,
|
||||||
action_dim=action_dim,
|
action_dim=action_dim,
|
||||||
hidden_dim=512
|
hidden_dim=512
|
||||||
).to(DEVICE)
|
).to(DEVICE)
|
||||||
|
target_model.load_state_dict(dqn_model.state_dict())
|
||||||
# Copy initial weights to target network
|
|
||||||
target_dqn.load_state_dict(dqn_model.state_dict())
|
|
||||||
|
|
||||||
# Initialize optimizer
|
# Initialize optimizer
|
||||||
optimizer = optim.Adam(list(price_model.parameters()) + list(dqn_model.parameters()), lr=LEARNING_RATE)
|
optimizer = optim.Adam(dqn_model.parameters(), lr=LEARNING_RATE)
|
||||||
|
|
||||||
# Initialize replay buffer
|
# Initialize replay buffer with prioritized experience replay
|
||||||
replay_buffer = EnhancedReplayBuffer(
|
replay_buffer = EnhancedReplayBuffer(REPLAY_BUFFER_SIZE)
|
||||||
capacity=REPLAY_BUFFER_SIZE,
|
|
||||||
alpha=0.6,
|
|
||||||
beta=0.4,
|
|
||||||
beta_increment=0.001,
|
|
||||||
n_step=3,
|
|
||||||
gamma=GAMMA
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize gradient scaler for mixed precision training
|
# Initialize training metrics
|
||||||
scaler = GradScaler(enabled=(DEVICE.type == 'cuda'))
|
|
||||||
|
|
||||||
# Initialize tracking variables
|
|
||||||
rewards = []
|
rewards = []
|
||||||
profits = []
|
profits = []
|
||||||
win_rates = []
|
win_rates = []
|
||||||
best_reward = float('-inf')
|
best_reward = float('-inf')
|
||||||
best_pnl = float('-inf')
|
best_pnl = float('-inf')
|
||||||
best_winrate = float('-inf')
|
best_winrate = 0
|
||||||
|
|
||||||
# Load checkpoint if continuous training
|
# Initialize epsilon for exploration
|
||||||
|
epsilon = EPSILON_START
|
||||||
|
|
||||||
|
# Load checkpoint if continuing training
|
||||||
if continuous:
|
if continuous:
|
||||||
start_episode, rewards, profits, win_rates = load_checkpoint(
|
try:
|
||||||
price_model, dqn_model, optimizer, start_episode
|
checkpoint_path = 'models/enhanced_trading_agent_latest.pt'
|
||||||
)
|
if os.path.exists(checkpoint_path):
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
|
||||||
|
price_model.load_state_dict(checkpoint['price_model_state_dict'])
|
||||||
|
dqn_model.load_state_dict(checkpoint['dqn_model_state_dict'])
|
||||||
|
target_model.load_state_dict(checkpoint['dqn_model_state_dict'])
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
|
||||||
# Prepare multi-timeframe data for price prediction model training
|
if 'rewards' in checkpoint:
|
||||||
data_loaders = prepare_multi_timeframe_data(exchange, TIMEFRAMES)
|
rewards = checkpoint['rewards']
|
||||||
|
if 'profits' in checkpoint:
|
||||||
|
profits = checkpoint['profits']
|
||||||
|
if 'win_rates' in checkpoint:
|
||||||
|
win_rates = checkpoint['win_rates']
|
||||||
|
if 'best_reward' in checkpoint:
|
||||||
|
best_reward = checkpoint['best_reward']
|
||||||
|
if 'best_pnl' in checkpoint:
|
||||||
|
best_pnl = checkpoint['best_pnl']
|
||||||
|
if 'best_winrate' in checkpoint:
|
||||||
|
best_winrate = checkpoint['best_winrate']
|
||||||
|
|
||||||
|
print(f"Loaded checkpoint from {checkpoint_path}")
|
||||||
|
print(f"Continuing from episode {start_episode}")
|
||||||
|
|
||||||
|
# Decay epsilon based on start episode
|
||||||
|
for _ in range(start_episode):
|
||||||
|
epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
|
||||||
|
else:
|
||||||
|
print(f"No checkpoint found at {checkpoint_path}, starting from scratch")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading checkpoint: {e}")
|
||||||
|
print("Starting from scratch")
|
||||||
|
|
||||||
|
# Set models to training mode
|
||||||
|
price_model.train()
|
||||||
|
dqn_model.train()
|
||||||
|
|
||||||
|
# Initialize gradient scaler for mixed precision training
|
||||||
|
scaler = GradScaler()
|
||||||
|
|
||||||
|
print(f"Training on device: {DEVICE}")
|
||||||
|
|
||||||
# Pre-train price prediction model
|
# Pre-train price prediction model
|
||||||
print("Pre-training price prediction model...")
|
print("Pre-training price prediction model...")
|
||||||
train_price_predictor(price_model, data_loaders, optimizer, DEVICE, epochs=5)
|
train_price_predictor(price_model, exchange, TIMEFRAMES, DEVICE, num_epochs=5, batch_size=32)
|
||||||
|
|
||||||
# Main training loop
|
# Main training loop
|
||||||
epsilon = EPSILON_START
|
for episode in range(start_episode, start_episode + num_episodes):
|
||||||
|
# Initialize state
|
||||||
for episode in range(start_episode, num_episodes):
|
|
||||||
print(f"Episode {episode+1}/{num_episodes}")
|
|
||||||
|
|
||||||
# Reset environment
|
|
||||||
state = initialize_state(exchange, TIMEFRAMES)
|
state = initialize_state(exchange, TIMEFRAMES)
|
||||||
total_reward = 0
|
|
||||||
|
# Reset environment for new episode
|
||||||
|
exchange.reset()
|
||||||
|
|
||||||
|
# Track episode metrics
|
||||||
|
episode_reward = 0
|
||||||
|
episode_steps = 0
|
||||||
|
done = False
|
||||||
trades = []
|
trades = []
|
||||||
wins = 0
|
wins = 0
|
||||||
losses = 0
|
losses = 0
|
||||||
|
|
||||||
|
# Enable verbose logging for prediction validation if requested
|
||||||
|
if verbose:
|
||||||
|
# Set logging level to DEBUG for more detailed logs
|
||||||
|
import logging
|
||||||
|
logging.getLogger("trading_bot").setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
# Add a hook to log prediction validations
|
||||||
|
def log_prediction_validation(pred, actual, was_correct):
|
||||||
|
if was_correct:
|
||||||
|
print(f"CORRECT prediction: predicted={pred:.2f}, actual={actual:.2f}")
|
||||||
|
else:
|
||||||
|
print(f"INCORRECT prediction: predicted={pred:.2f}, actual={actual:.2f}")
|
||||||
|
|
||||||
|
# Monkey patch the validate_predictions method if possible
|
||||||
|
if hasattr(exchange, 'validate_predictions'):
|
||||||
|
original_validate = exchange.validate_predictions
|
||||||
|
|
||||||
|
def verbose_validate(self, new_candle):
|
||||||
|
result = original_validate(new_candle)
|
||||||
|
print(f"Validated predictions for candle at {new_candle['timestamp']}")
|
||||||
|
if hasattr(self, 'prediction_history'):
|
||||||
|
validated = [p for p in self.prediction_history if p['validated']]
|
||||||
|
correct = [p for p in validated if p['was_correct']]
|
||||||
|
if validated:
|
||||||
|
accuracy = len(correct) / len(validated)
|
||||||
|
print(f"Prediction accuracy: {accuracy:.2f} ({len(correct)}/{len(validated)})")
|
||||||
|
return result
|
||||||
|
|
||||||
|
exchange.validate_predictions = verbose_validate.__get__(exchange, exchange.__class__)
|
||||||
|
|
||||||
# Episode loop
|
# Episode loop
|
||||||
for step in range(MAX_STEPS_PER_EPISODE):
|
while not done and episode_steps < MAX_STEPS_PER_EPISODE:
|
||||||
# Epsilon-greedy action selection
|
# Epsilon-greedy action selection
|
||||||
if np.random.random() < epsilon:
|
if np.random.random() < epsilon:
|
||||||
action = np.random.randint(0, action_dim)
|
action = np.random.randint(0, action_dim)
|
||||||
else:
|
else:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(DEVICE)
|
q_values = dqn_model(torch.FloatTensor(state).unsqueeze(0).to(DEVICE))
|
||||||
q_values, _, _ = dqn_model(state_tensor)
|
|
||||||
action = q_values.argmax().item()
|
action = q_values.argmax().item()
|
||||||
|
|
||||||
# Execute action and get next state and reward
|
# Take action in environment
|
||||||
next_state, reward, done, trade_info = step_environment(
|
next_state, reward, done, info = exchange.step(action)
|
||||||
exchange, state, action, price_model, TIMEFRAMES, DEVICE
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store transition in replay buffer
|
# Store transition in replay buffer
|
||||||
replay_buffer.push(
|
replay_buffer.add(state, action, reward, next_state, done)
|
||||||
torch.FloatTensor(state),
|
|
||||||
action,
|
|
||||||
reward,
|
|
||||||
torch.FloatTensor(next_state),
|
|
||||||
done
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update state and accumulate reward
|
# Update state and accumulate reward
|
||||||
state = next_state
|
state = next_state
|
||||||
total_reward += reward
|
episode_reward += reward
|
||||||
|
|
||||||
# Track trade outcomes
|
# Track trade outcomes
|
||||||
if trade_info is not None:
|
if 'trade' in info and info['trade']:
|
||||||
trades.append(trade_info)
|
trades.append(info['trade'])
|
||||||
if trade_info['pnl'] > 0:
|
if 'pnl_dollar' in info['trade']:
|
||||||
|
if info['trade']['pnl_dollar'] > 0:
|
||||||
wins += 1
|
wins += 1
|
||||||
elif trade_info['pnl'] < 0:
|
elif info['trade']['pnl_dollar'] < 0:
|
||||||
losses += 1
|
losses += 1
|
||||||
|
|
||||||
|
# Log trade to TensorBoard
|
||||||
|
if writer:
|
||||||
|
writer.add_scalar('Trade/PnL', info['trade']['pnl_dollar'], episode)
|
||||||
|
if 'duration' in info['trade']:
|
||||||
|
writer.add_scalar('Trade/Duration', info['trade']['duration'], episode)
|
||||||
|
|
||||||
|
# Log prediction validation metrics if available
|
||||||
|
if verbose and 'prediction_accuracy' in info:
|
||||||
|
writer.add_scalar('Prediction/Accuracy', info['prediction_accuracy'], episode)
|
||||||
|
if 'low_prediction_accuracy' in info:
|
||||||
|
writer.add_scalar('Prediction/LowAccuracy', info['low_prediction_accuracy'], episode)
|
||||||
|
if 'high_prediction_accuracy' in info:
|
||||||
|
writer.add_scalar('Prediction/HighAccuracy', info['high_prediction_accuracy'], episode)
|
||||||
|
|
||||||
# Learn from experiences if enough samples
|
# Learn from experiences if enough samples
|
||||||
if len(replay_buffer) > BATCH_SIZE:
|
if len(replay_buffer) > BATCH_SIZE:
|
||||||
learn(dqn_model, target_dqn, replay_buffer, optimizer, scaler, DEVICE)
|
learn(dqn_model, target_model, replay_buffer, optimizer, scaler, DEVICE)
|
||||||
|
|
||||||
if done:
|
episode_steps += 1
|
||||||
break
|
|
||||||
|
# Log verbose information about the current state if requested
|
||||||
|
if verbose and episode_steps % 10 == 0:
|
||||||
|
print(f"Episode {episode+1}, Step {episode_steps}, Action: {['HOLD', 'BUY', 'SELL'][action]}, Reward: {reward:.4f}")
|
||||||
|
if hasattr(exchange, 'position') and exchange.position != 'FLAT':
|
||||||
|
print(f" Position: {exchange.position}, Entry Price: {exchange.entry_price:.2f}, Current PnL: {exchange.calculate_pnl():.2f}")
|
||||||
|
|
||||||
|
# Log prediction validation status
|
||||||
|
if hasattr(exchange, 'prediction_history'):
|
||||||
|
recent_predictions = [p for p in exchange.prediction_history if p['validated']][-5:]
|
||||||
|
if recent_predictions:
|
||||||
|
print(" Recent prediction validations:")
|
||||||
|
for pred in recent_predictions:
|
||||||
|
status = "✓" if pred['was_correct'] else "✗"
|
||||||
|
print(f" {status} {pred['type']} prediction: {pred['predicted_value']:.2f} vs {pred['actual_value']:.2f}")
|
||||||
|
|
||||||
# Update target network
|
# Update target network
|
||||||
if episode % TARGET_UPDATE == 0:
|
if episode % TARGET_UPDATE == 0:
|
||||||
target_dqn.load_state_dict(dqn_model.state_dict())
|
target_model.load_state_dict(dqn_model.state_dict())
|
||||||
|
|
||||||
# Calculate episode metrics
|
# Calculate episode metrics
|
||||||
avg_reward = total_reward / (step + 1)
|
avg_reward = episode_reward / max(1, episode_steps)
|
||||||
total_pnl = sum(trade['pnl'] for trade in trades) if trades else 0
|
total_pnl = sum(trade.get('pnl_dollar', 0) for trade in trades) if trades else 0
|
||||||
win_rate = (wins / (wins + losses) * 100) if (wins + losses) > 0 else 0
|
win_rate = (wins / (wins + losses) * 100) if (wins + losses) > 0 else 0
|
||||||
|
|
||||||
# Decay epsilon
|
# Store metrics
|
||||||
epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
|
|
||||||
|
|
||||||
# Track metrics
|
|
||||||
rewards.append(avg_reward)
|
rewards.append(avg_reward)
|
||||||
profits.append(total_pnl)
|
profits.append(total_pnl)
|
||||||
win_rates.append(win_rate)
|
win_rates.append(win_rate)
|
||||||
|
|
||||||
|
# Decay epsilon
|
||||||
|
epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
|
||||||
|
|
||||||
|
# Log episode metrics
|
||||||
|
print(f"Episode {episode+1}: Reward={avg_reward:.4f}, PnL=${total_pnl:.2f}, Win Rate={win_rate:.1f}%, Epsilon={epsilon:.4f}")
|
||||||
|
|
||||||
# Log to TensorBoard
|
# Log to TensorBoard
|
||||||
writer.add_scalar('Training/Reward', avg_reward, episode)
|
if writer:
|
||||||
writer.add_scalar('Training/Profit', total_pnl, episode)
|
writer.add_scalar('Metrics/Reward', avg_reward, episode)
|
||||||
writer.add_scalar('Training/WinRate', win_rate, episode)
|
writer.add_scalar('Metrics/PnL', total_pnl, episode)
|
||||||
writer.add_scalar('Training/Epsilon', epsilon, episode)
|
writer.add_scalar('Metrics/WinRate', win_rate, episode)
|
||||||
|
writer.add_scalar('Metrics/Epsilon', epsilon, episode)
|
||||||
|
|
||||||
# Print episode summary
|
# Log prediction validation metrics if available
|
||||||
print(f"Episode {episode+1} - Avg Reward: {avg_reward:.2f}, PnL: ${total_pnl:.2f}, Win Rate: {win_rate:.1f}%")
|
if hasattr(exchange, 'prediction_history'):
|
||||||
|
validated_predictions = [p for p in exchange.prediction_history if p['validated']]
|
||||||
|
if validated_predictions:
|
||||||
|
correct_predictions = [p for p in validated_predictions if p['was_correct']]
|
||||||
|
accuracy = len(correct_predictions) / len(validated_predictions)
|
||||||
|
writer.add_scalar('Prediction/OverallAccuracy', accuracy, episode)
|
||||||
|
|
||||||
# Save models and plot results
|
# Log separate accuracies for low and high predictions
|
||||||
if episode % SAVE_INTERVAL == 0 or episode == num_episodes - 1:
|
low_predictions = [p for p in validated_predictions if p['type'] == 'low']
|
||||||
best_reward, best_pnl, best_winrate = save_models(
|
high_predictions = [p for p in validated_predictions if p['type'] == 'high']
|
||||||
price_model, dqn_model, optimizer, episode,
|
|
||||||
rewards, profits, win_rates,
|
|
||||||
best_reward, best_pnl, best_winrate
|
|
||||||
)
|
|
||||||
plot_training_results(rewards, profits, win_rates, episode)
|
|
||||||
|
|
||||||
# Close TensorBoard writer
|
if low_predictions:
|
||||||
writer.close()
|
correct_lows = [p for p in low_predictions if p['was_correct']]
|
||||||
|
low_accuracy = len(correct_lows) / len(low_predictions)
|
||||||
|
writer.add_scalar('Prediction/LowAccuracy', low_accuracy, episode)
|
||||||
|
|
||||||
# Final save and plot
|
if high_predictions:
|
||||||
best_reward, best_pnl, best_winrate = save_models(
|
correct_highs = [p for p in high_predictions if p['was_correct']]
|
||||||
price_model, dqn_model, optimizer, num_episodes - 1,
|
high_accuracy = len(correct_highs) / len(high_predictions)
|
||||||
rewards, profits, win_rates,
|
writer.add_scalar('Prediction/HighAccuracy', high_accuracy, episode)
|
||||||
best_reward, best_pnl, best_winrate
|
|
||||||
)
|
|
||||||
plot_training_results(rewards, profits, win_rates, num_episodes - 1)
|
|
||||||
|
|
||||||
print("Training complete!")
|
if verbose:
|
||||||
return price_model, dqn_model
|
print(f"Prediction accuracy: {accuracy:.2f} ({len(correct_predictions)}/{len(validated_predictions)})")
|
||||||
|
if low_predictions:
|
||||||
|
print(f" Low prediction accuracy: {low_accuracy:.2f} ({len(correct_lows)}/{len(low_predictions)})")
|
||||||
|
if high_predictions:
|
||||||
|
print(f" High prediction accuracy: {high_accuracy:.2f} ({len(correct_highs)}/{len(high_predictions)})")
|
||||||
|
|
||||||
|
# Save best models
|
||||||
|
if avg_reward > best_reward:
|
||||||
|
best_reward = avg_reward
|
||||||
|
torch.save({
|
||||||
|
'price_model_state_dict': price_model.state_dict(),
|
||||||
|
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'episode': episode,
|
||||||
|
'reward': best_reward,
|
||||||
|
'pnl': total_pnl,
|
||||||
|
'win_rate': win_rate
|
||||||
|
}, 'models/enhanced_trading_agent_best_reward.pt')
|
||||||
|
print(f"Saved best reward model: {best_reward:.4f}")
|
||||||
|
|
||||||
|
if total_pnl > best_pnl:
|
||||||
|
best_pnl = total_pnl
|
||||||
|
torch.save({
|
||||||
|
'price_model_state_dict': price_model.state_dict(),
|
||||||
|
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'episode': episode,
|
||||||
|
'reward': avg_reward,
|
||||||
|
'pnl': best_pnl,
|
||||||
|
'win_rate': win_rate
|
||||||
|
}, 'models/enhanced_trading_agent_best_pnl.pt')
|
||||||
|
print(f"Saved best PnL model: ${best_pnl:.2f}")
|
||||||
|
|
||||||
|
if win_rate > best_winrate:
|
||||||
|
best_winrate = win_rate
|
||||||
|
torch.save({
|
||||||
|
'price_model_state_dict': price_model.state_dict(),
|
||||||
|
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'episode': episode,
|
||||||
|
'reward': avg_reward,
|
||||||
|
'pnl': total_pnl,
|
||||||
|
'win_rate': best_winrate
|
||||||
|
}, 'models/enhanced_trading_agent_best_winrate.pt')
|
||||||
|
print(f"Saved best win rate model: {best_winrate:.1f}%")
|
||||||
|
|
||||||
|
# Save latest model for continuous training
|
||||||
|
torch.save({
|
||||||
|
'price_model_state_dict': price_model.state_dict(),
|
||||||
|
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'episode': episode,
|
||||||
|
'rewards': rewards,
|
||||||
|
'profits': profits,
|
||||||
|
'win_rates': win_rates,
|
||||||
|
'best_reward': best_reward,
|
||||||
|
'best_pnl': best_pnl,
|
||||||
|
'best_winrate': best_winrate
|
||||||
|
}, 'models/enhanced_trading_agent_latest.pt')
|
||||||
|
|
||||||
|
# Add additional verbose logging at the end of each episode
|
||||||
|
if verbose:
|
||||||
|
print("\nEpisode Summary:")
|
||||||
|
print(f" Total Steps: {episode_steps}")
|
||||||
|
print(f" Total Trades: {len(trades)}")
|
||||||
|
print(f" Wins/Losses: {wins}/{losses}")
|
||||||
|
print(f" Average Trade PnL: ${total_pnl/max(1, len(trades)):.2f}")
|
||||||
|
|
||||||
|
# Log prediction validation summary
|
||||||
|
if hasattr(exchange, 'prediction_history'):
|
||||||
|
validated = [p for p in exchange.prediction_history if p['validated']]
|
||||||
|
if validated:
|
||||||
|
correct = [p for p in validated if p['was_correct']]
|
||||||
|
print("\nPrediction Validation Summary:")
|
||||||
|
print(f" Total Predictions: {len(exchange.prediction_history)}")
|
||||||
|
print(f" Validated Predictions: {len(validated)}")
|
||||||
|
print(f" Correct Predictions: {len(correct)}")
|
||||||
|
print(f" Accuracy: {len(correct)/len(validated):.2f}")
|
||||||
|
|
||||||
|
# Reset prediction history for next episode if it's getting too large
|
||||||
|
if len(exchange.prediction_history) > 1000:
|
||||||
|
print(" Resetting prediction history (too large)")
|
||||||
|
exchange.prediction_history = exchange.prediction_history[-100:]
|
||||||
|
|
||||||
|
# Return final metrics
|
||||||
|
return rewards, profits, win_rates
|
||||||
|
|
||||||
def learn(dqn, target_dqn, replay_buffer, optimizer, scaler, device):
|
def learn(dqn, target_dqn, replay_buffer, optimizer, scaler, device):
|
||||||
"""Update the DQN model using experiences from the replay buffer"""
|
"""
|
||||||
# Sample from replay buffer
|
Update the DQN model using experiences from the replay buffer
|
||||||
states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(BATCH_SIZE)
|
|
||||||
|
|
||||||
# Move to device
|
Args:
|
||||||
states = states.to(device)
|
dqn: The DQN model to update
|
||||||
actions = actions.to(device)
|
target_dqn: The target DQN model for stable Q-value estimates
|
||||||
rewards = rewards.to(device)
|
replay_buffer: Replay buffer containing experiences
|
||||||
next_states = next_states.to(device)
|
optimizer: Optimizer for updating the DQN model
|
||||||
dones = dones.to(device)
|
scaler: Gradient scaler for mixed precision training
|
||||||
weights = weights.to(device)
|
device: Device to train on (CPU or GPU)
|
||||||
|
"""
|
||||||
|
if len(replay_buffer) < BATCH_SIZE:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Sample batch from replay buffer
|
||||||
|
states, actions, rewards, next_states, dones = replay_buffer.sample(BATCH_SIZE)
|
||||||
|
|
||||||
|
# Convert to tensors
|
||||||
|
states = torch.FloatTensor(states).to(device)
|
||||||
|
actions = torch.LongTensor(actions).to(device)
|
||||||
|
rewards = torch.FloatTensor(rewards).to(device)
|
||||||
|
next_states = torch.FloatTensor(next_states).to(device)
|
||||||
|
dones = torch.FloatTensor(dones).to(device)
|
||||||
|
|
||||||
# Get current Q values
|
# Get current Q values
|
||||||
if device.type == 'cuda':
|
with autocast(device_type='cuda' if device.type == 'cuda' else 'cpu'):
|
||||||
with autocast(device_type='cuda', enabled=True):
|
q_values = dqn(states)
|
||||||
current_q_values, _, _ = dqn(states)
|
if isinstance(q_values, tuple):
|
||||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
q_values = q_values[0] # Extract the tensor from the tuple
|
||||||
|
q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||||
|
|
||||||
# Compute target Q values
|
# Get next Q values from target network
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_q_values, _, _ = target_dqn(next_states)
|
next_q_values = target_dqn(next_states)
|
||||||
max_next_q_values = next_q_values.max(1)[0]
|
if isinstance(next_q_values, tuple):
|
||||||
target_q_values = rewards + (1 - dones) * GAMMA * max_next_q_values
|
next_q_values = next_q_values[0] # Extract the tensor from the tuple
|
||||||
|
next_q_values = next_q_values.max(1)[0]
|
||||||
|
target_q_values = rewards + GAMMA * next_q_values * (1 - dones)
|
||||||
|
|
||||||
# Compute loss with importance sampling weights
|
# Calculate loss
|
||||||
td_errors = target_q_values - current_q_values
|
loss = F.smooth_l1_loss(q_values, target_q_values)
|
||||||
loss = (weights * td_errors.pow(2)).mean()
|
|
||||||
else:
|
|
||||||
# CPU version without autocast
|
|
||||||
current_q_values, _, _ = dqn(states)
|
|
||||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
|
||||||
|
|
||||||
# Compute target Q values
|
# Optimize the model
|
||||||
with torch.no_grad():
|
|
||||||
next_q_values, _, _ = target_dqn(next_states)
|
|
||||||
max_next_q_values = next_q_values.max(1)[0]
|
|
||||||
target_q_values = rewards + (1 - dones) * GAMMA * max_next_q_values
|
|
||||||
|
|
||||||
# Compute loss with importance sampling weights
|
|
||||||
td_errors = target_q_values - current_q_values
|
|
||||||
loss = (weights * td_errors.pow(2)).mean()
|
|
||||||
|
|
||||||
# Update priorities in replay buffer
|
|
||||||
replay_buffer.update_priorities(indices, td_errors.abs().detach().cpu().numpy())
|
|
||||||
|
|
||||||
# Optimize the model with mixed precision
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if device.type == 'cuda':
|
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
scaler.unscale_(optimizer)
|
|
||||||
torch.nn.utils.clip_grad_norm_(dqn.parameters(), max_norm=1.0)
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
else:
|
|
||||||
# CPU version without scaler
|
|
||||||
loss.backward()
|
|
||||||
torch.nn.utils.clip_grad_norm_(dqn.parameters(), max_norm=1.0)
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
def initialize_state(exchange, timeframes):
|
def initialize_state(exchange, timeframes):
|
||||||
"""Initialize the state with data from multiple timeframes"""
|
"""
|
||||||
|
Initialize the state for the trading agent using multi-timeframe data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exchange: Exchange object to fetch data from
|
||||||
|
timeframes: List of timeframes to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
state: Initial state vector
|
||||||
|
"""
|
||||||
|
# Initialize empty state
|
||||||
|
state = []
|
||||||
|
|
||||||
# Fetch data for each timeframe
|
# Fetch data for each timeframe
|
||||||
timeframe_data = {}
|
timeframe_data = {}
|
||||||
for tf in timeframes:
|
for tf in timeframes:
|
||||||
candles = exchange.fetch_ohlcv(timeframe=tf, limit=30)
|
candles = exchange.fetch_ohlcv(timeframe=tf, limit=30)
|
||||||
|
if not candles or len(candles) < 10:
|
||||||
|
print(f"Not enough data for timeframe {tf}, using zeros")
|
||||||
|
# Use zeros if not enough data
|
||||||
|
state.extend([0] * 20) # 20 features per timeframe
|
||||||
|
continue
|
||||||
|
|
||||||
timeframe_data[tf] = candles
|
timeframe_data[tf] = candles
|
||||||
|
|
||||||
# Extract features from each timeframe
|
# Extract features for this timeframe
|
||||||
state = []
|
|
||||||
|
|
||||||
for tf in timeframes:
|
|
||||||
candles = timeframe_data[tf]
|
|
||||||
|
|
||||||
# Price features
|
|
||||||
prices = [candle[4] for candle in candles[-10:]] # Last 10 close prices
|
prices = [candle[4] for candle in candles[-10:]] # Last 10 close prices
|
||||||
price_changes = [prices[i]/prices[i-1] - 1 for i in range(1, len(prices))]
|
price_changes = [prices[i]/prices[i-1] - 1 for i in range(1, len(prices))]
|
||||||
|
|
||||||
# Volume features
|
|
||||||
volumes = [candle[5] for candle in candles[-10:]] # Last 10 volumes
|
volumes = [candle[5] for candle in candles[-10:]] # Last 10 volumes
|
||||||
volume_changes = [volumes[i]/volumes[i-1] - 1 for i in range(1, len(volumes))]
|
volume_changes = [volumes[i]/volumes[i-1] - 1 for i in range(1, len(volumes))]
|
||||||
|
|
||||||
# Technical indicators
|
# Technical indicators
|
||||||
# Simple Moving Averages
|
|
||||||
sma_5 = sum(prices[-5:]) / 5
|
sma_5 = sum(prices[-5:]) / 5
|
||||||
sma_10 = sum(prices) / 10
|
sma_10 = sum(prices) / 10
|
||||||
|
|
||||||
# Relative Strength Index (simplified)
|
# RSI (simplified)
|
||||||
gains = [max(0, price_changes[i]) for i in range(len(price_changes))]
|
gains = [max(0, pc) for pc in price_changes]
|
||||||
losses = [max(0, -price_changes[i]) for i in range(len(price_changes))]
|
losses = [max(0, -pc) for pc in price_changes]
|
||||||
avg_gain = sum(gains) / len(gains)
|
avg_gain = sum(gains) / len(gains) if gains else 0
|
||||||
avg_loss = sum(losses) / len(losses)
|
avg_loss = sum(losses) / len(losses) if losses else 1e-10
|
||||||
rs = avg_gain / (avg_loss + 1e-10) # Avoid division by zero
|
rs = avg_gain / avg_loss
|
||||||
rsi = 100 - (100 / (1 + rs))
|
rsi = 100 - (100 / (1 + rs))
|
||||||
|
|
||||||
# Add features to state
|
# Add features to state
|
||||||
@ -506,85 +675,15 @@ def initialize_state(exchange, timeframes):
|
|||||||
state.append(sma_10 / prices[-1] - 1) # 1 feature
|
state.append(sma_10 / prices[-1] - 1) # 1 feature
|
||||||
state.append(rsi / 100) # 1 feature
|
state.append(rsi / 100) # 1 feature
|
||||||
|
|
||||||
# Add market regime features
|
# Add market regime features (placeholder)
|
||||||
# This is a placeholder - in a real implementation, you would use the market_regime_classifier
|
state.extend([0, 0, 0]) # 3 features for market regime
|
||||||
# from the DQN model to predict the current market regime
|
|
||||||
state.extend([0, 0, 0]) # 3 features for market regime (one-hot encoded)
|
|
||||||
|
|
||||||
# Add additional features to reach the expected dimension of 100
|
# Pad state to reach expected dimension of 100
|
||||||
# Calculate more technical indicators
|
if len(state) < 100:
|
||||||
for tf in timeframes:
|
state.extend([0] * (100 - len(state)))
|
||||||
candles = timeframe_data[tf]
|
elif len(state) > 100:
|
||||||
prices = [candle[4] for candle in candles[-20:]] # Last 20 close prices
|
|
||||||
|
|
||||||
# Bollinger Bands
|
|
||||||
window = 20
|
|
||||||
if len(prices) >= window:
|
|
||||||
sma_20 = sum(prices[-window:]) / window
|
|
||||||
std_dev = (sum((price - sma_20) ** 2 for price in prices[-window:]) / window) ** 0.5
|
|
||||||
upper_band = sma_20 + 2 * std_dev
|
|
||||||
lower_band = sma_20 - 2 * std_dev
|
|
||||||
|
|
||||||
# Add normalized Bollinger Band features
|
|
||||||
state.append((prices[-1] - sma_20) / (upper_band - sma_20 + 1e-10)) # Position within upper band
|
|
||||||
state.append((prices[-1] - lower_band) / (sma_20 - lower_band + 1e-10)) # Position within lower band
|
|
||||||
else:
|
|
||||||
# Fallback if not enough data
|
|
||||||
state.extend([0, 0])
|
|
||||||
|
|
||||||
# MACD (Moving Average Convergence Divergence)
|
|
||||||
if len(prices) >= 26:
|
|
||||||
ema_12 = sum(prices[-12:]) / 12 # Simplified EMA
|
|
||||||
ema_26 = sum(prices[-26:]) / 26 # Simplified EMA
|
|
||||||
macd = ema_12 - ema_26
|
|
||||||
|
|
||||||
# Add normalized MACD
|
|
||||||
state.append(macd / prices[-1])
|
|
||||||
else:
|
|
||||||
# Fallback if not enough data
|
|
||||||
state.append(0)
|
|
||||||
|
|
||||||
# Add price momentum features
|
|
||||||
for tf in timeframes:
|
|
||||||
candles = timeframe_data[tf]
|
|
||||||
prices = [candle[4] for candle in candles[-30:]]
|
|
||||||
|
|
||||||
# Calculate momentum over different periods
|
|
||||||
if len(prices) >= 30:
|
|
||||||
momentum_5 = prices[-1] / prices[-5] - 1
|
|
||||||
momentum_10 = prices[-1] / prices[-10] - 1
|
|
||||||
momentum_20 = prices[-1] / prices[-20] - 1
|
|
||||||
momentum_30 = prices[-1] / prices[-30] - 1
|
|
||||||
|
|
||||||
state.extend([momentum_5, momentum_10, momentum_20, momentum_30])
|
|
||||||
else:
|
|
||||||
# Fallback if not enough data
|
|
||||||
state.extend([0, 0, 0, 0])
|
|
||||||
|
|
||||||
# Add volume profile features
|
|
||||||
for tf in timeframes:
|
|
||||||
candles = timeframe_data[tf]
|
|
||||||
volumes = [candle[5] for candle in candles[-10:]]
|
|
||||||
|
|
||||||
# Volume profile
|
|
||||||
avg_volume = sum(volumes) / len(volumes)
|
|
||||||
volume_ratio = volumes[-1] / avg_volume
|
|
||||||
|
|
||||||
# Volume trend
|
|
||||||
volume_trend = sum(1 for i in range(1, len(volumes)) if volumes[i] > volumes[i-1]) / (len(volumes) - 1)
|
|
||||||
|
|
||||||
state.extend([volume_ratio, volume_trend])
|
|
||||||
|
|
||||||
# Pad with zeros if needed to reach exactly 100 dimensions
|
|
||||||
while len(state) < 100:
|
|
||||||
state.append(0)
|
|
||||||
|
|
||||||
# Ensure state has exactly 100 dimensions
|
|
||||||
if len(state) > 100:
|
|
||||||
state = state[:100]
|
state = state[:100]
|
||||||
|
|
||||||
assert len(state) == 100, f"State dimension mismatch: {len(state)} != 100"
|
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def step_environment(exchange, state, action, price_model, timeframes, device):
|
def step_environment(exchange, state, action, price_model, timeframes, device):
|
||||||
@ -744,6 +843,145 @@ def step_environment(exchange, state, action, price_model, timeframes, device):
|
|||||||
|
|
||||||
return next_state, reward, done, trade_info
|
return next_state, reward, done, trade_info
|
||||||
|
|
||||||
|
def train_price_predictor(model, exchange, timeframes, device, num_epochs=5, batch_size=32):
|
||||||
|
"""
|
||||||
|
Train the price prediction model using data from the exchange
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The EnhancedPricePredictionModel
|
||||||
|
exchange: Exchange object to fetch data from
|
||||||
|
timeframes: List of timeframes to use
|
||||||
|
device: Device to train on (CPU or GPU)
|
||||||
|
num_epochs: Number of training epochs
|
||||||
|
batch_size: Batch size for training
|
||||||
|
"""
|
||||||
|
print(f"Training price prediction model for {num_epochs} epochs with batch size {batch_size}")
|
||||||
|
|
||||||
|
# Create optimizer
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
# Set model to training mode
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# Fetch data for each timeframe
|
||||||
|
timeframe_data = {}
|
||||||
|
for tf in timeframes:
|
||||||
|
# Fetch more data for training
|
||||||
|
candles = exchange.fetch_ohlcv(timeframe=tf, limit=500)
|
||||||
|
if not candles or len(candles) < 30:
|
||||||
|
print(f"Not enough data for timeframe {tf}, skipping training")
|
||||||
|
return model
|
||||||
|
timeframe_data[tf] = candles
|
||||||
|
|
||||||
|
# Prepare training data
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
total_loss = 0
|
||||||
|
num_batches = 0
|
||||||
|
|
||||||
|
# Create batches
|
||||||
|
for i in range(0, len(timeframe_data[timeframes[0]]) - 30, batch_size):
|
||||||
|
if i + 30 + 5 >= len(timeframe_data[timeframes[0]]):
|
||||||
|
break
|
||||||
|
|
||||||
|
# Prepare inputs for each timeframe
|
||||||
|
inputs_list = []
|
||||||
|
for tf in timeframes:
|
||||||
|
if i + 30 >= len(timeframe_data[tf]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract price and volume data for input
|
||||||
|
input_data = []
|
||||||
|
for j in range(i, i + 30):
|
||||||
|
if j < len(timeframe_data[tf]):
|
||||||
|
candle = timeframe_data[tf][j]
|
||||||
|
# Use close price and volume
|
||||||
|
input_data.append([candle[4], candle[5]])
|
||||||
|
|
||||||
|
# Convert to tensor and add batch dimension
|
||||||
|
input_tensor = torch.tensor(input_data, dtype=torch.float32).unsqueeze(0).to(device)
|
||||||
|
inputs_list.append(input_tensor)
|
||||||
|
|
||||||
|
# Skip if we don't have data for all timeframes
|
||||||
|
if len(inputs_list) != len(timeframes):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prepare targets (next 5 candles)
|
||||||
|
target_data = []
|
||||||
|
for j in range(i + 30, i + 35):
|
||||||
|
if j < len(timeframe_data[timeframes[0]]):
|
||||||
|
candle = timeframe_data[timeframes[0]][j]
|
||||||
|
# OHLCV values
|
||||||
|
target_data.append([candle[1], candle[2], candle[3], candle[4], candle[5]])
|
||||||
|
|
||||||
|
# Convert targets to tensor
|
||||||
|
price_targets = torch.tensor(target_data, dtype=torch.float32).to(device)
|
||||||
|
|
||||||
|
# Create extrema targets (binary classification for high/low points)
|
||||||
|
# Make sure it has the same shape as the model output (batch_size, 10)
|
||||||
|
extrema_targets = torch.zeros(1, 10, dtype=torch.float32).to(device) # 5 time steps, 2 classes each
|
||||||
|
|
||||||
|
# Create volume targets
|
||||||
|
volume_targets = torch.tensor([candle[5] for candle in timeframe_data[timeframes[0]][i+30:i+35]],
|
||||||
|
dtype=torch.float32).to(device)
|
||||||
|
|
||||||
|
# Zero gradients
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Forward pass
|
||||||
|
price_pred, extrema_logits, volume_pred = model(inputs_list)
|
||||||
|
|
||||||
|
# Ensure targets have the same shape as predictions
|
||||||
|
if price_pred.shape != price_targets.shape:
|
||||||
|
# Reshape price_targets to match price_pred
|
||||||
|
price_targets = price_targets.view(price_pred.shape)
|
||||||
|
|
||||||
|
if volume_pred.shape != volume_targets.shape:
|
||||||
|
# Reshape volume_targets to match volume_pred
|
||||||
|
volume_targets = volume_targets.view(volume_pred.shape)
|
||||||
|
|
||||||
|
# Calculate losses
|
||||||
|
price_loss = F.mse_loss(price_pred, price_targets)
|
||||||
|
extrema_loss = F.binary_cross_entropy_with_logits(extrema_logits, extrema_targets)
|
||||||
|
volume_loss = F.mse_loss(volume_pred, volume_targets)
|
||||||
|
|
||||||
|
# Combined loss with weighting
|
||||||
|
loss = price_loss + 0.5 * extrema_loss + 0.3 * volume_loss
|
||||||
|
|
||||||
|
# Backward pass
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Gradient clipping to prevent exploding gradients
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
if num_batches % 10 == 0:
|
||||||
|
print(f"Epoch {epoch+1}/{num_epochs}, Batch {num_batches}, Loss: {loss.item():.6f}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in batch: {e}")
|
||||||
|
print(f"Shapes - price_pred: {price_pred.shape}, price_targets: {price_targets.shape}")
|
||||||
|
print(f"Shapes - extrema_logits: {extrema_logits.shape}, extrema_targets: {extrema_targets.shape}")
|
||||||
|
print(f"Shapes - volume_pred: {volume_pred.shape}, volume_targets: {volume_targets.shape}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if num_batches > 0:
|
||||||
|
avg_loss = total_loss / num_batches
|
||||||
|
print(f"Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.6f}")
|
||||||
|
else:
|
||||||
|
print(f"Epoch {epoch+1}/{num_epochs}, No batches processed")
|
||||||
|
|
||||||
|
# Learning rate scheduling
|
||||||
|
if epoch > 0 and epoch % 2 == 0:
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] *= 0.9
|
||||||
|
|
||||||
|
print("Price prediction model training completed")
|
||||||
|
return model
|
||||||
|
|
||||||
# Main function to run training
|
# Main function to run training
|
||||||
def main():
|
def main():
|
||||||
from exchange_simulator import ExchangeSimulator
|
from exchange_simulator import ExchangeSimulator
|
||||||
@ -752,7 +990,7 @@ def main():
|
|||||||
exchange = ExchangeSimulator()
|
exchange = ExchangeSimulator()
|
||||||
|
|
||||||
# Train agent
|
# Train agent
|
||||||
price_model, dqn_model = enhanced_train_agent(
|
rewards, profits, win_rates = enhanced_train_agent(
|
||||||
exchange=exchange,
|
exchange=exchange,
|
||||||
num_episodes=NUM_EPISODES,
|
num_episodes=NUM_EPISODES,
|
||||||
continuous=CONTINUOUS_MODE,
|
continuous=CONTINUOUS_MODE,
|
||||||
|
@ -345,6 +345,439 @@ class ExchangeSimulator:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Reset the exchange simulator to its initial state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self for method chaining
|
||||||
|
"""
|
||||||
|
# Reset timestamp
|
||||||
|
self.current_timestamp = datetime.now()
|
||||||
|
|
||||||
|
# Regenerate data for each timeframe
|
||||||
|
for tf in self.timeframes:
|
||||||
|
self._generate_initial_data(tf)
|
||||||
|
|
||||||
|
# Reset any internal state
|
||||||
|
self.position = 'flat'
|
||||||
|
self.position_size = 0
|
||||||
|
self.entry_price = 0
|
||||||
|
self.stop_loss = 0
|
||||||
|
self.take_profit = 0
|
||||||
|
|
||||||
|
# Reset prediction history if it exists
|
||||||
|
if hasattr(self, 'prediction_history'):
|
||||||
|
self.prediction_history = []
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
"""
|
||||||
|
Take a step in the environment by executing an action
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: Action to take (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
next_state: Next state after taking action
|
||||||
|
reward: Reward received
|
||||||
|
done: Whether episode is done
|
||||||
|
info: Additional information
|
||||||
|
"""
|
||||||
|
# Get current price
|
||||||
|
current_price = self.data['1m'][-1][4]
|
||||||
|
|
||||||
|
# Initialize info dictionary
|
||||||
|
info = {
|
||||||
|
'price': current_price,
|
||||||
|
'timestamp': self.data['1m'][-1][0],
|
||||||
|
'trade': None
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process action
|
||||||
|
if action == 0: # HOLD
|
||||||
|
pass # No action needed
|
||||||
|
|
||||||
|
elif action == 1: # BUY/LONG
|
||||||
|
if self.position == 'flat':
|
||||||
|
# Open a new long position
|
||||||
|
self.position = 'long'
|
||||||
|
self.entry_price = current_price
|
||||||
|
self.position_size = 100 # Simplified position sizing
|
||||||
|
|
||||||
|
# Set stop loss and take profit levels
|
||||||
|
self.stop_loss = current_price * 0.99 # 1% stop loss
|
||||||
|
self.take_profit = current_price * 1.02 # 2% take profit
|
||||||
|
|
||||||
|
# Record entry time
|
||||||
|
self.entry_time = self.data['1m'][-1][0]
|
||||||
|
|
||||||
|
# Add to info
|
||||||
|
info['trade'] = {
|
||||||
|
'type': 'long',
|
||||||
|
'entry': current_price,
|
||||||
|
'entry_time': self.data['1m'][-1][0],
|
||||||
|
'size': self.position_size,
|
||||||
|
'stop_loss': self.stop_loss,
|
||||||
|
'take_profit': self.take_profit
|
||||||
|
}
|
||||||
|
|
||||||
|
elif self.position == 'short':
|
||||||
|
# Close short position and open long
|
||||||
|
pnl = self.entry_price - current_price
|
||||||
|
pnl_percent = pnl / self.entry_price * 100
|
||||||
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||||
|
|
||||||
|
# Add to info
|
||||||
|
info['trade'] = {
|
||||||
|
'type': 'short',
|
||||||
|
'entry': self.entry_price,
|
||||||
|
'exit': current_price,
|
||||||
|
'entry_time': self.entry_time,
|
||||||
|
'exit_time': self.data['1m'][-1][0],
|
||||||
|
'pnl_percent': pnl_percent,
|
||||||
|
'pnl_dollar': pnl_dollar,
|
||||||
|
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||||
|
}
|
||||||
|
|
||||||
|
# Open new long position
|
||||||
|
self.position = 'long'
|
||||||
|
self.entry_price = current_price
|
||||||
|
self.position_size = 100 # Simplified position sizing
|
||||||
|
|
||||||
|
# Set stop loss and take profit levels
|
||||||
|
self.stop_loss = current_price * 0.99 # 1% stop loss
|
||||||
|
self.take_profit = current_price * 1.02 # 2% take profit
|
||||||
|
|
||||||
|
# Record entry time
|
||||||
|
self.entry_time = self.data['1m'][-1][0]
|
||||||
|
|
||||||
|
elif action == 2: # SELL/SHORT
|
||||||
|
if self.position == 'flat':
|
||||||
|
# Open a new short position
|
||||||
|
self.position = 'short'
|
||||||
|
self.entry_price = current_price
|
||||||
|
self.position_size = 100 # Simplified position sizing
|
||||||
|
|
||||||
|
# Set stop loss and take profit levels
|
||||||
|
self.stop_loss = current_price * 1.01 # 1% stop loss
|
||||||
|
self.take_profit = current_price * 0.98 # 2% take profit
|
||||||
|
|
||||||
|
# Record entry time
|
||||||
|
self.entry_time = self.data['1m'][-1][0]
|
||||||
|
|
||||||
|
# Add to info
|
||||||
|
info['trade'] = {
|
||||||
|
'type': 'short',
|
||||||
|
'entry': current_price,
|
||||||
|
'entry_time': self.data['1m'][-1][0],
|
||||||
|
'size': self.position_size,
|
||||||
|
'stop_loss': self.stop_loss,
|
||||||
|
'take_profit': self.take_profit
|
||||||
|
}
|
||||||
|
|
||||||
|
elif self.position == 'long':
|
||||||
|
# Close long position and open short
|
||||||
|
pnl = current_price - self.entry_price
|
||||||
|
pnl_percent = pnl / self.entry_price * 100
|
||||||
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||||
|
|
||||||
|
# Add to info
|
||||||
|
info['trade'] = {
|
||||||
|
'type': 'long',
|
||||||
|
'entry': self.entry_price,
|
||||||
|
'exit': current_price,
|
||||||
|
'entry_time': self.entry_time,
|
||||||
|
'exit_time': self.data['1m'][-1][0],
|
||||||
|
'pnl_percent': pnl_percent,
|
||||||
|
'pnl_dollar': pnl_dollar,
|
||||||
|
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||||
|
}
|
||||||
|
|
||||||
|
# Open new short position
|
||||||
|
self.position = 'short'
|
||||||
|
self.entry_price = current_price
|
||||||
|
self.position_size = 100 # Simplified position sizing
|
||||||
|
|
||||||
|
# Set stop loss and take profit levels
|
||||||
|
self.stop_loss = current_price * 1.01 # 1% stop loss
|
||||||
|
self.take_profit = current_price * 0.98 # 2% take profit
|
||||||
|
|
||||||
|
# Record entry time
|
||||||
|
self.entry_time = self.data['1m'][-1][0]
|
||||||
|
|
||||||
|
# Generate next candle
|
||||||
|
self._add_new_candle('1m')
|
||||||
|
|
||||||
|
# Check if stop loss or take profit has been hit
|
||||||
|
self._check_sl_tp(info)
|
||||||
|
|
||||||
|
# Validate predictions if available
|
||||||
|
if hasattr(self, 'prediction_history') and len(self.prediction_history) > 0:
|
||||||
|
self.validate_predictions(self.data['1m'][-1])
|
||||||
|
|
||||||
|
# Prepare next state (simplified)
|
||||||
|
next_state = self._get_state()
|
||||||
|
|
||||||
|
# Calculate reward (simplified)
|
||||||
|
reward = 0
|
||||||
|
if info['trade'] is not None and 'pnl_dollar' in info['trade']:
|
||||||
|
reward = info['trade']['pnl_dollar']
|
||||||
|
|
||||||
|
# Check if done (simplified)
|
||||||
|
done = False
|
||||||
|
|
||||||
|
return next_state, reward, done, info
|
||||||
|
|
||||||
|
def _get_state(self):
|
||||||
|
"""
|
||||||
|
Get the current state of the environment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List representing the current state
|
||||||
|
"""
|
||||||
|
# Simplified state representation
|
||||||
|
state = []
|
||||||
|
|
||||||
|
# Add price features
|
||||||
|
for tf in ['1m', '5m', '15m']:
|
||||||
|
if tf in self.data:
|
||||||
|
# Get last 10 candles
|
||||||
|
candles = self.data[tf][-10:]
|
||||||
|
|
||||||
|
# Extract close prices
|
||||||
|
prices = [candle[4] for candle in candles]
|
||||||
|
|
||||||
|
# Calculate price changes
|
||||||
|
price_changes = [prices[i]/prices[i-1] - 1 for i in range(1, len(prices))]
|
||||||
|
|
||||||
|
# Add to state
|
||||||
|
state.extend(price_changes)
|
||||||
|
|
||||||
|
# Add current price relative to SMA
|
||||||
|
sma_5 = sum(prices[-5:]) / 5
|
||||||
|
sma_10 = sum(prices) / 10
|
||||||
|
state.append(prices[-1] / sma_5 - 1)
|
||||||
|
state.append(prices[-1] / sma_10 - 1)
|
||||||
|
|
||||||
|
# Pad state to 100 dimensions
|
||||||
|
while len(state) < 100:
|
||||||
|
state.append(0)
|
||||||
|
|
||||||
|
# Ensure state has exactly 100 dimensions
|
||||||
|
if len(state) > 100:
|
||||||
|
state = state[:100]
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def _check_sl_tp(self, info):
|
||||||
|
"""
|
||||||
|
Check if stop loss or take profit has been hit
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info: Info dictionary to update
|
||||||
|
"""
|
||||||
|
if self.position == 'flat':
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get current price
|
||||||
|
current_price = self.data['1m'][-1][4]
|
||||||
|
|
||||||
|
if self.position == 'long':
|
||||||
|
# Check stop loss
|
||||||
|
if current_price <= self.stop_loss:
|
||||||
|
# Stop loss hit
|
||||||
|
pnl_percent = (self.stop_loss - self.entry_price) / self.entry_price * 100
|
||||||
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||||
|
|
||||||
|
# Add to info
|
||||||
|
info['trade'] = {
|
||||||
|
'type': 'long',
|
||||||
|
'entry': self.entry_price,
|
||||||
|
'exit': self.stop_loss,
|
||||||
|
'entry_time': self.entry_time,
|
||||||
|
'exit_time': self.data['1m'][-1][0],
|
||||||
|
'pnl_percent': pnl_percent,
|
||||||
|
'pnl_dollar': pnl_dollar,
|
||||||
|
'reason': 'stop_loss',
|
||||||
|
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||||
|
}
|
||||||
|
|
||||||
|
# Reset position
|
||||||
|
self.position = 'flat'
|
||||||
|
self.entry_price = 0
|
||||||
|
self.position_size = 0
|
||||||
|
self.stop_loss = 0
|
||||||
|
self.take_profit = 0
|
||||||
|
|
||||||
|
# Check take profit
|
||||||
|
elif current_price >= self.take_profit:
|
||||||
|
# Take profit hit
|
||||||
|
pnl_percent = (self.take_profit - self.entry_price) / self.entry_price * 100
|
||||||
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||||
|
|
||||||
|
# Add to info
|
||||||
|
info['trade'] = {
|
||||||
|
'type': 'long',
|
||||||
|
'entry': self.entry_price,
|
||||||
|
'exit': self.take_profit,
|
||||||
|
'entry_time': self.entry_time,
|
||||||
|
'exit_time': self.data['1m'][-1][0],
|
||||||
|
'pnl_percent': pnl_percent,
|
||||||
|
'pnl_dollar': pnl_dollar,
|
||||||
|
'reason': 'take_profit',
|
||||||
|
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||||
|
}
|
||||||
|
|
||||||
|
# Reset position
|
||||||
|
self.position = 'flat'
|
||||||
|
self.entry_price = 0
|
||||||
|
self.position_size = 0
|
||||||
|
self.stop_loss = 0
|
||||||
|
self.take_profit = 0
|
||||||
|
|
||||||
|
elif self.position == 'short':
|
||||||
|
# Check stop loss
|
||||||
|
if current_price >= self.stop_loss:
|
||||||
|
# Stop loss hit
|
||||||
|
pnl_percent = (self.entry_price - self.stop_loss) / self.entry_price * 100
|
||||||
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||||
|
|
||||||
|
# Add to info
|
||||||
|
info['trade'] = {
|
||||||
|
'type': 'short',
|
||||||
|
'entry': self.entry_price,
|
||||||
|
'exit': self.stop_loss,
|
||||||
|
'entry_time': self.entry_time,
|
||||||
|
'exit_time': self.data['1m'][-1][0],
|
||||||
|
'pnl_percent': pnl_percent,
|
||||||
|
'pnl_dollar': pnl_dollar,
|
||||||
|
'reason': 'stop_loss',
|
||||||
|
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||||
|
}
|
||||||
|
|
||||||
|
# Reset position
|
||||||
|
self.position = 'flat'
|
||||||
|
self.entry_price = 0
|
||||||
|
self.position_size = 0
|
||||||
|
self.stop_loss = 0
|
||||||
|
self.take_profit = 0
|
||||||
|
|
||||||
|
# Check take profit
|
||||||
|
elif current_price <= self.take_profit:
|
||||||
|
# Take profit hit
|
||||||
|
pnl_percent = (self.entry_price - self.take_profit) / self.entry_price * 100
|
||||||
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||||
|
|
||||||
|
# Add to info
|
||||||
|
info['trade'] = {
|
||||||
|
'type': 'short',
|
||||||
|
'entry': self.entry_price,
|
||||||
|
'exit': self.take_profit,
|
||||||
|
'entry_time': self.entry_time,
|
||||||
|
'exit_time': self.data['1m'][-1][0],
|
||||||
|
'pnl_percent': pnl_percent,
|
||||||
|
'pnl_dollar': pnl_dollar,
|
||||||
|
'reason': 'take_profit',
|
||||||
|
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||||
|
}
|
||||||
|
|
||||||
|
# Reset position
|
||||||
|
self.position = 'flat'
|
||||||
|
self.entry_price = 0
|
||||||
|
self.position_size = 0
|
||||||
|
self.stop_loss = 0
|
||||||
|
self.take_profit = 0
|
||||||
|
|
||||||
|
def validate_predictions(self, new_candle):
|
||||||
|
"""
|
||||||
|
Validate previous extrema predictions against new candle data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_candle: New candle data to validate against
|
||||||
|
"""
|
||||||
|
if not hasattr(self, 'prediction_history') or not self.prediction_history:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Extract candle data
|
||||||
|
timestamp = new_candle[0]
|
||||||
|
high_price = new_candle[2]
|
||||||
|
low_price = new_candle[3]
|
||||||
|
|
||||||
|
# Track validation metrics
|
||||||
|
validated_count = 0
|
||||||
|
correct_count = 0
|
||||||
|
|
||||||
|
# Check each prediction that hasn't been validated yet
|
||||||
|
for pred in self.prediction_history:
|
||||||
|
if pred.get('validated', False):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if this prediction's time has come (or passed)
|
||||||
|
if 'predicted_timestamp' in pred and timestamp >= pred['predicted_timestamp']:
|
||||||
|
pred['validated'] = True
|
||||||
|
validated_count += 1
|
||||||
|
|
||||||
|
# Check if prediction was correct
|
||||||
|
if pred['type'] == 'low':
|
||||||
|
# A low prediction is correct if price went within 0.5% of predicted low
|
||||||
|
price_diff_percent = abs(low_price - pred['price']) / pred['price'] * 100
|
||||||
|
pred['actual_price'] = low_price
|
||||||
|
pred['price_diff_percent'] = price_diff_percent
|
||||||
|
|
||||||
|
# Consider correct if within 0.5% or price went lower than predicted
|
||||||
|
was_correct = price_diff_percent < 0.5 or low_price <= pred['price']
|
||||||
|
pred['was_correct'] = was_correct
|
||||||
|
|
||||||
|
if was_correct:
|
||||||
|
correct_count += 1
|
||||||
|
|
||||||
|
elif pred['type'] == 'high':
|
||||||
|
# A high prediction is correct if price went within 0.5% of predicted high
|
||||||
|
price_diff_percent = abs(high_price - pred['price']) / pred['price'] * 100
|
||||||
|
pred['actual_price'] = high_price
|
||||||
|
pred['price_diff_percent'] = price_diff_percent
|
||||||
|
|
||||||
|
# Consider correct if within 0.5% or price went higher than predicted
|
||||||
|
was_correct = price_diff_percent < 0.5 or high_price >= pred['price']
|
||||||
|
pred['was_correct'] = was_correct
|
||||||
|
|
||||||
|
if was_correct:
|
||||||
|
correct_count += 1
|
||||||
|
|
||||||
|
# Return validation metrics
|
||||||
|
if validated_count > 0:
|
||||||
|
return {
|
||||||
|
'validated_count': validated_count,
|
||||||
|
'correct_count': correct_count,
|
||||||
|
'accuracy': correct_count / validated_count
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def calculate_pnl(self):
|
||||||
|
"""
|
||||||
|
Calculate the current profit/loss of the open position
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Current PnL in dollars, 0 if no position is open
|
||||||
|
"""
|
||||||
|
if self.position == 'flat':
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
current_price = self.data['1m'][-1][4]
|
||||||
|
|
||||||
|
if self.position == 'long':
|
||||||
|
pnl_percent = (current_price - self.entry_price) / self.entry_price * 100
|
||||||
|
elif self.position == 'short':
|
||||||
|
pnl_percent = (self.entry_price - current_price) / self.entry_price * 100
|
||||||
|
else:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||||
|
return pnl_dollar
|
||||||
|
|
||||||
# Example usage
|
# Example usage
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Create exchange simulator
|
# Create exchange simulator
|
||||||
|
3750
crypto/gogo2/main.py
3750
crypto/gogo2/main.py
File diff suppressed because it is too large
Load Diff
24
crypto/gogo2/main_function.py
Normal file
24
crypto/gogo2/main_function.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import asyncio
|
||||||
|
from exchange_simulator import ExchangeSimulator
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""
|
||||||
|
Main function to run the training process.
|
||||||
|
"""
|
||||||
|
# Initialize exchange simulator
|
||||||
|
exchange = ExchangeSimulator()
|
||||||
|
|
||||||
|
# Train agent
|
||||||
|
print("Starting training process...")
|
||||||
|
# Add your training code here
|
||||||
|
print("Training complete!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Program terminated by user")
|
8195
crypto/gogo2/main_multiu_broken.py
Normal file
8195
crypto/gogo2/main_multiu_broken.py
Normal file
File diff suppressed because it is too large
Load Diff
30
crypto/gogo2/main_multiu_fixed.py
Normal file
30
crypto/gogo2/main_multiu_fixed.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from exchange_simulator import ExchangeSimulator
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""
|
||||||
|
Main function to run the training process.
|
||||||
|
"""
|
||||||
|
# Initialize exchange simulator
|
||||||
|
exchange = ExchangeSimulator()
|
||||||
|
|
||||||
|
# Train agent
|
||||||
|
print("Starting training process...")
|
||||||
|
# Add your training code here
|
||||||
|
print("Training complete!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Program terminated by user")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error running main: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
@ -1,9 +1,15 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
import torch
|
import torch
|
||||||
from enhanced_training import enhanced_train_agent
|
from enhanced_training import enhanced_train_agent
|
||||||
from exchange_simulator import ExchangeSimulator
|
from exchange_simulator import ExchangeSimulator
|
||||||
|
|
||||||
|
# Fix for Windows asyncio
|
||||||
|
if sys.platform == 'win32':
|
||||||
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
parser = argparse.ArgumentParser(description='Enhanced Trading Bot Training')
|
parser = argparse.ArgumentParser(description='Enhanced Trading Bot Training')
|
||||||
@ -26,6 +32,9 @@ def main():
|
|||||||
parser.add_argument('--refresh-data', action='store_true',
|
parser.add_argument('--refresh-data', action='store_true',
|
||||||
help='Refresh data before training')
|
help='Refresh data before training')
|
||||||
|
|
||||||
|
parser.add_argument('--verbose', action='store_true',
|
||||||
|
help='Enable verbose logging')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Set device
|
# Set device
|
||||||
@ -51,7 +60,8 @@ def main():
|
|||||||
exchange=exchange,
|
exchange=exchange,
|
||||||
num_episodes=args.episodes,
|
num_episodes=args.episodes,
|
||||||
continuous=False,
|
continuous=False,
|
||||||
start_episode=0
|
start_episode=0,
|
||||||
|
verbose=args.verbose
|
||||||
)
|
)
|
||||||
|
|
||||||
elif args.mode == 'continuous':
|
elif args.mode == 'continuous':
|
||||||
@ -61,7 +71,8 @@ def main():
|
|||||||
exchange=exchange,
|
exchange=exchange,
|
||||||
num_episodes=args.episodes,
|
num_episodes=args.episodes,
|
||||||
continuous=True,
|
continuous=True,
|
||||||
start_episode=args.start_episode
|
start_episode=args.start_episode,
|
||||||
|
verbose=args.verbose
|
||||||
)
|
)
|
||||||
|
|
||||||
elif args.mode == 'evaluate':
|
elif args.mode == 'evaluate':
|
||||||
|
18
crypto/gogo2/run_main.py
Normal file
18
crypto/gogo2/run_main.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from enhanced_training import main
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Program terminated by user")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error running main: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
56
crypto/gogo2/visualize_logs.py
Normal file
56
crypto/gogo2/visualize_logs.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import subprocess
|
||||||
|
import webbrowser
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Visualize TensorBoard logs')
|
||||||
|
parser.add_argument('--logdir', type=str, default='./logs', help='Directory containing TensorBoard logs')
|
||||||
|
parser.add_argument('--port', type=int, default=6006, help='Port for TensorBoard server')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
log_dir = Path(args.logdir)
|
||||||
|
|
||||||
|
if not log_dir.exists():
|
||||||
|
print(f"Log directory {log_dir} does not exist. Creating it...")
|
||||||
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Check if TensorBoard is installed
|
||||||
|
try:
|
||||||
|
subprocess.run(['tensorboard', '--version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
print("TensorBoard not found. Installing...")
|
||||||
|
subprocess.run(['pip', 'install', 'tensorboard'], check=True)
|
||||||
|
|
||||||
|
# Start TensorBoard server
|
||||||
|
print(f"Starting TensorBoard server on port {args.port}...")
|
||||||
|
tensorboard_process = subprocess.Popen(
|
||||||
|
['tensorboard', '--logdir', str(log_dir), '--port', str(args.port)],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for TensorBoard to start
|
||||||
|
time.sleep(3)
|
||||||
|
|
||||||
|
# Open browser
|
||||||
|
url = f"http://localhost:{args.port}"
|
||||||
|
print(f"Opening TensorBoard in browser: {url}")
|
||||||
|
webbrowser.open(url)
|
||||||
|
|
||||||
|
print("TensorBoard is running. Press Ctrl+C to stop.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Keep the script running until interrupted
|
||||||
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("Stopping TensorBoard server...")
|
||||||
|
tensorboard_process.terminate()
|
||||||
|
tensorboard_process.wait()
|
||||||
|
print("TensorBoard server stopped.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user