fixed training again

This commit is contained in:
Dobromir Popov 2025-03-18 02:04:58 +02:00
parent bdf6afc6ad
commit 2c03675f3c
3 changed files with 1182 additions and 271 deletions

View File

@ -9,6 +9,8 @@ import datetime
import traceback
import numpy as np
import torch
import gc
from functools import partial
from main import initialize_exchange, TradingEnvironment, Agent
from torch.utils.tensorboard import SummaryWriter
@ -54,6 +56,11 @@ def robust_save(model, path):
# Backup path in case the main save fails
backup_path = f"{path}.backup"
# Clean up GPU memory before saving
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Attempt 1: Try with default settings in a separate file first
try:
logger.info(f"Saving model to {backup_path} (attempt 1)")
@ -122,6 +129,28 @@ def robust_save(model, path):
logger.error(f"All save attempts failed: {e}")
return False
# Implement timeout wrapper for exchange operations
async def with_timeout(coroutine, timeout=30, default=None):
"""
Execute a coroutine with a timeout
Args:
coroutine: The coroutine to execute
timeout: Timeout in seconds
default: Default value to return on timeout
Returns:
The result of the coroutine or default value on timeout
"""
try:
return await asyncio.wait_for(coroutine, timeout=timeout)
except asyncio.TimeoutError:
logger.warning(f"Operation timed out after {timeout} seconds")
return default
except Exception as e:
logger.error(f"Operation failed: {e}")
return default
# Implement fetch_and_update_data function
async def fetch_and_update_data(exchange, env, symbol, timeframe):
"""
@ -139,8 +168,12 @@ async def fetch_and_update_data(exchange, env, symbol, timeframe):
# Default to 100 candles if not specified
limit = 1000
# Fetch OHLCV data
candles = await exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
# Fetch OHLCV data with timeout
candles = await with_timeout(
exchange.fetch_ohlcv(symbol, timeframe, limit=limit),
timeout=30,
default=[]
)
if not candles or len(candles) == 0:
logger.warning(f"No candles returned for {symbol} on {timeframe}")
@ -181,6 +214,16 @@ async def fetch_and_update_data(exchange, env, symbol, timeframe):
logger.error(traceback.format_exc())
return False
# Implement memory management function
def manage_memory():
"""
Clean up memory to avoid memory leaks during long running sessions
"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.debug("Memory cleaned")
async def live_training(
symbol="ETH/USDT",
timeframe="1m",
@ -194,6 +237,8 @@ async def live_training(
gamma=0.99,
window_size=30,
max_episodes=0, # 0 means unlimited
retry_delay=5, # Seconds to wait before retrying after an error
max_retries=3, # Maximum number of retries for operations
):
"""
Live training function that uses real market data to improve the model without executing real trades.
@ -211,15 +256,30 @@ async def live_training(
gamma: Discount factor for training
window_size: Window size for the environment
max_episodes: Maximum number of episodes (0 for unlimited)
retry_delay: Seconds to wait before retrying after an error
max_retries: Maximum number of retries for operations
"""
logger.info(f"Starting live training for {symbol} on {timeframe} timeframe")
# Initialize exchange (without sandbox mode)
exchange = None
try:
exchange = await initialize_exchange()
logger.info(f"Exchange initialized: {exchange.id}")
# Retry loop for exchange initialization
for retry in range(max_retries):
try:
exchange = await initialize_exchange()
logger.info(f"Exchange initialized: {exchange.id}")
break
except Exception as e:
logger.error(f"Error initializing exchange (attempt {retry+1}/{max_retries}): {e}")
if retry < max_retries - 1:
logger.info(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached. Could not initialize exchange.")
return
try:
# Initialize environment
env = TradingEnvironment(
initial_balance=initial_balance,
@ -228,11 +288,20 @@ async def live_training(
timeframe=timeframe,
)
# Fetch initial data
# Fetch initial data (with retries)
logger.info(f"Fetching initial data for {symbol}")
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
success = False
for retry in range(max_retries):
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
if success:
break
logger.warning(f"Failed to fetch initial data (attempt {retry+1}/{max_retries})")
if retry < max_retries - 1:
logger.info(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
if not success:
logger.error("Failed to fetch initial data, exiting")
logger.error("Failed to fetch initial data after multiple attempts, exiting")
return
# Initialize agent
@ -268,6 +337,10 @@ async def live_training(
step_counter = 0
last_update_time = datetime.datetime.now()
# Track consecutive errors to enable circuit breaker
consecutive_errors = 0
max_consecutive_errors = 5
while True:
# Check if we've reached the maximum number of episodes
if max_episodes > 0 and episode_count >= max_episodes:
@ -284,11 +357,14 @@ async def live_training(
if not success:
logger.warning("Failed to update data, will try again later")
# Wait a bit before trying again
await asyncio.sleep(5)
await asyncio.sleep(retry_delay)
continue
last_update_time = current_time
# Clean up memory before running an episode
manage_memory()
# Run training iterations on the updated data
episode_reward = 0
env.reset()
@ -337,12 +413,22 @@ async def live_training(
# Train the agent on a batch of experiences
if len(agent.memory) > batch_size:
agent.learn()
try:
agent.learn()
# Additional training iterations
if steps_in_episode % 10 == 0 and training_iterations > 1:
for _ in range(training_iterations - 1):
agent.learn()
# Additional training iterations
if steps_in_episode % 10 == 0 and training_iterations > 1:
for _ in range(training_iterations - 1):
agent.learn()
# Reset consecutive errors counter on successful learning
consecutive_errors = 0
except Exception as e:
logger.error(f"Error during learning: {e}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
break
if done:
logger.info(f"Episode done after {steps_in_episode} steps")
@ -351,7 +437,10 @@ async def live_training(
except Exception as e:
logger.error(f"Error during episode step: {e}")
logger.error(traceback.format_exc())
break
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
break
# Update training statistics
episode_count += 1
@ -419,26 +508,29 @@ async def live_training(
logger.error(traceback.format_exc())
finally:
# Save final model
if robust_save(agent, save_path):
logger.info(f"Final model saved to {save_path}")
else:
logger.error(f"Failed to save final model")
if 'agent' in locals():
if robust_save(agent, save_path):
logger.info(f"Final model saved to {save_path}")
else:
logger.error(f"Failed to save final model")
# Close TensorBoard writer
try:
writer.close()
logger.info("TensorBoard writer closed")
except Exception as e:
logger.error(f"Error closing TensorBoard writer: {e}")
# Close TensorBoard writer
try:
writer.close()
logger.info("TensorBoard writer closed")
except Exception as e:
logger.error(f"Error closing TensorBoard writer: {e}")
# Close exchange connection
if exchange:
try:
await exchange.close()
await with_timeout(exchange.close(), timeout=10)
logger.info("Exchange connection closed")
except Exception as e:
logger.error(f"Error closing exchange connection: {e}")
# Final memory cleanup
manage_memory()
logger.info("Live training completed")
async def main():
@ -452,6 +544,8 @@ async def main():
parser.add_argument('--update_interval', type=int, default=60, help='Interval to update data in seconds')
parser.add_argument('--training_iterations', type=int, default=100, help='Training iterations per update')
parser.add_argument('--max_episodes', type=int, default=0, help='Maximum number of episodes (0 for unlimited)')
parser.add_argument('--retry_delay', type=int, default=5, help='Seconds to wait before retrying after an error')
parser.add_argument('--max_retries', type=int, default=3, help='Maximum number of retries for operations')
args = parser.parse_args()
@ -466,9 +560,10 @@ async def main():
update_interval=args.update_interval,
training_iterations=args.training_iterations,
max_episodes=args.max_episodes,
retry_delay=args.retry_delay,
max_retries=args.max_retries,
)
# At the beginning of the file, after importing the modules
# Override Agent's save method with our robust save function
def monkey_patch_agent_save():
"""Replace Agent's save method with our robust save approach"""

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,227 @@
#!/usr/bin/env python
import os
import logging
import torch
import argparse
import gc
import traceback
import shutil
from main import Agent, robust_save
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler("test_model_save_load.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def create_test_directory():
"""Create a test directory for saving models"""
test_dir = "test_models"
os.makedirs(test_dir, exist_ok=True)
return test_dir
def test_save_load_cycle(state_size=64, action_size=4, hidden_size=384):
"""Test a full cycle of saving and loading models"""
test_dir = create_test_directory()
# Create a test agent
logger.info(f"Creating test agent with state_size={state_size}, action_size={action_size}, hidden_size={hidden_size}")
agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
# Define paths for testing
save_path = os.path.join(test_dir, "test_agent.pt")
# Test saving
logger.info(f"Testing save to {save_path}")
save_success = agent.save(save_path)
if save_success:
logger.info(f"Save successful, model size: {os.path.getsize(save_path)} bytes")
else:
logger.error("Save failed!")
return False
# Memory cleanup
del agent
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Test loading
logger.info(f"Testing load from {save_path}")
try:
new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
new_agent.load(save_path)
logger.info("Load successful")
# Verify model architecture
logger.info(f"Verifying model architecture")
assert new_agent.state_size == state_size, f"Expected state_size={state_size}, got {new_agent.state_size}"
assert new_agent.action_size == action_size, f"Expected action_size={action_size}, got {new_agent.action_size}"
assert new_agent.hidden_size == hidden_size, f"Expected hidden_size={hidden_size}, got {new_agent.hidden_size}"
logger.info("Model architecture verified correctly")
return True
except Exception as e:
logger.error(f"Error during load or verification: {e}")
logger.error(traceback.format_exc())
return False
def test_robust_save_methods(state_size=64, action_size=4, hidden_size=384):
"""Test all the robust save methods"""
test_dir = create_test_directory()
# Create a test agent
logger.info(f"Creating test agent for robust save testing")
agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
# Test each robust save method
methods = [
("regular", os.path.join(test_dir, "regular_save.pt")),
("backup", os.path.join(test_dir, "backup_save.pt")),
("pickle2", os.path.join(test_dir, "pickle2_save.pt")),
("no_optimizer", os.path.join(test_dir, "no_optimizer_save.pt")),
("jit", os.path.join(test_dir, "jit_save.pt"))
]
results = {}
for method_name, save_path in methods:
logger.info(f"Testing {method_name} save method to {save_path}")
try:
if method_name == "regular":
# Use regular save
success = agent.save(save_path)
elif method_name == "backup":
# Use backup method directly
backup_path = f"{save_path}.backup"
checkpoint = {
'policy_net': agent.policy_net.state_dict(),
'target_net': agent.target_net.state_dict(),
'optimizer': agent.optimizer.state_dict(),
'epsilon': agent.epsilon,
'state_size': agent.state_size,
'action_size': agent.action_size,
'hidden_size': agent.hidden_size
}
torch.save(checkpoint, backup_path)
shutil.copy(backup_path, save_path)
success = os.path.exists(save_path)
elif method_name == "pickle2":
# Use pickle protocol 2
checkpoint = {
'policy_net': agent.policy_net.state_dict(),
'target_net': agent.target_net.state_dict(),
'optimizer': agent.optimizer.state_dict(),
'epsilon': agent.epsilon,
'state_size': agent.state_size,
'action_size': agent.action_size,
'hidden_size': agent.hidden_size
}
torch.save(checkpoint, save_path, pickle_protocol=2)
success = os.path.exists(save_path)
elif method_name == "no_optimizer":
# Save without optimizer
checkpoint = {
'policy_net': agent.policy_net.state_dict(),
'target_net': agent.target_net.state_dict(),
'epsilon': agent.epsilon,
'state_size': agent.state_size,
'action_size': agent.action_size,
'hidden_size': agent.hidden_size
}
torch.save(checkpoint, save_path)
success = os.path.exists(save_path)
elif method_name == "jit":
# Use JIT save
try:
scripted_policy = torch.jit.script(agent.policy_net)
torch.jit.save(scripted_policy, f"{save_path}.policy.jit")
scripted_target = torch.jit.script(agent.target_net)
torch.jit.save(scripted_target, f"{save_path}.target.jit")
# Save parameters
with open(f"{save_path}.params.json", "w") as f:
import json
params = {
'epsilon': float(agent.epsilon),
'state_size': int(agent.state_size),
'action_size': int(agent.action_size),
'hidden_size': int(agent.hidden_size)
}
json.dump(params, f)
success = (os.path.exists(f"{save_path}.policy.jit") and
os.path.exists(f"{save_path}.target.jit") and
os.path.exists(f"{save_path}.params.json"))
except Exception as e:
logger.error(f"JIT save failed: {e}")
success = False
if success:
if method_name != "jit":
file_size = os.path.getsize(save_path)
logger.info(f"{method_name} save successful, size: {file_size} bytes")
else:
logger.info(f"{method_name} save successful")
results[method_name] = True
else:
logger.error(f"{method_name} save failed")
results[method_name] = False
except Exception as e:
logger.error(f"Error during {method_name} save: {e}")
logger.error(traceback.format_exc())
results[method_name] = False
# Test loading each saved model
for method_name, save_path in methods:
if not results[method_name]:
logger.info(f"Skipping load test for {method_name} (save failed)")
continue
if method_name == "jit":
logger.info(f"Skipping load test for {method_name} (requires special loading)")
continue
logger.info(f"Testing load from {save_path}")
try:
new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
new_agent.load(save_path)
logger.info(f"Load successful for {method_name} save")
except Exception as e:
logger.error(f"Error loading from {method_name} save: {e}")
logger.error(traceback.format_exc())
results[method_name] += " (load failed)"
# Return summary of results
return results
def main():
parser = argparse.ArgumentParser(description='Test model saving and loading')
parser.add_argument('--state_size', type=int, default=64, help='State size for test model')
parser.add_argument('--action_size', type=int, default=4, help='Action size for test model')
parser.add_argument('--hidden_size', type=int, default=384, help='Hidden size for test model')
parser.add_argument('--test_robust', action='store_true', help='Test all robust save methods')
args = parser.parse_args()
logger.info("Starting model save/load test")
if args.test_robust:
results = test_robust_save_methods(args.state_size, args.action_size, args.hidden_size)
logger.info(f"Robust save method results: {results}")
else:
success = test_save_load_cycle(args.state_size, args.action_size, args.hidden_size)
logger.info(f"Save/load cycle {'successful' if success else 'failed'}")
logger.info("Test completed")
if __name__ == "__main__":
main()