improvments and fixes

This commit is contained in:
Dobromir Popov 2025-03-12 16:52:49 +02:00
parent 506458d55e
commit 2e901e18f2
8 changed files with 2906 additions and 650 deletions

View File

@ -1,3 +1,5 @@
pip install torch-tb-profiler
ensure we use GPU if available to train faster. during training we need to have RL loop that looks at streaming data, and retrospective backtesting/training on predictions. sincr the start of the traing we're only loosing. implement robust penalty and analysis when closing a loosing trade and improve the reward function. ensure we use GPU if available to train faster. during training we need to have RL loop that looks at streaming data, and retrospective backtesting/training on predictions. sincr the start of the traing we're only loosing. implement robust penalty and analysis when closing a loosing trade and improve the reward function.

319
crypto/gogo2/data_cache.py Normal file
View File

@ -0,0 +1,319 @@
import os
import json
import time
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import logging
# Set up logging
logger = logging.getLogger('trading_bot')
class OHLCVCache:
"""
A simple cache for OHLCV data from exchanges.
Stores data in a structured format and provides backup when exchange is unavailable.
"""
def __init__(self, cache_dir="cache", max_age_hours=24):
"""
Initialize the OHLCV cache.
Args:
cache_dir: Directory to store cache files
max_age_hours: Maximum age of cached data in hours before considered stale
"""
self.cache_dir = cache_dir
self.max_age_seconds = max_age_hours * 3600
# Create cache directory if it doesn't exist
os.makedirs(cache_dir, exist_ok=True)
# In-memory cache for faster access
self.memory_cache = {}
def _get_cache_filename(self, symbol, timeframe):
"""Generate a standardized filename for the cache file"""
# Replace / with _ in symbol name (e.g., ETH/USDT -> ETH_USDT)
safe_symbol = symbol.replace('/', '_')
return os.path.join(self.cache_dir, f"{safe_symbol}_{timeframe}.json")
def save(self, data, symbol, timeframe):
"""
Save OHLCV data to cache.
Args:
data: List of dictionaries containing OHLCV data
symbol: Trading pair symbol (e.g., 'ETH/USDT')
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
"""
if not data:
logger.warning(f"No data to cache for {symbol} ({timeframe})")
return False
try:
# Convert data to a serializable format
serializable_data = []
for candle in data:
serializable_data.append({
'timestamp': candle['timestamp'],
'open': float(candle['open']),
'high': float(candle['high']),
'low': float(candle['low']),
'close': float(candle['close']),
'volume': float(candle['volume'])
})
# Create cache entry with metadata
cache_entry = {
'symbol': symbol,
'timeframe': timeframe,
'last_updated': int(time.time()),
'data': serializable_data
}
# Save to file
filename = self._get_cache_filename(symbol, timeframe)
with open(filename, 'w') as f:
json.dump(cache_entry, f)
# Update in-memory cache
cache_key = f"{symbol}_{timeframe}"
self.memory_cache[cache_key] = cache_entry
logger.info(f"Cached {len(data)} candles for {symbol} ({timeframe})")
return True
except Exception as e:
logger.error(f"Error saving data to cache: {e}")
return False
def load(self, symbol, timeframe, max_age_override=None):
"""
Load OHLCV data from cache.
Args:
symbol: Trading pair symbol (e.g., 'ETH/USDT')
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
max_age_override: Override the default max age (in seconds)
Returns:
List of dictionaries containing OHLCV data, or None if cache is missing or stale
"""
cache_key = f"{symbol}_{timeframe}"
max_age = max_age_override if max_age_override is not None else self.max_age_seconds
try:
# Check in-memory cache first
if cache_key in self.memory_cache:
cache_entry = self.memory_cache[cache_key]
# Check if cache is fresh
cache_age = int(time.time()) - cache_entry['last_updated']
if cache_age <= max_age:
logger.info(f"Using in-memory cache for {symbol} ({timeframe}), age: {cache_age//60} minutes")
return cache_entry['data']
# Check file cache
filename = self._get_cache_filename(symbol, timeframe)
if not os.path.exists(filename):
logger.info(f"No cache file found for {symbol} ({timeframe})")
return None
# Load cache file
with open(filename, 'r') as f:
cache_entry = json.load(f)
# Check if cache is fresh
cache_age = int(time.time()) - cache_entry['last_updated']
if cache_age > max_age:
logger.info(f"Cache for {symbol} ({timeframe}) is stale ({cache_age//60} minutes old)")
return None
# Update in-memory cache
self.memory_cache[cache_key] = cache_entry
logger.info(f"Loaded {len(cache_entry['data'])} candles from cache for {symbol} ({timeframe})")
return cache_entry['data']
except Exception as e:
logger.error(f"Error loading data from cache: {e}")
return None
def append(self, new_candle, symbol, timeframe):
"""
Append a new candle to the cached data.
Args:
new_candle: Dictionary containing a single OHLCV candle
symbol: Trading pair symbol (e.g., 'ETH/USDT')
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
Returns:
Boolean indicating success
"""
try:
# Load existing data
data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for append
if data is None:
data = []
# Check if the candle already exists (same timestamp)
for i, candle in enumerate(data):
if candle['timestamp'] == new_candle['timestamp']:
# Update existing candle
data[i] = {
'timestamp': new_candle['timestamp'],
'open': float(new_candle['open']),
'high': float(new_candle['high']),
'low': float(new_candle['low']),
'close': float(new_candle['close']),
'volume': float(new_candle['volume'])
}
# Save updated data
return self.save(data, symbol, timeframe)
# Append new candle
data.append({
'timestamp': new_candle['timestamp'],
'open': float(new_candle['open']),
'high': float(new_candle['high']),
'low': float(new_candle['low']),
'close': float(new_candle['close']),
'volume': float(new_candle['volume'])
})
# Save updated data
return self.save(data, symbol, timeframe)
except Exception as e:
logger.error(f"Error appending candle to cache: {e}")
return False
def get_latest_timestamp(self, symbol, timeframe):
"""
Get the timestamp of the most recent candle in the cache.
Args:
symbol: Trading pair symbol (e.g., 'ETH/USDT')
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
Returns:
Timestamp (milliseconds) of the most recent candle, or None if cache is empty
"""
data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for this check
if not data:
return None
# Find the most recent timestamp
latest_timestamp = max(candle['timestamp'] for candle in data)
return latest_timestamp
def clear(self, symbol=None, timeframe=None):
"""
Clear cache for a specific symbol and timeframe, or all cache if not specified.
Args:
symbol: Trading pair symbol (e.g., 'ETH/USDT'), or None to clear all symbols
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h'), or None to clear all timeframes
Returns:
Number of cache files deleted
"""
count = 0
try:
if symbol and timeframe:
# Clear specific cache
filename = self._get_cache_filename(symbol, timeframe)
if os.path.exists(filename):
os.remove(filename)
count = 1
# Clear from memory cache
cache_key = f"{symbol}_{timeframe}"
if cache_key in self.memory_cache:
del self.memory_cache[cache_key]
else:
# Clear all matching caches
for filename in os.listdir(self.cache_dir):
file_path = os.path.join(self.cache_dir, filename)
# Skip directories
if not os.path.isfile(file_path):
continue
# Check if file matches the filter
should_delete = True
if symbol:
safe_symbol = symbol.replace('/', '_')
if not filename.startswith(f"{safe_symbol}_"):
should_delete = False
if timeframe:
if not filename.endswith(f"_{timeframe}.json"):
should_delete = False
# Delete file if it matches the filter
if should_delete:
os.remove(file_path)
count += 1
# Clear memory cache
keys_to_delete = []
for cache_key in self.memory_cache:
should_delete = True
if symbol:
if not cache_key.startswith(f"{symbol}_"):
should_delete = False
if timeframe:
if not cache_key.endswith(f"_{timeframe}"):
should_delete = False
if should_delete:
keys_to_delete.append(cache_key)
for key in keys_to_delete:
del self.memory_cache[key]
logger.info(f"Cleared {count} cache files")
return count
except Exception as e:
logger.error(f"Error clearing cache: {e}")
return 0
def to_dataframe(self, symbol, timeframe):
"""
Convert cached OHLCV data to a pandas DataFrame.
Args:
symbol: Trading pair symbol (e.g., 'ETH/USDT')
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
Returns:
pandas DataFrame with OHLCV data, or None if cache is missing
"""
data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for conversion
if not data:
return None
# Convert to DataFrame
df = pd.DataFrame(data)
# Convert timestamp to datetime
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
# Set datetime as index
df.set_index('datetime', inplace=True)
return df
# Create a global instance for easy access
ohlcv_cache = OHLCVCache()

