fixed training again
This commit is contained in:
parent
bdf6afc6ad
commit
2c03675f3c
@ -9,6 +9,8 @@ import datetime
|
|||||||
import traceback
|
import traceback
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import gc
|
||||||
|
from functools import partial
|
||||||
from main import initialize_exchange, TradingEnvironment, Agent
|
from main import initialize_exchange, TradingEnvironment, Agent
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -54,6 +56,11 @@ def robust_save(model, path):
|
|||||||
# Backup path in case the main save fails
|
# Backup path in case the main save fails
|
||||||
backup_path = f"{path}.backup"
|
backup_path = f"{path}.backup"
|
||||||
|
|
||||||
|
# Clean up GPU memory before saving
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
# Attempt 1: Try with default settings in a separate file first
|
# Attempt 1: Try with default settings in a separate file first
|
||||||
try:
|
try:
|
||||||
logger.info(f"Saving model to {backup_path} (attempt 1)")
|
logger.info(f"Saving model to {backup_path} (attempt 1)")
|
||||||
@ -122,6 +129,28 @@ def robust_save(model, path):
|
|||||||
logger.error(f"All save attempts failed: {e}")
|
logger.error(f"All save attempts failed: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Implement timeout wrapper for exchange operations
|
||||||
|
async def with_timeout(coroutine, timeout=30, default=None):
|
||||||
|
"""
|
||||||
|
Execute a coroutine with a timeout
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coroutine: The coroutine to execute
|
||||||
|
timeout: Timeout in seconds
|
||||||
|
default: Default value to return on timeout
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the coroutine or default value on timeout
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(coroutine, timeout=timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"Operation timed out after {timeout} seconds")
|
||||||
|
return default
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Operation failed: {e}")
|
||||||
|
return default
|
||||||
|
|
||||||
# Implement fetch_and_update_data function
|
# Implement fetch_and_update_data function
|
||||||
async def fetch_and_update_data(exchange, env, symbol, timeframe):
|
async def fetch_and_update_data(exchange, env, symbol, timeframe):
|
||||||
"""
|
"""
|
||||||
@ -139,8 +168,12 @@ async def fetch_and_update_data(exchange, env, symbol, timeframe):
|
|||||||
# Default to 100 candles if not specified
|
# Default to 100 candles if not specified
|
||||||
limit = 1000
|
limit = 1000
|
||||||
|
|
||||||
# Fetch OHLCV data
|
# Fetch OHLCV data with timeout
|
||||||
candles = await exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
|
candles = await with_timeout(
|
||||||
|
exchange.fetch_ohlcv(symbol, timeframe, limit=limit),
|
||||||
|
timeout=30,
|
||||||
|
default=[]
|
||||||
|
)
|
||||||
|
|
||||||
if not candles or len(candles) == 0:
|
if not candles or len(candles) == 0:
|
||||||
logger.warning(f"No candles returned for {symbol} on {timeframe}")
|
logger.warning(f"No candles returned for {symbol} on {timeframe}")
|
||||||
@ -181,6 +214,16 @@ async def fetch_and_update_data(exchange, env, symbol, timeframe):
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Implement memory management function
|
||||||
|
def manage_memory():
|
||||||
|
"""
|
||||||
|
Clean up memory to avoid memory leaks during long running sessions
|
||||||
|
"""
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
logger.debug("Memory cleaned")
|
||||||
|
|
||||||
async def live_training(
|
async def live_training(
|
||||||
symbol="ETH/USDT",
|
symbol="ETH/USDT",
|
||||||
timeframe="1m",
|
timeframe="1m",
|
||||||
@ -194,6 +237,8 @@ async def live_training(
|
|||||||
gamma=0.99,
|
gamma=0.99,
|
||||||
window_size=30,
|
window_size=30,
|
||||||
max_episodes=0, # 0 means unlimited
|
max_episodes=0, # 0 means unlimited
|
||||||
|
retry_delay=5, # Seconds to wait before retrying after an error
|
||||||
|
max_retries=3, # Maximum number of retries for operations
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Live training function that uses real market data to improve the model without executing real trades.
|
Live training function that uses real market data to improve the model without executing real trades.
|
||||||
@ -211,15 +256,30 @@ async def live_training(
|
|||||||
gamma: Discount factor for training
|
gamma: Discount factor for training
|
||||||
window_size: Window size for the environment
|
window_size: Window size for the environment
|
||||||
max_episodes: Maximum number of episodes (0 for unlimited)
|
max_episodes: Maximum number of episodes (0 for unlimited)
|
||||||
|
retry_delay: Seconds to wait before retrying after an error
|
||||||
|
max_retries: Maximum number of retries for operations
|
||||||
"""
|
"""
|
||||||
logger.info(f"Starting live training for {symbol} on {timeframe} timeframe")
|
logger.info(f"Starting live training for {symbol} on {timeframe} timeframe")
|
||||||
|
|
||||||
# Initialize exchange (without sandbox mode)
|
# Initialize exchange (without sandbox mode)
|
||||||
exchange = None
|
exchange = None
|
||||||
|
|
||||||
|
# Retry loop for exchange initialization
|
||||||
|
for retry in range(max_retries):
|
||||||
try:
|
try:
|
||||||
exchange = await initialize_exchange()
|
exchange = await initialize_exchange()
|
||||||
logger.info(f"Exchange initialized: {exchange.id}")
|
logger.info(f"Exchange initialized: {exchange.id}")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error initializing exchange (attempt {retry+1}/{max_retries}): {e}")
|
||||||
|
if retry < max_retries - 1:
|
||||||
|
logger.info(f"Retrying in {retry_delay} seconds...")
|
||||||
|
await asyncio.sleep(retry_delay)
|
||||||
|
else:
|
||||||
|
logger.error("Max retries reached. Could not initialize exchange.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
# Initialize environment
|
# Initialize environment
|
||||||
env = TradingEnvironment(
|
env = TradingEnvironment(
|
||||||
initial_balance=initial_balance,
|
initial_balance=initial_balance,
|
||||||
@ -228,11 +288,20 @@ async def live_training(
|
|||||||
timeframe=timeframe,
|
timeframe=timeframe,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fetch initial data
|
# Fetch initial data (with retries)
|
||||||
logger.info(f"Fetching initial data for {symbol}")
|
logger.info(f"Fetching initial data for {symbol}")
|
||||||
|
success = False
|
||||||
|
for retry in range(max_retries):
|
||||||
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
|
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
|
||||||
|
if success:
|
||||||
|
break
|
||||||
|
logger.warning(f"Failed to fetch initial data (attempt {retry+1}/{max_retries})")
|
||||||
|
if retry < max_retries - 1:
|
||||||
|
logger.info(f"Retrying in {retry_delay} seconds...")
|
||||||
|
await asyncio.sleep(retry_delay)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
logger.error("Failed to fetch initial data, exiting")
|
logger.error("Failed to fetch initial data after multiple attempts, exiting")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Initialize agent
|
# Initialize agent
|
||||||
@ -268,6 +337,10 @@ async def live_training(
|
|||||||
step_counter = 0
|
step_counter = 0
|
||||||
last_update_time = datetime.datetime.now()
|
last_update_time = datetime.datetime.now()
|
||||||
|
|
||||||
|
# Track consecutive errors to enable circuit breaker
|
||||||
|
consecutive_errors = 0
|
||||||
|
max_consecutive_errors = 5
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# Check if we've reached the maximum number of episodes
|
# Check if we've reached the maximum number of episodes
|
||||||
if max_episodes > 0 and episode_count >= max_episodes:
|
if max_episodes > 0 and episode_count >= max_episodes:
|
||||||
@ -284,11 +357,14 @@ async def live_training(
|
|||||||
if not success:
|
if not success:
|
||||||
logger.warning("Failed to update data, will try again later")
|
logger.warning("Failed to update data, will try again later")
|
||||||
# Wait a bit before trying again
|
# Wait a bit before trying again
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(retry_delay)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
last_update_time = current_time
|
last_update_time = current_time
|
||||||
|
|
||||||
|
# Clean up memory before running an episode
|
||||||
|
manage_memory()
|
||||||
|
|
||||||
# Run training iterations on the updated data
|
# Run training iterations on the updated data
|
||||||
episode_reward = 0
|
episode_reward = 0
|
||||||
env.reset()
|
env.reset()
|
||||||
@ -337,6 +413,7 @@ async def live_training(
|
|||||||
|
|
||||||
# Train the agent on a batch of experiences
|
# Train the agent on a batch of experiences
|
||||||
if len(agent.memory) > batch_size:
|
if len(agent.memory) > batch_size:
|
||||||
|
try:
|
||||||
agent.learn()
|
agent.learn()
|
||||||
|
|
||||||
# Additional training iterations
|
# Additional training iterations
|
||||||
@ -344,6 +421,15 @@ async def live_training(
|
|||||||
for _ in range(training_iterations - 1):
|
for _ in range(training_iterations - 1):
|
||||||
agent.learn()
|
agent.learn()
|
||||||
|
|
||||||
|
# Reset consecutive errors counter on successful learning
|
||||||
|
consecutive_errors = 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during learning: {e}")
|
||||||
|
consecutive_errors += 1
|
||||||
|
if consecutive_errors >= max_consecutive_errors:
|
||||||
|
logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
|
||||||
|
break
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
logger.info(f"Episode done after {steps_in_episode} steps")
|
logger.info(f"Episode done after {steps_in_episode} steps")
|
||||||
break
|
break
|
||||||
@ -351,6 +437,9 @@ async def live_training(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during episode step: {e}")
|
logger.error(f"Error during episode step: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
consecutive_errors += 1
|
||||||
|
if consecutive_errors >= max_consecutive_errors:
|
||||||
|
logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Update training statistics
|
# Update training statistics
|
||||||
@ -419,6 +508,7 @@ async def live_training(
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
finally:
|
finally:
|
||||||
# Save final model
|
# Save final model
|
||||||
|
if 'agent' in locals():
|
||||||
if robust_save(agent, save_path):
|
if robust_save(agent, save_path):
|
||||||
logger.info(f"Final model saved to {save_path}")
|
logger.info(f"Final model saved to {save_path}")
|
||||||
else:
|
else:
|
||||||
@ -434,11 +524,13 @@ async def live_training(
|
|||||||
# Close exchange connection
|
# Close exchange connection
|
||||||
if exchange:
|
if exchange:
|
||||||
try:
|
try:
|
||||||
await exchange.close()
|
await with_timeout(exchange.close(), timeout=10)
|
||||||
logger.info("Exchange connection closed")
|
logger.info("Exchange connection closed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error closing exchange connection: {e}")
|
logger.error(f"Error closing exchange connection: {e}")
|
||||||
|
|
||||||
|
# Final memory cleanup
|
||||||
|
manage_memory()
|
||||||
logger.info("Live training completed")
|
logger.info("Live training completed")
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
@ -452,6 +544,8 @@ async def main():
|
|||||||
parser.add_argument('--update_interval', type=int, default=60, help='Interval to update data in seconds')
|
parser.add_argument('--update_interval', type=int, default=60, help='Interval to update data in seconds')
|
||||||
parser.add_argument('--training_iterations', type=int, default=100, help='Training iterations per update')
|
parser.add_argument('--training_iterations', type=int, default=100, help='Training iterations per update')
|
||||||
parser.add_argument('--max_episodes', type=int, default=0, help='Maximum number of episodes (0 for unlimited)')
|
parser.add_argument('--max_episodes', type=int, default=0, help='Maximum number of episodes (0 for unlimited)')
|
||||||
|
parser.add_argument('--retry_delay', type=int, default=5, help='Seconds to wait before retrying after an error')
|
||||||
|
parser.add_argument('--max_retries', type=int, default=3, help='Maximum number of retries for operations')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -466,9 +560,10 @@ async def main():
|
|||||||
update_interval=args.update_interval,
|
update_interval=args.update_interval,
|
||||||
training_iterations=args.training_iterations,
|
training_iterations=args.training_iterations,
|
||||||
max_episodes=args.max_episodes,
|
max_episodes=args.max_episodes,
|
||||||
|
retry_delay=args.retry_delay,
|
||||||
|
max_retries=args.max_retries,
|
||||||
)
|
)
|
||||||
|
|
||||||
# At the beginning of the file, after importing the modules
|
|
||||||
# Override Agent's save method with our robust save function
|
# Override Agent's save method with our robust save function
|
||||||
def monkey_patch_agent_save():
|
def monkey_patch_agent_save():
|
||||||
"""Replace Agent's save method with our robust save approach"""
|
"""Replace Agent's save method with our robust save approach"""
|
||||||
|
File diff suppressed because it is too large
Load Diff
227
crypto/gogo2/test_model_save_load.py
Normal file
227
crypto/gogo2/test_model_save_load.py
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import traceback
|
||||||
|
import shutil
|
||||||
|
from main import Agent, robust_save
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler("test_model_save_load.log"),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def create_test_directory():
|
||||||
|
"""Create a test directory for saving models"""
|
||||||
|
test_dir = "test_models"
|
||||||
|
os.makedirs(test_dir, exist_ok=True)
|
||||||
|
return test_dir
|
||||||
|
|
||||||
|
def test_save_load_cycle(state_size=64, action_size=4, hidden_size=384):
|
||||||
|
"""Test a full cycle of saving and loading models"""
|
||||||
|
test_dir = create_test_directory()
|
||||||
|
|
||||||
|
# Create a test agent
|
||||||
|
logger.info(f"Creating test agent with state_size={state_size}, action_size={action_size}, hidden_size={hidden_size}")
|
||||||
|
agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
|
||||||
|
|
||||||
|
# Define paths for testing
|
||||||
|
save_path = os.path.join(test_dir, "test_agent.pt")
|
||||||
|
|
||||||
|
# Test saving
|
||||||
|
logger.info(f"Testing save to {save_path}")
|
||||||
|
save_success = agent.save(save_path)
|
||||||
|
|
||||||
|
if save_success:
|
||||||
|
logger.info(f"Save successful, model size: {os.path.getsize(save_path)} bytes")
|
||||||
|
else:
|
||||||
|
logger.error("Save failed!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Memory cleanup
|
||||||
|
del agent
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Test loading
|
||||||
|
logger.info(f"Testing load from {save_path}")
|
||||||
|
try:
|
||||||
|
new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
|
||||||
|
new_agent.load(save_path)
|
||||||
|
logger.info("Load successful")
|
||||||
|
|
||||||
|
# Verify model architecture
|
||||||
|
logger.info(f"Verifying model architecture")
|
||||||
|
assert new_agent.state_size == state_size, f"Expected state_size={state_size}, got {new_agent.state_size}"
|
||||||
|
assert new_agent.action_size == action_size, f"Expected action_size={action_size}, got {new_agent.action_size}"
|
||||||
|
assert new_agent.hidden_size == hidden_size, f"Expected hidden_size={hidden_size}, got {new_agent.hidden_size}"
|
||||||
|
|
||||||
|
logger.info("Model architecture verified correctly")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during load or verification: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_robust_save_methods(state_size=64, action_size=4, hidden_size=384):
|
||||||
|
"""Test all the robust save methods"""
|
||||||
|
test_dir = create_test_directory()
|
||||||
|
|
||||||
|
# Create a test agent
|
||||||
|
logger.info(f"Creating test agent for robust save testing")
|
||||||
|
agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
|
||||||
|
|
||||||
|
# Test each robust save method
|
||||||
|
methods = [
|
||||||
|
("regular", os.path.join(test_dir, "regular_save.pt")),
|
||||||
|
("backup", os.path.join(test_dir, "backup_save.pt")),
|
||||||
|
("pickle2", os.path.join(test_dir, "pickle2_save.pt")),
|
||||||
|
("no_optimizer", os.path.join(test_dir, "no_optimizer_save.pt")),
|
||||||
|
("jit", os.path.join(test_dir, "jit_save.pt"))
|
||||||
|
]
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for method_name, save_path in methods:
|
||||||
|
logger.info(f"Testing {method_name} save method to {save_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if method_name == "regular":
|
||||||
|
# Use regular save
|
||||||
|
success = agent.save(save_path)
|
||||||
|
elif method_name == "backup":
|
||||||
|
# Use backup method directly
|
||||||
|
backup_path = f"{save_path}.backup"
|
||||||
|
checkpoint = {
|
||||||
|
'policy_net': agent.policy_net.state_dict(),
|
||||||
|
'target_net': agent.target_net.state_dict(),
|
||||||
|
'optimizer': agent.optimizer.state_dict(),
|
||||||
|
'epsilon': agent.epsilon,
|
||||||
|
'state_size': agent.state_size,
|
||||||
|
'action_size': agent.action_size,
|
||||||
|
'hidden_size': agent.hidden_size
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, backup_path)
|
||||||
|
shutil.copy(backup_path, save_path)
|
||||||
|
success = os.path.exists(save_path)
|
||||||
|
elif method_name == "pickle2":
|
||||||
|
# Use pickle protocol 2
|
||||||
|
checkpoint = {
|
||||||
|
'policy_net': agent.policy_net.state_dict(),
|
||||||
|
'target_net': agent.target_net.state_dict(),
|
||||||
|
'optimizer': agent.optimizer.state_dict(),
|
||||||
|
'epsilon': agent.epsilon,
|
||||||
|
'state_size': agent.state_size,
|
||||||
|
'action_size': agent.action_size,
|
||||||
|
'hidden_size': agent.hidden_size
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, save_path, pickle_protocol=2)
|
||||||
|
success = os.path.exists(save_path)
|
||||||
|
elif method_name == "no_optimizer":
|
||||||
|
# Save without optimizer
|
||||||
|
checkpoint = {
|
||||||
|
'policy_net': agent.policy_net.state_dict(),
|
||||||
|
'target_net': agent.target_net.state_dict(),
|
||||||
|
'epsilon': agent.epsilon,
|
||||||
|
'state_size': agent.state_size,
|
||||||
|
'action_size': agent.action_size,
|
||||||
|
'hidden_size': agent.hidden_size
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, save_path)
|
||||||
|
success = os.path.exists(save_path)
|
||||||
|
elif method_name == "jit":
|
||||||
|
# Use JIT save
|
||||||
|
try:
|
||||||
|
scripted_policy = torch.jit.script(agent.policy_net)
|
||||||
|
torch.jit.save(scripted_policy, f"{save_path}.policy.jit")
|
||||||
|
|
||||||
|
scripted_target = torch.jit.script(agent.target_net)
|
||||||
|
torch.jit.save(scripted_target, f"{save_path}.target.jit")
|
||||||
|
|
||||||
|
# Save parameters
|
||||||
|
with open(f"{save_path}.params.json", "w") as f:
|
||||||
|
import json
|
||||||
|
params = {
|
||||||
|
'epsilon': float(agent.epsilon),
|
||||||
|
'state_size': int(agent.state_size),
|
||||||
|
'action_size': int(agent.action_size),
|
||||||
|
'hidden_size': int(agent.hidden_size)
|
||||||
|
}
|
||||||
|
json.dump(params, f)
|
||||||
|
|
||||||
|
success = (os.path.exists(f"{save_path}.policy.jit") and
|
||||||
|
os.path.exists(f"{save_path}.target.jit") and
|
||||||
|
os.path.exists(f"{save_path}.params.json"))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"JIT save failed: {e}")
|
||||||
|
success = False
|
||||||
|
|
||||||
|
if success:
|
||||||
|
if method_name != "jit":
|
||||||
|
file_size = os.path.getsize(save_path)
|
||||||
|
logger.info(f"{method_name} save successful, size: {file_size} bytes")
|
||||||
|
else:
|
||||||
|
logger.info(f"{method_name} save successful")
|
||||||
|
results[method_name] = True
|
||||||
|
else:
|
||||||
|
logger.error(f"{method_name} save failed")
|
||||||
|
results[method_name] = False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during {method_name} save: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
results[method_name] = False
|
||||||
|
|
||||||
|
# Test loading each saved model
|
||||||
|
for method_name, save_path in methods:
|
||||||
|
if not results[method_name]:
|
||||||
|
logger.info(f"Skipping load test for {method_name} (save failed)")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if method_name == "jit":
|
||||||
|
logger.info(f"Skipping load test for {method_name} (requires special loading)")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"Testing load from {save_path}")
|
||||||
|
try:
|
||||||
|
new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
|
||||||
|
new_agent.load(save_path)
|
||||||
|
logger.info(f"Load successful for {method_name} save")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading from {method_name} save: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
results[method_name] += " (load failed)"
|
||||||
|
|
||||||
|
# Return summary of results
|
||||||
|
return results
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Test model saving and loading')
|
||||||
|
parser.add_argument('--state_size', type=int, default=64, help='State size for test model')
|
||||||
|
parser.add_argument('--action_size', type=int, default=4, help='Action size for test model')
|
||||||
|
parser.add_argument('--hidden_size', type=int, default=384, help='Hidden size for test model')
|
||||||
|
parser.add_argument('--test_robust', action='store_true', help='Test all robust save methods')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logger.info("Starting model save/load test")
|
||||||
|
|
||||||
|
if args.test_robust:
|
||||||
|
results = test_robust_save_methods(args.state_size, args.action_size, args.hidden_size)
|
||||||
|
logger.info(f"Robust save method results: {results}")
|
||||||
|
else:
|
||||||
|
success = test_save_load_cycle(args.state_size, args.action_size, args.hidden_size)
|
||||||
|
logger.info(f"Save/load cycle {'successful' if success else 'failed'}")
|
||||||
|
|
||||||
|
logger.info("Test completed")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user