gogo2/live_training.py
Dobromir Popov 3871afd4b8 init
2025-03-18 09:23:09 +02:00

593 lines
24 KiB
Python

#!/usr/bin/env python
import asyncio
import logging
import sys
import platform
import argparse
import os
import datetime
import traceback
import numpy as np
import torch
import gc
from functools import partial
from main import initialize_exchange, TradingEnvironment, Agent
from torch.utils.tensorboard import SummaryWriter
# Fix for Windows asyncio issues with aiodns
if platform.system() == 'Windows':
try:
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
print("Using Windows SelectorEventLoopPolicy to fix aiodns issue")
except Exception as e:
print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}")
# Setup logging function
def setup_logging():
"""Setup logging configuration for the application"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("live_training.log"),
logging.StreamHandler(sys.stdout) # Added stdout handler for immediate feedback
]
)
# Set up logging
setup_logging()
logger = logging.getLogger(__name__)
# Implement a robust save function to handle PyTorch serialization errors
def robust_save(model, path):
"""
Robust model saving with multiple fallback approaches
Args:
model: The Agent model to save
path: Path to save the model
Returns:
bool: True if successful, False otherwise
"""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
# Backup path in case the main save fails
backup_path = f"{path}.backup"
# Clean up GPU memory before saving
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Attempt 1: Try with default settings in a separate file first
try:
logger.info(f"Saving model to {backup_path} (attempt 1)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, backup_path)
logger.info(f"Successfully saved to {backup_path}")
# If backup worked, copy to the actual path
if os.path.exists(backup_path):
import shutil
shutil.copy(backup_path, path)
logger.info(f"Copied backup to {path}")
return True
except Exception as e:
logger.warning(f"First save attempt failed: {e}")
# Attempt 2: Try with pickle protocol 2 (more compatible)
try:
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path, pickle_protocol=2)
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
return True
except Exception as e:
logger.warning(f"Second save attempt failed: {e}")
# Attempt 3: Try without optimizer state (which can be large and cause issues)
try:
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path)
logger.info(f"Successfully saved to {path} without optimizer state")
return True
except Exception as e:
logger.warning(f"Third save attempt failed: {e}")
# Attempt 4: Try with torch.jit.save instead
try:
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
# Save policy network using jit
scripted_policy = torch.jit.script(model.policy_net)
torch.jit.save(scripted_policy, f"{path}.policy.jit")
# Save target network using jit
scripted_target = torch.jit.script(model.target_net)
torch.jit.save(scripted_target, f"{path}.target.jit")
# Save epsilon value separately
with open(f"{path}.epsilon.txt", "w") as f:
f.write(str(model.epsilon))
logger.info(f"Successfully saved model components with jit.save")
return True
except Exception as e:
logger.error(f"All save attempts failed: {e}")
return False
# Implement timeout wrapper for exchange operations
async def with_timeout(coroutine, timeout=30, default=None):
"""
Execute a coroutine with a timeout
Args:
coroutine: The coroutine to execute
timeout: Timeout in seconds
default: Default value to return on timeout
Returns:
The result of the coroutine or default value on timeout
"""
try:
return await asyncio.wait_for(coroutine, timeout=timeout)
except asyncio.TimeoutError:
logger.warning(f"Operation timed out after {timeout} seconds")
return default
except Exception as e:
logger.error(f"Operation failed: {e}")
return default
# Implement fetch_and_update_data function
async def fetch_and_update_data(exchange, env, symbol, timeframe):
"""
Fetch new candle data and update the environment
Args:
exchange: CCXT exchange instance
env: Trading environment instance
symbol: Trading pair symbol
timeframe: Timeframe for the candles
"""
logger.info(f"Fetching new data for {symbol} on {timeframe} timeframe")
try:
# Default to 100 candles if not specified
limit = 1000
# Fetch OHLCV data with timeout
candles = await with_timeout(
exchange.fetch_ohlcv(symbol, timeframe, limit=limit),
timeout=30,
default=[]
)
if not candles or len(candles) == 0:
logger.warning(f"No candles returned for {symbol} on {timeframe}")
return False
logger.info(f"Successfully fetched {len(candles)} candles")
# Convert to format expected by environment
formatted_candles = []
for candle in candles:
timestamp, open_price, high, low, close, volume = candle
formatted_candles.append({
'timestamp': timestamp,
'open': open_price,
'high': high,
'low': low,
'close': close,
'volume': volume
})
# Update environment data
env.data = formatted_candles
if hasattr(env, '_initialize_features'):
env._initialize_features()
logger.info(f"Updated environment with {len(formatted_candles)} candles")
# Print latest candle info
if formatted_candles:
latest = formatted_candles[-1]
dt = datetime.datetime.fromtimestamp(latest['timestamp']/1000).strftime('%Y-%m-%d %H:%M:%S')
logger.info(f"Latest candle: Time={dt}, Open={latest['open']}, High={latest['high']}, Low={latest['low']}, Close={latest['close']}, Volume={latest['volume']}")
return True
except Exception as e:
logger.error(f"Error fetching candle data: {e}")
logger.error(traceback.format_exc())
return False
# Implement memory management function
def manage_memory():
"""
Clean up memory to avoid memory leaks during long running sessions
"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.debug("Memory cleaned")
async def live_training(
symbol="ETH/USDT",
timeframe="1m",
model_path="models/trading_agent_best_pnl.pt",
save_path="models/trading_agent_live_trained.pt",
initial_balance=1000,
update_interval=60,
training_iterations=100,
learning_rate=0.0001,
batch_size=64,
gamma=0.99,
window_size=30,
max_episodes=0, # 0 means unlimited
retry_delay=5, # Seconds to wait before retrying after an error
max_retries=3, # Maximum number of retries for operations
):
"""
Live training function that uses real market data to improve the model without executing real trades.
Args:
symbol: Trading pair symbol
timeframe: Timeframe for training
model_path: Path to the initial model to load
save_path: Path to save the improved model
initial_balance: Initial balance for simulation
update_interval: Interval to update data in seconds
training_iterations: Number of training iterations per data update
learning_rate: Learning rate for training
batch_size: Batch size for training
gamma: Discount factor for training
window_size: Window size for the environment
max_episodes: Maximum number of episodes (0 for unlimited)
retry_delay: Seconds to wait before retrying after an error
max_retries: Maximum number of retries for operations
"""
logger.info(f"Starting live training for {symbol} on {timeframe} timeframe")
# Initialize exchange (without sandbox mode)
exchange = None
# Retry loop for exchange initialization
for retry in range(max_retries):
try:
exchange = await initialize_exchange()
logger.info(f"Exchange initialized: {exchange.id}")
break
except Exception as e:
logger.error(f"Error initializing exchange (attempt {retry+1}/{max_retries}): {e}")
if retry < max_retries - 1:
logger.info(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached. Could not initialize exchange.")
return
try:
# Initialize environment
env = TradingEnvironment(
initial_balance=initial_balance,
window_size=window_size,
symbol=symbol,
timeframe=timeframe,
)
# Fetch initial data (with retries)
logger.info(f"Fetching initial data for {symbol}")
success = False
for retry in range(max_retries):
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
if success:
break
logger.warning(f"Failed to fetch initial data (attempt {retry+1}/{max_retries})")
if retry < max_retries - 1:
logger.info(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
if not success:
logger.error("Failed to fetch initial data after multiple attempts, exiting")
return
# Initialize agent
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE, hidden_size=384)
# Load model if provided
if os.path.exists(model_path):
try:
agent.load(model_path)
logger.info(f"Model loaded successfully from {model_path}")
except Exception as e:
logger.warning(f"Error loading model: {e}")
logger.info("Starting with a new model")
else:
logger.warning(f"Model file {model_path} not found. Starting with a new model.")
# Initialize TensorBoard writer
run_id = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter(log_dir=f"runs/live_training_{run_id}")
agent.writer = writer
# Initialize training statistics
total_rewards = 0
episode_count = 0
best_reward = float('-inf')
best_pnl = float('-inf')
# Start live training loop
logger.info(f"Starting live training loop")
step_counter = 0
last_update_time = datetime.datetime.now()
# Track consecutive errors to enable circuit breaker
consecutive_errors = 0
max_consecutive_errors = 5
while True:
# Check if we've reached the maximum number of episodes
if max_episodes > 0 and episode_count >= max_episodes:
logger.info(f"Reached maximum episodes ({max_episodes}), stopping")
break
# Check if it's time to update data
current_time = datetime.datetime.now()
time_diff = (current_time - last_update_time).total_seconds()
if time_diff >= update_interval:
logger.info(f"Updating market data after {time_diff:.1f} seconds")
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
if not success:
logger.warning("Failed to update data, will try again later")
# Wait a bit before trying again
await asyncio.sleep(retry_delay)
continue
last_update_time = current_time
# Clean up memory before running an episode
manage_memory()
# Run training iterations on the updated data
episode_reward = 0
env.reset()
done = False
# Run one simulated episode with the current data
steps_in_episode = 0
max_steps = len(env.data) - env.window_size - 1
logger.info(f"Starting episode {episode_count + 1} with {max_steps} steps")
while not done and steps_in_episode < max_steps:
try:
state = env.get_state()
action = agent.select_action(state, training=True)
try:
next_state, reward, done, info = env.step(action)
except ValueError as e:
logger.error(f"Error during env.step: {e}")
# If we get a ValueError, it might be because step is returning 3 values instead of 4
# Let's try to handle this case
if "too many values to unpack" in str(e):
logger.info("Trying alternative step format")
result = env.step(action)
if len(result) == 3:
next_state, reward, done = result
info = {}
else:
raise
else:
raise
# Save experience in replay memory
agent.memory.push(state, action, reward, next_state, done)
# Move to the next state
state = next_state
episode_reward += reward
step_counter += 1
steps_in_episode += 1
# Log action and results every 50 steps
if steps_in_episode % 50 == 0:
logger.info(f"Step {steps_in_episode}/{max_steps} | Action: {action} | Reward: {reward:.2f} | Balance: ${env.balance:.2f}")
# Train the agent on a batch of experiences
if len(agent.memory) > batch_size:
try:
agent.learn()
# Additional training iterations
if steps_in_episode % 10 == 0 and training_iterations > 1:
for _ in range(training_iterations - 1):
agent.learn()
# Reset consecutive errors counter on successful learning
consecutive_errors = 0
except Exception as e:
logger.error(f"Error during learning: {e}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
break
if done:
logger.info(f"Episode done after {steps_in_episode} steps")
break
except Exception as e:
logger.error(f"Error during episode step: {e}")
logger.error(traceback.format_exc())
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
break
# Update training statistics
episode_count += 1
total_rewards += episode_reward
avg_reward = total_rewards / episode_count
# Track metrics
writer.add_scalar('LiveTraining/Reward', episode_reward, episode_count)
writer.add_scalar('LiveTraining/AvgReward', avg_reward, episode_count)
writer.add_scalar('LiveTraining/Balance', env.balance, episode_count)
writer.add_scalar('LiveTraining/PnL', env.total_pnl, episode_count)
# Report progress
logger.info(f"""
Episode: {episode_count}
Reward: {episode_reward:.2f}
Avg Reward: {avg_reward:.2f}
Balance: ${env.balance:.2f}
PnL: ${env.total_pnl:.2f}
Memory Size: {len(agent.memory)}
Total Steps: {step_counter}
""")
# Save the model if it's the best so far (by reward or PnL)
if episode_reward > best_reward:
best_reward = episode_reward
reward_model_path = f"models/trading_agent_best_reward_{run_id}.pt"
if robust_save(agent, reward_model_path):
logger.info(f"New best reward model saved: {episode_reward:.2f} to {reward_model_path}")
else:
logger.error(f"Failed to save best reward model")
if env.total_pnl > best_pnl:
best_pnl = env.total_pnl
pnl_model_path = f"models/trading_agent_best_pnl_{run_id}.pt"
if robust_save(agent, pnl_model_path):
logger.info(f"New best PnL model saved: ${env.total_pnl:.2f} to {pnl_model_path}")
else:
logger.error(f"Failed to save best PnL model")
# Regularly save the model
if episode_count % 5 == 0:
if robust_save(agent, save_path):
logger.info(f"Model checkpoint saved to {save_path}")
else:
logger.error(f"Failed to save checkpoint")
# Update target network periodically
if episode_count % 5 == 0:
try:
agent.update_target_network()
logger.info("Target network updated")
except Exception as e:
logger.error(f"Error updating target network: {e}")
# Sleep to avoid excessive API calls
await asyncio.sleep(1)
except asyncio.CancelledError:
logger.info("Live training cancelled")
except KeyboardInterrupt:
logger.info("Live training stopped by user")
except Exception as e:
logger.error(f"Error in live training: {e}")
logger.error(traceback.format_exc())
finally:
# Save final model
if 'agent' in locals():
if robust_save(agent, save_path):
logger.info(f"Final model saved to {save_path}")
else:
logger.error(f"Failed to save final model")
# Close TensorBoard writer
try:
writer.close()
logger.info("TensorBoard writer closed")
except Exception as e:
logger.error(f"Error closing TensorBoard writer: {e}")
# Close exchange connection
if exchange:
try:
await with_timeout(exchange.close(), timeout=10)
logger.info("Exchange connection closed")
except Exception as e:
logger.error(f"Error closing exchange connection: {e}")
# Final memory cleanup
manage_memory()
logger.info("Live training completed")
async def main():
"""Main function to parse arguments and start live training"""
parser = argparse.ArgumentParser(description='Live Training with Real Market Data')
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for training')
parser.add_argument('--model_path', type=str, default='models/trading_agent_best_pnl.pt', help='Path to initial model')
parser.add_argument('--save_path', type=str, default='models/trading_agent_live_trained.pt', help='Path to save improved model')
parser.add_argument('--initial_balance', type=float, default=1000, help='Initial balance for simulation')
parser.add_argument('--update_interval', type=int, default=60, help='Interval to update data in seconds')
parser.add_argument('--training_iterations', type=int, default=100, help='Training iterations per update')
parser.add_argument('--max_episodes', type=int, default=0, help='Maximum number of episodes (0 for unlimited)')
parser.add_argument('--retry_delay', type=int, default=5, help='Seconds to wait before retrying after an error')
parser.add_argument('--max_retries', type=int, default=3, help='Maximum number of retries for operations')
args = parser.parse_args()
logger.info(f"Starting live training with {args.symbol} on {args.timeframe} timeframe")
await live_training(
symbol=args.symbol,
timeframe=args.timeframe,
model_path=args.model_path,
save_path=args.save_path,
initial_balance=args.initial_balance,
update_interval=args.update_interval,
training_iterations=args.training_iterations,
max_episodes=args.max_episodes,
retry_delay=args.retry_delay,
max_retries=args.max_retries,
)
# Override Agent's save method with our robust save function
def monkey_patch_agent_save():
"""Replace Agent's save method with our robust save approach"""
original_save = Agent.save
def patched_save(self, path):
return robust_save(self, path)
# Apply the patch
Agent.save = patched_save
logger.info("Monkey patched Agent.save with robust_save")
# Return the original method in case we need to restore it
return original_save
# Call the monkey patch function at the appropriate place
if __name__ == "__main__":
try:
print("Starting live training script")
# Apply the monkey patch before running the main function
original_save = monkey_patch_agent_save()
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Live training stopped by user")
except Exception as e:
logger.error(f"Error in main function: {e}")
logger.error(traceback.format_exc())