View File

@ -291,7 +291,7 @@ class EnhancedReplayBuffer:
def update_priorities(self, indices, td_errors): def update_priorities(self, indices, td_errors):
for idx, td_error in zip(indices, td_errors): for idx, td_error in zip(indices, td_errors):
# Update priority based on TD error # Update priority based on TD error
priority = abs(td_error) + 1e-5 # Small constant to ensure non-zero priority priority = float(abs(td_error) + 1e-5) # Small constant to ensure non-zero priority
self.priorities[idx] = priority self.priorities[idx] = priority
self.max_priority = max(self.max_priority, priority) self.max_priority = max(self.max_priority, priority)

View File

@ -0,0 +1,765 @@
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import GradScaler, autocast
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from tensorboardX import SummaryWriter
# Import our enhanced models
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN, EnhancedReplayBuffer, train_price_predictor, prepare_multi_timeframe_data
# Constants
TIMEFRAMES = ['1m', '15m', '1h']
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 1e-4
BATCH_SIZE = 64
GAMMA = 0.99
REPLAY_BUFFER_SIZE = 100000
TARGET_UPDATE = 10
NUM_EPISODES = 200
MAX_STEPS_PER_EPISODE = 1000
EPSILON_START = 1.0
EPSILON_END = 0.01
EPSILON_DECAY = 0.995
SAVE_INTERVAL = 10
CONTINUOUS_MODE = True
CONTINUOUS_START_EPISODE = 0
def setup_tensorboard():
"""Set up TensorBoard for logging training metrics"""
current_time = datetime.now().strftime('%Y%m%d-%H%M%S')
log_dir = os.path.join('runs', current_time)
writer = SummaryWriter(log_dir)
return writer
def save_models(price_model, dqn_model, optimizer, episode, rewards, profits, win_rates, best_reward, best_pnl, best_winrate):
"""Save model checkpoints and clean up old ones to keep only top 5 and best PnL"""
# Create models directory if it doesn't exist
os.makedirs('models', exist_ok=True)
# Save latest models
torch.save({
'price_model_state_dict': price_model.state_dict(),
'dqn_model_state_dict': dqn_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'episode': episode,
'rewards': rewards,
'profits': profits,
'win_rates': win_rates
}, 'models/enhanced_trading_agent_latest.pt')
# Save continuous training checkpoint
continuous_model_path = f'models/enhanced_trading_agent_continuous_{episode}.pt'
torch.save({
'price_model_state_dict': price_model.state_dict(),
'dqn_model_state_dict': dqn_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'episode': episode,
'rewards': rewards,
'profits': profits,
'win_rates': win_rates
}, continuous_model_path)
# Save best models
if rewards[-1] > best_reward:
best_reward = rewards[-1]
torch.save({
'price_model_state_dict': price_model.state_dict(),
'dqn_model_state_dict': dqn_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'episode': episode,
'rewards': rewards,
'profits': profits,
'win_rates': win_rates
}, 'models/enhanced_trading_agent_best_reward.pt')
if profits[-1] > best_pnl:
best_pnl = profits[-1]
torch.save({
'price_model_state_dict': price_model.state_dict(),
'dqn_model_state_dict': dqn_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'episode': episode,
'rewards': rewards,
'profits': profits,
'win_rates': win_rates
}, 'models/enhanced_trading_agent_best_pnl.pt')
if win_rates[-1] > best_winrate:
best_winrate = win_rates[-1]
torch.save({
'price_model_state_dict': price_model.state_dict(),
'dqn_model_state_dict': dqn_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'episode': episode,
'rewards': rewards,
'profits': profits,
'win_rates': win_rates
}, 'models/enhanced_trading_agent_best_winrate.pt')
# Save final model at the end of training
if episode == NUM_EPISODES - 1:
torch.save({
'price_model_state_dict': price_model.state_dict(),
'dqn_model_state_dict': dqn_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'episode': episode,
'rewards': rewards,
'profits': profits,
'win_rates': win_rates
}, 'models/enhanced_trading_agent_final.pt')
# Clean up old models - keep only top 5 most recent and best PnL
cleanup_model_files()
return best_reward, best_pnl, best_winrate
def cleanup_model_files():
"""Keep only the top 5 most recent continuous models and the best models"""
# Files we always want to keep
essential_files = [
'enhanced_trading_agent_latest.pt',
'enhanced_trading_agent_best_reward.pt',
'enhanced_trading_agent_best_pnl.pt',
'enhanced_trading_agent_best_winrate.pt',
'enhanced_trading_agent_final.pt'
]
# Get all continuous training model files
continuous_files = []
for file in os.listdir('models'):
if file.startswith('enhanced_trading_agent_continuous_') and file.endswith('.pt'):
continuous_files.append(file)
# Sort continuous files by episode number (newest first)
if continuous_files:
try:
continuous_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]), reverse=True)
# Keep only the 5 most recent continuous files
files_to_keep = essential_files + continuous_files[:5]
except (ValueError, IndexError):
# Handle case where filename format is unexpected
print("Warning: Could not sort continuous files by episode number. Keeping all continuous files.")
files_to_keep = essential_files + continuous_files
else:
files_to_keep = essential_files
# Delete all other model files
for file in os.listdir('models'):
if file.endswith('.pt') and file not in files_to_keep:
try:
os.remove(os.path.join('models', file))
print(f"Deleted old model file: {file}")
except Exception as e:
print(f"Error deleting {file}: {e}")
def plot_training_results(rewards, profits, win_rates, episode):
"""Plot training metrics"""
plt.figure(figsize=(15, 15))
# Plot rewards
plt.subplot(3, 1, 1)
plt.plot(rewards)
plt.title('Average Reward per Episode')
plt.xlabel('Episode')
plt.ylabel('Reward')
# Plot profits
plt.subplot(3, 1, 2)
plt.plot(profits)
plt.title('Profit/Loss per Episode')
plt.xlabel('Episode')
plt.ylabel('PnL ($)')
# Plot win rates
plt.subplot(3, 1, 3)
plt.plot(win_rates)
plt.title('Win Rate per Episode')
plt.xlabel('Episode')
plt.ylabel('Win Rate (%)')
plt.ylim(0, 100)
plt.tight_layout()
plt.savefig('training_results.png')
# Also save episode-specific plots periodically
if episode % 20 == 0:
os.makedirs('visualizations', exist_ok=True)
plt.savefig(f'visualizations/training_episode_{episode}.png')
plt.close()
def load_checkpoint(price_model, dqn_model, optimizer, episode=None):
"""Load model checkpoint for continuous training"""
if episode is not None:
checkpoint_path = f'models/enhanced_trading_agent_continuous_{episode}.pt'
else:
checkpoint_path = 'models/enhanced_trading_agent_latest.pt'
if os.path.exists(checkpoint_path):
print(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
price_model.load_state_dict(checkpoint['price_model_state_dict'])
dqn_model.load_state_dict(checkpoint['dqn_model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_episode = checkpoint['episode'] + 1
rewards = checkpoint['rewards']
profits = checkpoint['profits']
win_rates = checkpoint['win_rates']
print(f"Resuming training from episode {start_episode}")
return start_episode, rewards, profits, win_rates
else:
print("No checkpoint found, starting training from scratch")
return 0, [], [], []
def enhanced_train_agent(exchange, num_episodes=NUM_EPISODES, continuous=CONTINUOUS_MODE, start_episode=CONTINUOUS_START_EPISODE):
"""
Train the enhanced trading agent using multi-timeframe data
Args:
exchange: Exchange object to fetch data from
num_episodes: Number of episodes to train for
continuous: Whether to continue training from a checkpoint
start_episode: Episode to start from if continuous training
"""
print(f"Training on device: {DEVICE}")
# Set up TensorBoard
writer = setup_tensorboard()
# Initialize models
state_dim = 100 # Increased state dimension for multi-timeframe features
action_dim = 3 # Buy, Sell, Hold
price_model = EnhancedPricePredictionModel(
input_dim=2, # Price and volume
hidden_dim=256,
num_layers=3,
output_dim=5, # Predict next 5 candles
num_timeframes=len(TIMEFRAMES)
).to(DEVICE)
dqn_model = EnhancedDQN(
state_dim=state_dim,
action_dim=action_dim,
hidden_dim=512
).to(DEVICE)
target_dqn = EnhancedDQN(
state_dim=state_dim,
action_dim=action_dim,
hidden_dim=512
).to(DEVICE)
# Copy initial weights to target network
target_dqn.load_state_dict(dqn_model.state_dict())
# Initialize optimizer
optimizer = optim.Adam(list(price_model.parameters()) + list(dqn_model.parameters()), lr=LEARNING_RATE)
# Initialize replay buffer
replay_buffer = EnhancedReplayBuffer(
capacity=REPLAY_BUFFER_SIZE,
alpha=0.6,
beta=0.4,
beta_increment=0.001,
n_step=3,
gamma=GAMMA
)
# Initialize gradient scaler for mixed precision training
scaler = GradScaler(enabled=(DEVICE.type == 'cuda'))
# Initialize tracking variables
rewards = []
profits = []
win_rates = []
best_reward = float('-inf')
best_pnl = float('-inf')
best_winrate = float('-inf')
# Load checkpoint if continuous training
if continuous:
start_episode, rewards, profits, win_rates = load_checkpoint(
price_model, dqn_model, optimizer, start_episode
)
# Prepare multi-timeframe data for price prediction model training
data_loaders = prepare_multi_timeframe_data(exchange, TIMEFRAMES)
# Pre-train price prediction model
print("Pre-training price prediction model...")
train_price_predictor(price_model, data_loaders, optimizer, DEVICE, epochs=5)
# Main training loop
epsilon = EPSILON_START
for episode in range(start_episode, num_episodes):
print(f"Episode {episode+1}/{num_episodes}")
# Reset environment
state = initialize_state(exchange, TIMEFRAMES)
total_reward = 0
trades = []
wins = 0
losses = 0
# Episode loop
for step in range(MAX_STEPS_PER_EPISODE):
# Epsilon-greedy action selection
if np.random.random() < epsilon:
action = np.random.randint(0, action_dim)
else:
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(DEVICE)
q_values, _, _ = dqn_model(state_tensor)
action = q_values.argmax().item()
# Execute action and get next state and reward
next_state, reward, done, trade_info = step_environment(
exchange, state, action, price_model, TIMEFRAMES, DEVICE
)
# Store transition in replay buffer
replay_buffer.push(
torch.FloatTensor(state),
action,
reward,
torch.FloatTensor(next_state),
done
)
# Update state and accumulate reward
state = next_state
total_reward += reward
# Track trade outcomes
if trade_info is not None:
trades.append(trade_info)
if trade_info['pnl'] > 0:
wins += 1
elif trade_info['pnl'] < 0:
losses += 1
# Learn from experiences if enough samples
if len(replay_buffer) > BATCH_SIZE:
learn(dqn_model, target_dqn, replay_buffer, optimizer, scaler, DEVICE)
if done:
break
# Update target network
if episode % TARGET_UPDATE == 0:
target_dqn.load_state_dict(dqn_model.state_dict())
# Calculate episode metrics
avg_reward = total_reward / (step + 1)
total_pnl = sum(trade['pnl'] for trade in trades) if trades else 0
win_rate = (wins / (wins + losses) * 100) if (wins + losses) > 0 else 0
# Decay epsilon
epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
# Track metrics
rewards.append(avg_reward)
profits.append(total_pnl)
win_rates.append(win_rate)
# Log to TensorBoard
writer.add_scalar('Training/Reward', avg_reward, episode)
writer.add_scalar('Training/Profit', total_pnl, episode)
writer.add_scalar('Training/WinRate', win_rate, episode)
writer.add_scalar('Training/Epsilon', epsilon, episode)
# Print episode summary
print(f"Episode {episode+1} - Avg Reward: {avg_reward:.2f}, PnL: ${total_pnl:.2f}, Win Rate: {win_rate:.1f}%")
# Save models and plot results
if episode % SAVE_INTERVAL == 0 or episode == num_episodes - 1:
best_reward, best_pnl, best_winrate = save_models(
price_model, dqn_model, optimizer, episode,
rewards, profits, win_rates,
best_reward, best_pnl, best_winrate
)
plot_training_results(rewards, profits, win_rates, episode)
# Close TensorBoard writer
writer.close()
# Final save and plot
best_reward, best_pnl, best_winrate = save_models(
price_model, dqn_model, optimizer, num_episodes - 1,
rewards, profits, win_rates,
best_reward, best_pnl, best_winrate
)
plot_training_results(rewards, profits, win_rates, num_episodes - 1)
print("Training complete!")
return price_model, dqn_model
def learn(dqn, target_dqn, replay_buffer, optimizer, scaler, device):
"""Update the DQN model using experiences from the replay buffer"""
# Sample from replay buffer
states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(BATCH_SIZE)
# Move to device
states = states.to(device)
actions = actions.to(device)
rewards = rewards.to(device)
next_states = next_states.to(device)
dones = dones.to(device)
weights = weights.to(device)
# Get current Q values
if device.type == 'cuda':
with autocast(device_type='cuda', enabled=True):
current_q_values, _, _ = dqn(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Compute target Q values
with torch.no_grad():
next_q_values, _, _ = target_dqn(next_states)
max_next_q_values = next_q_values.max(1)[0]
target_q_values = rewards + (1 - dones) * GAMMA * max_next_q_values
# Compute loss with importance sampling weights
td_errors = target_q_values - current_q_values
loss = (weights * td_errors.pow(2)).mean()
else:
# CPU version without autocast
current_q_values, _, _ = dqn(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Compute target Q values
with torch.no_grad():
next_q_values, _, _ = target_dqn(next_states)
max_next_q_values = next_q_values.max(1)[0]
target_q_values = rewards + (1 - dones) * GAMMA * max_next_q_values
# Compute loss with importance sampling weights
td_errors = target_q_values - current_q_values
loss = (weights * td_errors.pow(2)).mean()
# Update priorities in replay buffer
replay_buffer.update_priorities(indices, td_errors.abs().detach().cpu().numpy())
# Optimize the model with mixed precision
optimizer.zero_grad()
if device.type == 'cuda':
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(dqn.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
else:
# CPU version without scaler
loss.backward()
torch.nn.utils.clip_grad_norm_(dqn.parameters(), max_norm=1.0)
optimizer.step()
def initialize_state(exchange, timeframes):
"""Initialize the state with data from multiple timeframes"""
# Fetch data for each timeframe
timeframe_data = {}
for tf in timeframes:
candles = exchange.fetch_ohlcv(timeframe=tf, limit=30)
timeframe_data[tf] = candles
# Extract features from each timeframe
state = []
for tf in timeframes:
candles = timeframe_data[tf]
# Price features
prices = [candle[4] for candle in candles[-10:]] # Last 10 close prices
price_changes = [prices[i]/prices[i-1] - 1 for i in range(1, len(prices))]
# Volume features
volumes = [candle[5] for candle in candles[-10:]] # Last 10 volumes
volume_changes = [volumes[i]/volumes[i-1] - 1 for i in range(1, len(volumes))]
# Technical indicators
# Simple Moving Averages
sma_5 = sum(prices[-5:]) / 5
sma_10 = sum(prices) / 10
# Relative Strength Index (simplified)
gains = [max(0, price_changes[i]) for i in range(len(price_changes))]
losses = [max(0, -price_changes[i]) for i in range(len(price_changes))]
avg_gain = sum(gains) / len(gains)
avg_loss = sum(losses) / len(losses)
rs = avg_gain / (avg_loss + 1e-10) # Avoid division by zero
rsi = 100 - (100 / (1 + rs))
# Add features to state
state.extend(price_changes) # 9 features
state.extend(volume_changes) # 9 features
state.append(sma_5 / prices[-1] - 1) # 1 feature
state.append(sma_10 / prices[-1] - 1) # 1 feature
state.append(rsi / 100) # 1 feature
# Add market regime features
# This is a placeholder - in a real implementation, you would use the market_regime_classifier
# from the DQN model to predict the current market regime
state.extend([0, 0, 0]) # 3 features for market regime (one-hot encoded)
# Add additional features to reach the expected dimension of 100
# Calculate more technical indicators
for tf in timeframes:
candles = timeframe_data[tf]
prices = [candle[4] for candle in candles[-20:]] # Last 20 close prices
# Bollinger Bands
window = 20
if len(prices) >= window:
sma_20 = sum(prices[-window:]) / window
std_dev = (sum((price - sma_20) ** 2 for price in prices[-window:]) / window) ** 0.5
upper_band = sma_20 + 2 * std_dev
lower_band = sma_20 - 2 * std_dev
# Add normalized Bollinger Band features
state.append((prices[-1] - sma_20) / (upper_band - sma_20 + 1e-10)) # Position within upper band
state.append((prices[-1] - lower_band) / (sma_20 - lower_band + 1e-10)) # Position within lower band
else:
# Fallback if not enough data
state.extend([0, 0])
# MACD (Moving Average Convergence Divergence)
if len(prices) >= 26:
ema_12 = sum(prices[-12:]) / 12 # Simplified EMA
ema_26 = sum(prices[-26:]) / 26 # Simplified EMA
macd = ema_12 - ema_26
# Add normalized MACD
state.append(macd / prices[-1])
else:
# Fallback if not enough data
state.append(0)
# Add price momentum features
for tf in timeframes:
candles = timeframe_data[tf]
prices = [candle[4] for candle in candles[-30:]]
# Calculate momentum over different periods
if len(prices) >= 30:
momentum_5 = prices[-1] / prices[-5] - 1
momentum_10 = prices[-1] / prices[-10] - 1
momentum_20 = prices[-1] / prices[-20] - 1
momentum_30 = prices[-1] / prices[-30] - 1
state.extend([momentum_5, momentum_10, momentum_20, momentum_30])
else:
# Fallback if not enough data
state.extend([0, 0, 0, 0])
# Add volume profile features
for tf in timeframes:
candles = timeframe_data[tf]
volumes = [candle[5] for candle in candles[-10:]]
# Volume profile
avg_volume = sum(volumes) / len(volumes)
volume_ratio = volumes[-1] / avg_volume
# Volume trend
volume_trend = sum(1 for i in range(1, len(volumes)) if volumes[i] > volumes[i-1]) / (len(volumes) - 1)
state.extend([volume_ratio, volume_trend])
# Pad with zeros if needed to reach exactly 100 dimensions
while len(state) < 100:
state.append(0)
# Ensure state has exactly 100 dimensions
if len(state) > 100:
state = state[:100]
assert len(state) == 100, f"State dimension mismatch: {len(state)} != 100"
return state
def step_environment(exchange, state, action, price_model, timeframes, device):
"""
Execute action in the environment and return next state, reward, done flag, and trade info
Args:
exchange: Exchange object to interact with
state: Current state
action: Action to take (0: Hold, 1: Buy, 2: Sell)
price_model: Price prediction model
timeframes: List of timeframes to use
device: Device to run models on
Returns:
next_state: Next state after taking action
reward: Reward received
done: Whether episode is done
trade_info: Information about the trade (if any)
"""
# Fetch latest data for each timeframe
timeframe_data = {}
for tf in timeframes:
candles = exchange.fetch_ohlcv(timeframe=tf, limit=30)
timeframe_data[tf] = candles
# Prepare inputs for price prediction model
price_inputs = []
for tf in timeframes:
candles = timeframe_data[tf]
# Extract price and volume data
input_data = torch.tensor([
[candle[4], candle[5]] for candle in candles[-30:] # Last 30 candles
], dtype=torch.float32).unsqueeze(0).to(device) # Add batch dimension
price_inputs.append(input_data)
# Get price and extrema predictions
with torch.no_grad():
price_pred, extrema_logits, volume_pred = price_model(price_inputs)
# Convert predictions to numpy
price_pred = price_pred.cpu().numpy()[0] # Remove batch dimension
extrema_probs = torch.sigmoid(extrema_logits).cpu().numpy()[0]
volume_pred = volume_pred.cpu().numpy()[0]
# Execute action
current_price = timeframe_data['1m'][-1][4] # Current close price
trade_info = None
reward = 0
if action == 1: # Buy
# Check if we're at a predicted low point (good time to buy)
is_predicted_low = any(extrema_probs[i*2+1] > 0.7 for i in range(5))
# Calculate entry quality based on predictions
entry_quality = 0.5 # Default quality
if is_predicted_low:
entry_quality += 0.2 # Bonus for buying at predicted low
# Check volume confirmation
volume_increasing = volume_pred[0] > timeframe_data['1m'][-1][5]
if volume_increasing:
entry_quality += 0.1 # Bonus for increasing volume
# Execute buy order
# In a real implementation, this would interact with the exchange
# For now, we'll simulate the trade
trade_info = {
'action': 'buy',
'price': current_price,
'size': 100 * entry_quality, # Size based on entry quality
'entry_quality': entry_quality,
'pnl': 0 # Will be updated later
}
# Calculate reward
# Base reward for taking action
reward = 1
# Bonus for buying at predicted low
if is_predicted_low:
reward += 5
print("Trading at predicted low - additional reward")
# Bonus for volume confirmation
if volume_increasing:
reward += 2
print("Trading with high volume - additional reward")
elif action == 2: # Sell
# Check if we're at a predicted high point (good time to sell)
is_predicted_high = any(extrema_probs[i*2] > 0.7 for i in range(5))
# Calculate entry quality based on predictions
entry_quality = 0.5 # Default quality
if is_predicted_high:
entry_quality += 0.2 # Bonus for selling at predicted high
# Check volume confirmation
volume_increasing = volume_pred[0] > timeframe_data['1m'][-1][5]
if volume_increasing:
entry_quality += 0.1 # Bonus for increasing volume
# Execute sell order
# In a real implementation, this would interact with the exchange
# For now, we'll simulate the trade
trade_info = {
'action': 'sell',
'price': current_price,
'size': 100 * entry_quality, # Size based on entry quality
'entry_quality': entry_quality,
'pnl': 0 # Will be updated later
}
# Calculate reward
# Base reward for taking action
reward = 1
# Bonus for selling at predicted high
if is_predicted_high:
reward += 5
print("Trading at predicted high - additional reward")
# Bonus for volume confirmation
if volume_increasing:
reward += 2
print("Trading with high volume - additional reward")
else: # Hold
# Small reward for holding
reward = 0.1
# Simulate trade outcome
if trade_info is not None:
# In a real implementation, this would be based on actual market movement
# For now, we'll use the price prediction to simulate the outcome
future_price = price_pred[0] # Price in the next candle
if trade_info['action'] == 'buy':
# For buy, profit if price goes up
pnl_pct = (future_price / current_price - 1) * 100
trade_info['pnl'] = pnl_pct * trade_info['size'] / 100
else: # sell
# For sell, profit if price goes down
pnl_pct = (1 - future_price / current_price) * 100
trade_info['pnl'] = pnl_pct * trade_info['size'] / 100
# Adjust reward based on trade outcome
reward += trade_info['pnl'] * 10 # Scale PnL for reward
# Update state
next_state = initialize_state(exchange, timeframes)
# Check if episode is done
# In a real implementation, this would be based on episode length or other criteria
done = False
return next_state, reward, done, trade_info
# Main function to run training
def main():
from exchange_simulator import ExchangeSimulator
# Initialize exchange simulator
exchange = ExchangeSimulator()
# Train agent
price_model, dqn_model = enhanced_train_agent(
exchange=exchange,
num_episodes=NUM_EPISODES,
continuous=CONTINUOUS_MODE,
start_episode=CONTINUOUS_START_EPISODE
)
print("Training complete!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,373 @@
import numpy as np
import pandas as pd
import os
import random
from datetime import datetime, timedelta
class ExchangeSimulator:
"""
A simple exchange simulator that generates realistic market data
for testing trading algorithms without connecting to a real exchange.
"""
def __init__(self, symbol="BTC/USDT", seed=42):
"""
Initialize the exchange simulator
Args:
symbol: Trading pair symbol
seed: Random seed for reproducibility
"""
self.symbol = symbol
self.seed = seed
np.random.seed(seed)
random.seed(seed)
# Initialize data storage
self.data = {}
self.current_timestamp = datetime.now()
# Generate initial data for different timeframes
self.timeframes = ['1m', '5m', '15m', '30m', '1h', '4h', '1d']
self.timeframe_minutes = {
'1m': 1,
'5m': 5,
'15m': 15,
'30m': 30,
'1h': 60,
'4h': 240,
'1d': 1440
}
# Generate initial price around $50,000 (for BTC/USDT)
self.base_price = 50000.0
# Generate data for each timeframe
for tf in self.timeframes:
self._generate_initial_data(tf)
def _generate_initial_data(self, timeframe, num_candles=1000):
"""
Generate initial historical data for a specific timeframe
Args:
timeframe: Timeframe to generate data for
num_candles: Number of candles to generate
"""
# Calculate time delta for this timeframe
minutes = self.timeframe_minutes[timeframe]
# Generate timestamps
end_time = self.current_timestamp
timestamps = [end_time - timedelta(minutes=minutes * i) for i in range(num_candles)]
timestamps.reverse() # Oldest first
# Generate price data with realistic patterns
prices = self._generate_price_series(num_candles)
# Generate volume data with realistic patterns
volumes = self._generate_volume_series(num_candles, timeframe)
# Create OHLCV data
ohlcv_data = []
for i in range(num_candles):
# Calculate OHLC based on close price
close = prices[i]
high = close * (1 + np.random.uniform(0, 0.01))
low = close * (1 - np.random.uniform(0, 0.01))
open_price = prices[i-1] if i > 0 else close * (1 - np.random.uniform(-0.005, 0.005))
# Create candle
candle = [
int(timestamps[i].timestamp() * 1000), # Timestamp in milliseconds
open_price, # Open
high, # High
low, # Low
close, # Close
volumes[i] # Volume
]
ohlcv_data.append(candle)
# Store data
self.data[timeframe] = ohlcv_data
def _generate_price_series(self, length):
"""
Generate a realistic price series with trends, reversals, and volatility
Args:
length: Number of prices to generate
Returns:
List of prices
"""
# Start with base price
prices = [self.base_price]
# Parameters for price generation
trend_strength = 0.001 # Strength of trend
volatility = 0.005 # Daily volatility
mean_reversion = 0.001 # Mean reversion strength
# Generate price series
for i in range(1, length):
# Determine if we're in a trend
if i % 100 == 0:
# Change trend direction every ~100 candles
trend_strength = -trend_strength
# Calculate price change
trend = trend_strength * prices[-1]
random_change = np.random.normal(0, volatility) * prices[-1]
mean_reversion_change = mean_reversion * (self.base_price - prices[-1])
# Calculate new price
new_price = prices[-1] + trend + random_change + mean_reversion_change
# Ensure price doesn't go negative
new_price = max(new_price, prices[-1] * 0.9)
prices.append(new_price)
return prices
def _generate_volume_series(self, length, timeframe):
"""
Generate a realistic volume series with patterns
Args:
length: Number of volumes to generate
timeframe: Timeframe for volume scaling
Returns:
List of volumes
"""
# Base volume depends on timeframe
base_volume = {
'1m': 10,
'5m': 50,
'15m': 150,
'30m': 300,
'1h': 600,
'4h': 2400,
'1d': 10000
}[timeframe]
# Generate volume series
volumes = []
for i in range(length):
# Volume tends to be higher at trend reversals and during volatile periods
cycle_factor = 1 + 0.5 * np.sin(i / 20) # Cyclical pattern
random_factor = np.random.lognormal(0, 0.5) # Random spikes
# Calculate volume
volume = base_volume * cycle_factor * random_factor
# Add some volume spikes
if random.random() < 0.05: # 5% chance of volume spike
volume *= random.uniform(2, 5)
volumes.append(volume)
return volumes
def fetch_ohlcv(self, timeframe='1m', limit=100, since=None):
"""
Fetch OHLCV data for a specific timeframe
Args:
timeframe: Timeframe to fetch data for
limit: Number of candles to fetch
since: Timestamp to fetch data since (not used in simulator)
Returns:
List of OHLCV candles
"""
# Ensure timeframe exists
if timeframe not in self.data:
if timeframe in self.timeframe_minutes:
self._generate_initial_data(timeframe)
else:
# Default to 1m if timeframe not supported
timeframe = '1m'
# Get data
data = self.data[timeframe]
# Return limited data
return data[-limit:]
def update(self):
"""
Update the exchange data by generating a new candle for each timeframe
"""
# Update current timestamp
self.current_timestamp = datetime.now()
# Update each timeframe
for tf in self.timeframes:
self._add_new_candle(tf)
def _add_new_candle(self, timeframe):
"""
Add a new candle to the specified timeframe
Args:
timeframe: Timeframe to add candle to
"""
# Get existing data
data = self.data[timeframe]
# Get last close price
last_close = data[-1][4]
# Calculate time delta for this timeframe
minutes = self.timeframe_minutes[timeframe]
# Calculate new timestamp
new_timestamp = int((data[-1][0] / 1000 + minutes * 60) * 1000)
# Generate new price with some randomness
price_change = np.random.normal(0, 0.002) * last_close
new_close = last_close + price_change
# Calculate OHLC
new_open = last_close
new_high = max(new_open, new_close) * (1 + np.random.uniform(0, 0.005))
new_low = min(new_open, new_close) * (1 - np.random.uniform(0, 0.005))
# Generate volume
base_volume = data[-1][5]
volume_change = np.random.normal(0, 0.2) * base_volume
new_volume = max(base_volume + volume_change, base_volume * 0.5)
# Create new candle
new_candle = [
new_timestamp,
new_open,
new_high,
new_low,
new_close,
new_volume
]
# Add to data
self.data[timeframe].append(new_candle)
def get_ticker(self, symbol=None):
"""
Get current ticker information
Args:
symbol: Symbol to get ticker for (defaults to initialized symbol)
Returns:
Dictionary with ticker information
"""
if symbol is None:
symbol = self.symbol
# Get latest 1m candle
latest_candle = self.data['1m'][-1]
return {
'symbol': symbol,
'bid': latest_candle[4] * 0.9999, # Slightly below last price
'ask': latest_candle[4] * 1.0001, # Slightly above last price
'last': latest_candle[4],
'high': latest_candle[2],
'low': latest_candle[3],
'volume': latest_candle[5],
'timestamp': latest_candle[0]
}
def create_order(self, symbol, type, side, amount, price=None):
"""
Simulate creating an order
Args:
symbol: Symbol to create order for
type: Order type (limit, market)
side: Order side (buy, sell)
amount: Order amount
price: Order price (for limit orders)
Returns:
Dictionary with order information
"""
# Get current ticker
ticker = self.get_ticker(symbol)
# Determine execution price
if type == 'market':
if side == 'buy':
execution_price = ticker['ask']
else:
execution_price = ticker['bid']
else: # limit order
execution_price = price
# Create order object
order = {
'id': f"order_{int(datetime.now().timestamp() * 1000)}",
'symbol': symbol,
'type': type,
'side': side,
'amount': amount,
'price': execution_price,
'cost': amount * execution_price,
'filled': amount,
'status': 'closed',
'timestamp': int(datetime.now().timestamp() * 1000)
}
return order
def fetch_balance(self):
"""
Fetch account balance (simulated)
Returns:
Dictionary with balance information
"""
return {
'total': {
'USD': 10000.0,
'BTC': 1.0
},
'free': {
'USD': 5000.0,
'BTC': 0.5
},
'used': {
'USD': 5000.0,
'BTC': 0.5
}
}
# Example usage
if __name__ == "__main__":
# Create exchange simulator
exchange = ExchangeSimulator()
# Fetch some data
ohlcv = exchange.fetch_ohlcv(timeframe='1h', limit=10)
print("OHLCV data (1h timeframe):")
for candle in ohlcv[-5:]:
timestamp = datetime.fromtimestamp(candle[0] / 1000)
print(f"{timestamp}: Open={candle[1]:.2f}, High={candle[2]:.2f}, Low={candle[3]:.2f}, Close={candle[4]:.2f}, Volume={candle[5]:.2f}")
# Get current ticker
ticker = exchange.get_ticker()
print(f"\nCurrent ticker: {ticker['last']:.2f}")
# Create a market buy order
order = exchange.create_order("BTC/USDT", "market", "buy", 0.1)
print(f"\nCreated order: {order}")
# Update the exchange (simulate time passing)
exchange.update()
# Get updated ticker
updated_ticker = exchange.get_ticker()
print(f"\nUpdated ticker: {updated_ticker['last']:.2f}")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,305 @@
import argparse
import os
import torch
from enhanced_training import enhanced_train_agent
from exchange_simulator import ExchangeSimulator
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='Enhanced Trading Bot Training')
parser.add_argument('--mode', type=str, default='train', choices=['train', 'continuous', 'evaluate', 'live', 'demo'],
help='Mode to run the trading bot in')
parser.add_argument('--episodes', type=int, default=100,
help='Number of episodes to train for')
parser.add_argument('--start-episode', type=int, default=0,
help='Episode to start from for continuous training')
parser.add_argument('--device', type=str, default='auto',
help='Device to train on (auto, cuda, cpu)')
parser.add_argument('--timeframes', type=str, default='1m,15m,1h',
help='Comma-separated list of timeframes to use')
parser.add_argument('--refresh-data', action='store_true',
help='Refresh data before training')
args = parser.parse_args()
# Set device
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(args.device)
print(f"Using device: {device}")
# Parse timeframes
timeframes = args.timeframes.split(',')
print(f"Using timeframes: {timeframes}")
# Initialize exchange simulator
exchange = ExchangeSimulator()
# Run in specified mode
if args.mode == 'train':
# Train from scratch
print(f"Training for {args.episodes} episodes...")
enhanced_train_agent(
exchange=exchange,
num_episodes=args.episodes,
continuous=False,
start_episode=0
)
elif args.mode == 'continuous':
# Continue training from checkpoint
print(f"Continuing training from episode {args.start_episode} for {args.episodes} episodes...")
enhanced_train_agent(
exchange=exchange,
num_episodes=args.episodes,
continuous=True,
start_episode=args.start_episode
)
elif args.mode == 'evaluate':
# Evaluate the model
print("Evaluating model...")
evaluate_model(exchange, device)
elif args.mode == 'live' or args.mode == 'demo':
# Run in live or demo mode
is_demo = args.mode == 'demo'
print(f"Running in {'demo' if is_demo else 'live'} mode...")
run_live(exchange, device, is_demo=is_demo)
print("Done!")
def evaluate_model(exchange, device):
"""
Evaluate the trained model
Args:
exchange: Exchange simulator
device: Device to run on
"""
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN
import torch
import numpy as np
# Load the best model
model_path = 'models/enhanced_trading_agent_best_pnl.pt'
if not os.path.exists(model_path):
model_path = 'models/enhanced_trading_agent_latest.pt'
if not os.path.exists(model_path):
print("No model found to evaluate!")
return
print(f"Loading model from {model_path}")
checkpoint = torch.load(model_path, map_location=device)
# Initialize models
state_dim = 100
action_dim = 3
timeframes = ['1m', '15m', '1h']
price_model = EnhancedPricePredictionModel(
input_dim=2,
hidden_dim=256,
num_layers=3,
output_dim=5,
num_timeframes=len(timeframes)
).to(device)
dqn_model = EnhancedDQN(
state_dim=state_dim,
action_dim=action_dim,
hidden_dim=512
).to(device)
# Load model weights
price_model.load_state_dict(checkpoint['price_model_state_dict'])
dqn_model.load_state_dict(checkpoint['dqn_model_state_dict'])
# Set models to evaluation mode
price_model.eval()
dqn_model.eval()
# Run evaluation
num_steps = 1000
total_reward = 0
trades = []
# Initialize state
from enhanced_training import initialize_state, step_environment
state = initialize_state(exchange, timeframes)
for step in range(num_steps):
# Select action
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
q_values, _, _ = dqn_model(state_tensor)
action = q_values.argmax().item()
# Execute action
next_state, reward, done, trade_info = step_environment(
exchange, state, action, price_model, timeframes, device
)
# Update state and accumulate reward
state = next_state
total_reward += reward
# Track trade
if trade_info is not None:
trades.append(trade_info)
print(f"Trade: {trade_info['action']} at {trade_info['price']:.2f}, PnL: {trade_info['pnl']:.2f}")
# Update exchange (simulate time passing)
if step % 10 == 0:
exchange.update()
if done:
break
# Calculate metrics
avg_reward = total_reward / num_steps
total_pnl = sum(trade['pnl'] for trade in trades) if trades else 0
wins = sum(1 for trade in trades if trade['pnl'] > 0)
losses = sum(1 for trade in trades if trade['pnl'] < 0)
win_rate = (wins / (wins + losses) * 100) if (wins + losses) > 0 else 0
print("\nEvaluation Results:")
print(f"Average Reward: {avg_reward:.2f}")
print(f"Total PnL: ${total_pnl:.2f}")
print(f"Win Rate: {win_rate:.1f}% ({wins}/{wins+losses})")
def run_live(exchange, device, is_demo=True):
"""
Run the trading bot in live or demo mode
Args:
exchange: Exchange simulator or real exchange
device: Device to run on
is_demo: Whether to run in demo mode (no real trades)
"""
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN
import torch
import time
# Load the best model
model_path = 'models/enhanced_trading_agent_best_pnl.pt'
if not os.path.exists(model_path):
model_path = 'models/enhanced_trading_agent_latest.pt'
if not os.path.exists(model_path):
print("No model found to run in live mode!")
return
print(f"Loading model from {model_path}")
checkpoint = torch.load(model_path, map_location=device)
# Initialize models
state_dim = 100
action_dim = 3
timeframes = ['1m', '15m', '1h']
price_model = EnhancedPricePredictionModel(
input_dim=2,
hidden_dim=256,
num_layers=3,
output_dim=5,
num_timeframes=len(timeframes)
).to(device)
dqn_model = EnhancedDQN(
state_dim=state_dim,
action_dim=action_dim,
hidden_dim=512
).to(device)
# Load model weights
price_model.load_state_dict(checkpoint['price_model_state_dict'])
dqn_model.load_state_dict(checkpoint['dqn_model_state_dict'])
# Set models to evaluation mode
price_model.eval()
dqn_model.eval()
# Run live trading
print(f"Running in {'demo' if is_demo else 'live'} mode...")
print("Press Ctrl+C to stop")
# Initialize state
from enhanced_training import initialize_state, step_environment
state = initialize_state(exchange, timeframes)
try:
while True:
# Select action
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
q_values, _, market_regime = dqn_model(state_tensor)
action = q_values.argmax().item()
# Get market regime prediction
regime_probs = torch.softmax(market_regime, dim=1).cpu().numpy()[0]
regime_names = ['Trending', 'Ranging', 'Volatile']
predicted_regime = regime_names[regime_probs.argmax()]
# Get current price
ticker = exchange.get_ticker()
current_price = ticker['last']
# Print state
print(f"\nCurrent price: ${current_price:.2f}")
print(f"Predicted market regime: {predicted_regime} ({regime_probs.max()*100:.1f}% confidence)")
# Execute action
next_state, reward, _, trade_info = step_environment(
exchange, state, action, price_model, timeframes, device
)
# Print action
action_names = ['Hold', 'Buy', 'Sell']
print(f"Action: {action_names[action]}")
if trade_info is not None:
print(f"Trade: {trade_info['action']} at {trade_info['price']:.2f}, Size: {trade_info['size']:.2f}, Entry Quality: {trade_info['entry_quality']:.2f}")
# Execute real trade if not in demo mode
if not is_demo:
if trade_info['action'] == 'buy':
order = exchange.create_order(
symbol="BTC/USDT",
type="market",
side="buy",
amount=trade_info['size'] / current_price
)
print(f"Executed buy order: {order}")
else: # sell
order = exchange.create_order(
symbol="BTC/USDT",
type="market",
side="sell",
amount=trade_info['size'] / current_price
)
print(f"Executed sell order: {order}")
# Update state
state = next_state
# Update exchange (simulate time passing)
exchange.update()
# Wait for next candle
time.sleep(5) # In a real implementation, this would wait for the next candle
except KeyboardInterrupt:
print("\nStopping live trading")
if __name__ == "__main__":
main()

185
crypto/gogo2/test_cache.py Normal file
View File

@ -0,0 +1,185 @@
import os
import sys
import json
import logging
import time
from datetime import datetime
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger('cache_test')
# Import our cache implementation
from data_cache import ohlcv_cache
def generate_sample_data(num_candles=100):
"""Generate sample OHLCV data for testing"""
data = []
base_timestamp = int(time.time() * 1000) - (num_candles * 60 * 1000) # Start from num_candles minutes ago
for i in range(num_candles):
timestamp = base_timestamp + (i * 60 * 1000) # Add i minutes
# Generate some random-ish but realistic looking price data
base_price = 1900.0 + (i * 0.5) # Slight uptrend
open_price = base_price - 0.5 + (i % 3)
close_price = base_price + 0.3 + ((i+1) % 4)
high_price = max(open_price, close_price) + 1.0 + (i % 2)
low_price = min(open_price, close_price) - 0.8 - (i % 2)
volume = 10.0 + (i % 10) * 2.0
data.append({
'timestamp': timestamp,
'open': open_price,
'high': high_price,
'low': low_price,
'close': close_price,
'volume': volume
})
return data
def test_cache_save_load():
"""Test saving and loading data from cache"""
logger.info("Testing cache save and load...")
# Generate sample data
data = generate_sample_data(100)
logger.info(f"Generated {len(data)} sample candles")
# Save to cache
symbol = "ETH/USDT"
timeframe = "1m"
success = ohlcv_cache.save(data, symbol, timeframe)
logger.info(f"Saved to cache: {success}")
# Load from cache
cached_data = ohlcv_cache.load(symbol, timeframe)
logger.info(f"Loaded {len(cached_data) if cached_data else 0} candles from cache")
# Verify data integrity
if cached_data:
first_original = data[0]
first_cached = cached_data[0]
logger.info(f"First original candle: {first_original}")
logger.info(f"First cached candle: {first_cached}")
last_original = data[-1]
last_cached = cached_data[-1]
logger.info(f"Last original candle: {last_original}")
logger.info(f"Last cached candle: {last_cached}")
return success and cached_data and len(cached_data) == len(data)
def test_cache_append():
"""Test appending a new candle to cached data"""
logger.info("Testing cache append...")
# Generate sample data
data = generate_sample_data(100)
# Save to cache
symbol = "ETH/USDT"
timeframe = "5m"
success = ohlcv_cache.save(data, symbol, timeframe)
logger.info(f"Saved to cache: {success}")
# Generate a new candle
last_timestamp = data[-1]['timestamp']
new_timestamp = last_timestamp + (5 * 60 * 1000) # 5 minutes later
new_candle = {
'timestamp': new_timestamp,
'open': 1950.0,
'high': 1955.0,
'low': 1948.0,
'close': 1952.0,
'volume': 15.0
}
# Append to cache
success = ohlcv_cache.append(new_candle, symbol, timeframe)
logger.info(f"Appended to cache: {success}")
# Load from cache
cached_data = ohlcv_cache.load(symbol, timeframe)
logger.info(f"Loaded {len(cached_data) if cached_data else 0} candles from cache")
# Verify the new candle was appended
if cached_data:
last_cached = cached_data[-1]
logger.info(f"New candle: {new_candle}")
logger.info(f"Last cached candle: {last_cached}")
return success and cached_data and len(cached_data) == len(data) + 1
def test_cache_dataframe():
"""Test converting cached data to a pandas DataFrame"""
logger.info("Testing cache to DataFrame conversion...")
# Generate sample data
data = generate_sample_data(100)
# Save to cache
symbol = "ETH/USDT"
timeframe = "15m"
success = ohlcv_cache.save(data, symbol, timeframe)
logger.info(f"Saved to cache: {success}")
# Convert to DataFrame
df = ohlcv_cache.to_dataframe(symbol, timeframe)
logger.info(f"Converted to DataFrame with {len(df) if df is not None else 0} rows")
# Display DataFrame info
if df is not None:
logger.info(f"DataFrame columns: {df.columns.tolist()}")
logger.info(f"DataFrame index: {df.index.name}")
logger.info(f"First row: {df.iloc[0].to_dict()}")
logger.info(f"Last row: {df.iloc[-1].to_dict()}")
return success and df is not None and len(df) == len(data)
def main():
"""Run all tests"""
logger.info("Starting cache tests...")
# Run tests
save_load_success = test_cache_save_load()
append_success = test_cache_append()
dataframe_success = test_cache_dataframe()
# Print results
logger.info("Test results:")
logger.info(f" Save/Load: {'PASS' if save_load_success else 'FAIL'}")
logger.info(f" Append: {'PASS' if append_success else 'FAIL'}")
logger.info(f" DataFrame: {'PASS' if dataframe_success else 'FAIL'}")
# Check cache directory contents
cache_dir = ohlcv_cache.cache_dir
logger.info(f"Cache directory: {cache_dir}")
if os.path.exists(cache_dir):
files = os.listdir(cache_dir)
logger.info(f"Cache files: {files}")
# Print file sizes
for file in files:
file_path = os.path.join(cache_dir, file)
size_kb = os.path.getsize(file_path) / 1024
logger.info(f" {file}: {size_kb:.2f} KB")
# Print first few lines of each file
with open(file_path, 'r') as f:
data = json.load(f)
logger.info(f" Metadata: symbol={data.get('symbol')}, timeframe={data.get('timeframe')}, last_updated={datetime.fromtimestamp(data.get('last_updated')).strftime('%Y-%m-%d %H:%M:%S')}")
logger.info(f" Candles: {len(data.get('data', []))}")
return save_load_success and append_success and dataframe_success
if __name__ == "__main__":
main()