more features, but dead-end
This commit is contained in:
parent
2e901e18f2
commit
d5c291d15c
12
crypto/gogo2/.vscode/launch.json
vendored
12
crypto/gogo2/.vscode/launch.json
vendored
@ -5,8 +5,8 @@
|
||||
"name": "Train Bot",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"args": ["--mode", "train", "--episodes", "100"],
|
||||
"program": "main_multiu_broken.py",
|
||||
"args": ["--mode", "train", "--episodes", "10000"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
},
|
||||
@ -14,7 +14,7 @@
|
||||
"name": "Evaluate Bot",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"program": "main_multiu_broken.py",
|
||||
"args": ["--mode", "eval", "--episodes", "10"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
@ -23,7 +23,7 @@
|
||||
"name": "Live Trading (Demo)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"program": "main_multiu_broken.py",
|
||||
"args": ["--mode", "live", "--demo"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
@ -32,7 +32,7 @@
|
||||
"name": "Live Trading (Real)",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"program": "main_multiu_broken.py",
|
||||
"args": ["--mode", "live"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
@ -41,7 +41,7 @@
|
||||
"name": "Continuous Training",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"program": "main_multiu_broken.py",
|
||||
"args": ["--mode", "continuous", "--refresh-data"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
|
@ -202,6 +202,28 @@ class EnhancedReplayBuffer:
|
||||
self.n_step_buffer = []
|
||||
self.max_priority = 1.0
|
||||
|
||||
def add(self, state, action, reward, next_state, done):
|
||||
"""
|
||||
Add a new experience to the buffer (simplified version of push for compatibility)
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
action: Action taken
|
||||
reward: Reward received
|
||||
next_state: Next state
|
||||
done: Whether the episode is done
|
||||
"""
|
||||
# Store in replay buffer with max priority
|
||||
if len(self.buffer) < self.capacity:
|
||||
self.buffer.append(None)
|
||||
self.buffer[self.position] = (state, action, reward, next_state, done)
|
||||
|
||||
# Set priority to max priority to ensure it gets sampled
|
||||
self.priorities[self.position] = self.max_priority
|
||||
|
||||
# Move position pointer
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
|
||||
def push(self, state, action, reward, next_state, done):
|
||||
# Store experience in n-step buffer
|
||||
self.n_step_buffer.append((state, action, reward, next_state, done))
|
||||
@ -278,15 +300,8 @@ class EnhancedReplayBuffer:
|
||||
next_states.append(next_state)
|
||||
dones.append(done)
|
||||
|
||||
return (
|
||||
torch.stack(states),
|
||||
torch.tensor(actions),
|
||||
torch.tensor(rewards, dtype=torch.float32),
|
||||
torch.stack(next_states),
|
||||
torch.tensor(dones, dtype=torch.float32),
|
||||
indices,
|
||||
weights
|
||||
)
|
||||
# Return only the states, actions, rewards, next_states, dones for compatibility with learn function
|
||||
return states, actions, rewards, next_states, dones
|
||||
|
||||
def update_priorities(self, indices, td_errors):
|
||||
for idx, td_error in zip(indices, td_errors):
|
||||
|
@ -2,6 +2,7 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch.amp import GradScaler, autocast
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
@ -9,7 +10,7 @@ from datetime import datetime
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
# Import our enhanced models
|
||||
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN, EnhancedReplayBuffer, train_price_predictor, prepare_multi_timeframe_data
|
||||
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN, EnhancedReplayBuffer
|
||||
|
||||
# Constants
|
||||
TIMEFRAMES = ['1m', '15m', '1h']
|
||||
@ -218,285 +219,453 @@ def load_checkpoint(price_model, dqn_model, optimizer, episode=None):
|
||||
print("No checkpoint found, starting training from scratch")
|
||||
return 0, [], [], []
|
||||
|
||||
def enhanced_train_agent(exchange, num_episodes=NUM_EPISODES, continuous=CONTINUOUS_MODE, start_episode=CONTINUOUS_START_EPISODE):
|
||||
def enhanced_train_agent(exchange, num_episodes=NUM_EPISODES, continuous=CONTINUOUS_MODE, start_episode=CONTINUOUS_START_EPISODE, verbose=False):
|
||||
"""
|
||||
Train the enhanced trading agent using multi-timeframe data
|
||||
Train an enhanced trading agent with multi-timeframe models
|
||||
|
||||
Args:
|
||||
exchange: Exchange object to fetch data from
|
||||
exchange: Exchange simulator or real exchange
|
||||
num_episodes: Number of episodes to train for
|
||||
continuous: Whether to continue training from a checkpoint
|
||||
start_episode: Episode to start from if continuous training
|
||||
start_episode: Episode to start from for continuous training
|
||||
verbose: Whether to enable verbose logging
|
||||
"""
|
||||
print(f"Training on device: {DEVICE}")
|
||||
|
||||
# Set up TensorBoard
|
||||
writer = setup_tensorboard()
|
||||
|
||||
# Initialize models
|
||||
state_dim = 100 # Increased state dimension for multi-timeframe features
|
||||
action_dim = 3 # Buy, Sell, Hold
|
||||
state_dim = 100 # Increased state dimension for enhanced features
|
||||
action_dim = 3 # 0: HOLD, 1: BUY, 2: SELL
|
||||
|
||||
# Initialize price prediction model with multi-timeframe support
|
||||
price_model = EnhancedPricePredictionModel(
|
||||
input_dim=2, # Price and volume
|
||||
hidden_dim=256,
|
||||
num_layers=3,
|
||||
output_dim=5, # Predict next 5 candles
|
||||
output_dim=5, # OHLCV prediction
|
||||
num_timeframes=len(TIMEFRAMES)
|
||||
).to(DEVICE)
|
||||
|
||||
# Initialize DQN model with enhanced architecture
|
||||
dqn_model = EnhancedDQN(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
hidden_dim=512
|
||||
).to(DEVICE)
|
||||
|
||||
target_dqn = EnhancedDQN(
|
||||
# Initialize target network
|
||||
target_model = EnhancedDQN(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
hidden_dim=512
|
||||
).to(DEVICE)
|
||||
|
||||
# Copy initial weights to target network
|
||||
target_dqn.load_state_dict(dqn_model.state_dict())
|
||||
target_model.load_state_dict(dqn_model.state_dict())
|
||||
|
||||
# Initialize optimizer
|
||||
optimizer = optim.Adam(list(price_model.parameters()) + list(dqn_model.parameters()), lr=LEARNING_RATE)
|
||||
optimizer = optim.Adam(dqn_model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
# Initialize replay buffer
|
||||
replay_buffer = EnhancedReplayBuffer(
|
||||
capacity=REPLAY_BUFFER_SIZE,
|
||||
alpha=0.6,
|
||||
beta=0.4,
|
||||
beta_increment=0.001,
|
||||
n_step=3,
|
||||
gamma=GAMMA
|
||||
)
|
||||
# Initialize replay buffer with prioritized experience replay
|
||||
replay_buffer = EnhancedReplayBuffer(REPLAY_BUFFER_SIZE)
|
||||
|
||||
# Initialize gradient scaler for mixed precision training
|
||||
scaler = GradScaler(enabled=(DEVICE.type == 'cuda'))
|
||||
|
||||
# Initialize tracking variables
|
||||
# Initialize training metrics
|
||||
rewards = []
|
||||
profits = []
|
||||
win_rates = []
|
||||
best_reward = float('-inf')
|
||||
best_pnl = float('-inf')
|
||||
best_winrate = float('-inf')
|
||||
best_winrate = 0
|
||||
|
||||
# Load checkpoint if continuous training
|
||||
# Initialize epsilon for exploration
|
||||
epsilon = EPSILON_START
|
||||
|
||||
# Load checkpoint if continuing training
|
||||
if continuous:
|
||||
start_episode, rewards, profits, win_rates = load_checkpoint(
|
||||
price_model, dqn_model, optimizer, start_episode
|
||||
)
|
||||
try:
|
||||
checkpoint_path = 'models/enhanced_trading_agent_latest.pt'
|
||||
if os.path.exists(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
|
||||
price_model.load_state_dict(checkpoint['price_model_state_dict'])
|
||||
dqn_model.load_state_dict(checkpoint['dqn_model_state_dict'])
|
||||
target_model.load_state_dict(checkpoint['dqn_model_state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
# Prepare multi-timeframe data for price prediction model training
|
||||
data_loaders = prepare_multi_timeframe_data(exchange, TIMEFRAMES)
|
||||
if 'rewards' in checkpoint:
|
||||
rewards = checkpoint['rewards']
|
||||
if 'profits' in checkpoint:
|
||||
profits = checkpoint['profits']
|
||||
if 'win_rates' in checkpoint:
|
||||
win_rates = checkpoint['win_rates']
|
||||
if 'best_reward' in checkpoint:
|
||||
best_reward = checkpoint['best_reward']
|
||||
if 'best_pnl' in checkpoint:
|
||||
best_pnl = checkpoint['best_pnl']
|
||||
if 'best_winrate' in checkpoint:
|
||||
best_winrate = checkpoint['best_winrate']
|
||||
|
||||
print(f"Loaded checkpoint from {checkpoint_path}")
|
||||
print(f"Continuing from episode {start_episode}")
|
||||
|
||||
# Decay epsilon based on start episode
|
||||
for _ in range(start_episode):
|
||||
epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
|
||||
else:
|
||||
print(f"No checkpoint found at {checkpoint_path}, starting from scratch")
|
||||
except Exception as e:
|
||||
print(f"Error loading checkpoint: {e}")
|
||||
print("Starting from scratch")
|
||||
|
||||
# Set models to training mode
|
||||
price_model.train()
|
||||
dqn_model.train()
|
||||
|
||||
# Initialize gradient scaler for mixed precision training
|
||||
scaler = GradScaler()
|
||||
|
||||
print(f"Training on device: {DEVICE}")
|
||||
|
||||
# Pre-train price prediction model
|
||||
print("Pre-training price prediction model...")
|
||||
train_price_predictor(price_model, data_loaders, optimizer, DEVICE, epochs=5)
|
||||
train_price_predictor(price_model, exchange, TIMEFRAMES, DEVICE, num_epochs=5, batch_size=32)
|
||||
|
||||
# Main training loop
|
||||
epsilon = EPSILON_START
|
||||
|
||||
for episode in range(start_episode, num_episodes):
|
||||
print(f"Episode {episode+1}/{num_episodes}")
|
||||
|
||||
# Reset environment
|
||||
for episode in range(start_episode, start_episode + num_episodes):
|
||||
# Initialize state
|
||||
state = initialize_state(exchange, TIMEFRAMES)
|
||||
total_reward = 0
|
||||
|
||||
# Reset environment for new episode
|
||||
exchange.reset()
|
||||
|
||||
# Track episode metrics
|
||||
episode_reward = 0
|
||||
episode_steps = 0
|
||||
done = False
|
||||
trades = []
|
||||
wins = 0
|
||||
losses = 0
|
||||
|
||||
# Enable verbose logging for prediction validation if requested
|
||||
if verbose:
|
||||
# Set logging level to DEBUG for more detailed logs
|
||||
import logging
|
||||
logging.getLogger("trading_bot").setLevel(logging.DEBUG)
|
||||
|
||||
# Add a hook to log prediction validations
|
||||
def log_prediction_validation(pred, actual, was_correct):
|
||||
if was_correct:
|
||||
print(f"CORRECT prediction: predicted={pred:.2f}, actual={actual:.2f}")
|
||||
else:
|
||||
print(f"INCORRECT prediction: predicted={pred:.2f}, actual={actual:.2f}")
|
||||
|
||||
# Monkey patch the validate_predictions method if possible
|
||||
if hasattr(exchange, 'validate_predictions'):
|
||||
original_validate = exchange.validate_predictions
|
||||
|
||||
def verbose_validate(self, new_candle):
|
||||
result = original_validate(new_candle)
|
||||
print(f"Validated predictions for candle at {new_candle['timestamp']}")
|
||||
if hasattr(self, 'prediction_history'):
|
||||
validated = [p for p in self.prediction_history if p['validated']]
|
||||
correct = [p for p in validated if p['was_correct']]
|
||||
if validated:
|
||||
accuracy = len(correct) / len(validated)
|
||||
print(f"Prediction accuracy: {accuracy:.2f} ({len(correct)}/{len(validated)})")
|
||||
return result
|
||||
|
||||
exchange.validate_predictions = verbose_validate.__get__(exchange, exchange.__class__)
|
||||
|
||||
# Episode loop
|
||||
for step in range(MAX_STEPS_PER_EPISODE):
|
||||
while not done and episode_steps < MAX_STEPS_PER_EPISODE:
|
||||
# Epsilon-greedy action selection
|
||||
if np.random.random() < epsilon:
|
||||
action = np.random.randint(0, action_dim)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(DEVICE)
|
||||
q_values, _, _ = dqn_model(state_tensor)
|
||||
q_values = dqn_model(torch.FloatTensor(state).unsqueeze(0).to(DEVICE))
|
||||
action = q_values.argmax().item()
|
||||
|
||||
# Execute action and get next state and reward
|
||||
next_state, reward, done, trade_info = step_environment(
|
||||
exchange, state, action, price_model, TIMEFRAMES, DEVICE
|
||||
)
|
||||
# Take action in environment
|
||||
next_state, reward, done, info = exchange.step(action)
|
||||
|
||||
# Store transition in replay buffer
|
||||
replay_buffer.push(
|
||||
torch.FloatTensor(state),
|
||||
action,
|
||||
reward,
|
||||
torch.FloatTensor(next_state),
|
||||
done
|
||||
)
|
||||
replay_buffer.add(state, action, reward, next_state, done)
|
||||
|
||||
# Update state and accumulate reward
|
||||
state = next_state
|
||||
total_reward += reward
|
||||
episode_reward += reward
|
||||
|
||||
# Track trade outcomes
|
||||
if trade_info is not None:
|
||||
trades.append(trade_info)
|
||||
if trade_info['pnl'] > 0:
|
||||
if 'trade' in info and info['trade']:
|
||||
trades.append(info['trade'])
|
||||
if 'pnl_dollar' in info['trade']:
|
||||
if info['trade']['pnl_dollar'] > 0:
|
||||
wins += 1
|
||||
elif trade_info['pnl'] < 0:
|
||||
elif info['trade']['pnl_dollar'] < 0:
|
||||
losses += 1
|
||||
|
||||
# Log trade to TensorBoard
|
||||
if writer:
|
||||
writer.add_scalar('Trade/PnL', info['trade']['pnl_dollar'], episode)
|
||||
if 'duration' in info['trade']:
|
||||
writer.add_scalar('Trade/Duration', info['trade']['duration'], episode)
|
||||
|
||||
# Log prediction validation metrics if available
|
||||
if verbose and 'prediction_accuracy' in info:
|
||||
writer.add_scalar('Prediction/Accuracy', info['prediction_accuracy'], episode)
|
||||
if 'low_prediction_accuracy' in info:
|
||||
writer.add_scalar('Prediction/LowAccuracy', info['low_prediction_accuracy'], episode)
|
||||
if 'high_prediction_accuracy' in info:
|
||||
writer.add_scalar('Prediction/HighAccuracy', info['high_prediction_accuracy'], episode)
|
||||
|
||||
# Learn from experiences if enough samples
|
||||
if len(replay_buffer) > BATCH_SIZE:
|
||||
learn(dqn_model, target_dqn, replay_buffer, optimizer, scaler, DEVICE)
|
||||
learn(dqn_model, target_model, replay_buffer, optimizer, scaler, DEVICE)
|
||||
|
||||
if done:
|
||||
break
|
||||
episode_steps += 1
|
||||
|
||||
# Log verbose information about the current state if requested
|
||||
if verbose and episode_steps % 10 == 0:
|
||||
print(f"Episode {episode+1}, Step {episode_steps}, Action: {['HOLD', 'BUY', 'SELL'][action]}, Reward: {reward:.4f}")
|
||||
if hasattr(exchange, 'position') and exchange.position != 'FLAT':
|
||||
print(f" Position: {exchange.position}, Entry Price: {exchange.entry_price:.2f}, Current PnL: {exchange.calculate_pnl():.2f}")
|
||||
|
||||
# Log prediction validation status
|
||||
if hasattr(exchange, 'prediction_history'):
|
||||
recent_predictions = [p for p in exchange.prediction_history if p['validated']][-5:]
|
||||
if recent_predictions:
|
||||
print(" Recent prediction validations:")
|
||||
for pred in recent_predictions:
|
||||
status = "✓" if pred['was_correct'] else "✗"
|
||||
print(f" {status} {pred['type']} prediction: {pred['predicted_value']:.2f} vs {pred['actual_value']:.2f}")
|
||||
|
||||
# Update target network
|
||||
if episode % TARGET_UPDATE == 0:
|
||||
target_dqn.load_state_dict(dqn_model.state_dict())
|
||||
target_model.load_state_dict(dqn_model.state_dict())
|
||||
|
||||
# Calculate episode metrics
|
||||
avg_reward = total_reward / (step + 1)
|
||||
total_pnl = sum(trade['pnl'] for trade in trades) if trades else 0
|
||||
avg_reward = episode_reward / max(1, episode_steps)
|
||||
total_pnl = sum(trade.get('pnl_dollar', 0) for trade in trades) if trades else 0
|
||||
win_rate = (wins / (wins + losses) * 100) if (wins + losses) > 0 else 0
|
||||
|
||||
# Decay epsilon
|
||||
epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
|
||||
|
||||
# Track metrics
|
||||
# Store metrics
|
||||
rewards.append(avg_reward)
|
||||
profits.append(total_pnl)
|
||||
win_rates.append(win_rate)
|
||||
|
||||
# Decay epsilon
|
||||
epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
|
||||
|
||||
# Log episode metrics
|
||||
print(f"Episode {episode+1}: Reward={avg_reward:.4f}, PnL=${total_pnl:.2f}, Win Rate={win_rate:.1f}%, Epsilon={epsilon:.4f}")
|
||||
|
||||
# Log to TensorBoard
|
||||
writer.add_scalar('Training/Reward', avg_reward, episode)
|
||||
writer.add_scalar('Training/Profit', total_pnl, episode)
|
||||
writer.add_scalar('Training/WinRate', win_rate, episode)
|
||||
writer.add_scalar('Training/Epsilon', epsilon, episode)
|
||||
if writer:
|
||||
writer.add_scalar('Metrics/Reward', avg_reward, episode)
|
||||
writer.add_scalar('Metrics/PnL', total_pnl, episode)
|
||||
writer.add_scalar('Metrics/WinRate', win_rate, episode)
|
||||
writer.add_scalar('Metrics/Epsilon', epsilon, episode)
|
||||
|
||||
# Print episode summary
|
||||
print(f"Episode {episode+1} - Avg Reward: {avg_reward:.2f}, PnL: ${total_pnl:.2f}, Win Rate: {win_rate:.1f}%")
|
||||
# Log prediction validation metrics if available
|
||||
if hasattr(exchange, 'prediction_history'):
|
||||
validated_predictions = [p for p in exchange.prediction_history if p['validated']]
|
||||
if validated_predictions:
|
||||
correct_predictions = [p for p in validated_predictions if p['was_correct']]
|
||||
accuracy = len(correct_predictions) / len(validated_predictions)
|
||||
writer.add_scalar('Prediction/OverallAccuracy', accuracy, episode)
|
||||
|
||||
# Save models and plot results
|
||||
if episode % SAVE_INTERVAL == 0 or episode == num_episodes - 1:
|
||||
best_reward, best_pnl, best_winrate = save_models(
|
||||
price_model, dqn_model, optimizer, episode,
|
||||
rewards, profits, win_rates,
|
||||
best_reward, best_pnl, best_winrate
|
||||
)
|
||||
plot_training_results(rewards, profits, win_rates, episode)
|
||||
# Log separate accuracies for low and high predictions
|
||||
low_predictions = [p for p in validated_predictions if p['type'] == 'low']
|
||||
high_predictions = [p for p in validated_predictions if p['type'] == 'high']
|
||||
|
||||
# Close TensorBoard writer
|
||||
writer.close()
|
||||
if low_predictions:
|
||||
correct_lows = [p for p in low_predictions if p['was_correct']]
|
||||
low_accuracy = len(correct_lows) / len(low_predictions)
|
||||
writer.add_scalar('Prediction/LowAccuracy', low_accuracy, episode)
|
||||
|
||||
# Final save and plot
|
||||
best_reward, best_pnl, best_winrate = save_models(
|
||||
price_model, dqn_model, optimizer, num_episodes - 1,
|
||||
rewards, profits, win_rates,
|
||||
best_reward, best_pnl, best_winrate
|
||||
)
|
||||
plot_training_results(rewards, profits, win_rates, num_episodes - 1)
|
||||
if high_predictions:
|
||||
correct_highs = [p for p in high_predictions if p['was_correct']]
|
||||
high_accuracy = len(correct_highs) / len(high_predictions)
|
||||
writer.add_scalar('Prediction/HighAccuracy', high_accuracy, episode)
|
||||
|
||||
print("Training complete!")
|
||||
return price_model, dqn_model
|
||||
if verbose:
|
||||
print(f"Prediction accuracy: {accuracy:.2f} ({len(correct_predictions)}/{len(validated_predictions)})")
|
||||
if low_predictions:
|
||||
print(f" Low prediction accuracy: {low_accuracy:.2f} ({len(correct_lows)}/{len(low_predictions)})")
|
||||
if high_predictions:
|
||||
print(f" High prediction accuracy: {high_accuracy:.2f} ({len(correct_highs)}/{len(high_predictions)})")
|
||||
|
||||
# Save best models
|
||||
if avg_reward > best_reward:
|
||||
best_reward = avg_reward
|
||||
torch.save({
|
||||
'price_model_state_dict': price_model.state_dict(),
|
||||
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'episode': episode,
|
||||
'reward': best_reward,
|
||||
'pnl': total_pnl,
|
||||
'win_rate': win_rate
|
||||
}, 'models/enhanced_trading_agent_best_reward.pt')
|
||||
print(f"Saved best reward model: {best_reward:.4f}")
|
||||
|
||||
if total_pnl > best_pnl:
|
||||
best_pnl = total_pnl
|
||||
torch.save({
|
||||
'price_model_state_dict': price_model.state_dict(),
|
||||
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'episode': episode,
|
||||
'reward': avg_reward,
|
||||
'pnl': best_pnl,
|
||||
'win_rate': win_rate
|
||||
}, 'models/enhanced_trading_agent_best_pnl.pt')
|
||||
print(f"Saved best PnL model: ${best_pnl:.2f}")
|
||||
|
||||
if win_rate > best_winrate:
|
||||
best_winrate = win_rate
|
||||
torch.save({
|
||||
'price_model_state_dict': price_model.state_dict(),
|
||||
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'episode': episode,
|
||||
'reward': avg_reward,
|
||||
'pnl': total_pnl,
|
||||
'win_rate': best_winrate
|
||||
}, 'models/enhanced_trading_agent_best_winrate.pt')
|
||||
print(f"Saved best win rate model: {best_winrate:.1f}%")
|
||||
|
||||
# Save latest model for continuous training
|
||||
torch.save({
|
||||
'price_model_state_dict': price_model.state_dict(),
|
||||
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'episode': episode,
|
||||
'rewards': rewards,
|
||||
'profits': profits,
|
||||
'win_rates': win_rates,
|
||||
'best_reward': best_reward,
|
||||
'best_pnl': best_pnl,
|
||||
'best_winrate': best_winrate
|
||||
}, 'models/enhanced_trading_agent_latest.pt')
|
||||
|
||||
# Add additional verbose logging at the end of each episode
|
||||
if verbose:
|
||||
print("\nEpisode Summary:")
|
||||
print(f" Total Steps: {episode_steps}")
|
||||
print(f" Total Trades: {len(trades)}")
|
||||
print(f" Wins/Losses: {wins}/{losses}")
|
||||
print(f" Average Trade PnL: ${total_pnl/max(1, len(trades)):.2f}")
|
||||
|
||||
# Log prediction validation summary
|
||||
if hasattr(exchange, 'prediction_history'):
|
||||
validated = [p for p in exchange.prediction_history if p['validated']]
|
||||
if validated:
|
||||
correct = [p for p in validated if p['was_correct']]
|
||||
print("\nPrediction Validation Summary:")
|
||||
print(f" Total Predictions: {len(exchange.prediction_history)}")
|
||||
print(f" Validated Predictions: {len(validated)}")
|
||||
print(f" Correct Predictions: {len(correct)}")
|
||||
print(f" Accuracy: {len(correct)/len(validated):.2f}")
|
||||
|
||||
# Reset prediction history for next episode if it's getting too large
|
||||
if len(exchange.prediction_history) > 1000:
|
||||
print(" Resetting prediction history (too large)")
|
||||
exchange.prediction_history = exchange.prediction_history[-100:]
|
||||
|
||||
# Return final metrics
|
||||
return rewards, profits, win_rates
|
||||
|
||||
def learn(dqn, target_dqn, replay_buffer, optimizer, scaler, device):
|
||||
"""Update the DQN model using experiences from the replay buffer"""
|
||||
# Sample from replay buffer
|
||||
states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(BATCH_SIZE)
|
||||
"""
|
||||
Update the DQN model using experiences from the replay buffer
|
||||
|
||||
# Move to device
|
||||
states = states.to(device)
|
||||
actions = actions.to(device)
|
||||
rewards = rewards.to(device)
|
||||
next_states = next_states.to(device)
|
||||
dones = dones.to(device)
|
||||
weights = weights.to(device)
|
||||
Args:
|
||||
dqn: The DQN model to update
|
||||
target_dqn: The target DQN model for stable Q-value estimates
|
||||
replay_buffer: Replay buffer containing experiences
|
||||
optimizer: Optimizer for updating the DQN model
|
||||
scaler: Gradient scaler for mixed precision training
|
||||
device: Device to train on (CPU or GPU)
|
||||
"""
|
||||
if len(replay_buffer) < BATCH_SIZE:
|
||||
return
|
||||
|
||||
# Sample batch from replay buffer
|
||||
states, actions, rewards, next_states, dones = replay_buffer.sample(BATCH_SIZE)
|
||||
|
||||
# Convert to tensors
|
||||
states = torch.FloatTensor(states).to(device)
|
||||
actions = torch.LongTensor(actions).to(device)
|
||||
rewards = torch.FloatTensor(rewards).to(device)
|
||||
next_states = torch.FloatTensor(next_states).to(device)
|
||||
dones = torch.FloatTensor(dones).to(device)
|
||||
|
||||
# Get current Q values
|
||||
if device.type == 'cuda':
|
||||
with autocast(device_type='cuda', enabled=True):
|
||||
current_q_values, _, _ = dqn(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
with autocast(device_type='cuda' if device.type == 'cuda' else 'cpu'):
|
||||
q_values = dqn(states)
|
||||
if isinstance(q_values, tuple):
|
||||
q_values = q_values[0] # Extract the tensor from the tuple
|
||||
q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Compute target Q values
|
||||
# Get next Q values from target network
|
||||
with torch.no_grad():
|
||||
next_q_values, _, _ = target_dqn(next_states)
|
||||
max_next_q_values = next_q_values.max(1)[0]
|
||||
target_q_values = rewards + (1 - dones) * GAMMA * max_next_q_values
|
||||
next_q_values = target_dqn(next_states)
|
||||
if isinstance(next_q_values, tuple):
|
||||
next_q_values = next_q_values[0] # Extract the tensor from the tuple
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
target_q_values = rewards + GAMMA * next_q_values * (1 - dones)
|
||||
|
||||
# Compute loss with importance sampling weights
|
||||
td_errors = target_q_values - current_q_values
|
||||
loss = (weights * td_errors.pow(2)).mean()
|
||||
else:
|
||||
# CPU version without autocast
|
||||
current_q_values, _, _ = dqn(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
# Calculate loss
|
||||
loss = F.smooth_l1_loss(q_values, target_q_values)
|
||||
|
||||
# Compute target Q values
|
||||
with torch.no_grad():
|
||||
next_q_values, _, _ = target_dqn(next_states)
|
||||
max_next_q_values = next_q_values.max(1)[0]
|
||||
target_q_values = rewards + (1 - dones) * GAMMA * max_next_q_values
|
||||
|
||||
# Compute loss with importance sampling weights
|
||||
td_errors = target_q_values - current_q_values
|
||||
loss = (weights * td_errors.pow(2)).mean()
|
||||
|
||||
# Update priorities in replay buffer
|
||||
replay_buffer.update_priorities(indices, td_errors.abs().detach().cpu().numpy())
|
||||
|
||||
# Optimize the model with mixed precision
|
||||
# Optimize the model
|
||||
optimizer.zero_grad()
|
||||
|
||||
if device.type == 'cuda':
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(dqn.parameters(), max_norm=1.0)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
# CPU version without scaler
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(dqn.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
|
||||
def initialize_state(exchange, timeframes):
|
||||
"""Initialize the state with data from multiple timeframes"""
|
||||
"""
|
||||
Initialize the state for the trading agent using multi-timeframe data
|
||||
|
||||
Args:
|
||||
exchange: Exchange object to fetch data from
|
||||
timeframes: List of timeframes to use
|
||||
|
||||
Returns:
|
||||
state: Initial state vector
|
||||
"""
|
||||
# Initialize empty state
|
||||
state = []
|
||||
|
||||
# Fetch data for each timeframe
|
||||
timeframe_data = {}
|
||||
for tf in timeframes:
|
||||
candles = exchange.fetch_ohlcv(timeframe=tf, limit=30)
|
||||
if not candles or len(candles) < 10:
|
||||
print(f"Not enough data for timeframe {tf}, using zeros")
|
||||
# Use zeros if not enough data
|
||||
state.extend([0] * 20) # 20 features per timeframe
|
||||
continue
|
||||
|
||||
timeframe_data[tf] = candles
|
||||
|
||||
# Extract features from each timeframe
|
||||
state = []
|
||||
|
||||
for tf in timeframes:
|
||||
candles = timeframe_data[tf]
|
||||
|
||||
# Price features
|
||||
# Extract features for this timeframe
|
||||
prices = [candle[4] for candle in candles[-10:]] # Last 10 close prices
|
||||
price_changes = [prices[i]/prices[i-1] - 1 for i in range(1, len(prices))]
|
||||
|
||||
# Volume features
|
||||
volumes = [candle[5] for candle in candles[-10:]] # Last 10 volumes
|
||||
volume_changes = [volumes[i]/volumes[i-1] - 1 for i in range(1, len(volumes))]
|
||||
|
||||
# Technical indicators
|
||||
# Simple Moving Averages
|
||||
sma_5 = sum(prices[-5:]) / 5
|
||||
sma_10 = sum(prices) / 10
|
||||
|
||||
# Relative Strength Index (simplified)
|
||||
gains = [max(0, price_changes[i]) for i in range(len(price_changes))]
|
||||
losses = [max(0, -price_changes[i]) for i in range(len(price_changes))]
|
||||
avg_gain = sum(gains) / len(gains)
|
||||
avg_loss = sum(losses) / len(losses)
|
||||
rs = avg_gain / (avg_loss + 1e-10) # Avoid division by zero
|
||||
# RSI (simplified)
|
||||
gains = [max(0, pc) for pc in price_changes]
|
||||
losses = [max(0, -pc) for pc in price_changes]
|
||||
avg_gain = sum(gains) / len(gains) if gains else 0
|
||||
avg_loss = sum(losses) / len(losses) if losses else 1e-10
|
||||
rs = avg_gain / avg_loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
|
||||
# Add features to state
|
||||
@ -506,85 +675,15 @@ def initialize_state(exchange, timeframes):
|
||||
state.append(sma_10 / prices[-1] - 1) # 1 feature
|
||||
state.append(rsi / 100) # 1 feature
|
||||
|
||||
# Add market regime features
|
||||
# This is a placeholder - in a real implementation, you would use the market_regime_classifier
|
||||
# from the DQN model to predict the current market regime
|
||||
state.extend([0, 0, 0]) # 3 features for market regime (one-hot encoded)
|
||||
# Add market regime features (placeholder)
|
||||
state.extend([0, 0, 0]) # 3 features for market regime
|
||||
|
||||
# Add additional features to reach the expected dimension of 100
|
||||
# Calculate more technical indicators
|
||||
for tf in timeframes:
|
||||
candles = timeframe_data[tf]
|
||||
prices = [candle[4] for candle in candles[-20:]] # Last 20 close prices
|
||||
|
||||
# Bollinger Bands
|
||||
window = 20
|
||||
if len(prices) >= window:
|
||||
sma_20 = sum(prices[-window:]) / window
|
||||
std_dev = (sum((price - sma_20) ** 2 for price in prices[-window:]) / window) ** 0.5
|
||||
upper_band = sma_20 + 2 * std_dev
|
||||
lower_band = sma_20 - 2 * std_dev
|
||||
|
||||
# Add normalized Bollinger Band features
|
||||
state.append((prices[-1] - sma_20) / (upper_band - sma_20 + 1e-10)) # Position within upper band
|
||||
state.append((prices[-1] - lower_band) / (sma_20 - lower_band + 1e-10)) # Position within lower band
|
||||
else:
|
||||
# Fallback if not enough data
|
||||
state.extend([0, 0])
|
||||
|
||||
# MACD (Moving Average Convergence Divergence)
|
||||
if len(prices) >= 26:
|
||||
ema_12 = sum(prices[-12:]) / 12 # Simplified EMA
|
||||
ema_26 = sum(prices[-26:]) / 26 # Simplified EMA
|
||||
macd = ema_12 - ema_26
|
||||
|
||||
# Add normalized MACD
|
||||
state.append(macd / prices[-1])
|
||||
else:
|
||||
# Fallback if not enough data
|
||||
state.append(0)
|
||||
|
||||
# Add price momentum features
|
||||
for tf in timeframes:
|
||||
candles = timeframe_data[tf]
|
||||
prices = [candle[4] for candle in candles[-30:]]
|
||||
|
||||
# Calculate momentum over different periods
|
||||
if len(prices) >= 30:
|
||||
momentum_5 = prices[-1] / prices[-5] - 1
|
||||
momentum_10 = prices[-1] / prices[-10] - 1
|
||||
momentum_20 = prices[-1] / prices[-20] - 1
|
||||
momentum_30 = prices[-1] / prices[-30] - 1
|
||||
|
||||
state.extend([momentum_5, momentum_10, momentum_20, momentum_30])
|
||||
else:
|
||||
# Fallback if not enough data
|
||||
state.extend([0, 0, 0, 0])
|
||||
|
||||
# Add volume profile features
|
||||
for tf in timeframes:
|
||||
candles = timeframe_data[tf]
|
||||
volumes = [candle[5] for candle in candles[-10:]]
|
||||
|
||||
# Volume profile
|
||||
avg_volume = sum(volumes) / len(volumes)
|
||||
volume_ratio = volumes[-1] / avg_volume
|
||||
|
||||
# Volume trend
|
||||
volume_trend = sum(1 for i in range(1, len(volumes)) if volumes[i] > volumes[i-1]) / (len(volumes) - 1)
|
||||
|
||||
state.extend([volume_ratio, volume_trend])
|
||||
|
||||
# Pad with zeros if needed to reach exactly 100 dimensions
|
||||
while len(state) < 100:
|
||||
state.append(0)
|
||||
|
||||
# Ensure state has exactly 100 dimensions
|
||||
if len(state) > 100:
|
||||
# Pad state to reach expected dimension of 100
|
||||
if len(state) < 100:
|
||||
state.extend([0] * (100 - len(state)))
|
||||
elif len(state) > 100:
|
||||
state = state[:100]
|
||||
|
||||
assert len(state) == 100, f"State dimension mismatch: {len(state)} != 100"
|
||||
|
||||
return state
|
||||
|
||||
def step_environment(exchange, state, action, price_model, timeframes, device):
|
||||
@ -744,6 +843,145 @@ def step_environment(exchange, state, action, price_model, timeframes, device):
|
||||
|
||||
return next_state, reward, done, trade_info
|
||||
|
||||
def train_price_predictor(model, exchange, timeframes, device, num_epochs=5, batch_size=32):
|
||||
"""
|
||||
Train the price prediction model using data from the exchange
|
||||
|
||||
Args:
|
||||
model: The EnhancedPricePredictionModel
|
||||
exchange: Exchange object to fetch data from
|
||||
timeframes: List of timeframes to use
|
||||
device: Device to train on (CPU or GPU)
|
||||
num_epochs: Number of training epochs
|
||||
batch_size: Batch size for training
|
||||
"""
|
||||
print(f"Training price prediction model for {num_epochs} epochs with batch size {batch_size}")
|
||||
|
||||
# Create optimizer
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
# Set model to training mode
|
||||
model.train()
|
||||
|
||||
# Fetch data for each timeframe
|
||||
timeframe_data = {}
|
||||
for tf in timeframes:
|
||||
# Fetch more data for training
|
||||
candles = exchange.fetch_ohlcv(timeframe=tf, limit=500)
|
||||
if not candles or len(candles) < 30:
|
||||
print(f"Not enough data for timeframe {tf}, skipping training")
|
||||
return model
|
||||
timeframe_data[tf] = candles
|
||||
|
||||
# Prepare training data
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
num_batches = 0
|
||||
|
||||
# Create batches
|
||||
for i in range(0, len(timeframe_data[timeframes[0]]) - 30, batch_size):
|
||||
if i + 30 + 5 >= len(timeframe_data[timeframes[0]]):
|
||||
break
|
||||
|
||||
# Prepare inputs for each timeframe
|
||||
inputs_list = []
|
||||
for tf in timeframes:
|
||||
if i + 30 >= len(timeframe_data[tf]):
|
||||
continue
|
||||
|
||||
# Extract price and volume data for input
|
||||
input_data = []
|
||||
for j in range(i, i + 30):
|
||||
if j < len(timeframe_data[tf]):
|
||||
candle = timeframe_data[tf][j]
|
||||
# Use close price and volume
|
||||
input_data.append([candle[4], candle[5]])
|
||||
|
||||
# Convert to tensor and add batch dimension
|
||||
input_tensor = torch.tensor(input_data, dtype=torch.float32).unsqueeze(0).to(device)
|
||||
inputs_list.append(input_tensor)
|
||||
|
||||
# Skip if we don't have data for all timeframes
|
||||
if len(inputs_list) != len(timeframes):
|
||||
continue
|
||||
|
||||
# Prepare targets (next 5 candles)
|
||||
target_data = []
|
||||
for j in range(i + 30, i + 35):
|
||||
if j < len(timeframe_data[timeframes[0]]):
|
||||
candle = timeframe_data[timeframes[0]][j]
|
||||
# OHLCV values
|
||||
target_data.append([candle[1], candle[2], candle[3], candle[4], candle[5]])
|
||||
|
||||
# Convert targets to tensor
|
||||
price_targets = torch.tensor(target_data, dtype=torch.float32).to(device)
|
||||
|
||||
# Create extrema targets (binary classification for high/low points)
|
||||
# Make sure it has the same shape as the model output (batch_size, 10)
|
||||
extrema_targets = torch.zeros(1, 10, dtype=torch.float32).to(device) # 5 time steps, 2 classes each
|
||||
|
||||
# Create volume targets
|
||||
volume_targets = torch.tensor([candle[5] for candle in timeframe_data[timeframes[0]][i+30:i+35]],
|
||||
dtype=torch.float32).to(device)
|
||||
|
||||
# Zero gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
try:
|
||||
# Forward pass
|
||||
price_pred, extrema_logits, volume_pred = model(inputs_list)
|
||||
|
||||
# Ensure targets have the same shape as predictions
|
||||
if price_pred.shape != price_targets.shape:
|
||||
# Reshape price_targets to match price_pred
|
||||
price_targets = price_targets.view(price_pred.shape)
|
||||
|
||||
if volume_pred.shape != volume_targets.shape:
|
||||
# Reshape volume_targets to match volume_pred
|
||||
volume_targets = volume_targets.view(volume_pred.shape)
|
||||
|
||||
# Calculate losses
|
||||
price_loss = F.mse_loss(price_pred, price_targets)
|
||||
extrema_loss = F.binary_cross_entropy_with_logits(extrema_logits, extrema_targets)
|
||||
volume_loss = F.mse_loss(volume_pred, volume_targets)
|
||||
|
||||
# Combined loss with weighting
|
||||
loss = price_loss + 0.5 * extrema_loss + 0.3 * volume_loss
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Gradient clipping to prevent exploding gradients
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
if num_batches % 10 == 0:
|
||||
print(f"Epoch {epoch+1}/{num_epochs}, Batch {num_batches}, Loss: {loss.item():.6f}")
|
||||
except Exception as e:
|
||||
print(f"Error in batch: {e}")
|
||||
print(f"Shapes - price_pred: {price_pred.shape}, price_targets: {price_targets.shape}")
|
||||
print(f"Shapes - extrema_logits: {extrema_logits.shape}, extrema_targets: {extrema_targets.shape}")
|
||||
print(f"Shapes - volume_pred: {volume_pred.shape}, volume_targets: {volume_targets.shape}")
|
||||
continue
|
||||
|
||||
if num_batches > 0:
|
||||
avg_loss = total_loss / num_batches
|
||||
print(f"Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.6f}")
|
||||
else:
|
||||
print(f"Epoch {epoch+1}/{num_epochs}, No batches processed")
|
||||
|
||||
# Learning rate scheduling
|
||||
if epoch > 0 and epoch % 2 == 0:
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] *= 0.9
|
||||
|
||||
print("Price prediction model training completed")
|
||||
return model
|
||||
|
||||
# Main function to run training
|
||||
def main():
|
||||
from exchange_simulator import ExchangeSimulator
|
||||
@ -752,7 +990,7 @@ def main():
|
||||
exchange = ExchangeSimulator()
|
||||
|
||||
# Train agent
|
||||
price_model, dqn_model = enhanced_train_agent(
|
||||
rewards, profits, win_rates = enhanced_train_agent(
|
||||
exchange=exchange,
|
||||
num_episodes=NUM_EPISODES,
|
||||
continuous=CONTINUOUS_MODE,
|
||||
|
@ -345,6 +345,439 @@ class ExchangeSimulator:
|
||||
}
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the exchange simulator to its initial state
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
# Reset timestamp
|
||||
self.current_timestamp = datetime.now()
|
||||
|
||||
# Regenerate data for each timeframe
|
||||
for tf in self.timeframes:
|
||||
self._generate_initial_data(tf)
|
||||
|
||||
# Reset any internal state
|
||||
self.position = 'flat'
|
||||
self.position_size = 0
|
||||
self.entry_price = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
# Reset prediction history if it exists
|
||||
if hasattr(self, 'prediction_history'):
|
||||
self.prediction_history = []
|
||||
|
||||
return self
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Take a step in the environment by executing an action
|
||||
|
||||
Args:
|
||||
action: Action to take (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT)
|
||||
|
||||
Returns:
|
||||
next_state: Next state after taking action
|
||||
reward: Reward received
|
||||
done: Whether episode is done
|
||||
info: Additional information
|
||||
"""
|
||||
# Get current price
|
||||
current_price = self.data['1m'][-1][4]
|
||||
|
||||
# Initialize info dictionary
|
||||
info = {
|
||||
'price': current_price,
|
||||
'timestamp': self.data['1m'][-1][0],
|
||||
'trade': None
|
||||
}
|
||||
|
||||
# Process action
|
||||
if action == 0: # HOLD
|
||||
pass # No action needed
|
||||
|
||||
elif action == 1: # BUY/LONG
|
||||
if self.position == 'flat':
|
||||
# Open a new long position
|
||||
self.position = 'long'
|
||||
self.entry_price = current_price
|
||||
self.position_size = 100 # Simplified position sizing
|
||||
|
||||
# Set stop loss and take profit levels
|
||||
self.stop_loss = current_price * 0.99 # 1% stop loss
|
||||
self.take_profit = current_price * 1.02 # 2% take profit
|
||||
|
||||
# Record entry time
|
||||
self.entry_time = self.data['1m'][-1][0]
|
||||
|
||||
# Add to info
|
||||
info['trade'] = {
|
||||
'type': 'long',
|
||||
'entry': current_price,
|
||||
'entry_time': self.data['1m'][-1][0],
|
||||
'size': self.position_size,
|
||||
'stop_loss': self.stop_loss,
|
||||
'take_profit': self.take_profit
|
||||
}
|
||||
|
||||
elif self.position == 'short':
|
||||
# Close short position and open long
|
||||
pnl = self.entry_price - current_price
|
||||
pnl_percent = pnl / self.entry_price * 100
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Add to info
|
||||
info['trade'] = {
|
||||
'type': 'short',
|
||||
'entry': self.entry_price,
|
||||
'exit': current_price,
|
||||
'entry_time': self.entry_time,
|
||||
'exit_time': self.data['1m'][-1][0],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||
}
|
||||
|
||||
# Open new long position
|
||||
self.position = 'long'
|
||||
self.entry_price = current_price
|
||||
self.position_size = 100 # Simplified position sizing
|
||||
|
||||
# Set stop loss and take profit levels
|
||||
self.stop_loss = current_price * 0.99 # 1% stop loss
|
||||
self.take_profit = current_price * 1.02 # 2% take profit
|
||||
|
||||
# Record entry time
|
||||
self.entry_time = self.data['1m'][-1][0]
|
||||
|
||||
elif action == 2: # SELL/SHORT
|
||||
if self.position == 'flat':
|
||||
# Open a new short position
|
||||
self.position = 'short'
|
||||
self.entry_price = current_price
|
||||
self.position_size = 100 # Simplified position sizing
|
||||
|
||||
# Set stop loss and take profit levels
|
||||
self.stop_loss = current_price * 1.01 # 1% stop loss
|
||||
self.take_profit = current_price * 0.98 # 2% take profit
|
||||
|
||||
# Record entry time
|
||||
self.entry_time = self.data['1m'][-1][0]
|
||||
|
||||
# Add to info
|
||||
info['trade'] = {
|
||||
'type': 'short',
|
||||
'entry': current_price,
|
||||
'entry_time': self.data['1m'][-1][0],
|
||||
'size': self.position_size,
|
||||
'stop_loss': self.stop_loss,
|
||||
'take_profit': self.take_profit
|
||||
}
|
||||
|
||||
elif self.position == 'long':
|
||||
# Close long position and open short
|
||||
pnl = current_price - self.entry_price
|
||||
pnl_percent = pnl / self.entry_price * 100
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Add to info
|
||||
info['trade'] = {
|
||||
'type': 'long',
|
||||
'entry': self.entry_price,
|
||||
'exit': current_price,
|
||||
'entry_time': self.entry_time,
|
||||
'exit_time': self.data['1m'][-1][0],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||
}
|
||||
|
||||
# Open new short position
|
||||
self.position = 'short'
|
||||
self.entry_price = current_price
|
||||
self.position_size = 100 # Simplified position sizing
|
||||
|
||||
# Set stop loss and take profit levels
|
||||
self.stop_loss = current_price * 1.01 # 1% stop loss
|
||||
self.take_profit = current_price * 0.98 # 2% take profit
|
||||
|
||||
# Record entry time
|
||||
self.entry_time = self.data['1m'][-1][0]
|
||||
|
||||
# Generate next candle
|
||||
self._add_new_candle('1m')
|
||||
|
||||
# Check if stop loss or take profit has been hit
|
||||
self._check_sl_tp(info)
|
||||
|
||||
# Validate predictions if available
|
||||
if hasattr(self, 'prediction_history') and len(self.prediction_history) > 0:
|
||||
self.validate_predictions(self.data['1m'][-1])
|
||||
|
||||
# Prepare next state (simplified)
|
||||
next_state = self._get_state()
|
||||
|
||||
# Calculate reward (simplified)
|
||||
reward = 0
|
||||
if info['trade'] is not None and 'pnl_dollar' in info['trade']:
|
||||
reward = info['trade']['pnl_dollar']
|
||||
|
||||
# Check if done (simplified)
|
||||
done = False
|
||||
|
||||
return next_state, reward, done, info
|
||||
|
||||
def _get_state(self):
|
||||
"""
|
||||
Get the current state of the environment
|
||||
|
||||
Returns:
|
||||
List representing the current state
|
||||
"""
|
||||
# Simplified state representation
|
||||
state = []
|
||||
|
||||
# Add price features
|
||||
for tf in ['1m', '5m', '15m']:
|
||||
if tf in self.data:
|
||||
# Get last 10 candles
|
||||
candles = self.data[tf][-10:]
|
||||
|
||||
# Extract close prices
|
||||
prices = [candle[4] for candle in candles]
|
||||
|
||||
# Calculate price changes
|
||||
price_changes = [prices[i]/prices[i-1] - 1 for i in range(1, len(prices))]
|
||||
|
||||
# Add to state
|
||||
state.extend(price_changes)
|
||||
|
||||
# Add current price relative to SMA
|
||||
sma_5 = sum(prices[-5:]) / 5
|
||||
sma_10 = sum(prices) / 10
|
||||
state.append(prices[-1] / sma_5 - 1)
|
||||
state.append(prices[-1] / sma_10 - 1)
|
||||
|
||||
# Pad state to 100 dimensions
|
||||
while len(state) < 100:
|
||||
state.append(0)
|
||||
|
||||
# Ensure state has exactly 100 dimensions
|
||||
if len(state) > 100:
|
||||
state = state[:100]
|
||||
|
||||
return state
|
||||
|
||||
def _check_sl_tp(self, info):
|
||||
"""
|
||||
Check if stop loss or take profit has been hit
|
||||
|
||||
Args:
|
||||
info: Info dictionary to update
|
||||
"""
|
||||
if self.position == 'flat':
|
||||
return
|
||||
|
||||
# Get current price
|
||||
current_price = self.data['1m'][-1][4]
|
||||
|
||||
if self.position == 'long':
|
||||
# Check stop loss
|
||||
if current_price <= self.stop_loss:
|
||||
# Stop loss hit
|
||||
pnl_percent = (self.stop_loss - self.entry_price) / self.entry_price * 100
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Add to info
|
||||
info['trade'] = {
|
||||
'type': 'long',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.stop_loss,
|
||||
'entry_time': self.entry_time,
|
||||
'exit_time': self.data['1m'][-1][0],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'reason': 'stop_loss',
|
||||
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||
}
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.entry_price = 0
|
||||
self.position_size = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
# Check take profit
|
||||
elif current_price >= self.take_profit:
|
||||
# Take profit hit
|
||||
pnl_percent = (self.take_profit - self.entry_price) / self.entry_price * 100
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Add to info
|
||||
info['trade'] = {
|
||||
'type': 'long',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.take_profit,
|
||||
'entry_time': self.entry_time,
|
||||
'exit_time': self.data['1m'][-1][0],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'reason': 'take_profit',
|
||||
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||
}
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.entry_price = 0
|
||||
self.position_size = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
elif self.position == 'short':
|
||||
# Check stop loss
|
||||
if current_price >= self.stop_loss:
|
||||
# Stop loss hit
|
||||
pnl_percent = (self.entry_price - self.stop_loss) / self.entry_price * 100
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Add to info
|
||||
info['trade'] = {
|
||||
'type': 'short',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.stop_loss,
|
||||
'entry_time': self.entry_time,
|
||||
'exit_time': self.data['1m'][-1][0],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'reason': 'stop_loss',
|
||||
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||
}
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.entry_price = 0
|
||||
self.position_size = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
# Check take profit
|
||||
elif current_price <= self.take_profit:
|
||||
# Take profit hit
|
||||
pnl_percent = (self.entry_price - self.take_profit) / self.entry_price * 100
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Add to info
|
||||
info['trade'] = {
|
||||
'type': 'short',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.take_profit,
|
||||
'entry_time': self.entry_time,
|
||||
'exit_time': self.data['1m'][-1][0],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'reason': 'take_profit',
|
||||
'duration': (self.data['1m'][-1][0] - self.entry_time) / (1000 * 60) # Duration in minutes
|
||||
}
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.entry_price = 0
|
||||
self.position_size = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
def validate_predictions(self, new_candle):
|
||||
"""
|
||||
Validate previous extrema predictions against new candle data
|
||||
|
||||
Args:
|
||||
new_candle: New candle data to validate against
|
||||
"""
|
||||
if not hasattr(self, 'prediction_history') or not self.prediction_history:
|
||||
return
|
||||
|
||||
# Extract candle data
|
||||
timestamp = new_candle[0]
|
||||
high_price = new_candle[2]
|
||||
low_price = new_candle[3]
|
||||
|
||||
# Track validation metrics
|
||||
validated_count = 0
|
||||
correct_count = 0
|
||||
|
||||
# Check each prediction that hasn't been validated yet
|
||||
for pred in self.prediction_history:
|
||||
if pred.get('validated', False):
|
||||
continue
|
||||
|
||||
# Check if this prediction's time has come (or passed)
|
||||
if 'predicted_timestamp' in pred and timestamp >= pred['predicted_timestamp']:
|
||||
pred['validated'] = True
|
||||
validated_count += 1
|
||||
|
||||
# Check if prediction was correct
|
||||
if pred['type'] == 'low':
|
||||
# A low prediction is correct if price went within 0.5% of predicted low
|
||||
price_diff_percent = abs(low_price - pred['price']) / pred['price'] * 100
|
||||
pred['actual_price'] = low_price
|
||||
pred['price_diff_percent'] = price_diff_percent
|
||||
|
||||
# Consider correct if within 0.5% or price went lower than predicted
|
||||
was_correct = price_diff_percent < 0.5 or low_price <= pred['price']
|
||||
pred['was_correct'] = was_correct
|
||||
|
||||
if was_correct:
|
||||
correct_count += 1
|
||||
|
||||
elif pred['type'] == 'high':
|
||||
# A high prediction is correct if price went within 0.5% of predicted high
|
||||
price_diff_percent = abs(high_price - pred['price']) / pred['price'] * 100
|
||||
pred['actual_price'] = high_price
|
||||
pred['price_diff_percent'] = price_diff_percent
|
||||
|
||||
# Consider correct if within 0.5% or price went higher than predicted
|
||||
was_correct = price_diff_percent < 0.5 or high_price >= pred['price']
|
||||
pred['was_correct'] = was_correct
|
||||
|
||||
if was_correct:
|
||||
correct_count += 1
|
||||
|
||||
# Return validation metrics
|
||||
if validated_count > 0:
|
||||
return {
|
||||
'validated_count': validated_count,
|
||||
'correct_count': correct_count,
|
||||
'accuracy': correct_count / validated_count
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def calculate_pnl(self):
|
||||
"""
|
||||
Calculate the current profit/loss of the open position
|
||||
|
||||
Returns:
|
||||
float: Current PnL in dollars, 0 if no position is open
|
||||
"""
|
||||
if self.position == 'flat':
|
||||
return 0.0
|
||||
|
||||
current_price = self.data['1m'][-1][4]
|
||||
|
||||
if self.position == 'long':
|
||||
pnl_percent = (current_price - self.entry_price) / self.entry_price * 100
|
||||
elif self.position == 'short':
|
||||
pnl_percent = (self.entry_price - current_price) / self.entry_price * 100
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
return pnl_dollar
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Create exchange simulator
|
||||
|
3750
crypto/gogo2/main.py
3750
crypto/gogo2/main.py
File diff suppressed because it is too large
Load Diff
24
crypto/gogo2/main_function.py
Normal file
24
crypto/gogo2/main_function.py
Normal file
@ -0,0 +1,24 @@
|
||||
import asyncio
|
||||
from exchange_simulator import ExchangeSimulator
|
||||
import logging
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Main function to run the training process.
|
||||
"""
|
||||
# Initialize exchange simulator
|
||||
exchange = ExchangeSimulator()
|
||||
|
||||
# Train agent
|
||||
print("Starting training process...")
|
||||
# Add your training code here
|
||||
print("Training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Program terminated by user")
|
8195
crypto/gogo2/main_multiu_broken.py
Normal file
8195
crypto/gogo2/main_multiu_broken.py
Normal file
File diff suppressed because it is too large
Load Diff
30
crypto/gogo2/main_multiu_fixed.py
Normal file
30
crypto/gogo2/main_multiu_fixed.py
Normal file
@ -0,0 +1,30 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from exchange_simulator import ExchangeSimulator
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Main function to run the training process.
|
||||
"""
|
||||
# Initialize exchange simulator
|
||||
exchange = ExchangeSimulator()
|
||||
|
||||
# Train agent
|
||||
print("Starting training process...")
|
||||
# Add your training code here
|
||||
print("Training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Program terminated by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error running main: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
@ -1,9 +1,15 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import torch
|
||||
from enhanced_training import enhanced_train_agent
|
||||
from exchange_simulator import ExchangeSimulator
|
||||
|
||||
# Fix for Windows asyncio
|
||||
if sys.platform == 'win32':
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='Enhanced Trading Bot Training')
|
||||
@ -26,6 +32,9 @@ def main():
|
||||
parser.add_argument('--refresh-data', action='store_true',
|
||||
help='Refresh data before training')
|
||||
|
||||
parser.add_argument('--verbose', action='store_true',
|
||||
help='Enable verbose logging')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set device
|
||||
@ -51,7 +60,8 @@ def main():
|
||||
exchange=exchange,
|
||||
num_episodes=args.episodes,
|
||||
continuous=False,
|
||||
start_episode=0
|
||||
start_episode=0,
|
||||
verbose=args.verbose
|
||||
)
|
||||
|
||||
elif args.mode == 'continuous':
|
||||
@ -61,7 +71,8 @@ def main():
|
||||
exchange=exchange,
|
||||
num_episodes=args.episodes,
|
||||
continuous=True,
|
||||
start_episode=args.start_episode
|
||||
start_episode=args.start_episode,
|
||||
verbose=args.verbose
|
||||
)
|
||||
|
||||
elif args.mode == 'evaluate':
|
||||
|
18
crypto/gogo2/run_main.py
Normal file
18
crypto/gogo2/run_main.py
Normal file
@ -0,0 +1,18 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from enhanced_training import main
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Program terminated by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error running main: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
56
crypto/gogo2/visualize_logs.py
Normal file
56
crypto/gogo2/visualize_logs.py
Normal file
@ -0,0 +1,56 @@
|
||||
import os
|
||||
import argparse
|
||||
import subprocess
|
||||
import webbrowser
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Visualize TensorBoard logs')
|
||||
parser.add_argument('--logdir', type=str, default='./logs', help='Directory containing TensorBoard logs')
|
||||
parser.add_argument('--port', type=int, default=6006, help='Port for TensorBoard server')
|
||||
args = parser.parse_args()
|
||||
|
||||
log_dir = Path(args.logdir)
|
||||
|
||||
if not log_dir.exists():
|
||||
print(f"Log directory {log_dir} does not exist. Creating it...")
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check if TensorBoard is installed
|
||||
try:
|
||||
subprocess.run(['tensorboard', '--version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
print("TensorBoard not found. Installing...")
|
||||
subprocess.run(['pip', 'install', 'tensorboard'], check=True)
|
||||
|
||||
# Start TensorBoard server
|
||||
print(f"Starting TensorBoard server on port {args.port}...")
|
||||
tensorboard_process = subprocess.Popen(
|
||||
['tensorboard', '--logdir', str(log_dir), '--port', str(args.port)],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE
|
||||
)
|
||||
|
||||
# Wait for TensorBoard to start
|
||||
time.sleep(3)
|
||||
|
||||
# Open browser
|
||||
url = f"http://localhost:{args.port}"
|
||||
print(f"Opening TensorBoard in browser: {url}")
|
||||
webbrowser.open(url)
|
||||
|
||||
print("TensorBoard is running. Press Ctrl+C to stop.")
|
||||
|
||||
try:
|
||||
# Keep the script running until interrupted
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("Stopping TensorBoard server...")
|
||||
tensorboard_process.terminate()
|
||||
tensorboard_process.wait()
|
||||
print("TensorBoard server stopped.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user