RL trainer

This commit is contained in:
Dobromir Popov
2025-05-28 13:20:15 +03:00
parent d6a71c2b1a
commit a6eaa01735
8 changed files with 1476 additions and 132 deletions

469
core/realtime_rl_trainer.py Normal file
View File

@ -0,0 +1,469 @@
"""
Real-Time RL Training System
This module implements continuous learning from live trading decisions.
The RL agent learns from every trade signal and position closure to improve
decision-making over time.
"""
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
import threading
import time
import json
import os
# Import existing DQN agent
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'NN'))
from NN.models.dqn_agent import DQNAgent
logger = logging.getLogger(__name__)
class TradingExperience:
"""Represents a single trading experience for RL learning"""
def __init__(self,
pre_trade_state: np.ndarray,
action: int, # 0=SELL, 1=HOLD, 2=BUY
entry_price: float,
exit_price: float,
holding_time: float, # seconds
pnl: float,
fees: float,
confidence: float,
market_conditions: Dict[str, Any],
timestamp: datetime):
self.pre_trade_state = pre_trade_state
self.action = action
self.entry_price = entry_price
self.exit_price = exit_price
self.holding_time = holding_time
self.pnl = pnl
self.fees = fees
self.confidence = confidence
self.market_conditions = market_conditions
self.timestamp = timestamp
# Calculate reward
self.reward = self._calculate_reward()
def _calculate_reward(self) -> float:
"""Calculate reward for this trading experience"""
# Net PnL after fees
net_pnl = self.pnl - self.fees
# Base reward from PnL (normalized by entry price)
base_reward = net_pnl / self.entry_price
# Time penalty - prefer faster profitable trades
time_penalty = 0.0
if self.holding_time > 300: # 5 minutes
time_penalty = -0.001 * (self.holding_time / 60) # -0.001 per minute
# Confidence bonus - reward high-confidence correct decisions
confidence_bonus = 0.0
if net_pnl > 0 and self.confidence > 0.7:
confidence_bonus = 0.01 * self.confidence
# Volume consideration (prefer trades that move significant amounts)
volume_factor = min(abs(base_reward) * 10, 0.05) # Cap at 5%
total_reward = base_reward + time_penalty + confidence_bonus
# Scale reward to reasonable range
return np.tanh(total_reward * 100) * 10 # Scale and bound reward
class MarketStateBuilder:
"""Builds state representations for RL agent from market data"""
def __init__(self, state_size: int = 100):
self.state_size = state_size
self.price_history = deque(maxlen=50)
self.volume_history = deque(maxlen=50)
self.rsi_history = deque(maxlen=14)
self.macd_history = deque(maxlen=26)
def update_market_data(self, price: float, volume: float,
rsi: float = None, macd: float = None):
"""Update market data buffers"""
self.price_history.append(price)
self.volume_history.append(volume)
if rsi is not None:
self.rsi_history.append(rsi)
if macd is not None:
self.macd_history.append(macd)
def build_state(self, current_position: str = 'NONE',
position_pnl: float = 0.0,
account_balance: float = 1000.0) -> np.ndarray:
"""Build state vector for RL agent"""
state = np.zeros(self.state_size)
try:
# Price features (normalized returns)
if len(self.price_history) >= 2:
prices = np.array(list(self.price_history))
returns = np.diff(prices) / prices[:-1]
# Recent returns (last 20)
recent_returns = returns[-20:] if len(returns) >= 20 else returns
state[:len(recent_returns)] = recent_returns
# Price momentum features
state[20] = np.mean(returns[-5:]) if len(returns) >= 5 else 0 # 5-bar momentum
state[21] = np.mean(returns[-10:]) if len(returns) >= 10 else 0 # 10-bar momentum
state[22] = np.std(returns[-10:]) if len(returns) >= 10 else 0 # Volatility
# Volume features
if len(self.volume_history) >= 2:
volumes = np.array(list(self.volume_history))
volume_changes = np.diff(volumes) / volumes[:-1]
recent_volume_changes = volume_changes[-10:] if len(volume_changes) >= 10 else volume_changes
state[30:30+len(recent_volume_changes)] = recent_volume_changes
# Volume momentum
state[40] = np.mean(volume_changes[-5:]) if len(volume_changes) >= 5 else 0
# Technical indicators
if len(self.rsi_history) >= 1:
state[50] = (list(self.rsi_history)[-1] - 50) / 50 # Normalized RSI
if len(self.macd_history) >= 2:
macd_values = list(self.macd_history)
state[51] = macd_values[-1] / 100 # Normalized MACD
state[52] = (macd_values[-1] - macd_values[-2]) / 100 # MACD change
# Position information
position_encoding = {'NONE': 0, 'LONG': 1, 'SHORT': -1}
state[60] = position_encoding.get(current_position, 0)
state[61] = position_pnl / 100 # Normalized PnL
state[62] = account_balance / 1000 # Normalized balance
# Market regime features
if len(self.price_history) >= 20:
prices = np.array(list(self.price_history))
# Trend strength
state[70] = (prices[-1] - prices[-20]) / prices[-20] # 20-bar trend
# Market volatility regime
returns = np.diff(prices) / prices[:-1]
state[71] = np.std(returns[-20:]) if len(returns) >= 20 else 0
# Support/resistance levels
high_20 = np.max(prices[-20:])
low_20 = np.min(prices[-20:])
current_price = prices[-1]
state[72] = (current_price - low_20) / (high_20 - low_20) if high_20 != low_20 else 0.5
except Exception as e:
logger.error(f"Error building state: {e}")
return state
class RealTimeRLTrainer:
"""Real-time RL trainer that learns from live trading decisions"""
def __init__(self, config: Dict[str, Any] = None):
"""Initialize the real-time RL trainer"""
self.config = config or {}
# RL Agent configuration
state_size = self.config.get('state_size', 100)
action_size = 3 # BUY, HOLD, SELL
# Initialize RL agent
self.agent = DQNAgent(
state_shape=(state_size,),
n_actions=action_size,
learning_rate=self.config.get('learning_rate', 0.0001),
gamma=self.config.get('gamma', 0.95),
epsilon=self.config.get('epsilon', 0.1), # Low epsilon for live trading
epsilon_min=0.05,
epsilon_decay=0.999,
buffer_size=self.config.get('buffer_size', 10000),
batch_size=self.config.get('batch_size', 32)
)
# Market state builder
self.state_builder = MarketStateBuilder(state_size)
# Training data storage
self.pending_trades = {} # symbol -> trade info
self.completed_experiences = deque(maxlen=1000)
self.learning_history = []
# Training controls
self.training_enabled = self.config.get('training_enabled', True)
self.min_experiences_for_training = self.config.get('min_experiences', 10)
self.training_frequency = self.config.get('training_frequency', 5) # Train every N experiences
self.experience_count = 0
# Model saving
self.model_save_path = self.config.get('model_save_path', 'models/realtime_rl')
self.save_frequency = self.config.get('save_frequency', 100) # Save every N experiences
# Performance tracking
self.performance_history = []
self.recent_rewards = deque(maxlen=100)
self.trade_count = 0
self.win_count = 0
# Threading for async training
self.training_thread = None
self.training_queue = deque()
self.training_lock = threading.Lock()
logger.info(f"Real-time RL trainer initialized")
logger.info(f"State size: {state_size}, Action size: {action_size}")
logger.info(f"Training enabled: {self.training_enabled}")
def update_market_data(self, symbol: str, price: float, volume: float,
rsi: float = None, macd: float = None):
"""Update market data for state building"""
self.state_builder.update_market_data(price, volume, rsi, macd)
def record_trade_signal(self, symbol: str, action: str, confidence: float,
current_price: float, position_info: Dict[str, Any] = None):
"""Record a trade signal for future learning"""
try:
# Build current state
current_position = 'NONE'
position_pnl = 0.0
account_balance = 1000.0
if position_info:
current_position = position_info.get('side', 'NONE')
position_pnl = position_info.get('unrealized_pnl', 0.0)
account_balance = position_info.get('account_balance', 1000.0)
state = self.state_builder.build_state(current_position, position_pnl, account_balance)
# Convert action to numeric
action_map = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
action_num = action_map.get(action.upper(), 1)
# Store pending trade
trade_info = {
'pre_trade_state': state.copy(),
'action': action_num,
'entry_price': current_price,
'confidence': confidence,
'entry_time': datetime.now(),
'market_conditions': {
'volatility': np.std(list(self.state_builder.price_history)[-10:]) if len(self.state_builder.price_history) >= 10 else 0,
'trend': state[70] if len(state) > 70 else 0,
'volume_trend': state[40] if len(state) > 40 else 0
}
}
if action.upper() in ['BUY', 'SELL']:
self.pending_trades[symbol] = trade_info
logger.info(f"Recorded {action} signal for {symbol} at ${current_price:.2f} (confidence: {confidence:.2f})")
except Exception as e:
logger.error(f"Error recording trade signal: {e}")
def record_position_closure(self, symbol: str, exit_price: float,
pnl: float, fees: float):
"""Record position closure and create learning experience"""
try:
if symbol not in self.pending_trades:
logger.warning(f"No pending trade found for {symbol}")
return
trade_info = self.pending_trades.pop(symbol)
# Calculate holding time
holding_time = (datetime.now() - trade_info['entry_time']).total_seconds()
# Create trading experience
experience = TradingExperience(
pre_trade_state=trade_info['pre_trade_state'],
action=trade_info['action'],
entry_price=trade_info['entry_price'],
exit_price=exit_price,
holding_time=holding_time,
pnl=pnl,
fees=fees,
confidence=trade_info['confidence'],
market_conditions=trade_info['market_conditions'],
timestamp=datetime.now()
)
# Add to completed experiences
self.completed_experiences.append(experience)
self.recent_rewards.append(experience.reward)
self.experience_count += 1
self.trade_count += 1
if experience.reward > 0:
self.win_count += 1
# Log the experience
logger.info(f"Recorded experience: {symbol} PnL=${pnl:.4f} Reward={experience.reward:.4f} "
f"(Win rate: {self.win_count/self.trade_count*100:.1f}%)")
# Create next state (current market state after trade)
current_state = self.state_builder.build_state('NONE', 0.0, 1000.0)
# Store in agent memory for learning
self.agent.remember(
state=trade_info['pre_trade_state'],
action=trade_info['action'],
reward=experience.reward,
next_state=current_state,
done=True # Each trade is a complete episode
)
# Trigger training if conditions are met
if self.training_enabled:
self._maybe_train()
# Save model periodically
if self.experience_count % self.save_frequency == 0:
self._save_model()
except Exception as e:
logger.error(f"Error recording position closure: {e}")
def _maybe_train(self):
"""Train the agent if conditions are met"""
try:
if (len(self.agent.memory) >= self.min_experiences_for_training and
self.experience_count % self.training_frequency == 0):
# Perform training step
loss = self.agent.replay()
if loss is not None:
self.learning_history.append({
'timestamp': datetime.now().isoformat(),
'experience_count': self.experience_count,
'loss': loss,
'epsilon': self.agent.epsilon,
'avg_reward': np.mean(list(self.recent_rewards)) if self.recent_rewards else 0,
'win_rate': self.win_count / self.trade_count if self.trade_count > 0 else 0,
'memory_size': len(self.agent.memory)
})
logger.info(f"RL Training: Loss={loss:.4f}, Epsilon={self.agent.epsilon:.3f}, "
f"Avg Reward={np.mean(list(self.recent_rewards)):.4f}, "
f"Memory Size={len(self.agent.memory)}")
except Exception as e:
logger.error(f"Error in training: {e}")
def get_action_prediction(self, symbol: str, current_position: str = 'NONE',
position_pnl: float = 0.0, account_balance: float = 1000.0) -> Tuple[str, float]:
"""Get action prediction from trained RL agent"""
try:
# Build current state
state = self.state_builder.build_state(current_position, position_pnl, account_balance)
# Get prediction from agent
with torch.no_grad():
q_values, _, _, _, _ = self.agent.policy_net(
torch.FloatTensor(state).unsqueeze(0).to(self.agent.device)
)
# Get action with highest Q-value
action_idx = q_values.argmax().item()
confidence = torch.softmax(q_values, dim=1).max().item()
# Convert to action string
action_map = {0: 'SELL', 1: 'HOLD', 2: 'BUY'}
action = action_map[action_idx]
return action, confidence
except Exception as e:
logger.error(f"Error getting action prediction: {e}")
return 'HOLD', 0.5
def get_training_stats(self) -> Dict[str, Any]:
"""Get current training statistics"""
try:
return {
'total_experiences': self.experience_count,
'total_trades': self.trade_count,
'win_count': self.win_count,
'win_rate': self.win_count / self.trade_count if self.trade_count > 0 else 0,
'avg_reward': np.mean(list(self.recent_rewards)) if self.recent_rewards else 0,
'memory_size': len(self.agent.memory),
'epsilon': self.agent.epsilon,
'recent_loss': self.learning_history[-1]['loss'] if self.learning_history else 0,
'training_enabled': self.training_enabled,
'pending_trades': len(self.pending_trades)
}
except Exception as e:
logger.error(f"Error getting training stats: {e}")
return {}
def _save_model(self):
"""Save the trained model"""
try:
os.makedirs(self.model_save_path, exist_ok=True)
# Save RL agent
self.agent.save(os.path.join(self.model_save_path, 'rl_agent'))
# Save training history
history_path = os.path.join(self.model_save_path, 'training_history.json')
with open(history_path, 'w') as f:
json.dump(self.learning_history, f, indent=2)
# Save performance stats
stats_path = os.path.join(self.model_save_path, 'performance_stats.json')
with open(stats_path, 'w') as f:
json.dump(self.get_training_stats(), f, indent=2)
logger.info(f"Saved RL model and training data to {self.model_save_path}")
except Exception as e:
logger.error(f"Error saving model: {e}")
def load_model(self):
"""Load a previously saved model"""
try:
model_path = os.path.join(self.model_save_path, 'rl_agent')
if os.path.exists(f"{model_path}_policy_model.pt"):
self.agent.load(model_path)
logger.info(f"Loaded RL model from {model_path}")
# Load training history if available
history_path = os.path.join(self.model_save_path, 'training_history.json')
if os.path.exists(history_path):
with open(history_path, 'r') as f:
self.learning_history = json.load(f)
return True
else:
logger.info("No saved model found, starting with fresh model")
return False
except Exception as e:
logger.error(f"Error loading model: {e}")
return False
def enable_training(self, enabled: bool = True):
"""Enable or disable training"""
self.training_enabled = enabled
logger.info(f"RL training {'enabled' if enabled else 'disabled'}")
def reset_performance_stats(self):
"""Reset performance tracking statistics"""
self.trade_count = 0
self.win_count = 0
self.recent_rewards.clear()
logger.info("Reset RL performance statistics")

