#!/usr/bin/env python import asyncio import logging import sys import platform import argparse import os import datetime import traceback import numpy as np import torch import gc from functools import partial from main import initialize_exchange, TradingEnvironment, Agent from torch.utils.tensorboard import SummaryWriter # Fix for Windows asyncio issues with aiodns if platform.system() == 'Windows': try: asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) print("Using Windows SelectorEventLoopPolicy to fix aiodns issue") except Exception as e: print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}") # Setup logging function def setup_logging(): """Setup logging configuration for the application""" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("live_training.log"), logging.StreamHandler(sys.stdout) # Added stdout handler for immediate feedback ] ) # Set up logging setup_logging() logger = logging.getLogger(__name__) # Implement a robust save function to handle PyTorch serialization errors def robust_save(model, path): """ Robust model saving with multiple fallback approaches Args: model: The Agent model to save path: Path to save the model Returns: bool: True if successful, False otherwise """ # Create directory if it doesn't exist os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) # Backup path in case the main save fails backup_path = f"{path}.backup" # Clean up GPU memory before saving if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() # Attempt 1: Try with default settings in a separate file first try: logger.info(f"Saving model to {backup_path} (attempt 1)") checkpoint = { 'policy_net': model.policy_net.state_dict(), 'target_net': model.target_net.state_dict(), 'optimizer': model.optimizer.state_dict(), 'epsilon': model.epsilon } torch.save(checkpoint, backup_path) logger.info(f"Successfully saved to {backup_path}") # If backup worked, copy to the actual path if os.path.exists(backup_path): import shutil shutil.copy(backup_path, path) logger.info(f"Copied backup to {path}") return True except Exception as e: logger.warning(f"First save attempt failed: {e}") # Attempt 2: Try with pickle protocol 2 (more compatible) try: logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)") checkpoint = { 'policy_net': model.policy_net.state_dict(), 'target_net': model.target_net.state_dict(), 'optimizer': model.optimizer.state_dict(), 'epsilon': model.epsilon } torch.save(checkpoint, path, pickle_protocol=2) logger.info(f"Successfully saved to {path} with pickle_protocol=2") return True except Exception as e: logger.warning(f"Second save attempt failed: {e}") # Attempt 3: Try without optimizer state (which can be large and cause issues) try: logger.info(f"Saving model to {path} (attempt 3 - without optimizer)") checkpoint = { 'policy_net': model.policy_net.state_dict(), 'target_net': model.target_net.state_dict(), 'epsilon': model.epsilon } torch.save(checkpoint, path) logger.info(f"Successfully saved to {path} without optimizer state") return True except Exception as e: logger.warning(f"Third save attempt failed: {e}") # Attempt 4: Try with torch.jit.save instead try: logger.info(f"Saving model to {path} (attempt 4 - with jit.save)") # Save policy network using jit scripted_policy = torch.jit.script(model.policy_net) torch.jit.save(scripted_policy, f"{path}.policy.jit") # Save target network using jit scripted_target = torch.jit.script(model.target_net) torch.jit.save(scripted_target, f"{path}.target.jit") # Save epsilon value separately with open(f"{path}.epsilon.txt", "w") as f: f.write(str(model.epsilon)) logger.info(f"Successfully saved model components with jit.save") return True except Exception as e: logger.error(f"All save attempts failed: {e}") return False # Implement timeout wrapper for exchange operations async def with_timeout(coroutine, timeout=30, default=None): """ Execute a coroutine with a timeout Args: coroutine: The coroutine to execute timeout: Timeout in seconds default: Default value to return on timeout Returns: The result of the coroutine or default value on timeout """ try: return await asyncio.wait_for(coroutine, timeout=timeout) except asyncio.TimeoutError: logger.warning(f"Operation timed out after {timeout} seconds") return default except Exception as e: logger.error(f"Operation failed: {e}") return default # Implement fetch_and_update_data function async def fetch_and_update_data(exchange, env, symbol, timeframe): """ Fetch new candle data and update the environment Args: exchange: CCXT exchange instance env: Trading environment instance symbol: Trading pair symbol timeframe: Timeframe for the candles """ logger.info(f"Fetching new data for {symbol} on {timeframe} timeframe") try: # Default to 100 candles if not specified limit = 1000 # Fetch OHLCV data with timeout candles = await with_timeout( exchange.fetch_ohlcv(symbol, timeframe, limit=limit), timeout=30, default=[] ) if not candles or len(candles) == 0: logger.warning(f"No candles returned for {symbol} on {timeframe}") return False logger.info(f"Successfully fetched {len(candles)} candles") # Convert to format expected by environment formatted_candles = [] for candle in candles: timestamp, open_price, high, low, close, volume = candle formatted_candles.append({ 'timestamp': timestamp, 'open': open_price, 'high': high, 'low': low, 'close': close, 'volume': volume }) # Update environment data env.data = formatted_candles if hasattr(env, '_initialize_features'): env._initialize_features() logger.info(f"Updated environment with {len(formatted_candles)} candles") # Print latest candle info if formatted_candles: latest = formatted_candles[-1] dt = datetime.datetime.fromtimestamp(latest['timestamp']/1000).strftime('%Y-%m-%d %H:%M:%S') logger.info(f"Latest candle: Time={dt}, Open={latest['open']}, High={latest['high']}, Low={latest['low']}, Close={latest['close']}, Volume={latest['volume']}") return True except Exception as e: logger.error(f"Error fetching candle data: {e}") logger.error(traceback.format_exc()) return False # Implement memory management function def manage_memory(): """ Clean up memory to avoid memory leaks during long running sessions """ if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() logger.debug("Memory cleaned") async def live_training( symbol="ETH/USDT", timeframe="1m", model_path="models/trading_agent_best_pnl.pt", save_path="models/trading_agent_live_trained.pt", initial_balance=1000, update_interval=60, training_iterations=100, learning_rate=0.0001, batch_size=64, gamma=0.99, window_size=30, max_episodes=0, # 0 means unlimited retry_delay=5, # Seconds to wait before retrying after an error max_retries=3, # Maximum number of retries for operations ): """ Live training function that uses real market data to improve the model without executing real trades. Args: symbol: Trading pair symbol timeframe: Timeframe for training model_path: Path to the initial model to load save_path: Path to save the improved model initial_balance: Initial balance for simulation update_interval: Interval to update data in seconds training_iterations: Number of training iterations per data update learning_rate: Learning rate for training batch_size: Batch size for training gamma: Discount factor for training window_size: Window size for the environment max_episodes: Maximum number of episodes (0 for unlimited) retry_delay: Seconds to wait before retrying after an error max_retries: Maximum number of retries for operations """ logger.info(f"Starting live training for {symbol} on {timeframe} timeframe") # Initialize exchange (without sandbox mode) exchange = None # Retry loop for exchange initialization for retry in range(max_retries): try: exchange = await initialize_exchange() logger.info(f"Exchange initialized: {exchange.id}") break except Exception as e: logger.error(f"Error initializing exchange (attempt {retry+1}/{max_retries}): {e}") if retry < max_retries - 1: logger.info(f"Retrying in {retry_delay} seconds...") await asyncio.sleep(retry_delay) else: logger.error("Max retries reached. Could not initialize exchange.") return try: # Initialize environment env = TradingEnvironment( initial_balance=initial_balance, window_size=window_size, symbol=symbol, timeframe=timeframe, ) # Fetch initial data (with retries) logger.info(f"Fetching initial data for {symbol}") success = False for retry in range(max_retries): success = await fetch_and_update_data(exchange, env, symbol, timeframe) if success: break logger.warning(f"Failed to fetch initial data (attempt {retry+1}/{max_retries})") if retry < max_retries - 1: logger.info(f"Retrying in {retry_delay} seconds...") await asyncio.sleep(retry_delay) if not success: logger.error("Failed to fetch initial data after multiple attempts, exiting") return # Initialize agent STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64 ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4 agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE, hidden_size=384) # Load model if provided if os.path.exists(model_path): try: agent.load(model_path) logger.info(f"Model loaded successfully from {model_path}") except Exception as e: logger.warning(f"Error loading model: {e}") logger.info("Starting with a new model") else: logger.warning(f"Model file {model_path} not found. Starting with a new model.") # Initialize TensorBoard writer run_id = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') writer = SummaryWriter(log_dir=f"runs/live_training_{run_id}") agent.writer = writer # Initialize training statistics total_rewards = 0 episode_count = 0 best_reward = float('-inf') best_pnl = float('-inf') # Start live training loop logger.info(f"Starting live training loop") step_counter = 0 last_update_time = datetime.datetime.now() # Track consecutive errors to enable circuit breaker consecutive_errors = 0 max_consecutive_errors = 5 while True: # Check if we've reached the maximum number of episodes if max_episodes > 0 and episode_count >= max_episodes: logger.info(f"Reached maximum episodes ({max_episodes}), stopping") break # Check if it's time to update data current_time = datetime.datetime.now() time_diff = (current_time - last_update_time).total_seconds() if time_diff >= update_interval: logger.info(f"Updating market data after {time_diff:.1f} seconds") success = await fetch_and_update_data(exchange, env, symbol, timeframe) if not success: logger.warning("Failed to update data, will try again later") # Wait a bit before trying again await asyncio.sleep(retry_delay) continue last_update_time = current_time # Clean up memory before running an episode manage_memory() # Run training iterations on the updated data episode_reward = 0 env.reset() done = False # Run one simulated episode with the current data steps_in_episode = 0 max_steps = len(env.data) - env.window_size - 1 logger.info(f"Starting episode {episode_count + 1} with {max_steps} steps") while not done and steps_in_episode < max_steps: try: state = env.get_state() action = agent.select_action(state, training=True) try: next_state, reward, done, info = env.step(action) except ValueError as e: logger.error(f"Error during env.step: {e}") # If we get a ValueError, it might be because step is returning 3 values instead of 4 # Let's try to handle this case if "too many values to unpack" in str(e): logger.info("Trying alternative step format") result = env.step(action) if len(result) == 3: next_state, reward, done = result info = {} else: raise else: raise # Save experience in replay memory agent.memory.push(state, action, reward, next_state, done) # Move to the next state state = next_state episode_reward += reward step_counter += 1 steps_in_episode += 1 # Log action and results every 50 steps if steps_in_episode % 50 == 0: logger.info(f"Step {steps_in_episode}/{max_steps} | Action: {action} | Reward: {reward:.2f} | Balance: ${env.balance:.2f}") # Train the agent on a batch of experiences if len(agent.memory) > batch_size: try: agent.learn() # Additional training iterations if steps_in_episode % 10 == 0 and training_iterations > 1: for _ in range(training_iterations - 1): agent.learn() # Reset consecutive errors counter on successful learning consecutive_errors = 0 except Exception as e: logger.error(f"Error during learning: {e}") consecutive_errors += 1 if consecutive_errors >= max_consecutive_errors: logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors") break if done: logger.info(f"Episode done after {steps_in_episode} steps") break except Exception as e: logger.error(f"Error during episode step: {e}") logger.error(traceback.format_exc()) consecutive_errors += 1 if consecutive_errors >= max_consecutive_errors: logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors") break # Update training statistics episode_count += 1 total_rewards += episode_reward avg_reward = total_rewards / episode_count # Track metrics writer.add_scalar('LiveTraining/Reward', episode_reward, episode_count) writer.add_scalar('LiveTraining/AvgReward', avg_reward, episode_count) writer.add_scalar('LiveTraining/Balance', env.balance, episode_count) writer.add_scalar('LiveTraining/PnL', env.total_pnl, episode_count) # Report progress logger.info(f""" Episode: {episode_count} Reward: {episode_reward:.2f} Avg Reward: {avg_reward:.2f} Balance: ${env.balance:.2f} PnL: ${env.total_pnl:.2f} Memory Size: {len(agent.memory)} Total Steps: {step_counter} """) # Save the model if it's the best so far (by reward or PnL) if episode_reward > best_reward: best_reward = episode_reward reward_model_path = f"models/trading_agent_best_reward_{run_id}.pt" if robust_save(agent, reward_model_path): logger.info(f"New best reward model saved: {episode_reward:.2f} to {reward_model_path}") else: logger.error(f"Failed to save best reward model") if env.total_pnl > best_pnl: best_pnl = env.total_pnl pnl_model_path = f"models/trading_agent_best_pnl_{run_id}.pt" if robust_save(agent, pnl_model_path): logger.info(f"New best PnL model saved: ${env.total_pnl:.2f} to {pnl_model_path}") else: logger.error(f"Failed to save best PnL model") # Regularly save the model if episode_count % 5 == 0: if robust_save(agent, save_path): logger.info(f"Model checkpoint saved to {save_path}") else: logger.error(f"Failed to save checkpoint") # Update target network periodically if episode_count % 5 == 0: try: agent.update_target_network() logger.info("Target network updated") except Exception as e: logger.error(f"Error updating target network: {e}") # Sleep to avoid excessive API calls await asyncio.sleep(1) except asyncio.CancelledError: logger.info("Live training cancelled") except KeyboardInterrupt: logger.info("Live training stopped by user") except Exception as e: logger.error(f"Error in live training: {e}") logger.error(traceback.format_exc()) finally: # Save final model if 'agent' in locals(): if robust_save(agent, save_path): logger.info(f"Final model saved to {save_path}") else: logger.error(f"Failed to save final model") # Close TensorBoard writer try: writer.close() logger.info("TensorBoard writer closed") except Exception as e: logger.error(f"Error closing TensorBoard writer: {e}") # Close exchange connection if exchange: try: await with_timeout(exchange.close(), timeout=10) logger.info("Exchange connection closed") except Exception as e: logger.error(f"Error closing exchange connection: {e}") # Final memory cleanup manage_memory() logger.info("Live training completed") async def main(): """Main function to parse arguments and start live training""" parser = argparse.ArgumentParser(description='Live Training with Real Market Data') parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol') parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for training') parser.add_argument('--model_path', type=str, default='models/trading_agent_best_pnl.pt', help='Path to initial model') parser.add_argument('--save_path', type=str, default='models/trading_agent_live_trained.pt', help='Path to save improved model') parser.add_argument('--initial_balance', type=float, default=1000, help='Initial balance for simulation') parser.add_argument('--update_interval', type=int, default=60, help='Interval to update data in seconds') parser.add_argument('--training_iterations', type=int, default=100, help='Training iterations per update') parser.add_argument('--max_episodes', type=int, default=0, help='Maximum number of episodes (0 for unlimited)') parser.add_argument('--retry_delay', type=int, default=5, help='Seconds to wait before retrying after an error') parser.add_argument('--max_retries', type=int, default=3, help='Maximum number of retries for operations') args = parser.parse_args() logger.info(f"Starting live training with {args.symbol} on {args.timeframe} timeframe") await live_training( symbol=args.symbol, timeframe=args.timeframe, model_path=args.model_path, save_path=args.save_path, initial_balance=args.initial_balance, update_interval=args.update_interval, training_iterations=args.training_iterations, max_episodes=args.max_episodes, retry_delay=args.retry_delay, max_retries=args.max_retries, ) # Override Agent's save method with our robust save function def monkey_patch_agent_save(): """Replace Agent's save method with our robust save approach""" original_save = Agent.save def patched_save(self, path): return robust_save(self, path) # Apply the patch Agent.save = patched_save logger.info("Monkey patched Agent.save with robust_save") # Return the original method in case we need to restore it return original_save # Call the monkey patch function at the appropriate place if __name__ == "__main__": try: print("Starting live training script") # Apply the monkey patch before running the main function original_save = monkey_patch_agent_save() asyncio.run(main()) except KeyboardInterrupt: logger.info("Live training stopped by user") except Exception as e: logger.error(f"Error in main function: {e}") logger.error(traceback.format_exc())