improvements

This commit is contained in:
Dobromir Popov 2025-03-10 13:32:35 +02:00
parent 2b1f00cbfc
commit 715261a3f9
2 changed files with 75 additions and 21 deletions

View File

@ -1 +1 @@
{"best_reward": 202.7441047517104, "best_pnl": -10.072078721366783, "best_win_rate": 30.864197530864196, "last_episode": 10, "timestamp": "2025-03-10T12:45:27.247997"} {"best_reward": 202.7441047517104, "best_pnl": -1.285678343969877, "best_win_rate": 38.70967741935484, "last_episode": 20, "timestamp": "2025-03-10T13:31:02.938465"}

View File

@ -499,6 +499,8 @@ class TradingEnvironment:
'type': 'long', 'type': 'long',
'entry': self.entry_price, 'entry': self.entry_price,
'exit': self.stop_loss, 'exit': self.stop_loss,
'entry_time': self.data[self.entry_index]['timestamp'],
'exit_time': self.data[self.current_step]['timestamp'],
'pnl_percent': pnl_percent, 'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar, 'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index, 'duration': self.current_step - self.entry_index,
@ -542,6 +544,8 @@ class TradingEnvironment:
'type': 'long', 'type': 'long',
'entry': self.entry_price, 'entry': self.entry_price,
'exit': self.take_profit, 'exit': self.take_profit,
'entry_time': self.data[self.entry_index]['timestamp'],
'exit_time': self.data[self.current_step]['timestamp'],
'pnl_percent': pnl_percent, 'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar, 'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index, 'duration': self.current_step - self.entry_index,
@ -588,6 +592,8 @@ class TradingEnvironment:
'type': 'short', 'type': 'short',
'entry': self.entry_price, 'entry': self.entry_price,
'exit': self.stop_loss, 'exit': self.stop_loss,
'entry_time': self.data[self.entry_index]['timestamp'],
'exit_time': self.data[self.current_step]['timestamp'],
'pnl_percent': pnl_percent, 'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar, 'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index, 'duration': self.current_step - self.entry_index,
@ -631,6 +637,8 @@ class TradingEnvironment:
'type': 'short', 'type': 'short',
'entry': self.entry_price, 'entry': self.entry_price,
'exit': self.take_profit, 'exit': self.take_profit,
'entry_time': self.data[self.entry_index]['timestamp'],
'exit_time': self.data[self.current_step]['timestamp'],
'pnl_percent': pnl_percent, 'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar, 'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index, 'duration': self.current_step - self.entry_index,
@ -808,6 +816,8 @@ class TradingEnvironment:
'type': 'short', 'type': 'short',
'entry': self.entry_price, 'entry': self.entry_price,
'exit': self.current_price, 'exit': self.current_price,
'entry_time': self.data[self.entry_index]['timestamp'],
'exit_time': self.data[self.current_step]['timestamp'],
'pnl_percent': pnl_percent, 'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar, 'pnl_dollar': pnl_dollar,
'duration': self.current_step - self.entry_index, 'duration': self.current_step - self.entry_index,
@ -872,6 +882,8 @@ class TradingEnvironment:
'type': 'long', 'type': 'long',
'entry': self.entry_price, 'entry': self.entry_price,
'exit': self.current_price, 'exit': self.current_price,
'entry_time': self.data[self.entry_index]['timestamp'],
'exit_time': self.data[self.current_step]['timestamp'],
'pnl_percent': pnl_percent, 'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar 'pnl_dollar': pnl_dollar
}) })
@ -925,6 +937,8 @@ class TradingEnvironment:
'type': 'long', 'type': 'long',
'entry': self.entry_price, 'entry': self.entry_price,
'exit': self.current_price, 'exit': self.current_price,
'entry_time': self.data[self.entry_index]['timestamp'],
'exit_time': self.data[self.current_step]['timestamp'],
'pnl_percent': pnl_percent, 'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar 'pnl_dollar': pnl_dollar
}) })
@ -970,6 +984,8 @@ class TradingEnvironment:
'type': 'short', 'type': 'short',
'entry': self.entry_price, 'entry': self.entry_price,
'exit': self.current_price, 'exit': self.current_price,
'entry_time': self.data[self.entry_index]['timestamp'],
'exit_time': self.data[self.current_step]['timestamp'],
'pnl_percent': pnl_percent, 'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar 'pnl_dollar': pnl_dollar
}) })
@ -1667,6 +1683,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
# Add early stopping based on performance # Add early stopping based on performance
patience = 50 # Episodes to wait for improvement patience = 50 # Episodes to wait for improvement
best_pnl = -float('inf') best_pnl = -float('inf')
best_reward = -float('inf') # Initialize best_reward
best_win_rate = 0 # Initialize best_win_rate
episodes_without_improvement = 0 episodes_without_improvement = 0
# Add adaptive learning rate # Add adaptive learning rate
@ -1684,18 +1702,42 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
stats = { stats = {
'episode_rewards': [], 'episode_rewards': [],
'episode_lengths': [], 'episode_profits': [],
'balances': [],
'win_rates': [], 'win_rates': [],
'episode_pnls': [], 'trade_counts': [],
'cumulative_pnl': [], 'prediction_accuracies': []
'drawdowns': [],
'prediction_accuracy': []
} }
# Create checkpoint directory if it doesn't exist
os.makedirs("checkpoints", exist_ok=True)
# Load best model if it exists (to resume training)
best_model_path = "models/trading_agent_best_pnl.pt"
if os.path.exists(best_model_path):
try:
logger.info(f"Loading best model from {best_model_path} to resume training")
agent.load(best_model_path)
# Try to load best metrics from checkpoint file
checkpoint_info_path = "checkpoints/best_metrics.json"
if os.path.exists(checkpoint_info_path):
with open(checkpoint_info_path, 'r') as f:
best_metrics = json.load(f)
best_reward = best_metrics.get('best_reward', best_reward)
best_pnl = best_metrics.get('best_pnl', best_pnl)
best_win_rate = best_metrics.get('best_win_rate', best_win_rate)
logger.info(f"Resumed with best metrics - Reward: {best_reward:.2f}, PnL: ${best_pnl:.2f}, Win Rate: {best_win_rate:.1f}%")
except Exception as e:
logger.warning(f"Could not load best model: {e}")
try: try:
# Initialize price predictor # Initialize price predictor and attach it to the environment
env.initialize_price_predictor(agent.device) price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5)
price_predictor.to(agent.device)
price_predictor_optimizer = optim.Adam(price_predictor.parameters(), lr=1e-4)
# Attach the price predictor to the environment
env.price_predictor = price_predictor
env.price_predictor_optimizer = price_predictor_optimizer
for episode in range(num_episodes): for episode in range(num_episodes):
try: try:
@ -1731,11 +1773,22 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
# Reset environment # Reset environment
state = env.reset() state = env.reset()
episode_reward = 0
env.episode_pnl = 0.0 # Reset episode PnL
# Identify optimal trade points for this episode # Initialize episode variables
env.identify_optimal_trades() episode_reward = 0
done = False
step = 0
# Initialize trade analysis dictionary
trade_analysis = {
'win_rate': 0,
'uptrend_win_rate': 0,
'downtrend_win_rate': 0,
'sideways_win_rate': 0,
'avg_win_pnl': 0,
'avg_loss_pnl': 0,
'max_drawdown': 0
}
# Train price predictor # Train price predictor
prediction_loss = env.train_price_predictor() prediction_loss = env.train_price_predictor()
@ -1743,7 +1796,7 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
# Update price predictions # Update price predictions
env.update_price_predictions() env.update_price_predictions()
for step in range(max_steps_per_episode): while not done:
# Select action # Select action
action = agent.select_action(state) action = agent.select_action(state)
@ -1782,13 +1835,10 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
# Update stats # Update stats
stats['episode_rewards'].append(episode_reward) stats['episode_rewards'].append(episode_reward)
stats['episode_lengths'].append(step + 1) stats['episode_profits'].append(env.episode_pnl)
stats['balances'].append(env.balance)
stats['win_rates'].append(win_rate) stats['win_rates'].append(win_rate)
stats['episode_pnls'].append(env.episode_pnl) stats['trade_counts'].append(total_trades)
stats['cumulative_pnl'].append(env.total_pnl) stats['prediction_accuracies'].append(prediction_accuracy)
stats['drawdowns'].append(env.max_drawdown * 100)
stats['prediction_accuracy'].append(prediction_accuracy)
# Log detailed trade analysis # Log detailed trade analysis
if trade_analysis: if trade_analysis:
@ -1866,6 +1916,10 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
logger.info(f"Early stopping triggered after {episode+1} episodes without improvement") logger.info(f"Early stopping triggered after {episode+1} episodes without improvement")
break break
# Create visualization every 10 episodes or on the last episode
if episode % 10 == 0 or episode == num_episodes - 1:
visualize_training_results(env, agent, episode)
except Exception as e: except Exception as e:
logger.error(f"Error in episode {episode}: {e}") logger.error(f"Error in episode {episode}: {e}")
logger.error(f"Traceback: {traceback.format_exc()}") logger.error(f"Traceback: {traceback.format_exc()}")
@ -1903,7 +1957,7 @@ def plot_training_results(stats):
# Plot balance # Plot balance
plt.subplot(3, 2, 2) plt.subplot(3, 2, 2)
plt.plot(stats['balances']) plt.plot(stats['episode_profits'])
plt.title('Account Balance') plt.title('Account Balance')
plt.xlabel('Episode') plt.xlabel('Episode')
plt.ylabel('Balance ($)') plt.ylabel('Balance ($)')