improvements
This commit is contained in:
parent
2b1f00cbfc
commit
715261a3f9
@ -1 +1 @@
|
||||
{"best_reward": 202.7441047517104, "best_pnl": -10.072078721366783, "best_win_rate": 30.864197530864196, "last_episode": 10, "timestamp": "2025-03-10T12:45:27.247997"}
|
||||
{"best_reward": 202.7441047517104, "best_pnl": -1.285678343969877, "best_win_rate": 38.70967741935484, "last_episode": 20, "timestamp": "2025-03-10T13:31:02.938465"}
|
@ -499,6 +499,8 @@ class TradingEnvironment:
|
||||
'type': 'long',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.stop_loss,
|
||||
'entry_time': self.data[self.entry_index]['timestamp'],
|
||||
'exit_time': self.data[self.current_step]['timestamp'],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'duration': self.current_step - self.entry_index,
|
||||
@ -542,6 +544,8 @@ class TradingEnvironment:
|
||||
'type': 'long',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.take_profit,
|
||||
'entry_time': self.data[self.entry_index]['timestamp'],
|
||||
'exit_time': self.data[self.current_step]['timestamp'],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'duration': self.current_step - self.entry_index,
|
||||
@ -588,6 +592,8 @@ class TradingEnvironment:
|
||||
'type': 'short',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.stop_loss,
|
||||
'entry_time': self.data[self.entry_index]['timestamp'],
|
||||
'exit_time': self.data[self.current_step]['timestamp'],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'duration': self.current_step - self.entry_index,
|
||||
@ -631,6 +637,8 @@ class TradingEnvironment:
|
||||
'type': 'short',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.take_profit,
|
||||
'entry_time': self.data[self.entry_index]['timestamp'],
|
||||
'exit_time': self.data[self.current_step]['timestamp'],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'duration': self.current_step - self.entry_index,
|
||||
@ -808,6 +816,8 @@ class TradingEnvironment:
|
||||
'type': 'short',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.current_price,
|
||||
'entry_time': self.data[self.entry_index]['timestamp'],
|
||||
'exit_time': self.data[self.current_step]['timestamp'],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'duration': self.current_step - self.entry_index,
|
||||
@ -872,6 +882,8 @@ class TradingEnvironment:
|
||||
'type': 'long',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.current_price,
|
||||
'entry_time': self.data[self.entry_index]['timestamp'],
|
||||
'exit_time': self.data[self.current_step]['timestamp'],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar
|
||||
})
|
||||
@ -925,6 +937,8 @@ class TradingEnvironment:
|
||||
'type': 'long',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.current_price,
|
||||
'entry_time': self.data[self.entry_index]['timestamp'],
|
||||
'exit_time': self.data[self.current_step]['timestamp'],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar
|
||||
})
|
||||
@ -970,6 +984,8 @@ class TradingEnvironment:
|
||||
'type': 'short',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.current_price,
|
||||
'entry_time': self.data[self.entry_index]['timestamp'],
|
||||
'exit_time': self.data[self.current_step]['timestamp'],
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar
|
||||
})
|
||||
@ -1667,6 +1683,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
|
||||
# Add early stopping based on performance
|
||||
patience = 50 # Episodes to wait for improvement
|
||||
best_pnl = -float('inf')
|
||||
best_reward = -float('inf') # Initialize best_reward
|
||||
best_win_rate = 0 # Initialize best_win_rate
|
||||
episodes_without_improvement = 0
|
||||
|
||||
# Add adaptive learning rate
|
||||
@ -1684,18 +1702,42 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
|
||||
|
||||
stats = {
|
||||
'episode_rewards': [],
|
||||
'episode_lengths': [],
|
||||
'balances': [],
|
||||
'episode_profits': [],
|
||||
'win_rates': [],
|
||||
'episode_pnls': [],
|
||||
'cumulative_pnl': [],
|
||||
'drawdowns': [],
|
||||
'prediction_accuracy': []
|
||||
'trade_counts': [],
|
||||
'prediction_accuracies': []
|
||||
}
|
||||
|
||||
# Create checkpoint directory if it doesn't exist
|
||||
os.makedirs("checkpoints", exist_ok=True)
|
||||
|
||||
# Load best model if it exists (to resume training)
|
||||
best_model_path = "models/trading_agent_best_pnl.pt"
|
||||
if os.path.exists(best_model_path):
|
||||
try:
|
||||
# Initialize price predictor
|
||||
env.initialize_price_predictor(agent.device)
|
||||
logger.info(f"Loading best model from {best_model_path} to resume training")
|
||||
agent.load(best_model_path)
|
||||
# Try to load best metrics from checkpoint file
|
||||
checkpoint_info_path = "checkpoints/best_metrics.json"
|
||||
if os.path.exists(checkpoint_info_path):
|
||||
with open(checkpoint_info_path, 'r') as f:
|
||||
best_metrics = json.load(f)
|
||||
best_reward = best_metrics.get('best_reward', best_reward)
|
||||
best_pnl = best_metrics.get('best_pnl', best_pnl)
|
||||
best_win_rate = best_metrics.get('best_win_rate', best_win_rate)
|
||||
logger.info(f"Resumed with best metrics - Reward: {best_reward:.2f}, PnL: ${best_pnl:.2f}, Win Rate: {best_win_rate:.1f}%")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load best model: {e}")
|
||||
|
||||
try:
|
||||
# Initialize price predictor and attach it to the environment
|
||||
price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5)
|
||||
price_predictor.to(agent.device)
|
||||
price_predictor_optimizer = optim.Adam(price_predictor.parameters(), lr=1e-4)
|
||||
|
||||
# Attach the price predictor to the environment
|
||||
env.price_predictor = price_predictor
|
||||
env.price_predictor_optimizer = price_predictor_optimizer
|
||||
|
||||
for episode in range(num_episodes):
|
||||
try:
|
||||
@ -1731,11 +1773,22 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
|
||||
|
||||
# Reset environment
|
||||
state = env.reset()
|
||||
episode_reward = 0
|
||||
env.episode_pnl = 0.0 # Reset episode PnL
|
||||
|
||||
# Identify optimal trade points for this episode
|
||||
env.identify_optimal_trades()
|
||||
# Initialize episode variables
|
||||
episode_reward = 0
|
||||
done = False
|
||||
step = 0
|
||||
|
||||
# Initialize trade analysis dictionary
|
||||
trade_analysis = {
|
||||
'win_rate': 0,
|
||||
'uptrend_win_rate': 0,
|
||||
'downtrend_win_rate': 0,
|
||||
'sideways_win_rate': 0,
|
||||
'avg_win_pnl': 0,
|
||||
'avg_loss_pnl': 0,
|
||||
'max_drawdown': 0
|
||||
}
|
||||
|
||||
# Train price predictor
|
||||
prediction_loss = env.train_price_predictor()
|
||||
@ -1743,7 +1796,7 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
|
||||
# Update price predictions
|
||||
env.update_price_predictions()
|
||||
|
||||
for step in range(max_steps_per_episode):
|
||||
while not done:
|
||||
# Select action
|
||||
action = agent.select_action(state)
|
||||
|
||||
@ -1782,13 +1835,10 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
|
||||
|
||||
# Update stats
|
||||
stats['episode_rewards'].append(episode_reward)
|
||||
stats['episode_lengths'].append(step + 1)
|
||||
stats['balances'].append(env.balance)
|
||||
stats['episode_profits'].append(env.episode_pnl)
|
||||
stats['win_rates'].append(win_rate)
|
||||
stats['episode_pnls'].append(env.episode_pnl)
|
||||
stats['cumulative_pnl'].append(env.total_pnl)
|
||||
stats['drawdowns'].append(env.max_drawdown * 100)
|
||||
stats['prediction_accuracy'].append(prediction_accuracy)
|
||||
stats['trade_counts'].append(total_trades)
|
||||
stats['prediction_accuracies'].append(prediction_accuracy)
|
||||
|
||||
# Log detailed trade analysis
|
||||
if trade_analysis:
|
||||
@ -1866,6 +1916,10 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
|
||||
logger.info(f"Early stopping triggered after {episode+1} episodes without improvement")
|
||||
break
|
||||
|
||||
# Create visualization every 10 episodes or on the last episode
|
||||
if episode % 10 == 0 or episode == num_episodes - 1:
|
||||
visualize_training_results(env, agent, episode)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in episode {episode}: {e}")
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
@ -1903,7 +1957,7 @@ def plot_training_results(stats):
|
||||
|
||||
# Plot balance
|
||||
plt.subplot(3, 2, 2)
|
||||
plt.plot(stats['balances'])
|
||||
plt.plot(stats['episode_profits'])
|
||||
plt.title('Account Balance')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Balance ($)')
|
||||
|
Loading…
x
Reference in New Issue
Block a user