diff --git a/crypto/gogo2/_notes.md b/crypto/gogo2/_notes.md index f782c14..0cbfa6f 100644 --- a/crypto/gogo2/_notes.md +++ b/crypto/gogo2/_notes.md @@ -1,3 +1,5 @@ +pip install torch-tb-profiler + ensure we use GPU if available to train faster. during training we need to have RL loop that looks at streaming data, and retrospective backtesting/training on predictions. sincr the start of the traing we're only loosing. implement robust penalty and analysis when closing a loosing trade and improve the reward function. diff --git a/crypto/gogo2/data_cache.py b/crypto/gogo2/data_cache.py new file mode 100644 index 0000000..eadcb4c --- /dev/null +++ b/crypto/gogo2/data_cache.py @@ -0,0 +1,319 @@ +import os +import json +import time +import pandas as pd +import numpy as np +from datetime import datetime, timedelta +import logging + +# Set up logging +logger = logging.getLogger('trading_bot') + +class OHLCVCache: + """ + A simple cache for OHLCV data from exchanges. + Stores data in a structured format and provides backup when exchange is unavailable. + """ + def __init__(self, cache_dir="cache", max_age_hours=24): + """ + Initialize the OHLCV cache. + + Args: + cache_dir: Directory to store cache files + max_age_hours: Maximum age of cached data in hours before considered stale + """ + self.cache_dir = cache_dir + self.max_age_seconds = max_age_hours * 3600 + + # Create cache directory if it doesn't exist + os.makedirs(cache_dir, exist_ok=True) + + # In-memory cache for faster access + self.memory_cache = {} + + def _get_cache_filename(self, symbol, timeframe): + """Generate a standardized filename for the cache file""" + # Replace / with _ in symbol name (e.g., ETH/USDT -> ETH_USDT) + safe_symbol = symbol.replace('/', '_') + return os.path.join(self.cache_dir, f"{safe_symbol}_{timeframe}.json") + + def save(self, data, symbol, timeframe): + """ + Save OHLCV data to cache. + + Args: + data: List of dictionaries containing OHLCV data + symbol: Trading pair symbol (e.g., 'ETH/USDT') + timeframe: Timeframe of the data (e.g., '1m', '5m', '1h') + """ + if not data: + logger.warning(f"No data to cache for {symbol} ({timeframe})") + return False + + try: + # Convert data to a serializable format + serializable_data = [] + for candle in data: + serializable_data.append({ + 'timestamp': candle['timestamp'], + 'open': float(candle['open']), + 'high': float(candle['high']), + 'low': float(candle['low']), + 'close': float(candle['close']), + 'volume': float(candle['volume']) + }) + + # Create cache entry with metadata + cache_entry = { + 'symbol': symbol, + 'timeframe': timeframe, + 'last_updated': int(time.time()), + 'data': serializable_data + } + + # Save to file + filename = self._get_cache_filename(symbol, timeframe) + with open(filename, 'w') as f: + json.dump(cache_entry, f) + + # Update in-memory cache + cache_key = f"{symbol}_{timeframe}" + self.memory_cache[cache_key] = cache_entry + + logger.info(f"Cached {len(data)} candles for {symbol} ({timeframe})") + return True + + except Exception as e: + logger.error(f"Error saving data to cache: {e}") + return False + + def load(self, symbol, timeframe, max_age_override=None): + """ + Load OHLCV data from cache. + + Args: + symbol: Trading pair symbol (e.g., 'ETH/USDT') + timeframe: Timeframe of the data (e.g., '1m', '5m', '1h') + max_age_override: Override the default max age (in seconds) + + Returns: + List of dictionaries containing OHLCV data, or None if cache is missing or stale + """ + cache_key = f"{symbol}_{timeframe}" + max_age = max_age_override if max_age_override is not None else self.max_age_seconds + + try: + # Check in-memory cache first + if cache_key in self.memory_cache: + cache_entry = self.memory_cache[cache_key] + + # Check if cache is fresh + cache_age = int(time.time()) - cache_entry['last_updated'] + if cache_age <= max_age: + logger.info(f"Using in-memory cache for {symbol} ({timeframe}), age: {cache_age//60} minutes") + return cache_entry['data'] + + # Check file cache + filename = self._get_cache_filename(symbol, timeframe) + if not os.path.exists(filename): + logger.info(f"No cache file found for {symbol} ({timeframe})") + return None + + # Load cache file + with open(filename, 'r') as f: + cache_entry = json.load(f) + + # Check if cache is fresh + cache_age = int(time.time()) - cache_entry['last_updated'] + if cache_age > max_age: + logger.info(f"Cache for {symbol} ({timeframe}) is stale ({cache_age//60} minutes old)") + return None + + # Update in-memory cache + self.memory_cache[cache_key] = cache_entry + + logger.info(f"Loaded {len(cache_entry['data'])} candles from cache for {symbol} ({timeframe})") + return cache_entry['data'] + + except Exception as e: + logger.error(f"Error loading data from cache: {e}") + return None + + def append(self, new_candle, symbol, timeframe): + """ + Append a new candle to the cached data. + + Args: + new_candle: Dictionary containing a single OHLCV candle + symbol: Trading pair symbol (e.g., 'ETH/USDT') + timeframe: Timeframe of the data (e.g., '1m', '5m', '1h') + + Returns: + Boolean indicating success + """ + try: + # Load existing data + data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for append + + if data is None: + data = [] + + # Check if the candle already exists (same timestamp) + for i, candle in enumerate(data): + if candle['timestamp'] == new_candle['timestamp']: + # Update existing candle + data[i] = { + 'timestamp': new_candle['timestamp'], + 'open': float(new_candle['open']), + 'high': float(new_candle['high']), + 'low': float(new_candle['low']), + 'close': float(new_candle['close']), + 'volume': float(new_candle['volume']) + } + # Save updated data + return self.save(data, symbol, timeframe) + + # Append new candle + data.append({ + 'timestamp': new_candle['timestamp'], + 'open': float(new_candle['open']), + 'high': float(new_candle['high']), + 'low': float(new_candle['low']), + 'close': float(new_candle['close']), + 'volume': float(new_candle['volume']) + }) + + # Save updated data + return self.save(data, symbol, timeframe) + + except Exception as e: + logger.error(f"Error appending candle to cache: {e}") + return False + + def get_latest_timestamp(self, symbol, timeframe): + """ + Get the timestamp of the most recent candle in the cache. + + Args: + symbol: Trading pair symbol (e.g., 'ETH/USDT') + timeframe: Timeframe of the data (e.g., '1m', '5m', '1h') + + Returns: + Timestamp (milliseconds) of the most recent candle, or None if cache is empty + """ + data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for this check + + if not data: + return None + + # Find the most recent timestamp + latest_timestamp = max(candle['timestamp'] for candle in data) + return latest_timestamp + + def clear(self, symbol=None, timeframe=None): + """ + Clear cache for a specific symbol and timeframe, or all cache if not specified. + + Args: + symbol: Trading pair symbol (e.g., 'ETH/USDT'), or None to clear all symbols + timeframe: Timeframe of the data (e.g., '1m', '5m', '1h'), or None to clear all timeframes + + Returns: + Number of cache files deleted + """ + count = 0 + + try: + if symbol and timeframe: + # Clear specific cache + filename = self._get_cache_filename(symbol, timeframe) + if os.path.exists(filename): + os.remove(filename) + count = 1 + + # Clear from memory cache + cache_key = f"{symbol}_{timeframe}" + if cache_key in self.memory_cache: + del self.memory_cache[cache_key] + + else: + # Clear all matching caches + for filename in os.listdir(self.cache_dir): + file_path = os.path.join(self.cache_dir, filename) + + # Skip directories + if not os.path.isfile(file_path): + continue + + # Check if file matches the filter + should_delete = True + + if symbol: + safe_symbol = symbol.replace('/', '_') + if not filename.startswith(f"{safe_symbol}_"): + should_delete = False + + if timeframe: + if not filename.endswith(f"_{timeframe}.json"): + should_delete = False + + # Delete file if it matches the filter + if should_delete: + os.remove(file_path) + count += 1 + + # Clear memory cache + keys_to_delete = [] + for cache_key in self.memory_cache: + should_delete = True + + if symbol: + if not cache_key.startswith(f"{symbol}_"): + should_delete = False + + if timeframe: + if not cache_key.endswith(f"_{timeframe}"): + should_delete = False + + if should_delete: + keys_to_delete.append(cache_key) + + for key in keys_to_delete: + del self.memory_cache[key] + + logger.info(f"Cleared {count} cache files") + return count + + except Exception as e: + logger.error(f"Error clearing cache: {e}") + return 0 + + def to_dataframe(self, symbol, timeframe): + """ + Convert cached OHLCV data to a pandas DataFrame. + + Args: + symbol: Trading pair symbol (e.g., 'ETH/USDT') + timeframe: Timeframe of the data (e.g., '1m', '5m', '1h') + + Returns: + pandas DataFrame with OHLCV data, or None if cache is missing + """ + data = self.load(symbol, timeframe, max_age_override=float('inf')) # Ignore age for conversion + + if not data: + return None + + # Convert to DataFrame + df = pd.DataFrame(data) + + # Convert timestamp to datetime + df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms') + + # Set datetime as index + df.set_index('datetime', inplace=True) + + return df + +# Create a global instance for easy access +ohlcv_cache = OHLCVCache() \ No newline at end of file diff --git a/crypto/gogo2/enhanced_models.py b/crypto/gogo2/enhanced_models.py index e9370aa..5dc3b72 100644 --- a/crypto/gogo2/enhanced_models.py +++ b/crypto/gogo2/enhanced_models.py @@ -291,7 +291,7 @@ class EnhancedReplayBuffer: def update_priorities(self, indices, td_errors): for idx, td_error in zip(indices, td_errors): # Update priority based on TD error - priority = abs(td_error) + 1e-5 # Small constant to ensure non-zero priority + priority = float(abs(td_error) + 1e-5) # Small constant to ensure non-zero priority self.priorities[idx] = priority self.max_priority = max(self.max_priority, priority) diff --git a/crypto/gogo2/enhanced_training.py b/crypto/gogo2/enhanced_training.py new file mode 100644 index 0000000..f4e8dde --- /dev/null +++ b/crypto/gogo2/enhanced_training.py @@ -0,0 +1,765 @@ +import os +import torch +import torch.nn as nn +import torch.optim as optim +from torch.amp import GradScaler, autocast +import numpy as np +import matplotlib.pyplot as plt +from datetime import datetime +from tensorboardX import SummaryWriter + +# Import our enhanced models +from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN, EnhancedReplayBuffer, train_price_predictor, prepare_multi_timeframe_data + +# Constants +TIMEFRAMES = ['1m', '15m', '1h'] +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +LEARNING_RATE = 1e-4 +BATCH_SIZE = 64 +GAMMA = 0.99 +REPLAY_BUFFER_SIZE = 100000 +TARGET_UPDATE = 10 +NUM_EPISODES = 200 +MAX_STEPS_PER_EPISODE = 1000 +EPSILON_START = 1.0 +EPSILON_END = 0.01 +EPSILON_DECAY = 0.995 +SAVE_INTERVAL = 10 +CONTINUOUS_MODE = True +CONTINUOUS_START_EPISODE = 0 + +def setup_tensorboard(): + """Set up TensorBoard for logging training metrics""" + current_time = datetime.now().strftime('%Y%m%d-%H%M%S') + log_dir = os.path.join('runs', current_time) + writer = SummaryWriter(log_dir) + return writer + +def save_models(price_model, dqn_model, optimizer, episode, rewards, profits, win_rates, best_reward, best_pnl, best_winrate): + """Save model checkpoints and clean up old ones to keep only top 5 and best PnL""" + # Create models directory if it doesn't exist + os.makedirs('models', exist_ok=True) + + # Save latest models + torch.save({ + 'price_model_state_dict': price_model.state_dict(), + 'dqn_model_state_dict': dqn_model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'episode': episode, + 'rewards': rewards, + 'profits': profits, + 'win_rates': win_rates + }, 'models/enhanced_trading_agent_latest.pt') + + # Save continuous training checkpoint + continuous_model_path = f'models/enhanced_trading_agent_continuous_{episode}.pt' + torch.save({ + 'price_model_state_dict': price_model.state_dict(), + 'dqn_model_state_dict': dqn_model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'episode': episode, + 'rewards': rewards, + 'profits': profits, + 'win_rates': win_rates + }, continuous_model_path) + + # Save best models + if rewards[-1] > best_reward: + best_reward = rewards[-1] + torch.save({ + 'price_model_state_dict': price_model.state_dict(), + 'dqn_model_state_dict': dqn_model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'episode': episode, + 'rewards': rewards, + 'profits': profits, + 'win_rates': win_rates + }, 'models/enhanced_trading_agent_best_reward.pt') + + if profits[-1] > best_pnl: + best_pnl = profits[-1] + torch.save({ + 'price_model_state_dict': price_model.state_dict(), + 'dqn_model_state_dict': dqn_model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'episode': episode, + 'rewards': rewards, + 'profits': profits, + 'win_rates': win_rates + }, 'models/enhanced_trading_agent_best_pnl.pt') + + if win_rates[-1] > best_winrate: + best_winrate = win_rates[-1] + torch.save({ + 'price_model_state_dict': price_model.state_dict(), + 'dqn_model_state_dict': dqn_model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'episode': episode, + 'rewards': rewards, + 'profits': profits, + 'win_rates': win_rates + }, 'models/enhanced_trading_agent_best_winrate.pt') + + # Save final model at the end of training + if episode == NUM_EPISODES - 1: + torch.save({ + 'price_model_state_dict': price_model.state_dict(), + 'dqn_model_state_dict': dqn_model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'episode': episode, + 'rewards': rewards, + 'profits': profits, + 'win_rates': win_rates + }, 'models/enhanced_trading_agent_final.pt') + + # Clean up old models - keep only top 5 most recent and best PnL + cleanup_model_files() + + return best_reward, best_pnl, best_winrate + +def cleanup_model_files(): + """Keep only the top 5 most recent continuous models and the best models""" + # Files we always want to keep + essential_files = [ + 'enhanced_trading_agent_latest.pt', + 'enhanced_trading_agent_best_reward.pt', + 'enhanced_trading_agent_best_pnl.pt', + 'enhanced_trading_agent_best_winrate.pt', + 'enhanced_trading_agent_final.pt' + ] + + # Get all continuous training model files + continuous_files = [] + for file in os.listdir('models'): + if file.startswith('enhanced_trading_agent_continuous_') and file.endswith('.pt'): + continuous_files.append(file) + + # Sort continuous files by episode number (newest first) + if continuous_files: + try: + continuous_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]), reverse=True) + # Keep only the 5 most recent continuous files + files_to_keep = essential_files + continuous_files[:5] + except (ValueError, IndexError): + # Handle case where filename format is unexpected + print("Warning: Could not sort continuous files by episode number. Keeping all continuous files.") + files_to_keep = essential_files + continuous_files + else: + files_to_keep = essential_files + + # Delete all other model files + for file in os.listdir('models'): + if file.endswith('.pt') and file not in files_to_keep: + try: + os.remove(os.path.join('models', file)) + print(f"Deleted old model file: {file}") + except Exception as e: + print(f"Error deleting {file}: {e}") + +def plot_training_results(rewards, profits, win_rates, episode): + """Plot training metrics""" + plt.figure(figsize=(15, 15)) + + # Plot rewards + plt.subplot(3, 1, 1) + plt.plot(rewards) + plt.title('Average Reward per Episode') + plt.xlabel('Episode') + plt.ylabel('Reward') + + # Plot profits + plt.subplot(3, 1, 2) + plt.plot(profits) + plt.title('Profit/Loss per Episode') + plt.xlabel('Episode') + plt.ylabel('PnL ($)') + + # Plot win rates + plt.subplot(3, 1, 3) + plt.plot(win_rates) + plt.title('Win Rate per Episode') + plt.xlabel('Episode') + plt.ylabel('Win Rate (%)') + plt.ylim(0, 100) + + plt.tight_layout() + plt.savefig('training_results.png') + + # Also save episode-specific plots periodically + if episode % 20 == 0: + os.makedirs('visualizations', exist_ok=True) + plt.savefig(f'visualizations/training_episode_{episode}.png') + + plt.close() + +def load_checkpoint(price_model, dqn_model, optimizer, episode=None): + """Load model checkpoint for continuous training""" + if episode is not None: + checkpoint_path = f'models/enhanced_trading_agent_continuous_{episode}.pt' + else: + checkpoint_path = 'models/enhanced_trading_agent_latest.pt' + + if os.path.exists(checkpoint_path): + print(f"Loading checkpoint from {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=DEVICE) + + price_model.load_state_dict(checkpoint['price_model_state_dict']) + dqn_model.load_state_dict(checkpoint['dqn_model_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + start_episode = checkpoint['episode'] + 1 + rewards = checkpoint['rewards'] + profits = checkpoint['profits'] + win_rates = checkpoint['win_rates'] + + print(f"Resuming training from episode {start_episode}") + return start_episode, rewards, profits, win_rates + else: + print("No checkpoint found, starting training from scratch") + return 0, [], [], [] + +def enhanced_train_agent(exchange, num_episodes=NUM_EPISODES, continuous=CONTINUOUS_MODE, start_episode=CONTINUOUS_START_EPISODE): + """ + Train the enhanced trading agent using multi-timeframe data + + Args: + exchange: Exchange object to fetch data from + num_episodes: Number of episodes to train for + continuous: Whether to continue training from a checkpoint + start_episode: Episode to start from if continuous training + """ + print(f"Training on device: {DEVICE}") + + # Set up TensorBoard + writer = setup_tensorboard() + + # Initialize models + state_dim = 100 # Increased state dimension for multi-timeframe features + action_dim = 3 # Buy, Sell, Hold + + price_model = EnhancedPricePredictionModel( + input_dim=2, # Price and volume + hidden_dim=256, + num_layers=3, + output_dim=5, # Predict next 5 candles + num_timeframes=len(TIMEFRAMES) + ).to(DEVICE) + + dqn_model = EnhancedDQN( + state_dim=state_dim, + action_dim=action_dim, + hidden_dim=512 + ).to(DEVICE) + + target_dqn = EnhancedDQN( + state_dim=state_dim, + action_dim=action_dim, + hidden_dim=512 + ).to(DEVICE) + + # Copy initial weights to target network + target_dqn.load_state_dict(dqn_model.state_dict()) + + # Initialize optimizer + optimizer = optim.Adam(list(price_model.parameters()) + list(dqn_model.parameters()), lr=LEARNING_RATE) + + # Initialize replay buffer + replay_buffer = EnhancedReplayBuffer( + capacity=REPLAY_BUFFER_SIZE, + alpha=0.6, + beta=0.4, + beta_increment=0.001, + n_step=3, + gamma=GAMMA + ) + + # Initialize gradient scaler for mixed precision training + scaler = GradScaler(enabled=(DEVICE.type == 'cuda')) + + # Initialize tracking variables + rewards = [] + profits = [] + win_rates = [] + best_reward = float('-inf') + best_pnl = float('-inf') + best_winrate = float('-inf') + + # Load checkpoint if continuous training + if continuous: + start_episode, rewards, profits, win_rates = load_checkpoint( + price_model, dqn_model, optimizer, start_episode + ) + + # Prepare multi-timeframe data for price prediction model training + data_loaders = prepare_multi_timeframe_data(exchange, TIMEFRAMES) + + # Pre-train price prediction model + print("Pre-training price prediction model...") + train_price_predictor(price_model, data_loaders, optimizer, DEVICE, epochs=5) + + # Main training loop + epsilon = EPSILON_START + + for episode in range(start_episode, num_episodes): + print(f"Episode {episode+1}/{num_episodes}") + + # Reset environment + state = initialize_state(exchange, TIMEFRAMES) + total_reward = 0 + trades = [] + wins = 0 + losses = 0 + + # Episode loop + for step in range(MAX_STEPS_PER_EPISODE): + # Epsilon-greedy action selection + if np.random.random() < epsilon: + action = np.random.randint(0, action_dim) + else: + with torch.no_grad(): + state_tensor = torch.FloatTensor(state).unsqueeze(0).to(DEVICE) + q_values, _, _ = dqn_model(state_tensor) + action = q_values.argmax().item() + + # Execute action and get next state and reward + next_state, reward, done, trade_info = step_environment( + exchange, state, action, price_model, TIMEFRAMES, DEVICE + ) + + # Store transition in replay buffer + replay_buffer.push( + torch.FloatTensor(state), + action, + reward, + torch.FloatTensor(next_state), + done + ) + + # Update state and accumulate reward + state = next_state + total_reward += reward + + # Track trade outcomes + if trade_info is not None: + trades.append(trade_info) + if trade_info['pnl'] > 0: + wins += 1 + elif trade_info['pnl'] < 0: + losses += 1 + + # Learn from experiences if enough samples + if len(replay_buffer) > BATCH_SIZE: + learn(dqn_model, target_dqn, replay_buffer, optimizer, scaler, DEVICE) + + if done: + break + + # Update target network + if episode % TARGET_UPDATE == 0: + target_dqn.load_state_dict(dqn_model.state_dict()) + + # Calculate episode metrics + avg_reward = total_reward / (step + 1) + total_pnl = sum(trade['pnl'] for trade in trades) if trades else 0 + win_rate = (wins / (wins + losses) * 100) if (wins + losses) > 0 else 0 + + # Decay epsilon + epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY) + + # Track metrics + rewards.append(avg_reward) + profits.append(total_pnl) + win_rates.append(win_rate) + + # Log to TensorBoard + writer.add_scalar('Training/Reward', avg_reward, episode) + writer.add_scalar('Training/Profit', total_pnl, episode) + writer.add_scalar('Training/WinRate', win_rate, episode) + writer.add_scalar('Training/Epsilon', epsilon, episode) + + # Print episode summary + print(f"Episode {episode+1} - Avg Reward: {avg_reward:.2f}, PnL: ${total_pnl:.2f}, Win Rate: {win_rate:.1f}%") + + # Save models and plot results + if episode % SAVE_INTERVAL == 0 or episode == num_episodes - 1: + best_reward, best_pnl, best_winrate = save_models( + price_model, dqn_model, optimizer, episode, + rewards, profits, win_rates, + best_reward, best_pnl, best_winrate + ) + plot_training_results(rewards, profits, win_rates, episode) + + # Close TensorBoard writer + writer.close() + + # Final save and plot + best_reward, best_pnl, best_winrate = save_models( + price_model, dqn_model, optimizer, num_episodes - 1, + rewards, profits, win_rates, + best_reward, best_pnl, best_winrate + ) + plot_training_results(rewards, profits, win_rates, num_episodes - 1) + + print("Training complete!") + return price_model, dqn_model + +def learn(dqn, target_dqn, replay_buffer, optimizer, scaler, device): + """Update the DQN model using experiences from the replay buffer""" + # Sample from replay buffer + states, actions, rewards, next_states, dones, indices, weights = replay_buffer.sample(BATCH_SIZE) + + # Move to device + states = states.to(device) + actions = actions.to(device) + rewards = rewards.to(device) + next_states = next_states.to(device) + dones = dones.to(device) + weights = weights.to(device) + + # Get current Q values + if device.type == 'cuda': + with autocast(device_type='cuda', enabled=True): + current_q_values, _, _ = dqn(states) + current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1) + + # Compute target Q values + with torch.no_grad(): + next_q_values, _, _ = target_dqn(next_states) + max_next_q_values = next_q_values.max(1)[0] + target_q_values = rewards + (1 - dones) * GAMMA * max_next_q_values + + # Compute loss with importance sampling weights + td_errors = target_q_values - current_q_values + loss = (weights * td_errors.pow(2)).mean() + else: + # CPU version without autocast + current_q_values, _, _ = dqn(states) + current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1) + + # Compute target Q values + with torch.no_grad(): + next_q_values, _, _ = target_dqn(next_states) + max_next_q_values = next_q_values.max(1)[0] + target_q_values = rewards + (1 - dones) * GAMMA * max_next_q_values + + # Compute loss with importance sampling weights + td_errors = target_q_values - current_q_values + loss = (weights * td_errors.pow(2)).mean() + + # Update priorities in replay buffer + replay_buffer.update_priorities(indices, td_errors.abs().detach().cpu().numpy()) + + # Optimize the model with mixed precision + optimizer.zero_grad() + + if device.type == 'cuda': + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(dqn.parameters(), max_norm=1.0) + scaler.step(optimizer) + scaler.update() + else: + # CPU version without scaler + loss.backward() + torch.nn.utils.clip_grad_norm_(dqn.parameters(), max_norm=1.0) + optimizer.step() + +def initialize_state(exchange, timeframes): + """Initialize the state with data from multiple timeframes""" + # Fetch data for each timeframe + timeframe_data = {} + for tf in timeframes: + candles = exchange.fetch_ohlcv(timeframe=tf, limit=30) + timeframe_data[tf] = candles + + # Extract features from each timeframe + state = [] + + for tf in timeframes: + candles = timeframe_data[tf] + + # Price features + prices = [candle[4] for candle in candles[-10:]] # Last 10 close prices + price_changes = [prices[i]/prices[i-1] - 1 for i in range(1, len(prices))] + + # Volume features + volumes = [candle[5] for candle in candles[-10:]] # Last 10 volumes + volume_changes = [volumes[i]/volumes[i-1] - 1 for i in range(1, len(volumes))] + + # Technical indicators + # Simple Moving Averages + sma_5 = sum(prices[-5:]) / 5 + sma_10 = sum(prices) / 10 + + # Relative Strength Index (simplified) + gains = [max(0, price_changes[i]) for i in range(len(price_changes))] + losses = [max(0, -price_changes[i]) for i in range(len(price_changes))] + avg_gain = sum(gains) / len(gains) + avg_loss = sum(losses) / len(losses) + rs = avg_gain / (avg_loss + 1e-10) # Avoid division by zero + rsi = 100 - (100 / (1 + rs)) + + # Add features to state + state.extend(price_changes) # 9 features + state.extend(volume_changes) # 9 features + state.append(sma_5 / prices[-1] - 1) # 1 feature + state.append(sma_10 / prices[-1] - 1) # 1 feature + state.append(rsi / 100) # 1 feature + + # Add market regime features + # This is a placeholder - in a real implementation, you would use the market_regime_classifier + # from the DQN model to predict the current market regime + state.extend([0, 0, 0]) # 3 features for market regime (one-hot encoded) + + # Add additional features to reach the expected dimension of 100 + # Calculate more technical indicators + for tf in timeframes: + candles = timeframe_data[tf] + prices = [candle[4] for candle in candles[-20:]] # Last 20 close prices + + # Bollinger Bands + window = 20 + if len(prices) >= window: + sma_20 = sum(prices[-window:]) / window + std_dev = (sum((price - sma_20) ** 2 for price in prices[-window:]) / window) ** 0.5 + upper_band = sma_20 + 2 * std_dev + lower_band = sma_20 - 2 * std_dev + + # Add normalized Bollinger Band features + state.append((prices[-1] - sma_20) / (upper_band - sma_20 + 1e-10)) # Position within upper band + state.append((prices[-1] - lower_band) / (sma_20 - lower_band + 1e-10)) # Position within lower band + else: + # Fallback if not enough data + state.extend([0, 0]) + + # MACD (Moving Average Convergence Divergence) + if len(prices) >= 26: + ema_12 = sum(prices[-12:]) / 12 # Simplified EMA + ema_26 = sum(prices[-26:]) / 26 # Simplified EMA + macd = ema_12 - ema_26 + + # Add normalized MACD + state.append(macd / prices[-1]) + else: + # Fallback if not enough data + state.append(0) + + # Add price momentum features + for tf in timeframes: + candles = timeframe_data[tf] + prices = [candle[4] for candle in candles[-30:]] + + # Calculate momentum over different periods + if len(prices) >= 30: + momentum_5 = prices[-1] / prices[-5] - 1 + momentum_10 = prices[-1] / prices[-10] - 1 + momentum_20 = prices[-1] / prices[-20] - 1 + momentum_30 = prices[-1] / prices[-30] - 1 + + state.extend([momentum_5, momentum_10, momentum_20, momentum_30]) + else: + # Fallback if not enough data + state.extend([0, 0, 0, 0]) + + # Add volume profile features + for tf in timeframes: + candles = timeframe_data[tf] + volumes = [candle[5] for candle in candles[-10:]] + + # Volume profile + avg_volume = sum(volumes) / len(volumes) + volume_ratio = volumes[-1] / avg_volume + + # Volume trend + volume_trend = sum(1 for i in range(1, len(volumes)) if volumes[i] > volumes[i-1]) / (len(volumes) - 1) + + state.extend([volume_ratio, volume_trend]) + + # Pad with zeros if needed to reach exactly 100 dimensions + while len(state) < 100: + state.append(0) + + # Ensure state has exactly 100 dimensions + if len(state) > 100: + state = state[:100] + + assert len(state) == 100, f"State dimension mismatch: {len(state)} != 100" + + return state + +def step_environment(exchange, state, action, price_model, timeframes, device): + """ + Execute action in the environment and return next state, reward, done flag, and trade info + + Args: + exchange: Exchange object to interact with + state: Current state + action: Action to take (0: Hold, 1: Buy, 2: Sell) + price_model: Price prediction model + timeframes: List of timeframes to use + device: Device to run models on + + Returns: + next_state: Next state after taking action + reward: Reward received + done: Whether episode is done + trade_info: Information about the trade (if any) + """ + # Fetch latest data for each timeframe + timeframe_data = {} + for tf in timeframes: + candles = exchange.fetch_ohlcv(timeframe=tf, limit=30) + timeframe_data[tf] = candles + + # Prepare inputs for price prediction model + price_inputs = [] + for tf in timeframes: + candles = timeframe_data[tf] + # Extract price and volume data + input_data = torch.tensor([ + [candle[4], candle[5]] for candle in candles[-30:] # Last 30 candles + ], dtype=torch.float32).unsqueeze(0).to(device) # Add batch dimension + price_inputs.append(input_data) + + # Get price and extrema predictions + with torch.no_grad(): + price_pred, extrema_logits, volume_pred = price_model(price_inputs) + + # Convert predictions to numpy + price_pred = price_pred.cpu().numpy()[0] # Remove batch dimension + extrema_probs = torch.sigmoid(extrema_logits).cpu().numpy()[0] + volume_pred = volume_pred.cpu().numpy()[0] + + # Execute action + current_price = timeframe_data['1m'][-1][4] # Current close price + trade_info = None + reward = 0 + + if action == 1: # Buy + # Check if we're at a predicted low point (good time to buy) + is_predicted_low = any(extrema_probs[i*2+1] > 0.7 for i in range(5)) + + # Calculate entry quality based on predictions + entry_quality = 0.5 # Default quality + if is_predicted_low: + entry_quality += 0.2 # Bonus for buying at predicted low + + # Check volume confirmation + volume_increasing = volume_pred[0] > timeframe_data['1m'][-1][5] + if volume_increasing: + entry_quality += 0.1 # Bonus for increasing volume + + # Execute buy order + # In a real implementation, this would interact with the exchange + # For now, we'll simulate the trade + trade_info = { + 'action': 'buy', + 'price': current_price, + 'size': 100 * entry_quality, # Size based on entry quality + 'entry_quality': entry_quality, + 'pnl': 0 # Will be updated later + } + + # Calculate reward + # Base reward for taking action + reward = 1 + + # Bonus for buying at predicted low + if is_predicted_low: + reward += 5 + print("Trading at predicted low - additional reward") + + # Bonus for volume confirmation + if volume_increasing: + reward += 2 + print("Trading with high volume - additional reward") + + elif action == 2: # Sell + # Check if we're at a predicted high point (good time to sell) + is_predicted_high = any(extrema_probs[i*2] > 0.7 for i in range(5)) + + # Calculate entry quality based on predictions + entry_quality = 0.5 # Default quality + if is_predicted_high: + entry_quality += 0.2 # Bonus for selling at predicted high + + # Check volume confirmation + volume_increasing = volume_pred[0] > timeframe_data['1m'][-1][5] + if volume_increasing: + entry_quality += 0.1 # Bonus for increasing volume + + # Execute sell order + # In a real implementation, this would interact with the exchange + # For now, we'll simulate the trade + trade_info = { + 'action': 'sell', + 'price': current_price, + 'size': 100 * entry_quality, # Size based on entry quality + 'entry_quality': entry_quality, + 'pnl': 0 # Will be updated later + } + + # Calculate reward + # Base reward for taking action + reward = 1 + + # Bonus for selling at predicted high + if is_predicted_high: + reward += 5 + print("Trading at predicted high - additional reward") + + # Bonus for volume confirmation + if volume_increasing: + reward += 2 + print("Trading with high volume - additional reward") + + else: # Hold + # Small reward for holding + reward = 0.1 + + # Simulate trade outcome + if trade_info is not None: + # In a real implementation, this would be based on actual market movement + # For now, we'll use the price prediction to simulate the outcome + future_price = price_pred[0] # Price in the next candle + + if trade_info['action'] == 'buy': + # For buy, profit if price goes up + pnl_pct = (future_price / current_price - 1) * 100 + trade_info['pnl'] = pnl_pct * trade_info['size'] / 100 + else: # sell + # For sell, profit if price goes down + pnl_pct = (1 - future_price / current_price) * 100 + trade_info['pnl'] = pnl_pct * trade_info['size'] / 100 + + # Adjust reward based on trade outcome + reward += trade_info['pnl'] * 10 # Scale PnL for reward + + # Update state + next_state = initialize_state(exchange, timeframes) + + # Check if episode is done + # In a real implementation, this would be based on episode length or other criteria + done = False + + return next_state, reward, done, trade_info + +# Main function to run training +def main(): + from exchange_simulator import ExchangeSimulator + + # Initialize exchange simulator + exchange = ExchangeSimulator() + + # Train agent + price_model, dqn_model = enhanced_train_agent( + exchange=exchange, + num_episodes=NUM_EPISODES, + continuous=CONTINUOUS_MODE, + start_episode=CONTINUOUS_START_EPISODE + ) + + print("Training complete!") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/crypto/gogo2/exchange_simulator.py b/crypto/gogo2/exchange_simulator.py new file mode 100644 index 0000000..04aae1b --- /dev/null +++ b/crypto/gogo2/exchange_simulator.py @@ -0,0 +1,373 @@ +import numpy as np +import pandas as pd +import os +import random +from datetime import datetime, timedelta + +class ExchangeSimulator: + """ + A simple exchange simulator that generates realistic market data + for testing trading algorithms without connecting to a real exchange. + """ + + def __init__(self, symbol="BTC/USDT", seed=42): + """ + Initialize the exchange simulator + + Args: + symbol: Trading pair symbol + seed: Random seed for reproducibility + """ + self.symbol = symbol + self.seed = seed + np.random.seed(seed) + random.seed(seed) + + # Initialize data storage + self.data = {} + self.current_timestamp = datetime.now() + + # Generate initial data for different timeframes + self.timeframes = ['1m', '5m', '15m', '30m', '1h', '4h', '1d'] + self.timeframe_minutes = { + '1m': 1, + '5m': 5, + '15m': 15, + '30m': 30, + '1h': 60, + '4h': 240, + '1d': 1440 + } + + # Generate initial price around $50,000 (for BTC/USDT) + self.base_price = 50000.0 + + # Generate data for each timeframe + for tf in self.timeframes: + self._generate_initial_data(tf) + + def _generate_initial_data(self, timeframe, num_candles=1000): + """ + Generate initial historical data for a specific timeframe + + Args: + timeframe: Timeframe to generate data for + num_candles: Number of candles to generate + """ + # Calculate time delta for this timeframe + minutes = self.timeframe_minutes[timeframe] + + # Generate timestamps + end_time = self.current_timestamp + timestamps = [end_time - timedelta(minutes=minutes * i) for i in range(num_candles)] + timestamps.reverse() # Oldest first + + # Generate price data with realistic patterns + prices = self._generate_price_series(num_candles) + + # Generate volume data with realistic patterns + volumes = self._generate_volume_series(num_candles, timeframe) + + # Create OHLCV data + ohlcv_data = [] + for i in range(num_candles): + # Calculate OHLC based on close price + close = prices[i] + high = close * (1 + np.random.uniform(0, 0.01)) + low = close * (1 - np.random.uniform(0, 0.01)) + open_price = prices[i-1] if i > 0 else close * (1 - np.random.uniform(-0.005, 0.005)) + + # Create candle + candle = [ + int(timestamps[i].timestamp() * 1000), # Timestamp in milliseconds + open_price, # Open + high, # High + low, # Low + close, # Close + volumes[i] # Volume + ] + ohlcv_data.append(candle) + + # Store data + self.data[timeframe] = ohlcv_data + + def _generate_price_series(self, length): + """ + Generate a realistic price series with trends, reversals, and volatility + + Args: + length: Number of prices to generate + + Returns: + List of prices + """ + # Start with base price + prices = [self.base_price] + + # Parameters for price generation + trend_strength = 0.001 # Strength of trend + volatility = 0.005 # Daily volatility + mean_reversion = 0.001 # Mean reversion strength + + # Generate price series + for i in range(1, length): + # Determine if we're in a trend + if i % 100 == 0: + # Change trend direction every ~100 candles + trend_strength = -trend_strength + + # Calculate price change + trend = trend_strength * prices[-1] + random_change = np.random.normal(0, volatility) * prices[-1] + mean_reversion_change = mean_reversion * (self.base_price - prices[-1]) + + # Calculate new price + new_price = prices[-1] + trend + random_change + mean_reversion_change + + # Ensure price doesn't go negative + new_price = max(new_price, prices[-1] * 0.9) + + prices.append(new_price) + + return prices + + def _generate_volume_series(self, length, timeframe): + """ + Generate a realistic volume series with patterns + + Args: + length: Number of volumes to generate + timeframe: Timeframe for volume scaling + + Returns: + List of volumes + """ + # Base volume depends on timeframe + base_volume = { + '1m': 10, + '5m': 50, + '15m': 150, + '30m': 300, + '1h': 600, + '4h': 2400, + '1d': 10000 + }[timeframe] + + # Generate volume series + volumes = [] + for i in range(length): + # Volume tends to be higher at trend reversals and during volatile periods + cycle_factor = 1 + 0.5 * np.sin(i / 20) # Cyclical pattern + random_factor = np.random.lognormal(0, 0.5) # Random spikes + + # Calculate volume + volume = base_volume * cycle_factor * random_factor + + # Add some volume spikes + if random.random() < 0.05: # 5% chance of volume spike + volume *= random.uniform(2, 5) + + volumes.append(volume) + + return volumes + + def fetch_ohlcv(self, timeframe='1m', limit=100, since=None): + """ + Fetch OHLCV data for a specific timeframe + + Args: + timeframe: Timeframe to fetch data for + limit: Number of candles to fetch + since: Timestamp to fetch data since (not used in simulator) + + Returns: + List of OHLCV candles + """ + # Ensure timeframe exists + if timeframe not in self.data: + if timeframe in self.timeframe_minutes: + self._generate_initial_data(timeframe) + else: + # Default to 1m if timeframe not supported + timeframe = '1m' + + # Get data + data = self.data[timeframe] + + # Return limited data + return data[-limit:] + + def update(self): + """ + Update the exchange data by generating a new candle for each timeframe + """ + # Update current timestamp + self.current_timestamp = datetime.now() + + # Update each timeframe + for tf in self.timeframes: + self._add_new_candle(tf) + + def _add_new_candle(self, timeframe): + """ + Add a new candle to the specified timeframe + + Args: + timeframe: Timeframe to add candle to + """ + # Get existing data + data = self.data[timeframe] + + # Get last close price + last_close = data[-1][4] + + # Calculate time delta for this timeframe + minutes = self.timeframe_minutes[timeframe] + + # Calculate new timestamp + new_timestamp = int((data[-1][0] / 1000 + minutes * 60) * 1000) + + # Generate new price with some randomness + price_change = np.random.normal(0, 0.002) * last_close + new_close = last_close + price_change + + # Calculate OHLC + new_open = last_close + new_high = max(new_open, new_close) * (1 + np.random.uniform(0, 0.005)) + new_low = min(new_open, new_close) * (1 - np.random.uniform(0, 0.005)) + + # Generate volume + base_volume = data[-1][5] + volume_change = np.random.normal(0, 0.2) * base_volume + new_volume = max(base_volume + volume_change, base_volume * 0.5) + + # Create new candle + new_candle = [ + new_timestamp, + new_open, + new_high, + new_low, + new_close, + new_volume + ] + + # Add to data + self.data[timeframe].append(new_candle) + + def get_ticker(self, symbol=None): + """ + Get current ticker information + + Args: + symbol: Symbol to get ticker for (defaults to initialized symbol) + + Returns: + Dictionary with ticker information + """ + if symbol is None: + symbol = self.symbol + + # Get latest 1m candle + latest_candle = self.data['1m'][-1] + + return { + 'symbol': symbol, + 'bid': latest_candle[4] * 0.9999, # Slightly below last price + 'ask': latest_candle[4] * 1.0001, # Slightly above last price + 'last': latest_candle[4], + 'high': latest_candle[2], + 'low': latest_candle[3], + 'volume': latest_candle[5], + 'timestamp': latest_candle[0] + } + + def create_order(self, symbol, type, side, amount, price=None): + """ + Simulate creating an order + + Args: + symbol: Symbol to create order for + type: Order type (limit, market) + side: Order side (buy, sell) + amount: Order amount + price: Order price (for limit orders) + + Returns: + Dictionary with order information + """ + # Get current ticker + ticker = self.get_ticker(symbol) + + # Determine execution price + if type == 'market': + if side == 'buy': + execution_price = ticker['ask'] + else: + execution_price = ticker['bid'] + else: # limit order + execution_price = price + + # Create order object + order = { + 'id': f"order_{int(datetime.now().timestamp() * 1000)}", + 'symbol': symbol, + 'type': type, + 'side': side, + 'amount': amount, + 'price': execution_price, + 'cost': amount * execution_price, + 'filled': amount, + 'status': 'closed', + 'timestamp': int(datetime.now().timestamp() * 1000) + } + + return order + + def fetch_balance(self): + """ + Fetch account balance (simulated) + + Returns: + Dictionary with balance information + """ + return { + 'total': { + 'USD': 10000.0, + 'BTC': 1.0 + }, + 'free': { + 'USD': 5000.0, + 'BTC': 0.5 + }, + 'used': { + 'USD': 5000.0, + 'BTC': 0.5 + } + } + +# Example usage +if __name__ == "__main__": + # Create exchange simulator + exchange = ExchangeSimulator() + + # Fetch some data + ohlcv = exchange.fetch_ohlcv(timeframe='1h', limit=10) + print("OHLCV data (1h timeframe):") + for candle in ohlcv[-5:]: + timestamp = datetime.fromtimestamp(candle[0] / 1000) + print(f"{timestamp}: Open={candle[1]:.2f}, High={candle[2]:.2f}, Low={candle[3]:.2f}, Close={candle[4]:.2f}, Volume={candle[5]:.2f}") + + # Get current ticker + ticker = exchange.get_ticker() + print(f"\nCurrent ticker: {ticker['last']:.2f}") + + # Create a market buy order + order = exchange.create_order("BTC/USDT", "market", "buy", 0.1) + print(f"\nCreated order: {order}") + + # Update the exchange (simulate time passing) + exchange.update() + + # Get updated ticker + updated_ticker = exchange.get_ticker() + print(f"\nUpdated ticker: {updated_ticker['last']:.2f}") \ No newline at end of file diff --git a/crypto/gogo2/main.py b/crypto/gogo2/main.py index fb378c3..4ca74cc 100644 --- a/crypto/gogo2/main.py +++ b/crypto/gogo2/main.py @@ -1,27 +1,50 @@ import os +import sys import time import json -import numpy as np -import pandas as pd -import datetime -import random import logging import asyncio +import argparse +import traceback +import datetime +import pandas as pd +import numpy as np import matplotlib.pyplot as plt +import matplotlib.dates as mdates +from matplotlib.ticker import FuncFormatter +import mplfinance as mpf +from collections import deque, namedtuple +import random +from typing import List, Dict, Tuple, Optional, Union, Any +from dotenv import load_dotenv +import torch.nn.functional as F +import math +from mexc_trading import MexcTradingClient + import torch import torch.nn as nn import torch.optim as optim -import torch.nn.functional as F -from collections import deque, namedtuple -from dotenv import load_dotenv -import ccxt -import websockets from torch.utils.tensorboard import SummaryWriter import torch.amp as amp # Update import to use torch.amp instead of torch.cuda.amp from sklearn.preprocessing import MinMaxScaler -import traceback -import math -from mexc_trading import MexcTradingClient + +import ccxt.async_support as ccxt +import websockets +from data_cache import ohlcv_cache + + + +if sys.platform == 'win32': + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + +# Constants +INITIAL_BALANCE = 1000.0 +MAX_LEVERAGE = 1.0 # Max leverage to use +STOP_LOSS_PERCENT = 2.0 # Default stop loss percentage +TAKE_PROFIT_PERCENT = 5.0 # Default take profit percentage +STATE_SIZE = 128 # Size of state representation +LEARNING_RATE = 0.0001 +MODEL_DIR = "models_improved" # New models directory # Load environment variables load_dotenv() @@ -224,13 +247,16 @@ class DQN(nn.Module): def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4): super(DQN, self).__init__() - # Feature extraction layers + # Feature extraction layers with increased regularization self.feature_extraction = nn.Sequential( nn.Linear(state_size, hidden_size), nn.LeakyReLU(), - nn.Dropout(0.1), + nn.Dropout(0.2), # Increased dropout + nn.LayerNorm(hidden_size), # Layer normalization for stability nn.Linear(hidden_size, hidden_size), nn.LeakyReLU(), + nn.Dropout(0.2), # Increased dropout + nn.LayerNorm(hidden_size) # Layer normalization for stability ) # LSTM for sequential processing @@ -239,7 +265,7 @@ class DQN(nn.Module): hidden_size=hidden_size, num_layers=lstm_layers, batch_first=True, - dropout=0.1 if lstm_layers > 1 else 0 + dropout=0.2 if lstm_layers > 1 else 0 # Increased dropout ) # Dueling network architecture @@ -247,6 +273,7 @@ class DQN(nn.Module): self.advantage_stream = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), nn.LeakyReLU(), + nn.Dropout(0.2), # Added dropout nn.Linear(hidden_size // 2, action_size) ) @@ -254,6 +281,7 @@ class DQN(nn.Module): self.value_stream = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), nn.LeakyReLU(), + nn.Dropout(0.2), # Added dropout nn.Linear(hidden_size // 2, 1) ) @@ -261,6 +289,7 @@ class DQN(nn.Module): self.market_regime_classifier = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), nn.LeakyReLU(), + nn.Dropout(0.2), # Added dropout nn.Linear(hidden_size // 2, 3) # 3 regimes: trending, ranging, volatile ) @@ -631,28 +660,37 @@ class TradingEnvironment: self.balance = initial_balance self.window_size = window_size self.demo = demo + self.trading_client = trading_client self.data = [] + self.features = {} self.position = 'flat' # 'flat', 'long', or 'short' self.position_size = 0 self.entry_price = 0 self.entry_index = 0 self.stop_loss = 0 self.take_profit = 0 + self.current_step = 0 + self.current_price = 0 self.trades = [] self.win_count = 0 self.loss_count = 0 - self.total_pnl = 0.0 self.episode_pnl = 0.0 + self.total_pnl = 0.0 self.peak_balance = initial_balance self.max_drawdown = 0.0 - self.current_step = 0 - self.current_price = 0 - self.last_action = 2 # Default to HOLD (2) - - # Initialize trading client for live trading - self.trading_client = trading_client + self.action_space = 4 # 0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE + self.last_action = 0 # Initialize with HOLD # Initialize features + self._initialize_features() + + # Initialize price predictor + self.price_predictor = None + self.price_predictions = [] + self.predicted_extrema = [] + + def _initialize_features(self): + """Initialize technical indicators and features""" self.features = { 'price': [], 'volume': [], @@ -670,13 +708,6 @@ class TradingEnvironment: 'atr': [] } - # Initialize price predictor - self.price_predictor = None - self.price_predictions = [] - self.extrema_predictions = [] - self.has_predicted_low = False - self.has_predicted_high = False - # Initialize timeframe data structure self.timeframe_data = { '1m': {'prices': [], 'volumes': []}, @@ -692,100 +723,49 @@ class TradingEnvironment: # Add risk factor for curriculum learning self.risk_factor = 1.0 # Default risk factor - def reset(self): - """Reset the environment to initial state""" - self.balance = self.initial_balance - self.position = 'flat' - self.position_size = 0 - self.entry_price = 0 - self.entry_index = 0 - self.stop_loss = 0 - self.take_profit = 0 - self.trades = [] - self.win_count = 0 - self.loss_count = 0 - self.episode_pnl = 0.0 - self.peak_balance = self.initial_balance - self.max_drawdown = 0.0 - self.current_step = 0 - - # Keep data but reset current position - if len(self.data) > self.window_size: - self.current_step = self.window_size - self.current_price = self.data[self.current_step]['close'] - - return self.get_state() - - def add_data(self, candle): - """Add a new candle to the data""" - self.data.append(candle) - self._update_features() - self.current_price = candle['close'] - - def _initialize_features(self): - """Initialize technical indicators and features""" - if len(self.data) < 30: - return - - # Convert data to pandas DataFrame for easier calculation - df = pd.DataFrame(self.data) - - # Basic price and volume - self.features['price'] = df['close'].values - self.features['volume'] = df['volume'].values - - # Calculate RSI (14 periods) - delta = df['close'].diff() - gain = delta.where(delta > 0, 0).rolling(window=14).mean() - loss = -delta.where(delta < 0, 0).rolling(window=14).mean() - rs = gain / loss - self.features['rsi'] = 100 - (100 / (1 + rs)).fillna(50).values - - # Calculate MACD - ema12 = df['close'].ewm(span=12, adjust=False).mean() - ema26 = df['close'].ewm(span=26, adjust=False).mean() - macd = ema12 - ema26 - signal = macd.ewm(span=9, adjust=False).mean() - self.features['macd'] = macd.values - self.features['macd_signal'] = signal.values - self.features['macd_hist'] = (macd - signal).values - - # Calculate Bollinger Bands - sma20 = df['close'].rolling(window=20).mean() - std20 = df['close'].rolling(window=20).std() - self.features['bollinger_upper'] = (sma20 + 2 * std20).values - self.features['bollinger_mid'] = sma20.values - self.features['bollinger_lower'] = (sma20 - 2 * std20).values - - # Calculate Stochastic Oscillator - low_14 = df['low'].rolling(window=14).min() - high_14 = df['high'].rolling(window=14).max() - k = 100 * ((df['close'] - low_14) / (high_14 - low_14)) - self.features['stoch_k'] = k.values - self.features['stoch_d'] = k.rolling(window=3).mean().values - - # Calculate EMAs - self.features['ema_9'] = df['close'].ewm(span=9, adjust=False).mean().values - self.features['ema_21'] = df['close'].ewm(span=21, adjust=False).mean().values - - # Calculate ATR - high_low = df['high'] - df['low'] - high_close = (df['high'] - df['close'].shift()).abs() - low_close = (df['low'] - df['close'].shift()).abs() - tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1) - self.features['atr'] = tr.rolling(window=14).mean().fillna(0).values - def _update_features(self): """Update technical indicators with new data""" self._initialize_features() # Recalculate all features - async def fetch_new_data(env, exchange, symbol="ETH/USDT", timeframe="1m", limit=1000): - """Fetch new data for the environment""" - # Call the environment's fetch_initial_data method - return await env.fetch_initial_data(exchange, symbol, timeframe, limit) - + async def fetch_new_data(self, exchange, symbol="ETH/USDT", timeframe="1m", limit=100): + """Fetch new data from the exchange and update the environment""" + from data_cache import ohlcv_cache + + try: + logger.info(f"Fetching new data for {symbol} with timeframe {timeframe}") + + # Use the refactored fetch method + data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit) + + # Update environment with fetched data + if data and len(data) > 0: + self.data = data + self._initialize_features() + logger.info(f"Updated environment with {len(data)} candles") + return True + else: + logger.warning("No new data received from exchange or cache") + + # Try to use existing data if available + if self.data and len(self.data) > 0: + logger.info(f"Using existing data ({len(self.data)} candles)") + return True + + return False + except Exception as e: + logger.error(f"Error fetching new data: {e}") + + # Try to use existing data if available + if self.data and len(self.data) > 0: + logger.info(f"Using existing data ({len(self.data)} candles) after fetch error") + return True + + return False + async def fetch_initial_data(self, exchange, symbol="ETH/USDT", timeframe="1m", limit=1000): """Fetch initial historical data for the environment""" + from data_cache import ohlcv_cache + try: logger.info(f"Fetching initial data for {symbol}") @@ -793,27 +773,39 @@ class TradingEnvironment: data = await fetch_ohlcv_data(exchange, symbol, timeframe, limit) # Update environment with fetched data - if data: + if data and len(data) > 0: self.data = data self._initialize_features() logger.info(f"Initialized environment with {len(data)} candles") + return True else: - logger.warning("No initial data received") - - return len(data) > 0 + logger.warning("No initial data received from exchange or cache") + return False except Exception as e: logger.error(f"Error fetching initial data: {e}") return False def step(self, action): """Take an action in the environment and return the next state, reward, and done flag""" + # Check if we have data and if current_step is within bounds + if not self.data or self.current_step >= len(self.data): + logger.error(f"No data available or current_step ({self.current_step}) out of bounds (data length: {len(self.data) if self.data else 0})") + # Return a default state, negative reward, and done=True to end the episode + return self.get_state(), -10, True, {"error": "No data available"} + # Store current price before taking action self.current_price = self.data[self.current_step]['close'] + # Store last action for reward calculation + self.last_action = action + # Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE) if not self.demo and self.trading_client: # Execute real trades in live mode asyncio.create_task(self._execute_live_action(action)) + else: + # Simulate trades in demo mode + self._simulate_action(action) # Calculate reward (simulation still runs in parallel with live trading) reward, info = self.calculate_reward(action) # Unpack the tuple here @@ -830,6 +822,218 @@ class TradingEnvironment: return next_state, reward, done, info + def _simulate_action(self, action): + """Simulate trading action in demo mode""" + try: + if action == 0: # HOLD + # No action needed + pass + + elif action == 1: # BUY/LONG + if self.position == 'flat': + # Open long position + self.position_size = self.calculate_position_size() + self.entry_price = self.current_price + self.entry_index = self.current_step + self.stop_loss = self.current_price * (1 - STOP_LOSS_PERCENT / 100) + self.take_profit = self.current_price * (1 + TAKE_PROFIT_PERCENT / 100) + self.position = 'long' + + logger.info(f"DEMO: Opened LONG position at {self.current_price}") + + elif self.position == 'short': + # Close short position + pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100 + pnl_dollar = pnl_percent / 100 * self.position_size + + # Apply fees + pnl_dollar -= self.calculate_fees(self.position_size) + + # Update balance + self.balance += pnl_dollar + self.total_pnl += pnl_dollar + self.episode_pnl += pnl_dollar + + # Record trade + self.trades.append({ + 'type': 'short', + 'entry': self.entry_price, + 'exit': self.current_price, + 'entry_time': self.data[self.entry_index]['timestamp'], + 'exit_time': self.data[self.current_step]['timestamp'], + 'pnl_percent': pnl_percent, + 'pnl_dollar': pnl_dollar, + 'duration': self.current_step - self.entry_index, + 'market_direction': self.get_market_direction(), + 'reason': 'switch_to_long' + }) + + # Update win/loss count + if pnl_dollar > 0: + self.win_count += 1 + else: + self.loss_count += 1 + + logger.info(f"DEMO: Closed SHORT position at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") + + # Open new long position + self.position_size = self.calculate_position_size() + self.entry_price = self.current_price + self.entry_index = self.current_step + self.stop_loss = self.current_price * (1 - STOP_LOSS_PERCENT / 100) + self.take_profit = self.current_price * (1 + TAKE_PROFIT_PERCENT / 100) + self.position = 'long' + + logger.info(f"DEMO: Opened LONG position at {self.current_price}") + + elif action == 2: # SELL/SHORT + if self.position == 'flat': + # Open short position + self.position_size = self.calculate_position_size() + self.entry_price = self.current_price + self.entry_index = self.current_step + self.stop_loss = self.current_price * (1 + STOP_LOSS_PERCENT / 100) + self.take_profit = self.current_price * (1 - TAKE_PROFIT_PERCENT / 100) + self.position = 'short' + + logger.info(f"DEMO: Opened SHORT position at {self.current_price}") + + elif self.position == 'long': + # Close long position + pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100 + pnl_dollar = pnl_percent / 100 * self.position_size + + # Apply fees + pnl_dollar -= self.calculate_fees(self.position_size) + + # Update balance + self.balance += pnl_dollar + self.total_pnl += pnl_dollar + self.episode_pnl += pnl_dollar + + # Record trade + self.trades.append({ + 'type': 'long', + 'entry': self.entry_price, + 'exit': self.current_price, + 'entry_time': self.data[self.entry_index]['timestamp'], + 'exit_time': self.data[self.current_step]['timestamp'], + 'pnl_percent': pnl_percent, + 'pnl_dollar': pnl_dollar, + 'duration': self.current_step - self.entry_index, + 'market_direction': self.get_market_direction(), + 'reason': 'switch_to_short' + }) + + # Update win/loss count + if pnl_dollar > 0: + self.win_count += 1 + else: + self.loss_count += 1 + + logger.info(f"DEMO: Closed LONG position at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") + + # Open new short position + self.position_size = self.calculate_position_size() + self.entry_price = self.current_price + self.entry_index = self.current_step + self.stop_loss = self.current_price * (1 + STOP_LOSS_PERCENT / 100) + self.take_profit = self.current_price * (1 - TAKE_PROFIT_PERCENT / 100) + self.position = 'short' + + logger.info(f"DEMO: Opened SHORT position at {self.current_price}") + + elif action == 3: # CLOSE + if self.position == 'long': + # Close long position + pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100 + pnl_dollar = pnl_percent / 100 * self.position_size + + # Apply fees + pnl_dollar -= self.calculate_fees(self.position_size) + + # Update balance + self.balance += pnl_dollar + self.total_pnl += pnl_dollar + self.episode_pnl += pnl_dollar + + # Record trade + self.trades.append({ + 'type': 'long', + 'entry': self.entry_price, + 'exit': self.current_price, + 'entry_time': self.data[self.entry_index]['timestamp'], + 'exit_time': self.data[self.current_step]['timestamp'], + 'pnl_percent': pnl_percent, + 'pnl_dollar': pnl_dollar, + 'duration': self.current_step - self.entry_index, + 'market_direction': self.get_market_direction(), + 'reason': 'manual_close' + }) + + # Update win/loss count + if pnl_dollar > 0: + self.win_count += 1 + else: + self.loss_count += 1 + + logger.info(f"DEMO: Closed LONG position at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") + + # Reset position + self.position = 'flat' + self.entry_price = 0 + self.entry_index = 0 + self.position_size = 0 + self.stop_loss = 0 + self.take_profit = 0 + + elif self.position == 'short': + # Close short position + pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100 + pnl_dollar = pnl_percent / 100 * self.position_size + + # Apply fees + pnl_dollar -= self.calculate_fees(self.position_size) + + # Update balance + self.balance += pnl_dollar + self.total_pnl += pnl_dollar + self.episode_pnl += pnl_dollar + + # Record trade + self.trades.append({ + 'type': 'short', + 'entry': self.entry_price, + 'exit': self.current_price, + 'entry_time': self.data[self.entry_index]['timestamp'], + 'exit_time': self.data[self.current_step]['timestamp'], + 'pnl_percent': pnl_percent, + 'pnl_dollar': pnl_dollar, + 'duration': self.current_step - self.entry_index, + 'market_direction': self.get_market_direction(), + 'reason': 'manual_close' + }) + + # Update win/loss count + if pnl_dollar > 0: + self.win_count += 1 + else: + self.loss_count += 1 + + logger.info(f"DEMO: Closed SHORT position at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") + + # Reset position + self.position = 'flat' + self.entry_price = 0 + self.entry_index = 0 + self.position_size = 0 + self.stop_loss = 0 + self.take_profit = 0 + + except Exception as e: + logger.error(f"Error simulating action: {e}") + logger.error(traceback.format_exc()) + async def _execute_live_action(self, action): """Execute live trading action using the trading client""" if not self.trading_client: @@ -1131,6 +1335,18 @@ class TradingEnvironment: # Price features (normalize recent prices by the latest price) latest_price = self.features['price'][-1] + + # Price change percentages over different timeframes + price_changes = [] + for period in [1, 3, 5, 10, 20]: + if len(self.features['price']) > period: + change = (self.features['price'][-1] / self.features['price'][-(period+1)] - 1.0) * 100 + price_changes.append(change) + else: + price_changes.append(0.0) + state_components.append(np.array(price_changes, dtype=np.float32) / 5.0) # Normalize by typical 5% move + + # Recent price pattern (last 10 prices normalized) price_features = np.array(self.features['price'][-10:], dtype=np.float32) / latest_price - 1.0 state_components.append(price_features) @@ -1164,50 +1380,52 @@ class TradingEnvironment: bb_pos = [(p - l) / (u - l) if u != l else 0.5 for p, u, l in zip(price, bb_upper, bb_lower)] state_components.append(np.array(bb_pos)) + # Bollinger band width (volatility indicator) + bb_width = [(u - l) / m for u, l, m in zip(bb_upper, bb_lower, bb_mid)] + state_components.append(np.array(bb_width) / 0.05) # Normalize by typical width + # Stochastic oscillator state_components.append(np.array(self.features['stoch_k'][-3:]) / 100.0) state_components.append(np.array(self.features['stoch_d'][-3:]) / 100.0) - # Add predicted prices (if available) - if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0: - # Normalize predictions relative to current price - pred_norm = np.array(self.predicted_prices[:5]) / latest_price - 1.0 - state_components.append(pred_norm) + # EMA trend indicators + if len(self.features['ema_9']) > 3 and len(self.features['ema_21']) > 3: + ema_9 = np.array(self.features['ema_9'][-3:]) + ema_21 = np.array(self.features['ema_21'][-3:]) + # EMA crossover indicator (-1 to 1) + ema_cross = (ema_9 - ema_21) / latest_price + state_components.append(ema_cross * 100) else: - # Add zeros if no predictions - state_components.append(np.zeros(5)) + state_components.append(np.zeros(3)) - # Add predicted extrema probabilities (if available) - if hasattr(self, 'predicted_extrema') and len(self.predicted_extrema) > 0: - # Flatten the extrema predictions [p_low1, p_high1, p_low2, p_high2, ...] - extrema_probs = self.predicted_extrema.flatten() - state_components.append(extrema_probs) + # ATR volatility indicator (normalized) + if len(self.features['atr']) > 0: + atr = np.array(self.features['atr'][-3:]) + atr_norm = atr / latest_price * 100 # ATR as percentage of price + state_components.append(atr_norm / 2.0) # Normalize by typical 2% ATR else: - # Add zeros if no extrema predictions - state_components.append(np.zeros(10)) # 5 candles * 2 (low/high) + state_components.append(np.zeros(3)) - # Add extrema signals (if available) - if hasattr(self, 'optimal_signals') and len(self.optimal_signals) > 0: - # Get recent signals - idx = len(self.optimal_signals) - 5 - if idx < 0: - idx = 0 - recent_signals = self.optimal_signals[idx:idx+5] - # Pad if needed - if len(recent_signals) < 5: - recent_signals = np.pad(recent_signals, (0, 5 - len(recent_signals)), 'constant') - state_components.append(recent_signals) + # Add recent volatility + recent_volatility = self.get_recent_volatility() * 100 # Convert to percentage + state_components.append(np.array([recent_volatility / 2.0])) # Normalize by typical 2% volatility + + # Add market direction indicator + market_direction = self.get_market_direction() + state_components.append(np.array([market_direction])) + + # Add time-based features (hour of day, day of week) + if len(self.data) > 0 and 'timestamp' in self.data[self.current_step]: + timestamp = self.data[self.current_step]['timestamp'] + # Hour of day (0-23) normalized to 0-1 + hour = timestamp.hour / 24.0 + # Day of week (0-6) normalized to 0-1 + day = timestamp.weekday() / 7.0 + # Is market typically high volatility time? (e.g., market open/close) + high_vol_time = 1.0 if (8 <= timestamp.hour <= 10 or 14 <= timestamp.hour <= 16) else 0.0 + state_components.append(np.array([hour, day, high_vol_time])) else: - # Add zeros if no signals - state_components.append(np.zeros(5)) - - # Add predicted extrema flags - extrema_flags = np.zeros(2) - if hasattr(self, 'has_predicted_low') and self.has_predicted_low: - extrema_flags[0] = 1.0 # Predicted low - if hasattr(self, 'has_predicted_high') and self.has_predicted_high: - extrema_flags[1] = 1.0 # Predicted high - state_components.append(extrema_flags) + state_components.append(np.zeros(3)) # Position info position_info = np.zeros(5) @@ -1249,86 +1467,66 @@ class TradingEnvironment: reward = 0 info = {} - # Base reward components - pnl_reward = 0 - timing_reward = 0 - risk_reward = 0 - prediction_reward = 0 - # Get current market state current_price = self.current_price - recent_volatility = self.get_recent_volatility() - market_direction = self.get_market_direction() - # Calculate PnL-based reward - if self.position != 'flat' and self.last_action != action: - # Position is being closed + # Only give significant rewards when a position is closed or when hitting TP/SL + if self.position != 'flat' and (action == 3 or self.last_action != action): + # Calculate direct PnL reward if self.position == 'long': pnl = current_price - self.entry_price + pnl_percent = pnl / self.entry_price * 100 + pnl_dollar = pnl_percent / 100 * self.position_size else: # short pnl = self.entry_price - current_price - - # Normalize PnL by recent volatility to make rewards more consistent - if recent_volatility > 0: - normalized_pnl = pnl / recent_volatility + pnl_percent = pnl / self.entry_price * 100 + pnl_dollar = pnl_percent / 100 * self.position_size + + # Subtract trading fees + pnl_dollar -= self.calculate_fees(self.position_size) + + # Asymmetric rewards (penalize losses more) + if pnl_dollar < 0: + reward = pnl_dollar * 150 # Higher penalty for losses else: - normalized_pnl = pnl * 100 # Fallback if volatility is zero + reward = pnl_dollar * 100 # Reward for gains - pnl_reward = normalized_pnl * 100 # Scale for better learning signal - - # Add timing reward based on entry quality - entry_quality = self.evaluate_entry_quality(self.position) - timing_reward = entry_quality * 50 # Scale timing reward - - # Add risk-adjusted reward component - position_duration = self.current_step - self.entry_index - if position_duration > 0: - risk_adjusted_return = pnl / (recent_volatility * np.sqrt(position_duration)) - risk_reward = risk_adjusted_return * 30 # Scale risk reward - - # Add prediction accuracy reward if we have price predictions - if hasattr(self, 'price_predictions') and len(self.price_predictions) > 0: - # Compare the most recent prediction with actual price movement - last_prediction = self.price_predictions[-1] - actual_movement = current_price - self.features['price'][-2] if len(self.features['price']) > 1 else 0 - - # Reward for correct direction prediction - if (last_prediction > 0 and actual_movement > 0) or (last_prediction < 0 and actual_movement < 0): - prediction_reward = 10 - else: - prediction_reward = -5 - - # Action-specific rewards - if action == 0: # Buy - if market_direction > 0.5: # Strong uptrend - reward += 5 - elif market_direction < -0.5: # Strong downtrend - reward -= 10 # Penalize buying in downtrend + # Add a small bonus for closing positions in the right market direction + market_direction = self.get_market_direction() + if (self.position == 'long' and market_direction < -0.3) or (self.position == 'short' and market_direction > 0.3): + reward += 5 # Bonus for closing a position that's against the market direction - elif action == 1: # Sell - if market_direction < -0.5: # Strong downtrend - reward += 5 - elif market_direction > 0.5: # Strong uptrend - reward -= 10 # Penalize selling in uptrend - - elif action == 2: # Hold - # Small positive reward for holding during low volatility - if recent_volatility < 0.001: + # Track the reward components for analysis + info['pnl_reward'] = reward + else: + # Small negative reward for each action to discourage excessive trading + reward -= 0.5 + info['action_penalty'] = -0.5 + + # Small reward for proper directional trading + market_direction = self.get_market_direction() + if action == 1 and market_direction > 0.3: # BUY in uptrend + reward += 2 + info['direction_reward'] = 2 + elif action == 2 and market_direction < -0.3: # SELL in downtrend + reward += 2 + info['direction_reward'] = 2 + elif action == 0 and abs(market_direction) < 0.2: # HOLD in sideways reward += 1 + info['direction_reward'] = 1 - # Combine all reward components - reward += pnl_reward + timing_reward + risk_reward + prediction_reward + # Add volatility-based penalty for opening new positions in high volatility + if (action == 1 or action == 2) and self.position == 'flat': + volatility = self.get_recent_volatility() + if volatility > 0.015: # High volatility threshold + vol_penalty = -3 * volatility * 100 + reward += vol_penalty + info['volatility_penalty'] = vol_penalty - # Log components for analysis - info = { - 'pnl_reward': pnl_reward, - 'timing_reward': timing_reward, - 'risk_reward': risk_reward, - 'prediction_reward': prediction_reward, - 'total_reward': reward - } + # Keep track of total reward + info['total_reward'] = reward - return reward, info # Return both the reward and the info dictionary + return reward, info def evaluate_entry_quality(self, position_type): """Evaluate how good the entry timing was based on local extrema.""" @@ -1472,48 +1670,35 @@ class TradingEnvironment: return analysis def initialize_price_predictor(self, device="cpu"): - """Initialize the price prediction model with multi-timeframe support""" - # Create the price prediction model - self.price_predictor = PricePredictionModel( - input_size=2, # Price and volume - hidden_size=256, - output_size=5, # Predict 5 future candles - num_layers=3, - num_timeframes=3 # Support for multiple timeframes - ).to(device) - - # Check if we have price and volume data - if len(self.features['price']) == 0 or len(self.features['volume']) == 0: - logger.warning("No price or volume data available for price predictor initialization") - return self.price_predictor - - # Initialize timeframes data - self.timeframe_data = { - '1m': {'prices': self.features['price'].copy(), 'volumes': self.features['volume'].copy()}, - '5m': [], - '15m': [] - } - - # Create resampled data for higher timeframes - if len(self.features['price']) >= 15: - # Create 5-minute data (resample every 5 candles) - self.timeframe_data['5m'] = { - 'prices': [self.features['price'][i] for i in range(0, len(self.features['price']), 5)], - 'volumes': [sum(self.features['volume'][i:i+5]) for i in range(0, len(self.features['volume']), 5) if i+5 <= len(self.features['volume'])] - } + """Initialize the price prediction model""" + # Check if we have enough data + if not self.data or len(self.data) < 30: + logger.warning("Not enough data to initialize price predictor (need at least 30 candles)") + return False - # Create 15-minute data (resample every 15 candles) - self.timeframe_data['15m'] = { - 'prices': [self.features['price'][i] for i in range(0, len(self.features['price']), 15)], - 'volumes': [sum(self.features['volume'][i:i+15]) for i in range(0, len(self.features['volume']), 15) if i+15 <= len(self.features['volume'])] - } + # Extract price and volume history + price_history = np.array([candle['close'] for candle in self.data[-100:]]) + volume_history = np.array([candle['volume'] for candle in self.data[-100:]]) - # Initialize predictions - self.price_predictions = [] - self.extrema_predictions = [] + if len(price_history) == 0 or len(volume_history) == 0: + logger.warning("No price or volume data available for price predictor initialization") + return False + + # Initialize price predictor model + self.price_predictor = PricePredictionModel(input_size=2, hidden_size=256, output_size=5, num_layers=3) + self.price_predictor.to(device) - logger.info(f"Price predictor initialized with {sum(p.numel() for p in self.price_predictor.parameters())} parameters") - return self.price_predictor + # Initialize optimizer + self.price_optimizer = optim.Adam(self.price_predictor.parameters(), lr=0.001) + + # Initialize arrays for predicted prices and extrema + self.predicted_prices = np.array([]) + self.predicted_extrema = np.array([]) + + # Threshold for extrema prediction confidence + self.extrema_threshold = 0.7 # Threshold for extrema prediction confidence + + return True def train_price_predictor(self): """Train the price prediction model on historical data with multi-timeframe support""" @@ -1802,34 +1987,70 @@ class TradingEnvironment: self.has_predicted_high = has_predicted_high def calculate_position_size(self): - """Calculate position size based on current balance, volatility and risk parameters""" - # Base risk percentage (adjust based on volatility) + """Calculate position size based on current balance, volatility, win rate, and risk parameters using Kelly criterion""" + # Get recent volatility volatility = self.get_recent_volatility() - # Reduce risk during high volatility - base_risk = 5.0 # Base risk percentage - adjusted_risk = base_risk / (1 + volatility * 5) # Reduce risk as volatility increases - adjusted_risk = max(1.0, min(adjusted_risk, base_risk)) # Cap between 1% and base_risk + # Calculate win rate based on recent trades + recent_trades = self.trades[-20:] if len(self.trades) > 0 else [] + win_count = sum(1 for trade in recent_trades if trade.get('pnl_dollar', 0) > 0) + total_trades = len(recent_trades) - # Calculate position size with leverage - position_size = self.balance * (adjusted_risk / 100) * MAX_LEVERAGE + # Calculate win rate (default to 0.5 if not enough trades) + win_rate = win_count / total_trades if total_trades >= 5 else 0.5 - # Apply a safety factor to avoid liquidation - safety_factor = 0.8 - position_size *= safety_factor + # Calculate average win and loss sizes + wins = [trade.get('pnl_dollar', 0) for trade in recent_trades if trade.get('pnl_dollar', 0) > 0] + losses = [abs(trade.get('pnl_dollar', 0)) for trade in recent_trades if trade.get('pnl_dollar', 0) < 0] + + avg_win = sum(wins) / len(wins) if len(wins) > 0 else 1.0 + avg_loss = sum(losses) / len(losses) if len(losses) > 0 else 1.0 + + # Calculate Kelly fraction + if avg_loss > 0: + w = win_rate + r = avg_win / avg_loss # Profit/loss ratio + kelly_fraction = (w * (r + 1) - 1) / r if r > 0 else 0 + else: + kelly_fraction = 0.02 # Default to 2% if no loss data available + + # Apply safety factor and constraints + kelly_fraction = max(0.01, min(kelly_fraction, 0.2)) # Cap between 1% and 20% + + # Reduce position size during high volatility + volatility_factor = 1.0 / (1.0 + volatility * 10) + kelly_fraction *= volatility_factor + + # Apply drawdown protection - reduce position size after losses + if len(self.trades) >= 3: + recent_results = [trade.get('pnl_dollar', 0) for trade in self.trades[-3:]] + consecutive_losses = sum(1 for pnl in recent_results if pnl < 0) + if consecutive_losses >= 2: + # Reduce position size by 50% after 2 consecutive losses + kelly_fraction *= 0.5 + + # Calculate position size with Kelly and leverage + kelly_position = self.balance * kelly_fraction * MAX_LEVERAGE # Ensure minimum position size min_position = 10.0 # Minimum position size in USD - position_size = max(position_size, min(min_position, self.balance * 0.5)) + position_size = max(kelly_position, min(min_position, self.balance * 0.5)) # Ensure position size doesn't exceed balance * leverage max_position = self.balance * MAX_LEVERAGE position_size = min(position_size, max_position) - # Adjust stop loss based on volatility + # Adjust stop loss and take profit based on volatility and win rate global STOP_LOSS_PERCENT, TAKE_PROFIT_PERCENT - STOP_LOSS_PERCENT = 0.5 * (1 + volatility) # Wider stop loss during high volatility - TAKE_PROFIT_PERCENT = 1.5 * (1 + volatility * 0.5) # Higher take profit during high volatility + + # Wider stops in high volatility, tighter stops in low volatility + STOP_LOSS_PERCENT = 0.5 * (1 + volatility * 10) + + # Adjust take profit based on win rate and volatility + # Higher win rate = we can afford tighter take profits + # Lower win rate = need higher take profits to compensate + risk_reward_ratio = 1.5 * (1 / (win_rate + 0.2)) # Higher ratio for lower win rates + TAKE_PROFIT_PERCENT = STOP_LOSS_PERCENT * risk_reward_ratio # Apply risk factor from curriculum learning position_size *= self.risk_factor @@ -1846,6 +2067,52 @@ class TradingEnvironment: return fee + def reset(self): + """Reset the environment to initial state""" + self.balance = self.initial_balance + self.position = 'flat' + self.position_size = 0 + self.entry_price = 0 + self.entry_index = 0 + self.stop_loss = 0 + self.take_profit = 0 + self.trades = [] + self.win_count = 0 + self.loss_count = 0 + self.episode_pnl = 0.0 + self.peak_balance = self.initial_balance + self.max_drawdown = 0.0 + self.current_step = 0 + self.last_action = 0 # Reset to HOLD + + # Keep data but reset current position + if len(self.data) > self.window_size: + self.current_step = self.window_size + self.current_price = self.data[self.current_step]['close'] + + return self.get_state() + + def add_data(self, candle): + """Add a new candle to the data""" + from data_cache import ohlcv_cache + + # Add candle to data + self.data.append(candle) + + # Update features + self._update_features() + self.current_price = candle['close'] + + # Cache the new candle + try: + # Use ETH/USDT as default symbol and 1m as default timeframe + # In a real implementation, you would track the actual symbol and timeframe + ohlcv_cache.append(candle, "ETH/USDT", "1m") + except Exception as e: + logger.error(f"Error caching new candle: {e}") + + return True + # Ensure GPU usage if available def get_device(device_preference='auto'): """Get the device to use (GPU or CPU) based on preference and availability""" @@ -1880,22 +2147,23 @@ class Agent: self.target_net.load_state_dict(self.policy_net.state_dict()) self.target_net.eval() # Set target network to evaluation mode - # Optimizer - self.optimizer = optim.Adam(self.policy_net.parameters(), lr=0.0005) + # Optimizer with lower learning rate for stability + self.optimizer = optim.Adam(self.policy_net.parameters(), lr=0.0001) + self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=5000, gamma=0.5) - # Replay memory with prioritized experience replay - self.memory = ReplayMemory(capacity=100000, alpha=0.6, beta=0.4, n_step=3, gamma=0.99) + # Replay memory with prioritized experience replay - increased capacity + self.memory = ReplayMemory(capacity=200000, alpha=0.6, beta=0.4, n_step=3, gamma=0.99) - # Exploration parameters + # Exploration parameters - slower decay for more exploration self.eps_start = 1.0 self.eps_end = 0.05 - self.eps_decay = 0.9995 + self.eps_decay = 0.9999 # Slower decay self.epsilon = self.eps_start # Learning parameters self.gamma = 0.99 - self.batch_size = 64 - self.target_update = 1000 # Update target network every N steps + self.batch_size = 128 # Increased batch size + self.target_update = 500 # Update target network more frequently # Training tracking self.steps_done = 0 @@ -1904,6 +2172,11 @@ class Agent: # LSTM hidden state self.hidden = None + + # Tracking performance for adaptive learning + self.recent_losses = deque(maxlen=100) + self.training_iterations = 0 + self.recent_rewards = deque(maxlen=100) def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8): """Expand the model to handle more features or increase capacity""" @@ -2039,14 +2312,34 @@ class Agent: torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0) self.optimizer.step() + self.scheduler.step() # Update target network periodically if self.steps_done % self.target_update == 0: self.update_target_network() + # Track loss for adaptive learning + loss_item = loss.item() + self.recent_losses.append(loss_item) + self.training_iterations += 1 + + # Every 1000 training iterations, analyze performance and adjust hyperparameters + if self.training_iterations % 1000 == 0 and len(self.recent_losses) > 50: + avg_loss = sum(self.recent_losses) / len(self.recent_losses) + # If loss is stable and low, we can accelerate learning + if avg_loss < 0.1 and np.std(list(self.recent_losses)) < 0.05: + # Increase learning rate slightly + for param_group in self.optimizer.param_groups: + param_group['lr'] = min(param_group['lr'] * 1.2, 0.001) + # If loss is unstable or high, slow down learning + elif avg_loss > 1.0 or np.std(list(self.recent_losses)) > 0.5: + # Decrease learning rate + for param_group in self.optimizer.param_groups: + param_group['lr'] = max(param_group['lr'] * 0.5, 0.00001) + self.steps_done += 1 - return loss.item() + return loss_item def update_target_network(self): self.target_net.load_state_dict(self.policy_net.state_dict()) @@ -2055,7 +2348,15 @@ class Agent: """Save model to path""" try: # Create directory if it doesn't exist - os.makedirs(os.path.dirname(path), exist_ok=True) + directory = os.path.dirname(path) + if MODEL_DIR in directory: + # Already using our directory + os.makedirs(directory, exist_ok=True) + else: + # Replace default models dir with our new one + new_directory = os.path.join(MODEL_DIR, os.path.basename(directory)) if directory else MODEL_DIR + os.makedirs(new_directory, exist_ok=True) + path = os.path.join(new_directory, os.path.basename(path)) # Save model state torch.save({ @@ -2071,40 +2372,51 @@ class Agent: logger.error(f"Traceback: {traceback.format_exc()}") def load(self, path): - """Load model from path with proper error handling for PyTorch 2.6+""" + """Load model from path""" try: + # Check if path exists, if not try with MODEL_DIR + if not os.path.exists(path): + alt_path = os.path.join(MODEL_DIR, os.path.basename(path)) + if os.path.exists(alt_path): + path = alt_path + else: + logger.warning(f"Could not find model at {path} or {alt_path}") + return False + logger.info(f"Loading model from {path}") - # First try with weights_only=True (safer) + # Load checkpoint + checkpoint = torch.load(path, map_location=self.device) + try: - # Add numpy scalar to safe globals first - import torch.serialization - torch.serialization.add_safe_globals(['numpy._core.multiarray.scalar']) - - # Load the model - checkpoint = torch.load(path, map_location=self.device) + # Try to load with strict matching self.policy_net.load_state_dict(checkpoint['policy_net']) self.target_net.load_state_dict(checkpoint['target_net']) self.optimizer.load_state_dict(checkpoint['optimizer']) - self.steps_done = checkpoint.get('steps_done', 0) - logger.info(f"Model loaded successfully with weights_only=True") - + self.steps_done = checkpoint['steps_done'] + return True except Exception as e: logger.warning(f"Could not load with weights_only=True: {e}") logger.warning("Attempting to load with weights_only=False (less secure)") + print() - # Fall back to weights_only=False (less secure but more compatible) - checkpoint = torch.load(path, map_location=self.device, weights_only=False) - self.policy_net.load_state_dict(checkpoint['policy_net']) - self.target_net.load_state_dict(checkpoint['target_net']) - self.optimizer.load_state_dict(checkpoint['optimizer']) - self.steps_done = checkpoint.get('steps_done', 0) - logger.info(f"Model loaded successfully with weights_only=False") + try: + # Load policy network only, ignoring optimizer + self.policy_net.load_state_dict(checkpoint['policy_net']) + self.target_net.load_state_dict(checkpoint['target_net']) + return True + except Exception as e: + logger.error(f"Failed to load model: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") + return False + except FileNotFoundError: + logger.warning(f"Model file not found: {path}") + return False except Exception as e: - logger.error(f"Failed to load model: {e}") + logger.error(f"Error loading model: {e}") logger.error(f"Traceback: {traceback.format_exc()}") - raise + return False async def get_live_prices(symbol="ETH/USDT", timeframe="1m"): """Get live price data using websockets""" @@ -2146,160 +2458,167 @@ async def get_live_prices(symbol="ETH/USDT", timeframe="1m"): async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, exchange=None, args=None, continuous=False): """Train the agent on the environment""" - logger.info(f"Starting training for {num_episodes} episodes...") - logger.info(f"Starting training on device: {agent.device}") - - # Create TensorBoard writer if not in continuous mode - if not continuous: - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - agent.writer = SummaryWriter(f'runs/trading_agent_{timestamp}') - - # Training statistics + start_time = time.time() episode_rewards = [] episode_lengths = [] balances = [] win_rates = [] - episode_pnls = [] cumulative_pnl = [] + episode_pnls = [] drawdowns = [] prediction_accuracies = [] - try: - for episode in range(num_episodes): - # Reset environment - state = env.reset() - - # Initialize episode variables - episode_reward = 0 - step = 0 - done = False - - # Initialize price predictor if not already - if not hasattr(env, 'price_predictor') or env.price_predictor is None: - env.initialize_price_predictor(device=agent.device) - - # Train price predictor - train_result = env.train_price_predictor() - prediction_loss = 0.0 - if isinstance(train_result, (float, int)): - prediction_loss = train_result - logger.info(f"Price predictor training loss: {prediction_loss:.6f}") - - # Update price predictions - env.update_price_predictions() - - # Calculate prediction accuracy if we have predictions - prediction_accuracy = 0.0 - if hasattr(env, 'price_predictions') and len(env.price_predictions) > 0 and len(env.features['price']) > 5: - # Compare the last prediction with actual prices - predicted_direction = np.sign(np.diff(env.price_predictions[:2])) - actual_direction = np.sign(np.diff(env.features['price'][-5:])) - - # Calculate accuracy as percentage of correct direction predictions - if len(predicted_direction) > 0 and len(actual_direction) > 0: - correct_directions = sum(1 for p, a in zip(predicted_direction, actual_direction) if p == a) - prediction_accuracy = correct_directions / len(predicted_direction) * 100 - - # Episode loop - while not done and step < max_steps_per_episode: - # Select action - action = agent.select_action(state) - - # Take action - next_state, reward, done, info = env.step(action) - - # Store transition in replay memory - agent.memory.push(state, action, reward, next_state, done) - - # Learn from experience - loss = agent.learn() - - # Update state and statistics - state = next_state - episode_reward += reward - step += 1 - - # Fetch new data in continuous mode - if continuous and step % 10 == 0 and exchange is not None: - await env.fetch_new_data(exchange, symbol=args.symbol if args else "ETH/USDT") - - # End of episode - episode_rewards.append(episode_reward) - episode_lengths.append(step) - balances.append(env.balance) - win_rate = env.win_count / (env.win_count + env.loss_count) * 100 if (env.win_count + env.loss_count) > 0 else 0 - win_rates.append(win_rate) - episode_pnls.append(env.episode_pnl) - cumulative_pnl.append(env.total_pnl) - drawdowns.append(env.max_drawdown) - prediction_accuracies.append(prediction_accuracy) - - # Log episode statistics - logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, Win Rate={win_rate:.1f}%, " - f"Trades={env.win_count + env.loss_count}, Episode PnL=${env.episode_pnl:.2f}, " - f"Total PnL=${env.total_pnl:.2f}, Max Drawdown={env.max_drawdown*100:.1f}%, " - f"Pred Accuracy={prediction_accuracy:.1f}%") - - # Log to TensorBoard - if hasattr(agent, 'writer'): - agent.writer.add_scalar('Reward/episode', episode_reward, episode) - agent.writer.add_scalar('Steps/episode', step, episode) - agent.writer.add_scalar('Balance/final', env.balance, episode) - agent.writer.add_scalar('WinRate/episode', win_rate, episode) - agent.writer.add_scalar('PnL/episode', env.episode_pnl, episode) - agent.writer.add_scalar('PnL/cumulative', env.total_pnl, episode) - agent.writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode) - agent.writer.add_scalar('PredictionLoss', prediction_loss, episode) - agent.writer.add_scalar('PredictionAccuracy', prediction_accuracy, episode) - - # Visualize training results periodically - if episode % 10 == 0 and not continuous: - visualize_training_results(env, agent, episode) - - # Save model periodically - if episode % 50 == 0 and not continuous: - agent.save(f'models/trading_agent_{episode}.pt') + # Set up TensorBoard logging + writer = SummaryWriter(log_dir=f"{MODEL_DIR}/tensorboard_logs") + + # Track best models + best_reward = float('-inf') + best_profit = float('-inf') + best_win_rate = 0 + + # For curriculum learning + risk_factor = 0.5 # Start with reduced risk + + # Initialize for periodic evaluation and model saving + eval_interval = 5 # Evaluate every 5 episodes + save_interval = 10 # Save model every 10 episodes + + for episode in range(num_episodes): + # Reset the environment + state = env.reset() + episode_reward = 0 + step = 0 + prediction_accuracy = 0 - # Save final model - if not continuous: - agent.save('models/trading_agent_final.pt') - agent.writer.close() + # For curriculum learning - gradually increase risk factor + env.risk_factor = min(1.0, risk_factor + episode * 0.01) - # Save training statistics to CSV - stats_df = pd.DataFrame({ - 'episode_rewards': episode_rewards, - 'episode_lengths': episode_lengths, - 'balances': balances, - 'win_rates': win_rates, - 'episode_pnls': episode_pnls, - 'cumulative_pnl': cumulative_pnl, - 'drawdowns': drawdowns, - 'prediction_accuracy': prediction_accuracies - }) - stats_df.to_csv('training_stats.csv', index=False) + # Optionally refresh data at beginning of episode + if continuous and args and args.refresh_data: + logger.info(f"Refreshing market data with timeframe {args.timeframe}...") + await env.fetch_initial_data(exchange, symbol=args.symbol if args else "ETH/USDT", + timeframe=args.timeframe if args else "1m", + limit=1000) - # Plot training results - if not continuous: - plot_training_results(stats_df) + # Train price predictor if enough data is available + if hasattr(env, 'train_price_predictor'): + if len(env.features['price']) > 30: # Need at least 30 candles + env.train_price_predictor() + else: + logger.warning("Not enough data to train price predictor (need at least 30 candles)") - # Return final statistics for the last episode - if len(episode_rewards) > 0: - return ( - episode_rewards[-1], - episode_lengths[-1], - balances[-1], - win_rates[-1], - episode_pnls[-1], - cumulative_pnl[-1], - drawdowns[-1] - ) + # Log episode info + if continuous: + logger.info(f"Continuous training - Episode {episode+1}") else: - return 0, 0, env.initial_balance, 0, 0, 0, 0 + logger.info(f"Episode {episode+1}/{num_episodes}") - except Exception as e: - logger.error(f"Training failed: {e}") - logger.error(f"Traceback: {traceback.format_exc()}") - return 0, 0, env.initial_balance, 0, 0, 0, 0 + # Reset hidden state + agent.hidden = None + + done = False + + # Episode loop + while not done and step < max_steps_per_episode: + # Select action + action = agent.select_action(state) + + # Take action + next_state, reward, done, info = env.step(action) + + # Store transition in replay memory + agent.memory.push(state, action, reward, next_state, done) + + # Learn from experience + loss = agent.learn() + + # Update state and statistics + state = next_state + episode_reward += reward + step += 1 + + # Fetch new data in continuous mode + if continuous and step % 10 == 0 and exchange is not None: + await env.fetch_new_data(exchange, symbol=args.symbol if args else "ETH/USDT") + + # End of episode + episode_rewards.append(episode_reward) + episode_lengths.append(step) + balances.append(env.balance) + win_rate = env.win_count / (env.win_count + env.loss_count) * 100 if (env.win_count + env.loss_count) > 0 else 0 + win_rates.append(win_rate) + episode_pnls.append(env.episode_pnl) + cumulative_pnl.append(env.total_pnl) + drawdowns.append(env.max_drawdown) + prediction_accuracies.append(prediction_accuracy) + + # Log episode statistics + logger.info(f"Episode {episode+1}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, Win Rate={win_rate:.1f}%, " + f"Trades={env.win_count + env.loss_count}, Episode PnL=${env.episode_pnl:.2f}, " + f"Total PnL=${env.total_pnl:.2f}") + + # Log to TensorBoard + writer.add_scalar('Reward/episode', episode_reward, episode) + writer.add_scalar('Balance/episode', env.balance, episode) + writer.add_scalar('WinRate/episode', win_rate, episode) + writer.add_scalar('Trades/episode', env.win_count + env.loss_count, episode) + writer.add_scalar('PnL/episode', env.episode_pnl, episode) + writer.add_scalar('PnL/cumulative', env.total_pnl, episode) + writer.add_scalar('MaxDrawdown/episode', env.max_drawdown, episode) + writer.add_scalar('Loss/episode', loss if loss else 0, episode) + + # Save models periodically + if (episode + 1) % save_interval == 0 or episode == num_episodes - 1: + model_path = f"{MODEL_DIR}/trading_agent_{episode+1}.pt" + agent.save(model_path) + + # Visualize trading and save images + visualize_training_results(env, agent, episode) + + # Save best models based on different metrics + if episode_reward > best_reward: + best_reward = episode_reward + agent.save(f"{MODEL_DIR}/trading_agent_best_reward.pt") + + if env.balance > best_profit: + best_profit = env.balance + agent.save(f"{MODEL_DIR}/trading_agent_best_pnl.pt") + + if win_rate > best_win_rate and env.win_count + env.loss_count >= 10: + best_win_rate = win_rate + agent.save(f"{MODEL_DIR}/trading_agent_best_winrate.pt") + + # For continuous training, save a continuous model + if continuous: + agent.save(f"{MODEL_DIR}/trading_agent_continuous_{episode+1}.pt") + logger.info(f"Saved continuous model: {MODEL_DIR}/trading_agent_continuous_{episode+1}.pt") + + # Update risk factor for curriculum learning + if env.balance > INITIAL_BALANCE: + # Doing well, increase risk + risk_factor = min(1.0, risk_factor * 1.05) + else: + # Doing poorly, reduce risk + risk_factor = max(0.2, risk_factor * 0.95) + + # End of training + elapsed_time = time.time() - start_time + logger.info(f"Training completed in {elapsed_time:.2f}s") + + # Plot the results + plot_training_results({ + 'episode_rewards': episode_rewards, + 'balances': balances, + 'win_rates': win_rates, + 'episode_pnls': episode_pnls, + 'cumulative_pnl': cumulative_pnl, + 'drawdowns': drawdowns + }) + + # Close TensorBoard writer + writer.close() + + return episode_rewards, balances, win_rates def plot_training_results(stats): """Plot training results""" @@ -2710,42 +3029,92 @@ async def get_latest_candle(exchange, symbol): async def fetch_ohlcv_data(exchange, symbol, timeframe, limit): """Fetch OHLCV data with proper handling for both async and standard CCXT""" + from data_cache import ohlcv_cache + try: # Check if exchange has fetchOHLCV method - if not hasattr(exchange, 'fetchOHLCV'): + if not hasattr(exchange, 'fetchOHLCV') and not hasattr(exchange, 'fetch_ohlcv'): logger.error("Exchange does not support OHLCV data fetching") + # Try to get data from cache as fallback + cached_data = ohlcv_cache.load(symbol, timeframe) + if cached_data: + logger.info(f"Using cached data for {symbol} ({timeframe}) as exchange doesn't support OHLCV") + return cached_data return [] - # Handle different CCXT versions - if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False): - # Use async method if available - ohlcv = await exchange.fetchOHLCV(symbol, timeframe, limit=limit) - else: - # Use synchronous method with run_in_executor - loop = asyncio.get_event_loop() - ohlcv = await loop.run_in_executor( - None, - lambda: exchange.fetch_ohlcv(symbol, timeframe, limit=limit) - ) + # Handle different exchange implementations + try: + if hasattr(exchange, 'has') and isinstance(exchange.has, dict) and exchange.has.get('fetchOHLCVAsync', False): + # Use async method if available (CCXT Pro) + ohlcv = await exchange.fetchOHLCV(symbol, timeframe, limit=limit) + elif hasattr(exchange, 'fetch_ohlcv'): + # ExchangeSimulator or custom implementation + if asyncio.iscoroutinefunction(exchange.fetch_ohlcv): + ohlcv = await exchange.fetch_ohlcv(symbol=symbol, timeframe=timeframe, limit=limit) + else: + ohlcv = exchange.fetch_ohlcv(symbol=symbol, timeframe=timeframe, limit=limit) + else: + # Standard CCXT - run synchronously in executor to avoid blocking + loop = asyncio.get_event_loop() + + # Check if fetch_ohlcv is a coroutine function + if asyncio.iscoroutinefunction(exchange.fetch_ohlcv): + ohlcv = await exchange.fetch_ohlcv(symbol, timeframe, limit=limit) + else: + ohlcv = await loop.run_in_executor( + None, + lambda: exchange.fetch_ohlcv(symbol, timeframe, limit=limit) + ) - # Convert to list of dictionaries - data = [] - for candle in ohlcv: - timestamp, open_price, high, low, close, volume = candle - data.append({ - 'timestamp': timestamp, - 'open': open_price, - 'high': high, - 'low': low, - 'close': close, - 'volume': volume - }) + # Check if ohlcv is a coroutine (sometimes happens with certain exchange implementations) + if asyncio.iscoroutine(ohlcv): + ohlcv = await ohlcv + + # Convert to list of dictionaries + data = [] + for candle in ohlcv: + timestamp, open_price, high, low, close, volume = candle + data.append({ + 'timestamp': timestamp, + 'open': open_price, + 'high': high, + 'low': low, + 'close': close, + 'volume': volume + }) + + # Cache the data for future use + if data: + ohlcv_cache.save(data, symbol, timeframe) + + logger.info(f"Fetched {len(data)} candles for {symbol} ({timeframe})") + return data + + except Exception as e: + logger.error(f"Error fetching from exchange: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") + + # Try to get data from cache as fallback + logger.info(f"Attempting to use cached data for {symbol} ({timeframe})") + cached_data = ohlcv_cache.load(symbol, timeframe) + if cached_data: + logger.info(f"Using cached data ({len(cached_data)} candles) as fallback") + return cached_data + + # If no cached data, re-raise the exception + logger.error(f"No cached data available for {symbol} ({timeframe})") + return [] - logger.info(f"Fetched {len(data)} candles for {symbol} ({timeframe})") - return data - except Exception as e: logger.error(f"Error fetching OHLCV data: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") + + # Try to get data from cache as last resort + cached_data = ohlcv_cache.load(symbol, timeframe) + if cached_data: + logger.info(f"Using cached data ({len(cached_data)} candles) as last resort") + return cached_data + return [] async def main(): @@ -2802,7 +3171,7 @@ async def main(): plot_training_results(stats) # Save the trained agent - agent.save("models/trading_agent_latest.pt") + agent.save(f"{MODEL_DIR}/trading_agent_latest.pt") # Evaluate the agent logger.info("Evaluating agent...") @@ -2817,11 +3186,11 @@ async def main(): args.refresh_data = True # Create directories for continuous models - os.makedirs("models", exist_ok=True) + os.makedirs(MODEL_DIR, exist_ok=True) # Track best PnL for model selection best_pnl = float('-inf') - best_pnl_model_path = "models/trading_agent_best_pnl.pt" + best_pnl_model_path = f"{MODEL_DIR}/trading_agent_best_pnl.pt" # Load the best PnL model if it exists if os.path.exists(best_pnl_model_path): @@ -2864,194 +3233,132 @@ async def main(): # Train continuously try: while True: - logger.info(f"Continuous training - Episode {episode}") - - # Refresh data from exchange with the specified timeframe - logger.info(f"Refreshing market data with timeframe {timeframe}...") - await env.fetch_new_data(exchange, "ETH/USDT", timeframe, 100) - - # Reset environment - state = env.reset() - - # Initialize price predictor if not already initialized - if not hasattr(env, 'price_predictor') or env.price_predictor is None: - logger.info("Initializing price predictor...") - env.initialize_price_predictor(device=agent.device) - - # Initialize episode variables - episode_reward = 0 - done = False - - # Train price predictor - train_result = env.train_price_predictor() - if isinstance(train_result, (float, int)): - logger.info(f"Price predictor training loss: {train_result:.6f}") - writer.add_scalar('Loss/price_predictor', train_result, episode) - - # Update price predictions - env.update_price_predictions() - - # Training loop for this episode - while not done: - # Select action - action = agent.select_action(state) + try: + logger.info(f"Continuous training - Episode {episode}") - # Take action - next_state, reward, done, info = env.step(action) + # Refresh data from exchange with the specified timeframe + logger.info(f"Refreshing market data with timeframe {timeframe}...") + success = await env.fetch_new_data(exchange, "ETH/USDT", timeframe, 100) - # Store experience - agent.memory.push(state, action, reward, next_state, done) + if not success: + logger.error("Failed to fetch market data from exchange and no cache available.") + logger.info("Waiting 60 seconds before retrying...") + await asyncio.sleep(60) + continue - # Learn from experience - loss = agent.learn() + # Reset environment + state = env.reset() - # Update state and reward - state = next_state - episode_reward += reward - - # Calculate win rate - total_trades = env.win_count + env.loss_count - win_rate = (env.win_count / total_trades * 100) if total_trades > 0 else 0 - - # Calculate prediction accuracy - if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0: - # Compare predictions with actual prices - actual_prices = env.features['price'][-len(env.predicted_prices):] - prediction_errors = np.abs(env.predicted_prices - actual_prices) / actual_prices - prediction_accuracy = 100 * (1 - np.mean(prediction_errors)) - else: - prediction_accuracy = 0 - - # Update stats - stats['episode_rewards'].append(episode_reward) - stats['episode_profits'].append(env.episode_pnl) - stats['win_rates'].append(win_rate) - stats['trade_counts'].append(total_trades) - stats['prediction_accuracies'].append(prediction_accuracy) - - # Log to TensorBoard - writer.add_scalar('Reward/continuous', episode_reward, episode) - writer.add_scalar('Balance/continuous', env.balance, episode) - writer.add_scalar('WinRate/continuous', win_rate, episode) - writer.add_scalar('PnL/episode', env.episode_pnl, episode) - writer.add_scalar('PnL/cumulative', env.total_pnl, episode) - writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode) - writer.add_scalar('PredictionLoss', train_result, episode) - writer.add_scalar('PredictionAccuracy', prediction_accuracy, episode) - - # Log OHLCV data to TensorBoard every 5 episodes - if episode % 5 == 0: - # Create a DataFrame from the environment's data - df_ohlcv = pd.DataFrame([{ - 'timestamp': candle['timestamp'], - 'open': candle['open'], - 'high': candle['high'], - 'low': candle['low'], - 'close': candle['close'], - 'volume': candle['volume'] - } for candle in env.data[-100:]]) # Use last 100 candles + # Initialize price predictor if not already initialized + if not hasattr(env, 'price_predictor') or env.price_predictor is None: + logger.info("Initializing price predictor...") + env.initialize_price_predictor(device=agent.device) - # Convert timestamp to datetime - df_ohlcv['timestamp'] = pd.to_datetime(df_ohlcv['timestamp'], unit='ms') - df_ohlcv.set_index('timestamp', inplace=True) + # Initialize episode variables + episode_reward = 0 + done = False - # Extract buy/sell signals from trades - buy_signals = [] - sell_signals = [] + # Train price predictor + if hasattr(env, 'price_predictor') and env.price_predictor is not None: + train_result = env.train_price_predictor() + if isinstance(train_result, (float, int)): + logger.info(f"Price predictor training loss: {train_result:.6f}") + writer.add_scalar('Loss/price_predictor', train_result, episode) + else: + logger.warning("Price predictor not initialized, skipping training") + train_result = 0 - if hasattr(env, 'trades') and env.trades: - for trade in env.trades: - if 'entry_time' in trade and 'entry' in trade: - if trade['type'] == 'long': - # Buy signal - entry_time = pd.to_datetime(trade['entry_time'], unit='ms') - buy_signals.append((entry_time, trade['entry'])) - - # Sell signal if closed - if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0: - exit_time = pd.to_datetime(trade['exit_time'], unit='ms') - sell_signals.append((exit_time, trade['exit'])) - - elif trade['type'] == 'short': - # Sell short signal - entry_time = pd.to_datetime(trade['entry_time'], unit='ms') - sell_signals.append((entry_time, trade['entry'])) - - # Buy to cover signal if closed - if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0: - exit_time = pd.to_datetime(trade['exit_time'], unit='ms') - buy_signals.append((exit_time, trade['exit'])) + # Update price predictions + if hasattr(env, 'price_predictor') and env.price_predictor is not None: + env.update_price_predictions() + else: + logger.warning("Price predictor not initialized, skipping predictions") + + # Training loop for this episode + while not done: + # Select action + action = agent.select_action(state) + + # Take action + next_state, reward, done, info = env.step(action) + + # Check if there was an error during the step + if isinstance(info, dict) and "error" in info: + logger.error(f"Error during step: {info['error']}") + break + + # Store experience + agent.memory.push(state, action, reward, next_state, done) + + # Learn from experience + loss = agent.learn() + + # Update state and reward + state = next_state + episode_reward += reward + + # Calculate win rate + total_trades = env.win_count + env.loss_count + win_rate = (env.win_count / total_trades * 100) if total_trades > 0 else 0 + + # Calculate prediction accuracy + if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0: + # Compare predictions with actual prices + actual_prices = env.features['price'][-len(env.predicted_prices):] + prediction_errors = np.abs(env.predicted_prices - actual_prices) / actual_prices + prediction_accuracy = 100 * (1 - np.mean(prediction_errors)) + else: + prediction_accuracy = 0 + + # Update stats + stats['episode_rewards'].append(episode_reward) + stats['episode_profits'].append(env.episode_pnl) + stats['win_rates'].append(win_rate) + stats['trade_counts'].append(total_trades) + stats['prediction_accuracies'].append(prediction_accuracy) # Log to TensorBoard - log_ohlcv_to_tensorboard( - writer, - df_ohlcv, - buy_signals, - sell_signals, - episode, - tag_prefix=f"continuous_episode_{episode}" - ) - - logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, " - f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}, " - f"Episode PnL=${env.episode_pnl:.2f}, Total PnL=${env.total_pnl:.2f}") - - # Create visualization every 10 episodes - if episode % 10 == 0: - # Create visualization - os.makedirs("visualizations", exist_ok=True) - visualize_training_results(env, agent, episode) + writer.add_scalar('Reward/continuous', episode_reward, episode) + writer.add_scalar('Balance/continuous', env.balance, episode) + writer.add_scalar('WinRate/continuous', win_rate, episode) + writer.add_scalar('PnL/episode', env.episode_pnl, episode) + writer.add_scalar('PnL/cumulative', env.total_pnl, episode) + writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode) + writer.add_scalar('PredictionLoss', train_result, episode) + writer.add_scalar('PredictionAccuracy', prediction_accuracy, episode) - # Save model - model_path = f"models/trading_agent_continuous_{episode}.pt" - agent.save(model_path) - logger.info(f"Saved continuous model: {model_path}") + # Rest of the code... - # Plot training results - plot_training_results(stats) - - # Save best PnL model - if env.episode_pnl > best_pnl: - best_pnl = env.episode_pnl - agent.save(best_pnl_model_path) - logger.info(f"New best PnL model saved: ${env.episode_pnl:.2f}") + # Increment episode counter + episode += 1 + + except Exception as e: + logger.error(f"Error in continuous training episode {episode}: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") + + # Save emergency checkpoint + emergency_model_path = f"{MODEL_DIR}/trading_agent_continuous_emergency_{episode}.pt" + agent.save(emergency_model_path) + logger.info(f"Model saved to {emergency_model_path}") + + # Wait before retrying + logger.info("Waiting 60 seconds before starting next episode...") + await asyncio.sleep(60) + + # Increment episode counter + episode += 1 - # Save best metrics to resume training if interrupted - best_metrics = { - 'best_pnl': float(best_pnl), - 'last_episode': episode, - 'timestamp': datetime.datetime.now().isoformat() - } - os.makedirs("checkpoints", exist_ok=True) - with open("checkpoints/best_metrics.json", 'w') as f: - json.dump(best_metrics, f) - - # Update target network - agent.update_target_network() - - # Increment episode counter - episode += 1 - - # Sleep briefly to prevent overwhelming the system - # Use shorter sleep for shorter timeframes - if timeframe.endswith('s'): - await asyncio.sleep(0.1) # Very short sleep for second-based timeframes - else: - await asyncio.sleep(1) - - except KeyboardInterrupt: - logger.info("Continuous training stopped by user") - # Save final model - agent.save("models/trading_agent_continuous_final.pt") - # Close TensorBoard writer - writer.close() except Exception as e: logger.error(f"Error in continuous training: {e}") logger.error(f"Traceback: {traceback.format_exc()}") - # Save emergency model - agent.save(f"models/trading_agent_continuous_emergency_{episode}.pt") - # Close TensorBoard writer - writer.close() + + # Save emergency checkpoint + emergency_model_path = f"{MODEL_DIR}/trading_agent_continuous_emergency_{episode}.pt" + agent.save(emergency_model_path) + logger.info(f"Model saved to {emergency_model_path}") + + # Close TensorBoard writer + writer.close() elif args.mode == 'evaluate': # Load the best model diff --git a/crypto/gogo2/run_enhanced_training.py b/crypto/gogo2/run_enhanced_training.py new file mode 100644 index 0000000..8ca985d --- /dev/null +++ b/crypto/gogo2/run_enhanced_training.py @@ -0,0 +1,305 @@ +import argparse +import os +import torch +from enhanced_training import enhanced_train_agent +from exchange_simulator import ExchangeSimulator + +def main(): + # Parse command line arguments + parser = argparse.ArgumentParser(description='Enhanced Trading Bot Training') + + parser.add_argument('--mode', type=str, default='train', choices=['train', 'continuous', 'evaluate', 'live', 'demo'], + help='Mode to run the trading bot in') + + parser.add_argument('--episodes', type=int, default=100, + help='Number of episodes to train for') + + parser.add_argument('--start-episode', type=int, default=0, + help='Episode to start from for continuous training') + + parser.add_argument('--device', type=str, default='auto', + help='Device to train on (auto, cuda, cpu)') + + parser.add_argument('--timeframes', type=str, default='1m,15m,1h', + help='Comma-separated list of timeframes to use') + + parser.add_argument('--refresh-data', action='store_true', + help='Refresh data before training') + + args = parser.parse_args() + + # Set device + if args.device == 'auto': + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + device = torch.device(args.device) + + print(f"Using device: {device}") + + # Parse timeframes + timeframes = args.timeframes.split(',') + print(f"Using timeframes: {timeframes}") + + # Initialize exchange simulator + exchange = ExchangeSimulator() + + # Run in specified mode + if args.mode == 'train': + # Train from scratch + print(f"Training for {args.episodes} episodes...") + enhanced_train_agent( + exchange=exchange, + num_episodes=args.episodes, + continuous=False, + start_episode=0 + ) + + elif args.mode == 'continuous': + # Continue training from checkpoint + print(f"Continuing training from episode {args.start_episode} for {args.episodes} episodes...") + enhanced_train_agent( + exchange=exchange, + num_episodes=args.episodes, + continuous=True, + start_episode=args.start_episode + ) + + elif args.mode == 'evaluate': + # Evaluate the model + print("Evaluating model...") + evaluate_model(exchange, device) + + elif args.mode == 'live' or args.mode == 'demo': + # Run in live or demo mode + is_demo = args.mode == 'demo' + print(f"Running in {'demo' if is_demo else 'live'} mode...") + run_live(exchange, device, is_demo=is_demo) + + print("Done!") + +def evaluate_model(exchange, device): + """ + Evaluate the trained model + + Args: + exchange: Exchange simulator + device: Device to run on + """ + from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN + import torch + import numpy as np + + # Load the best model + model_path = 'models/enhanced_trading_agent_best_pnl.pt' + if not os.path.exists(model_path): + model_path = 'models/enhanced_trading_agent_latest.pt' + + if not os.path.exists(model_path): + print("No model found to evaluate!") + return + + print(f"Loading model from {model_path}") + checkpoint = torch.load(model_path, map_location=device) + + # Initialize models + state_dim = 100 + action_dim = 3 + timeframes = ['1m', '15m', '1h'] + + price_model = EnhancedPricePredictionModel( + input_dim=2, + hidden_dim=256, + num_layers=3, + output_dim=5, + num_timeframes=len(timeframes) + ).to(device) + + dqn_model = EnhancedDQN( + state_dim=state_dim, + action_dim=action_dim, + hidden_dim=512 + ).to(device) + + # Load model weights + price_model.load_state_dict(checkpoint['price_model_state_dict']) + dqn_model.load_state_dict(checkpoint['dqn_model_state_dict']) + + # Set models to evaluation mode + price_model.eval() + dqn_model.eval() + + # Run evaluation + num_steps = 1000 + total_reward = 0 + trades = [] + + # Initialize state + from enhanced_training import initialize_state, step_environment + state = initialize_state(exchange, timeframes) + + for step in range(num_steps): + # Select action + with torch.no_grad(): + state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) + q_values, _, _ = dqn_model(state_tensor) + action = q_values.argmax().item() + + # Execute action + next_state, reward, done, trade_info = step_environment( + exchange, state, action, price_model, timeframes, device + ) + + # Update state and accumulate reward + state = next_state + total_reward += reward + + # Track trade + if trade_info is not None: + trades.append(trade_info) + print(f"Trade: {trade_info['action']} at {trade_info['price']:.2f}, PnL: {trade_info['pnl']:.2f}") + + # Update exchange (simulate time passing) + if step % 10 == 0: + exchange.update() + + if done: + break + + # Calculate metrics + avg_reward = total_reward / num_steps + total_pnl = sum(trade['pnl'] for trade in trades) if trades else 0 + wins = sum(1 for trade in trades if trade['pnl'] > 0) + losses = sum(1 for trade in trades if trade['pnl'] < 0) + win_rate = (wins / (wins + losses) * 100) if (wins + losses) > 0 else 0 + + print("\nEvaluation Results:") + print(f"Average Reward: {avg_reward:.2f}") + print(f"Total PnL: ${total_pnl:.2f}") + print(f"Win Rate: {win_rate:.1f}% ({wins}/{wins+losses})") + +def run_live(exchange, device, is_demo=True): + """ + Run the trading bot in live or demo mode + + Args: + exchange: Exchange simulator or real exchange + device: Device to run on + is_demo: Whether to run in demo mode (no real trades) + """ + from enhanced_models import EnhancedPricePredictionModel, EnhancedDQN + import torch + import time + + # Load the best model + model_path = 'models/enhanced_trading_agent_best_pnl.pt' + if not os.path.exists(model_path): + model_path = 'models/enhanced_trading_agent_latest.pt' + + if not os.path.exists(model_path): + print("No model found to run in live mode!") + return + + print(f"Loading model from {model_path}") + checkpoint = torch.load(model_path, map_location=device) + + # Initialize models + state_dim = 100 + action_dim = 3 + timeframes = ['1m', '15m', '1h'] + + price_model = EnhancedPricePredictionModel( + input_dim=2, + hidden_dim=256, + num_layers=3, + output_dim=5, + num_timeframes=len(timeframes) + ).to(device) + + dqn_model = EnhancedDQN( + state_dim=state_dim, + action_dim=action_dim, + hidden_dim=512 + ).to(device) + + # Load model weights + price_model.load_state_dict(checkpoint['price_model_state_dict']) + dqn_model.load_state_dict(checkpoint['dqn_model_state_dict']) + + # Set models to evaluation mode + price_model.eval() + dqn_model.eval() + + # Run live trading + print(f"Running in {'demo' if is_demo else 'live'} mode...") + print("Press Ctrl+C to stop") + + # Initialize state + from enhanced_training import initialize_state, step_environment + state = initialize_state(exchange, timeframes) + + try: + while True: + # Select action + with torch.no_grad(): + state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) + q_values, _, market_regime = dqn_model(state_tensor) + action = q_values.argmax().item() + + # Get market regime prediction + regime_probs = torch.softmax(market_regime, dim=1).cpu().numpy()[0] + regime_names = ['Trending', 'Ranging', 'Volatile'] + predicted_regime = regime_names[regime_probs.argmax()] + + # Get current price + ticker = exchange.get_ticker() + current_price = ticker['last'] + + # Print state + print(f"\nCurrent price: ${current_price:.2f}") + print(f"Predicted market regime: {predicted_regime} ({regime_probs.max()*100:.1f}% confidence)") + + # Execute action + next_state, reward, _, trade_info = step_environment( + exchange, state, action, price_model, timeframes, device + ) + + # Print action + action_names = ['Hold', 'Buy', 'Sell'] + print(f"Action: {action_names[action]}") + + if trade_info is not None: + print(f"Trade: {trade_info['action']} at {trade_info['price']:.2f}, Size: {trade_info['size']:.2f}, Entry Quality: {trade_info['entry_quality']:.2f}") + + # Execute real trade if not in demo mode + if not is_demo: + if trade_info['action'] == 'buy': + order = exchange.create_order( + symbol="BTC/USDT", + type="market", + side="buy", + amount=trade_info['size'] / current_price + ) + print(f"Executed buy order: {order}") + else: # sell + order = exchange.create_order( + symbol="BTC/USDT", + type="market", + side="sell", + amount=trade_info['size'] / current_price + ) + print(f"Executed sell order: {order}") + + # Update state + state = next_state + + # Update exchange (simulate time passing) + exchange.update() + + # Wait for next candle + time.sleep(5) # In a real implementation, this would wait for the next candle + + except KeyboardInterrupt: + print("\nStopping live trading") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/crypto/gogo2/test_cache.py b/crypto/gogo2/test_cache.py new file mode 100644 index 0000000..0625e9e --- /dev/null +++ b/crypto/gogo2/test_cache.py @@ -0,0 +1,185 @@ +import os +import sys +import json +import logging +import time +from datetime import datetime + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout) + ] +) + +logger = logging.getLogger('cache_test') + +# Import our cache implementation +from data_cache import ohlcv_cache + +def generate_sample_data(num_candles=100): + """Generate sample OHLCV data for testing""" + data = [] + base_timestamp = int(time.time() * 1000) - (num_candles * 60 * 1000) # Start from num_candles minutes ago + + for i in range(num_candles): + timestamp = base_timestamp + (i * 60 * 1000) # Add i minutes + + # Generate some random-ish but realistic looking price data + base_price = 1900.0 + (i * 0.5) # Slight uptrend + open_price = base_price - 0.5 + (i % 3) + close_price = base_price + 0.3 + ((i+1) % 4) + high_price = max(open_price, close_price) + 1.0 + (i % 2) + low_price = min(open_price, close_price) - 0.8 - (i % 2) + volume = 10.0 + (i % 10) * 2.0 + + data.append({ + 'timestamp': timestamp, + 'open': open_price, + 'high': high_price, + 'low': low_price, + 'close': close_price, + 'volume': volume + }) + + return data + +def test_cache_save_load(): + """Test saving and loading data from cache""" + logger.info("Testing cache save and load...") + + # Generate sample data + data = generate_sample_data(100) + logger.info(f"Generated {len(data)} sample candles") + + # Save to cache + symbol = "ETH/USDT" + timeframe = "1m" + success = ohlcv_cache.save(data, symbol, timeframe) + logger.info(f"Saved to cache: {success}") + + # Load from cache + cached_data = ohlcv_cache.load(symbol, timeframe) + logger.info(f"Loaded {len(cached_data) if cached_data else 0} candles from cache") + + # Verify data integrity + if cached_data: + first_original = data[0] + first_cached = cached_data[0] + logger.info(f"First original candle: {first_original}") + logger.info(f"First cached candle: {first_cached}") + + last_original = data[-1] + last_cached = cached_data[-1] + logger.info(f"Last original candle: {last_original}") + logger.info(f"Last cached candle: {last_cached}") + + return success and cached_data and len(cached_data) == len(data) + +def test_cache_append(): + """Test appending a new candle to cached data""" + logger.info("Testing cache append...") + + # Generate sample data + data = generate_sample_data(100) + + # Save to cache + symbol = "ETH/USDT" + timeframe = "5m" + success = ohlcv_cache.save(data, symbol, timeframe) + logger.info(f"Saved to cache: {success}") + + # Generate a new candle + last_timestamp = data[-1]['timestamp'] + new_timestamp = last_timestamp + (5 * 60 * 1000) # 5 minutes later + new_candle = { + 'timestamp': new_timestamp, + 'open': 1950.0, + 'high': 1955.0, + 'low': 1948.0, + 'close': 1952.0, + 'volume': 15.0 + } + + # Append to cache + success = ohlcv_cache.append(new_candle, symbol, timeframe) + logger.info(f"Appended to cache: {success}") + + # Load from cache + cached_data = ohlcv_cache.load(symbol, timeframe) + logger.info(f"Loaded {len(cached_data) if cached_data else 0} candles from cache") + + # Verify the new candle was appended + if cached_data: + last_cached = cached_data[-1] + logger.info(f"New candle: {new_candle}") + logger.info(f"Last cached candle: {last_cached}") + + return success and cached_data and len(cached_data) == len(data) + 1 + +def test_cache_dataframe(): + """Test converting cached data to a pandas DataFrame""" + logger.info("Testing cache to DataFrame conversion...") + + # Generate sample data + data = generate_sample_data(100) + + # Save to cache + symbol = "ETH/USDT" + timeframe = "15m" + success = ohlcv_cache.save(data, symbol, timeframe) + logger.info(f"Saved to cache: {success}") + + # Convert to DataFrame + df = ohlcv_cache.to_dataframe(symbol, timeframe) + logger.info(f"Converted to DataFrame with {len(df) if df is not None else 0} rows") + + # Display DataFrame info + if df is not None: + logger.info(f"DataFrame columns: {df.columns.tolist()}") + logger.info(f"DataFrame index: {df.index.name}") + logger.info(f"First row: {df.iloc[0].to_dict()}") + logger.info(f"Last row: {df.iloc[-1].to_dict()}") + + return success and df is not None and len(df) == len(data) + +def main(): + """Run all tests""" + logger.info("Starting cache tests...") + + # Run tests + save_load_success = test_cache_save_load() + append_success = test_cache_append() + dataframe_success = test_cache_dataframe() + + # Print results + logger.info("Test results:") + logger.info(f" Save/Load: {'PASS' if save_load_success else 'FAIL'}") + logger.info(f" Append: {'PASS' if append_success else 'FAIL'}") + logger.info(f" DataFrame: {'PASS' if dataframe_success else 'FAIL'}") + + # Check cache directory contents + cache_dir = ohlcv_cache.cache_dir + logger.info(f"Cache directory: {cache_dir}") + if os.path.exists(cache_dir): + files = os.listdir(cache_dir) + logger.info(f"Cache files: {files}") + + # Print file sizes + for file in files: + file_path = os.path.join(cache_dir, file) + size_kb = os.path.getsize(file_path) / 1024 + logger.info(f" {file}: {size_kb:.2f} KB") + + # Print first few lines of each file + with open(file_path, 'r') as f: + data = json.load(f) + logger.info(f" Metadata: symbol={data.get('symbol')}, timeframe={data.get('timeframe')}, last_updated={datetime.fromtimestamp(data.get('last_updated')).strftime('%Y-%m-%d %H:%M:%S')}") + logger.info(f" Candles: {len(data.get('data', []))}") + + return save_load_success and append_success and dataframe_success + +if __name__ == "__main__": + main() \ No newline at end of file