better saves

This commit is contained in:
Dobromir Popov 2025-03-17 23:36:44 +02:00
parent 2e7a242ac7
commit bdf6afc6ad
5 changed files with 873 additions and 1 deletions

View File

@ -6,7 +6,7 @@
"type": "python",
"request": "launch",
"program": "main.py",
"args": ["--mode", "train", "--episodes", "10"],
"args": ["--mode", "train", "--episodes", "100"],
"console": "integratedTerminal",
"justMyCode": true
},

View File

@ -0,0 +1,74 @@
# Model Saving Fix
## Issue
During training sessions, PyTorch model saving operations sometimes fail with errors like:
```
RuntimeError: [enforce fail at inline_container.cc:626] . unexpected pos 18278784 vs 18278680
```
or
```
RuntimeError: [enforce fail at inline_container.cc:820] . PytorchStreamWriter failed writing file data/75: file write failed
```
These errors occur in the PyTorch serialization mechanism when saving models using `torch.save()`.
## Solution
We've implemented a robust model saving approach that uses multiple fallback methods if the primary save operation fails:
1. **Attempt 1**: Save to a backup file first, then copy to the target path.
2. **Attempt 2**: Use an older pickle protocol (pickle protocol 2) which can be more compatible.
3. **Attempt 3**: Save without the optimizer state, which can reduce file size and avoid serialization issues.
4. **Attempt 4**: Use TorchScript's `torch.jit.save()` instead of `torch.save()`, which uses a different serialization mechanism.
## Implementation
The solution is implemented in two parts:
1. A `robust_save` function that tries multiple saving approaches with fallbacks.
2. A monkey patch that replaces the Agent's `save` method with our robust version.
### Example Usage
```python
# Import the robust_save function
from live_training import robust_save
# Save a model with fallbacks
success = robust_save(agent, "models/my_model.pt")
if success:
print("Model saved successfully!")
else:
print("All save attempts failed")
```
## Testing
We've created a test script `test_save.py` that demonstrates the robust saving approach and verifies that it works correctly.
To run the test:
```bash
python test_save.py
```
This script creates a simple model, attempts to save it using both the standard and robust methods, and reports on the results.
## Future Improvements
Possible future improvements to the model saving mechanism:
1. Additional fallback methods like serializing individual neural network layers.
2. Automatic retry mechanism with exponential backoff.
3. Asynchronous saving to avoid blocking the training loop.
4. Checksumming saved models to verify integrity.
## Related Issues
For more information on similar issues with PyTorch model saving, see:
- https://github.com/pytorch/pytorch/issues/27736
- https://github.com/pytorch/pytorch/issues/24045

View File

