585 lines
24 KiB
Python
585 lines
24 KiB
Python
import os
|
|
import sys
|
|
import time
|
|
import logging
|
|
import argparse
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import TensorDataset, DataLoader
|
|
import contextlib
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
# Add parent directory to path
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
# Import our enhanced agent
|
|
from NN.models.dqn_agent_enhanced import EnhancedDQNAgent
|
|
from NN.utils.data_interface import DataInterface
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler(),
|
|
logging.FileHandler('logs/enhanced_training.log')
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def parse_args():
|
|
"""Parse command line arguments"""
|
|
parser = argparse.ArgumentParser(description='Train enhanced RL trading agent')
|
|
parser.add_argument('--episodes', type=int, default=100, help='Number of episodes to train')
|
|
parser.add_argument('--max-steps', type=int, default=2000, help='Maximum steps per episode')
|
|
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading symbol')
|
|
parser.add_argument('--no-gpu', action='store_true', help='Disable GPU usage')
|
|
parser.add_argument('--confidence', type=float, default=0.4, help='Confidence threshold')
|
|
parser.add_argument('--load-model', type=str, default='', help='Load existing model')
|
|
parser.add_argument('--batch-size', type=int, default=128, help='Training batch size')
|
|
parser.add_argument('--learning-rate', type=float, default=0.0003, help='Learning rate')
|
|
parser.add_argument('--no-pretrain', action='store_true', help='Skip pre-training')
|
|
parser.add_argument('--pretrain-epochs', type=int, default=20, help='Number of pre-training epochs')
|
|
return parser.parse_args()
|
|
|
|
def generate_price_prediction_training_data(data_1m, data_1h, data_1d, window_size=20):
|
|
"""
|
|
Generate labeled training data for price prediction pre-training
|
|
|
|
Args:
|
|
data_1m: 1-minute candle data
|
|
data_1h: 1-hour candle data
|
|
data_1d: 1-day candle data
|
|
window_size: Size of the observation window
|
|
|
|
Returns:
|
|
X, y_immediate, y_midterm, y_longterm, y_values
|
|
"""
|
|
logger.info("Generating price prediction training data")
|
|
|
|
# Features to use
|
|
ohlcv_columns = ['open', 'high', 'low', 'close', 'volume']
|
|
|
|
# Create feature sets
|
|
X = []
|
|
y_immediate = [] # 1m prediction (next 5min)
|
|
y_midterm = [] # 1h prediction (next few hours)
|
|
y_longterm = [] # 1d prediction (next day)
|
|
y_values = [] # % change for each timeframe
|
|
|
|
# Need enough data for all timeframes
|
|
if len(data_1m) < window_size + 5 or len(data_1h) < 2 or len(data_1d) < 2:
|
|
logger.error("Not enough data for all timeframes")
|
|
return np.array([]), np.array([]), np.array([]), np.array([]), np.array([])
|
|
|
|
# Generate examples
|
|
for i in range(window_size, len(data_1m) - 5):
|
|
# Skip if we can't align with higher timeframes
|
|
if i % 60 != 0: # Only use minutes that align with hour boundaries
|
|
continue
|
|
|
|
try:
|
|
# Get window of 1m data as input
|
|
window_1m = data_1m[i-window_size:i][ohlcv_columns].values
|
|
|
|
# Find corresponding indices in higher timeframes
|
|
curr_timestamp = data_1m.index[i]
|
|
h_idx = data_1h.index.get_indexer([curr_timestamp], method='nearest')[0]
|
|
d_idx = data_1d.index.get_indexer([curr_timestamp], method='nearest')[0]
|
|
|
|
# Skip if indices are out of bounds
|
|
if h_idx < 0 or h_idx >= len(data_1h) - 1 or d_idx < 0 or d_idx >= len(data_1d) - 1:
|
|
continue
|
|
|
|
# Get future prices for label generation
|
|
future_5m = data_1m[i+5]['close']
|
|
future_1h = data_1h[h_idx+1]['close']
|
|
future_1d = data_1d[d_idx+1]['close']
|
|
|
|
current_price = data_1m[i]['close']
|
|
|
|
# Calculate % change for each timeframe
|
|
change_5m = (future_5m - current_price) / current_price * 100
|
|
change_1h = (future_1h - current_price) / current_price * 100
|
|
change_1d = (future_1d - current_price) / current_price * 100
|
|
|
|
# Determine price direction (0=down, 1=sideways, 2=up)
|
|
def get_direction(change):
|
|
if change < -0.5: # Down if less than -0.5%
|
|
return 0
|
|
elif change > 0.5: # Up if more than 0.5%
|
|
return 2
|
|
else: # Sideways if between -0.5% and 0.5%
|
|
return 1
|
|
|
|
direction_5m = get_direction(change_5m)
|
|
direction_1h = get_direction(change_1h)
|
|
direction_1d = get_direction(change_1d)
|
|
|
|
# Add to dataset
|
|
X.append(window_1m.flatten())
|
|
y_immediate.append(direction_5m)
|
|
y_midterm.append(direction_1h)
|
|
y_longterm.append(direction_1d)
|
|
y_values.append([change_5m, change_1h, change_1d, 0]) # Last value reserved
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error generating training example at index {i}: {str(e)}")
|
|
|
|
# Convert to numpy arrays
|
|
X = np.array(X)
|
|
y_immediate = np.array(y_immediate)
|
|
y_midterm = np.array(y_midterm)
|
|
y_longterm = np.array(y_longterm)
|
|
y_values = np.array(y_values)
|
|
|
|
logger.info(f"Generated {len(X)} training examples")
|
|
logger.info(f"Class distribution - Immediate: {np.bincount(y_immediate)}, "
|
|
f"Midterm: {np.bincount(y_midterm)}, Long-term: {np.bincount(y_longterm)}")
|
|
|
|
return X, y_immediate, y_midterm, y_longterm, y_values
|
|
|
|
def pretrain_price_prediction(agent, data_interface, n_epochs=20, batch_size=128, device=None):
|
|
"""
|
|
Pre-train the price prediction capabilities of the agent
|
|
|
|
Args:
|
|
agent: EnhancedDQNAgent instance
|
|
data_interface: DataInterface instance
|
|
n_epochs: Number of pre-training epochs
|
|
batch_size: Batch size for pre-training
|
|
device: Device to use for pre-training
|
|
|
|
Returns:
|
|
The pre-trained agent
|
|
"""
|
|
logger.info("Starting price prediction pre-training")
|
|
|
|
try:
|
|
# Ensure we have the necessary timeframes
|
|
timeframes_needed = ['1m', '1h', '1d']
|
|
for tf in timeframes_needed:
|
|
if tf not in data_interface.timeframes:
|
|
logger.info(f"Adding timeframe {tf} for pre-training")
|
|
# Add timeframe to the list if not present
|
|
if tf not in data_interface.timeframes:
|
|
data_interface.timeframes.append(tf)
|
|
data_interface.dataframes[tf] = None
|
|
|
|
# Get data for each timeframe
|
|
data_1m = data_interface.get_historical_data(timeframe='1m')
|
|
data_1h = data_interface.get_historical_data(timeframe='1h')
|
|
data_1d = data_interface.get_historical_data(timeframe='1d')
|
|
|
|
# Generate labeled training data
|
|
X, y_immediate, y_midterm, y_longterm, y_values = generate_price_prediction_training_data(
|
|
data_1m, data_1h, data_1d, window_size=20
|
|
)
|
|
|
|
if len(X) == 0:
|
|
logger.error("No training examples generated. Skipping pre-training.")
|
|
return agent
|
|
|
|
# Split data into training and validation sets
|
|
X_train, X_val, y_imm_train, y_imm_val, y_mid_train, y_mid_val, y_long_train, y_long_val, y_val_train, y_val_val = train_test_split(
|
|
X, y_immediate, y_midterm, y_longterm, y_values, test_size=0.2, random_state=42
|
|
)
|
|
|
|
# Convert to torch tensors
|
|
X_train_tensor = torch.FloatTensor(X_train).to(device)
|
|
y_imm_train_tensor = torch.LongTensor(y_imm_train).to(device)
|
|
y_mid_train_tensor = torch.LongTensor(y_mid_train).to(device)
|
|
y_long_train_tensor = torch.LongTensor(y_long_train).to(device)
|
|
y_val_train_tensor = torch.FloatTensor(y_val_train).to(device)
|
|
|
|
X_val_tensor = torch.FloatTensor(X_val).to(device)
|
|
y_imm_val_tensor = torch.LongTensor(y_imm_val).to(device)
|
|
y_mid_val_tensor = torch.LongTensor(y_mid_val).to(device)
|
|
y_long_val_tensor = torch.LongTensor(y_long_val).to(device)
|
|
y_val_val_tensor = torch.FloatTensor(y_val_val).to(device)
|
|
|
|
# Calculate class weights for imbalanced data
|
|
def get_class_weights(labels):
|
|
counts = np.bincount(labels)
|
|
if len(counts) < 3: # Ensure we have 3 classes
|
|
counts = np.append(counts, [0] * (3 - len(counts)))
|
|
weights = 1.0 / np.array(counts)
|
|
weights = weights / np.sum(weights) # Normalize
|
|
return weights
|
|
|
|
imm_weights = torch.FloatTensor(get_class_weights(y_imm_train)).to(device)
|
|
mid_weights = torch.FloatTensor(get_class_weights(y_mid_train)).to(device)
|
|
long_weights = torch.FloatTensor(get_class_weights(y_long_train)).to(device)
|
|
|
|
# Create DataLoader for batch training
|
|
train_dataset = TensorDataset(
|
|
X_train_tensor, y_imm_train_tensor, y_mid_train_tensor,
|
|
y_long_train_tensor, y_val_train_tensor
|
|
)
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
# Set up loss functions with class weights
|
|
imm_criterion = nn.CrossEntropyLoss(weight=imm_weights)
|
|
mid_criterion = nn.CrossEntropyLoss(weight=mid_weights)
|
|
long_criterion = nn.CrossEntropyLoss(weight=long_weights)
|
|
value_criterion = nn.MSELoss()
|
|
|
|
# Set up optimizer (separate from agent's optimizer)
|
|
pretrain_optimizer = torch.optim.Adam(agent.policy_net.parameters(), lr=0.0002)
|
|
pretrain_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
pretrain_optimizer, mode='min', factor=0.5, patience=3, verbose=True
|
|
)
|
|
|
|
# Set model to training mode
|
|
agent.policy_net.train()
|
|
|
|
# Training loop
|
|
best_val_loss = float('inf')
|
|
patience = 5
|
|
patience_counter = 0
|
|
|
|
# Create TensorBoard writer for pre-training
|
|
writer = SummaryWriter(log_dir=f'runs/pretrain_{int(time.time())}')
|
|
|
|
for epoch in range(n_epochs):
|
|
# Training phase
|
|
train_loss = 0.0
|
|
imm_correct, mid_correct, long_correct = 0, 0, 0
|
|
total = 0
|
|
|
|
for X_batch, y_imm_batch, y_mid_batch, y_long_batch, y_val_batch in train_loader:
|
|
# Zero gradients
|
|
pretrain_optimizer.zero_grad()
|
|
|
|
# Forward pass
|
|
with torch.cuda.amp.autocast() if agent.use_mixed_precision else contextlib.nullcontext():
|
|
q_values, _, price_preds, _ = agent.policy_net(X_batch)
|
|
|
|
# Calculate losses for each prediction head
|
|
imm_loss = imm_criterion(price_preds['immediate'], y_imm_batch)
|
|
mid_loss = mid_criterion(price_preds['midterm'], y_mid_batch)
|
|
long_loss = long_criterion(price_preds['longterm'], y_long_batch)
|
|
value_loss = value_criterion(price_preds['values'], y_val_batch)
|
|
|
|
# Combined loss (weighted by importance)
|
|
total_loss = imm_loss + 0.7 * mid_loss + 0.5 * long_loss + 0.3 * value_loss
|
|
|
|
# Backward pass and optimize
|
|
if agent.use_mixed_precision:
|
|
agent.scaler.scale(total_loss).backward()
|
|
agent.scaler.unscale_(pretrain_optimizer)
|
|
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
|
|
agent.scaler.step(pretrain_optimizer)
|
|
agent.scaler.update()
|
|
else:
|
|
total_loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
|
|
pretrain_optimizer.step()
|
|
|
|
# Accumulate metrics
|
|
train_loss += total_loss.item()
|
|
total += X_batch.size(0)
|
|
|
|
# Calculate accuracy
|
|
_, imm_pred = torch.max(price_preds['immediate'], 1)
|
|
_, mid_pred = torch.max(price_preds['midterm'], 1)
|
|
_, long_pred = torch.max(price_preds['longterm'], 1)
|
|
|
|
imm_correct += (imm_pred == y_imm_batch).sum().item()
|
|
mid_correct += (mid_pred == y_mid_batch).sum().item()
|
|
long_correct += (long_pred == y_long_batch).sum().item()
|
|
|
|
# Calculate epoch metrics
|
|
train_loss /= len(train_loader)
|
|
imm_acc = imm_correct / total
|
|
mid_acc = mid_correct / total
|
|
long_acc = long_correct / total
|
|
|
|
# Validation phase
|
|
agent.policy_net.eval()
|
|
val_loss = 0.0
|
|
imm_val_correct, mid_val_correct, long_val_correct = 0, 0, 0
|
|
|
|
with torch.no_grad():
|
|
# Forward pass on validation data
|
|
q_values, _, val_price_preds, _ = agent.policy_net(X_val_tensor)
|
|
|
|
# Calculate validation losses
|
|
val_imm_loss = imm_criterion(val_price_preds['immediate'], y_imm_val_tensor)
|
|
val_mid_loss = mid_criterion(val_price_preds['midterm'], y_mid_val_tensor)
|
|
val_long_loss = long_criterion(val_price_preds['longterm'], y_long_val_tensor)
|
|
val_value_loss = value_criterion(val_price_preds['values'], y_val_val_tensor)
|
|
|
|
val_total_loss = val_imm_loss + 0.7 * val_mid_loss + 0.5 * val_long_loss + 0.3 * val_value_loss
|
|
val_loss = val_total_loss.item()
|
|
|
|
# Calculate validation accuracy
|
|
_, imm_val_pred = torch.max(val_price_preds['immediate'], 1)
|
|
_, mid_val_pred = torch.max(val_price_preds['midterm'], 1)
|
|
_, long_val_pred = torch.max(val_price_preds['longterm'], 1)
|
|
|
|
imm_val_correct = (imm_val_pred == y_imm_val_tensor).sum().item()
|
|
mid_val_correct = (mid_val_pred == y_mid_val_tensor).sum().item()
|
|
long_val_correct = (long_val_pred == y_long_val_tensor).sum().item()
|
|
|
|
imm_val_acc = imm_val_correct / len(X_val_tensor)
|
|
mid_val_acc = mid_val_correct / len(X_val_tensor)
|
|
long_val_acc = long_val_correct / len(X_val_tensor)
|
|
|
|
# Log to TensorBoard
|
|
writer.add_scalar('pretrain/train_loss', train_loss, epoch)
|
|
writer.add_scalar('pretrain/val_loss', val_loss, epoch)
|
|
writer.add_scalar('pretrain/imm_acc', imm_acc, epoch)
|
|
writer.add_scalar('pretrain/mid_acc', mid_acc, epoch)
|
|
writer.add_scalar('pretrain/long_acc', long_acc, epoch)
|
|
writer.add_scalar('pretrain/imm_val_acc', imm_val_acc, epoch)
|
|
writer.add_scalar('pretrain/mid_val_acc', mid_val_acc, epoch)
|
|
writer.add_scalar('pretrain/long_val_acc', long_val_acc, epoch)
|
|
|
|
# Learning rate scheduling
|
|
pretrain_scheduler.step(val_loss)
|
|
|
|
# Early stopping check
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
patience_counter = 0
|
|
# Copy policy_net weights to target_net
|
|
agent.target_net.load_state_dict(agent.policy_net.state_dict())
|
|
logger.info(f"Saved best model with validation loss: {val_loss:.4f}")
|
|
# Save pre-trained model
|
|
agent.save("NN/models/saved/enhanced_dqn_pretrained")
|
|
else:
|
|
patience_counter += 1
|
|
if patience_counter >= patience:
|
|
logger.info(f"Early stopping triggered after {epoch+1} epochs")
|
|
break
|
|
|
|
# Log progress
|
|
logger.info(f"Epoch {epoch+1}/{n_epochs}: "
|
|
f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
|
|
f"Imm Acc: {imm_acc:.4f}/{imm_val_acc:.4f}, "
|
|
f"Mid Acc: {mid_acc:.4f}/{mid_val_acc:.4f}, "
|
|
f"Long Acc: {long_acc:.4f}/{long_val_acc:.4f}")
|
|
|
|
# Set model back to training mode for next epoch
|
|
agent.policy_net.train()
|
|
|
|
writer.close()
|
|
logger.info("Price prediction pre-training complete")
|
|
return agent
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during price prediction pre-training: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return agent
|
|
|
|
def train_enhanced_rl(args):
|
|
"""
|
|
Train the enhanced RL agent for trading
|
|
|
|
Args:
|
|
args: Command line arguments
|
|
"""
|
|
# Setup device
|
|
if args.no_gpu:
|
|
device = torch.device('cpu')
|
|
else:
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
logger.info(f"Using device: {device}")
|
|
|
|
# Set up data interface
|
|
data_interface = DataInterface(symbol=args.symbol, timeframes=['1m', '5m', '15m'])
|
|
|
|
# Fetch historical data for each timeframe
|
|
for timeframe in data_interface.timeframes:
|
|
df = data_interface.get_historical_data(timeframe=timeframe)
|
|
logger.info(f"Using data for {args.symbol} {timeframe} ({len(data_interface.dataframes[timeframe])} candles)")
|
|
|
|
# Create environment for training
|
|
from NN.environments.trading_env import TradingEnvironment
|
|
window_size = 20
|
|
train_env = TradingEnvironment(
|
|
data_interface=data_interface,
|
|
initial_balance=10000.0,
|
|
transaction_fee=0.0002,
|
|
window_size=window_size,
|
|
max_position=1.0,
|
|
reward_scaling=100.0
|
|
)
|
|
|
|
# Create agent with improved parameters
|
|
state_shape = train_env.observation_space.shape
|
|
n_actions = train_env.action_space.n
|
|
|
|
agent = EnhancedDQNAgent(
|
|
state_shape=state_shape,
|
|
n_actions=n_actions,
|
|
learning_rate=args.learning_rate,
|
|
gamma=0.95,
|
|
epsilon=1.0,
|
|
epsilon_min=0.05,
|
|
epsilon_decay=0.995,
|
|
buffer_size=50000,
|
|
batch_size=args.batch_size,
|
|
target_update=10,
|
|
confidence_threshold=args.confidence,
|
|
device=device
|
|
)
|
|
|
|
# Load existing model if specified
|
|
if args.load_model:
|
|
model_path = args.load_model
|
|
if agent.load(model_path):
|
|
logger.info(f"Loaded existing model from {model_path}")
|
|
else:
|
|
logger.error(f"Error loading model from {model_path}")
|
|
|
|
# Pre-training for price prediction
|
|
if not args.no_pretrain and not args.load_model:
|
|
logger.info("Starting pre-training phase")
|
|
agent = pretrain_price_prediction(
|
|
agent=agent,
|
|
data_interface=data_interface,
|
|
n_epochs=args.pretrain_epochs,
|
|
batch_size=args.batch_size,
|
|
device=device
|
|
)
|
|
logger.info("Pre-training completed")
|
|
|
|
# Setup TensorBoard
|
|
writer = SummaryWriter(log_dir=f'runs/enhanced_rl_{int(time.time())}')
|
|
|
|
# Log hardware info
|
|
writer.add_text("hardware/device", str(device), 0)
|
|
if torch.cuda.is_available():
|
|
for i in range(torch.cuda.device_count()):
|
|
writer.add_text(f"hardware/gpu_{i}", torch.cuda.get_device_name(i), 0)
|
|
|
|
# Move agent to device
|
|
agent.move_models_to_device(device)
|
|
|
|
# Training loop
|
|
logger.info(f"Starting enhanced training for {args.episodes} episodes")
|
|
|
|
total_rewards = []
|
|
episode_losses = []
|
|
trade_win_rates = []
|
|
best_reward = -np.inf
|
|
|
|
try:
|
|
for episode in range(args.episodes):
|
|
# Reset environment for new episode
|
|
state = train_env.reset()
|
|
total_reward = 0.0
|
|
done = False
|
|
step = 0
|
|
episode_start_time = time.time()
|
|
|
|
# Track trade statistics
|
|
trades = []
|
|
wins = 0
|
|
losses = 0
|
|
|
|
# Run episode
|
|
while not done and step < args.max_steps:
|
|
# Choose action
|
|
action, confidence = agent.act(state)
|
|
|
|
# Take action in environment
|
|
next_state, reward, done, info = train_env.step(action)
|
|
|
|
# Remember experience
|
|
agent.remember(state, action, reward, next_state, done)
|
|
|
|
# Track trade results
|
|
if 'trade_result' in info and info['trade_result'] is not None:
|
|
trade_result = info['trade_result']
|
|
trade_pnl = trade_result['pnl']
|
|
trades.append(trade_pnl)
|
|
|
|
if trade_pnl > 0:
|
|
wins += 1
|
|
logger.info(f"Profitable trade! {trade_pnl:.2f}% profit, reward: {reward:.4f}")
|
|
else:
|
|
losses += 1
|
|
logger.info(f"Loss trade! {trade_pnl:.2f}% loss, penalty: {reward:.4f}")
|
|
|
|
# Update state and counters
|
|
state = next_state
|
|
total_reward += reward
|
|
step += 1
|
|
|
|
# Train agent
|
|
loss = agent.replay()
|
|
if loss > 0:
|
|
episode_losses.append(loss)
|
|
|
|
# Log training metrics for each episode
|
|
episode_time = time.time() - episode_start_time
|
|
total_rewards.append(total_reward)
|
|
|
|
# Calculate win rate
|
|
win_rate = wins / max(1, (wins + losses))
|
|
trade_win_rates.append(win_rate)
|
|
|
|
# Log to console and TensorBoard
|
|
logger.info(f"Episode {episode}/{args.episodes} - Reward: {total_reward:.4f}, Win Rate: {win_rate:.2f}, "
|
|
f"Trades: {len(trades)}, Balance: ${train_env.balance:.2f}, Epsilon: {agent.epsilon:.4f}, "
|
|
f"Time: {episode_time:.2f}s")
|
|
|
|
writer.add_scalar('metrics/reward', total_reward, episode)
|
|
writer.add_scalar('metrics/balance', train_env.balance, episode)
|
|
writer.add_scalar('metrics/win_rate', win_rate, episode)
|
|
writer.add_scalar('metrics/trades', len(trades), episode)
|
|
writer.add_scalar('metrics/epsilon', agent.epsilon, episode)
|
|
|
|
if episode_losses:
|
|
avg_loss = sum(episode_losses) / len(episode_losses)
|
|
writer.add_scalar('metrics/loss', avg_loss, episode)
|
|
|
|
# Check if this is the best model so far
|
|
if total_reward > best_reward:
|
|
best_reward = total_reward
|
|
# Save best model
|
|
agent.save(f"NN/models/saved/enhanced_dqn_best")
|
|
logger.info(f"New best model saved with reward: {best_reward:.4f}")
|
|
|
|
# Save checkpoint every 10 episodes
|
|
if episode % 10 == 0 and episode > 0:
|
|
agent.save(f"NN/models/saved/enhanced_dqn_checkpoint")
|
|
logger.info(f"Checkpoint saved at episode {episode}")
|
|
|
|
# Reset episode losses
|
|
episode_losses = []
|
|
|
|
# Final save
|
|
agent.save(f"NN/models/saved/enhanced_dqn_final")
|
|
logger.info("Enhanced training completed, final model saved")
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Training interrupted by user")
|
|
except Exception as e:
|
|
logger.error(f"Training failed: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
finally:
|
|
# Close TensorBoard writer
|
|
writer.close()
|
|
|
|
return agent, train_env
|
|
|
|
if __name__ == "__main__":
|
|
# Create logs directory if it doesn't exist
|
|
os.makedirs("logs", exist_ok=True)
|
|
os.makedirs("NN/models/saved", exist_ok=True)
|
|
|
|
# Parse arguments
|
|
args = parse_args()
|
|
|
|
# Start training
|
|
train_enhanced_rl(args) |