View File

@ -9,7 +9,7 @@ import logging
import time
import os
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from threading import Lock
import sys
@ -20,6 +20,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'NN'))
from NN.exchanges import MEXCInterface
from .config import get_config
from .config_sync import ConfigSynchronizer
from .realtime_rl_trainer import RealTimeRLTrainer
logger = logging.getLogger(__name__)
@ -119,6 +120,29 @@ class TradingExecutor:
mexc_interface=self.exchange if self.trading_enabled else None
)
# Initialize real-time RL trainer for continuous learning
rl_config = {
'state_size': 100,
'learning_rate': 0.0001,
'gamma': 0.95,
'epsilon': 0.1, # Low exploration for live trading
'buffer_size': 10000,
'batch_size': 32,
'training_enabled': self.mexc_config.get('rl_learning_enabled', True),
'min_experiences': 10,
'training_frequency': 3, # Train every 3 trades
'save_frequency': 50, # Save every 50 trades
'model_save_path': 'models/realtime_rl'
}
self.rl_trainer = RealTimeRLTrainer(rl_config)
# Try to load existing RL model
if self.rl_trainer.load_model():
logger.info("TRADING EXECUTOR: Loaded existing RL model for continuous learning")
else:
logger.info("TRADING EXECUTOR: Starting with fresh RL model")
# Perform initial fee sync on startup if trading is enabled
if self.trading_enabled and self.exchange:
try:
@ -189,6 +213,29 @@ class TradingExecutor:
return False
current_price = ticker['last']
# Update RL trainer with market data (estimate volume from price movement)
estimated_volume = abs(current_price) * 1000 # Simple volume estimate
self.rl_trainer.update_market_data(symbol, current_price, estimated_volume)
# Get position info for RL trainer
position_info = None
if symbol in self.positions:
position = self.positions[symbol]
position_info = {
'side': position.side,
'unrealized_pnl': position.unrealized_pnl,
'account_balance': 1000.0 # Could get from exchange
}
# Record trade signal with RL trainer for learning
self.rl_trainer.record_trade_signal(
symbol=symbol,
action=action,
confidence=confidence,
current_price=current_price,
position_info=position_info
)
with self.lock:
try:
if action == 'BUY':
@ -348,6 +395,14 @@ class TradingExecutor:
self.trade_history.append(trade_record)
self.daily_loss += max(0, -pnl) # Add to daily loss if negative
# Record position closure with RL trainer for learning
self.rl_trainer.record_position_closure(
symbol=symbol,
exit_price=current_price,
pnl=pnl,
fees=0.0 # No fees in simulation
)
# Remove position
del self.positions[symbol]
self.last_trade_time[symbol] = datetime.now()
@ -397,6 +452,14 @@ class TradingExecutor:
self.trade_history.append(trade_record)
self.daily_loss += max(0, -(pnl - fees)) # Add to daily loss if negative
# Record position closure with RL trainer for learning
self.rl_trainer.record_position_closure(
symbol=symbol,
exit_price=current_price,
pnl=pnl,
fees=fees
)
# Remove position
del self.positions[symbol]
self.last_trade_time[symbol] = datetime.now()
@ -464,6 +527,9 @@ class TradingExecutor:
effective_fee_rate = (total_fees / max(0.01, total_volume)) if total_volume > 0 else 0
fee_impact_on_pnl = (total_fees / max(0.01, abs(gross_pnl))) * 100 if gross_pnl != 0 else 0
# Get RL training statistics
rl_stats = self.rl_trainer.get_training_stats() if hasattr(self, 'rl_trainer') else {}
return {
'daily_trades': self.daily_trades,
'daily_loss': self.daily_loss,
@ -490,6 +556,15 @@ class TradingExecutor:
'fee_impact_percent': fee_impact_on_pnl,
'is_fee_efficient': fee_impact_on_pnl < 5.0, # Less than 5% impact is good
'fee_savings_vs_market': (0.001 - effective_fee_rate) * total_volume if effective_fee_rate < 0.001 else 0
},
'rl_learning': {
'enabled': rl_stats.get('training_enabled', False),
'total_experiences': rl_stats.get('total_experiences', 0),
'rl_win_rate': rl_stats.get('win_rate', 0),
'avg_reward': rl_stats.get('avg_reward', 0),
'memory_size': rl_stats.get('memory_size', 0),
'epsilon': rl_stats.get('epsilon', 0),
'pending_trades': rl_stats.get('pending_trades', 0)
}
}
@ -803,3 +878,71 @@ class TradingExecutor:
'sync_available': False,
'error': str(e)
}
def get_rl_prediction(self, symbol: str) -> Tuple[str, float]:
"""Get RL agent prediction for the current market state
Args:
symbol: Trading symbol
Returns:
tuple: (action, confidence) where action is BUY/SELL/HOLD
"""
if not hasattr(self, 'rl_trainer'):
return 'HOLD', 0.5
try:
# Get current position info
current_position = 'NONE'
position_pnl = 0.0
account_balance = 1000.0
if symbol in self.positions:
position = self.positions[symbol]
current_position = position.side
position_pnl = position.unrealized_pnl
# Get RL prediction
action, confidence = self.rl_trainer.get_action_prediction(
symbol=symbol,
current_position=current_position,
position_pnl=position_pnl,
account_balance=account_balance
)
return action, confidence
except Exception as e:
logger.error(f"TRADING EXECUTOR: Error getting RL prediction: {e}")
return 'HOLD', 0.5
def enable_rl_training(self, enabled: bool = True):
"""Enable or disable real-time RL training
Args:
enabled: Whether to enable RL training
"""
if hasattr(self, 'rl_trainer'):
self.rl_trainer.enable_training(enabled)
logger.info(f"TRADING EXECUTOR: RL training {'enabled' if enabled else 'disabled'}")
else:
logger.warning("TRADING EXECUTOR: RL trainer not initialized")
def get_rl_training_stats(self) -> Dict[str, Any]:
"""Get comprehensive RL training statistics
Returns:
dict: RL training statistics and performance metrics
"""
if hasattr(self, 'rl_trainer'):
return self.rl_trainer.get_training_stats()
else:
return {'error': 'RL trainer not initialized'}
def save_rl_model(self):
"""Manually save the current RL model"""
if hasattr(self, 'rl_trainer'):
self.rl_trainer._save_model()
logger.info("TRADING EXECUTOR: RL model saved manually")
else:
logger.warning("TRADING EXECUTOR: RL trainer not initialized")