@ -0,0 +1,498 @@
#!/usr/bin/env python
import asyncio
import logging
import sys
import platform
import argparse
import os
import datetime
import traceback
import numpy as np
import torch
from main import initialize_exchange, TradingEnvironment, Agent
from torch.utils.tensorboard import SummaryWriter
# Fix for Windows asyncio issues with aiodns
if platform.system() == 'Windows':
try:
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
print("Using Windows SelectorEventLoopPolicy to fix aiodns issue")
except Exception as e:
print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}")
# Setup logging function
def setup_logging():
"""Setup logging configuration for the application"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("live_training.log"),
logging.StreamHandler(sys.stdout) # Added stdout handler for immediate feedback
]
)
# Set up logging
setup_logging()
logger = logging.getLogger(__name__)
# Implement a robust save function to handle PyTorch serialization errors
def robust_save(model, path):
"""
Robust model saving with multiple fallback approaches
Args:
model: The Agent model to save
path: Path to save the model
Returns:
bool: True if successful, False otherwise
"""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
# Backup path in case the main save fails
backup_path = f"{path}.backup"
# Attempt 1: Try with default settings in a separate file first
try:
logger.info(f"Saving model to {backup_path} (attempt 1)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, backup_path)
logger.info(f"Successfully saved to {backup_path}")
# If backup worked, copy to the actual path
if os.path.exists(backup_path):
import shutil
shutil.copy(backup_path, path)
logger.info(f"Copied backup to {path}")
return True
except Exception as e:
logger.warning(f"First save attempt failed: {e}")
# Attempt 2: Try with pickle protocol 2 (more compatible)
try:
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path, pickle_protocol=2)
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
return True
except Exception as e:
logger.warning(f"Second save attempt failed: {e}")
# Attempt 3: Try without optimizer state (which can be large and cause issues)
try:
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path)
logger.info(f"Successfully saved to {path} without optimizer state")
return True
except Exception as e:
logger.warning(f"Third save attempt failed: {e}")
# Attempt 4: Try with torch.jit.save instead
try:
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
# Save policy network using jit
scripted_policy = torch.jit.script(model.policy_net)
torch.jit.save(scripted_policy, f"{path}.policy.jit")
# Save target network using jit
scripted_target = torch.jit.script(model.target_net)
torch.jit.save(scripted_target, f"{path}.target.jit")
# Save epsilon value separately
with open(f"{path}.epsilon.txt", "w") as f:
f.write(str(model.epsilon))
logger.info(f"Successfully saved model components with jit.save")
return True
except Exception as e:
logger.error(f"All save attempts failed: {e}")
return False
# Implement fetch_and_update_data function
async def fetch_and_update_data(exchange, env, symbol, timeframe):
"""
Fetch new candle data and update the environment
Args:
exchange: CCXT exchange instance
env: Trading environment instance
symbol: Trading pair symbol
timeframe: Timeframe for the candles
"""
logger.info(f"Fetching new data for {symbol} on {timeframe} timeframe")
try:
# Default to 100 candles if not specified
limit = 1000
# Fetch OHLCV data
candles = await exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
if not candles or len(candles) == 0:
logger.warning(f"No candles returned for {symbol} on {timeframe}")
return False
logger.info(f"Successfully fetched {len(candles)} candles")
# Convert to format expected by environment
formatted_candles = []
for candle in candles:
timestamp, open_price, high, low, close, volume = candle
formatted_candles.append({
'timestamp': timestamp,
'open': open_price,
'high': high,
'low': low,
'close': close,
'volume': volume
})
# Update environment data
env.data = formatted_candles
if hasattr(env, '_initialize_features'):
env._initialize_features()
logger.info(f"Updated environment with {len(formatted_candles)} candles")
# Print latest candle info
if formatted_candles:
latest = formatted_candles[-1]
dt = datetime.datetime.fromtimestamp(latest['timestamp']/1000).strftime('%Y-%m-%d %H:%M:%S')
logger.info(f"Latest candle: Time={dt}, Open={latest['open']}, High={latest['high']}, Low={latest['low']}, Close={latest['close']}, Volume={latest['volume']}")
return True
except Exception as e:
logger.error(f"Error fetching candle data: {e}")
logger.error(traceback.format_exc())
return False
async def live_training(
symbol="ETH/USDT",
timeframe="1m",
model_path="models/trading_agent_best_pnl.pt",
save_path="models/trading_agent_live_trained.pt",
initial_balance=1000,
update_interval=60,
training_iterations=100,
learning_rate=0.0001,
batch_size=64,
gamma=0.99,
window_size=30,
max_episodes=0, # 0 means unlimited
):
"""
Live training function that uses real market data to improve the model without executing real trades.
Args:
symbol: Trading pair symbol
timeframe: Timeframe for training
model_path: Path to the initial model to load
save_path: Path to save the improved model
initial_balance: Initial balance for simulation
update_interval: Interval to update data in seconds
training_iterations: Number of training iterations per data update
learning_rate: Learning rate for training
batch_size: Batch size for training
gamma: Discount factor for training
window_size: Window size for the environment
max_episodes: Maximum number of episodes (0 for unlimited)
"""
logger.info(f"Starting live training for {symbol} on {timeframe} timeframe")
# Initialize exchange (without sandbox mode)
exchange = None
try:
exchange = await initialize_exchange()
logger.info(f"Exchange initialized: {exchange.id}")
# Initialize environment
env = TradingEnvironment(
initial_balance=initial_balance,
window_size=window_size,
symbol=symbol,
timeframe=timeframe,
)
# Fetch initial data
logger.info(f"Fetching initial data for {symbol}")
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
if not success:
logger.error("Failed to fetch initial data, exiting")
return
# Initialize agent
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE, hidden_size=384)
# Load model if provided
if os.path.exists(model_path):
try:
agent.load(model_path)
logger.info(f"Model loaded successfully from {model_path}")
except Exception as e:
logger.warning(f"Error loading model: {e}")
logger.info("Starting with a new model")
else:
logger.warning(f"Model file {model_path} not found. Starting with a new model.")
# Initialize TensorBoard writer
run_id = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter(log_dir=f"runs/live_training_{run_id}")
agent.writer = writer
# Initialize training statistics
total_rewards = 0
episode_count = 0
best_reward = float('-inf')
best_pnl = float('-inf')
# Start live training loop
logger.info(f"Starting live training loop")
step_counter = 0
last_update_time = datetime.datetime.now()
while True:
# Check if we've reached the maximum number of episodes
if max_episodes > 0 and episode_count >= max_episodes:
logger.info(f"Reached maximum episodes ({max_episodes}), stopping")
break
# Check if it's time to update data
current_time = datetime.datetime.now()
time_diff = (current_time - last_update_time).total_seconds()
if time_diff >= update_interval:
logger.info(f"Updating market data after {time_diff:.1f} seconds")
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
if not success:
logger.warning("Failed to update data, will try again later")
# Wait a bit before trying again
await asyncio.sleep(5)
continue
last_update_time = current_time
# Run training iterations on the updated data
episode_reward = 0
env.reset()
done = False
# Run one simulated episode with the current data
steps_in_episode = 0
max_steps = len(env.data) - env.window_size - 1
logger.info(f"Starting episode {episode_count + 1} with {max_steps} steps")
while not done and steps_in_episode < max_steps:
try:
state = env.get_state()
action = agent.select_action(state, training=True)
try:
next_state, reward, done, info = env.step(action)
except ValueError as e:
logger.error(f"Error during env.step: {e}")
# If we get a ValueError, it might be because step is returning 3 values instead of 4
# Let's try to handle this case
if "too many values to unpack" in str(e):
logger.info("Trying alternative step format")
result = env.step(action)
if len(result) == 3:
next_state, reward, done = result
info = {}
else:
raise
else:
raise
# Save experience in replay memory
agent.memory.push(state, action, reward, next_state, done)
# Move to the next state
state = next_state
episode_reward += reward
step_counter += 1
steps_in_episode += 1
# Log action and results every 50 steps
if steps_in_episode % 50 == 0:
logger.info(f"Step {steps_in_episode}/{max_steps} | Action: {action} | Reward: {reward:.2f} | Balance: ${env.balance:.2f}")
# Train the agent on a batch of experiences
if len(agent.memory) > batch_size:
agent.learn()
# Additional training iterations
if steps_in_episode % 10 == 0 and training_iterations > 1:
for _ in range(training_iterations - 1):
agent.learn()
if done:
logger.info(f"Episode done after {steps_in_episode} steps")
break
except Exception as e:
logger.error(f"Error during episode step: {e}")
logger.error(traceback.format_exc())
break
# Update training statistics
episode_count += 1
total_rewards += episode_reward
avg_reward = total_rewards / episode_count
# Track metrics
writer.add_scalar('LiveTraining/Reward', episode_reward, episode_count)
writer.add_scalar('LiveTraining/AvgReward', avg_reward, episode_count)
writer.add_scalar('LiveTraining/Balance', env.balance, episode_count)
writer.add_scalar('LiveTraining/PnL', env.total_pnl, episode_count)
# Report progress
logger.info(f"""
Episode: {episode_count}
Reward: {episode_reward:.2f}
Avg Reward: {avg_reward:.2f}
Balance: ${env.balance:.2f}
PnL: ${env.total_pnl:.2f}
Memory Size: {len(agent.memory)}
Total Steps: {step_counter}
""")
# Save the model if it's the best so far (by reward or PnL)
if episode_reward > best_reward:
best_reward = episode_reward
reward_model_path = f"models/trading_agent_best_reward_{run_id}.pt"
if robust_save(agent, reward_model_path):
logger.info(f"New best reward model saved: {episode_reward:.2f} to {reward_model_path}")
else:
logger.error(f"Failed to save best reward model")
if env.total_pnl > best_pnl:
best_pnl = env.total_pnl
pnl_model_path = f"models/trading_agent_best_pnl_{run_id}.pt"
if robust_save(agent, pnl_model_path):
logger.info(f"New best PnL model saved: ${env.total_pnl:.2f} to {pnl_model_path}")
else:
logger.error(f"Failed to save best PnL model")
# Regularly save the model
if episode_count % 5 == 0:
if robust_save(agent, save_path):
logger.info(f"Model checkpoint saved to {save_path}")
else:
logger.error(f"Failed to save checkpoint")
# Update target network periodically
if episode_count % 5 == 0:
try:
agent.update_target_network()
logger.info("Target network updated")
except Exception as e:
logger.error(f"Error updating target network: {e}")
# Sleep to avoid excessive API calls
await asyncio.sleep(1)
except asyncio.CancelledError:
logger.info("Live training cancelled")
except KeyboardInterrupt:
logger.info("Live training stopped by user")
except Exception as e:
logger.error(f"Error in live training: {e}")
logger.error(traceback.format_exc())
finally:
# Save final model
if robust_save(agent, save_path):
logger.info(f"Final model saved to {save_path}")
else:
logger.error(f"Failed to save final model")
# Close TensorBoard writer
try:
writer.close()
logger.info("TensorBoard writer closed")
except Exception as e:
logger.error(f"Error closing TensorBoard writer: {e}")
# Close exchange connection
if exchange:
try:
await exchange.close()
logger.info("Exchange connection closed")
except Exception as e:
logger.error(f"Error closing exchange connection: {e}")
logger.info("Live training completed")
async def main():
"""Main function to parse arguments and start live training"""
parser = argparse.ArgumentParser(description='Live Training with Real Market Data')
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for training')
parser.add_argument('--model_path', type=str, default='models/trading_agent_best_pnl.pt', help='Path to initial model')
parser.add_argument('--save_path', type=str, default='models/trading_agent_live_trained.pt', help='Path to save improved model')
parser.add_argument('--initial_balance', type=float, default=1000, help='Initial balance for simulation')
parser.add_argument('--update_interval', type=int, default=60, help='Interval to update data in seconds')
parser.add_argument('--training_iterations', type=int, default=100, help='Training iterations per update')
parser.add_argument('--max_episodes', type=int, default=0, help='Maximum number of episodes (0 for unlimited)')
args = parser.parse_args()
logger.info(f"Starting live training with {args.symbol} on {args.timeframe} timeframe")
await live_training(
symbol=args.symbol,
timeframe=args.timeframe,
model_path=args.model_path,
save_path=args.save_path,
initial_balance=args.initial_balance,
update_interval=args.update_interval,
training_iterations=args.training_iterations,
max_episodes=args.max_episodes,
)
# At the beginning of the file, after importing the modules
# Override Agent's save method with our robust save function
def monkey_patch_agent_save():
"""Replace Agent's save method with our robust save approach"""
original_save = Agent.save
def patched_save(self, path):
return robust_save(self, path)
# Apply the patch
Agent.save = patched_save
logger.info("Monkey patched Agent.save with robust_save")
# Return the original method in case we need to restore it
return original_save
# Call the monkey patch function at the appropriate place
if __name__ == "__main__":
try:
print("Starting live training script")
# Apply the monkey patch before running the main function
original_save = monkey_patch_agent_save()
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Live training stopped by user")
except Exception as e:
logger.error(f"Error in main function: {e}")
logger.error(traceback.format_exc())

