improvments and fixes
This commit is contained in:
parent
506458d55e
commit
2e901e18f2
@ -1,3 +1,5 @@
|
|||||||
|
pip install torch-tb-profiler
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ensure we use GPU if available to train faster. during training we need to have RL loop that looks at streaming data, and retrospective backtesting/training on predictions. sincr the start of the traing we're only loosing. implement robust penalty and analysis when closing a loosing trade and improve the reward function.
|
ensure we use GPU if available to train faster. during training we need to have RL loop that looks at streaming data, and retrospective backtesting/training on predictions. sincr the start of the traing we're only loosing. implement robust penalty and analysis when closing a loosing trade and improve the reward function.
|
||||||
|
319
crypto/gogo2/data_cache.py
Normal file
319
crypto/gogo2/data_cache.py
Normal file
@ -0,0 +1,319 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
logger = logging.getLogger('trading_bot')
|
||||||
|
|
||||||
|
class OHLCVCache:
|
||||||
|
"""
|
||||||
|
A simple cache for OHLCV data from exchanges.
|
||||||
|
Stores data in a structured format and provides backup when exchange is unavailable.
|
||||||
|
"""
|
||||||
|
def __init__(self, cache_dir="cache", max_age_hours=24):
|
||||||
|
"""
|
||||||
|
Initialize the OHLCV cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_dir: Directory to store cache files
|
||||||
|
max_age_hours: Maximum age of cached data in hours before considered stale
|
||||||
|
"""
|
||||||
|
self.cache_dir = cache_dir
|
||||||
|
self.max_age_seconds = max_age_hours * 3600
|
||||||
|
|
||||||
|
# Create cache directory if it doesn't exist
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# In-memory cache for faster access
|
||||||
|
self.memory_cache = {}
|
||||||
|
|
||||||
|
def _get_cache_filename(self, symbol, timeframe):
|
||||||
|
"""Generate a standardized filename for the cache file"""
|
||||||
|
# Replace / with _ in symbol name (e.g., ETH/USDT -> ETH_USDT)
|
||||||
|
safe_symbol = symbol.replace('/', '_')
|
||||||
|
return os.path.join(self.cache_dir, f"{safe_symbol}_{timeframe}.json")
|
||||||
|
|
||||||
|
def save(self, data, symbol, timeframe):
|
||||||
|
"""
|
||||||
|
Save OHLCV data to cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: List of dictionaries containing OHLCV data
|
||||||
|
symbol: Trading pair symbol (e.g., 'ETH/USDT')
|
||||||
|
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
|
||||||
|
"""
|
||||||
|
if not data:
|
||||||
|
logger.warning(f"No data to cache for {symbol} ({timeframe})")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert data to a serializable format
|
||||||
|
serializable_data = []
|
||||||
|
for candle in data:
|
||||||
|
serializable_data.append({
|
||||||
|
'timestamp': candle['timestamp'],
|
||||||
|
'open': float(candle['open']),
|
||||||
|
'high': float(candle['high']),
|
||||||
|
'low': float(candle['low']),
|
||||||
|
'close': float(candle['close']),
|
||||||
|
'volume': float(candle['volume'])
|
||||||
|
})
|
||||||
|
|
||||||
|
# Create cache entry with metadata
|
||||||
|
cache_entry = {
|
||||||
|
'symbol': symbol,
|
||||||
|
'timeframe': timeframe,
|
||||||
|
'last_updated': int(time.time()),
|
||||||
|
'data': serializable_data
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save to file
|
||||||
|
filename = self._get_cache_filename(symbol, timeframe)
|
||||||
|
with open(filename, 'w') as f:
|
||||||
|
json.dump(cache_entry, f)
|
||||||
|
|
||||||
|
# Update in-memory cache
|
||||||
|
cache_key = f"{symbol}_{timeframe}"
|
||||||
|
self.memory_cache[cache_key] = cache_entry
|
||||||
|
|
||||||
|
logger.info(f"Cached {len(data)} candles for {symbol} ({timeframe})")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving data to cache: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def load(self, symbol, timeframe, max_age_override=None):
|
||||||
|
"""
|
||||||
|
Load OHLCV data from cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading pair symbol (e.g., 'ETH/USDT')
|
||||||
|
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
|
||||||
|
max_age_override: Override the default max age (in seconds)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries containing OHLCV data, or None if cache is missing or stale
|
||||||
|
"""
|
||||||
|
cache_key = f"{symbol}_{timeframe}"
|
||||||
|
max_age = max_age_override if max_age_override is not None else self.max_age_seconds
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check in-memory cache first
|
||||||
|
if cache_key in self.memory_cache:
|
||||||
|
cache_entry = self.memory_cache[cache_key]
|
||||||
|
|
||||||
|
# Check if cache is fresh
|
||||||
|
cache_age = int(time.time()) - cache_entry['last_updated']
|
||||||
|
if cache_age <= max_age:
|
||||||
|
logger.info(f"Using in-memory cache for {symbol} ({timeframe}), age: {cache_age//60} minutes")
|
||||||
|
return cache_entry['data']
|
||||||
|
|
||||||
|
# Check file cache
|
||||||
|
filename = self._get_cache_filename(symbol, timeframe)
|
||||||
|
if not os.path.exists(filename):
|
||||||
|
logger.info(f"No cache file found for {symbol} ({timeframe})")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Load cache file
|
||||||
|
with open(filename, 'r') as f:
|
||||||
|
cache_entry = json.load(f)
|
||||||
|
|
||||||
|
# Check if cache is fresh
|
||||||
|
cache_age = int(time.time()) - cache_entry['last_updated']
|
||||||
|
if cache_age > max_age:
|
||||||
|
logger.info(f"Cache for {symbol} ({timeframe}) is stale ({cache_age//60} minutes old)")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Update in-memory cache
|
||||||
|
self.memory_cache[cache_key] = cache_entry
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(cache_entry['data'])} candles from cache for {symbol} ({timeframe})")
|
||||||
|
return cache_entry['data']
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading data from cache: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def append(self, new_candle, symbol, timeframe):
|
||||||
|
"""
|
||||||
|
Append a new candle to the cached data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_candle: Dictionary containing a single OHLCV candle
|
||||||
|
symbol: Trading pair symbol (e.g., 'ETH/USDT')
|
||||||
|
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Boolean indicating success
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load existing data
|
||||||
|
data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for append
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
data = []
|
||||||
|
|
||||||
|
# Check if the candle already exists (same timestamp)
|
||||||
|
for i, candle in enumerate(data):
|
||||||
|
if candle['timestamp'] == new_candle['timestamp']:
|
||||||
|
# Update existing candle
|
||||||
|
data[i] = {
|
||||||
|
'timestamp': new_candle['timestamp'],
|
||||||
|
'open': float(new_candle['open']),
|
||||||
|
'high': float(new_candle['high']),
|
||||||
|
'low': float(new_candle['low']),
|
||||||
|
'close': float(new_candle['close']),
|
||||||
|
'volume': float(new_candle['volume'])
|
||||||
|
}
|
||||||
|
# Save updated data
|
||||||
|
return self.save(data, symbol, timeframe)
|
||||||
|
|
||||||
|
# Append new candle
|
||||||
|
data.append({
|
||||||
|
'timestamp': new_candle['timestamp'],
|
||||||
|
'open': float(new_candle['open']),
|
||||||
|
'high': float(new_candle['high']),
|
||||||
|
'low': float(new_candle['low']),
|
||||||
|
'close': float(new_candle['close']),
|
||||||
|
'volume': float(new_candle['volume'])
|
||||||
|
})
|
||||||
|
|
||||||
|
# Save updated data
|
||||||
|
return self.save(data, symbol, timeframe)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error appending candle to cache: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_latest_timestamp(self, symbol, timeframe):
|
||||||
|
"""
|
||||||
|
Get the timestamp of the most recent candle in the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading pair symbol (e.g., 'ETH/USDT')
|
||||||
|
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Timestamp (milliseconds) of the most recent candle, or None if cache is empty
|
||||||
|
"""
|
||||||
|
data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for this check
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find the most recent timestamp
|
||||||
|
latest_timestamp = max(candle['timestamp'] for candle in data)
|
||||||
|
return latest_timestamp
|
||||||
|
|
||||||
|
def clear(self, symbol=None, timeframe=None):
|
||||||
|
"""
|
||||||
|
Clear cache for a specific symbol and timeframe, or all cache if not specified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading pair symbol (e.g., 'ETH/USDT'), or None to clear all symbols
|
||||||
|
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h'), or None to clear all timeframes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of cache files deleted
|
||||||
|
"""
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if symbol and timeframe:
|
||||||
|
# Clear specific cache
|
||||||
|
filename = self._get_cache_filename(symbol, timeframe)
|
||||||
|
if os.path.exists(filename):
|
||||||
|
os.remove(filename)
|
||||||
|
count = 1
|
||||||
|
|
||||||
|
# Clear from memory cache
|
||||||
|
cache_key = f"{symbol}_{timeframe}"
|
||||||
|
if cache_key in self.memory_cache:
|
||||||
|
del self.memory_cache[cache_key]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Clear all matching caches
|
||||||
|
for filename in os.listdir(self.cache_dir):
|
||||||
|
file_path = os.path.join(self.cache_dir, filename)
|
||||||
|
|
||||||
|
# Skip directories
|
||||||
|
if not os.path.isfile(file_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if file matches the filter
|
||||||
|
should_delete = True
|
||||||
|
|
||||||
|
if symbol:
|
||||||
|
safe_symbol = symbol.replace('/', '_')
|
||||||
|
if not filename.startswith(f"{safe_symbol}_"):
|
||||||
|
should_delete = False
|
||||||
|
|
||||||
|
if timeframe:
|
||||||
|
if not filename.endswith(f"_{timeframe}.json"):
|
||||||
|
should_delete = False
|
||||||
|
|
||||||
|
# Delete file if it matches the filter
|
||||||
|
if should_delete:
|
||||||
|
os.remove(file_path)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
# Clear memory cache
|
||||||
|
keys_to_delete = []
|
||||||
|
for cache_key in self.memory_cache:
|
||||||
|
should_delete = True
|
||||||
|
|
||||||
|
if symbol:
|
||||||
|
if not cache_key.startswith(f"{symbol}_"):
|
||||||
|
should_delete = False
|
||||||
|
|
||||||
|
if timeframe:
|
||||||
|
if not cache_key.endswith(f"_{timeframe}"):
|
||||||
|
should_delete = False
|
||||||
|
|
||||||
|
if should_delete:
|
||||||
|
keys_to_delete.append(cache_key)
|
||||||
|
|
||||||
|
for key in keys_to_delete:
|
||||||
|
del self.memory_cache[key]
|
||||||
|
|
||||||
|
logger.info(f"Cleared {count} cache files")
|
||||||
|
return count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error clearing cache: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def to_dataframe(self, symbol, timeframe):
|
||||||
|
"""
|
||||||
|
Convert cached OHLCV data to a pandas DataFrame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading pair symbol (e.g., 'ETH/USDT')
|
||||||
|
timeframe: Timeframe of the data (e.g., '1m', '5m', '1h')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pandas DataFrame with OHLCV data, or None if cache is missing
|
||||||
|
"""
|
||||||
|
data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for conversion
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert to DataFrame
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
|
||||||
|
# Convert timestamp to datetime
|
||||||
|
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||||||
|
|
||||||
|
# Set datetime as index
|
||||||
|
df.set_index('datetime', inplace=True)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
# Create a global instance for easy access
|
||||||
|
ohlcv_cache = OHLCVCache()
|
@ -291,7 +291,7 @@ class EnhancedReplayBuffer:
|
|||||||
def update_priorities(self, indices, td_errors):
|
def update_priorities(self, indices, td_errors):
|
||||||
for idx, td_error in zip(indices, td_errors):
|
for idx, td_error in zip(indices, td_errors):
|
||||||
# Update priority based on TD error
|
# Update priority based on TD error
|
||||||
priority = abs(td_error) + 1e-5 # Small constant to ensure non-zero priority
|
priority = float(abs(td_error) + 1e-5) # Small constant to ensure non-zero priority
|
||||||
self.priorities[idx] = priority
|
self.priorities[idx] = priority
|
||||||
self.max_priority = max(self.max_priority, priority)
|
self.max_priority = max(self.max_priority, priority)
|
||||||
|
|
||||||
|
765
crypto/gogo2/enhanced_training.py
Normal file
765
crypto/gogo2/enhanced_training.py
Normal file
@ -0,0 +1,765 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.amp import GradScaler, autocast
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from datetime import datetime
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
|
# Import our enhanced models
|
||||||
|
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN, EnhancedReplayBuffer, train_price_predictor, prepare_multi_timeframe_data
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
TIMEFRAMES = ['1m', '15m', '1h']
|
||||||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
LEARNING_RATE = 1e-4
|
||||||
|
BATCH_SIZE = 64
|
||||||
|
GAMMA = 0.99
|
||||||
|
REPLAY_BUFFER_SIZE = 100000
|
||||||
|
TARGET_UPDATE = 10
|
||||||
|
NUM_EPISODES = 200
|
||||||
|
MAX_STEPS_PER_EPISODE = 1000
|
||||||
|
EPSILON_START = 1.0
|
||||||
|
EPSILON_END = 0.01
|
||||||
|
EPSILON_DECAY = 0.995
|
||||||
|
SAVE_INTERVAL = 10
|
||||||
|
CONTINUOUS_MODE = True
|
||||||
|
CONTINUOUS_START_EPISODE = 0
|
||||||
|
|
||||||
|
def setup_tensorboard():
|
||||||
|
"""Set up TensorBoard for logging training metrics"""
|
||||||
|
current_time = datetime.now().strftime('%Y%m%d-%H%M%S')
|
||||||
|
log_dir = os.path.join('runs', current_time)
|
||||||
|
writer = SummaryWriter(log_dir)
|
||||||
|
return writer
|
||||||
|
|
||||||
|
def save_models(price_model, dqn_model, optimizer, episode, rewards, profits, win_rates, best_reward, best_pnl, best_winrate):
|
||||||
|
"""Save model checkpoints and clean up old ones to keep only top 5 and best PnL"""
|
||||||
|
# Create models directory if it doesn't exist
|
||||||
|
os.makedirs('models', exist_ok=True)
|
||||||
|
|
||||||
|
# Save latest models
|
||||||
|
torch.save({
|
||||||
|
'price_model_state_dict': price_model.state_dict(),
|
||||||
|
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'episode': episode,
|
||||||
|
'rewards': rewards,
|
||||||
|
'profits': profits,
|
||||||
|
'win_rates': win_rates
|
||||||
|
}, 'models/enhanced_trading_agent_latest.pt')
|
||||||
|
|
||||||
|
# Save continuous training checkpoint
|
||||||
|
continuous_model_path = f'models/enhanced_trading_agent_continuous_{episode}.pt'
|
||||||
|
torch.save({
|
||||||
|
'price_model_state_dict': price_model.state_dict(),
|
||||||
|
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'episode': episode,
|
||||||
|
'rewards': rewards,
|
||||||
|
'profits': profits,
|
||||||
|
'win_rates': win_rates
|
||||||
|
}, continuous_model_path)
|
||||||
|
|
||||||
|
# Save best models
|
||||||
|
if rewards[-1] > best_reward:
|
||||||
|
best_reward = rewards[-1]
|
||||||
|
torch.save({
|
||||||
|
'price_model_state_dict': price_model.state_dict(),
|
||||||
|
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'episode': episode,
|
||||||
|
'rewards': rewards,
|
||||||
|
'profits': profits,
|
||||||
|
'win_rates': win_rates
|
||||||
|
}, 'models/enhanced_trading_agent_best_reward.pt')
|
||||||
|
|
||||||
|
if profits[-1] > best_pnl:
|
||||||
|
best_pnl = profits[-1]
|
||||||
|
torch.save({
|
||||||
|
'price_model_state_dict': price_model.state_dict(),
|
||||||
|
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'episode': episode,
|
||||||
|
'rewards': rewards,
|
||||||
|
'profits': profits,
|
||||||
|
'win_rates': win_rates
|
||||||
|
}, 'models/enhanced_trading_agent_best_pnl.pt')
|
||||||
|
|
||||||
|
if win_rates[-1] > best_winrate:
|
||||||
|
best_winrate = win_rates[-1]
|
||||||
|
torch.save({
|
||||||
|
'price_model_state_dict': price_model.state_dict(),
|
||||||
|
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'episode': episode,
|
||||||
|
'rewards': rewards,
|
||||||
|
'profits': profits,
|
||||||
|
'win_rates': win_rates
|
||||||
|
}, 'models/enhanced_trading_agent_best_winrate.pt')
|
||||||
|
|
||||||
|
# Save final model at the end of training
|
||||||
|
if episode == NUM_EPISODES - 1:
|
||||||
|
torch.save({
|
||||||
|
'price_model_state_dict': price_model.state_dict(),
|
||||||
|
'dqn_model_state_dict': dqn_model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'episode': episode,
|
||||||
|
'rewards': rewards,
|
||||||
|
'profits': profits,
|
||||||
|
'win_rates': win_rates
|
||||||
|
}, 'models/enhanced_trading_agent_final.pt')
|
||||||
|
|
||||||
|
# Clean up old models - keep only top 5 most recent and best PnL
|
||||||
|
cleanup_model_files()
|
||||||
|
|
||||||
|
return best_reward, best_pnl, best_winrate
|
||||||
|
|
||||||
|
def cleanup_model_files():
|
||||||
|
"""Keep only the top 5 most recent continuous models and the best models"""
|
||||||
|
# Files we always want to keep
|
||||||
|
essential_files = [
|
||||||
|
'enhanced_trading_agent_latest.pt',
|
||||||
|
'enhanced_trading_agent_best_reward.pt',
|
||||||
|
'enhanced_trading_agent_best_pnl.pt',
|
||||||
|
'enhanced_trading_agent_best_winrate.pt',
|
||||||
|
'enhanced_trading_agent_final.pt'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Get all continuous training model files
|
||||||
|
continuous_files = []
|
||||||
|
for file in os.listdir('models'):
|
||||||
|
if file.startswith('enhanced_trading_agent_continuous_') and file.endswith('.pt'):
|
||||||
|
continuous_files.append(file)
|
||||||
|
|
||||||
|
# Sort continuous files by episode number (newest first)
|
||||||
|
if continuous_files:
|
||||||
|
try:
|
||||||
|
continuous_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]), reverse=True)
|
||||||
|
# Keep only the 5 most recent continuous files
|
||||||
|
files_to_keep = essential_files + continuous_files[:5]
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
# Handle case where filename format is unexpected
|
||||||
|
print("Warning: Could not sort continuous files by episode number. Keeping all continuous files.")
|
||||||
|
files_to_keep = essential_files + continuous_files
|
||||||
|
else:
|
||||||
|
files_to_keep = essential_files
|
||||||
|
|
||||||
|
# Delete all other model files
|
||||||
|
for file in os.listdir('models'):
|
||||||
|
if file.endswith('.pt') and file not in files_to_keep:
|
||||||
|
try:
|
||||||
|
os.remove(os.path.join('models', file))
|
||||||
|
print(f"Deleted old model file: {file}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error deleting {file}: {e}")
|
||||||
|
|
||||||
|
def plot_training_results(rewards, profits, win_rates, episode):
|
||||||
|
"""Plot training metrics"""
|
||||||
|
plt.figure(figsize=(15, 15))
|
||||||
|
|
||||||
|
# Plot rewards
|
||||||
|
plt.subplot(3, 1, 1)
|
||||||
|
plt.plot(rewards)
|
||||||
|
plt.title('Average Reward per Episode')
|
||||||
|
plt.xlabel('Episode')
|
||||||
|
plt.ylabel('Reward')
|
||||||
|
|
||||||
|
# Plot profits
|
||||||
|
plt.subplot(3, 1, 2)
|
||||||
|
plt.plot(profits)
|
||||||
|
plt.title('Profit/Loss per Episode')
|
||||||
|
plt.xlabel('Episode')
|
||||||
|
plt.ylabel('PnL ($)')
|
||||||
|
|
||||||
|
# Plot win rates
|
||||||
|
plt.subplot(3, 1, 3)
|
||||||
|
plt.plot(win_rates)
|
||||||
|
plt.title('Win Rate per Episode')
|
||||||
|
plt.xlabel('Episode')
|
||||||
|
plt.ylabel('Win Rate (%)')
|
||||||
|
plt.ylim(0, 100)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('training_results.png')
|
||||||
|
|
||||||
|
# Also save episode-specific plots periodically
|
||||||
|
if episode % 20 == 0:
|
||||||
|
os.makedirs('visualizations', exist_ok=True)
|
||||||
|
plt.savefig(f'visualizations/training_episode_{episode}.png')
|
||||||
|
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def load_checkpoint(price_model, dqn_model, optimizer, episode=None):
|
||||||
|
"""Load model checkpoint for continuous training"""
|
||||||
|
if episode is not None:
|
||||||
|
checkpoint_path = f'models/enhanced_trading_agent_continuous_{episode}.pt'
|
||||||
|
else:
|
||||||
|
checkpoint_path = 'models/enhanced_trading_agent_latest.pt'
|
||||||
|
|
||||||
|
if os.path.exists(checkpoint_path):
|
||||||
|
print(f"Loading checkpoint from {checkpoint_path}")
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
|
||||||
|
|
||||||
|
price_model.load_state_dict(checkpoint['price_model_state_dict'])
|
||||||
|
dqn_model.load_state_dict(checkpoint['dqn_model_state_dict'])
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
|
||||||
|
start_episode = checkpoint['episode'] + 1
|
||||||
|
rewards = checkpoint['rewards']
|
||||||
|
profits = checkpoint['profits']
|
||||||
|
win_rates = checkpoint['win_rates']
|
||||||
|
|
||||||
|
print(f"Resuming training from episode {start_episode}")
|
||||||
|
return start_episode, rewards, profits, win_rates
|
||||||
|
else:
|
||||||
|
print("No checkpoint found, starting training from scratch")
|
||||||
|
return 0, [], [], []
|
||||||
|
|
||||||
|
def enhanced_train_agent(exchange, num_episodes=NUM_EPISODES, continuous=CONTINUOUS_MODE, start_episode=CONTINUOUS_START_EPISODE):
|
||||||
|
"""
|
||||||
|
Train the enhanced trading agent using multi-timeframe data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exchange: Exchange object to fetch data from
|
||||||
|
num_episodes: Number of episodes to train for
|
||||||
|
continuous: Whether to continue training from a checkpoint
|
||||||
|
start_episode: Episode to start from if continuous training
|
||||||
|
"""
|
||||||
|
print(f"Training on device: {DEVICE}")
|
||||||
|
|
||||||
|
# Set up TensorBoard
|
||||||
|
writer = setup_tensorboard()
|
||||||
|
|
||||||
|
# Initialize models
|
||||||
|
state_dim = 100 # Increased state dimension for multi-timeframe features
|
||||||
|
action_dim = 3 # Buy, Sell, Hold
|
||||||
|
|
||||||
|
price_model = EnhancedPricePredictionModel(
|
||||||
|
input_dim=2, # Price and volume
|
||||||
|
hidden_dim=256,
|
||||||
|
num_layers=3,
|
||||||
|
output_dim=5, # Predict next 5 candles
|
||||||
|
num_timeframes=len(TIMEFRAMES)
|
||||||
|
).to(DEVICE)
|
||||||
|
|
||||||
|
dqn_model = EnhancedDQN(
|
||||||
|
state_dim=state_dim,
|
||||||
|
action_dim=action_dim,
|
||||||
|
hidden_dim=512
|
||||||
|
).to(DEVICE)
|
||||||
|
|
||||||
|
target_dqn = EnhancedDQN(
|
||||||
|
state_dim=state_dim,
|
||||||
|
action_dim=action_dim,
|
||||||
|
hidden_dim=512
|
||||||
|
).to(DEVICE)
|
||||||
|
|
||||||
|
# Copy initial weights to target network
|
||||||
|
target_dqn.load_state_dict(dqn_model.state_dict())
|
||||||
|
|
||||||
|
# Initialize optimizer
|
||||||
|
optimizer = optim.Adam(list(price_model.parameters()) + list(dqn_model.parameters()), lr=LEARNING_RATE)
|
||||||
|
|
||||||
|
# Initialize replay buffer
|
||||||
|
replay_buffer = EnhancedReplayBuffer(
|
||||||
|
capacity=REPLAY_BUFFER_SIZE,
|
||||||
|
alpha=0.6,
|
||||||
|
beta=0.4,
|
||||||
|
beta_increment=0.001,
|
||||||
|
n_step=3,
|
||||||
|
gamma=GAMMA
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize gradient scaler for mixed precision training
|
||||||
|
scaler = GradScaler(enabled=(DEVICE.type == 'cuda'))
|
||||||
|
|
||||||
|
# Initialize tracking variables
|
||||||
|
rewards = []
|
||||||
|
profits = []
|
||||||
|
win_rates = []
|
||||||
|
best_reward = float('-inf')
|
||||||
|
best_pnl = float('-inf')
|
||||||
|
best_winrate = float('-inf')
|
||||||
|
|
||||||
|
# Load checkpoint if continuous training
|
||||||
|
if continuous:
|
||||||
|
start_episode, rewards, profits, win_rates = load_checkpoint(
|
||||||
|
price_model, dqn_model, optimizer, start_episode
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare multi-timeframe data for price prediction model training
|
||||||
|
data_loaders = prepare_multi_timeframe_data(exchange, TIMEFRAMES)
|
||||||
|
|
||||||
|
# Pre-train price prediction model
|
||||||
|
print("Pre-training price prediction model...")
|
||||||
|
train_price_predictor(price_model, data_loaders, optimizer, DEVICE, epochs=5)
|
||||||
|
|
||||||
|
# Main training loop
|
||||||
|
epsilon = EPSILON_START
|
||||||
|
|
||||||
|
for episode in range(start_episode, num_episodes):
|
||||||
|
print(f"Episode {episode+1}/{num_episodes}")
|
||||||
|
|
||||||
|
# Reset environment
|
||||||
|
state = initialize_state(exchange, TIMEFRAMES)
|
||||||
|
total_reward = 0
|
||||||
|
trades = []
|
||||||
|
wins = 0
|
||||||
|
losses = 0
|
||||||
|
|
||||||
|
# Episode loop
|
||||||
|
for step in range(MAX_STEPS_PER_EPISODE):
|
||||||
|
# Epsilon-greedy action selection
|
||||||
|
if np.random.random() < epsilon:
|
||||||
|
action = np.random.randint(0, action_dim)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(DEVICE)
|
||||||
|
q_values, _, _ = dqn_model(state_tensor)
|
||||||
|
action = q_values.argmax().item()
|
||||||
|
|
||||||
|
# Execute action and get next state and reward
|
||||||
|
next_state, reward, done, trade_info = step_environment(
|
||||||
|
exchange, state, action, price_model, TIMEFRAMES, DEVICE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store transition in replay buffer
|
||||||
|
replay_buffer.push(
|
||||||
|
torch.FloatTensor(state),
|
||||||
|
action,
|
||||||
|
reward,
|
||||||
|
torch.FloatTensor(next_state),
|
||||||
|
done
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update state and accumulate reward
|
||||||
|
state = next_state
|
||||||
|
total_reward += reward
|
||||||
|
|
||||||
|
# Track trade outcomes
|
||||||
|
if trade_info is not None:
|
||||||
|
trades.append(trade_info)
|
||||||
|
if trade_info['pnl'] > 0:
|
||||||
|
wins += 1
|
||||||
|
elif trade_info['pnl'] < 0:
|
||||||
|
losses += 1
|
||||||
|
|
||||||
|
# Learn from experiences if enough samples
|
||||||
|
if len(replay_buffer) > BATCH_SIZE:
|
||||||
|
learn(dqn_model, target_dqn, replay_buffer, optimizer, scaler, DEVICE)
|
||||||
|
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Update target network
|
||||||
|
if episode % TARGET_UPDATE == 0:
|
||||||
|
target_dqn.load_state_dict(dqn_model.state_dict())
|
||||||
|
|
||||||
|
# Calculate episode metrics
|
||||||
|
avg_reward = total_reward / (step + 1)
|
||||||
|
total_pnl = sum(trade['pnl'] for trade in trades) if trades else 0
|
||||||
|
win_rate = (wins / (wins + losses) * 100) if (wins + losses) > 0 else 0
|
||||||
|
|
||||||
|
# Decay epsilon
|
||||||
|
epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
|
||||||
|
|
||||||
|
# Track metrics
|
||||||
|
rewards.append(avg_reward)
|
||||||
|
profits.append(total_pnl)
|
||||||
|
win_rates.append(win_rate)
|
||||||
|
|
||||||
|
# Log to TensorBoard
|
||||||
|
writer.add_scalar('Training/Reward', avg_reward, episode)
|
||||||
|
writer.add_scalar('Training/Profit', total_pnl, episode)
|
||||||
|
writer.add_scalar('Training/WinRate', win_rate, episode)
|
||||||
|
writer.add_scalar('Training/Epsilon', epsilon, episode)
|
||||||
|
|
||||||
|
# Print episode summary
|
||||||
|
print(f"Episode {episode+1} - Avg Reward: {avg_reward:.2f}, PnL: ${total_pnl:.2f}, Win Rate: {win_rate:.1f}%")
|
||||||
|
|
||||||
|
# Save models and plot results
|
||||||
|
if episode % SAVE_INTERVAL == 0 or episode == num_episodes - 1:
|
||||||
|
best_reward, best_pnl, best_winrate = save_models(
|
||||||
|
price_model, dqn_model, optimizer, episode,
|
||||||
|
rewards, profits, win_rates,
|
||||||
|
best_reward, best_pnl, best_winrate
|
||||||
|
)
|
||||||
|
plot_training_results(rewards, profits, win_rates, episode)
|
||||||
|
|
||||||
|
# Close TensorBoard writer
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
# Final save and plot
|
||||||
|
best_reward, best_pnl, best_winrate = save_models(
|
||||||
|
price_model, dqn_model, optimizer, num_episodes - 1,
|
||||||
|
rewards, profits, win_rates,
|
||||||
|
best_reward, best_pnl, best_winrate
|
||||||
|
)
|
||||||
|
plot_training_results(rewards, profits, win_rates, num_episodes - 1)
|
||||||
|
|
||||||
|
print("Training complete!")
|
||||||
|
return price_model, dqn_model
|
||||||
|
|
||||||
|
def learn(dqn, target_dqn, replay_buffer, optimizer, scaler, device):
|
||||||
|
"""Update the DQN model using experiences from the replay buffer"""
|
||||||
|
# Sample from replay buffer
|
||||||
|
states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(BATCH_SIZE)
|
||||||
|
|
||||||
|
# Move to device
|
||||||
|
states = states.to(device)
|
||||||
|
actions = actions.to(device)
|
||||||
|
rewards = rewards.to(device)
|
||||||
|
next_states = next_states.to(device)
|
||||||
|
dones = dones.to(device)
|
||||||
|
weights = weights.to(device)
|
||||||
|
|
||||||
|
# Get current Q values
|
||||||
|
if device.type == 'cuda':
|
||||||
|
with autocast(device_type='cuda', enabled=True):
|
||||||
|
current_q_values, _, _ = dqn(states)
|
||||||
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||||
|
|
||||||
|
# Compute target Q values
|
||||||
|
with torch.no_grad():
|
||||||
|
next_q_values, _, _ = target_dqn(next_states)
|
||||||
|
max_next_q_values = next_q_values.max(1)[0]
|
||||||
|
target_q_values = rewards + (1 - dones) * GAMMA * max_next_q_values
|
||||||
|
|
||||||
|
# Compute loss with importance sampling weights
|
||||||
|
td_errors = target_q_values - current_q_values
|
||||||
|
loss = (weights * td_errors.pow(2)).mean()
|
||||||
|
else:
|
||||||
|
# CPU version without autocast
|
||||||
|
current_q_values, _, _ = dqn(states)
|
||||||
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||||
|
|
||||||
|
# Compute target Q values
|
||||||
|
with torch.no_grad():
|
||||||
|
next_q_values, _, _ = target_dqn(next_states)
|
||||||
|
max_next_q_values = next_q_values.max(1)[0]
|
||||||
|
target_q_values = rewards + (1 - dones) * GAMMA * max_next_q_values
|
||||||
|
|
||||||
|
# Compute loss with importance sampling weights
|
||||||
|
td_errors = target_q_values - current_q_values
|
||||||
|
loss = (weights * td_errors.pow(2)).mean()
|
||||||
|
|
||||||
|
# Update priorities in replay buffer
|
||||||
|
replay_buffer.update_priorities(indices, td_errors.abs().detach().cpu().numpy())
|
||||||
|
|
||||||
|
# Optimize the model with mixed precision
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if device.type == 'cuda':
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(dqn.parameters(), max_norm=1.0)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
else:
|
||||||
|
# CPU version without scaler
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(dqn.parameters(), max_norm=1.0)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
def initialize_state(exchange, timeframes):
|
||||||
|
"""Initialize the state with data from multiple timeframes"""
|
||||||
|
# Fetch data for each timeframe
|
||||||
|
timeframe_data = {}
|
||||||
|
for tf in timeframes:
|
||||||
|
candles = exchange.fetch_ohlcv(timeframe=tf, limit=30)
|
||||||
|
timeframe_data[tf] = candles
|
||||||
|
|
||||||
|
# Extract features from each timeframe
|
||||||
|
state = []
|
||||||
|
|
||||||
|
for tf in timeframes:
|
||||||
|
candles = timeframe_data[tf]
|
||||||
|
|
||||||
|
# Price features
|
||||||
|
prices = [candle[4] for candle in candles[-10:]] # Last 10 close prices
|
||||||
|
price_changes = [prices[i]/prices[i-1] - 1 for i in range(1, len(prices))]
|
||||||
|
|
||||||
|
# Volume features
|
||||||
|
volumes = [candle[5] for candle in candles[-10:]] # Last 10 volumes
|
||||||
|
volume_changes = [volumes[i]/volumes[i-1] - 1 for i in range(1, len(volumes))]
|
||||||
|
|
||||||
|
# Technical indicators
|
||||||
|
# Simple Moving Averages
|
||||||
|
sma_5 = sum(prices[-5:]) / 5
|
||||||
|
sma_10 = sum(prices) / 10
|
||||||
|
|
||||||
|
# Relative Strength Index (simplified)
|
||||||
|
gains = [max(0, price_changes[i]) for i in range(len(price_changes))]
|
||||||
|
losses = [max(0, -price_changes[i]) for i in range(len(price_changes))]
|
||||||
|
avg_gain = sum(gains) / len(gains)
|
||||||
|
avg_loss = sum(losses) / len(losses)
|
||||||
|
rs = avg_gain / (avg_loss + 1e-10) # Avoid division by zero
|
||||||
|
rsi = 100 - (100 / (1 + rs))
|
||||||
|
|
||||||
|
# Add features to state
|
||||||
|
state.extend(price_changes) # 9 features
|
||||||
|
state.extend(volume_changes) # 9 features
|
||||||
|
state.append(sma_5 / prices[-1] - 1) # 1 feature
|
||||||
|
state.append(sma_10 / prices[-1] - 1) # 1 feature
|
||||||
|
state.append(rsi / 100) # 1 feature
|
||||||
|
|
||||||
|
# Add market regime features
|
||||||
|
# This is a placeholder - in a real implementation, you would use the market_regime_classifier
|
||||||
|
# from the DQN model to predict the current market regime
|
||||||
|
state.extend([0, 0, 0]) # 3 features for market regime (one-hot encoded)
|
||||||
|
|
||||||
|
# Add additional features to reach the expected dimension of 100
|
||||||
|
# Calculate more technical indicators
|
||||||
|
for tf in timeframes:
|
||||||
|
candles = timeframe_data[tf]
|
||||||
|
prices = [candle[4] for candle in candles[-20:]] # Last 20 close prices
|
||||||
|
|
||||||
|
# Bollinger Bands
|
||||||
|
window = 20
|
||||||
|
if len(prices) >= window:
|
||||||
|
sma_20 = sum(prices[-window:]) / window
|
||||||
|
std_dev = (sum((price - sma_20) ** 2 for price in prices[-window:]) / window) ** 0.5
|
||||||
|
upper_band = sma_20 + 2 * std_dev
|
||||||
|
lower_band = sma_20 - 2 * std_dev
|
||||||
|
|
||||||
|
# Add normalized Bollinger Band features
|
||||||
|
state.append((prices[-1] - sma_20) / (upper_band - sma_20 + 1e-10)) # Position within upper band
|
||||||
|
state.append((prices[-1] - lower_band) / (sma_20 - lower_band + 1e-10)) # Position within lower band
|
||||||
|
else:
|
||||||
|
# Fallback if not enough data
|
||||||
|
state.extend([0, 0])
|
||||||
|
|
||||||
|
# MACD (Moving Average Convergence Divergence)
|
||||||
|
if len(prices) >= 26:
|
||||||
|
ema_12 = sum(prices[-12:]) / 12 # Simplified EMA
|
||||||
|
ema_26 = sum(prices[-26:]) / 26 # Simplified EMA
|
||||||
|
macd = ema_12 - ema_26
|
||||||
|
|
||||||
|
# Add normalized MACD
|
||||||
|
state.append(macd / prices[-1])
|
||||||
|
else:
|
||||||
|
# Fallback if not enough data
|
||||||
|
state.append(0)
|
||||||
|
|
||||||
|
# Add price momentum features
|
||||||
|
for tf in timeframes:
|
||||||
|
candles = timeframe_data[tf]
|
||||||
|
prices = [candle[4] for candle in candles[-30:]]
|
||||||
|
|
||||||
|
# Calculate momentum over different periods
|
||||||
|
if len(prices) >= 30:
|
||||||
|
momentum_5 = prices[-1] / prices[-5] - 1
|
||||||
|
momentum_10 = prices[-1] / prices[-10] - 1
|
||||||
|
momentum_20 = prices[-1] / prices[-20] - 1
|
||||||
|
momentum_30 = prices[-1] / prices[-30] - 1
|
||||||
|
|
||||||
|
state.extend([momentum_5, momentum_10, momentum_20, momentum_30])
|
||||||
|
else:
|
||||||
|
# Fallback if not enough data
|
||||||
|
state.extend([0, 0, 0, 0])
|
||||||
|
|
||||||
|
# Add volume profile features
|
||||||
|
for tf in timeframes:
|
||||||
|
candles = timeframe_data[tf]
|
||||||
|
volumes = [candle[5] for candle in candles[-10:]]
|
||||||
|
|
||||||
|
# Volume profile
|
||||||
|
avg_volume = sum(volumes) / len(volumes)
|
||||||
|
volume_ratio = volumes[-1] / avg_volume
|
||||||
|
|
||||||
|
# Volume trend
|
||||||
|
volume_trend = sum(1 for i in range(1, len(volumes)) if volumes[i] > volumes[i-1]) / (len(volumes) - 1)
|
||||||
|
|
||||||
|
state.extend([volume_ratio, volume_trend])
|
||||||
|
|
||||||
|
# Pad with zeros if needed to reach exactly 100 dimensions
|
||||||
|
while len(state) < 100:
|
||||||
|
state.append(0)
|
||||||
|
|
||||||
|
# Ensure state has exactly 100 dimensions
|
||||||
|
if len(state) > 100:
|
||||||
|
state = state[:100]
|
||||||
|
|
||||||
|
assert len(state) == 100, f"State dimension mismatch: {len(state)} != 100"
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def step_environment(exchange, state, action, price_model, timeframes, device):
|
||||||
|
"""
|
||||||
|
Execute action in the environment and return next state, reward, done flag, and trade info
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exchange: Exchange object to interact with
|
||||||
|
state: Current state
|
||||||
|
action: Action to take (0: Hold, 1: Buy, 2: Sell)
|
||||||
|
price_model: Price prediction model
|
||||||
|
timeframes: List of timeframes to use
|
||||||
|
device: Device to run models on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
next_state: Next state after taking action
|
||||||
|
reward: Reward received
|
||||||
|
done: Whether episode is done
|
||||||
|
trade_info: Information about the trade (if any)
|
||||||
|
"""
|
||||||
|
# Fetch latest data for each timeframe
|
||||||
|
timeframe_data = {}
|
||||||
|
for tf in timeframes:
|
||||||
|
candles = exchange.fetch_ohlcv(timeframe=tf, limit=30)
|
||||||
|
timeframe_data[tf] = candles
|
||||||
|
|
||||||
|
# Prepare inputs for price prediction model
|
||||||
|
price_inputs = []
|
||||||
|
for tf in timeframes:
|
||||||
|
candles = timeframe_data[tf]
|
||||||
|
# Extract price and volume data
|
||||||
|
input_data = torch.tensor([
|
||||||
|
[candle[4], candle[5]] for candle in candles[-30:] # Last 30 candles
|
||||||
|
], dtype=torch.float32).unsqueeze(0).to(device) # Add batch dimension
|
||||||
|
price_inputs.append(input_data)
|
||||||
|
|
||||||
|
# Get price and extrema predictions
|
||||||
|
with torch.no_grad():
|
||||||
|
price_pred, extrema_logits, volume_pred = price_model(price_inputs)
|
||||||
|
|
||||||
|
# Convert predictions to numpy
|
||||||
|
price_pred = price_pred.cpu().numpy()[0] # Remove batch dimension
|
||||||
|
extrema_probs = torch.sigmoid(extrema_logits).cpu().numpy()[0]
|
||||||
|
volume_pred = volume_pred.cpu().numpy()[0]
|
||||||
|
|
||||||
|
# Execute action
|
||||||
|
current_price = timeframe_data['1m'][-1][4] # Current close price
|
||||||
|
trade_info = None
|
||||||
|
reward = 0
|
||||||
|
|
||||||
|
if action == 1: # Buy
|
||||||
|
# Check if we're at a predicted low point (good time to buy)
|
||||||
|
is_predicted_low = any(extrema_probs[i*2+1] > 0.7 for i in range(5))
|
||||||
|
|
||||||
|
# Calculate entry quality based on predictions
|
||||||
|
entry_quality = 0.5 # Default quality
|
||||||
|
if is_predicted_low:
|
||||||
|
entry_quality += 0.2 # Bonus for buying at predicted low
|
||||||
|
|
||||||
|
# Check volume confirmation
|
||||||
|
volume_increasing = volume_pred[0] > timeframe_data['1m'][-1][5]
|
||||||
|
if volume_increasing:
|
||||||
|
entry_quality += 0.1 # Bonus for increasing volume
|
||||||
|
|
||||||
|
# Execute buy order
|
||||||
|
# In a real implementation, this would interact with the exchange
|
||||||
|
# For now, we'll simulate the trade
|
||||||
|
trade_info = {
|
||||||
|
'action': 'buy',
|
||||||
|
'price': current_price,
|
||||||
|
'size': 100 * entry_quality, # Size based on entry quality
|
||||||
|
'entry_quality': entry_quality,
|
||||||
|
'pnl': 0 # Will be updated later
|
||||||
|
}
|
||||||
|
|
||||||
|
# Calculate reward
|
||||||
|
# Base reward for taking action
|
||||||
|
reward = 1
|
||||||
|
|
||||||
|
# Bonus for buying at predicted low
|
||||||
|
if is_predicted_low:
|
||||||
|
reward += 5
|
||||||
|
print("Trading at predicted low - additional reward")
|
||||||
|
|
||||||
|
# Bonus for volume confirmation
|
||||||
|
if volume_increasing:
|
||||||
|
reward += 2
|
||||||
|
print("Trading with high volume - additional reward")
|
||||||
|
|
||||||
|
elif action == 2: # Sell
|
||||||
|
# Check if we're at a predicted high point (good time to sell)
|
||||||
|
is_predicted_high = any(extrema_probs[i*2] > 0.7 for i in range(5))
|
||||||
|
|
||||||
|
# Calculate entry quality based on predictions
|
||||||
|
entry_quality = 0.5 # Default quality
|
||||||
|
if is_predicted_high:
|
||||||
|
entry_quality += 0.2 # Bonus for selling at predicted high
|
||||||
|
|
||||||
|
# Check volume confirmation
|
||||||
|
volume_increasing = volume_pred[0] > timeframe_data['1m'][-1][5]
|
||||||
|
if volume_increasing:
|
||||||
|
entry_quality += 0.1 # Bonus for increasing volume
|
||||||
|
|
||||||
|
# Execute sell order
|
||||||
|
# In a real implementation, this would interact with the exchange
|
||||||
|
# For now, we'll simulate the trade
|
||||||
|
trade_info = {
|
||||||
|
'action': 'sell',
|
||||||
|
'price': current_price,
|
||||||
|
'size': 100 * entry_quality, # Size based on entry quality
|
||||||
|
'entry_quality': entry_quality,
|
||||||
|
'pnl': 0 # Will be updated later
|
||||||
|
}
|
||||||
|
|
||||||
|
# Calculate reward
|
||||||
|
# Base reward for taking action
|
||||||
|
reward = 1
|
||||||
|
|
||||||
|
# Bonus for selling at predicted high
|
||||||
|
if is_predicted_high:
|
||||||
|
reward += 5
|
||||||
|
print("Trading at predicted high - additional reward")
|
||||||
|
|
||||||
|
# Bonus for volume confirmation
|
||||||
|
if volume_increasing:
|
||||||
|
reward += 2
|
||||||
|
print("Trading with high volume - additional reward")
|
||||||
|
|
||||||
|
else: # Hold
|
||||||
|
# Small reward for holding
|
||||||
|
reward = 0.1
|
||||||
|
|
||||||
|
# Simulate trade outcome
|
||||||
|
if trade_info is not None:
|
||||||
|
# In a real implementation, this would be based on actual market movement
|
||||||
|
# For now, we'll use the price prediction to simulate the outcome
|
||||||
|
future_price = price_pred[0] # Price in the next candle
|
||||||
|
|
||||||
|
if trade_info['action'] == 'buy':
|
||||||
|
# For buy, profit if price goes up
|
||||||
|
pnl_pct = (future_price / current_price - 1) * 100
|
||||||
|
trade_info['pnl'] = pnl_pct * trade_info['size'] / 100
|
||||||
|
else: # sell
|
||||||
|
# For sell, profit if price goes down
|
||||||
|
pnl_pct = (1 - future_price / current_price) * 100
|
||||||
|
trade_info['pnl'] = pnl_pct * trade_info['size'] / 100
|
||||||
|
|
||||||
|
# Adjust reward based on trade outcome
|
||||||
|
reward += trade_info['pnl'] * 10 # Scale PnL for reward
|
||||||
|
|
||||||
|
# Update state
|
||||||
|
next_state = initialize_state(exchange, timeframes)
|
||||||
|
|
||||||
|
# Check if episode is done
|
||||||
|
# In a real implementation, this would be based on episode length or other criteria
|
||||||
|
done = False
|
||||||
|
|
||||||
|
return next_state, reward, done, trade_info
|
||||||
|
|
||||||
|
# Main function to run training
|
||||||
|
def main():
|
||||||
|
from exchange_simulator import ExchangeSimulator
|
||||||
|
|
||||||
|
# Initialize exchange simulator
|
||||||
|
exchange = ExchangeSimulator()
|
||||||
|
|
||||||
|
# Train agent
|
||||||
|
price_model, dqn_model = enhanced_train_agent(
|
||||||
|
exchange=exchange,
|
||||||
|
num_episodes=NUM_EPISODES,
|
||||||
|
continuous=CONTINUOUS_MODE,
|
||||||
|
start_episode=CONTINUOUS_START_EPISODE
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Training complete!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
373
crypto/gogo2/exchange_simulator.py
Normal file
373
crypto/gogo2/exchange_simulator.py
Normal file
@ -0,0 +1,373 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
class ExchangeSimulator:
|
||||||
|
"""
|
||||||
|
A simple exchange simulator that generates realistic market data
|
||||||
|
for testing trading algorithms without connecting to a real exchange.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, symbol="BTC/USDT", seed=42):
|
||||||
|
"""
|
||||||
|
Initialize the exchange simulator
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading pair symbol
|
||||||
|
seed: Random seed for reproducibility
|
||||||
|
"""
|
||||||
|
self.symbol = symbol
|
||||||
|
self.seed = seed
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
# Initialize data storage
|
||||||
|
self.data = {}
|
||||||
|
self.current_timestamp = datetime.now()
|
||||||
|
|
||||||
|
# Generate initial data for different timeframes
|
||||||
|
self.timeframes = ['1m', '5m', '15m', '30m', '1h', '4h', '1d']
|
||||||
|
self.timeframe_minutes = {
|
||||||
|
'1m': 1,
|
||||||
|
'5m': 5,
|
||||||
|
'15m': 15,
|
||||||
|
'30m': 30,
|
||||||
|
'1h': 60,
|
||||||
|
'4h': 240,
|
||||||
|
'1d': 1440
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate initial price around $50,000 (for BTC/USDT)
|
||||||
|
self.base_price = 50000.0
|
||||||
|
|
||||||
|
# Generate data for each timeframe
|
||||||
|
for tf in self.timeframes:
|
||||||
|
self._generate_initial_data(tf)
|
||||||
|
|
||||||
|
def _generate_initial_data(self, timeframe, num_candles=1000):
|
||||||
|
"""
|
||||||
|
Generate initial historical data for a specific timeframe
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeframe: Timeframe to generate data for
|
||||||
|
num_candles: Number of candles to generate
|
||||||
|
"""
|
||||||
|
# Calculate time delta for this timeframe
|
||||||
|
minutes = self.timeframe_minutes[timeframe]
|
||||||
|
|
||||||
|
# Generate timestamps
|
||||||
|
end_time = self.current_timestamp
|
||||||
|
timestamps = [end_time - timedelta(minutes=minutes * i) for i in range(num_candles)]
|
||||||
|
timestamps.reverse() # Oldest first
|
||||||
|
|
||||||
|
# Generate price data with realistic patterns
|
||||||
|
prices = self._generate_price_series(num_candles)
|
||||||
|
|
||||||
|
# Generate volume data with realistic patterns
|
||||||
|
volumes = self._generate_volume_series(num_candles, timeframe)
|
||||||
|
|
||||||
|
# Create OHLCV data
|
||||||
|
ohlcv_data = []
|
||||||
|
for i in range(num_candles):
|
||||||
|
# Calculate OHLC based on close price
|
||||||
|
close = prices[i]
|
||||||
|
high = close * (1 + np.random.uniform(0, 0.01))
|
||||||
|
low = close * (1 - np.random.uniform(0, 0.01))
|
||||||
|
open_price = prices[i-1] if i > 0 else close * (1 - np.random.uniform(-0.005, 0.005))
|
||||||
|
|
||||||
|
# Create candle
|
||||||
|
candle = [
|
||||||
|
int(timestamps[i].timestamp() * 1000), # Timestamp in milliseconds
|
||||||
|
open_price, # Open
|
||||||
|
high, # High
|
||||||
|
low, # Low
|
||||||
|
close, # Close
|
||||||
|
volumes[i] # Volume
|
||||||
|
]
|
||||||
|
ohlcv_data.append(candle)
|
||||||
|
|
||||||
|
# Store data
|
||||||
|
self.data[timeframe] = ohlcv_data
|
||||||
|
|
||||||
|
def _generate_price_series(self, length):
|
||||||
|
"""
|
||||||
|
Generate a realistic price series with trends, reversals, and volatility
|
||||||
|
|
||||||
|
Args:
|
||||||
|
length: Number of prices to generate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of prices
|
||||||
|
"""
|
||||||
|
# Start with base price
|
||||||
|
prices = [self.base_price]
|
||||||
|
|
||||||
|
# Parameters for price generation
|
||||||
|
trend_strength = 0.001 # Strength of trend
|
||||||
|
volatility = 0.005 # Daily volatility
|
||||||
|
mean_reversion = 0.001 # Mean reversion strength
|
||||||
|
|
||||||
|
# Generate price series
|
||||||
|
for i in range(1, length):
|
||||||
|
# Determine if we're in a trend
|
||||||
|
if i % 100 == 0:
|
||||||
|
# Change trend direction every ~100 candles
|
||||||
|
trend_strength = -trend_strength
|
||||||
|
|
||||||
|
# Calculate price change
|
||||||
|
trend = trend_strength * prices[-1]
|
||||||
|
random_change = np.random.normal(0, volatility) * prices[-1]
|
||||||
|
mean_reversion_change = mean_reversion * (self.base_price - prices[-1])
|
||||||
|
|
||||||
|
# Calculate new price
|
||||||
|
new_price = prices[-1] + trend + random_change + mean_reversion_change
|
||||||
|
|
||||||
|
# Ensure price doesn't go negative
|
||||||
|
new_price = max(new_price, prices[-1] * 0.9)
|
||||||
|
|
||||||
|
prices.append(new_price)
|
||||||
|
|
||||||
|
return prices
|
||||||
|
|
||||||
|
def _generate_volume_series(self, length, timeframe):
|
||||||
|
"""
|
||||||
|
Generate a realistic volume series with patterns
|
||||||
|
|
||||||
|
Args:
|
||||||
|
length: Number of volumes to generate
|
||||||
|
timeframe: Timeframe for volume scaling
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of volumes
|
||||||
|
"""
|
||||||
|
# Base volume depends on timeframe
|
||||||
|
base_volume = {
|
||||||
|
'1m': 10,
|
||||||
|
'5m': 50,
|
||||||
|
'15m': 150,
|
||||||
|
'30m': 300,
|
||||||
|
'1h': 600,
|
||||||
|
'4h': 2400,
|
||||||
|
'1d': 10000
|
||||||
|
}[timeframe]
|
||||||
|
|
||||||
|
# Generate volume series
|
||||||
|
volumes = []
|
||||||
|
for i in range(length):
|
||||||
|
# Volume tends to be higher at trend reversals and during volatile periods
|
||||||
|
cycle_factor = 1 + 0.5 * np.sin(i / 20) # Cyclical pattern
|
||||||
|
random_factor = np.random.lognormal(0, 0.5) # Random spikes
|
||||||
|
|
||||||
|
# Calculate volume
|
||||||
|
volume = base_volume * cycle_factor * random_factor
|
||||||
|
|
||||||
|
# Add some volume spikes
|
||||||
|
if random.random() < 0.05: # 5% chance of volume spike
|
||||||
|
volume *= random.uniform(2, 5)
|
||||||
|
|
||||||
|
volumes.append(volume)
|
||||||
|
|
||||||
|
return volumes
|
||||||
|
|
||||||
|
def fetch_ohlcv(self, timeframe='1m', limit=100, since=None):
|
||||||
|
"""
|
||||||
|
Fetch OHLCV data for a specific timeframe
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeframe: Timeframe to fetch data for
|
||||||
|
limit: Number of candles to fetch
|
||||||
|
since: Timestamp to fetch data since (not used in simulator)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of OHLCV candles
|
||||||
|
"""
|
||||||
|
# Ensure timeframe exists
|
||||||
|
if timeframe not in self.data:
|
||||||
|
if timeframe in self.timeframe_minutes:
|
||||||
|
self._generate_initial_data(timeframe)
|
||||||
|
else:
|
||||||
|
# Default to 1m if timeframe not supported
|
||||||
|
timeframe = '1m'
|
||||||
|
|
||||||
|
# Get data
|
||||||
|
data = self.data[timeframe]
|
||||||
|
|
||||||
|
# Return limited data
|
||||||
|
return data[-limit:]
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
"""
|
||||||
|
Update the exchange data by generating a new candle for each timeframe
|
||||||
|
"""
|
||||||
|
# Update current timestamp
|
||||||
|
self.current_timestamp = datetime.now()
|
||||||
|
|
||||||
|
# Update each timeframe
|
||||||
|
for tf in self.timeframes:
|
||||||
|
self._add_new_candle(tf)
|
||||||
|
|
||||||
|
def _add_new_candle(self, timeframe):
|
||||||
|
"""
|
||||||
|
Add a new candle to the specified timeframe
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeframe: Timeframe to add candle to
|
||||||
|
"""
|
||||||
|
# Get existing data
|
||||||
|
data = self.data[timeframe]
|
||||||
|
|
||||||
|
# Get last close price
|
||||||
|
last_close = data[-1][4]
|
||||||
|
|
||||||
|
# Calculate time delta for this timeframe
|
||||||
|
minutes = self.timeframe_minutes[timeframe]
|
||||||
|
|
||||||
|
# Calculate new timestamp
|
||||||
|
new_timestamp = int((data[-1][0] / 1000 + minutes * 60) * 1000)
|
||||||
|
|
||||||
|
# Generate new price with some randomness
|
||||||
|
price_change = np.random.normal(0, 0.002) * last_close
|
||||||
|
new_close = last_close + price_change
|
||||||
|
|
||||||
|
# Calculate OHLC
|
||||||
|
new_open = last_close
|
||||||
|
new_high = max(new_open, new_close) * (1 + np.random.uniform(0, 0.005))
|
||||||
|
new_low = min(new_open, new_close) * (1 - np.random.uniform(0, 0.005))
|
||||||
|
|
||||||
|
# Generate volume
|
||||||
|
base_volume = data[-1][5]
|
||||||
|
volume_change = np.random.normal(0, 0.2) * base_volume
|
||||||
|
new_volume = max(base_volume + volume_change, base_volume * 0.5)
|
||||||
|
|
||||||
|
# Create new candle
|
||||||
|
new_candle = [
|
||||||
|
new_timestamp,
|
||||||
|
new_open,
|
||||||
|
new_high,
|
||||||
|
new_low,
|
||||||
|
new_close,
|
||||||
|
new_volume
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add to data
|
||||||
|
self.data[timeframe].append(new_candle)
|
||||||
|
|
||||||
|
def get_ticker(self, symbol=None):
|
||||||
|
"""
|
||||||
|
Get current ticker information
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Symbol to get ticker for (defaults to initialized symbol)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with ticker information
|
||||||
|
"""
|
||||||
|
if symbol is None:
|
||||||
|
symbol = self.symbol
|
||||||
|
|
||||||
|
# Get latest 1m candle
|
||||||
|
latest_candle = self.data['1m'][-1]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'symbol': symbol,
|
||||||
|
'bid': latest_candle[4] * 0.9999, # Slightly below last price
|
||||||
|
'ask': latest_candle[4] * 1.0001, # Slightly above last price
|
||||||
|
'last': latest_candle[4],
|
||||||
|
'high': latest_candle[2],
|
||||||
|
'low': latest_candle[3],
|
||||||
|
'volume': latest_candle[5],
|
||||||
|
'timestamp': latest_candle[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_order(self, symbol, type, side, amount, price=None):
|
||||||
|
"""
|
||||||
|
Simulate creating an order
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Symbol to create order for
|
||||||
|
type: Order type (limit, market)
|
||||||
|
side: Order side (buy, sell)
|
||||||
|
amount: Order amount
|
||||||
|
price: Order price (for limit orders)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with order information
|
||||||
|
"""
|
||||||
|
# Get current ticker
|
||||||
|
ticker = self.get_ticker(symbol)
|
||||||
|
|
||||||
|
# Determine execution price
|
||||||
|
if type == 'market':
|
||||||
|
if side == 'buy':
|
||||||
|
execution_price = ticker['ask']
|
||||||
|
else:
|
||||||
|
execution_price = ticker['bid']
|
||||||
|
else: # limit order
|
||||||
|
execution_price = price
|
||||||
|
|
||||||
|
# Create order object
|
||||||
|
order = {
|
||||||
|
'id': f"order_{int(datetime.now().timestamp() * 1000)}",
|
||||||
|
'symbol': symbol,
|
||||||
|
'type': type,
|
||||||
|
'side': side,
|
||||||
|
'amount': amount,
|
||||||
|
'price': execution_price,
|
||||||
|
'cost': amount * execution_price,
|
||||||
|
'filled': amount,
|
||||||
|
'status': 'closed',
|
||||||
|
'timestamp': int(datetime.now().timestamp() * 1000)
|
||||||
|
}
|
||||||
|
|
||||||
|
return order
|
||||||
|
|
||||||
|
def fetch_balance(self):
|
||||||
|
"""
|
||||||
|
Fetch account balance (simulated)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with balance information
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'total': {
|
||||||
|
'USD': 10000.0,
|
||||||
|
'BTC': 1.0
|
||||||
|
},
|
||||||
|
'free': {
|
||||||
|
'USD': 5000.0,
|
||||||
|
'BTC': 0.5
|
||||||
|
},
|
||||||
|
'used': {
|
||||||
|
'USD': 5000.0,
|
||||||
|
'BTC': 0.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Create exchange simulator
|
||||||
|
exchange = ExchangeSimulator()
|
||||||
|
|
||||||
|
# Fetch some data
|
||||||
|
ohlcv = exchange.fetch_ohlcv(timeframe='1h', limit=10)
|
||||||
|
print("OHLCV data (1h timeframe):")
|
||||||
|
for candle in ohlcv[-5:]:
|
||||||
|
timestamp = datetime.fromtimestamp(candle[0] / 1000)
|
||||||
|
print(f"{timestamp}: Open={candle[1]:.2f}, High={candle[2]:.2f}, Low={candle[3]:.2f}, Close={candle[4]:.2f}, Volume={candle[5]:.2f}")
|
||||||
|
|
||||||
|
# Get current ticker
|
||||||
|
ticker = exchange.get_ticker()
|
||||||
|
print(f"\nCurrent ticker: {ticker['last']:.2f}")
|
||||||
|
|
||||||
|
# Create a market buy order
|
||||||
|
order = exchange.create_order("BTC/USDT", "market", "buy", 0.1)
|
||||||
|
print(f"\nCreated order: {order}")
|
||||||
|
|
||||||
|
# Update the exchange (simulate time passing)
|
||||||
|
exchange.update()
|
||||||
|
|
||||||
|
# Get updated ticker
|
||||||
|
updated_ticker = exchange.get_ticker()
|
||||||
|
print(f"\nUpdated ticker: {updated_ticker['last']:.2f}")
|
1555
crypto/gogo2/main.py
1555
crypto/gogo2/main.py
File diff suppressed because it is too large
Load Diff
305
crypto/gogo2/run_enhanced_training.py
Normal file
305
crypto/gogo2/run_enhanced_training.py
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from enhanced_training import enhanced_train_agent
|
||||||
|
from exchange_simulator import ExchangeSimulator
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(description='Enhanced Trading Bot Training')
|
||||||
|
|
||||||
|
parser.add_argument('--mode', type=str, default='train', choices=['train', 'continuous', 'evaluate', 'live', 'demo'],
|
||||||
|
help='Mode to run the trading bot in')
|
||||||
|
|
||||||
|
parser.add_argument('--episodes', type=int, default=100,
|
||||||
|
help='Number of episodes to train for')
|
||||||
|
|
||||||
|
parser.add_argument('--start-episode', type=int, default=0,
|
||||||
|
help='Episode to start from for continuous training')
|
||||||
|
|
||||||
|
parser.add_argument('--device', type=str, default='auto',
|
||||||
|
help='Device to train on (auto, cuda, cpu)')
|
||||||
|
|
||||||
|
parser.add_argument('--timeframes', type=str, default='1m,15m,1h',
|
||||||
|
help='Comma-separated list of timeframes to use')
|
||||||
|
|
||||||
|
parser.add_argument('--refresh-data', action='store_true',
|
||||||
|
help='Refresh data before training')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set device
|
||||||
|
if args.device == 'auto':
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
else:
|
||||||
|
device = torch.device(args.device)
|
||||||
|
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Parse timeframes
|
||||||
|
timeframes = args.timeframes.split(',')
|
||||||
|
print(f"Using timeframes: {timeframes}")
|
||||||
|
|
||||||
|
# Initialize exchange simulator
|
||||||
|
exchange = ExchangeSimulator()
|
||||||
|
|
||||||
|
# Run in specified mode
|
||||||
|
if args.mode == 'train':
|
||||||
|
# Train from scratch
|
||||||
|
print(f"Training for {args.episodes} episodes...")
|
||||||
|
enhanced_train_agent(
|
||||||
|
exchange=exchange,
|
||||||
|
num_episodes=args.episodes,
|
||||||
|
continuous=False,
|
||||||
|
start_episode=0
|
||||||
|
)
|
||||||
|
|
||||||
|
elif args.mode == 'continuous':
|
||||||
|
# Continue training from checkpoint
|
||||||
|
print(f"Continuing training from episode {args.start_episode} for {args.episodes} episodes...")
|
||||||
|
enhanced_train_agent(
|
||||||
|
exchange=exchange,
|
||||||
|
num_episodes=args.episodes,
|
||||||
|
continuous=True,
|
||||||
|
start_episode=args.start_episode
|
||||||
|
)
|
||||||
|
|
||||||
|
elif args.mode == 'evaluate':
|
||||||
|
# Evaluate the model
|
||||||
|
print("Evaluating model...")
|
||||||
|
evaluate_model(exchange, device)
|
||||||
|
|
||||||
|
elif args.mode == 'live' or args.mode == 'demo':
|
||||||
|
# Run in live or demo mode
|
||||||
|
is_demo = args.mode == 'demo'
|
||||||
|
print(f"Running in {'demo' if is_demo else 'live'} mode...")
|
||||||
|
run_live(exchange, device, is_demo=is_demo)
|
||||||
|
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
|
def evaluate_model(exchange, device):
|
||||||
|
"""
|
||||||
|
Evaluate the trained model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exchange: Exchange simulator
|
||||||
|
device: Device to run on
|
||||||
|
"""
|
||||||
|
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Load the best model
|
||||||
|
model_path = 'models/enhanced_trading_agent_best_pnl.pt'
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
model_path = 'models/enhanced_trading_agent_latest.pt'
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print("No model found to evaluate!")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Loading model from {model_path}")
|
||||||
|
checkpoint = torch.load(model_path, map_location=device)
|
||||||
|
|
||||||
|
# Initialize models
|
||||||
|
state_dim = 100
|
||||||
|
action_dim = 3
|
||||||
|
timeframes = ['1m', '15m', '1h']
|
||||||
|
|
||||||
|
price_model = EnhancedPricePredictionModel(
|
||||||
|
input_dim=2,
|
||||||
|
hidden_dim=256,
|
||||||
|
num_layers=3,
|
||||||
|
output_dim=5,
|
||||||
|
num_timeframes=len(timeframes)
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
dqn_model = EnhancedDQN(
|
||||||
|
state_dim=state_dim,
|
||||||
|
action_dim=action_dim,
|
||||||
|
hidden_dim=512
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# Load model weights
|
||||||
|
price_model.load_state_dict(checkpoint['price_model_state_dict'])
|
||||||
|
dqn_model.load_state_dict(checkpoint['dqn_model_state_dict'])
|
||||||
|
|
||||||
|
# Set models to evaluation mode
|
||||||
|
price_model.eval()
|
||||||
|
dqn_model.eval()
|
||||||
|
|
||||||
|
# Run evaluation
|
||||||
|
num_steps = 1000
|
||||||
|
total_reward = 0
|
||||||
|
trades = []
|
||||||
|
|
||||||
|
# Initialize state
|
||||||
|
from enhanced_training import initialize_state, step_environment
|
||||||
|
state = initialize_state(exchange, timeframes)
|
||||||
|
|
||||||
|
for step in range(num_steps):
|
||||||
|
# Select action
|
||||||
|
with torch.no_grad():
|
||||||
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||||
|
q_values, _, _ = dqn_model(state_tensor)
|
||||||
|
action = q_values.argmax().item()
|
||||||
|
|
||||||
|
# Execute action
|
||||||
|
next_state, reward, done, trade_info = step_environment(
|
||||||
|
exchange, state, action, price_model, timeframes, device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update state and accumulate reward
|
||||||
|
state = next_state
|
||||||
|
total_reward += reward
|
||||||
|
|
||||||
|
# Track trade
|
||||||
|
if trade_info is not None:
|
||||||
|
trades.append(trade_info)
|
||||||
|
print(f"Trade: {trade_info['action']} at {trade_info['price']:.2f}, PnL: {trade_info['pnl']:.2f}")
|
||||||
|
|
||||||
|
# Update exchange (simulate time passing)
|
||||||
|
if step % 10 == 0:
|
||||||
|
exchange.update()
|
||||||
|
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
avg_reward = total_reward / num_steps
|
||||||
|
total_pnl = sum(trade['pnl'] for trade in trades) if trades else 0
|
||||||
|
wins = sum(1 for trade in trades if trade['pnl'] > 0)
|
||||||
|
losses = sum(1 for trade in trades if trade['pnl'] < 0)
|
||||||
|
win_rate = (wins / (wins + losses) * 100) if (wins + losses) > 0 else 0
|
||||||
|
|
||||||
|
print("\nEvaluation Results:")
|
||||||
|
print(f"Average Reward: {avg_reward:.2f}")
|
||||||
|
print(f"Total PnL: ${total_pnl:.2f}")
|
||||||
|
print(f"Win Rate: {win_rate:.1f}% ({wins}/{wins+losses})")
|
||||||
|
|
||||||
|
def run_live(exchange, device, is_demo=True):
|
||||||
|
"""
|
||||||
|
Run the trading bot in live or demo mode
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exchange: Exchange simulator or real exchange
|
||||||
|
device: Device to run on
|
||||||
|
is_demo: Whether to run in demo mode (no real trades)
|
||||||
|
"""
|
||||||
|
from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Load the best model
|
||||||
|
model_path = 'models/enhanced_trading_agent_best_pnl.pt'
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
model_path = 'models/enhanced_trading_agent_latest.pt'
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print("No model found to run in live mode!")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Loading model from {model_path}")
|
||||||
|
checkpoint = torch.load(model_path, map_location=device)
|
||||||
|
|
||||||
|
# Initialize models
|
||||||
|
state_dim = 100
|
||||||
|
action_dim = 3
|
||||||
|
timeframes = ['1m', '15m', '1h']
|
||||||
|
|
||||||
|
price_model = EnhancedPricePredictionModel(
|
||||||
|
input_dim=2,
|
||||||
|
hidden_dim=256,
|
||||||
|
num_layers=3,
|
||||||
|
output_dim=5,
|
||||||
|
num_timeframes=len(timeframes)
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
dqn_model = EnhancedDQN(
|
||||||
|
state_dim=state_dim,
|
||||||
|
action_dim=action_dim,
|
||||||
|
hidden_dim=512
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# Load model weights
|
||||||
|
price_model.load_state_dict(checkpoint['price_model_state_dict'])
|
||||||
|
dqn_model.load_state_dict(checkpoint['dqn_model_state_dict'])
|
||||||
|
|
||||||
|
# Set models to evaluation mode
|
||||||
|
price_model.eval()
|
||||||
|
dqn_model.eval()
|
||||||
|
|
||||||
|
# Run live trading
|
||||||
|
print(f"Running in {'demo' if is_demo else 'live'} mode...")
|
||||||
|
print("Press Ctrl+C to stop")
|
||||||
|
|
||||||
|
# Initialize state
|
||||||
|
from enhanced_training import initialize_state, step_environment
|
||||||
|
state = initialize_state(exchange, timeframes)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Select action
|
||||||
|
with torch.no_grad():
|
||||||
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||||
|
q_values, _, market_regime = dqn_model(state_tensor)
|
||||||
|
action = q_values.argmax().item()
|
||||||
|
|
||||||
|
# Get market regime prediction
|
||||||
|
regime_probs = torch.softmax(market_regime, dim=1).cpu().numpy()[0]
|
||||||
|
regime_names = ['Trending', 'Ranging', 'Volatile']
|
||||||
|
predicted_regime = regime_names[regime_probs.argmax()]
|
||||||
|
|
||||||
|
# Get current price
|
||||||
|
ticker = exchange.get_ticker()
|
||||||
|
current_price = ticker['last']
|
||||||
|
|
||||||
|
# Print state
|
||||||
|
print(f"\nCurrent price: ${current_price:.2f}")
|
||||||
|
print(f"Predicted market regime: {predicted_regime} ({regime_probs.max()*100:.1f}% confidence)")
|
||||||
|
|
||||||
|
# Execute action
|
||||||
|
next_state, reward, _, trade_info = step_environment(
|
||||||
|
exchange, state, action, price_model, timeframes, device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print action
|
||||||
|
action_names = ['Hold', 'Buy', 'Sell']
|
||||||
|
print(f"Action: {action_names[action]}")
|
||||||
|
|
||||||
|
if trade_info is not None:
|
||||||
|
print(f"Trade: {trade_info['action']} at {trade_info['price']:.2f}, Size: {trade_info['size']:.2f}, Entry Quality: {trade_info['entry_quality']:.2f}")
|
||||||
|
|
||||||
|
# Execute real trade if not in demo mode
|
||||||
|
if not is_demo:
|
||||||
|
if trade_info['action'] == 'buy':
|
||||||
|
order = exchange.create_order(
|
||||||
|
symbol="BTC/USDT",
|
||||||
|
type="market",
|
||||||
|
side="buy",
|
||||||
|
amount=trade_info['size'] / current_price
|
||||||
|
)
|
||||||
|
print(f"Executed buy order: {order}")
|
||||||
|
else: # sell
|
||||||
|
order = exchange.create_order(
|
||||||
|
symbol="BTC/USDT",
|
||||||
|
type="market",
|
||||||
|
side="sell",
|
||||||
|
amount=trade_info['size'] / current_price
|
||||||
|
)
|
||||||
|
print(f"Executed sell order: {order}")
|
||||||
|
|
||||||
|
# Update state
|
||||||
|
state = next_state
|
||||||
|
|
||||||
|
# Update exchange (simulate time passing)
|
||||||
|
exchange.update()
|
||||||
|
|
||||||
|
# Wait for next candle
|
||||||
|
time.sleep(5) # In a real implementation, this would wait for the next candle
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nStopping live trading")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
185
crypto/gogo2/test_cache.py
Normal file
185
crypto/gogo2/test_cache.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(sys.stdout)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger('cache_test')
|
||||||
|
|
||||||
|
# Import our cache implementation
|
||||||
|
from data_cache import ohlcv_cache
|
||||||
|
|
||||||
|
def generate_sample_data(num_candles=100):
|
||||||
|
"""Generate sample OHLCV data for testing"""
|
||||||
|
data = []
|
||||||
|
base_timestamp = int(time.time() * 1000) - (num_candles * 60 * 1000) # Start from num_candles minutes ago
|
||||||
|
|
||||||
|
for i in range(num_candles):
|
||||||
|
timestamp = base_timestamp + (i * 60 * 1000) # Add i minutes
|
||||||
|
|
||||||
|
# Generate some random-ish but realistic looking price data
|
||||||
|
base_price = 1900.0 + (i * 0.5) # Slight uptrend
|
||||||
|
open_price = base_price - 0.5 + (i % 3)
|
||||||
|
close_price = base_price + 0.3 + ((i+1) % 4)
|
||||||
|
high_price = max(open_price, close_price) + 1.0 + (i % 2)
|
||||||
|
low_price = min(open_price, close_price) - 0.8 - (i % 2)
|
||||||
|
volume = 10.0 + (i % 10) * 2.0
|
||||||
|
|
||||||
|
data.append({
|
||||||
|
'timestamp': timestamp,
|
||||||
|
'open': open_price,
|
||||||
|
'high': high_price,
|
||||||
|
'low': low_price,
|
||||||
|
'close': close_price,
|
||||||
|
'volume': volume
|
||||||
|
})
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def test_cache_save_load():
|
||||||
|
"""Test saving and loading data from cache"""
|
||||||
|
logger.info("Testing cache save and load...")
|
||||||
|
|
||||||
|
# Generate sample data
|
||||||
|
data = generate_sample_data(100)
|
||||||
|
logger.info(f"Generated {len(data)} sample candles")
|
||||||
|
|
||||||
|
# Save to cache
|
||||||
|
symbol = "ETH/USDT"
|
||||||
|
timeframe = "1m"
|
||||||
|
success = ohlcv_cache.save(data, symbol, timeframe)
|
||||||
|
logger.info(f"Saved to cache: {success}")
|
||||||
|
|
||||||
|
# Load from cache
|
||||||
|
cached_data = ohlcv_cache.load(symbol, timeframe)
|
||||||
|
logger.info(f"Loaded {len(cached_data) if cached_data else 0} candles from cache")
|
||||||
|
|
||||||
|
# Verify data integrity
|
||||||
|
if cached_data:
|
||||||
|
first_original = data[0]
|
||||||
|
first_cached = cached_data[0]
|
||||||
|
logger.info(f"First original candle: {first_original}")
|
||||||
|
logger.info(f"First cached candle: {first_cached}")
|
||||||
|
|
||||||
|
last_original = data[-1]
|
||||||
|
last_cached = cached_data[-1]
|
||||||
|
logger.info(f"Last original candle: {last_original}")
|
||||||
|
logger.info(f"Last cached candle: {last_cached}")
|
||||||
|
|
||||||
|
return success and cached_data and len(cached_data) == len(data)
|
||||||
|
|
||||||
|
def test_cache_append():
|
||||||
|
"""Test appending a new candle to cached data"""
|
||||||
|
logger.info("Testing cache append...")
|
||||||
|
|
||||||
|
# Generate sample data
|
||||||
|
data = generate_sample_data(100)
|
||||||
|
|
||||||
|
# Save to cache
|
||||||
|
symbol = "ETH/USDT"
|
||||||
|
timeframe = "5m"
|
||||||
|
success = ohlcv_cache.save(data, symbol, timeframe)
|
||||||
|
logger.info(f"Saved to cache: {success}")
|
||||||
|
|
||||||
|
# Generate a new candle
|
||||||
|
last_timestamp = data[-1]['timestamp']
|
||||||
|
new_timestamp = last_timestamp + (5 * 60 * 1000) # 5 minutes later
|
||||||
|
new_candle = {
|
||||||
|
'timestamp': new_timestamp,
|
||||||
|
'open': 1950.0,
|
||||||
|
'high': 1955.0,
|
||||||
|
'low': 1948.0,
|
||||||
|
'close': 1952.0,
|
||||||
|
'volume': 15.0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Append to cache
|
||||||
|
success = ohlcv_cache.append(new_candle, symbol, timeframe)
|
||||||
|
logger.info(f"Appended to cache: {success}")
|
||||||
|
|
||||||
|
# Load from cache
|
||||||
|
cached_data = ohlcv_cache.load(symbol, timeframe)
|
||||||
|
logger.info(f"Loaded {len(cached_data) if cached_data else 0} candles from cache")
|
||||||
|
|
||||||
|
# Verify the new candle was appended
|
||||||
|
if cached_data:
|
||||||
|
last_cached = cached_data[-1]
|
||||||
|
logger.info(f"New candle: {new_candle}")
|
||||||
|
logger.info(f"Last cached candle: {last_cached}")
|
||||||
|
|
||||||
|
return success and cached_data and len(cached_data) == len(data) + 1
|
||||||
|
|
||||||
|
def test_cache_dataframe():
|
||||||
|
"""Test converting cached data to a pandas DataFrame"""
|
||||||
|
logger.info("Testing cache to DataFrame conversion...")
|
||||||
|
|
||||||
|
# Generate sample data
|
||||||
|
data = generate_sample_data(100)
|
||||||
|
|
||||||
|
# Save to cache
|
||||||
|
symbol = "ETH/USDT"
|
||||||
|
timeframe = "15m"
|
||||||
|
success = ohlcv_cache.save(data, symbol, timeframe)
|
||||||
|
logger.info(f"Saved to cache: {success}")
|
||||||
|
|
||||||
|
# Convert to DataFrame
|
||||||
|
df = ohlcv_cache.to_dataframe(symbol, timeframe)
|
||||||
|
logger.info(f"Converted to DataFrame with {len(df) if df is not None else 0} rows")
|
||||||
|
|
||||||
|
# Display DataFrame info
|
||||||
|
if df is not None:
|
||||||
|
logger.info(f"DataFrame columns: {df.columns.tolist()}")
|
||||||
|
logger.info(f"DataFrame index: {df.index.name}")
|
||||||
|
logger.info(f"First row: {df.iloc[0].to_dict()}")
|
||||||
|
logger.info(f"Last row: {df.iloc[-1].to_dict()}")
|
||||||
|
|
||||||
|
return success and df is not None and len(df) == len(data)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run all tests"""
|
||||||
|
logger.info("Starting cache tests...")
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
save_load_success = test_cache_save_load()
|
||||||
|
append_success = test_cache_append()
|
||||||
|
dataframe_success = test_cache_dataframe()
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
logger.info("Test results:")
|
||||||
|
logger.info(f" Save/Load: {'PASS' if save_load_success else 'FAIL'}")
|
||||||
|
logger.info(f" Append: {'PASS' if append_success else 'FAIL'}")
|
||||||
|
logger.info(f" DataFrame: {'PASS' if dataframe_success else 'FAIL'}")
|
||||||
|
|
||||||
|
# Check cache directory contents
|
||||||
|
cache_dir = ohlcv_cache.cache_dir
|
||||||
|
logger.info(f"Cache directory: {cache_dir}")
|
||||||
|
if os.path.exists(cache_dir):
|
||||||
|
files = os.listdir(cache_dir)
|
||||||
|
logger.info(f"Cache files: {files}")
|
||||||
|
|
||||||
|
# Print file sizes
|
||||||
|
for file in files:
|
||||||
|
file_path = os.path.join(cache_dir, file)
|
||||||
|
size_kb = os.path.getsize(file_path) / 1024
|
||||||
|
logger.info(f" {file}: {size_kb:.2f} KB")
|
||||||
|
|
||||||
|
# Print first few lines of each file
|
||||||
|
with open(file_path, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
logger.info(f" Metadata: symbol={data.get('symbol')}, timeframe={data.get('timeframe')}, last_updated={datetime.fromtimestamp(data.get('last_updated')).strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
logger.info(f" Candles: {len(data.get('data', []))}")
|
||||||
|
|
||||||
|
return save_load_success and append_success and dataframe_success
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user