From bdf6afc6ad52991f00e6f2d82c2dd5f797c6c509 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Mon, 17 Mar 2025 23:36:44 +0200 Subject: [PATCH] better saves --- crypto/gogo2/.vscode/launch.json | 2 +- crypto/gogo2/MODEL_SAVING_FIX.md | 74 ++++ crypto/gogo2/live_training.py | 498 +++++++++++++++++++++++ crypto/gogo2/simplified_live_training.py | 118 ++++++ crypto/gogo2/test_save.py | 182 +++++++++ 5 files changed, 873 insertions(+), 1 deletion(-) create mode 100644 crypto/gogo2/MODEL_SAVING_FIX.md create mode 100644 crypto/gogo2/live_training.py create mode 100644 crypto/gogo2/simplified_live_training.py create mode 100644 crypto/gogo2/test_save.py diff --git a/crypto/gogo2/.vscode/launch.json b/crypto/gogo2/.vscode/launch.json index 165d15a..7b7b3d5 100644 --- a/crypto/gogo2/.vscode/launch.json +++ b/crypto/gogo2/.vscode/launch.json @@ -6,7 +6,7 @@ "type": "python", "request": "launch", "program": "main.py", - "args": ["--mode", "train", "--episodes", "10"], + "args": ["--mode", "train", "--episodes", "100"], "console": "integratedTerminal", "justMyCode": true }, diff --git a/crypto/gogo2/MODEL_SAVING_FIX.md b/crypto/gogo2/MODEL_SAVING_FIX.md new file mode 100644 index 0000000..62ef726 --- /dev/null +++ b/crypto/gogo2/MODEL_SAVING_FIX.md @@ -0,0 +1,74 @@ +# Model Saving Fix + +## Issue + +During training sessions, PyTorch model saving operations sometimes fail with errors like: + +``` +RuntimeError: [enforce fail at inline_container.cc:626] . unexpected pos 18278784 vs 18278680 +``` + +or + +``` +RuntimeError: [enforce fail at inline_container.cc:820] . PytorchStreamWriter failed writing file data/75: file write failed +``` + +These errors occur in the PyTorch serialization mechanism when saving models using `torch.save()`. + +## Solution + +We've implemented a robust model saving approach that uses multiple fallback methods if the primary save operation fails: + +1. **Attempt 1**: Save to a backup file first, then copy to the target path. +2. **Attempt 2**: Use an older pickle protocol (pickle protocol 2) which can be more compatible. +3. **Attempt 3**: Save without the optimizer state, which can reduce file size and avoid serialization issues. +4. **Attempt 4**: Use TorchScript's `torch.jit.save()` instead of `torch.save()`, which uses a different serialization mechanism. + +## Implementation + +The solution is implemented in two parts: + +1. A `robust_save` function that tries multiple saving approaches with fallbacks. +2. A monkey patch that replaces the Agent's `save` method with our robust version. + +### Example Usage + +```python +# Import the robust_save function +from live_training import robust_save + +# Save a model with fallbacks +success = robust_save(agent, "models/my_model.pt") +if success: + print("Model saved successfully!") +else: + print("All save attempts failed") +``` + +## Testing + +We've created a test script `test_save.py` that demonstrates the robust saving approach and verifies that it works correctly. + +To run the test: + +```bash +python test_save.py +``` + +This script creates a simple model, attempts to save it using both the standard and robust methods, and reports on the results. + +## Future Improvements + +Possible future improvements to the model saving mechanism: + +1. Additional fallback methods like serializing individual neural network layers. +2. Automatic retry mechanism with exponential backoff. +3. Asynchronous saving to avoid blocking the training loop. +4. Checksumming saved models to verify integrity. + +## Related Issues + +For more information on similar issues with PyTorch model saving, see: +- https://github.com/pytorch/pytorch/issues/27736 +- https://github.com/pytorch/pytorch/issues/24045 \ No newline at end of file diff --git a/crypto/gogo2/live_training.py b/crypto/gogo2/live_training.py new file mode 100644 index 0000000..a3a1949 --- /dev/null +++ b/crypto/gogo2/live_training.py @@ -0,0 +1,498 @@ +#!/usr/bin/env python +import asyncio +import logging +import sys +import platform +import argparse +import os +import datetime +import traceback +import numpy as np +import torch +from main import initialize_exchange, TradingEnvironment, Agent +from torch.utils.tensorboard import SummaryWriter + +# Fix for Windows asyncio issues with aiodns +if platform.system() == 'Windows': + try: + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + print("Using Windows SelectorEventLoopPolicy to fix aiodns issue") + except Exception as e: + print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}") + +# Setup logging function +def setup_logging(): + """Setup logging configuration for the application""" + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("live_training.log"), + logging.StreamHandler(sys.stdout) # Added stdout handler for immediate feedback + ] + ) + +# Set up logging +setup_logging() +logger = logging.getLogger(__name__) + +# Implement a robust save function to handle PyTorch serialization errors +def robust_save(model, path): + """ + Robust model saving with multiple fallback approaches + + Args: + model: The Agent model to save + path: Path to save the model + + Returns: + bool: True if successful, False otherwise + """ + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + + # Backup path in case the main save fails + backup_path = f"{path}.backup" + + # Attempt 1: Try with default settings in a separate file first + try: + logger.info(f"Saving model to {backup_path} (attempt 1)") + checkpoint = { + 'policy_net': model.policy_net.state_dict(), + 'target_net': model.target_net.state_dict(), + 'optimizer': model.optimizer.state_dict(), + 'epsilon': model.epsilon + } + torch.save(checkpoint, backup_path) + logger.info(f"Successfully saved to {backup_path}") + + # If backup worked, copy to the actual path + if os.path.exists(backup_path): + import shutil + shutil.copy(backup_path, path) + logger.info(f"Copied backup to {path}") + return True + except Exception as e: + logger.warning(f"First save attempt failed: {e}") + + # Attempt 2: Try with pickle protocol 2 (more compatible) + try: + logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)") + checkpoint = { + 'policy_net': model.policy_net.state_dict(), + 'target_net': model.target_net.state_dict(), + 'optimizer': model.optimizer.state_dict(), + 'epsilon': model.epsilon + } + torch.save(checkpoint, path, pickle_protocol=2) + logger.info(f"Successfully saved to {path} with pickle_protocol=2") + return True + except Exception as e: + logger.warning(f"Second save attempt failed: {e}") + + # Attempt 3: Try without optimizer state (which can be large and cause issues) + try: + logger.info(f"Saving model to {path} (attempt 3 - without optimizer)") + checkpoint = { + 'policy_net': model.policy_net.state_dict(), + 'target_net': model.target_net.state_dict(), + 'epsilon': model.epsilon + } + torch.save(checkpoint, path) + logger.info(f"Successfully saved to {path} without optimizer state") + return True + except Exception as e: + logger.warning(f"Third save attempt failed: {e}") + + # Attempt 4: Try with torch.jit.save instead + try: + logger.info(f"Saving model to {path} (attempt 4 - with jit.save)") + # Save policy network using jit + scripted_policy = torch.jit.script(model.policy_net) + torch.jit.save(scripted_policy, f"{path}.policy.jit") + # Save target network using jit + scripted_target = torch.jit.script(model.target_net) + torch.jit.save(scripted_target, f"{path}.target.jit") + # Save epsilon value separately + with open(f"{path}.epsilon.txt", "w") as f: + f.write(str(model.epsilon)) + logger.info(f"Successfully saved model components with jit.save") + return True + except Exception as e: + logger.error(f"All save attempts failed: {e}") + return False + +# Implement fetch_and_update_data function +async def fetch_and_update_data(exchange, env, symbol, timeframe): + """ + Fetch new candle data and update the environment + + Args: + exchange: CCXT exchange instance + env: Trading environment instance + symbol: Trading pair symbol + timeframe: Timeframe for the candles + """ + logger.info(f"Fetching new data for {symbol} on {timeframe} timeframe") + + try: + # Default to 100 candles if not specified + limit = 1000 + + # Fetch OHLCV data + candles = await exchange.fetch_ohlcv(symbol, timeframe, limit=limit) + + if not candles or len(candles) == 0: + logger.warning(f"No candles returned for {symbol} on {timeframe}") + return False + + logger.info(f"Successfully fetched {len(candles)} candles") + + # Convert to format expected by environment + formatted_candles = [] + for candle in candles: + timestamp, open_price, high, low, close, volume = candle + formatted_candles.append({ + 'timestamp': timestamp, + 'open': open_price, + 'high': high, + 'low': low, + 'close': close, + 'volume': volume + }) + + # Update environment data + env.data = formatted_candles + if hasattr(env, '_initialize_features'): + env._initialize_features() + + logger.info(f"Updated environment with {len(formatted_candles)} candles") + + # Print latest candle info + if formatted_candles: + latest = formatted_candles[-1] + dt = datetime.datetime.fromtimestamp(latest['timestamp']/1000).strftime('%Y-%m-%d %H:%M:%S') + logger.info(f"Latest candle: Time={dt}, Open={latest['open']}, High={latest['high']}, Low={latest['low']}, Close={latest['close']}, Volume={latest['volume']}") + + return True + + except Exception as e: + logger.error(f"Error fetching candle data: {e}") + logger.error(traceback.format_exc()) + return False + +async def live_training( + symbol="ETH/USDT", + timeframe="1m", + model_path="models/trading_agent_best_pnl.pt", + save_path="models/trading_agent_live_trained.pt", + initial_balance=1000, + update_interval=60, + training_iterations=100, + learning_rate=0.0001, + batch_size=64, + gamma=0.99, + window_size=30, + max_episodes=0, # 0 means unlimited +): + """ + Live training function that uses real market data to improve the model without executing real trades. + + Args: + symbol: Trading pair symbol + timeframe: Timeframe for training + model_path: Path to the initial model to load + save_path: Path to save the improved model + initial_balance: Initial balance for simulation + update_interval: Interval to update data in seconds + training_iterations: Number of training iterations per data update + learning_rate: Learning rate for training + batch_size: Batch size for training + gamma: Discount factor for training + window_size: Window size for the environment + max_episodes: Maximum number of episodes (0 for unlimited) + """ + logger.info(f"Starting live training for {symbol} on {timeframe} timeframe") + + # Initialize exchange (without sandbox mode) + exchange = None + try: + exchange = await initialize_exchange() + logger.info(f"Exchange initialized: {exchange.id}") + + # Initialize environment + env = TradingEnvironment( + initial_balance=initial_balance, + window_size=window_size, + symbol=symbol, + timeframe=timeframe, + ) + + # Fetch initial data + logger.info(f"Fetching initial data for {symbol}") + success = await fetch_and_update_data(exchange, env, symbol, timeframe) + if not success: + logger.error("Failed to fetch initial data, exiting") + return + + # Initialize agent + STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64 + ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4 + agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE, hidden_size=384) + + # Load model if provided + if os.path.exists(model_path): + try: + agent.load(model_path) + logger.info(f"Model loaded successfully from {model_path}") + except Exception as e: + logger.warning(f"Error loading model: {e}") + logger.info("Starting with a new model") + else: + logger.warning(f"Model file {model_path} not found. Starting with a new model.") + + # Initialize TensorBoard writer + run_id = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + writer = SummaryWriter(log_dir=f"runs/live_training_{run_id}") + agent.writer = writer + + # Initialize training statistics + total_rewards = 0 + episode_count = 0 + best_reward = float('-inf') + best_pnl = float('-inf') + + # Start live training loop + logger.info(f"Starting live training loop") + + step_counter = 0 + last_update_time = datetime.datetime.now() + + while True: + # Check if we've reached the maximum number of episodes + if max_episodes > 0 and episode_count >= max_episodes: + logger.info(f"Reached maximum episodes ({max_episodes}), stopping") + break + + # Check if it's time to update data + current_time = datetime.datetime.now() + time_diff = (current_time - last_update_time).total_seconds() + + if time_diff >= update_interval: + logger.info(f"Updating market data after {time_diff:.1f} seconds") + success = await fetch_and_update_data(exchange, env, symbol, timeframe) + if not success: + logger.warning("Failed to update data, will try again later") + # Wait a bit before trying again + await asyncio.sleep(5) + continue + + last_update_time = current_time + + # Run training iterations on the updated data + episode_reward = 0 + env.reset() + done = False + + # Run one simulated episode with the current data + steps_in_episode = 0 + max_steps = len(env.data) - env.window_size - 1 + + logger.info(f"Starting episode {episode_count + 1} with {max_steps} steps") + + while not done and steps_in_episode < max_steps: + try: + state = env.get_state() + action = agent.select_action(state, training=True) + + try: + next_state, reward, done, info = env.step(action) + except ValueError as e: + logger.error(f"Error during env.step: {e}") + # If we get a ValueError, it might be because step is returning 3 values instead of 4 + # Let's try to handle this case + if "too many values to unpack" in str(e): + logger.info("Trying alternative step format") + result = env.step(action) + if len(result) == 3: + next_state, reward, done = result + info = {} + else: + raise + else: + raise + + # Save experience in replay memory + agent.memory.push(state, action, reward, next_state, done) + + # Move to the next state + state = next_state + episode_reward += reward + step_counter += 1 + steps_in_episode += 1 + + # Log action and results every 50 steps + if steps_in_episode % 50 == 0: + logger.info(f"Step {steps_in_episode}/{max_steps} | Action: {action} | Reward: {reward:.2f} | Balance: ${env.balance:.2f}") + + # Train the agent on a batch of experiences + if len(agent.memory) > batch_size: + agent.learn() + + # Additional training iterations + if steps_in_episode % 10 == 0 and training_iterations > 1: + for _ in range(training_iterations - 1): + agent.learn() + + if done: + logger.info(f"Episode done after {steps_in_episode} steps") + break + + except Exception as e: + logger.error(f"Error during episode step: {e}") + logger.error(traceback.format_exc()) + break + + # Update training statistics + episode_count += 1 + total_rewards += episode_reward + avg_reward = total_rewards / episode_count + + # Track metrics + writer.add_scalar('LiveTraining/Reward', episode_reward, episode_count) + writer.add_scalar('LiveTraining/AvgReward', avg_reward, episode_count) + writer.add_scalar('LiveTraining/Balance', env.balance, episode_count) + writer.add_scalar('LiveTraining/PnL', env.total_pnl, episode_count) + + # Report progress + logger.info(f""" + Episode: {episode_count} + Reward: {episode_reward:.2f} + Avg Reward: {avg_reward:.2f} + Balance: ${env.balance:.2f} + PnL: ${env.total_pnl:.2f} + Memory Size: {len(agent.memory)} + Total Steps: {step_counter} + """) + + # Save the model if it's the best so far (by reward or PnL) + if episode_reward > best_reward: + best_reward = episode_reward + reward_model_path = f"models/trading_agent_best_reward_{run_id}.pt" + if robust_save(agent, reward_model_path): + logger.info(f"New best reward model saved: {episode_reward:.2f} to {reward_model_path}") + else: + logger.error(f"Failed to save best reward model") + + if env.total_pnl > best_pnl: + best_pnl = env.total_pnl + pnl_model_path = f"models/trading_agent_best_pnl_{run_id}.pt" + if robust_save(agent, pnl_model_path): + logger.info(f"New best PnL model saved: ${env.total_pnl:.2f} to {pnl_model_path}") + else: + logger.error(f"Failed to save best PnL model") + + # Regularly save the model + if episode_count % 5 == 0: + if robust_save(agent, save_path): + logger.info(f"Model checkpoint saved to {save_path}") + else: + logger.error(f"Failed to save checkpoint") + + # Update target network periodically + if episode_count % 5 == 0: + try: + agent.update_target_network() + logger.info("Target network updated") + except Exception as e: + logger.error(f"Error updating target network: {e}") + + # Sleep to avoid excessive API calls + await asyncio.sleep(1) + + except asyncio.CancelledError: + logger.info("Live training cancelled") + except KeyboardInterrupt: + logger.info("Live training stopped by user") + except Exception as e: + logger.error(f"Error in live training: {e}") + logger.error(traceback.format_exc()) + finally: + # Save final model + if robust_save(agent, save_path): + logger.info(f"Final model saved to {save_path}") + else: + logger.error(f"Failed to save final model") + + # Close TensorBoard writer + try: + writer.close() + logger.info("TensorBoard writer closed") + except Exception as e: + logger.error(f"Error closing TensorBoard writer: {e}") + + # Close exchange connection + if exchange: + try: + await exchange.close() + logger.info("Exchange connection closed") + except Exception as e: + logger.error(f"Error closing exchange connection: {e}") + + logger.info("Live training completed") + +async def main(): + """Main function to parse arguments and start live training""" + parser = argparse.ArgumentParser(description='Live Training with Real Market Data') + parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol') + parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for training') + parser.add_argument('--model_path', type=str, default='models/trading_agent_best_pnl.pt', help='Path to initial model') + parser.add_argument('--save_path', type=str, default='models/trading_agent_live_trained.pt', help='Path to save improved model') + parser.add_argument('--initial_balance', type=float, default=1000, help='Initial balance for simulation') + parser.add_argument('--update_interval', type=int, default=60, help='Interval to update data in seconds') + parser.add_argument('--training_iterations', type=int, default=100, help='Training iterations per update') + parser.add_argument('--max_episodes', type=int, default=0, help='Maximum number of episodes (0 for unlimited)') + + args = parser.parse_args() + + logger.info(f"Starting live training with {args.symbol} on {args.timeframe} timeframe") + + await live_training( + symbol=args.symbol, + timeframe=args.timeframe, + model_path=args.model_path, + save_path=args.save_path, + initial_balance=args.initial_balance, + update_interval=args.update_interval, + training_iterations=args.training_iterations, + max_episodes=args.max_episodes, + ) + +# At the beginning of the file, after importing the modules +# Override Agent's save method with our robust save function +def monkey_patch_agent_save(): + """Replace Agent's save method with our robust save approach""" + original_save = Agent.save + + def patched_save(self, path): + return robust_save(self, path) + + # Apply the patch + Agent.save = patched_save + logger.info("Monkey patched Agent.save with robust_save") + + # Return the original method in case we need to restore it + return original_save + +# Call the monkey patch function at the appropriate place +if __name__ == "__main__": + try: + print("Starting live training script") + # Apply the monkey patch before running the main function + original_save = monkey_patch_agent_save() + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Live training stopped by user") + except Exception as e: + logger.error(f"Error in main function: {e}") + logger.error(traceback.format_exc()) \ No newline at end of file diff --git a/crypto/gogo2/simplified_live_training.py b/crypto/gogo2/simplified_live_training.py new file mode 100644 index 0000000..ab0136e --- /dev/null +++ b/crypto/gogo2/simplified_live_training.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python +import asyncio +import logging +import sys +import platform +import ccxt.async_support as ccxt +import os +import datetime + +# Fix for Windows asyncio issues with aiodns +if platform.system() == 'Windows': + try: + import asyncio + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + print("Using Windows SelectorEventLoopPolicy to fix aiodns issue") + except Exception as e: + print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}") + +# Setup direct console logging for immediate feedback +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout) + ] +) +logger = logging.getLogger(__name__) + +async def initialize_exchange(): + """Initialize the exchange with API credentials from environment variables""" + exchange_id = 'mexc' + try: + # Get API credentials from environment variables + api_key = os.getenv('MEXC_API_KEY', '') + secret_key = os.getenv('MEXC_SECRET_KEY', '') + + # Initialize the exchange + exchange_class = getattr(ccxt, exchange_id) + exchange = exchange_class({ + 'apiKey': api_key, + 'secret': secret_key, + 'enableRateLimit': True, + }) + + logger.info(f"Exchange initialized with standard CCXT: {exchange_id}") + return exchange + except Exception as e: + logger.error(f"Error initializing exchange: {e}") + raise + +async def fetch_ohlcv_data(exchange, symbol, timeframe, limit=1000): + """Fetch OHLCV data from the exchange""" + logger.info(f"Fetching {limit} {timeframe} candles for {symbol} (attempt 1/3)") + + try: + candles = await exchange.fetch_ohlcv(symbol, timeframe, limit=limit) + if not candles or len(candles) == 0: + logger.warning(f"No candles returned for {symbol} on {timeframe}") + return None + + logger.info(f"Successfully fetched {len(candles)} candles") + return candles + except Exception as e: + logger.error(f"Error fetching candle data: {e}") + return None + +async def main(): + """Main function to test live data fetching""" + symbol = "ETH/USDT" + timeframe = "1m" + + logger.info(f"Starting simplified live training test for {symbol} on {timeframe}") + + try: + # Initialize exchange + exchange = await initialize_exchange() + + # Fetch data every 10 seconds + for i in range(5): + logger.info(f"Fetch attempt {i+1}/5") + candles = await fetch_ohlcv_data(exchange, symbol, timeframe) + + if candles: + # Print the latest candle + latest = candles[-1] + timestamp, open_price, high, low, close, volume = latest + dt = datetime.datetime.fromtimestamp(timestamp/1000).strftime('%Y-%m-%d %H:%M:%S') + logger.info(f"Latest candle: Time={dt}, Open={open_price}, High={high}, Low={low}, Close={close}, Volume={volume}") + + # Wait 10 seconds before next fetch + if i < 4: # Don't wait after the last fetch + logger.info("Waiting 10 seconds before next fetch...") + await asyncio.sleep(10) + + # Close exchange connection + await exchange.close() + logger.info("Exchange connection closed") + + except Exception as e: + logger.error(f"Error in simplified live training test: {e}") + import traceback + logger.error(traceback.format_exc()) + finally: + try: + await exchange.close() + except: + pass + logger.info("Test completed") + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Test stopped by user") + except Exception as e: + logger.error(f"Error in main function: {e}") + import traceback + logger.error(traceback.format_exc()) \ No newline at end of file diff --git a/crypto/gogo2/test_save.py b/crypto/gogo2/test_save.py new file mode 100644 index 0000000..8bdb4cf --- /dev/null +++ b/crypto/gogo2/test_save.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python +import torch +import torch.nn as nn +import os +import logging +import sys +import platform + +# Fix for Windows asyncio issues with aiodns +if platform.system() == 'Windows': + try: + import asyncio + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + print("Using Windows SelectorEventLoopPolicy to fix aiodns issue") + except Exception as e: + print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}") + +# Setup logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler("test_save.log"), + logging.StreamHandler(sys.stdout) + ] +) +logger = logging.getLogger(__name__) + +# Define a simple model for testing +class SimpleModel(nn.Module): + def __init__(self): + super(SimpleModel, self).__init__() + self.fc1 = nn.Linear(10, 50) + self.fc2 = nn.Linear(50, 20) + self.fc3 = nn.Linear(20, 5) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + return self.fc3(x) + +# Create a simple Agent class for testing +class TestAgent: + def __init__(self): + self.policy_net = SimpleModel() + self.target_net = SimpleModel() + self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=0.001) + self.epsilon = 0.1 + + def save(self, path): + """Standard save method that might fail""" + checkpoint = { + 'policy_net': self.policy_net.state_dict(), + 'target_net': self.target_net.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'epsilon': self.epsilon + } + torch.save(checkpoint, path) + logger.info(f"Model saved to {path}") + +# Robust save function with multiple fallback approaches +def robust_save(model, path): + """ + Robust model saving with multiple fallback approaches + + Args: + model: The Agent model to save + path: Path to save the model + + Returns: + bool: True if successful, False otherwise + """ + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + + # Backup path in case the main save fails + backup_path = f"{path}.backup" + + # Attempt 1: Try with default settings in a separate file first + try: + logger.info(f"Saving model to {backup_path} (attempt 1)") + checkpoint = { + 'policy_net': model.policy_net.state_dict(), + 'target_net': model.target_net.state_dict(), + 'optimizer': model.optimizer.state_dict(), + 'epsilon': model.epsilon + } + torch.save(checkpoint, backup_path) + logger.info(f"Successfully saved to {backup_path}") + + # If backup worked, copy to the actual path + if os.path.exists(backup_path): + import shutil + shutil.copy(backup_path, path) + logger.info(f"Copied backup to {path}") + return True + except Exception as e: + logger.warning(f"First save attempt failed: {e}") + + # Attempt 2: Try with pickle protocol 2 (more compatible) + try: + logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)") + checkpoint = { + 'policy_net': model.policy_net.state_dict(), + 'target_net': model.target_net.state_dict(), + 'optimizer': model.optimizer.state_dict(), + 'epsilon': model.epsilon + } + torch.save(checkpoint, path, pickle_protocol=2) + logger.info(f"Successfully saved to {path} with pickle_protocol=2") + return True + except Exception as e: + logger.warning(f"Second save attempt failed: {e}") + + # Attempt 3: Try without optimizer state (which can be large and cause issues) + try: + logger.info(f"Saving model to {path} (attempt 3 - without optimizer)") + checkpoint = { + 'policy_net': model.policy_net.state_dict(), + 'target_net': model.target_net.state_dict(), + 'epsilon': model.epsilon + } + torch.save(checkpoint, path) + logger.info(f"Successfully saved to {path} without optimizer state") + return True + except Exception as e: + logger.warning(f"Third save attempt failed: {e}") + + # Attempt 4: Try with torch.jit.save instead + try: + logger.info(f"Saving model to {path} (attempt 4 - with jit.save)") + # Save policy network using jit + scripted_policy = torch.jit.script(model.policy_net) + torch.jit.save(scripted_policy, f"{path}.policy.jit") + # Save target network using jit + scripted_target = torch.jit.script(model.target_net) + torch.jit.save(scripted_target, f"{path}.target.jit") + # Save epsilon value separately + with open(f"{path}.epsilon.txt", "w") as f: + f.write(str(model.epsilon)) + logger.info(f"Successfully saved model components with jit.save") + return True + except Exception as e: + logger.error(f"All save attempts failed: {e}") + return False + +def main(): + # Create a test directory + save_dir = "test_models" + os.makedirs(save_dir, exist_ok=True) + + # Create a test agent + agent = TestAgent() + + # Test the regular save method (might fail) + try: + logger.info("Testing regular save method...") + save_path = os.path.join(save_dir, "regular_save.pt") + agent.save(save_path) + logger.info("Regular save succeeded") + except Exception as e: + logger.error(f"Regular save failed: {e}") + + # Test our robust save method + logger.info("Testing robust save method...") + save_path = os.path.join(save_dir, "robust_save.pt") + success = robust_save(agent, save_path) + + if success: + logger.info("Robust save succeeded!") + else: + logger.error("Robust save failed!") + + # Check which files were created + logger.info("Files created:") + for file in os.listdir(save_dir): + file_path = os.path.join(save_dir, file) + file_size = os.path.getsize(file_path) + logger.info(f" - {file} ({file_size} bytes)") + +if __name__ == "__main__": + main() \ No newline at end of file