View File

@ -0,0 +1,118 @@
#!/usr/bin/env python
import asyncio
import logging
import sys
import platform
import ccxt.async_support as ccxt
import os
import datetime
# Fix for Windows asyncio issues with aiodns
if platform.system() == 'Windows':
try:
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
print("Using Windows SelectorEventLoopPolicy to fix aiodns issue")
except Exception as e:
print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}")
# Setup direct console logging for immediate feedback
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
async def initialize_exchange():
"""Initialize the exchange with API credentials from environment variables"""
exchange_id = 'mexc'
try:
# Get API credentials from environment variables
api_key = os.getenv('MEXC_API_KEY', '')
secret_key = os.getenv('MEXC_SECRET_KEY', '')
# Initialize the exchange
exchange_class = getattr(ccxt, exchange_id)
exchange = exchange_class({
'apiKey': api_key,
'secret': secret_key,
'enableRateLimit': True,
})
logger.info(f"Exchange initialized with standard CCXT: {exchange_id}")
return exchange
except Exception as e:
logger.error(f"Error initializing exchange: {e}")
raise
async def fetch_ohlcv_data(exchange, symbol, timeframe, limit=1000):
"""Fetch OHLCV data from the exchange"""
logger.info(f"Fetching {limit} {timeframe} candles for {symbol} (attempt 1/3)")
try:
candles = await exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
if not candles or len(candles) == 0:
logger.warning(f"No candles returned for {symbol} on {timeframe}")
return None
logger.info(f"Successfully fetched {len(candles)} candles")
return candles
except Exception as e:
logger.error(f"Error fetching candle data: {e}")
return None
async def main():
"""Main function to test live data fetching"""
symbol = "ETH/USDT"
timeframe = "1m"
logger.info(f"Starting simplified live training test for {symbol} on {timeframe}")
try:
# Initialize exchange
exchange = await initialize_exchange()
# Fetch data every 10 seconds
for i in range(5):
logger.info(f"Fetch attempt {i+1}/5")
candles = await fetch_ohlcv_data(exchange, symbol, timeframe)
if candles:
# Print the latest candle
latest = candles[-1]
timestamp, open_price, high, low, close, volume = latest
dt = datetime.datetime.fromtimestamp(timestamp/1000).strftime('%Y-%m-%d %H:%M:%S')
logger.info(f"Latest candle: Time={dt}, Open={open_price}, High={high}, Low={low}, Close={close}, Volume={volume}")
# Wait 10 seconds before next fetch
if i < 4: # Don't wait after the last fetch
logger.info("Waiting 10 seconds before next fetch...")
await asyncio.sleep(10)
# Close exchange connection
await exchange.close()
logger.info("Exchange connection closed")
except Exception as e:
logger.error(f"Error in simplified live training test: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
try:
await exchange.close()
except:
pass
logger.info("Test completed")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Test stopped by user")
except Exception as e:
logger.error(f"Error in main function: {e}")
import traceback
logger.error(traceback.format_exc())

182
crypto/gogo2/test_save.py Normal file
View File

@ -0,0 +1,182 @@
#!/usr/bin/env python
import torch
import torch.nn as nn
import os
import logging
import sys
import platform
# Fix for Windows asyncio issues with aiodns
if platform.system() == 'Windows':
try:
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
print("Using Windows SelectorEventLoopPolicy to fix aiodns issue")
except Exception as e:
print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}")
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("test_save.log"),
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# Define a simple model for testing
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 50)
self.fc2 = nn.Linear(50, 20)
self.fc3 = nn.Linear(20, 5)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
# Create a simple Agent class for testing
class TestAgent:
def __init__(self):
self.policy_net = SimpleModel()
self.target_net = SimpleModel()
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=0.001)
self.epsilon = 0.1
def save(self, path):
"""Standard save method that might fail"""
checkpoint = {
'policy_net': self.policy_net.state_dict(),
'target_net': self.target_net.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epsilon': self.epsilon
}
torch.save(checkpoint, path)
logger.info(f"Model saved to {path}")
# Robust save function with multiple fallback approaches
def robust_save(model, path):
"""
Robust model saving with multiple fallback approaches
Args:
model: The Agent model to save
path: Path to save the model
Returns:
bool: True if successful, False otherwise
"""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
# Backup path in case the main save fails
backup_path = f"{path}.backup"
# Attempt 1: Try with default settings in a separate file first
try:
logger.info(f"Saving model to {backup_path} (attempt 1)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, backup_path)
logger.info(f"Successfully saved to {backup_path}")
# If backup worked, copy to the actual path
if os.path.exists(backup_path):
import shutil
shutil.copy(backup_path, path)
logger.info(f"Copied backup to {path}")
return True
except Exception as e:
logger.warning(f"First save attempt failed: {e}")
# Attempt 2: Try with pickle protocol 2 (more compatible)
try:
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path, pickle_protocol=2)
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
return True
except Exception as e:
logger.warning(f"Second save attempt failed: {e}")
# Attempt 3: Try without optimizer state (which can be large and cause issues)
try:
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path)
logger.info(f"Successfully saved to {path} without optimizer state")
return True
except Exception as e:
logger.warning(f"Third save attempt failed: {e}")
# Attempt 4: Try with torch.jit.save instead
try:
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
# Save policy network using jit
scripted_policy = torch.jit.script(model.policy_net)
torch.jit.save(scripted_policy, f"{path}.policy.jit")
# Save target network using jit
scripted_target = torch.jit.script(model.target_net)
torch.jit.save(scripted_target, f"{path}.target.jit")
# Save epsilon value separately
with open(f"{path}.epsilon.txt", "w") as f:
f.write(str(model.epsilon))
logger.info(f"Successfully saved model components with jit.save")
return True
except Exception as e:
logger.error(f"All save attempts failed: {e}")
return False
def main():
# Create a test directory
save_dir = "test_models"
os.makedirs(save_dir, exist_ok=True)
# Create a test agent
agent = TestAgent()
# Test the regular save method (might fail)
try:
logger.info("Testing regular save method...")
save_path = os.path.join(save_dir, "regular_save.pt")
agent.save(save_path)
logger.info("Regular save succeeded")
except Exception as e:
logger.error(f"Regular save failed: {e}")
# Test our robust save method
logger.info("Testing robust save method...")
save_path = os.path.join(save_dir, "robust_save.pt")
success = robust_save(agent, save_path)
if success:
logger.info("Robust save succeeded!")
else:
logger.error("Robust save failed!")
# Check which files were created
logger.info("Files created:")
for file in os.listdir(save_dir):
file_path = os.path.join(save_dir, file)
file_size = os.path.getsize(file_path)
logger.info(f" - {file} ({file_size} bytes)")
if __name__ == "__main__":
main()