gogo2/core/realtime_rl_trainer.py
Dobromir Popov a6eaa01735 RL trainer
2025-05-28 13:20:15 +03:00

469 lines
19 KiB
Python

"""
Real-Time RL Training System
This module implements continuous learning from live trading decisions.
The RL agent learns from every trade signal and position closure to improve
decision-making over time.
"""
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
import threading
import time
import json
import os
# Import existing DQN agent
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'NN'))
from NN.models.dqn_agent import DQNAgent
logger = logging.getLogger(__name__)
class TradingExperience:
"""Represents a single trading experience for RL learning"""
def __init__(self,
pre_trade_state: np.ndarray,
action: int, # 0=SELL, 1=HOLD, 2=BUY
entry_price: float,
exit_price: float,
holding_time: float, # seconds
pnl: float,
fees: float,
confidence: float,
market_conditions: Dict[str, Any],
timestamp: datetime):
self.pre_trade_state = pre_trade_state
self.action = action
self.entry_price = entry_price
self.exit_price = exit_price
self.holding_time = holding_time
self.pnl = pnl
self.fees = fees
self.confidence = confidence
self.market_conditions = market_conditions
self.timestamp = timestamp
# Calculate reward
self.reward = self._calculate_reward()
def _calculate_reward(self) -> float:
"""Calculate reward for this trading experience"""
# Net PnL after fees
net_pnl = self.pnl - self.fees
# Base reward from PnL (normalized by entry price)
base_reward = net_pnl / self.entry_price
# Time penalty - prefer faster profitable trades
time_penalty = 0.0
if self.holding_time > 300: # 5 minutes
time_penalty = -0.001 * (self.holding_time / 60) # -0.001 per minute
# Confidence bonus - reward high-confidence correct decisions
confidence_bonus = 0.0
if net_pnl > 0 and self.confidence > 0.7:
confidence_bonus = 0.01 * self.confidence
# Volume consideration (prefer trades that move significant amounts)
volume_factor = min(abs(base_reward) * 10, 0.05) # Cap at 5%
total_reward = base_reward + time_penalty + confidence_bonus
# Scale reward to reasonable range
return np.tanh(total_reward * 100) * 10 # Scale and bound reward
class MarketStateBuilder:
"""Builds state representations for RL agent from market data"""
def __init__(self, state_size: int = 100):
self.state_size = state_size
self.price_history = deque(maxlen=50)
self.volume_history = deque(maxlen=50)
self.rsi_history = deque(maxlen=14)
self.macd_history = deque(maxlen=26)
def update_market_data(self, price: float, volume: float,
rsi: float = None, macd: float = None):
"""Update market data buffers"""
self.price_history.append(price)
self.volume_history.append(volume)
if rsi is not None:
self.rsi_history.append(rsi)
if macd is not None:
self.macd_history.append(macd)
def build_state(self, current_position: str = 'NONE',
position_pnl: float = 0.0,
account_balance: float = 1000.0) -> np.ndarray:
"""Build state vector for RL agent"""
state = np.zeros(self.state_size)
try:
# Price features (normalized returns)
if len(self.price_history) >= 2:
prices = np.array(list(self.price_history))
returns = np.diff(prices) / prices[:-1]
# Recent returns (last 20)
recent_returns = returns[-20:] if len(returns) >= 20 else returns
state[:len(recent_returns)] = recent_returns
# Price momentum features
state[20] = np.mean(returns[-5:]) if len(returns) >= 5 else 0 # 5-bar momentum
state[21] = np.mean(returns[-10:]) if len(returns) >= 10 else 0 # 10-bar momentum
state[22] = np.std(returns[-10:]) if len(returns) >= 10 else 0 # Volatility
# Volume features
if len(self.volume_history) >= 2:
volumes = np.array(list(self.volume_history))
volume_changes = np.diff(volumes) / volumes[:-1]
recent_volume_changes = volume_changes[-10:] if len(volume_changes) >= 10 else volume_changes
state[30:30+len(recent_volume_changes)] = recent_volume_changes
# Volume momentum
state[40] = np.mean(volume_changes[-5:]) if len(volume_changes) >= 5 else 0
# Technical indicators
if len(self.rsi_history) >= 1:
state[50] = (list(self.rsi_history)[-1] - 50) / 50 # Normalized RSI
if len(self.macd_history) >= 2:
macd_values = list(self.macd_history)
state[51] = macd_values[-1] / 100 # Normalized MACD
state[52] = (macd_values[-1] - macd_values[-2]) / 100 # MACD change
# Position information
position_encoding = {'NONE': 0, 'LONG': 1, 'SHORT': -1}
state[60] = position_encoding.get(current_position, 0)
state[61] = position_pnl / 100 # Normalized PnL
state[62] = account_balance / 1000 # Normalized balance
# Market regime features
if len(self.price_history) >= 20:
prices = np.array(list(self.price_history))
# Trend strength
state[70] = (prices[-1] - prices[-20]) / prices[-20] # 20-bar trend
# Market volatility regime
returns = np.diff(prices) / prices[:-1]
state[71] = np.std(returns[-20:]) if len(returns) >= 20 else 0
# Support/resistance levels
high_20 = np.max(prices[-20:])
low_20 = np.min(prices[-20:])
current_price = prices[-1]
state[72] = (current_price - low_20) / (high_20 - low_20) if high_20 != low_20 else 0.5
except Exception as e:
logger.error(f"Error building state: {e}")
return state
class RealTimeRLTrainer:
"""Real-time RL trainer that learns from live trading decisions"""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize the real-time RL trainer"""
self.config = config or {}
# RL Agent configuration
state_size = self.config.get('state_size', 100)
action_size = 3 # BUY, HOLD, SELL
# Initialize RL agent
self.agent = DQNAgent(
state_shape=(state_size,),
n_actions=action_size,
learning_rate=self.config.get('learning_rate', 0.0001),
gamma=self.config.get('gamma', 0.95),
epsilon=self.config.get('epsilon', 0.1), # Low epsilon for live trading
epsilon_min=0.05,
epsilon_decay=0.999,
buffer_size=self.config.get('buffer_size', 10000),
batch_size=self.config.get('batch_size', 32)
)
# Market state builder
self.state_builder = MarketStateBuilder(state_size)
# Training data storage
self.pending_trades = {} # symbol -> trade info
self.completed_experiences = deque(maxlen=1000)
self.learning_history = []
# Training controls
self.training_enabled = self.config.get('training_enabled', True)
self.min_experiences_for_training = self.config.get('min_experiences', 10)
self.training_frequency = self.config.get('training_frequency', 5) # Train every N experiences
self.experience_count = 0
# Model saving
self.model_save_path = self.config.get('model_save_path', 'models/realtime_rl')
self.save_frequency = self.config.get('save_frequency', 100) # Save every N experiences
# Performance tracking
self.performance_history = []
self.recent_rewards = deque(maxlen=100)
self.trade_count = 0
self.win_count = 0
# Threading for async training
self.training_thread = None
self.training_queue = deque()
self.training_lock = threading.Lock()
logger.info(f"Real-time RL trainer initialized")
logger.info(f"State size: {state_size}, Action size: {action_size}")
logger.info(f"Training enabled: {self.training_enabled}")
def update_market_data(self, symbol: str, price: float, volume: float,
rsi: float = None, macd: float = None):
"""Update market data for state building"""
self.state_builder.update_market_data(price, volume, rsi, macd)
def record_trade_signal(self, symbol: str, action: str, confidence: float,
current_price: float, position_info: Dict[str, Any] = None):
"""Record a trade signal for future learning"""
try:
# Build current state
current_position = 'NONE'
position_pnl = 0.0
account_balance = 1000.0
if position_info:
current_position = position_info.get('side', 'NONE')
position_pnl = position_info.get('unrealized_pnl', 0.0)
account_balance = position_info.get('account_balance', 1000.0)
state = self.state_builder.build_state(current_position, position_pnl, account_balance)
# Convert action to numeric
action_map = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
action_num = action_map.get(action.upper(), 1)
# Store pending trade
trade_info = {
'pre_trade_state': state.copy(),
'action': action_num,
'entry_price': current_price,
'confidence': confidence,
'entry_time': datetime.now(),
'market_conditions': {
'volatility': np.std(list(self.state_builder.price_history)[-10:]) if len(self.state_builder.price_history) >= 10 else 0,
'trend': state[70] if len(state) > 70 else 0,
'volume_trend': state[40] if len(state) > 40 else 0
}
}
if action.upper() in ['BUY', 'SELL']:
self.pending_trades[symbol] = trade_info
logger.info(f"Recorded {action} signal for {symbol} at ${current_price:.2f} (confidence: {confidence:.2f})")
except Exception as e:
logger.error(f"Error recording trade signal: {e}")
def record_position_closure(self, symbol: str, exit_price: float,
pnl: float, fees: float):
"""Record position closure and create learning experience"""
try:
if symbol not in self.pending_trades:
logger.warning(f"No pending trade found for {symbol}")
return
trade_info = self.pending_trades.pop(symbol)
# Calculate holding time
holding_time = (datetime.now() - trade_info['entry_time']).total_seconds()
# Create trading experience
experience = TradingExperience(
pre_trade_state=trade_info['pre_trade_state'],
action=trade_info['action'],
entry_price=trade_info['entry_price'],
exit_price=exit_price,
holding_time=holding_time,
pnl=pnl,
fees=fees,
confidence=trade_info['confidence'],
market_conditions=trade_info['market_conditions'],
timestamp=datetime.now()
)
# Add to completed experiences
self.completed_experiences.append(experience)
self.recent_rewards.append(experience.reward)
self.experience_count += 1
self.trade_count += 1
if experience.reward > 0:
self.win_count += 1
# Log the experience
logger.info(f"Recorded experience: {symbol} PnL=${pnl:.4f} Reward={experience.reward:.4f} "
f"(Win rate: {self.win_count/self.trade_count*100:.1f}%)")
# Create next state (current market state after trade)
current_state = self.state_builder.build_state('NONE', 0.0, 1000.0)
# Store in agent memory for learning
self.agent.remember(
state=trade_info['pre_trade_state'],
action=trade_info['action'],
reward=experience.reward,
next_state=current_state,
done=True # Each trade is a complete episode
)
# Trigger training if conditions are met
if self.training_enabled:
self._maybe_train()
# Save model periodically
if self.experience_count % self.save_frequency == 0:
self._save_model()
except Exception as e:
logger.error(f"Error recording position closure: {e}")
def _maybe_train(self):
"""Train the agent if conditions are met"""
try:
if (len(self.agent.memory) >= self.min_experiences_for_training and
self.experience_count % self.training_frequency == 0):
# Perform training step
loss = self.agent.replay()
if loss is not None:
self.learning_history.append({
'timestamp': datetime.now().isoformat(),
'experience_count': self.experience_count,
'loss': loss,
'epsilon': self.agent.epsilon,
'avg_reward': np.mean(list(self.recent_rewards)) if self.recent_rewards else 0,
'win_rate': self.win_count / self.trade_count if self.trade_count > 0 else 0,
'memory_size': len(self.agent.memory)
})
logger.info(f"RL Training: Loss={loss:.4f}, Epsilon={self.agent.epsilon:.3f}, "
f"Avg Reward={np.mean(list(self.recent_rewards)):.4f}, "
f"Memory Size={len(self.agent.memory)}")
except Exception as e:
logger.error(f"Error in training: {e}")
def get_action_prediction(self, symbol: str, current_position: str = 'NONE',
position_pnl: float = 0.0, account_balance: float = 1000.0) -> Tuple[str, float]:
"""Get action prediction from trained RL agent"""
try:
# Build current state
state = self.state_builder.build_state(current_position, position_pnl, account_balance)
# Get prediction from agent
with torch.no_grad():
q_values, _, _, _, _ = self.agent.policy_net(
torch.FloatTensor(state).unsqueeze(0).to(self.agent.device)
)
# Get action with highest Q-value
action_idx = q_values.argmax().item()
confidence = torch.softmax(q_values, dim=1).max().item()
# Convert to action string
action_map = {0: 'SELL', 1: 'HOLD', 2: 'BUY'}
action = action_map[action_idx]
return action, confidence
except Exception as e:
logger.error(f"Error getting action prediction: {e}")
return 'HOLD', 0.5
def get_training_stats(self) -> Dict[str, Any]:
"""Get current training statistics"""
try:
return {
'total_experiences': self.experience_count,
'total_trades': self.trade_count,
'win_count': self.win_count,
'win_rate': self.win_count / self.trade_count if self.trade_count > 0 else 0,
'avg_reward': np.mean(list(self.recent_rewards)) if self.recent_rewards else 0,
'memory_size': len(self.agent.memory),
'epsilon': self.agent.epsilon,
'recent_loss': self.learning_history[-1]['loss'] if self.learning_history else 0,
'training_enabled': self.training_enabled,
'pending_trades': len(self.pending_trades)
}
except Exception as e:
logger.error(f"Error getting training stats: {e}")
return {}
def _save_model(self):
"""Save the trained model"""
try:
os.makedirs(self.model_save_path, exist_ok=True)
# Save RL agent
self.agent.save(os.path.join(self.model_save_path, 'rl_agent'))
# Save training history
history_path = os.path.join(self.model_save_path, 'training_history.json')
with open(history_path, 'w') as f:
json.dump(self.learning_history, f, indent=2)
# Save performance stats
stats_path = os.path.join(self.model_save_path, 'performance_stats.json')
with open(stats_path, 'w') as f:
json.dump(self.get_training_stats(), f, indent=2)
logger.info(f"Saved RL model and training data to {self.model_save_path}")
except Exception as e:
logger.error(f"Error saving model: {e}")
def load_model(self):
"""Load a previously saved model"""
try:
model_path = os.path.join(self.model_save_path, 'rl_agent')
if os.path.exists(f"{model_path}_policy_model.pt"):
self.agent.load(model_path)
logger.info(f"Loaded RL model from {model_path}")
# Load training history if available
history_path = os.path.join(self.model_save_path, 'training_history.json')
if os.path.exists(history_path):
with open(history_path, 'r') as f:
self.learning_history = json.load(f)
return True
else:
logger.info("No saved model found, starting with fresh model")
return False
except Exception as e:
logger.error(f"Error loading model: {e}")
return False
def enable_training(self, enabled: bool = True):
"""Enable or disable training"""
self.training_enabled = enabled
logger.info(f"RL training {'enabled' if enabled else 'disabled'}")
def reset_performance_stats(self):
"""Reset performance tracking statistics"""
self.trade_count = 0
self.win_count = 0
self.recent_rewards.clear()
logger.info("Reset RL performance statistics")