diff --git a/crypto/gogo2/live_training.py b/crypto/gogo2/live_training.py index a3a1949..f3d0685 100644 --- a/crypto/gogo2/live_training.py +++ b/crypto/gogo2/live_training.py @@ -9,6 +9,8 @@ import datetime import traceback import numpy as np import torch +import gc +from functools import partial from main import initialize_exchange, TradingEnvironment, Agent from torch.utils.tensorboard import SummaryWriter @@ -54,6 +56,11 @@ def robust_save(model, path): # Backup path in case the main save fails backup_path = f"{path}.backup" + # Clean up GPU memory before saving + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + # Attempt 1: Try with default settings in a separate file first try: logger.info(f"Saving model to {backup_path} (attempt 1)") @@ -122,6 +129,28 @@ def robust_save(model, path): logger.error(f"All save attempts failed: {e}") return False +# Implement timeout wrapper for exchange operations +async def with_timeout(coroutine, timeout=30, default=None): + """ + Execute a coroutine with a timeout + + Args: + coroutine: The coroutine to execute + timeout: Timeout in seconds + default: Default value to return on timeout + + Returns: + The result of the coroutine or default value on timeout + """ + try: + return await asyncio.wait_for(coroutine, timeout=timeout) + except asyncio.TimeoutError: + logger.warning(f"Operation timed out after {timeout} seconds") + return default + except Exception as e: + logger.error(f"Operation failed: {e}") + return default + # Implement fetch_and_update_data function async def fetch_and_update_data(exchange, env, symbol, timeframe): """ @@ -139,8 +168,12 @@ async def fetch_and_update_data(exchange, env, symbol, timeframe): # Default to 100 candles if not specified limit = 1000 - # Fetch OHLCV data - candles = await exchange.fetch_ohlcv(symbol, timeframe, limit=limit) + # Fetch OHLCV data with timeout + candles = await with_timeout( + exchange.fetch_ohlcv(symbol, timeframe, limit=limit), + timeout=30, + default=[] + ) if not candles or len(candles) == 0: logger.warning(f"No candles returned for {symbol} on {timeframe}") @@ -181,6 +214,16 @@ async def fetch_and_update_data(exchange, env, symbol, timeframe): logger.error(traceback.format_exc()) return False +# Implement memory management function +def manage_memory(): + """ + Clean up memory to avoid memory leaks during long running sessions + """ + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + logger.debug("Memory cleaned") + async def live_training( symbol="ETH/USDT", timeframe="1m", @@ -194,6 +237,8 @@ async def live_training( gamma=0.99, window_size=30, max_episodes=0, # 0 means unlimited + retry_delay=5, # Seconds to wait before retrying after an error + max_retries=3, # Maximum number of retries for operations ): """ Live training function that uses real market data to improve the model without executing real trades. @@ -211,15 +256,30 @@ async def live_training( gamma: Discount factor for training window_size: Window size for the environment max_episodes: Maximum number of episodes (0 for unlimited) + retry_delay: Seconds to wait before retrying after an error + max_retries: Maximum number of retries for operations """ logger.info(f"Starting live training for {symbol} on {timeframe} timeframe") # Initialize exchange (without sandbox mode) exchange = None + + # Retry loop for exchange initialization + for retry in range(max_retries): + try: + exchange = await initialize_exchange() + logger.info(f"Exchange initialized: {exchange.id}") + break + except Exception as e: + logger.error(f"Error initializing exchange (attempt {retry+1}/{max_retries}): {e}") + if retry < max_retries - 1: + logger.info(f"Retrying in {retry_delay} seconds...") + await asyncio.sleep(retry_delay) + else: + logger.error("Max retries reached. Could not initialize exchange.") + return + try: - exchange = await initialize_exchange() - logger.info(f"Exchange initialized: {exchange.id}") - # Initialize environment env = TradingEnvironment( initial_balance=initial_balance, @@ -228,11 +288,20 @@ async def live_training( timeframe=timeframe, ) - # Fetch initial data + # Fetch initial data (with retries) logger.info(f"Fetching initial data for {symbol}") - success = await fetch_and_update_data(exchange, env, symbol, timeframe) + success = False + for retry in range(max_retries): + success = await fetch_and_update_data(exchange, env, symbol, timeframe) + if success: + break + logger.warning(f"Failed to fetch initial data (attempt {retry+1}/{max_retries})") + if retry < max_retries - 1: + logger.info(f"Retrying in {retry_delay} seconds...") + await asyncio.sleep(retry_delay) + if not success: - logger.error("Failed to fetch initial data, exiting") + logger.error("Failed to fetch initial data after multiple attempts, exiting") return # Initialize agent @@ -268,6 +337,10 @@ async def live_training( step_counter = 0 last_update_time = datetime.datetime.now() + # Track consecutive errors to enable circuit breaker + consecutive_errors = 0 + max_consecutive_errors = 5 + while True: # Check if we've reached the maximum number of episodes if max_episodes > 0 and episode_count >= max_episodes: @@ -284,11 +357,14 @@ async def live_training( if not success: logger.warning("Failed to update data, will try again later") # Wait a bit before trying again - await asyncio.sleep(5) + await asyncio.sleep(retry_delay) continue last_update_time = current_time + # Clean up memory before running an episode + manage_memory() + # Run training iterations on the updated data episode_reward = 0 env.reset() @@ -337,12 +413,22 @@ async def live_training( # Train the agent on a batch of experiences if len(agent.memory) > batch_size: - agent.learn() - - # Additional training iterations - if steps_in_episode % 10 == 0 and training_iterations > 1: - for _ in range(training_iterations - 1): - agent.learn() + try: + agent.learn() + + # Additional training iterations + if steps_in_episode % 10 == 0 and training_iterations > 1: + for _ in range(training_iterations - 1): + agent.learn() + + # Reset consecutive errors counter on successful learning + consecutive_errors = 0 + except Exception as e: + logger.error(f"Error during learning: {e}") + consecutive_errors += 1 + if consecutive_errors >= max_consecutive_errors: + logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors") + break if done: logger.info(f"Episode done after {steps_in_episode} steps") @@ -351,7 +437,10 @@ async def live_training( except Exception as e: logger.error(f"Error during episode step: {e}") logger.error(traceback.format_exc()) - break + consecutive_errors += 1 + if consecutive_errors >= max_consecutive_errors: + logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors") + break # Update training statistics episode_count += 1 @@ -419,26 +508,29 @@ async def live_training( logger.error(traceback.format_exc()) finally: # Save final model - if robust_save(agent, save_path): - logger.info(f"Final model saved to {save_path}") - else: - logger.error(f"Failed to save final model") - - # Close TensorBoard writer - try: - writer.close() - logger.info("TensorBoard writer closed") - except Exception as e: - logger.error(f"Error closing TensorBoard writer: {e}") + if 'agent' in locals(): + if robust_save(agent, save_path): + logger.info(f"Final model saved to {save_path}") + else: + logger.error(f"Failed to save final model") + + # Close TensorBoard writer + try: + writer.close() + logger.info("TensorBoard writer closed") + except Exception as e: + logger.error(f"Error closing TensorBoard writer: {e}") # Close exchange connection if exchange: try: - await exchange.close() + await with_timeout(exchange.close(), timeout=10) logger.info("Exchange connection closed") except Exception as e: logger.error(f"Error closing exchange connection: {e}") + # Final memory cleanup + manage_memory() logger.info("Live training completed") async def main(): @@ -452,6 +544,8 @@ async def main(): parser.add_argument('--update_interval', type=int, default=60, help='Interval to update data in seconds') parser.add_argument('--training_iterations', type=int, default=100, help='Training iterations per update') parser.add_argument('--max_episodes', type=int, default=0, help='Maximum number of episodes (0 for unlimited)') + parser.add_argument('--retry_delay', type=int, default=5, help='Seconds to wait before retrying after an error') + parser.add_argument('--max_retries', type=int, default=3, help='Maximum number of retries for operations') args = parser.parse_args() @@ -466,9 +560,10 @@ async def main(): update_interval=args.update_interval, training_iterations=args.training_iterations, max_episodes=args.max_episodes, + retry_delay=args.retry_delay, + max_retries=args.max_retries, ) -# At the beginning of the file, after importing the modules # Override Agent's save method with our robust save function def monkey_patch_agent_save(): """Replace Agent's save method with our robust save approach""" diff --git a/crypto/gogo2/main.py b/crypto/gogo2/main.py index 25ac5dc..13e1ade 100644 --- a/crypto/gogo2/main.py +++ b/crypto/gogo2/main.py @@ -33,6 +33,11 @@ from datetime import datetime as dt from collections import defaultdict from gym.spaces import Discrete, Box import csv +import gc +import shutil +import math +import platform +import ctypes # Configure logging logging.basicConfig( @@ -1917,6 +1922,8 @@ class Agent: # Initialize exploration parameters self.epsilon = EPSILON_START + self.epsilon_start = EPSILON_START + self.epsilon_end = EPSILON_END self.epsilon_decay = EPSILON_DECAY self.epsilon_min = EPSILON_END @@ -2142,55 +2149,38 @@ class Agent: def update_epsilon(self, episode): """Update epsilon value based on episode number""" - self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) + # Calculate epsilon using a linear decay formula + epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \ + max(0, (self.epsilon_decay - episode)) / self.epsilon_decay + + # Update self.epsilon with the calculated value + self.epsilon = max(self.epsilon_min, epsilon) + return self.epsilon def update_target_network(self): self.target_net.load_state_dict(self.policy_net.state_dict()) def save(self, path="models/trading_agent_best_pnl.pt"): - """Save the model in a format compatible with PyTorch 2.6+""" + """Save the model using a robust saving approach with multiple fallbacks""" try: # Create directory if it doesn't exist os.makedirs(os.path.dirname(path), exist_ok=True) - # Ensure architecture parameters are set - if not hasattr(self, 'hidden_size'): - self.hidden_size = 256 # Default value - logger.warning("Setting default hidden_size=256 for saving") + # Call robust save function + success = robust_save(self, path) + + if success: + logger.info(f"Model saved successfully to {path}") + return True + else: + logger.error(f"All save attempts failed for path: {path}") + return False - if not hasattr(self, 'lstm_layers'): - self.lstm_layers = 2 # Default value - logger.warning("Setting default lstm_layers=2 for saving") - - if not hasattr(self, 'attention_heads'): - self.attention_heads = 4 # Default value - logger.warning("Setting default attention_heads=4 for saving") - - # Save model state - checkpoint = { - 'policy_net': self.policy_net.state_dict(), - 'target_net': self.target_net.state_dict(), - 'optimizer': self.optimizer.state_dict(), - 'epsilon': self.epsilon, - 'state_size': self.state_size, - 'action_size': self.action_size, - 'hidden_size': self.hidden_size, - 'lstm_layers': self.lstm_layers, - 'attention_heads': self.attention_heads - } - - # Save scaler state if it exists - if hasattr(self, 'scaler') and self.scaler is not None: - checkpoint['scaler'] = self.scaler.state_dict() - - # Save with pickle_protocol=4 for better compatibility - torch.save(checkpoint, path, _use_new_zipfile_serialization=True, pickle_protocol=4) - logger.info(f"Model saved to {path}") except Exception as e: - logger.error(f"Error saving model: {e}") - import traceback + logger.error(f"Error in save method: {e}") logger.error(traceback.format_exc()) + return False def load(self, path="models/trading_agent_best_pnl.pt"): """Load a trained model with improved error handling for PyTorch 2.6 compatibility""" @@ -2412,15 +2402,16 @@ async def get_live_prices(symbol="ETH/USDT", timeframe="1m"): await asyncio.sleep(5) break -async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000): +async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, use_compact_save=False): """ - Train the agent using reinforcement learning with multi-timeframe data and CNN pattern recognition. + Train the agent in the environment. Args: agent: The agent to train env: The trading environment num_episodes: Number of episodes to train for max_steps_per_episode: Maximum steps per episode + use_compact_save: Whether to use compact save (for low disk space) Returns: Training statistics @@ -2467,9 +2458,19 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000) # Make directory for models if it doesn't exist os.makedirs('models', exist_ok=True) + # Memory management function + def clean_memory(): + """Clean up memory to avoid memory leaks""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + # Start training loop for episode in range(num_episodes): try: + # Clean up memory before starting a new episode + clean_memory() + # Reset environment state = env.reset() episode_reward = 0 @@ -2488,6 +2489,10 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000) except Exception as e: logging.error(f"Failed to fetch candle data: {e}") + # Track consecutive errors + consecutive_errors = 0 + max_consecutive_errors = 5 + # Episode loop for step in range(max_steps_per_episode): try: @@ -2516,8 +2521,15 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000) global_step = episode * max_steps_per_episode + step if writer: writer.add_scalar('Loss/step', loss, global_step) + + # Reset consecutive errors counter on successful learning + consecutive_errors = 0 except Exception as e: logging.error(f"Error during learning: {e}") + consecutive_errors += 1 + if consecutive_errors >= max_consecutive_errors: + logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors") + break # Update target network periodically if step % TARGET_UPDATE == 0: @@ -2547,6 +2559,10 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000) except Exception as e: logging.warning(f"Error updating predictions: {e}") + # Clean memory periodically during long episodes + if step % 200 == 0 and step > 0: + clean_memory() + # Add chart to TensorBoard periodically if step % 100 == 0 or (step == max_steps_per_episode - 1) or done: try: @@ -2561,7 +2577,10 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000) except Exception as e: logging.error(f"Error in training step: {e}") - break + consecutive_errors += 1 + if consecutive_errors >= max_consecutive_errors: + logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors") + break # Calculate statistics from this episode balance = env.balance @@ -2619,23 +2638,50 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000) # Save model if this is the best reward or PnL if episode_reward > best_reward: best_reward = episode_reward - agent.save('models/trading_agent_best_reward.pt') - logging.info(f"New best reward: {best_reward:.2f}") + try: + if use_compact_save: + success = compact_save(agent, 'models/trading_agent_best_reward.pt') + else: + success = agent.save('models/trading_agent_best_reward.pt') + if success: + logging.info(f"New best reward: {best_reward:.2f}") + except Exception as e: + logging.error(f"Error saving best reward model: {e}") if pnl > best_pnl: best_pnl = pnl - agent.save('models/trading_agent_best_pnl.pt') - logging.info(f"New best PnL: ${best_pnl:.2f}") + try: + if use_compact_save: + success = compact_save(agent, 'models/trading_agent_best_pnl.pt') + else: + success = agent.save('models/trading_agent_best_pnl.pt') + if success: + logging.info(f"New best PnL: ${best_pnl:.2f}") + except Exception as e: + logging.error(f"Error saving best PnL model: {e}") # Save model if this is the best net PnL (after fees) if net_pnl > best_net_pnl: best_net_pnl = net_pnl - agent.save('models/trading_agent_best_net_pnl.pt') - logging.info(f"New best Net PnL: ${best_net_pnl:.2f}") + try: + if use_compact_save: + success = compact_save(agent, 'models/trading_agent_best_net_pnl.pt') + else: + success = agent.save('models/trading_agent_best_net_pnl.pt') + if success: + logging.info(f"New best Net PnL: ${best_net_pnl:.2f}") + except Exception as e: + logging.error(f"Error saving best net PnL model: {e}") # Save checkpoint periodically if episode % 10 == 0: - agent.save(f'models/trading_agent_checkpoint_{episode}.pt') + try: + if use_compact_save: + compact_save(agent, f'models/trading_agent_checkpoint_{episode}.pt') + else: + agent.save(f'models/trading_agent_checkpoint_{episode}.pt') + except Exception as e: + logging.error(f"Error saving checkpoint model: {e}") # Update epsilon agent.update_epsilon(episode) @@ -2654,23 +2700,50 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000) except Exception as e: logging.error(f"Error in episode {episode}: {e}") - import traceback logging.error(traceback.format_exc()) continue + # Clean memory before saving final model + clean_memory() + # Save final model - agent.save('models/trading_agent_final.pt') + try: + if use_compact_save: + compact_save(agent, 'models/trading_agent_final.pt') + else: + agent.save('models/trading_agent_final.pt') + except Exception as e: + logging.error(f"Error saving final model: {e}") # Save training statistics to file try: import pandas as pd + + # Make sure all arrays in stats are the same length by padding with NaN + max_length = max(len(v) for k, v in stats.items() if isinstance(v, list)) + for k, v in stats.items(): + if isinstance(v, list) and len(v) < max_length: + stats[k] = v + [float('nan')] * (max_length - len(v)) + + # Create dataframe and save stats_df = pd.DataFrame(stats) stats_df.to_csv('training_stats.csv', index=False) logging.info(f"Training statistics saved to training_stats.csv") except Exception as e: logging.error(f"Failed to save training statistics: {e}") - # Fallback to numpy save - np.save('training_stats.npy', stats) + logging.error(traceback.format_exc()) + + # Close exchange if it's still open + if exchange: + try: + # Check if exchange has the close method (ccxt.async_support) + if hasattr(exchange, 'close'): + await exchange.close() + logging.info("Closed exchange connection") + else: + logging.info("Exchange doesn't have close method (standard ccxt), skipping close") + except Exception as e: + logging.error(f"Error closing exchange: {e}") return stats @@ -3467,6 +3540,8 @@ async def main(): help='Operation mode: train, eval, or live') parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes for training or evaluation') + parser.add_argument('--max_steps', type=int, default=1000, + help='Maximum steps per episode for training') parser.add_argument('--demo', type=str, choices=['true', 'false'], default='true', help='Run in demo mode (paper trading) if true') parser.add_argument('--symbol', type=str, default='ETH/USDT', @@ -3477,6 +3552,8 @@ async def main(): help='Leverage for futures trading') parser.add_argument('--model', type=str, default=None, help='Path to model file for evaluation or live trading') + parser.add_argument('--compact_save', action='store_true', + help='Use compact model saving (for low disk space)') args = parser.parse_args() @@ -3512,7 +3589,9 @@ async def main(): # Train the agent logger.info(f"Starting training for {args.episodes} episodes...") - stats = await train_agent(agent, env, num_episodes=args.episodes) + stats = await train_agent(agent, env, num_episodes=args.episodes, + max_steps_per_episode=args.max_steps, + use_compact_save=args.compact_save) elif args.mode == 'eval' or args.mode == 'live': # Fetch initial data for the specified symbol and timeframe @@ -3698,92 +3777,73 @@ def create_candlestick_figure(data, trades=None, title="Trading Chart"): return None class CandlePatternCNN(nn.Module): - """ - Multi-timeframe CNN for candle pattern recognition. - Extracts features from 1s, 1m, 1h, and 1d candle data. - """ + """Convolutional neural network for detecting candlestick patterns""" + def __init__(self, input_channels=5, feature_dimension=512): super(CandlePatternCNN, self).__init__() + self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1) + self.relu1 = nn.ReLU() + self.pool1 = nn.MaxPool2d(kernel_size=2) + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) + self.relu2 = nn.ReLU() + self.pool2 = nn.MaxPool2d(kernel_size=2) + self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) + self.relu3 = nn.ReLU() + self.pool3 = nn.MaxPool2d(kernel_size=2) - # Base convolutional network for each timeframe - self.base_conv = nn.Sequential( - nn.Conv2d(input_channels, 64, kernel_size=(1, 3), padding=(0, 1)), - nn.BatchNorm2d(64), - nn.ReLU(), - nn.Conv2d(64, 128, kernel_size=(1, 5), padding=(0, 2)), - nn.BatchNorm2d(128), - nn.ReLU(), - nn.MaxPool2d(kernel_size=(1, 2)), - nn.Conv2d(128, 256, kernel_size=(1, 5), padding=(0, 2)), - nn.BatchNorm2d(256), - nn.ReLU(), - nn.MaxPool2d(kernel_size=(1, 2)), - nn.Conv2d(256, 512, kernel_size=(1, 3), padding=(0, 1)), - nn.BatchNorm2d(512), - nn.ReLU() - ) - - # Feature fusion layers - self.fusion = nn.Sequential( - nn.Linear(512 * 4 * 75, 2048), # 4 timeframes, assume 300/4=75 candles after pooling - nn.ReLU(), - nn.Dropout(0.3), - nn.Linear(2048, 1024), - nn.ReLU(), - nn.Dropout(0.2), - nn.Linear(1024, feature_dimension) - ) - - # Store intermediate activations - self.intermediate_features = {} + # Projection layers + self.fc1 = nn.Linear(128 * 4 * 4, 1024) + self.relu4 = nn.ReLU() + self.fc2 = nn.Linear(1024, feature_dimension) + # Initialize intermediate features as empty tensors, not as a dict + # This makes the model TorchScript compatible + self.feature_1m = torch.zeros(1, feature_dimension) + self.feature_1h = torch.zeros(1, feature_dimension) + self.feature_1d = torch.zeros(1, feature_dimension) + def forward(self, x_1m, x_1h, x_1d): - """ - Process candle data from multiple timeframes. + # Process 1m data + feat_1m = self.process_timeframe(x_1m) - Args: - x_1m: Tensor of shape [batch, channels, history_len] for 1-minute candles - x_1h: Tensor of shape [batch, channels, history_len] for 1-hour candles - x_1d: Tensor of shape [batch, channels, history_len] for 1-day candles - - Returns: - Tensor of extracted features - """ - # Add a dimension for the conv2d to work properly - x_1m = x_1m.unsqueeze(2) - x_1h = x_1h.unsqueeze(2) - x_1d = x_1d.unsqueeze(2) + # Process 1h data + feat_1h = self.process_timeframe(x_1h) - # Extract features from each timeframe - feat_1m = self.base_conv(x_1m) - feat_1h = self.base_conv(x_1h) - feat_1d = self.base_conv(x_1d) + # Process 1d data + feat_1d = self.process_timeframe(x_1d) - # Store intermediate features - self.intermediate_features['1m'] = feat_1m - self.intermediate_features['1h'] = feat_1h - self.intermediate_features['1d'] = feat_1d + # Store features as attributes instead of in a dictionary + self.feature_1m = feat_1m + self.feature_1h = feat_1h + self.feature_1d = feat_1d - # Flatten and concatenate features - batch_size = x_1m.size(0) - feat_1m = feat_1m.view(batch_size, -1) - feat_1h = feat_1h.view(batch_size, -1) - feat_1d = feat_1d.view(batch_size, -1) - - # Combine features for all timeframes + # Concatenate features from different timeframes combined_features = torch.cat([feat_1m, feat_1h, feat_1d], dim=1) - # Process through fusion layers - output = self.fusion(combined_features) + return combined_features + + def process_timeframe(self, x): + """Process a single timeframe batch of data""" + # Ensure proper shape for input, handle both batched and single inputs + if len(x.shape) == 3: # Single input, shape: [channels, height, width] + x = x.unsqueeze(0) # Add batch dimension - # Store final layer features - self.intermediate_features['fusion'] = output + x = self.pool1(self.relu1(self.conv1(x))) + x = self.pool2(self.relu2(self.conv2(x))) + x = self.pool3(self.relu3(self.conv3(x))) - return output + # Flatten the spatial dimensions for the fully connected layer + x = x.view(x.size(0), -1) + + x = self.relu4(self.fc1(x)) + x = self.fc2(x) + + return x def get_features(self): - """Returns dictionary of intermediate features for use by the agent""" - return self.intermediate_features + """Return features for each timeframe""" + # Use properties instead of dict for TorchScript compatibility + return self.feature_1m, self.feature_1h, self.feature_1d # Add candle cache system class CandleCache: @@ -3893,16 +3953,15 @@ async def fetch_multi_timeframe_data(exchange, symbol, candle_cache): class LSTMAttentionDQN(nn.Module): def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4): super(LSTMAttentionDQN, self).__init__() - - # CNN for pattern recognition - self.cnn = CandlePatternCNN(input_channels=5, feature_dimension=512) - - # Calculate expanded state size with CNN features - self.expanded_state_size = state_size + 512 + self.state_size = state_size + self.action_size = action_size + self.hidden_size = hidden_size + self.lstm_layers = lstm_layers + self.attention_heads = attention_heads # LSTM layer self.lstm = nn.LSTM( - input_size=self.expanded_state_size, + input_size=state_size, hidden_size=hidden_size, num_layers=lstm_layers, batch_first=True, @@ -3916,84 +3975,199 @@ class LSTMAttentionDQN(nn.Module): dropout=0.1 ) - # Advantage stream (dueling architecture) - self.advantage_stream = nn.Sequential( - nn.Linear(hidden_size, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, action_size) - ) - - # Value stream (dueling architecture) + # Value stream self.value_stream = nn.Sequential( - nn.Linear(hidden_size, hidden_size // 2), + nn.Linear(hidden_size, 128), nn.ReLU(), - nn.Linear(hidden_size // 2, 1) + nn.Linear(128, 1) ) + # Advantage stream + self.advantage_stream = nn.Sequential( + nn.Linear(hidden_size, 128), + nn.ReLU(), + nn.Linear(128, action_size) + ) + + # Fusion for multi-timeframe data + self.cnn_fusion = nn.Sequential( + nn.Linear(512 * 3, 1024), # 512 features from each of the 3 timeframes + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(1024, hidden_size) + ) + + # Initialize weights + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.LSTM): + for name, param in module.named_parameters(): + if 'weight' in name: + nn.init.xavier_uniform_(param) + elif 'bias' in name: + nn.init.constant_(param, 0) + def forward(self, state, x_1m=None, x_1h=None, x_1d=None): - # Handle different input shapes - if len(state.shape) == 1: - # Add batch dimension if missing - state = state.unsqueeze(0) - - if len(state.shape) == 2: - # Add sequence dimension if missing - state = state.unsqueeze(1) + """ + Forward pass handling different input shapes and optional CNN features + Args: + state: Primary state vector (batch_size, sequence_length, state_size) + x_1m, x_1h, x_1d: Optional CNN features from different timeframes + + Returns: + Q-values for each action + """ batch_size = state.size(0) - seq_len = state.size(1) - # If CNN inputs are provided, process them and concatenate with state + # Handle CNN features if provided if x_1m is not None and x_1h is not None and x_1d is not None: - # Note: x_1s is not used for now but kept in interface for future WebSocket implementation - cnn_features = self.cnn(x_1m, x_1h, x_1d) + # Ensure all CNN features have batch dimension + if len(x_1m.shape) == 2: + x_1m = x_1m.unsqueeze(0) + if len(x_1h.shape) == 2: + x_1h = x_1h.unsqueeze(0) + if len(x_1d.shape) == 2: + x_1d = x_1d.unsqueeze(0) + + # Ensure batch dimensions match + if x_1m.size(0) != batch_size: + x_1m = x_1m.expand(batch_size, -1, -1) if x_1m.size(0) == 1 else x_1m[:batch_size] + if x_1h.size(0) != batch_size: + x_1h = x_1h.expand(batch_size, -1, -1) if x_1h.size(0) == 1 else x_1h[:batch_size] + if x_1d.size(0) != batch_size: + x_1d = x_1d.expand(batch_size, -1, -1) if x_1d.size(0) == 1 else x_1d[:batch_size] - # Expand CNN features to match sequence length of state - cnn_features = cnn_features.unsqueeze(1).expand(-1, seq_len, -1) + # Check dimensions before concatenation + if x_1m.dim() == 3 and x_1m.size(1) == 512 and x_1h.size(1) == 512 and x_1d.size(1) == 512: + # Already in correct format [batch, features] + cnn_combined = torch.cat([x_1m, x_1h, x_1d], dim=1) + elif x_1m.dim() == 2 and x_1m.size(1) == 512 and x_1h.size(1) == 512 and x_1d.size(1) == 512: + # Dimensions correct but missing batch dimension + cnn_combined = torch.cat([x_1m, x_1h, x_1d], dim=1).unsqueeze(0) + else: + # Reshape to ensure correct dimensions + x_1m_flat = x_1m.reshape(batch_size, -1) + x_1h_flat = x_1h.reshape(batch_size, -1) + x_1d_flat = x_1d.reshape(batch_size, -1) + + # Handle variable dimensions more gracefully + needed_features = 512 + if x_1m_flat.size(1) < needed_features: + x_1m_flat = F.pad(x_1m_flat, (0, needed_features - x_1m_flat.size(1))) + else: + x_1m_flat = x_1m_flat[:, :needed_features] + + if x_1h_flat.size(1) < needed_features: + x_1h_flat = F.pad(x_1h_flat, (0, needed_features - x_1h_flat.size(1))) + else: + x_1h_flat = x_1h_flat[:, :needed_features] + + if x_1d_flat.size(1) < needed_features: + x_1d_flat = F.pad(x_1d_flat, (0, needed_features - x_1d_flat.size(1))) + else: + x_1d_flat = x_1d_flat[:, :needed_features] + + # Concatenate + cnn_combined = torch.cat([x_1m_flat, x_1h_flat, x_1d_flat], dim=1) - # Concatenate state with CNN features - state = torch.cat([state, cnn_features], dim=2) + # Use CNN fusion network to reduce dimension + cnn_features = self.cnn_fusion(cnn_combined) + + # Reshape to match LSTM input shape + cnn_features = cnn_features.view(batch_size, 1, self.hidden_size) + + # Combine with state input by concatenating along sequence dimension + if state.dim() < 3: + # If state is 2D [batch, features], reshape to 3D [batch, 1, features] + state = state.unsqueeze(1) + + # Ensure state has proper dimensions + if state.size(2) != self.state_size: + # If state dimension doesn't match, reshape or pad + if state.size(2) > self.state_size: + state = state[:, :, :self.state_size] + else: + state = F.pad(state, (0, self.state_size - state.size(2))) + + # Concatenate along sequence dimension + combined_input = torch.cat([state, cnn_features], dim=1) else: - # If CNN inputs not provided, pad with zeros - padding = torch.zeros(batch_size, seq_len, 512, device=state.device) - state = torch.cat([state, padding], dim=2) + # Use only state input if CNN features not provided + combined_input = state + if combined_input.dim() < 3: + # If state is 2D [batch, features], reshape to 3D [batch, 1, features] + combined_input = combined_input.unsqueeze(1) + + # Ensure state has proper dimensions + if combined_input.size(2) != self.state_size: + # If state dimension doesn't match, reshape or pad + if combined_input.size(2) > self.state_size: + combined_input = combined_input[:, :, :self.state_size] + else: + combined_input = F.pad(combined_input, (0, self.state_size - combined_input.size(2))) - # Process through LSTM - lstm_out, _ = self.lstm(state) + # Pass through LSTM + lstm_out, _ = self.lstm(combined_input) - # Apply attention - # Reshape for attention: [seq_len, batch_size, hidden_size] - lstm_out_permuted = lstm_out.permute(1, 0, 2) - attn_output, _ = self.attention(lstm_out_permuted, lstm_out_permuted, lstm_out_permuted) + # Apply self-attention to LSTM output + # Transform to shape required by MultiheadAttention (seq_len, batch, hidden) + attn_input = lstm_out.transpose(0, 1) + attn_output, _ = self.attention(attn_input, attn_input, attn_input) - # Reshape back: [batch_size, seq_len, hidden_size] - attn_output = attn_output.permute(1, 0, 2) + # Transform back to (batch, seq_len, hidden) + attn_output = attn_output.transpose(0, 1) - # Use the output of the last timestep - features = attn_output[:, -1, :] + # Use last output after attention + attn_out = attn_output[:, -1] - # Dueling architecture - advantage = self.advantage_stream(features) - value = self.value_stream(features) + # Value and advantage streams (dueling architecture) + value = self.value_stream(attn_out) + advantage = self.advantage_stream(attn_out) - # Combine value and advantage streams + # Combine value and advantage for Q-values q_values = value + advantage - advantage.mean(dim=1, keepdim=True) return q_values - + def forward_realtime(self, x): - """ - Optimized forward pass for real-time trading with minimal latency. + """Simplified forward pass for realtime inference""" + # Adapt x to the right format if needed + if isinstance(x, np.ndarray): + x = torch.FloatTensor(x) - TODO: Implement streamlined forward pass that prioritizes speed - """ - # For now, just use the regular forward pass - # This could be optimized later with techniques like: - # - Using a smaller model for real-time decisions - # - Skipping certain layers or calculations - # - Using quantized weights or other optimizations + # Add batch dimension if not present + if x.dim() == 1: + x = x.unsqueeze(0) - return self.forward(x) + # Add sequence dimension if not present + if x.dim() == 2: + x = x.unsqueeze(1) + + # Basic forward pass + lstm_out, _ = self.lstm(x) + + # Apply attention + attn_input = lstm_out.transpose(0, 1) + attn_output, _ = self.attention(attn_input, attn_input, attn_input) + attn_output = attn_output.transpose(0, 1) + + # Get last output after attention + features = attn_output[:, -1] + + # Value and advantage streams + value = self.value_stream(features) + advantage = self.advantage_stream(features) + + # Combine for Q-values + q_values = value + advantage - advantage.mean(dim=1, keepdim=True) + + return q_values # Add this class after the CandleCache class @@ -4354,20 +4528,20 @@ async def train_with_backtesting(agent, env, symbol="ETH/USDT", num_episodes=10, max_steps_per_episode=1000, period_name=None): """ - Train the agent using historical data from a specific time period. + Train agent with backtesting on historical data. Args: agent: The agent to train - env: The trading environment + env: Trading environment symbol: Trading pair symbol - since_timestamp: Start timestamp for backtesting (milliseconds) - until_timestamp: End timestamp for backtesting (milliseconds) - num_episodes: Number of episodes to train for + since_timestamp: Start timestamp for backtesting + until_timestamp: End timestamp for backtesting + num_episodes: Number of episodes to train max_steps_per_episode: Maximum steps per episode - period_name: Optional name for the backtesting period (for logging) + period_name: Name of the backtest period Returns: - Training statistics for the backtesting period + Training statistics dictionary """ # Create a backtesting candle cache backtest_cache = BacktestCandles(since_timestamp, until_timestamp) @@ -4376,6 +4550,7 @@ async def train_with_backtesting(agent, env, symbol="ETH/USDT", logging.info(f"Starting backtesting for period: {period_name}") # Initialize exchange for data fetching + exchange = None try: exchange = await initialize_exchange() logging.info("Initialized exchange for backtesting") @@ -4401,16 +4576,32 @@ async def train_with_backtesting(agent, env, symbol="ETH/USDT", 'net_pnl_after_fees': [] } + # Memory management function + def clean_memory(): + """Clean up memory to avoid memory leaks""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + # Fetch historical data for all timeframes try: + clean_memory() # Clean memory before fetching data candle_data = await backtest_cache.fetch_all_timeframes(exchange, symbol) if not candle_data or not candle_data['1m']: logging.error(f"No historical data available for backtesting period: {period_name}") + try: + await exchange.close() + except Exception as e: + logging.error(f"Error closing exchange: {e}") return None logging.info(f"Fetched historical data for backtesting: {len(candle_data['1m'])} minute candles") except Exception as e: logging.error(f"Failed to fetch historical data for backtesting: {e}") + try: + await exchange.close() + except Exception as exchange_err: + logging.error(f"Error closing exchange: {exchange_err}") return None # Track best models @@ -4424,6 +4615,9 @@ async def train_with_backtesting(agent, env, symbol="ETH/USDT", # Start backtesting training loop for episode in range(num_episodes): try: + # Clean memory before starting a new episode + clean_memory() + # Reset environment state = env.reset() episode_reward = 0 @@ -4432,36 +4626,61 @@ async def train_with_backtesting(agent, env, symbol="ETH/USDT", # Update CNN patterns with historical data env.update_cnn_patterns(candle_data) + # Track consecutive errors for circuit breaker + consecutive_errors = 0 + max_consecutive_errors = 5 + # Episode loop for step in range(max_steps_per_episode): - # Select action using CNN-enhanced policy - action = agent.select_action(state, training=True, candle_data=candle_data) + try: + # Select action using CNN-enhanced policy + action = agent.select_action(state, training=True, candle_data=candle_data) + + # Take action + next_state, reward, done, info = env.step(action) + + # Store transition in replay memory + agent.memory.push(state, action, reward, next_state, done) + + # Move to the next state + state = next_state + + # Update episode reward + episode_reward += reward + + # Learn from experience + if len(agent.memory) > BATCH_SIZE: + try: + loss = agent.learn() + if loss is not None: + episode_losses.append(loss) + # Reset consecutive errors counter on successful learning + consecutive_errors = 0 + except Exception as e: + logging.error(f"Error during learning: {e}") + consecutive_errors += 1 + if consecutive_errors >= max_consecutive_errors: + logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors") + break + + # Update target network periodically + if step % TARGET_UPDATE == 0: + agent.update_target_network() + + # Clean memory periodically during long episodes + if step % 200 == 0 and step > 0: + clean_memory() + + # End episode if done + if done: + break - # Take action - next_state, reward, done, info = env.step(action) - - # Store transition in replay memory - agent.memory.push(state, action, reward, next_state, done) - - # Move to the next state - state = next_state - - # Update episode reward - episode_reward += reward - - # Learn from experience - if len(agent.memory) > BATCH_SIZE: - loss = agent.learn() - if loss is not None: - episode_losses.append(loss) - - # Update target network periodically - if step % TARGET_UPDATE == 0: - agent.update_target_network() - - # End episode if done - if done: - break + except Exception as e: + logging.error(f"Error in training step: {e}") + consecutive_errors += 1 + if consecutive_errors >= max_consecutive_errors: + logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors") + break # Calculate statistics mean_loss = np.mean(episode_losses) if episode_losses else 0 @@ -4481,12 +4700,23 @@ async def train_with_backtesting(agent, env, symbol="ETH/USDT", stats['balances'].append(balance) stats['win_rates'].append(win_rate) stats['episode_pnls'].append(pnl) + stats['drawdowns'].append(env.max_drawdown) + stats['trade_counts'].append(trade_count) + stats['loss_values'].append(mean_loss) stats['fees'].append(fees) stats['net_pnl_after_fees'].append(net_pnl) - stats['loss_values'].append(mean_loss) - stats['trade_counts'].append(trade_count) - # Track best model + # Calculate and update cumulative PnL + if len(stats['episode_pnls']) > 0: + cumulative_pnl = sum(stats['episode_pnls']) + if 'cumulative_pnl' not in stats: + stats['cumulative_pnl'] = [] + stats['cumulative_pnl'].append(cumulative_pnl) + if writer: + writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode) + writer.add_scalar('CumulativeNetPnL/episode', sum(stats['net_pnl_after_fees']), episode) + + # Save model if this is the best reward or PnL if episode_reward > best_reward: best_reward = episode_reward model_path = f"models/backtest/{period_name}_best_reward.pt" if period_name else "models/backtest/best_reward.pt" @@ -4507,28 +4737,50 @@ async def train_with_backtesting(agent, env, symbol="ETH/USDT", logging.error(f"Error saving best PnL model: {e}") logging.info(f"New best PnL: ${best_pnl:.2f} (model not saved)") + # Save model if this is the best net PnL (after fees) if net_pnl > best_net_pnl: best_net_pnl = net_pnl - logging.info(f"New best Net PnL: ${best_net_pnl:.2f}") + model_path = f"models/backtest/{period_name}_best_net_pnl.pt" if period_name else "models/backtest/best_net_pnl.pt" + try: + agent.save(model_path) + logging.info(f"New best Net PnL: ${best_net_pnl:.2f}") + except Exception as e: + logging.error(f"Error saving best net PnL model: {e}") + logging.info(f"New best Net PnL: ${best_net_pnl:.2f} (model not saved)") - # Log episode results - logging.info( - f"Episode {episode+1}/{num_episodes} | " + - f"Reward: {episode_reward:.2f} | " + - f"Balance: ${balance:.2f} | " + - f"PnL: ${pnl:.2f} | " + - f"Fees: ${fees:.2f} | " + - f"Net PnL: ${net_pnl:.2f} | " + - f"Win Rate: {win_rate:.2f} | " + - f"Trades: {trade_count} | " + - f"Loss: {mean_loss:.5f} | " + - f"Epsilon: {epsilon:.4f}" - ) + # Save checkpoint periodically + if episode % 10 == 0: + try: + if use_compact_save: + compact_save(agent, f'models/trading_agent_checkpoint_{episode}.pt') + else: + agent.save(f'models/trading_agent_checkpoint_{episode}.pt') + except Exception as e: + logging.error(f"Error saving checkpoint model: {e}") + + # Update epsilon + agent.update_epsilon(episode) + + # Log training progress + logging.info(f"Episode {episode+1}/{num_episodes} | " + + f"Reward: {episode_reward:.2f} | " + + f"Balance: ${balance:.2f} | " + + f"PnL: ${pnl:.2f} | " + + f"Fees: ${fees:.2f} | " + + f"Net PnL: ${net_pnl:.2f} | " + + f"Win Rate: {win_rate:.2f} | " + + f"Trades: {trade_count} | " + + f"Loss: {mean_loss:.5f} | " + + f"Epsilon: {agent.epsilon:.4f}") except Exception as e: - logging.error(f"Error during backtesting episode {episode+1}: {e}") + logging.error(f"Error in episode {episode}: {e}") + logging.error(traceback.format_exc()) continue + # Clean memory before saving final model + clean_memory() + # Save final model if period_name: try: @@ -4560,17 +4812,354 @@ async def train_with_backtesting(agent, env, symbol="ETH/USDT", logging.error(f"Error saving backtesting statistics: {e}") # Close exchange connection - try: - await exchange.close() - except AttributeError: - # Some exchanges don't have a close method - logging.info("Exchange doesn't have a close method, skipping") - except Exception as e: - logging.error(f"Error closing exchange connection: {e}") + if exchange: + try: + await exchange.close() + logging.info("Exchange connection closed successfully") + except AttributeError: + # Some exchanges don't have a close method + logging.info("Exchange doesn't have a close method, skipping") + except Exception as e: + logging.error(f"Error closing exchange connection: {e}") return stats +# Implement a robust save function to handle PyTorch serialization errors +def robust_save(model, path): + """ + Save a model with multiple fallback approaches to ensure file is saved + even in low disk space conditions. + """ + logger.info(f"Saving model to {path}.backup (attempt 1)") + backup_path = f"{path}.backup" + + # Attempt 1: Regular save to backup file + try: + checkpoint = { + 'policy_net': model.policy_net.state_dict(), + 'target_net': model.target_net.state_dict(), + 'optimizer': model.optimizer.state_dict(), + 'epsilon': model.epsilon + } + torch.save(checkpoint, backup_path) + logger.info(f"Successfully saved to {backup_path}") + + # If successful, copy to final path + try: + shutil.copy2(backup_path, path) + logger.info(f"Copied backup to {path}") + logger.info(f"Model saved successfully to {path}") + return True + except Exception as e: + logger.warning(f"Failed to copy backup to main file: {str(e)}") + logger.info(f"Using backup file as the main save") + return True + except Exception as e: + logger.warning(f"First save attempt failed: {str(e)}") + + # Attempt 2: Try with older pickle protocol + logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)") + try: + checkpoint = { + 'policy_net': model.policy_net.state_dict(), + 'target_net': model.target_net.state_dict(), + 'optimizer': model.optimizer.state_dict(), + 'epsilon': model.epsilon + } + torch.save(checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2) + logger.info(f"Successfully saved to {path} with protocol 2") + return True + except Exception as e: + logger.warning(f"Second save attempt failed: {str(e)}") + + # Attempt 3: Try without optimizer + logger.info(f"Saving model to {path} (attempt 3 - without optimizer)") + try: + checkpoint = { + 'policy_net': model.policy_net.state_dict(), + 'target_net': model.target_net.state_dict(), + 'epsilon': model.epsilon + } + torch.save(checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2) + logger.info(f"Successfully saved to {path} without optimizer") + return True + except Exception as e: + logger.warning(f"Third save attempt failed: {str(e)}") + + # Attempt 4: Save model structure (as JSON) and parameters separately + logger.info(f"Saving model to {path} (attempt 4 - model structure as JSON)") + try: + # Save only essential model parameters as JSON + model_params = { + 'epsilon': float(model.epsilon), + 'state_size': model.state_size, + 'action_size': model.action_size, + 'hidden_size': model.hidden_size, + 'lstm_layers': model.policy_net.lstm_layers if hasattr(model.policy_net, 'lstm_layers') else 2, + 'attention_heads': model.policy_net.attention_heads if hasattr(model.policy_net, 'attention_heads') else 4 + } + + params_path = f"{path}.params.json" + with open(params_path, 'w') as f: + json.dump(model_params, f) + logger.info(f"Successfully saved model parameters to {params_path}") + + # Now try to save a smaller version of the model without CNN components + # This is a more minimal save for recovery purposes + try: + # Create stripped down checkpoint with minimal components + minimal_checkpoint = { + 'epsilon': model.epsilon, + 'state_size': model.state_size, + 'action_size': model.action_size, + 'hidden_size': model.hidden_size + } + + minimal_path = f"{path}.minimal" + torch.save(minimal_checkpoint, minimal_path, _use_new_zipfile_serialization=False, pickle_protocol=2) + logger.info(f"Successfully saved minimal checkpoint to {minimal_path}") + except Exception as e: + logger.warning(f"Minimal checkpoint save failed: {str(e)}") + + logger.info(f"Model saved successfully to {path}") + return True + except Exception as e: + logger.error(f"All save attempts failed for {path}: {str(e)}") + return False + +def cleanup_model_files(keep_best=True, keep_latest_n=5, aggressive=False): + """ + Delete old model files to free up disk space. + + Args: + keep_best (bool): Whether to keep the best model files (reward, pnl, net_pnl) + keep_latest_n (int): Number of latest checkpoint files to keep + aggressive (bool): If True, apply more aggressive cleanup in very low disk scenarios + """ + try: + logging.info(f"Running model file cleanup: keep_best={keep_best}, keep_latest_n={keep_latest_n}, aggressive={aggressive}") + models_dir = "models" + + # Get all files in the models directory + all_files = os.listdir(models_dir) + + # Files to potentially delete + checkpoint_files = [] + backup_files = [] + params_files = [] + dated_files = [] + + # Best files to keep if keep_best is True + best_patterns = [ + "trading_agent_best_reward.pt", + "trading_agent_best_pnl.pt", + "trading_agent_best_net_pnl.pt", + "trading_agent_final.pt" + ] + + # Categorize files for potential deletion + for filename in all_files: + file_path = os.path.join(models_dir, filename) + + # Skip directories + if os.path.isdir(file_path): + continue + + # Skip current best files if keep_best is True + if keep_best and any(filename == pattern for pattern in best_patterns): + continue + + # Check for different file types + if "checkpoint" in filename and filename.endswith(".pt"): + checkpoint_files.append((filename, os.path.getmtime(file_path), file_path)) + elif filename.endswith(".backup"): + backup_files.append((filename, os.path.getmtime(file_path), file_path)) + elif filename.endswith(".params.json"): + params_files.append((filename, os.path.getmtime(file_path), file_path)) + elif "_2025" in filename or "_2024" in filename: # Files with date stamps + dated_files.append((filename, os.path.getmtime(file_path), file_path)) + + bytes_freed = 0 + files_deleted = 0 + + # Process checkpoint files - keep the newest N + if len(checkpoint_files) > keep_latest_n: + # Sort by modification time (newest first) + checkpoint_files.sort(key=lambda x: x[1], reverse=True) + + # Keep the newest N files + files_to_delete = checkpoint_files[keep_latest_n:] + + # Delete old checkpoint files + for _, _, file_path in files_to_delete: + try: + file_size = os.path.getsize(file_path) + os.remove(file_path) + bytes_freed += file_size + files_deleted += 1 + logging.info(f"Deleted old checkpoint file: {file_path}") + except Exception as e: + logging.error(f"Failed to delete file {file_path}: {str(e)}") + + # If aggressive cleanup is enabled, remove more files + if aggressive: + # Delete all backup files except the newest one + if backup_files: + backup_files.sort(key=lambda x: x[1], reverse=True) + for _, _, file_path in backup_files[1:]: # Keep only newest backup + try: + file_size = os.path.getsize(file_path) + os.remove(file_path) + bytes_freed += file_size + files_deleted += 1 + logging.info(f"Deleted old backup file: {file_path}") + except Exception as e: + logging.error(f"Failed to delete file {file_path}: {str(e)}") + + # Delete all dated files (these are typically archived models) + for _, _, file_path in dated_files: + try: + file_size = os.path.getsize(file_path) + os.remove(file_path) + bytes_freed += file_size + files_deleted += 1 + logging.info(f"Deleted dated model file: {file_path}") + except Exception as e: + logging.error(f"Failed to delete file {file_path}: {str(e)}") + + logging.info(f"Cleanup complete. Deleted {files_deleted} files, freed {bytes_freed / (1024*1024):.2f} MB") + + # Check available disk space after cleanup + try: + if platform.system() == 'Windows': + free_bytes = ctypes.c_ulonglong(0) + ctypes.windll.kernel32.GetDiskFreeSpaceExW(ctypes.c_wchar_p(os.path.abspath(models_dir)), None, None, ctypes.pointer(free_bytes)) + free_mb = free_bytes.value / (1024 * 1024) + else: + st = os.statvfs(os.path.abspath(models_dir)) + free_mb = (st.f_bavail * st.f_frsize) / (1024 * 1024) + + logging.info(f"Available disk space after cleanup: {free_mb:.2f} MB") + + # If space is still low, recommend aggressive cleanup + if free_mb < 200 and not aggressive: # Less than 200MB available + logging.warning("Disk space still critically low. Consider using aggressive cleanup.") + except Exception as e: + logging.error(f"Error checking disk space: {str(e)}") + + except Exception as e: + logging.error(f"Error during file cleanup: {str(e)}") + logging.error(traceback.format_exc()) + +def compact_save(model, optimizer, reward, epsilon, state_size, action_size, hidden_size, path, use_quantization=False): + """ + Save a model in a compact format suitable for low disk space environments. + Includes fallbacks if the primary save method fails. + + Args: + model: The model to save + optimizer: The optimizer to save + reward: The current reward + epsilon: The current epsilon value + state_size: The state size + action_size: The action size + hidden_size: The hidden size + path: The path to save to + use_quantization: Whether to use quantization to reduce model size + + Returns: + bool: Whether the save was successful + """ + try: + # Create minimal checkpoint with essential data only + checkpoint = { + 'model_state_dict': model.state_dict(), + 'epsilon': epsilon, + 'state_size': state_size, + 'action_size': action_size, + 'hidden_size': hidden_size + } + + # Apply quantization if requested + if use_quantization: + try: + logging.info(f"Attempting quantized save to {path}") + # Quantize model to int8 + quantized_model = torch.quantization.quantize_dynamic( + model, # the original model + {torch.nn.Linear}, # a set of layers to dynamically quantize + dtype=torch.qint8 # the target dtype for quantized weights + ) + + # Create quantized checkpoint + quantized_checkpoint = { + 'model_state_dict': quantized_model.state_dict(), + 'epsilon': epsilon, + 'state_size': state_size, + 'action_size': action_size, + 'hidden_size': hidden_size, + 'is_quantized': True + } + + # Save with older pickle protocol and disable new zipfile serialization + torch.save(quantized_checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2) + logging.info(f"Quantized compact save successful to {path}") + return True + except Exception as e: + logging.warning(f"Quantized save failed, falling back to regular save: {str(e)}") + # Fall back to regular save if quantization fails + + # Regular save with older pickle protocol and no zipfile serialization + torch.save(checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2) + logging.info(f"Compact save successful to {path}") + return True + except Exception as e: + logging.error(f"Compact save failed: {str(e)}") + logging.error(traceback.format_exc()) + + # Fallback: Save just the parameters as JSON if we can't save the full model + try: + params = { + 'epsilon': epsilon, + 'state_size': state_size, + 'action_size': action_size, + 'hidden_size': hidden_size + } + json_path = f"{path}.params.json" + with open(json_path, 'w') as f: + json.dump(params, f) + logging.info(f"Saved minimal parameters to {json_path}") + return False + except Exception as json_e: + logging.error(f"JSON parameter save failed: {str(json_e)}") + return False + if __name__ == "__main__": + # Parse command line arguments + parser = argparse.ArgumentParser(description='Trading Bot') + parser.add_argument('--mode', type=str, default='train', help='Mode: train, test, live') + parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train') + parser.add_argument('--max_steps', type=int, default=1000, help='Maximum steps per episode') + parser.add_argument('--update_interval', type=int, default=10, help='Target network update interval') + parser.add_argument('--training_iterations', type=int, default=10, help='Number of training iterations per step') + parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading symbol') + parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for candlestick data') + parser.add_argument('--compact_save', action='store_true', help='Use compact save to reduce disk usage') + parser.add_argument('--use_quantization', action='store_true', help='Use model quantization for even smaller file sizes') + parser.add_argument('--cleanup', action='store_true', help='Clean up old model files before training') + parser.add_argument('--aggressive_cleanup', action='store_true', help='Perform aggressive cleanup to free more space') + parser.add_argument('--keep_latest', type=int, default=5, help='Number of latest checkpoint files to keep when cleaning up') + + args = parser.parse_args() + + # Import platform and ctypes for disk space checking + import platform + import ctypes + + # Run cleanup if requested + if args.cleanup: + cleanup_model_files(keep_best=True, keep_latest_n=args.keep_latest, aggressive=args.aggressive_cleanup) + try: asyncio.run(main()) except KeyboardInterrupt: diff --git a/crypto/gogo2/test_model_save_load.py b/crypto/gogo2/test_model_save_load.py new file mode 100644 index 0000000..a4815a9 --- /dev/null +++ b/crypto/gogo2/test_model_save_load.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python +import os +import logging +import torch +import argparse +import gc +import traceback +import shutil +from main import Agent, robust_save + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler("test_model_save_load.log"), + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +def create_test_directory(): + """Create a test directory for saving models""" + test_dir = "test_models" + os.makedirs(test_dir, exist_ok=True) + return test_dir + +def test_save_load_cycle(state_size=64, action_size=4, hidden_size=384): + """Test a full cycle of saving and loading models""" + test_dir = create_test_directory() + + # Create a test agent + logger.info(f"Creating test agent with state_size={state_size}, action_size={action_size}, hidden_size={hidden_size}") + agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size) + + # Define paths for testing + save_path = os.path.join(test_dir, "test_agent.pt") + + # Test saving + logger.info(f"Testing save to {save_path}") + save_success = agent.save(save_path) + + if save_success: + logger.info(f"Save successful, model size: {os.path.getsize(save_path)} bytes") + else: + logger.error("Save failed!") + return False + + # Memory cleanup + del agent + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # Test loading + logger.info(f"Testing load from {save_path}") + try: + new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size) + new_agent.load(save_path) + logger.info("Load successful") + + # Verify model architecture + logger.info(f"Verifying model architecture") + assert new_agent.state_size == state_size, f"Expected state_size={state_size}, got {new_agent.state_size}" + assert new_agent.action_size == action_size, f"Expected action_size={action_size}, got {new_agent.action_size}" + assert new_agent.hidden_size == hidden_size, f"Expected hidden_size={hidden_size}, got {new_agent.hidden_size}" + + logger.info("Model architecture verified correctly") + return True + except Exception as e: + logger.error(f"Error during load or verification: {e}") + logger.error(traceback.format_exc()) + return False + +def test_robust_save_methods(state_size=64, action_size=4, hidden_size=384): + """Test all the robust save methods""" + test_dir = create_test_directory() + + # Create a test agent + logger.info(f"Creating test agent for robust save testing") + agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size) + + # Test each robust save method + methods = [ + ("regular", os.path.join(test_dir, "regular_save.pt")), + ("backup", os.path.join(test_dir, "backup_save.pt")), + ("pickle2", os.path.join(test_dir, "pickle2_save.pt")), + ("no_optimizer", os.path.join(test_dir, "no_optimizer_save.pt")), + ("jit", os.path.join(test_dir, "jit_save.pt")) + ] + + results = {} + + for method_name, save_path in methods: + logger.info(f"Testing {method_name} save method to {save_path}") + + try: + if method_name == "regular": + # Use regular save + success = agent.save(save_path) + elif method_name == "backup": + # Use backup method directly + backup_path = f"{save_path}.backup" + checkpoint = { + 'policy_net': agent.policy_net.state_dict(), + 'target_net': agent.target_net.state_dict(), + 'optimizer': agent.optimizer.state_dict(), + 'epsilon': agent.epsilon, + 'state_size': agent.state_size, + 'action_size': agent.action_size, + 'hidden_size': agent.hidden_size + } + torch.save(checkpoint, backup_path) + shutil.copy(backup_path, save_path) + success = os.path.exists(save_path) + elif method_name == "pickle2": + # Use pickle protocol 2 + checkpoint = { + 'policy_net': agent.policy_net.state_dict(), + 'target_net': agent.target_net.state_dict(), + 'optimizer': agent.optimizer.state_dict(), + 'epsilon': agent.epsilon, + 'state_size': agent.state_size, + 'action_size': agent.action_size, + 'hidden_size': agent.hidden_size + } + torch.save(checkpoint, save_path, pickle_protocol=2) + success = os.path.exists(save_path) + elif method_name == "no_optimizer": + # Save without optimizer + checkpoint = { + 'policy_net': agent.policy_net.state_dict(), + 'target_net': agent.target_net.state_dict(), + 'epsilon': agent.epsilon, + 'state_size': agent.state_size, + 'action_size': agent.action_size, + 'hidden_size': agent.hidden_size + } + torch.save(checkpoint, save_path) + success = os.path.exists(save_path) + elif method_name == "jit": + # Use JIT save + try: + scripted_policy = torch.jit.script(agent.policy_net) + torch.jit.save(scripted_policy, f"{save_path}.policy.jit") + + scripted_target = torch.jit.script(agent.target_net) + torch.jit.save(scripted_target, f"{save_path}.target.jit") + + # Save parameters + with open(f"{save_path}.params.json", "w") as f: + import json + params = { + 'epsilon': float(agent.epsilon), + 'state_size': int(agent.state_size), + 'action_size': int(agent.action_size), + 'hidden_size': int(agent.hidden_size) + } + json.dump(params, f) + + success = (os.path.exists(f"{save_path}.policy.jit") and + os.path.exists(f"{save_path}.target.jit") and + os.path.exists(f"{save_path}.params.json")) + except Exception as e: + logger.error(f"JIT save failed: {e}") + success = False + + if success: + if method_name != "jit": + file_size = os.path.getsize(save_path) + logger.info(f"{method_name} save successful, size: {file_size} bytes") + else: + logger.info(f"{method_name} save successful") + results[method_name] = True + else: + logger.error(f"{method_name} save failed") + results[method_name] = False + + except Exception as e: + logger.error(f"Error during {method_name} save: {e}") + logger.error(traceback.format_exc()) + results[method_name] = False + + # Test loading each saved model + for method_name, save_path in methods: + if not results[method_name]: + logger.info(f"Skipping load test for {method_name} (save failed)") + continue + + if method_name == "jit": + logger.info(f"Skipping load test for {method_name} (requires special loading)") + continue + + logger.info(f"Testing load from {save_path}") + try: + new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size) + new_agent.load(save_path) + logger.info(f"Load successful for {method_name} save") + except Exception as e: + logger.error(f"Error loading from {method_name} save: {e}") + logger.error(traceback.format_exc()) + results[method_name] += " (load failed)" + + # Return summary of results + return results + +def main(): + parser = argparse.ArgumentParser(description='Test model saving and loading') + parser.add_argument('--state_size', type=int, default=64, help='State size for test model') + parser.add_argument('--action_size', type=int, default=4, help='Action size for test model') + parser.add_argument('--hidden_size', type=int, default=384, help='Hidden size for test model') + parser.add_argument('--test_robust', action='store_true', help='Test all robust save methods') + args = parser.parse_args() + + logger.info("Starting model save/load test") + + if args.test_robust: + results = test_robust_save_methods(args.state_size, args.action_size, args.hidden_size) + logger.info(f"Robust save method results: {results}") + else: + success = test_save_load_cycle(args.state_size, args.action_size, args.hidden_size) + logger.info(f"Save/load cycle {'successful' if success else 'failed'}") + + logger.info("Test completed") + +if __name__ == "__main__": + main() \ No newline at end of file