improvments and fixes
This commit is contained in:
parent
506458d55e
commit
2e901e18f2
@ -1,3 +1,5 @@
|
||||
pip install torch-tb-profiler
|
||||
|
||||
|
||||
|
||||
ensure we use GPU if available to train faster. during training we need to have RL loop that looks at streaming data, and retrospective backtesting/training on predictions. sincr the start of the traing we're only loosing. implement robust penalty and analysis when closing a loosing trade and improve the reward function.
|
||||
|
319
crypto/gogo2/data_cache.py
Normal file
319
crypto/gogo2/data_cache.py
Normal file
@ -0,0 +1,319 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger('trading_bot')
|
||||
|
||||
class OHLCVCache:
|
||||
"""
|
||||
A simple cache for OHLCV data from exchanges.
|
||||
Stores data in a structured format and provides backup when exchange is unavailable.
|
||||
"""
|
||||
def __init__(self, cache_dir="cache", max_age_hours=24):
|
||||
"""
|
||||
Initialize the OHLCV cache.
|
||||
|
||||
Args:
|
||||
cache_dir: Directory to store cache files
|
||||
max_age_hours: Maximum age of cached data in hours before considered stale
|
||||
"""
|
||||
self.cache_dir = cache_dir
|
||||
self.max_age_seconds = max_age_hours * 3600
|
||||
|
||||
# Create cache directory if it doesn't exist
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# In-memory cache for faster access
|
||||
self.memory_cache = {}
|
||||
|
||||
def _get_cache_filename(self, symbol, timeframe):
|
||||
"""Generate a standardized filename for the cache file"""
|
||||
# Replace / with _ in symbol name (e.g., ETH/USDT -> ETH_USDT)
|
||||
safe_symbol = symbol.replace('/', '_')
|
||||
return os.path.join(self.cache_dir, f"{safe_symbol}_{timeframe}.json")
|
||||
|
||||
def save(self, data, symbol, timeframe):
|
||||
"""
|
||||
Save OHLCV data to cache.
|
||||
|
||||
Args:
|
||||
data: List of dictionaries containing OHLCV data
|
||||
symbol: Trading pair symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
|
||||
"""
|
||||
if not data:
|
||||
logger.warning(f"No data to cache for {symbol} ({timeframe})")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Convert data to a serializable format
|
||||
serializable_data = []
|
||||
for candle in data:
|
||||
serializable_data.append({
|
||||
'timestamp': candle['timestamp'],
|
||||
'open': float(candle['open']),
|
||||
'high': float(candle['high']),
|
||||
'low': float(candle['low']),
|
||||
'close': float(candle['close']),
|
||||
'volume': float(candle['volume'])
|
||||
})
|
||||
|
||||
# Create cache entry with metadata
|
||||
cache_entry = {
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe,
|
||||
'last_updated': int(time.time()),
|
||||
'data': serializable_data
|
||||
}
|
||||
|
||||
# Save to file
|
||||
filename = self._get_cache_filename(symbol, timeframe)
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(cache_entry, f)
|
||||
|
||||
# Update in-memory cache
|
||||
cache_key = f"{symbol}_{timeframe}"
|
||||
self.memory_cache[cache_key] = cache_entry
|
||||
|
||||
logger.info(f"Cached {len(data)} candles for {symbol} ({timeframe})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving data to cache: {e}")
|
||||
return False
|
||||
|
||||
def load(self, symbol, timeframe, max_age_override=None):
|
||||
"""
|
||||
Load OHLCV data from cache.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
|
||||
max_age_override: Override the default max age (in seconds)
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing OHLCV data, or None if cache is missing or stale
|
||||
"""
|
||||
cache_key = f"{symbol}_{timeframe}"
|
||||
max_age = max_age_override if max_age_override is not None else self.max_age_seconds
|
||||
|
||||
try:
|
||||
# Check in-memory cache first
|
||||
if cache_key in self.memory_cache:
|
||||
cache_entry = self.memory_cache[cache_key]
|
||||
|
||||
# Check if cache is fresh
|
||||
cache_age = int(time.time()) - cache_entry['last_updated']
|
||||
if cache_age <= max_age:
|
||||
logger.info(f"Using in-memory cache for {symbol} ({timeframe}), age: {cache_age//60} minutes")
|
||||
return cache_entry['data']
|
||||
|
||||
# Check file cache
|
||||
filename = self._get_cache_filename(symbol, timeframe)
|
||||
if not os.path.exists(filename):
|
||||
logger.info(f"No cache file found for {symbol} ({timeframe})")
|
||||
return None
|
||||
|
||||
# Load cache file
|
||||
with open(filename, 'r') as f:
|
||||
cache_entry = json.load(f)
|
||||
|
||||
# Check if cache is fresh
|
||||
cache_age = int(time.time()) - cache_entry['last_updated']
|
||||
if cache_age > max_age:
|
||||
logger.info(f"Cache for {symbol} ({timeframe}) is stale ({cache_age//60} minutes old)")
|
||||
return None
|
||||
|
||||
# Update in-memory cache
|
||||
self.memory_cache[cache_key] = cache_entry
|
||||
|
||||
logger.info(f"Loaded {len(cache_entry['data'])} candles from cache for {symbol} ({timeframe})")
|
||||
return cache_entry['data']
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading data from cache: {e}")
|
||||
return None
|
||||
|
||||
def append(self, new_candle, symbol, timeframe):
|
||||
"""
|
||||
Append a new candle to the cached data.
|
||||
|
||||
Args:
|
||||
new_candle: Dictionary containing a single OHLCV candle
|
||||
symbol: Trading pair symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
|
||||
|
||||
Returns:
|
||||
Boolean indicating success
|
||||
"""
|
||||
try:
|
||||
# Load existing data
|
||||
data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for append
|
||||
|
||||
if data is None:
|
||||
data = []
|
||||
|
||||
# Check if the candle already exists (same timestamp)
|
||||
for i, candle in enumerate(data):
|
||||
if candle['timestamp'] == new_candle['timestamp']:
|
||||
# Update existing candle
|
||||
data[i] = {
|
||||
'timestamp': new_candle['timestamp'],
|
||||
'open': float(new_candle['open']),
|
||||
'high': float(new_candle['high']),
|
||||
'low': float(new_candle['low']),
|
||||
'close': float(new_candle['close']),
|
||||
'volume': float(new_candle['volume'])
|
||||
}
|
||||
# Save updated data
|
||||
return self.save(data, symbol, timeframe)
|
||||
|
||||
# Append new candle
|
||||
data.append({
|
||||
'timestamp': new_candle['timestamp'],
|
||||
'open': float(new_candle['open']),
|
||||
'high': float(new_candle['high']),
|
||||
'low': float(new_candle['low']),
|
||||
'close': float(new_candle['close']),
|
||||
'volume': float(new_candle['volume'])
|
||||
})
|
||||
|
||||
# Save updated data
|
||||
return self.save(data, symbol, timeframe)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error appending candle to cache: {e}")
|
||||
return False
|
||||
|
||||
def get_latest_timestamp(self, symbol, timeframe):
|
||||
"""
|
||||
Get the timestamp of the most recent candle in the cache.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
|
||||
|
||||
Returns:
|
||||
Timestamp (milliseconds) of the most recent candle, or None if cache is empty
|
||||
"""
|
||||
data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for this check
|
||||
|
||||
if not data:
|
||||
return None
|
||||
|
||||
# Find the most recent timestamp
|
||||
latest_timestamp = max(candle['timestamp'] for candle in data)
|
||||
return latest_timestamp
|
||||
|
||||
def clear(self, symbol=None, timeframe=None):
|
||||
"""
|
||||
Clear cache for a specific symbol and timeframe, or all cache if not specified.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol (e.g., 'ETH/USDT'), or None to clear all symbols
|
||||
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h'), or None to clear all timeframes
|
||||
|
||||
Returns:
|
||||
Number of cache files deleted
|
||||
"""
|
||||
count = 0
|
||||
|
||||
try:
|
||||
if symbol and timeframe:
|
||||
# Clear specific cache
|
||||
filename = self._get_cache_filename(symbol, timeframe)
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
count = 1
|
||||
|
||||
# Clear from memory cache
|
||||
cache_key = f"{symbol}_{timeframe}"
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
else:
|
||||
# Clear all matching caches
|
||||
for filename in os.listdir(self.cache_dir):
|
||||
file_path = os.path.join(self.cache_dir, filename)
|
||||
|
||||
# Skip directories
|
||||
if not os.path.isfile(file_path):
|
||||
continue
|
||||
|
||||
# Check if file matches the filter
|
||||
should_delete = True
|
||||
|
||||
if symbol:
|
||||
safe_symbol = symbol.replace('/', '_')
|
||||
if not filename.startswith(f"{safe_symbol}_"):
|
||||
should_delete = False
|
||||
|
||||
if timeframe:
|
||||
if not filename.endswith(f"_{timeframe}.json"):
|
||||
should_delete = False
|
||||
|
||||
# Delete file if it matches the filter
|
||||
if should_delete:
|
||||
os.remove(file_path)
|
||||
count += 1
|
||||
|
||||
# Clear memory cache
|
||||
keys_to_delete = []
|
||||
for cache_key in self.memory_cache:
|
||||
should_delete = True
|
||||
|
||||
if symbol:
|
||||
if not cache_key.startswith(f"{symbol}_"):
|
||||
should_delete = False
|
||||
|
||||
if timeframe:
|
||||
if not cache_key.endswith(f"_{timeframe}"):
|
||||
should_delete = False
|
||||
|
||||
if should_delete:
|
||||
keys_to_delete.append(cache_key)
|
||||
|
||||
for key in keys_to_delete:
|
||||
del self.memory_cache[key]
|
||||
|
||||
logger.info(f"Cleared {count} cache files")
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing cache: {e}")
|
||||
return 0
|
||||
|
||||
def to_dataframe(self, symbol, timeframe):
|
||||
"""
|
||||
Convert cached OHLCV data to a pandas DataFrame.
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
|
||||
|
||||
Returns:
|
||||
pandas DataFrame with OHLCV data, or None if cache is missing
|
||||
"""
|
||||
data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for conversion
|
||||
|
||||
if not data:
|
||||
return None
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Convert timestamp to datetime
|
||||
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||||
|
||||
# Set datetime as index
|
||||
df.set_index('datetime', inplace=True)
|
||||
|
||||
return df
|
||||
|
||||
# Create a global instance for easy access
|
||||
ohlcv_cache = OHLCVCache()
|
@ -291,7 +291,7 @@ class EnhancedReplayBuffer:
|
||||
def update_priorities(self, indices, td_errors):
|
||||
for idx, td_error in zip(indices, td_errors):
|
||||
# Update priority based on TD error
|
||||
priority = abs(td_error) + 1e-5 # Small constant to ensure non-zero priority
|
||||
priority = float(abs(td_error) + 1e-5) # Small constant to ensure non-zero priority
|
||||
self.priorities[idx] = priority
|
||||
self.max_priority = max(self.max_priority, priority)
|
||||
|
||||
|
765
crypto/gogo2/enhanced_training.py
Normal file
765
crypto/gogo2/enhanced_training.py
Normal file
@ -0,0 +1,765 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.amp import GradScaler, autocast
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from datetime import datetime
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
# Import our enhanced models
|
||||
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN, EnhancedReplayBuffer, train_price_predictor, prepare_multi_timeframe_data
|
||||
|
||||
# Constants
|
||||
TIMEFRAMES = ['1m', '15m', '1h']
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
LEARNING_RATE = 1e-4
|
||||
BATCH_SIZE = 64
|
||||
GAMMA = 0.99
|
||||
REPLAY_BUFFER_SIZE = 100000
|
||||
TARGET_UPDATE = 10
|
||||
NUM_EPISODES = 200
|
||||
MAX_STEPS_PER_EPISODE = 1000
|
||||
EPSILON_START = 1.0
|
||||
EPSILON_END = 0.01
|
||||
EPSILON_DECAY = 0.995
|
||||
SAVE_INTERVAL = 10
|
||||
CONTINUOUS_MODE = True
|
||||
CONTINUOUS_START_EPISODE = 0
|
||||
|
||||
def setup_tensorboard():
|
||||
"""Set up TensorBoard for logging training metrics"""
|
||||
current_time = datetime.now().strftime('%Y%m%d-%H%M%S')
|
||||
log_dir = os.path.join('runs', current_time)
|
||||
writer = SummaryWriter(log_dir)
|
||||
return writer
|
||||
|
||||
def save_models(price_model, dqn_model, optimizer, episode, rewards, profits, win_rates, best_reward, best_pnl, best_winrate):
|
||||
"""Save model checkpoints and clean up old ones to keep only top 5 and best PnL"""
|
||||
# Create models directory if it doesn't exist
|
||||
os.makedirs('models', exist_ok=True)
|
||||
|
||||
# Save latest models
|
||||
torch.save({
|
||||
'price_model_state_dict': price_model.state_dict(),
|
||||
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'episode': episode,
|
||||
'rewards': rewards,
|
||||
'profits': profits,
|
||||
'win_rates': win_rates
|
||||
}, 'models/enhanced_trading_agent_latest.pt')
|
||||
|
||||
# Save continuous training checkpoint
|
||||
continuous_model_path = f'models/enhanced_trading_agent_continuous_{episode}.pt'
|
||||
torch.save({
|
||||
'price_model_state_dict': price_model.state_dict(),
|
||||
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'episode': episode,
|
||||
'rewards': rewards,
|
||||
'profits': profits,
|
||||
'win_rates': win_rates
|
||||
}, continuous_model_path)
|
||||
|
||||
# Save best models
|
||||
if rewards[-1] > best_reward:
|
||||
best_reward = rewards[-1]
|
||||
torch.save({
|
||||
'price_model_state_dict': price_model.state_dict(),
|
||||
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'episode': episode,
|
||||
'rewards': rewards,
|
||||
'profits': profits,
|
||||
'win_rates': win_rates
|
||||
}, 'models/enhanced_trading_agent_best_reward.pt')
|
||||
|
||||
if profits[-1] > best_pnl:
|
||||
best_pnl = profits[-1]
|
||||
torch.save({
|
||||
'price_model_state_dict': price_model.state_dict(),
|
||||
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'episode': episode,
|
||||
'rewards': rewards,
|
||||
'profits': profits,
|
||||
'win_rates': win_rates
|
||||
}, 'models/enhanced_trading_agent_best_pnl.pt')
|
||||
|
||||
if win_rates[-1] > best_winrate:
|
||||
best_winrate = win_rates[-1]
|
||||
torch.save({
|
||||
'price_model_state_dict': price_model.state_dict(),
|
||||
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'episode': episode,
|
||||
'rewards': rewards,
|
||||
'profits': profits,
|
||||
'win_rates': win_rates
|
||||
}, 'models/enhanced_trading_agent_best_winrate.pt')
|
||||
|
||||
# Save final model at the end of training
|
||||
if episode == NUM_EPISODES - 1:
|
||||
torch.save({
|
||||
'price_model_state_dict': price_model.state_dict(),
|
||||
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'episode': episode,
|
||||
'rewards': rewards,
|
||||
'profits': profits,
|
||||
'win_rates': win_rates
|
||||
}, 'models/enhanced_trading_agent_final.pt')
|
||||
|
||||
# Clean up old models - keep only top 5 most recent and best PnL
|
||||
cleanup_model_files()
|
||||
|
||||
return best_reward, best_pnl, best_winrate
|
||||
|
||||
def cleanup_model_files():
|
||||
"""Keep only the top 5 most recent continuous models and the best models"""
|
||||
# Files we always want to keep
|
||||
essential_files = [
|
||||
'enhanced_trading_agent_latest.pt',
|
||||
'enhanced_trading_agent_best_reward.pt',
|
||||
'enhanced_trading_agent_best_pnl.pt',
|
||||
'enhanced_trading_agent_best_winrate.pt',
|
||||
'enhanced_trading_agent_final.pt'
|
||||
]
|
||||
|
||||
# Get all continuous training model files
|
||||
continuous_files = []
|
||||
for file in os.listdir('models'):
|
||||
if file.startswith('enhanced_trading_agent_continuous_') and file.endswith('.pt'):
|
||||
continuous_files.append(file)
|
||||
|
||||
# Sort continuous files by episode number (newest first)
|
||||
if continuous_files:
|
||||
try:
|
||||
continuous_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]), reverse=True)
|
||||
# Keep only the 5 most recent continuous files
|
||||
files_to_keep = essential_files + continuous_files[:5]
|
||||
except (ValueError, IndexError):
|
||||
# Handle case where filename format is unexpected
|
||||
print("Warning: Could not sort continuous files by episode number. Keeping all continuous files.")
|
||||
files_to_keep = essential_files + continuous_files
|
||||
else:
|
||||
files_to_keep = essential_files
|
||||
|
||||
# Delete all other model files
|
||||
for file in os.listdir('models'):
|
||||
if file.endswith('.pt') and file not in files_to_keep:
|
||||
try:
|
||||
os.remove(os.path.join('models', file))
|
||||
print(f"Deleted old model file: {file}")
|
||||
except Exception as e:
|
||||
print(f"Error deleting {file}: {e}")
|
||||
|
||||
def plot_training_results(rewards, profits, win_rates, episode):
|
||||
"""Plot training metrics"""
|
||||
plt.figure(figsize=(15, 15))
|
||||
|
||||
# Plot rewards
|
||||
plt.subplot(3, 1, 1)
|
||||
plt.plot(rewards)
|
||||
plt.title('Average Reward per Episode')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Reward')
|
||||
|
||||
# Plot profits
|
||||
plt.subplot(3, 1, 2)
|
||||
plt.plot(profits)
|
||||
plt.title('Profit/Loss per Episode')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('PnL ($)')
|
||||
|
||||
# Plot win rates
|
||||
plt.subplot(3, 1, 3)
|
||||
plt.plot(win_rates)
|
||||
plt.title('Win Rate per Episode')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Win Rate (%)')
|
||||
plt.ylim(0, 100)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('training_results.png')
|
||||
|
||||
# Also save episode-specific plots periodically
|
||||
if episode % 20 == 0:
|
||||
os.makedirs('visualizations', exist_ok=True)
|
||||
plt.savefig(f'visualizations/training_episode_{episode}.png')
|
||||
|
||||
plt.close()
|
||||
|
||||
def load_checkpoint(price_model, dqn_model, optimizer, episode=None):
|
||||
"""Load model checkpoint for continuous training"""
|
||||
if episode is not None:
|
||||
checkpoint_path = f'models/enhanced_trading_agent_continuous_{episode}.pt'
|
||||
else:
|
||||
checkpoint_path = 'models/enhanced_trading_agent_latest.pt'
|
||||
|
||||
if os.path.exists(checkpoint_path):
|
||||
print(f"Loading checkpoint from {checkpoint_path}")
|
||||
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
|
||||
|
||||
price_model.load_state_dict(checkpoint['price_model_state_dict'])
|
||||
dqn_model.load_state_dict(checkpoint['dqn_model_state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
start_episode = checkpoint['episode'] + 1
|
||||
rewards = checkpoint['rewards']
|
||||
profits = checkpoint['profits']
|
||||
win_rates = checkpoint['win_rates']
|
||||
|
||||
print(f"Resuming training from episode {start_episode}")
|
||||
return start_episode, rewards, profits, win_rates
|
||||
else:
|
||||
print("No checkpoint found, starting training from scratch")
|
||||
return 0, [], [], []
|
||||
|
||||
def enhanced_train_agent(exchange, num_episodes=NUM_EPISODES, continuous=CONTINUOUS_MODE, start_episode=CONTINUOUS_START_EPISODE):
|
||||
"""
|
||||
Train the enhanced trading agent using multi-timeframe data
|
||||
|
||||
Args:
|
||||
exchange: Exchange object to fetch data from
|
||||
num_episodes: Number of episodes to train for
|
||||
continuous: Whether to continue training from a checkpoint
|
||||
start_episode: Episode to start from if continuous training
|
||||
"""
|
||||
print(f"Training on device: {DEVICE}")
|
||||
|
||||
# Set up TensorBoard
|
||||
writer = setup_tensorboard()
|
||||
|
||||
# Initialize models
|
||||
state_dim = 100 # Increased state dimension for multi-timeframe features
|
||||
action_dim = 3 # Buy, Sell, Hold
|
||||
|
||||
price_model = EnhancedPricePredictionModel(
|
||||
input_dim=2, # Price and volume
|
||||
hidden_dim=256,
|
||||
num_layers=3,
|
||||
output_dim=5, # Predict next 5 candles
|
||||
num_timeframes=len(TIMEFRAMES)
|
||||
).to(DEVICE)
|
||||
|
||||
dqn_model = EnhancedDQN(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
hidden_dim=512
|
||||
).to(DEVICE)
|
||||
|
||||
target_dqn = EnhancedDQN(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
hidden_dim=512
|
||||
).to(DEVICE)
|
||||
|
||||
# Copy initial weights to target network
|
||||
target_dqn.load_state_dict(dqn_model.state_dict())
|
||||
|
||||
# Initialize optimizer
|
||||
optimizer = optim.Adam(list(price_model.parameters()) + list(dqn_model.parameters()), lr=LEARNING_RATE)
|
||||
|
||||
# Initialize replay buffer
|
||||
replay_buffer = EnhancedReplayBuffer(
|
||||
capacity=REPLAY_BUFFER_SIZE,
|
||||
alpha=0.6,
|
||||
beta=0.4,
|
||||
beta_increment=0.001,
|
||||
n_step=3,
|
||||
gamma=GAMMA
|
||||
)
|
||||
|
||||
# Initialize gradient scaler for mixed precision training
|
||||
scaler = GradScaler(enabled=(DEVICE.type == 'cuda'))
|
||||
|
||||
# Initialize tracking variables
|
||||
rewards = []
|
||||
profits = []
|
||||
win_rates = []
|
||||
best_reward = float('-inf')
|
||||
best_pnl = float('-inf')
|
||||
best_winrate = float('-inf')
|
||||
|
||||
# Load checkpoint if continuous training
|
||||
if continuous:
|
||||
start_episode, rewards, profits, win_rates = load_checkpoint(
|
||||
price_model, dqn_model, optimizer, start_episode
|
||||
)
|
||||
|
||||
# Prepare multi-timeframe data for price prediction model training
|
||||
data_loaders = prepare_multi_timeframe_data(exchange, TIMEFRAMES)
|
||||
|
||||
# Pre-train price prediction model
|
||||
print("Pre-training price prediction model...")
|
||||
train_price_predictor(price_model, data_loaders, optimizer, DEVICE, epochs=5)
|
||||
|
||||
# Main training loop
|
||||
epsilon = EPSILON_START
|
||||
|
||||
for episode in range(start_episode, num_episodes):
|
||||
print(f"Episode {episode+1}/{num_episodes}")
|
||||
|
||||
# Reset environment
|
||||
state = initialize_state(exchange, TIMEFRAMES)
|
||||
total_reward = 0
|
||||
trades = []
|
||||
wins = 0
|
||||
losses = 0
|
||||
|
||||
# Episode loop
|
||||
for step in range(MAX_STEPS_PER_EPISODE):
|
||||
# Epsilon-greedy action selection
|
||||
if np.random.random() < epsilon:
|
||||
action = np.random.randint(0, action_dim)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(DEVICE)
|
||||
q_values, _, _ = dqn_model(state_tensor)
|
||||
action = q_values.argmax().item()
|
||||
|
||||
# Execute action and get next state and reward
|
||||
next_state, reward, done, trade_info = step_environment(
|
||||
exchange, state, action, price_model, TIMEFRAMES, DEVICE
|
||||
)
|
||||
|
||||
# Store transition in replay buffer
|
||||
replay_buffer.push(
|
||||
torch.FloatTensor(state),
|
||||
action,
|
||||
reward,
|
||||
torch.FloatTensor(next_state),
|
||||
done
|
||||
)
|
||||
|
||||
# Update state and accumulate reward
|
||||
state = next_state
|
||||
total_reward += reward
|
||||
|
||||
# Track trade outcomes
|
||||
if trade_info is not None:
|
||||
trades.append(trade_info)
|
||||
if trade_info['pnl'] > 0:
|
||||
wins += 1
|
||||
elif trade_info['pnl'] < 0:
|
||||
losses += 1
|
||||
|
||||
# Learn from experiences if enough samples
|
||||
if len(replay_buffer) > BATCH_SIZE:
|
||||
learn(dqn_model, target_dqn, replay_buffer, optimizer, scaler, DEVICE)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Update target network
|
||||
if episode % TARGET_UPDATE == 0:
|
||||
target_dqn.load_state_dict(dqn_model.state_dict())
|
||||
|
||||
# Calculate episode metrics
|
||||
avg_reward = total_reward / (step + 1)
|
||||
total_pnl = sum(trade['pnl'] for trade in trades) if trades else 0
|
||||
win_rate = (wins / (wins + losses) * 100) if (wins + losses) > 0 else 0
|
||||
|
||||
# Decay epsilon
|
||||
epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
|
||||
|
||||
# Track metrics
|
||||
rewards.append(avg_reward)
|
||||
profits.append(total_pnl)
|
||||
win_rates.append(win_rate)
|
||||
|
||||
# Log to TensorBoard
|
||||
writer.add_scalar('Training/Reward', avg_reward, episode)
|
||||
writer.add_scalar('Training/Profit', total_pnl, episode)
|
||||
writer.add_scalar('Training/WinRate', win_rate, episode)
|
||||
writer.add_scalar('Training/Epsilon', epsilon, episode)
|
||||
|
||||
# Print episode summary
|
||||
print(f"Episode {episode+1} - Avg Reward: {avg_reward:.2f}, PnL: ${total_pnl:.2f}, Win Rate: {win_rate:.1f}%")
|
||||
|
||||
# Save models and plot results
|
||||
if episode % SAVE_INTERVAL == 0 or episode == num_episodes - 1:
|
||||
best_reward, best_pnl, best_winrate = save_models(
|
||||
price_model, dqn_model, optimizer, episode,
|
||||
rewards, profits, win_rates,
|
||||
best_reward, best_pnl, best_winrate
|
||||
)
|
||||
plot_training_results(rewards, profits, win_rates, episode)
|
||||
|
||||
# Close TensorBoard writer
|
||||
writer.close()
|
||||
|
||||
# Final save and plot
|
||||
best_reward, best_pnl, best_winrate = save_models(
|
||||
price_model, dqn_model, optimizer, num_episodes - 1,
|
||||
rewards, profits, win_rates,
|
||||
best_reward, best_pnl, best_winrate
|
||||
)
|
||||
plot_training_results(rewards, profits, win_rates, num_episodes - 1)
|
||||
|
||||
print("Training complete!")
|
||||
return price_model, dqn_model
|
||||
|
||||
def learn(dqn, target_dqn, replay_buffer, optimizer, scaler, device):
|
||||
"""Update the DQN model using experiences from the replay buffer"""
|
||||
# Sample from replay buffer
|
||||
states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(BATCH_SIZE)
|
||||
|
||||
# Move to device
|
||||
states = states.to(device)
|
||||
actions = actions.to(device)
|
||||
rewards = rewards.to(device)
|
||||
next_states = next_states.to(device)
|
||||
dones = dones.to(device)
|
||||
weights = weights.to(device)
|
||||
|
||||
# Get current Q values
|
||||
if device.type == 'cuda':
|
||||
with autocast(device_type='cuda', enabled=True):
|
||||
current_q_values, _, _ = dqn(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Compute target Q values
|
||||
with torch.no_grad():
|
||||
next_q_values, _, _ = target_dqn(next_states)
|
||||
max_next_q_values = next_q_values.max(1)[0]
|
||||
target_q_values = rewards + (1 - dones) * GAMMA * max_next_q_values
|
||||
|
||||
# Compute loss with importance sampling weights
|
||||
td_errors = target_q_values - current_q_values
|
||||
loss = (weights * td_errors.pow(2)).mean()
|
||||
else:
|
||||
# CPU version without autocast
|
||||
current_q_values, _, _ = dqn(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Compute target Q values
|
||||
with torch.no_grad():
|
||||
next_q_values, _, _ = target_dqn(next_states)
|
||||
max_next_q_values = next_q_values.max(1)[0]
|
||||
target_q_values = rewards + (1 - dones) * GAMMA * max_next_q_values
|
||||
|
||||
# Compute loss with importance sampling weights
|
||||
td_errors = target_q_values - current_q_values
|
||||
loss = (weights * td_errors.pow(2)).mean()
|
||||
|
||||
# Update priorities in replay buffer
|
||||
replay_buffer.update_priorities(indices, td_errors.abs().detach().cpu().numpy())
|
||||
|
||||
# Optimize the model with mixed precision
|
||||
optimizer.zero_grad()
|
||||
|
||||
if device.type == 'cuda':
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(dqn.parameters(), max_norm=1.0)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
# CPU version without scaler
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(dqn.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
|
||||
def initialize_state(exchange, timeframes):
|
||||
"""Initialize the state with data from multiple timeframes"""
|
||||
# Fetch data for each timeframe
|
||||
timeframe_data = {}
|
||||
for tf in timeframes:
|
||||
candles = exchange.fetch_ohlcv(timeframe=tf, limit=30)
|
||||
timeframe_data[tf] = candles
|
||||
|
||||
# Extract features from each timeframe
|
||||
state = []
|
||||
|
||||
for tf in timeframes:
|
||||
candles = timeframe_data[tf]
|
||||
|
||||
# Price features
|
||||
prices = [candle[4] for candle in candles[-10:]] # Last 10 close prices
|
||||
price_changes = [prices[i]/prices[i-1] - 1 for i in range(1, len(prices))]
|
||||
|
||||
# Volume features
|
||||
volumes = [candle[5] for candle in candles[-10:]] # Last 10 volumes
|
||||
volume_changes = [volumes[i]/volumes[i-1] - 1 for i in range(1, len(volumes))]
|
||||
|
||||
# Technical indicators
|
||||
# Simple Moving Averages
|
||||
sma_5 = sum(prices[-5:]) / 5
|
||||
sma_10 = sum(prices) / 10
|
||||
|
||||
# Relative Strength Index (simplified)
|
||||
gains = [max(0, price_changes[i]) for i in range(len(price_changes))]
|
||||
losses = [max(0, -price_changes[i]) for i in range(len(price_changes))]
|
||||
avg_gain = sum(gains) / len(gains)
|
||||
avg_loss = sum(losses) / len(losses)
|
||||
rs = avg_gain / (avg_loss + 1e-10) # Avoid division by zero
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
|
||||
# Add features to state
|
||||
state.extend(price_changes) # 9 features
|
||||
state.extend(volume_changes) # 9 features
|
||||
state.append(sma_5 / prices[-1] - 1) # 1 feature
|
||||
state.append(sma_10 / prices[-1] - 1) # 1 feature
|
||||
state.append(rsi / 100) # 1 feature
|
||||
|
||||
# Add market regime features
|
||||
# This is a placeholder - in a real implementation, you would use the market_regime_classifier
|
||||
# from the DQN model to predict the current market regime
|
||||
state.extend([0, 0, 0]) # 3 features for market regime (one-hot encoded)
|
||||
|
||||
# Add additional features to reach the expected dimension of 100
|
||||
# Calculate more technical indicators
|
||||
for tf in timeframes:
|
||||
candles = timeframe_data[tf]
|
||||
prices = [candle[4] for candle in candles[-20:]] # Last 20 close prices
|
||||
|
||||
# Bollinger Bands
|
||||
window = 20
|
||||
if len(prices) >= window:
|
||||
sma_20 = sum(prices[-window:]) / window
|
||||
std_dev = (sum((price - sma_20) ** 2 for price in prices[-window:]) / window) ** 0.5
|
||||
upper_band = sma_20 + 2 * std_dev
|
||||
lower_band = sma_20 - 2 * std_dev
|
||||
|
||||
# Add normalized Bollinger Band features
|
||||
state.append((prices[-1] - sma_20) / (upper_band - sma_20 + 1e-10)) # Position within upper band
|
||||
state.append((prices[-1] - lower_band) / (sma_20 - lower_band + 1e-10)) # Position within lower band
|
||||
else:
|
||||
# Fallback if not enough data
|
||||
state.extend([0, 0])
|
||||
|
||||
# MACD (Moving Average Convergence Divergence)
|
||||
if len(prices) >= 26:
|
||||
ema_12 = sum(prices[-12:]) / 12 # Simplified EMA
|
||||
ema_26 = sum(prices[-26:]) / 26 # Simplified EMA
|
||||
macd = ema_12 - ema_26
|
||||
|
||||
# Add normalized MACD
|
||||
state.append(macd / prices[-1])
|
||||
else:
|
||||
# Fallback if not enough data
|
||||
state.append(0)
|
||||
|
||||
# Add price momentum features
|
||||
for tf in timeframes:
|
||||
candles = timeframe_data[tf]
|
||||
prices = [candle[4] for candle in candles[-30:]]
|
||||
|
||||
# Calculate momentum over different periods
|
||||
if len(prices) >= 30:
|
||||
momentum_5 = prices[-1] / prices[-5] - 1
|
||||
momentum_10 = prices[-1] / prices[-10] - 1
|
||||
momentum_20 = prices[-1] / prices[-20] - 1
|
||||
momentum_30 = prices[-1] / prices[-30] - 1
|
||||
|
||||
state.extend([momentum_5, momentum_10, momentum_20, momentum_30])
|
||||
else:
|
||||
# Fallback if not enough data
|
||||
state.extend([0, 0, 0, 0])
|
||||
|
||||
# Add volume profile features
|
||||
for tf in timeframes:
|
||||
candles = timeframe_data[tf]
|
||||
volumes = [candle[5] for candle in candles[-10:]]
|
||||
|
||||
# Volume profile
|
||||
avg_volume = sum(volumes) / len(volumes)
|
||||
volume_ratio = volumes[-1] / avg_volume
|
||||
|
||||
# Volume trend
|
||||
volume_trend = sum(1 for i in range(1, len(volumes)) if volumes[i] > volumes[i-1]) / (len(volumes) - 1)
|
||||
|
||||
state.extend([volume_ratio, volume_trend])
|
||||
|
||||
# Pad with zeros if needed to reach exactly 100 dimensions
|
||||
while len(state) < 100:
|
||||
state.append(0)
|
||||
|
||||
# Ensure state has exactly 100 dimensions
|
||||
if len(state) > 100:
|
||||
state = state[:100]
|
||||
|
||||
assert len(state) == 100, f"State dimension mismatch: {len(state)} != 100"
|
||||
|
||||
return state
|
||||
|
||||
def step_environment(exchange, state, action, price_model, timeframes, device):
|
||||
"""
|
||||
Execute action in the environment and return next state, reward, done flag, and trade info
|
||||
|
||||
Args:
|
||||
exchange: Exchange object to interact with
|
||||
state: Current state
|
||||
action: Action to take (0: Hold, 1: Buy, 2: Sell)
|
||||
price_model: Price prediction model
|
||||
timeframes: List of timeframes to use
|
||||
device: Device to run models on
|
||||
|
||||
Returns:
|
||||
next_state: Next state after taking action
|
||||
reward: Reward received
|
||||
done: Whether episode is done
|
||||
trade_info: Information about the trade (if any)
|
||||
"""
|
||||
# Fetch latest data for each timeframe
|
||||
timeframe_data = {}
|
||||
for tf in timeframes:
|
||||
candles = exchange.fetch_ohlcv(timeframe=tf, limit=30)
|
||||
timeframe_data[tf] = candles
|
||||
|
||||
# Prepare inputs for price prediction model
|
||||
price_inputs = []
|
||||
for tf in timeframes:
|
||||
candles = timeframe_data[tf]
|
||||
# Extract price and volume data
|
||||
input_data = torch.tensor([
|
||||
[candle[4], candle[5]] for candle in candles[-30:] # Last 30 candles
|
||||
], dtype=torch.float32).unsqueeze(0).to(device) # Add batch dimension
|
||||
price_inputs.append(input_data)
|
||||
|
||||
# Get price and extrema predictions
|
||||
with torch.no_grad():
|
||||
price_pred, extrema_logits, volume_pred = price_model(price_inputs)
|
||||
|
||||
# Convert predictions to numpy
|
||||
price_pred = price_pred.cpu().numpy()[0] # Remove batch dimension
|
||||
extrema_probs = torch.sigmoid(extrema_logits).cpu().numpy()[0]
|
||||
volume_pred = volume_pred.cpu().numpy()[0]
|
||||
|
||||
# Execute action
|
||||
current_price = timeframe_data['1m'][-1][4] # Current close price
|
||||
trade_info = None
|
||||
reward = 0
|
||||
|
||||
if action == 1: # Buy
|
||||
# Check if we're at a predicted low point (good time to buy)
|
||||
is_predicted_low = any(extrema_probs[i*2+1] > 0.7 for i in range(5))
|
||||
|
||||
# Calculate entry quality based on predictions
|
||||
entry_quality = 0.5 # Default quality
|
||||
if is_predicted_low:
|
||||
entry_quality += 0.2 # Bonus for buying at predicted low
|
||||
|
||||
# Check volume confirmation
|
||||
volume_increasing = volume_pred[0] > timeframe_data['1m'][-1][5]
|
||||
if volume_increasing:
|
||||
entry_quality += 0.1 # Bonus for increasing volume
|
||||
|
||||
# Execute buy order
|
||||
# In a real implementation, this would interact with the exchange
|
||||
# For now, we'll simulate the trade
|
||||
trade_info = {
|
||||
'action': 'buy',
|
||||
'price': current_price,
|
||||
'size': 100 * entry_quality, # Size based on entry quality
|
||||
'entry_quality': entry_quality,
|
||||
'pnl': 0 # Will be updated later
|
||||
}
|
||||
|
||||
# Calculate reward
|
||||
# Base reward for taking action
|
||||
reward = 1
|
||||
|
||||
# Bonus for buying at predicted low
|
||||
if is_predicted_low:
|
||||
reward += 5
|
||||
print("Trading at predicted low - additional reward")
|
||||
|
||||
# Bonus for volume confirmation
|
||||
if volume_increasing:
|
||||
reward += 2
|
||||
print("Trading with high volume - additional reward")
|
||||
|
||||
elif action == 2: # Sell
|
||||
# Check if we're at a predicted high point (good time to sell)
|
||||
is_predicted_high = any(extrema_probs[i*2] > 0.7 for i in range(5))
|
||||
|
||||
# Calculate entry quality based on predictions
|
||||
entry_quality = 0.5 # Default quality
|
||||
if is_predicted_high:
|
||||
entry_quality += 0.2 # Bonus for selling at predicted high
|
||||
|
||||
# Check volume confirmation
|
||||
volume_increasing = volume_pred[0] > timeframe_data['1m'][-1][5]
|
||||
if volume_increasing:
|
||||
entry_quality += 0.1 # Bonus for increasing volume
|
||||
|
||||
# Execute sell order
|
||||
# In a real implementation, this would interact with the exchange
|
||||
# For now, we'll simulate the trade
|
||||
trade_info = {
|
||||
'action': 'sell',
|
||||
'price': current_price,
|
||||
'size': 100 * entry_quality, # Size based on entry quality
|
||||
'entry_quality': entry_quality,
|
||||
'pnl': 0 # Will be updated later
|
||||
}
|
||||
|
||||
# Calculate reward
|
||||
# Base reward for taking action
|
||||
reward = 1
|
||||
|
||||
# Bonus for selling at predicted high
|
||||
if is_predicted_high:
|
||||
reward += 5
|
||||
print("Trading at predicted high - additional reward")
|
||||
|
||||
# Bonus for volume confirmation
|
||||
if volume_increasing:
|
||||
reward += 2
|
||||
print("Trading with high volume - additional reward")
|
||||
|
||||
else: # Hold
|
||||
# Small reward for holding
|
||||
reward = 0.1
|
||||
|
||||
# Simulate trade outcome
|
||||
if trade_info is not None:
|
||||
# In a real implementation, this would be based on actual market movement
|
||||
# For now, we'll use the price prediction to simulate the outcome
|
||||
future_price = price_pred[0] # Price in the next candle
|
||||
|
||||
if trade_info['action'] == 'buy':
|
||||
# For buy, profit if price goes up
|
||||
pnl_pct = (future_price / current_price - 1) * 100
|
||||
trade_info['pnl'] = pnl_pct * trade_info['size'] / 100
|
||||
else: # sell
|
||||
# For sell, profit if price goes down
|
||||
pnl_pct = (1 - future_price / current_price) * 100
|
||||
trade_info['pnl'] = pnl_pct * trade_info['size'] / 100
|
||||
|
||||
# Adjust reward based on trade outcome
|
||||
reward += trade_info['pnl'] * 10 # Scale PnL for reward
|
||||
|
||||
# Update state
|
||||
next_state = initialize_state(exchange, timeframes)
|
||||
|
||||
# Check if episode is done
|
||||
# In a real implementation, this would be based on episode length or other criteria
|
||||
done = False
|
||||
|
||||
return next_state, reward, done, trade_info
|
||||
|
||||
# Main function to run training
|
||||
def main():
|
||||
from exchange_simulator import ExchangeSimulator
|
||||
|
||||
# Initialize exchange simulator
|
||||
exchange = ExchangeSimulator()
|
||||
|
||||
# Train agent
|
||||
price_model, dqn_model = enhanced_train_agent(
|
||||
exchange=exchange,
|
||||
num_episodes=NUM_EPISODES,
|
||||
continuous=CONTINUOUS_MODE,
|
||||
start_episode=CONTINUOUS_START_EPISODE
|
||||
)
|
||||
|
||||
print("Training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
373
crypto/gogo2/exchange_simulator.py
Normal file
373
crypto/gogo2/exchange_simulator.py
Normal file
@ -0,0 +1,373 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
class ExchangeSimulator:
|
||||
"""
|
||||
A simple exchange simulator that generates realistic market data
|
||||
for testing trading algorithms without connecting to a real exchange.
|
||||
"""
|
||||
|
||||
def __init__(self, symbol="BTC/USDT", seed=42):
|
||||
"""
|
||||
Initialize the exchange simulator
|
||||
|
||||
Args:
|
||||
symbol: Trading pair symbol
|
||||
seed: Random seed for reproducibility
|
||||
"""
|
||||
self.symbol = symbol
|
||||
self.seed = seed
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
# Initialize data storage
|
||||
self.data = {}
|
||||
self.current_timestamp = datetime.now()
|
||||
|
||||
# Generate initial data for different timeframes
|
||||
self.timeframes = ['1m', '5m', '15m', '30m', '1h', '4h', '1d']
|
||||
self.timeframe_minutes = {
|
||||
'1m': 1,
|
||||
'5m': 5,
|
||||
'15m': 15,
|
||||
'30m': 30,
|
||||
'1h': 60,
|
||||
'4h': 240,
|
||||
'1d': 1440
|
||||
}
|
||||
|
||||
# Generate initial price around $50,000 (for BTC/USDT)
|
||||
self.base_price = 50000.0
|
||||
|
||||
# Generate data for each timeframe
|
||||
for tf in self.timeframes:
|
||||
self._generate_initial_data(tf)
|
||||
|
||||
def _generate_initial_data(self, timeframe, num_candles=1000):
|
||||
"""
|
||||
Generate initial historical data for a specific timeframe
|
||||
|
||||
Args:
|
||||
timeframe: Timeframe to generate data for
|
||||
num_candles: Number of candles to generate
|
||||
"""
|
||||
# Calculate time delta for this timeframe
|
||||
minutes = self.timeframe_minutes[timeframe]
|
||||
|
||||
# Generate timestamps
|
||||
end_time = self.current_timestamp
|
||||
timestamps = [end_time - timedelta(minutes=minutes * i) for i in range(num_candles)]
|
||||
timestamps.reverse() # Oldest first
|
||||
|
||||
# Generate price data with realistic patterns
|
||||
prices = self._generate_price_series(num_candles)
|
||||
|
||||
# Generate volume data with realistic patterns
|
||||
volumes = self._generate_volume_series(num_candles, timeframe)
|
||||
|
||||
# Create OHLCV data
|
||||
ohlcv_data = []
|
||||
for i in range(num_candles):
|
||||
# Calculate OHLC based on close price
|
||||
close = prices[i]
|
||||
high = close * (1 + np.random.uniform(0, 0.01))
|
||||
low = close * (1 - np.random.uniform(0, 0.01))
|
||||
open_price = prices[i-1] if i > 0 else close * (1 - np.random.uniform(-0.005, 0.005))
|
||||
|
||||
# Create candle
|
||||
candle = [
|
||||
int(timestamps[i].timestamp() * 1000), # Timestamp in milliseconds
|
||||
open_price, # Open
|
||||
high, # High
|
||||
low, # Low
|
||||
close, # Close
|
||||
volumes[i] # Volume
|
||||
]
|
||||
ohlcv_data.append(candle)
|
||||
|
||||
# Store data
|
||||
self.data[timeframe] = ohlcv_data
|
||||
|
||||
def _generate_price_series(self, length):
|
||||
"""
|
||||
Generate a realistic price series with trends, reversals, and volatility
|
||||
|
||||
Args:
|
||||
length: Number of prices to generate
|
||||
|
||||
Returns:
|
||||
List of prices
|
||||
"""
|
||||
# Start with base price
|
||||
prices = [self.base_price]
|
||||
|
||||
# Parameters for price generation
|
||||
trend_strength = 0.001 # Strength of trend
|
||||
volatility = 0.005 # Daily volatility
|
||||
mean_reversion = 0.001 # Mean reversion strength
|
||||
|
||||
# Generate price series
|
||||
for i in range(1, length):
|
||||
# Determine if we're in a trend
|
||||
if i % 100 == 0:
|
||||
# Change trend direction every ~100 candles
|
||||
trend_strength = -trend_strength
|
||||
|
||||
# Calculate price change
|
||||
trend = trend_strength * prices[-1]
|
||||
random_change = np.random.normal(0, volatility) * prices[-1]
|
||||
mean_reversion_change = mean_reversion * (self.base_price - prices[-1])
|
||||
|
||||
# Calculate new price
|
||||
new_price = prices[-1] + trend + random_change + mean_reversion_change
|
||||
|
||||
# Ensure price doesn't go negative
|
||||
new_price = max(new_price, prices[-1] * 0.9)
|
||||
|
||||
prices.append(new_price)
|
||||
|
||||
return prices
|
||||
|
||||
def _generate_volume_series(self, length, timeframe):
|
||||
"""
|
||||
Generate a realistic volume series with patterns
|
||||
|
||||
Args:
|
||||
length: Number of volumes to generate
|
||||
timeframe: Timeframe for volume scaling
|
||||
|
||||
Returns:
|
||||
List of volumes
|
||||
"""
|
||||
# Base volume depends on timeframe
|
||||
base_volume = {
|
||||
'1m': 10,
|
||||
'5m': 50,
|
||||
'15m': 150,
|
||||
'30m': 300,
|
||||
'1h': 600,
|
||||
'4h': 2400,
|
||||
'1d': 10000
|
||||
}[timeframe]
|
||||
|
||||
# Generate volume series
|
||||
volumes = []
|
||||
for i in range(length):
|
||||
# Volume tends to be higher at trend reversals and during volatile periods
|
||||
cycle_factor = 1 + 0.5 * np.sin(i / 20) # Cyclical pattern
|
||||
random_factor = np.random.lognormal(0, 0.5) # Random spikes
|
||||
|
||||
# Calculate volume
|
||||
volume = base_volume * cycle_factor * random_factor
|
||||
|
||||
# Add some volume spikes
|
||||
if random.random() < 0.05: # 5% chance of volume spike
|
||||
volume *= random.uniform(2, 5)
|
||||
|
||||
volumes.append(volume)
|
||||
|
||||
return volumes
|
||||
|
||||
def fetch_ohlcv(self, timeframe='1m', limit=100, since=None):
|
||||
"""
|
||||
Fetch OHLCV data for a specific timeframe
|
||||
|
||||
Args:
|
||||
timeframe: Timeframe to fetch data for
|
||||
limit: Number of candles to fetch
|
||||
since: Timestamp to fetch data since (not used in simulator)
|
||||
|
||||
Returns:
|
||||
List of OHLCV candles
|
||||
"""
|
||||
# Ensure timeframe exists
|
||||
if timeframe not in self.data:
|
||||
if timeframe in self.timeframe_minutes:
|
||||
self._generate_initial_data(timeframe)
|
||||
else:
|
||||
# Default to 1m if timeframe not supported
|
||||
timeframe = '1m'
|
||||
|
||||
# Get data
|
||||
data = self.data[timeframe]
|
||||
|
||||
# Return limited data
|
||||
return data[-limit:]
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
Update the exchange data by generating a new candle for each timeframe
|
||||
"""
|
||||
# Update current timestamp
|
||||
self.current_timestamp = datetime.now()
|
||||
|
||||
# Update each timeframe
|
||||
for tf in self.timeframes:
|
||||
self._add_new_candle(tf)
|
||||
|
||||
def _add_new_candle(self, timeframe):
|
||||
"""
|
||||
Add a new candle to the specified timeframe
|
||||
|
||||
Args:
|
||||
timeframe: Timeframe to add candle to
|
||||
"""
|
||||
# Get existing data
|
||||
data = self.data[timeframe]
|
||||
|
||||
# Get last close price
|
||||
last_close = data[-1][4]
|
||||
|
||||
# Calculate time delta for this timeframe
|
||||
minutes = self.timeframe_minutes[timeframe]
|
||||
|
||||
# Calculate new timestamp
|
||||
new_timestamp = int((data[-1][0] / 1000 + minutes * 60) * 1000)
|
||||
|
||||
# Generate new price with some randomness
|
||||
price_change = np.random.normal(0, 0.002) * last_close
|
||||
new_close = last_close + price_change
|
||||
|
||||
# Calculate OHLC
|
||||
new_open = last_close
|
||||
new_high = max(new_open, new_close) * (1 + np.random.uniform(0, 0.005))
|
||||
new_low = min(new_open, new_close) * (1 - np.random.uniform(0, 0.005))
|
||||
|
||||
# Generate volume
|
||||
base_volume = data[-1][5]
|
||||
volume_change = np.random.normal(0, 0.2) * base_volume
|
||||
new_volume = max(base_volume + volume_change, base_volume * 0.5)
|
||||
|
||||
# Create new candle
|
||||
new_candle = [
|
||||
new_timestamp,
|
||||
new_open,
|
||||
new_high,
|
||||
new_low,
|
||||
new_close,
|
||||
new_volume
|
||||
]
|
||||
|
||||
# Add to data
|
||||
self.data[timeframe].append(new_candle)
|
||||
|
||||
def get_ticker(self, symbol=None):
|
||||
"""
|
||||
Get current ticker information
|
||||
|
||||
Args:
|
||||
symbol: Symbol to get ticker for (defaults to initialized symbol)
|
||||
|
||||
Returns:
|
||||
Dictionary with ticker information
|
||||
"""
|
||||
if symbol is None:
|
||||
symbol = self.symbol
|
||||
|
||||
# Get latest 1m candle
|
||||
latest_candle = self.data['1m'][-1]
|
||||
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'bid': latest_candle[4] * 0.9999, # Slightly below last price
|
||||
'ask': latest_candle[4] * 1.0001, # Slightly above last price
|
||||
'last': latest_candle[4],
|
||||
'high': latest_candle[2],
|
||||
'low': latest_candle[3],
|
||||
'volume': latest_candle[5],
|
||||
'timestamp': latest_candle[0]
|
||||
}
|
||||
|
||||
def create_order(self, symbol, type, side, amount, price=None):
|
||||
"""
|
||||
Simulate creating an order
|
||||
|
||||
Args:
|
||||
symbol: Symbol to create order for
|
||||
type: Order type (limit, market)
|
||||
side: Order side (buy, sell)
|
||||
amount: Order amount
|
||||
price: Order price (for limit orders)
|
||||
|
||||
Returns:
|
||||
Dictionary with order information
|
||||
"""
|
||||
# Get current ticker
|
||||
ticker = self.get_ticker(symbol)
|
||||
|
||||
# Determine execution price
|
||||
if type == 'market':
|
||||
if side == 'buy':
|
||||
execution_price = ticker['ask']
|
||||
else:
|
||||
execution_price = ticker['bid']
|
||||
else: # limit order
|
||||
execution_price = price
|
||||
|
||||
# Create order object
|
||||
order = {
|
||||
'id': f"order_{int(datetime.now().timestamp() * 1000)}",
|
||||
'symbol': symbol,
|
||||
'type': type,
|
||||
'side': side,
|
||||
'amount': amount,
|
||||
'price': execution_price,
|
||||
'cost': amount * execution_price,
|
||||
'filled': amount,
|
||||
'status': 'closed',
|
||||
'timestamp': int(datetime.now().timestamp() * 1000)
|
||||
}
|
||||
|
||||
return order
|
||||
|
||||
def fetch_balance(self):
|
||||
"""
|
||||
Fetch account balance (simulated)
|
||||
|
||||
Returns:
|
||||
Dictionary with balance information
|
||||
"""
|
||||
return {
|
||||
'total': {
|
||||
'USD': 10000.0,
|
||||
'BTC': 1.0
|
||||
},
|
||||
'free': {
|
||||
'USD': 5000.0,
|
||||
'BTC': 0.5
|
||||
},
|
||||
'used': {
|
||||
'USD': 5000.0,
|
||||
'BTC': 0.5
|
||||
}
|
||||
}
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Create exchange simulator
|
||||
exchange = ExchangeSimulator()
|
||||
|
||||
# Fetch some data
|
||||
ohlcv = exchange.fetch_ohlcv(timeframe='1h', limit=10)
|
||||
print("OHLCV data (1h timeframe):")
|
||||
for candle in ohlcv[-5:]:
|
||||
timestamp = datetime.fromtimestamp(candle[0] / 1000)
|
||||
print(f"{timestamp}: Open={candle[1]:.2f}, High={candle[2]:.2f}, Low={candle[3]:.2f}, Close={candle[4]:.2f}, Volume={candle[5]:.2f}")
|
||||
|
||||
# Get current ticker
|
||||
ticker = exchange.get_ticker()
|
||||
print(f"\nCurrent ticker: {ticker['last']:.2f}")
|
||||
|
||||
# Create a market buy order
|
||||
order = exchange.create_order("BTC/USDT", "market", "buy", 0.1)
|
||||
print(f"\nCreated order: {order}")
|
||||
|
||||
# Update the exchange (simulate time passing)
|
||||
exchange.update()
|
||||
|
||||
# Get updated ticker
|
||||
updated_ticker = exchange.get_ticker()
|
||||
print(f"\nUpdated ticker: {updated_ticker['last']:.2f}")
|
1605
crypto/gogo2/main.py
1605
crypto/gogo2/main.py
File diff suppressed because it is too large
Load Diff
305
crypto/gogo2/run_enhanced_training.py
Normal file
305
crypto/gogo2/run_enhanced_training.py
Normal file
@ -0,0 +1,305 @@
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from enhanced_training import enhanced_train_agent
|
||||
from exchange_simulator import ExchangeSimulator
|
||||
|
||||
def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='Enhanced Trading Bot Training')
|
||||
|
||||
parser.add_argument('--mode', type=str, default='train', choices=['train', 'continuous', 'evaluate', 'live', 'demo'],
|
||||
help='Mode to run the trading bot in')
|
||||
|
||||
parser.add_argument('--episodes', type=int, default=100,
|
||||
help='Number of episodes to train for')
|
||||
|
||||
parser.add_argument('--start-episode', type=int, default=0,
|
||||
help='Episode to start from for continuous training')
|
||||
|
||||
parser.add_argument('--device', type=str, default='auto',
|
||||
help='Device to train on (auto, cuda, cpu)')
|
||||
|
||||
parser.add_argument('--timeframes', type=str, default='1m,15m,1h',
|
||||
help='Comma-separated list of timeframes to use')
|
||||
|
||||
parser.add_argument('--refresh-data', action='store_true',
|
||||
help='Refresh data before training')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set device
|
||||
if args.device == 'auto':
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Parse timeframes
|
||||
timeframes = args.timeframes.split(',')
|
||||
print(f"Using timeframes: {timeframes}")
|
||||
|
||||
# Initialize exchange simulator
|
||||
exchange = ExchangeSimulator()
|
||||
|
||||
# Run in specified mode
|
||||
if args.mode == 'train':
|
||||
# Train from scratch
|
||||
print(f"Training for {args.episodes} episodes...")
|
||||
enhanced_train_agent(
|
||||
exchange=exchange,
|
||||
num_episodes=args.episodes,
|
||||
continuous=False,
|
||||
start_episode=0
|
||||
)
|
||||
|
||||
elif args.mode == 'continuous':
|
||||
# Continue training from checkpoint
|
||||
print(f"Continuing training from episode {args.start_episode} for {args.episodes} episodes...")
|
||||
enhanced_train_agent(
|
||||
exchange=exchange,
|
||||
num_episodes=args.episodes,
|
||||
continuous=True,
|
||||
start_episode=args.start_episode
|
||||
)
|
||||
|
||||
elif args.mode == 'evaluate':
|
||||
# Evaluate the model
|
||||
print("Evaluating model...")
|
||||
evaluate_model(exchange, device)
|
||||
|
||||
elif args.mode == 'live' or args.mode == 'demo':
|
||||
# Run in live or demo mode
|
||||
is_demo = args.mode == 'demo'
|
||||
print(f"Running in {'demo' if is_demo else 'live'} mode...")
|
||||
run_live(exchange, device, is_demo=is_demo)
|
||||
|
||||
print("Done!")
|
||||
|
||||
def evaluate_model(exchange, device):
|
||||
"""
|
||||
Evaluate the trained model
|
||||
|
||||
Args:
|
||||
exchange: Exchange simulator
|
||||
device: Device to run on
|
||||
"""
|
||||
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# Load the best model
|
||||
model_path = 'models/enhanced_trading_agent_best_pnl.pt'
|
||||
if not os.path.exists(model_path):
|
||||
model_path = 'models/enhanced_trading_agent_latest.pt'
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print("No model found to evaluate!")
|
||||
return
|
||||
|
||||
print(f"Loading model from {model_path}")
|
||||
checkpoint = torch.load(model_path, map_location=device)
|
||||
|
||||
# Initialize models
|
||||
state_dim = 100
|
||||
action_dim = 3
|
||||
timeframes = ['1m', '15m', '1h']
|
||||
|
||||
price_model = EnhancedPricePredictionModel(
|
||||
input_dim=2,
|
||||
hidden_dim=256,
|
||||
num_layers=3,
|
||||
output_dim=5,
|
||||
num_timeframes=len(timeframes)
|
||||
).to(device)
|
||||
|
||||
dqn_model = EnhancedDQN(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
hidden_dim=512
|
||||
).to(device)
|
||||
|
||||
# Load model weights
|
||||
price_model.load_state_dict(checkpoint['price_model_state_dict'])
|
||||
dqn_model.load_state_dict(checkpoint['dqn_model_state_dict'])
|
||||
|
||||
# Set models to evaluation mode
|
||||
price_model.eval()
|
||||
dqn_model.eval()
|
||||
|
||||
# Run evaluation
|
||||
num_steps = 1000
|
||||
total_reward = 0
|
||||
trades = []
|
||||
|
||||
# Initialize state
|
||||
from enhanced_training import initialize_state, step_environment
|
||||
state = initialize_state(exchange, timeframes)
|
||||
|
||||
for step in range(num_steps):
|
||||
# Select action
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||
q_values, _, _ = dqn_model(state_tensor)
|
||||
action = q_values.argmax().item()
|
||||
|
||||
# Execute action
|
||||
next_state, reward, done, trade_info = step_environment(
|
||||
exchange, state, action, price_model, timeframes, device
|
||||
)
|
||||
|
||||
# Update state and accumulate reward
|
||||
state = next_state
|
||||
total_reward += reward
|
||||
|
||||
# Track trade
|
||||
if trade_info is not None:
|
||||
trades.append(trade_info)
|
||||
print(f"Trade: {trade_info['action']} at {trade_info['price']:.2f}, PnL: {trade_info['pnl']:.2f}")
|
||||
|
||||
# Update exchange (simulate time passing)
|
||||
if step % 10 == 0:
|
||||
exchange.update()
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Calculate metrics
|
||||
avg_reward = total_reward / num_steps
|
||||
total_pnl = sum(trade['pnl'] for trade in trades) if trades else 0
|
||||
wins = sum(1 for trade in trades if trade['pnl'] > 0)
|
||||
losses = sum(1 for trade in trades if trade['pnl'] < 0)
|
||||
win_rate = (wins / (wins + losses) * 100) if (wins + losses) > 0 else 0
|
||||
|
||||
print("\nEvaluation Results:")
|
||||
print(f"Average Reward: {avg_reward:.2f}")
|
||||
print(f"Total PnL: ${total_pnl:.2f}")
|
||||
print(f"Win Rate: {win_rate:.1f}% ({wins}/{wins+losses})")
|
||||
|
||||
def run_live(exchange, device, is_demo=True):
|
||||
"""
|
||||
Run the trading bot in live or demo mode
|
||||
|
||||
Args:
|
||||
exchange: Exchange simulator or real exchange
|
||||
device: Device to run on
|
||||
is_demo: Whether to run in demo mode (no real trades)
|
||||
"""
|
||||
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN
|
||||
import torch
|
||||
import time
|
||||
|
||||
# Load the best model
|
||||
model_path = 'models/enhanced_trading_agent_best_pnl.pt'
|
||||
if not os.path.exists(model_path):
|
||||
model_path = 'models/enhanced_trading_agent_latest.pt'
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print("No model found to run in live mode!")
|
||||
return
|
||||
|
||||
print(f"Loading model from {model_path}")
|
||||
checkpoint = torch.load(model_path, map_location=device)
|
||||
|
||||
# Initialize models
|
||||
state_dim = 100
|
||||
action_dim = 3
|
||||
timeframes = ['1m', '15m', '1h']
|
||||
|
||||
price_model = EnhancedPricePredictionModel(
|
||||
input_dim=2,
|
||||
hidden_dim=256,
|
||||
num_layers=3,
|
||||
output_dim=5,
|
||||
num_timeframes=len(timeframes)
|
||||
).to(device)
|
||||
|
||||
dqn_model = EnhancedDQN(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
hidden_dim=512
|
||||
).to(device)
|
||||
|
||||
# Load model weights
|
||||
price_model.load_state_dict(checkpoint['price_model_state_dict'])
|
||||
dqn_model.load_state_dict(checkpoint['dqn_model_state_dict'])
|
||||
|
||||
# Set models to evaluation mode
|
||||
price_model.eval()
|
||||
dqn_model.eval()
|
||||
|
||||
# Run live trading
|
||||
print(f"Running in {'demo' if is_demo else 'live'} mode...")
|
||||
print("Press Ctrl+C to stop")
|
||||
|
||||
# Initialize state
|
||||
from enhanced_training import initialize_state, step_environment
|
||||
state = initialize_state(exchange, timeframes)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Select action
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||
q_values, _, market_regime = dqn_model(state_tensor)
|
||||
action = q_values.argmax().item()
|
||||
|
||||
# Get market regime prediction
|
||||
regime_probs = torch.softmax(market_regime, dim=1).cpu().numpy()[0]
|
||||
regime_names = ['Trending', 'Ranging', 'Volatile']
|
||||
predicted_regime = regime_names[regime_probs.argmax()]
|
||||
|
||||
# Get current price
|
||||
ticker = exchange.get_ticker()
|
||||
current_price = ticker['last']
|
||||
|
||||
# Print state
|
||||
print(f"\nCurrent price: ${current_price:.2f}")
|
||||
print(f"Predicted market regime: {predicted_regime} ({regime_probs.max()*100:.1f}% confidence)")
|
||||
|
||||
# Execute action
|
||||
next_state, reward, _, trade_info = step_environment(
|
||||
exchange, state, action, price_model, timeframes, device
|
||||
)
|
||||
|
||||
# Print action
|
||||
action_names = ['Hold', 'Buy', 'Sell']
|
||||
print(f"Action: {action_names[action]}")
|
||||
|
||||
if trade_info is not None:
|
||||
print(f"Trade: {trade_info['action']} at {trade_info['price']:.2f}, Size: {trade_info['size']:.2f}, Entry Quality: {trade_info['entry_quality']:.2f}")
|
||||
|
||||
# Execute real trade if not in demo mode
|
||||
if not is_demo:
|
||||
if trade_info['action'] == 'buy':
|
||||
order = exchange.create_order(
|
||||
symbol="BTC/USDT",
|
||||
type="market",
|
||||
side="buy",
|
||||
amount=trade_info['size'] / current_price
|
||||
)
|
||||
print(f"Executed buy order: {order}")
|
||||
else: # sell
|
||||
order = exchange.create_order(
|
||||
symbol="BTC/USDT",
|
||||
type="market",
|
||||
side="sell",
|
||||
amount=trade_info['size'] / current_price
|
||||
)
|
||||
print(f"Executed sell order: {order}")
|
||||
|
||||
# Update state
|
||||
state = next_state
|
||||
|
||||
# Update exchange (simulate time passing)
|
||||
exchange.update()
|
||||
|
||||
# Wait for next candle
|
||||
time.sleep(5) # In a real implementation, this would wait for the next candle
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping live trading")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
185
crypto/gogo2/test_cache.py
Normal file
185
crypto/gogo2/test_cache.py
Normal file
@ -0,0 +1,185 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger('cache_test')
|
||||
|
||||
# Import our cache implementation
|
||||
from data_cache import ohlcv_cache
|
||||
|
||||
def generate_sample_data(num_candles=100):
|
||||
"""Generate sample OHLCV data for testing"""
|
||||
data = []
|
||||
base_timestamp = int(time.time() * 1000) - (num_candles * 60 * 1000) # Start from num_candles minutes ago
|
||||
|
||||
for i in range(num_candles):
|
||||
timestamp = base_timestamp + (i * 60 * 1000) # Add i minutes
|
||||
|
||||
# Generate some random-ish but realistic looking price data
|
||||
base_price = 1900.0 + (i * 0.5) # Slight uptrend
|
||||
open_price = base_price - 0.5 + (i % 3)
|
||||
close_price = base_price + 0.3 + ((i+1) % 4)
|
||||
high_price = max(open_price, close_price) + 1.0 + (i % 2)
|
||||
low_price = min(open_price, close_price) - 0.8 - (i % 2)
|
||||
volume = 10.0 + (i % 10) * 2.0
|
||||
|
||||
data.append({
|
||||
'timestamp': timestamp,
|
||||
'open': open_price,
|
||||
'high': high_price,
|
||||
'low': low_price,
|
||||
'close': close_price,
|
||||
'volume': volume
|
||||
})
|
||||
|
||||
return data
|
||||
|
||||
def test_cache_save_load():
|
||||
"""Test saving and loading data from cache"""
|
||||
logger.info("Testing cache save and load...")
|
||||
|
||||
# Generate sample data
|
||||
data = generate_sample_data(100)
|
||||
logger.info(f"Generated {len(data)} sample candles")
|
||||
|
||||
# Save to cache
|
||||
symbol = "ETH/USDT"
|
||||
timeframe = "1m"
|
||||
success = ohlcv_cache.save(data, symbol, timeframe)
|
||||
logger.info(f"Saved to cache: {success}")
|
||||
|
||||
# Load from cache
|
||||
cached_data = ohlcv_cache.load(symbol, timeframe)
|
||||
logger.info(f"Loaded {len(cached_data) if cached_data else 0} candles from cache")
|
||||
|
||||
# Verify data integrity
|
||||
if cached_data:
|
||||
first_original = data[0]
|
||||
first_cached = cached_data[0]
|
||||
logger.info(f"First original candle: {first_original}")
|
||||
logger.info(f"First cached candle: {first_cached}")
|
||||
|
||||
last_original = data[-1]
|
||||
last_cached = cached_data[-1]
|
||||
logger.info(f"Last original candle: {last_original}")
|
||||
logger.info(f"Last cached candle: {last_cached}")
|
||||
|
||||
return success and cached_data and len(cached_data) == len(data)
|
||||
|
||||
def test_cache_append():
|
||||
"""Test appending a new candle to cached data"""
|
||||
logger.info("Testing cache append...")
|
||||
|
||||
# Generate sample data
|
||||
data = generate_sample_data(100)
|
||||
|
||||
# Save to cache
|
||||
symbol = "ETH/USDT"
|
||||
timeframe = "5m"
|
||||
success = ohlcv_cache.save(data, symbol, timeframe)
|
||||
logger.info(f"Saved to cache: {success}")
|
||||
|
||||
# Generate a new candle
|
||||
last_timestamp = data[-1]['timestamp']
|
||||
new_timestamp = last_timestamp + (5 * 60 * 1000) # 5 minutes later
|
||||
new_candle = {
|
||||
'timestamp': new_timestamp,
|
||||
'open': 1950.0,
|
||||
'high': 1955.0,
|
||||
'low': 1948.0,
|
||||
'close': 1952.0,
|
||||
'volume': 15.0
|
||||
}
|
||||
|
||||
# Append to cache
|
||||
success = ohlcv_cache.append(new_candle, symbol, timeframe)
|
||||
logger.info(f"Appended to cache: {success}")
|
||||
|
||||
# Load from cache
|
||||
cached_data = ohlcv_cache.load(symbol, timeframe)
|
||||
logger.info(f"Loaded {len(cached_data) if cached_data else 0} candles from cache")
|
||||
|
||||
# Verify the new candle was appended
|
||||
if cached_data:
|
||||
last_cached = cached_data[-1]
|
||||
logger.info(f"New candle: {new_candle}")
|
||||
logger.info(f"Last cached candle: {last_cached}")
|
||||
|
||||
return success and cached_data and len(cached_data) == len(data) + 1
|
||||
|
||||
def test_cache_dataframe():
|
||||
"""Test converting cached data to a pandas DataFrame"""
|
||||
logger.info("Testing cache to DataFrame conversion...")
|
||||
|
||||
# Generate sample data
|
||||
data = generate_sample_data(100)
|
||||
|
||||
# Save to cache
|
||||
symbol = "ETH/USDT"
|
||||
timeframe = "15m"
|
||||
success = ohlcv_cache.save(data, symbol, timeframe)
|
||||
logger.info(f"Saved to cache: {success}")
|
||||
|
||||
# Convert to DataFrame
|
||||
df = ohlcv_cache.to_dataframe(symbol, timeframe)
|
||||
logger.info(f"Converted to DataFrame with {len(df) if df is not None else 0} rows")
|
||||
|
||||
# Display DataFrame info
|
||||
if df is not None:
|
||||
logger.info(f"DataFrame columns: {df.columns.tolist()}")
|
||||
logger.info(f"DataFrame index: {df.index.name}")
|
||||
logger.info(f"First row: {df.iloc[0].to_dict()}")
|
||||
logger.info(f"Last row: {df.iloc[-1].to_dict()}")
|
||||
|
||||
return success and df is not None and len(df) == len(data)
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
logger.info("Starting cache tests...")
|
||||
|
||||
# Run tests
|
||||
save_load_success = test_cache_save_load()
|
||||
append_success = test_cache_append()
|
||||
dataframe_success = test_cache_dataframe()
|
||||
|
||||
# Print results
|
||||
logger.info("Test results:")
|
||||
logger.info(f" Save/Load: {'PASS' if save_load_success else 'FAIL'}")
|
||||
logger.info(f" Append: {'PASS' if append_success else 'FAIL'}")
|
||||
logger.info(f" DataFrame: {'PASS' if dataframe_success else 'FAIL'}")
|
||||
|
||||
# Check cache directory contents
|
||||
cache_dir = ohlcv_cache.cache_dir
|
||||
logger.info(f"Cache directory: {cache_dir}")
|
||||
if os.path.exists(cache_dir):
|
||||
files = os.listdir(cache_dir)
|
||||
logger.info(f"Cache files: {files}")
|
||||
|
||||
# Print file sizes
|
||||
for file in files:
|
||||
file_path = os.path.join(cache_dir, file)
|
||||
size_kb = os.path.getsize(file_path) / 1024
|
||||
logger.info(f" {file}: {size_kb:.2f} KB")
|
||||
|
||||
# Print first few lines of each file
|
||||
with open(file_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
logger.info(f" Metadata: symbol={data.get('symbol')}, timeframe={data.get('timeframe')}, last_updated={datetime.fromtimestamp(data.get('last_updated')).strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
logger.info(f" Candles: {len(data.get('data', []))}")
|
||||
|
||||
return save_load_success and append_success and dataframe_success
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user