fixed training again

This commit is contained in:
Dobromir Popov 2025-03-18 02:04:58 +02:00
parent bdf6afc6ad
commit 2c03675f3c
3 changed files with 1182 additions and 271 deletions

View File

@ -9,6 +9,8 @@ import datetime
import traceback import traceback
import numpy as np import numpy as np
import torch import torch
import gc
from functools import partial
from main import initialize_exchange, TradingEnvironment, Agent from main import initialize_exchange, TradingEnvironment, Agent
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -54,6 +56,11 @@ def robust_save(model, path):
# Backup path in case the main save fails # Backup path in case the main save fails
backup_path = f"{path}.backup" backup_path = f"{path}.backup"
# Clean up GPU memory before saving
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Attempt 1: Try with default settings in a separate file first # Attempt 1: Try with default settings in a separate file first
try: try:
logger.info(f"Saving model to {backup_path} (attempt 1)") logger.info(f"Saving model to {backup_path} (attempt 1)")
@ -122,6 +129,28 @@ def robust_save(model, path):
logger.error(f"All save attempts failed: {e}") logger.error(f"All save attempts failed: {e}")
return False return False
# Implement timeout wrapper for exchange operations
async def with_timeout(coroutine, timeout=30, default=None):
"""
Execute a coroutine with a timeout
Args:
coroutine: The coroutine to execute
timeout: Timeout in seconds
default: Default value to return on timeout
Returns:
The result of the coroutine or default value on timeout
"""
try:
return await asyncio.wait_for(coroutine, timeout=timeout)
except asyncio.TimeoutError:
logger.warning(f"Operation timed out after {timeout} seconds")
return default
except Exception as e:
logger.error(f"Operation failed: {e}")
return default
# Implement fetch_and_update_data function # Implement fetch_and_update_data function
async def fetch_and_update_data(exchange, env, symbol, timeframe): async def fetch_and_update_data(exchange, env, symbol, timeframe):
""" """
@ -139,8 +168,12 @@ async def fetch_and_update_data(exchange, env, symbol, timeframe):
# Default to 100 candles if not specified # Default to 100 candles if not specified
limit = 1000 limit = 1000
# Fetch OHLCV data # Fetch OHLCV data with timeout
candles = await exchange.fetch_ohlcv(symbol, timeframe, limit=limit) candles = await with_timeout(
exchange.fetch_ohlcv(symbol, timeframe, limit=limit),
timeout=30,
default=[]
)
if not candles or len(candles) == 0: if not candles or len(candles) == 0:
logger.warning(f"No candles returned for {symbol} on {timeframe}") logger.warning(f"No candles returned for {symbol} on {timeframe}")
@ -181,6 +214,16 @@ async def fetch_and_update_data(exchange, env, symbol, timeframe):
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return False return False
# Implement memory management function
def manage_memory():
"""
Clean up memory to avoid memory leaks during long running sessions
"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.debug("Memory cleaned")
async def live_training( async def live_training(
symbol="ETH/USDT", symbol="ETH/USDT",
timeframe="1m", timeframe="1m",
@ -194,6 +237,8 @@ async def live_training(
gamma=0.99, gamma=0.99,
window_size=30, window_size=30,
max_episodes=0, # 0 means unlimited max_episodes=0, # 0 means unlimited
retry_delay=5, # Seconds to wait before retrying after an error
max_retries=3, # Maximum number of retries for operations
): ):
""" """
Live training function that uses real market data to improve the model without executing real trades. Live training function that uses real market data to improve the model without executing real trades.
@ -211,15 +256,30 @@ async def live_training(
gamma: Discount factor for training gamma: Discount factor for training
window_size: Window size for the environment window_size: Window size for the environment
max_episodes: Maximum number of episodes (0 for unlimited) max_episodes: Maximum number of episodes (0 for unlimited)
retry_delay: Seconds to wait before retrying after an error
max_retries: Maximum number of retries for operations
""" """
logger.info(f"Starting live training for {symbol} on {timeframe} timeframe") logger.info(f"Starting live training for {symbol} on {timeframe} timeframe")
# Initialize exchange (without sandbox mode) # Initialize exchange (without sandbox mode)
exchange = None exchange = None
# Retry loop for exchange initialization
for retry in range(max_retries):
try:
exchange = await initialize_exchange()
logger.info(f"Exchange initialized: {exchange.id}")
break
except Exception as e:
logger.error(f"Error initializing exchange (attempt {retry+1}/{max_retries}): {e}")
if retry < max_retries - 1:
logger.info(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached. Could not initialize exchange.")
return
try: try:
exchange = await initialize_exchange()
logger.info(f"Exchange initialized: {exchange.id}")
# Initialize environment # Initialize environment
env = TradingEnvironment( env = TradingEnvironment(
initial_balance=initial_balance, initial_balance=initial_balance,
@ -228,11 +288,20 @@ async def live_training(
timeframe=timeframe, timeframe=timeframe,
) )
# Fetch initial data # Fetch initial data (with retries)
logger.info(f"Fetching initial data for {symbol}") logger.info(f"Fetching initial data for {symbol}")
success = await fetch_and_update_data(exchange, env, symbol, timeframe) success = False
for retry in range(max_retries):
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
if success:
break
logger.warning(f"Failed to fetch initial data (attempt {retry+1}/{max_retries})")
if retry < max_retries - 1:
logger.info(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
if not success: if not success:
logger.error("Failed to fetch initial data, exiting") logger.error("Failed to fetch initial data after multiple attempts, exiting")
return return
# Initialize agent # Initialize agent
@ -268,6 +337,10 @@ async def live_training(
step_counter = 0 step_counter = 0
last_update_time = datetime.datetime.now() last_update_time = datetime.datetime.now()
# Track consecutive errors to enable circuit breaker
consecutive_errors = 0
max_consecutive_errors = 5
while True: while True:
# Check if we've reached the maximum number of episodes # Check if we've reached the maximum number of episodes
if max_episodes > 0 and episode_count >= max_episodes: if max_episodes > 0 and episode_count >= max_episodes:
@ -284,11 +357,14 @@ async def live_training(
if not success: if not success:
logger.warning("Failed to update data, will try again later") logger.warning("Failed to update data, will try again later")
# Wait a bit before trying again # Wait a bit before trying again
await asyncio.sleep(5) await asyncio.sleep(retry_delay)
continue continue
last_update_time = current_time last_update_time = current_time
# Clean up memory before running an episode
manage_memory()
# Run training iterations on the updated data # Run training iterations on the updated data
episode_reward = 0 episode_reward = 0
env.reset() env.reset()
@ -337,12 +413,22 @@ async def live_training(
# Train the agent on a batch of experiences # Train the agent on a batch of experiences
if len(agent.memory) > batch_size: if len(agent.memory) > batch_size:
agent.learn() try:
agent.learn()
# Additional training iterations
if steps_in_episode % 10 == 0 and training_iterations > 1: # Additional training iterations
for _ in range(training_iterations - 1): if steps_in_episode % 10 == 0 and training_iterations > 1:
agent.learn() for _ in range(training_iterations - 1):
agent.learn()
# Reset consecutive errors counter on successful learning
consecutive_errors = 0
except Exception as e:
logger.error(f"Error during learning: {e}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
break
if done: if done:
logger.info(f"Episode done after {steps_in_episode} steps") logger.info(f"Episode done after {steps_in_episode} steps")
@ -351,7 +437,10 @@ async def live_training(
except Exception as e: except Exception as e:
logger.error(f"Error during episode step: {e}") logger.error(f"Error during episode step: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
break consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
break
# Update training statistics # Update training statistics
episode_count += 1 episode_count += 1
@ -419,26 +508,29 @@ async def live_training(
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
finally: finally:
# Save final model # Save final model
if robust_save(agent, save_path): if 'agent' in locals():
logger.info(f"Final model saved to {save_path}") if robust_save(agent, save_path):
else: logger.info(f"Final model saved to {save_path}")
logger.error(f"Failed to save final model") else:
logger.error(f"Failed to save final model")
# Close TensorBoard writer
try: # Close TensorBoard writer
writer.close() try:
logger.info("TensorBoard writer closed") writer.close()
except Exception as e: logger.info("TensorBoard writer closed")
logger.error(f"Error closing TensorBoard writer: {e}") except Exception as e:
logger.error(f"Error closing TensorBoard writer: {e}")
# Close exchange connection # Close exchange connection
if exchange: if exchange:
try: try:
await exchange.close() await with_timeout(exchange.close(), timeout=10)
logger.info("Exchange connection closed") logger.info("Exchange connection closed")
except Exception as e: except Exception as e:
logger.error(f"Error closing exchange connection: {e}") logger.error(f"Error closing exchange connection: {e}")
# Final memory cleanup
manage_memory()
logger.info("Live training completed") logger.info("Live training completed")
async def main(): async def main():
@ -452,6 +544,8 @@ async def main():
parser.add_argument('--update_interval', type=int, default=60, help='Interval to update data in seconds') parser.add_argument('--update_interval', type=int, default=60, help='Interval to update data in seconds')
parser.add_argument('--training_iterations', type=int, default=100, help='Training iterations per update') parser.add_argument('--training_iterations', type=int, default=100, help='Training iterations per update')
parser.add_argument('--max_episodes', type=int, default=0, help='Maximum number of episodes (0 for unlimited)') parser.add_argument('--max_episodes', type=int, default=0, help='Maximum number of episodes (0 for unlimited)')
parser.add_argument('--retry_delay', type=int, default=5, help='Seconds to wait before retrying after an error')
parser.add_argument('--max_retries', type=int, default=3, help='Maximum number of retries for operations')
args = parser.parse_args() args = parser.parse_args()
@ -466,9 +560,10 @@ async def main():
update_interval=args.update_interval, update_interval=args.update_interval,
training_iterations=args.training_iterations, training_iterations=args.training_iterations,
max_episodes=args.max_episodes, max_episodes=args.max_episodes,
retry_delay=args.retry_delay,
max_retries=args.max_retries,
) )
# At the beginning of the file, after importing the modules
# Override Agent's save method with our robust save function # Override Agent's save method with our robust save function
def monkey_patch_agent_save(): def monkey_patch_agent_save():
"""Replace Agent's save method with our robust save approach""" """Replace Agent's save method with our robust save approach"""

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,227 @@
#!/usr/bin/env python
import os
import logging
import torch
import argparse
import gc
import traceback
import shutil
from main import Agent, robust_save
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler("test_model_save_load.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def create_test_directory():
"""Create a test directory for saving models"""
test_dir = "test_models"
os.makedirs(test_dir, exist_ok=True)
return test_dir
def test_save_load_cycle(state_size=64, action_size=4, hidden_size=384):
"""Test a full cycle of saving and loading models"""
test_dir = create_test_directory()
# Create a test agent
logger.info(f"Creating test agent with state_size={state_size}, action_size={action_size}, hidden_size={hidden_size}")
agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
# Define paths for testing
save_path = os.path.join(test_dir, "test_agent.pt")
# Test saving
logger.info(f"Testing save to {save_path}")
save_success = agent.save(save_path)
if save_success:
logger.info(f"Save successful, model size: {os.path.getsize(save_path)} bytes")
else:
logger.error("Save failed!")
return False
# Memory cleanup
del agent
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Test loading
logger.info(f"Testing load from {save_path}")
try:
new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
new_agent.load(save_path)
logger.info("Load successful")
# Verify model architecture
logger.info(f"Verifying model architecture")
assert new_agent.state_size == state_size, f"Expected state_size={state_size}, got {new_agent.state_size}"
assert new_agent.action_size == action_size, f"Expected action_size={action_size}, got {new_agent.action_size}"
assert new_agent.hidden_size == hidden_size, f"Expected hidden_size={hidden_size}, got {new_agent.hidden_size}"
logger.info("Model architecture verified correctly")
return True
except Exception as e:
logger.error(f"Error during load or verification: {e}")
logger.error(traceback.format_exc())
return False
def test_robust_save_methods(state_size=64, action_size=4, hidden_size=384):
"""Test all the robust save methods"""
test_dir = create_test_directory()
# Create a test agent
logger.info(f"Creating test agent for robust save testing")
agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
# Test each robust save method
methods = [
("regular", os.path.join(test_dir, "regular_save.pt")),
("backup", os.path.join(test_dir, "backup_save.pt")),
("pickle2", os.path.join(test_dir, "pickle2_save.pt")),
("no_optimizer", os.path.join(test_dir, "no_optimizer_save.pt")),
("jit", os.path.join(test_dir, "jit_save.pt"))
]
results = {}
for method_name, save_path in methods:
logger.info(f"Testing {method_name} save method to {save_path}")
try:
if method_name == "regular":
# Use regular save
success = agent.save(save_path)
elif method_name == "backup":
# Use backup method directly
backup_path = f"{save_path}.backup"
checkpoint = {
'policy_net': agent.policy_net.state_dict(),
'target_net': agent.target_net.state_dict(),
'optimizer': agent.optimizer.state_dict(),
'epsilon': agent.epsilon,
'state_size': agent.state_size,
'action_size': agent.action_size,
'hidden_size': agent.hidden_size
}
torch.save(checkpoint, backup_path)
shutil.copy(backup_path, save_path)
success = os.path.exists(save_path)
elif method_name == "pickle2":
# Use pickle protocol 2
checkpoint = {
'policy_net': agent.policy_net.state_dict(),
'target_net': agent.target_net.state_dict(),
'optimizer': agent.optimizer.state_dict(),
'epsilon': agent.epsilon,
'state_size': agent.state_size,
'action_size': agent.action_size,
'hidden_size': agent.hidden_size
}
torch.save(checkpoint, save_path, pickle_protocol=2)
success = os.path.exists(save_path)
elif method_name == "no_optimizer":
# Save without optimizer
checkpoint = {
'policy_net': agent.policy_net.state_dict(),
'target_net': agent.target_net.state_dict(),
'epsilon': agent.epsilon,
'state_size': agent.state_size,
'action_size': agent.action_size,
'hidden_size': agent.hidden_size
}
torch.save(checkpoint, save_path)
success = os.path.exists(save_path)
elif method_name == "jit":
# Use JIT save
try:
scripted_policy = torch.jit.script(agent.policy_net)
torch.jit.save(scripted_policy, f"{save_path}.policy.jit")
scripted_target = torch.jit.script(agent.target_net)
torch.jit.save(scripted_target, f"{save_path}.target.jit")
# Save parameters
with open(f"{save_path}.params.json", "w") as f:
import json
params = {
'epsilon': float(agent.epsilon),
'state_size': int(agent.state_size),
'action_size': int(agent.action_size),
'hidden_size': int(agent.hidden_size)
}
json.dump(params, f)
success = (os.path.exists(f"{save_path}.policy.jit") and
os.path.exists(f"{save_path}.target.jit") and
os.path.exists(f"{save_path}.params.json"))
except Exception as e:
logger.error(f"JIT save failed: {e}")
success = False
if success:
if method_name != "jit":
file_size = os.path.getsize(save_path)
logger.info(f"{method_name} save successful, size: {file_size} bytes")
else:
logger.info(f"{method_name} save successful")
results[method_name] = True
else:
logger.error(f"{method_name} save failed")
results[method_name] = False
except Exception as e:
logger.error(f"Error during {method_name} save: {e}")
logger.error(traceback.format_exc())
results[method_name] = False
# Test loading each saved model
for method_name, save_path in methods:
if not results[method_name]:
logger.info(f"Skipping load test for {method_name} (save failed)")
continue
if method_name == "jit":
logger.info(f"Skipping load test for {method_name} (requires special loading)")
continue
logger.info(f"Testing load from {save_path}")
try:
new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
new_agent.load(save_path)
logger.info(f"Load successful for {method_name} save")
except Exception as e:
logger.error(f"Error loading from {method_name} save: {e}")
logger.error(traceback.format_exc())
results[method_name] += " (load failed)"
# Return summary of results
return results
def main():
parser = argparse.ArgumentParser(description='Test model saving and loading')
parser.add_argument('--state_size', type=int, default=64, help='State size for test model')
parser.add_argument('--action_size', type=int, default=4, help='Action size for test model')
parser.add_argument('--hidden_size', type=int, default=384, help='Hidden size for test model')
parser.add_argument('--test_robust', action='store_true', help='Test all robust save methods')
args = parser.parse_args()
logger.info("Starting model save/load test")
if args.test_robust:
results = test_robust_save_methods(args.state_size, args.action_size, args.hidden_size)
logger.info(f"Robust save method results: {results}")
else:
success = test_save_load_cycle(args.state_size, args.action_size, args.hidden_size)
logger.info(f"Save/load cycle {'successful' if success else 'failed'}")
logger.info("Test completed")
if __name__ == "__main__":
main()