fixed training again
This commit is contained in:
parent
bdf6afc6ad
commit
2c03675f3c
@ -9,6 +9,8 @@ import datetime
|
||||
import traceback
|
||||
import numpy as np
|
||||
import torch
|
||||
import gc
|
||||
from functools import partial
|
||||
from main import initialize_exchange, TradingEnvironment, Agent
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -54,6 +56,11 @@ def robust_save(model, path):
|
||||
# Backup path in case the main save fails
|
||||
backup_path = f"{path}.backup"
|
||||
|
||||
# Clean up GPU memory before saving
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Attempt 1: Try with default settings in a separate file first
|
||||
try:
|
||||
logger.info(f"Saving model to {backup_path} (attempt 1)")
|
||||
@ -122,6 +129,28 @@ def robust_save(model, path):
|
||||
logger.error(f"All save attempts failed: {e}")
|
||||
return False
|
||||
|
||||
# Implement timeout wrapper for exchange operations
|
||||
async def with_timeout(coroutine, timeout=30, default=None):
|
||||
"""
|
||||
Execute a coroutine with a timeout
|
||||
|
||||
Args:
|
||||
coroutine: The coroutine to execute
|
||||
timeout: Timeout in seconds
|
||||
default: Default value to return on timeout
|
||||
|
||||
Returns:
|
||||
The result of the coroutine or default value on timeout
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(coroutine, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Operation timed out after {timeout} seconds")
|
||||
return default
|
||||
except Exception as e:
|
||||
logger.error(f"Operation failed: {e}")
|
||||
return default
|
||||
|
||||
# Implement fetch_and_update_data function
|
||||
async def fetch_and_update_data(exchange, env, symbol, timeframe):
|
||||
"""
|
||||
@ -139,8 +168,12 @@ async def fetch_and_update_data(exchange, env, symbol, timeframe):
|
||||
# Default to 100 candles if not specified
|
||||
limit = 1000
|
||||
|
||||
# Fetch OHLCV data
|
||||
candles = await exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
|
||||
# Fetch OHLCV data with timeout
|
||||
candles = await with_timeout(
|
||||
exchange.fetch_ohlcv(symbol, timeframe, limit=limit),
|
||||
timeout=30,
|
||||
default=[]
|
||||
)
|
||||
|
||||
if not candles or len(candles) == 0:
|
||||
logger.warning(f"No candles returned for {symbol} on {timeframe}")
|
||||
@ -181,6 +214,16 @@ async def fetch_and_update_data(exchange, env, symbol, timeframe):
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
# Implement memory management function
|
||||
def manage_memory():
|
||||
"""
|
||||
Clean up memory to avoid memory leaks during long running sessions
|
||||
"""
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
logger.debug("Memory cleaned")
|
||||
|
||||
async def live_training(
|
||||
symbol="ETH/USDT",
|
||||
timeframe="1m",
|
||||
@ -194,6 +237,8 @@ async def live_training(
|
||||
gamma=0.99,
|
||||
window_size=30,
|
||||
max_episodes=0, # 0 means unlimited
|
||||
retry_delay=5, # Seconds to wait before retrying after an error
|
||||
max_retries=3, # Maximum number of retries for operations
|
||||
):
|
||||
"""
|
||||
Live training function that uses real market data to improve the model without executing real trades.
|
||||
@ -211,15 +256,30 @@ async def live_training(
|
||||
gamma: Discount factor for training
|
||||
window_size: Window size for the environment
|
||||
max_episodes: Maximum number of episodes (0 for unlimited)
|
||||
retry_delay: Seconds to wait before retrying after an error
|
||||
max_retries: Maximum number of retries for operations
|
||||
"""
|
||||
logger.info(f"Starting live training for {symbol} on {timeframe} timeframe")
|
||||
|
||||
# Initialize exchange (without sandbox mode)
|
||||
exchange = None
|
||||
|
||||
# Retry loop for exchange initialization
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
exchange = await initialize_exchange()
|
||||
logger.info(f"Exchange initialized: {exchange.id}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing exchange (attempt {retry+1}/{max_retries}): {e}")
|
||||
if retry < max_retries - 1:
|
||||
logger.info(f"Retrying in {retry_delay} seconds...")
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.error("Max retries reached. Could not initialize exchange.")
|
||||
return
|
||||
|
||||
try:
|
||||
exchange = await initialize_exchange()
|
||||
logger.info(f"Exchange initialized: {exchange.id}")
|
||||
|
||||
# Initialize environment
|
||||
env = TradingEnvironment(
|
||||
initial_balance=initial_balance,
|
||||
@ -228,11 +288,20 @@ async def live_training(
|
||||
timeframe=timeframe,
|
||||
)
|
||||
|
||||
# Fetch initial data
|
||||
# Fetch initial data (with retries)
|
||||
logger.info(f"Fetching initial data for {symbol}")
|
||||
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
|
||||
success = False
|
||||
for retry in range(max_retries):
|
||||
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
|
||||
if success:
|
||||
break
|
||||
logger.warning(f"Failed to fetch initial data (attempt {retry+1}/{max_retries})")
|
||||
if retry < max_retries - 1:
|
||||
logger.info(f"Retrying in {retry_delay} seconds...")
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
if not success:
|
||||
logger.error("Failed to fetch initial data, exiting")
|
||||
logger.error("Failed to fetch initial data after multiple attempts, exiting")
|
||||
return
|
||||
|
||||
# Initialize agent
|
||||
@ -268,6 +337,10 @@ async def live_training(
|
||||
step_counter = 0
|
||||
last_update_time = datetime.datetime.now()
|
||||
|
||||
# Track consecutive errors to enable circuit breaker
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 5
|
||||
|
||||
while True:
|
||||
# Check if we've reached the maximum number of episodes
|
||||
if max_episodes > 0 and episode_count >= max_episodes:
|
||||
@ -284,11 +357,14 @@ async def live_training(
|
||||
if not success:
|
||||
logger.warning("Failed to update data, will try again later")
|
||||
# Wait a bit before trying again
|
||||
await asyncio.sleep(5)
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
|
||||
last_update_time = current_time
|
||||
|
||||
# Clean up memory before running an episode
|
||||
manage_memory()
|
||||
|
||||
# Run training iterations on the updated data
|
||||
episode_reward = 0
|
||||
env.reset()
|
||||
@ -337,12 +413,22 @@ async def live_training(
|
||||
|
||||
# Train the agent on a batch of experiences
|
||||
if len(agent.memory) > batch_size:
|
||||
agent.learn()
|
||||
|
||||
# Additional training iterations
|
||||
if steps_in_episode % 10 == 0 and training_iterations > 1:
|
||||
for _ in range(training_iterations - 1):
|
||||
agent.learn()
|
||||
try:
|
||||
agent.learn()
|
||||
|
||||
# Additional training iterations
|
||||
if steps_in_episode % 10 == 0 and training_iterations > 1:
|
||||
for _ in range(training_iterations - 1):
|
||||
agent.learn()
|
||||
|
||||
# Reset consecutive errors counter on successful learning
|
||||
consecutive_errors = 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error during learning: {e}")
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
|
||||
break
|
||||
|
||||
if done:
|
||||
logger.info(f"Episode done after {steps_in_episode} steps")
|
||||
@ -351,7 +437,10 @@ async def live_training(
|
||||
except Exception as e:
|
||||
logger.error(f"Error during episode step: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
break
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
|
||||
break
|
||||
|
||||
# Update training statistics
|
||||
episode_count += 1
|
||||
@ -419,26 +508,29 @@ async def live_training(
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
# Save final model
|
||||
if robust_save(agent, save_path):
|
||||
logger.info(f"Final model saved to {save_path}")
|
||||
else:
|
||||
logger.error(f"Failed to save final model")
|
||||
|
||||
# Close TensorBoard writer
|
||||
try:
|
||||
writer.close()
|
||||
logger.info("TensorBoard writer closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing TensorBoard writer: {e}")
|
||||
if 'agent' in locals():
|
||||
if robust_save(agent, save_path):
|
||||
logger.info(f"Final model saved to {save_path}")
|
||||
else:
|
||||
logger.error(f"Failed to save final model")
|
||||
|
||||
# Close TensorBoard writer
|
||||
try:
|
||||
writer.close()
|
||||
logger.info("TensorBoard writer closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing TensorBoard writer: {e}")
|
||||
|
||||
# Close exchange connection
|
||||
if exchange:
|
||||
try:
|
||||
await exchange.close()
|
||||
await with_timeout(exchange.close(), timeout=10)
|
||||
logger.info("Exchange connection closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing exchange connection: {e}")
|
||||
|
||||
# Final memory cleanup
|
||||
manage_memory()
|
||||
logger.info("Live training completed")
|
||||
|
||||
async def main():
|
||||
@ -452,6 +544,8 @@ async def main():
|
||||
parser.add_argument('--update_interval', type=int, default=60, help='Interval to update data in seconds')
|
||||
parser.add_argument('--training_iterations', type=int, default=100, help='Training iterations per update')
|
||||
parser.add_argument('--max_episodes', type=int, default=0, help='Maximum number of episodes (0 for unlimited)')
|
||||
parser.add_argument('--retry_delay', type=int, default=5, help='Seconds to wait before retrying after an error')
|
||||
parser.add_argument('--max_retries', type=int, default=3, help='Maximum number of retries for operations')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -466,9 +560,10 @@ async def main():
|
||||
update_interval=args.update_interval,
|
||||
training_iterations=args.training_iterations,
|
||||
max_episodes=args.max_episodes,
|
||||
retry_delay=args.retry_delay,
|
||||
max_retries=args.max_retries,
|
||||
)
|
||||
|
||||
# At the beginning of the file, after importing the modules
|
||||
# Override Agent's save method with our robust save function
|
||||
def monkey_patch_agent_save():
|
||||
"""Replace Agent's save method with our robust save approach"""
|
||||
|
1073
crypto/gogo2/main.py
1073
crypto/gogo2/main.py
File diff suppressed because it is too large
Load Diff
227
crypto/gogo2/test_model_save_load.py
Normal file
227
crypto/gogo2/test_model_save_load.py
Normal file
@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import argparse
|
||||
import gc
|
||||
import traceback
|
||||
import shutil
|
||||
from main import Agent, robust_save
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler("test_model_save_load.log"),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def create_test_directory():
|
||||
"""Create a test directory for saving models"""
|
||||
test_dir = "test_models"
|
||||
os.makedirs(test_dir, exist_ok=True)
|
||||
return test_dir
|
||||
|
||||
def test_save_load_cycle(state_size=64, action_size=4, hidden_size=384):
|
||||
"""Test a full cycle of saving and loading models"""
|
||||
test_dir = create_test_directory()
|
||||
|
||||
# Create a test agent
|
||||
logger.info(f"Creating test agent with state_size={state_size}, action_size={action_size}, hidden_size={hidden_size}")
|
||||
agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
|
||||
|
||||
# Define paths for testing
|
||||
save_path = os.path.join(test_dir, "test_agent.pt")
|
||||
|
||||
# Test saving
|
||||
logger.info(f"Testing save to {save_path}")
|
||||
save_success = agent.save(save_path)
|
||||
|
||||
if save_success:
|
||||
logger.info(f"Save successful, model size: {os.path.getsize(save_path)} bytes")
|
||||
else:
|
||||
logger.error("Save failed!")
|
||||
return False
|
||||
|
||||
# Memory cleanup
|
||||
del agent
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Test loading
|
||||
logger.info(f"Testing load from {save_path}")
|
||||
try:
|
||||
new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
|
||||
new_agent.load(save_path)
|
||||
logger.info("Load successful")
|
||||
|
||||
# Verify model architecture
|
||||
logger.info(f"Verifying model architecture")
|
||||
assert new_agent.state_size == state_size, f"Expected state_size={state_size}, got {new_agent.state_size}"
|
||||
assert new_agent.action_size == action_size, f"Expected action_size={action_size}, got {new_agent.action_size}"
|
||||
assert new_agent.hidden_size == hidden_size, f"Expected hidden_size={hidden_size}, got {new_agent.hidden_size}"
|
||||
|
||||
logger.info("Model architecture verified correctly")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error during load or verification: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def test_robust_save_methods(state_size=64, action_size=4, hidden_size=384):
|
||||
"""Test all the robust save methods"""
|
||||
test_dir = create_test_directory()
|
||||
|
||||
# Create a test agent
|
||||
logger.info(f"Creating test agent for robust save testing")
|
||||
agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
|
||||
|
||||
# Test each robust save method
|
||||
methods = [
|
||||
("regular", os.path.join(test_dir, "regular_save.pt")),
|
||||
("backup", os.path.join(test_dir, "backup_save.pt")),
|
||||
("pickle2", os.path.join(test_dir, "pickle2_save.pt")),
|
||||
("no_optimizer", os.path.join(test_dir, "no_optimizer_save.pt")),
|
||||
("jit", os.path.join(test_dir, "jit_save.pt"))
|
||||
]
|
||||
|
||||
results = {}
|
||||
|
||||
for method_name, save_path in methods:
|
||||
logger.info(f"Testing {method_name} save method to {save_path}")
|
||||
|
||||
try:
|
||||
if method_name == "regular":
|
||||
# Use regular save
|
||||
success = agent.save(save_path)
|
||||
elif method_name == "backup":
|
||||
# Use backup method directly
|
||||
backup_path = f"{save_path}.backup"
|
||||
checkpoint = {
|
||||
'policy_net': agent.policy_net.state_dict(),
|
||||
'target_net': agent.target_net.state_dict(),
|
||||
'optimizer': agent.optimizer.state_dict(),
|
||||
'epsilon': agent.epsilon,
|
||||
'state_size': agent.state_size,
|
||||
'action_size': agent.action_size,
|
||||
'hidden_size': agent.hidden_size
|
||||
}
|
||||
torch.save(checkpoint, backup_path)
|
||||
shutil.copy(backup_path, save_path)
|
||||
success = os.path.exists(save_path)
|
||||
elif method_name == "pickle2":
|
||||
# Use pickle protocol 2
|
||||
checkpoint = {
|
||||
'policy_net': agent.policy_net.state_dict(),
|
||||
'target_net': agent.target_net.state_dict(),
|
||||
'optimizer': agent.optimizer.state_dict(),
|
||||
'epsilon': agent.epsilon,
|
||||
'state_size': agent.state_size,
|
||||
'action_size': agent.action_size,
|
||||
'hidden_size': agent.hidden_size
|
||||
}
|
||||
torch.save(checkpoint, save_path, pickle_protocol=2)
|
||||
success = os.path.exists(save_path)
|
||||
elif method_name == "no_optimizer":
|
||||
# Save without optimizer
|
||||
checkpoint = {
|
||||
'policy_net': agent.policy_net.state_dict(),
|
||||
'target_net': agent.target_net.state_dict(),
|
||||
'epsilon': agent.epsilon,
|
||||
'state_size': agent.state_size,
|
||||
'action_size': agent.action_size,
|
||||
'hidden_size': agent.hidden_size
|
||||
}
|
||||
torch.save(checkpoint, save_path)
|
||||
success = os.path.exists(save_path)
|
||||
elif method_name == "jit":
|
||||
# Use JIT save
|
||||
try:
|
||||
scripted_policy = torch.jit.script(agent.policy_net)
|
||||
torch.jit.save(scripted_policy, f"{save_path}.policy.jit")
|
||||
|
||||
scripted_target = torch.jit.script(agent.target_net)
|
||||
torch.jit.save(scripted_target, f"{save_path}.target.jit")
|
||||
|
||||
# Save parameters
|
||||
with open(f"{save_path}.params.json", "w") as f:
|
||||
import json
|
||||
params = {
|
||||
'epsilon': float(agent.epsilon),
|
||||
'state_size': int(agent.state_size),
|
||||
'action_size': int(agent.action_size),
|
||||
'hidden_size': int(agent.hidden_size)
|
||||
}
|
||||
json.dump(params, f)
|
||||
|
||||
success = (os.path.exists(f"{save_path}.policy.jit") and
|
||||
os.path.exists(f"{save_path}.target.jit") and
|
||||
os.path.exists(f"{save_path}.params.json"))
|
||||
except Exception as e:
|
||||
logger.error(f"JIT save failed: {e}")
|
||||
success = False
|
||||
|
||||
if success:
|
||||
if method_name != "jit":
|
||||
file_size = os.path.getsize(save_path)
|
||||
logger.info(f"{method_name} save successful, size: {file_size} bytes")
|
||||
else:
|
||||
logger.info(f"{method_name} save successful")
|
||||
results[method_name] = True
|
||||
else:
|
||||
logger.error(f"{method_name} save failed")
|
||||
results[method_name] = False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during {method_name} save: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
results[method_name] = False
|
||||
|
||||
# Test loading each saved model
|
||||
for method_name, save_path in methods:
|
||||
if not results[method_name]:
|
||||
logger.info(f"Skipping load test for {method_name} (save failed)")
|
||||
continue
|
||||
|
||||
if method_name == "jit":
|
||||
logger.info(f"Skipping load test for {method_name} (requires special loading)")
|
||||
continue
|
||||
|
||||
logger.info(f"Testing load from {save_path}")
|
||||
try:
|
||||
new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
|
||||
new_agent.load(save_path)
|
||||
logger.info(f"Load successful for {method_name} save")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading from {method_name} save: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
results[method_name] += " (load failed)"
|
||||
|
||||
# Return summary of results
|
||||
return results
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Test model saving and loading')
|
||||
parser.add_argument('--state_size', type=int, default=64, help='State size for test model')
|
||||
parser.add_argument('--action_size', type=int, default=4, help='Action size for test model')
|
||||
parser.add_argument('--hidden_size', type=int, default=384, help='Hidden size for test model')
|
||||
parser.add_argument('--test_robust', action='store_true', help='Test all robust save methods')
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info("Starting model save/load test")
|
||||
|
||||
if args.test_robust:
|
||||
results = test_robust_save_methods(args.state_size, args.action_size, args.hidden_size)
|
||||
logger.info(f"Robust save method results: {results}")
|
||||
else:
|
||||
success = test_save_load_cycle(args.state_size, args.action_size, args.hidden_size)
|
||||
logger.info(f"Save/load cycle {'successful' if success else 'failed'}")
|
||||
|
||||
logger.info("Test completed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user