469 lines
19 KiB
Python
469 lines
19 KiB
Python
"""
|
|
Real-Time RL Training System
|
|
|
|
This module implements continuous learning from live trading decisions.
|
|
The RL agent learns from every trade signal and position closure to improve
|
|
decision-making over time.
|
|
"""
|
|
|
|
import logging
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from collections import deque
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
import threading
|
|
import time
|
|
import json
|
|
import os
|
|
|
|
# Import existing DQN agent
|
|
import sys
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'NN'))
|
|
from NN.models.dqn_agent import DQNAgent
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TradingExperience:
|
|
"""Represents a single trading experience for RL learning"""
|
|
|
|
def __init__(self,
|
|
pre_trade_state: np.ndarray,
|
|
action: int, # 0=SELL, 1=HOLD, 2=BUY
|
|
entry_price: float,
|
|
exit_price: float,
|
|
holding_time: float, # seconds
|
|
pnl: float,
|
|
fees: float,
|
|
confidence: float,
|
|
market_conditions: Dict[str, Any],
|
|
timestamp: datetime):
|
|
self.pre_trade_state = pre_trade_state
|
|
self.action = action
|
|
self.entry_price = entry_price
|
|
self.exit_price = exit_price
|
|
self.holding_time = holding_time
|
|
self.pnl = pnl
|
|
self.fees = fees
|
|
self.confidence = confidence
|
|
self.market_conditions = market_conditions
|
|
self.timestamp = timestamp
|
|
|
|
# Calculate reward
|
|
self.reward = self._calculate_reward()
|
|
|
|
def _calculate_reward(self) -> float:
|
|
"""Calculate reward for this trading experience"""
|
|
# Net PnL after fees
|
|
net_pnl = self.pnl - self.fees
|
|
|
|
# Base reward from PnL (normalized by entry price)
|
|
base_reward = net_pnl / self.entry_price
|
|
|
|
# Time penalty - prefer faster profitable trades
|
|
time_penalty = 0.0
|
|
if self.holding_time > 300: # 5 minutes
|
|
time_penalty = -0.001 * (self.holding_time / 60) # -0.001 per minute
|
|
|
|
# Confidence bonus - reward high-confidence correct decisions
|
|
confidence_bonus = 0.0
|
|
if net_pnl > 0 and self.confidence > 0.7:
|
|
confidence_bonus = 0.01 * self.confidence
|
|
|
|
# Volume consideration (prefer trades that move significant amounts)
|
|
volume_factor = min(abs(base_reward) * 10, 0.05) # Cap at 5%
|
|
|
|
total_reward = base_reward + time_penalty + confidence_bonus
|
|
|
|
# Scale reward to reasonable range
|
|
return np.tanh(total_reward * 100) * 10 # Scale and bound reward
|
|
|
|
|
|
class MarketStateBuilder:
|
|
"""Builds state representations for RL agent from market data"""
|
|
|
|
def __init__(self, state_size: int = 100):
|
|
self.state_size = state_size
|
|
self.price_history = deque(maxlen=50)
|
|
self.volume_history = deque(maxlen=50)
|
|
self.rsi_history = deque(maxlen=14)
|
|
self.macd_history = deque(maxlen=26)
|
|
|
|
def update_market_data(self, price: float, volume: float,
|
|
rsi: float = None, macd: float = None):
|
|
"""Update market data buffers"""
|
|
self.price_history.append(price)
|
|
self.volume_history.append(volume)
|
|
if rsi is not None:
|
|
self.rsi_history.append(rsi)
|
|
if macd is not None:
|
|
self.macd_history.append(macd)
|
|
|
|
def build_state(self, current_position: str = 'NONE',
|
|
position_pnl: float = 0.0,
|
|
account_balance: float = 1000.0) -> np.ndarray:
|
|
"""Build state vector for RL agent"""
|
|
state = np.zeros(self.state_size)
|
|
|
|
try:
|
|
# Price features (normalized returns)
|
|
if len(self.price_history) >= 2:
|
|
prices = np.array(list(self.price_history))
|
|
returns = np.diff(prices) / prices[:-1]
|
|
|
|
# Recent returns (last 20)
|
|
recent_returns = returns[-20:] if len(returns) >= 20 else returns
|
|
state[:len(recent_returns)] = recent_returns
|
|
|
|
# Price momentum features
|
|
state[20] = np.mean(returns[-5:]) if len(returns) >= 5 else 0 # 5-bar momentum
|
|
state[21] = np.mean(returns[-10:]) if len(returns) >= 10 else 0 # 10-bar momentum
|
|
state[22] = np.std(returns[-10:]) if len(returns) >= 10 else 0 # Volatility
|
|
|
|
# Volume features
|
|
if len(self.volume_history) >= 2:
|
|
volumes = np.array(list(self.volume_history))
|
|
volume_changes = np.diff(volumes) / volumes[:-1]
|
|
recent_volume_changes = volume_changes[-10:] if len(volume_changes) >= 10 else volume_changes
|
|
state[30:30+len(recent_volume_changes)] = recent_volume_changes
|
|
|
|
# Volume momentum
|
|
state[40] = np.mean(volume_changes[-5:]) if len(volume_changes) >= 5 else 0
|
|
|
|
# Technical indicators
|
|
if len(self.rsi_history) >= 1:
|
|
state[50] = (list(self.rsi_history)[-1] - 50) / 50 # Normalized RSI
|
|
|
|
if len(self.macd_history) >= 2:
|
|
macd_values = list(self.macd_history)
|
|
state[51] = macd_values[-1] / 100 # Normalized MACD
|
|
state[52] = (macd_values[-1] - macd_values[-2]) / 100 # MACD change
|
|
|
|
# Position information
|
|
position_encoding = {'NONE': 0, 'LONG': 1, 'SHORT': -1}
|
|
state[60] = position_encoding.get(current_position, 0)
|
|
state[61] = position_pnl / 100 # Normalized PnL
|
|
state[62] = account_balance / 1000 # Normalized balance
|
|
|
|
# Market regime features
|
|
if len(self.price_history) >= 20:
|
|
prices = np.array(list(self.price_history))
|
|
|
|
# Trend strength
|
|
state[70] = (prices[-1] - prices[-20]) / prices[-20] # 20-bar trend
|
|
|
|
# Market volatility regime
|
|
returns = np.diff(prices) / prices[:-1]
|
|
state[71] = np.std(returns[-20:]) if len(returns) >= 20 else 0
|
|
|
|
# Support/resistance levels
|
|
high_20 = np.max(prices[-20:])
|
|
low_20 = np.min(prices[-20:])
|
|
current_price = prices[-1]
|
|
state[72] = (current_price - low_20) / (high_20 - low_20) if high_20 != low_20 else 0.5
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error building state: {e}")
|
|
|
|
return state
|
|
|
|
|
|
class RealTimeRLTrainer:
|
|
"""Real-time RL trainer that learns from live trading decisions"""
|
|
|
|
def __init__(self, config: Dict[str, Any] = None):
|
|
"""Initialize the real-time RL trainer"""
|
|
self.config = config or {}
|
|
|
|
# RL Agent configuration
|
|
state_size = self.config.get('state_size', 100)
|
|
action_size = 3 # BUY, HOLD, SELL
|
|
|
|
# Initialize RL agent
|
|
self.agent = DQNAgent(
|
|
state_shape=(state_size,),
|
|
n_actions=action_size,
|
|
learning_rate=self.config.get('learning_rate', 0.0001),
|
|
gamma=self.config.get('gamma', 0.95),
|
|
epsilon=self.config.get('epsilon', 0.1), # Low epsilon for live trading
|
|
epsilon_min=0.05,
|
|
epsilon_decay=0.999,
|
|
buffer_size=self.config.get('buffer_size', 10000),
|
|
batch_size=self.config.get('batch_size', 32)
|
|
)
|
|
|
|
# Market state builder
|
|
self.state_builder = MarketStateBuilder(state_size)
|
|
|
|
# Training data storage
|
|
self.pending_trades = {} # symbol -> trade info
|
|
self.completed_experiences = deque(maxlen=1000)
|
|
self.learning_history = []
|
|
|
|
# Training controls
|
|
self.training_enabled = self.config.get('training_enabled', True)
|
|
self.min_experiences_for_training = self.config.get('min_experiences', 10)
|
|
self.training_frequency = self.config.get('training_frequency', 5) # Train every N experiences
|
|
self.experience_count = 0
|
|
|
|
# Model saving
|
|
self.model_save_path = self.config.get('model_save_path', 'models/realtime_rl')
|
|
self.save_frequency = self.config.get('save_frequency', 100) # Save every N experiences
|
|
|
|
# Performance tracking
|
|
self.performance_history = []
|
|
self.recent_rewards = deque(maxlen=100)
|
|
self.trade_count = 0
|
|
self.win_count = 0
|
|
|
|
# Threading for async training
|
|
self.training_thread = None
|
|
self.training_queue = deque()
|
|
self.training_lock = threading.Lock()
|
|
|
|
logger.info(f"Real-time RL trainer initialized")
|
|
logger.info(f"State size: {state_size}, Action size: {action_size}")
|
|
logger.info(f"Training enabled: {self.training_enabled}")
|
|
|
|
def update_market_data(self, symbol: str, price: float, volume: float,
|
|
rsi: float = None, macd: float = None):
|
|
"""Update market data for state building"""
|
|
self.state_builder.update_market_data(price, volume, rsi, macd)
|
|
|
|
def record_trade_signal(self, symbol: str, action: str, confidence: float,
|
|
current_price: float, position_info: Dict[str, Any] = None):
|
|
"""Record a trade signal for future learning"""
|
|
try:
|
|
# Build current state
|
|
current_position = 'NONE'
|
|
position_pnl = 0.0
|
|
account_balance = 1000.0
|
|
|
|
if position_info:
|
|
current_position = position_info.get('side', 'NONE')
|
|
position_pnl = position_info.get('unrealized_pnl', 0.0)
|
|
account_balance = position_info.get('account_balance', 1000.0)
|
|
|
|
state = self.state_builder.build_state(current_position, position_pnl, account_balance)
|
|
|
|
# Convert action to numeric
|
|
action_map = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
|
|
action_num = action_map.get(action.upper(), 1)
|
|
|
|
# Store pending trade
|
|
trade_info = {
|
|
'pre_trade_state': state.copy(),
|
|
'action': action_num,
|
|
'entry_price': current_price,
|
|
'confidence': confidence,
|
|
'entry_time': datetime.now(),
|
|
'market_conditions': {
|
|
'volatility': np.std(list(self.state_builder.price_history)[-10:]) if len(self.state_builder.price_history) >= 10 else 0,
|
|
'trend': state[70] if len(state) > 70 else 0,
|
|
'volume_trend': state[40] if len(state) > 40 else 0
|
|
}
|
|
}
|
|
|
|
if action.upper() in ['BUY', 'SELL']:
|
|
self.pending_trades[symbol] = trade_info
|
|
logger.info(f"Recorded {action} signal for {symbol} at ${current_price:.2f} (confidence: {confidence:.2f})")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error recording trade signal: {e}")
|
|
|
|
def record_position_closure(self, symbol: str, exit_price: float,
|
|
pnl: float, fees: float):
|
|
"""Record position closure and create learning experience"""
|
|
try:
|
|
if symbol not in self.pending_trades:
|
|
logger.warning(f"No pending trade found for {symbol}")
|
|
return
|
|
|
|
trade_info = self.pending_trades.pop(symbol)
|
|
|
|
# Calculate holding time
|
|
holding_time = (datetime.now() - trade_info['entry_time']).total_seconds()
|
|
|
|
# Create trading experience
|
|
experience = TradingExperience(
|
|
pre_trade_state=trade_info['pre_trade_state'],
|
|
action=trade_info['action'],
|
|
entry_price=trade_info['entry_price'],
|
|
exit_price=exit_price,
|
|
holding_time=holding_time,
|
|
pnl=pnl,
|
|
fees=fees,
|
|
confidence=trade_info['confidence'],
|
|
market_conditions=trade_info['market_conditions'],
|
|
timestamp=datetime.now()
|
|
)
|
|
|
|
# Add to completed experiences
|
|
self.completed_experiences.append(experience)
|
|
self.recent_rewards.append(experience.reward)
|
|
self.experience_count += 1
|
|
self.trade_count += 1
|
|
|
|
if experience.reward > 0:
|
|
self.win_count += 1
|
|
|
|
# Log the experience
|
|
logger.info(f"Recorded experience: {symbol} PnL=${pnl:.4f} Reward={experience.reward:.4f} "
|
|
f"(Win rate: {self.win_count/self.trade_count*100:.1f}%)")
|
|
|
|
# Create next state (current market state after trade)
|
|
current_state = self.state_builder.build_state('NONE', 0.0, 1000.0)
|
|
|
|
# Store in agent memory for learning
|
|
self.agent.remember(
|
|
state=trade_info['pre_trade_state'],
|
|
action=trade_info['action'],
|
|
reward=experience.reward,
|
|
next_state=current_state,
|
|
done=True # Each trade is a complete episode
|
|
)
|
|
|
|
# Trigger training if conditions are met
|
|
if self.training_enabled:
|
|
self._maybe_train()
|
|
|
|
# Save model periodically
|
|
if self.experience_count % self.save_frequency == 0:
|
|
self._save_model()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error recording position closure: {e}")
|
|
|
|
def _maybe_train(self):
|
|
"""Train the agent if conditions are met"""
|
|
try:
|
|
if (len(self.agent.memory) >= self.min_experiences_for_training and
|
|
self.experience_count % self.training_frequency == 0):
|
|
|
|
# Perform training step
|
|
loss = self.agent.replay()
|
|
|
|
if loss is not None:
|
|
self.learning_history.append({
|
|
'timestamp': datetime.now().isoformat(),
|
|
'experience_count': self.experience_count,
|
|
'loss': loss,
|
|
'epsilon': self.agent.epsilon,
|
|
'avg_reward': np.mean(list(self.recent_rewards)) if self.recent_rewards else 0,
|
|
'win_rate': self.win_count / self.trade_count if self.trade_count > 0 else 0,
|
|
'memory_size': len(self.agent.memory)
|
|
})
|
|
|
|
logger.info(f"RL Training: Loss={loss:.4f}, Epsilon={self.agent.epsilon:.3f}, "
|
|
f"Avg Reward={np.mean(list(self.recent_rewards)):.4f}, "
|
|
f"Memory Size={len(self.agent.memory)}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in training: {e}")
|
|
|
|
def get_action_prediction(self, symbol: str, current_position: str = 'NONE',
|
|
position_pnl: float = 0.0, account_balance: float = 1000.0) -> Tuple[str, float]:
|
|
"""Get action prediction from trained RL agent"""
|
|
try:
|
|
# Build current state
|
|
state = self.state_builder.build_state(current_position, position_pnl, account_balance)
|
|
|
|
# Get prediction from agent
|
|
with torch.no_grad():
|
|
q_values, _, _, _, _ = self.agent.policy_net(
|
|
torch.FloatTensor(state).unsqueeze(0).to(self.agent.device)
|
|
)
|
|
|
|
# Get action with highest Q-value
|
|
action_idx = q_values.argmax().item()
|
|
confidence = torch.softmax(q_values, dim=1).max().item()
|
|
|
|
# Convert to action string
|
|
action_map = {0: 'SELL', 1: 'HOLD', 2: 'BUY'}
|
|
action = action_map[action_idx]
|
|
|
|
return action, confidence
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting action prediction: {e}")
|
|
return 'HOLD', 0.5
|
|
|
|
def get_training_stats(self) -> Dict[str, Any]:
|
|
"""Get current training statistics"""
|
|
try:
|
|
return {
|
|
'total_experiences': self.experience_count,
|
|
'total_trades': self.trade_count,
|
|
'win_count': self.win_count,
|
|
'win_rate': self.win_count / self.trade_count if self.trade_count > 0 else 0,
|
|
'avg_reward': np.mean(list(self.recent_rewards)) if self.recent_rewards else 0,
|
|
'memory_size': len(self.agent.memory),
|
|
'epsilon': self.agent.epsilon,
|
|
'recent_loss': self.learning_history[-1]['loss'] if self.learning_history else 0,
|
|
'training_enabled': self.training_enabled,
|
|
'pending_trades': len(self.pending_trades)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting training stats: {e}")
|
|
return {}
|
|
|
|
def _save_model(self):
|
|
"""Save the trained model"""
|
|
try:
|
|
os.makedirs(self.model_save_path, exist_ok=True)
|
|
|
|
# Save RL agent
|
|
self.agent.save(os.path.join(self.model_save_path, 'rl_agent'))
|
|
|
|
# Save training history
|
|
history_path = os.path.join(self.model_save_path, 'training_history.json')
|
|
with open(history_path, 'w') as f:
|
|
json.dump(self.learning_history, f, indent=2)
|
|
|
|
# Save performance stats
|
|
stats_path = os.path.join(self.model_save_path, 'performance_stats.json')
|
|
with open(stats_path, 'w') as f:
|
|
json.dump(self.get_training_stats(), f, indent=2)
|
|
|
|
logger.info(f"Saved RL model and training data to {self.model_save_path}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving model: {e}")
|
|
|
|
def load_model(self):
|
|
"""Load a previously saved model"""
|
|
try:
|
|
model_path = os.path.join(self.model_save_path, 'rl_agent')
|
|
if os.path.exists(f"{model_path}_policy_model.pt"):
|
|
self.agent.load(model_path)
|
|
logger.info(f"Loaded RL model from {model_path}")
|
|
|
|
# Load training history if available
|
|
history_path = os.path.join(self.model_save_path, 'training_history.json')
|
|
if os.path.exists(history_path):
|
|
with open(history_path, 'r') as f:
|
|
self.learning_history = json.load(f)
|
|
|
|
return True
|
|
else:
|
|
logger.info("No saved model found, starting with fresh model")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading model: {e}")
|
|
return False
|
|
|
|
def enable_training(self, enabled: bool = True):
|
|
"""Enable or disable training"""
|
|
self.training_enabled = enabled
|
|
logger.info(f"RL training {'enabled' if enabled else 'disabled'}")
|
|
|
|
def reset_performance_stats(self):
|
|
"""Reset performance tracking statistics"""
|
|
self.trade_count = 0
|
|
self.win_count = 0
|
|
self.recent_rewards.clear()
|
|
logger.info("Reset RL performance statistics") |