RL trainer
This commit is contained in:
469
core/realtime_rl_trainer.py
Normal file
469
core/realtime_rl_trainer.py
Normal file
@ -0,0 +1,469 @@
|
||||
"""
|
||||
Real-Time RL Training System
|
||||
|
||||
This module implements continuous learning from live trading decisions.
|
||||
The RL agent learns from every trade signal and position closure to improve
|
||||
decision-making over time.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import threading
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
|
||||
# Import existing DQN agent
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'NN'))
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TradingExperience:
|
||||
"""Represents a single trading experience for RL learning"""
|
||||
|
||||
def __init__(self,
|
||||
pre_trade_state: np.ndarray,
|
||||
action: int, # 0=SELL, 1=HOLD, 2=BUY
|
||||
entry_price: float,
|
||||
exit_price: float,
|
||||
holding_time: float, # seconds
|
||||
pnl: float,
|
||||
fees: float,
|
||||
confidence: float,
|
||||
market_conditions: Dict[str, Any],
|
||||
timestamp: datetime):
|
||||
self.pre_trade_state = pre_trade_state
|
||||
self.action = action
|
||||
self.entry_price = entry_price
|
||||
self.exit_price = exit_price
|
||||
self.holding_time = holding_time
|
||||
self.pnl = pnl
|
||||
self.fees = fees
|
||||
self.confidence = confidence
|
||||
self.market_conditions = market_conditions
|
||||
self.timestamp = timestamp
|
||||
|
||||
# Calculate reward
|
||||
self.reward = self._calculate_reward()
|
||||
|
||||
def _calculate_reward(self) -> float:
|
||||
"""Calculate reward for this trading experience"""
|
||||
# Net PnL after fees
|
||||
net_pnl = self.pnl - self.fees
|
||||
|
||||
# Base reward from PnL (normalized by entry price)
|
||||
base_reward = net_pnl / self.entry_price
|
||||
|
||||
# Time penalty - prefer faster profitable trades
|
||||
time_penalty = 0.0
|
||||
if self.holding_time > 300: # 5 minutes
|
||||
time_penalty = -0.001 * (self.holding_time / 60) # -0.001 per minute
|
||||
|
||||
# Confidence bonus - reward high-confidence correct decisions
|
||||
confidence_bonus = 0.0
|
||||
if net_pnl > 0 and self.confidence > 0.7:
|
||||
confidence_bonus = 0.01 * self.confidence
|
||||
|
||||
# Volume consideration (prefer trades that move significant amounts)
|
||||
volume_factor = min(abs(base_reward) * 10, 0.05) # Cap at 5%
|
||||
|
||||
total_reward = base_reward + time_penalty + confidence_bonus
|
||||
|
||||
# Scale reward to reasonable range
|
||||
return np.tanh(total_reward * 100) * 10 # Scale and bound reward
|
||||
|
||||
|
||||
class MarketStateBuilder:
|
||||
"""Builds state representations for RL agent from market data"""
|
||||
|
||||
def __init__(self, state_size: int = 100):
|
||||
self.state_size = state_size
|
||||
self.price_history = deque(maxlen=50)
|
||||
self.volume_history = deque(maxlen=50)
|
||||
self.rsi_history = deque(maxlen=14)
|
||||
self.macd_history = deque(maxlen=26)
|
||||
|
||||
def update_market_data(self, price: float, volume: float,
|
||||
rsi: float = None, macd: float = None):
|
||||
"""Update market data buffers"""
|
||||
self.price_history.append(price)
|
||||
self.volume_history.append(volume)
|
||||
if rsi is not None:
|
||||
self.rsi_history.append(rsi)
|
||||
if macd is not None:
|
||||
self.macd_history.append(macd)
|
||||
|
||||
def build_state(self, current_position: str = 'NONE',
|
||||
position_pnl: float = 0.0,
|
||||
account_balance: float = 1000.0) -> np.ndarray:
|
||||
"""Build state vector for RL agent"""
|
||||
state = np.zeros(self.state_size)
|
||||
|
||||
try:
|
||||
# Price features (normalized returns)
|
||||
if len(self.price_history) >= 2:
|
||||
prices = np.array(list(self.price_history))
|
||||
returns = np.diff(prices) / prices[:-1]
|
||||
|
||||
# Recent returns (last 20)
|
||||
recent_returns = returns[-20:] if len(returns) >= 20 else returns
|
||||
state[:len(recent_returns)] = recent_returns
|
||||
|
||||
# Price momentum features
|
||||
state[20] = np.mean(returns[-5:]) if len(returns) >= 5 else 0 # 5-bar momentum
|
||||
state[21] = np.mean(returns[-10:]) if len(returns) >= 10 else 0 # 10-bar momentum
|
||||
state[22] = np.std(returns[-10:]) if len(returns) >= 10 else 0 # Volatility
|
||||
|
||||
# Volume features
|
||||
if len(self.volume_history) >= 2:
|
||||
volumes = np.array(list(self.volume_history))
|
||||
volume_changes = np.diff(volumes) / volumes[:-1]
|
||||
recent_volume_changes = volume_changes[-10:] if len(volume_changes) >= 10 else volume_changes
|
||||
state[30:30+len(recent_volume_changes)] = recent_volume_changes
|
||||
|
||||
# Volume momentum
|
||||
state[40] = np.mean(volume_changes[-5:]) if len(volume_changes) >= 5 else 0
|
||||
|
||||
# Technical indicators
|
||||
if len(self.rsi_history) >= 1:
|
||||
state[50] = (list(self.rsi_history)[-1] - 50) / 50 # Normalized RSI
|
||||
|
||||
if len(self.macd_history) >= 2:
|
||||
macd_values = list(self.macd_history)
|
||||
state[51] = macd_values[-1] / 100 # Normalized MACD
|
||||
state[52] = (macd_values[-1] - macd_values[-2]) / 100 # MACD change
|
||||
|
||||
# Position information
|
||||
position_encoding = {'NONE': 0, 'LONG': 1, 'SHORT': -1}
|
||||
state[60] = position_encoding.get(current_position, 0)
|
||||
state[61] = position_pnl / 100 # Normalized PnL
|
||||
state[62] = account_balance / 1000 # Normalized balance
|
||||
|
||||
# Market regime features
|
||||
if len(self.price_history) >= 20:
|
||||
prices = np.array(list(self.price_history))
|
||||
|
||||
# Trend strength
|
||||
state[70] = (prices[-1] - prices[-20]) / prices[-20] # 20-bar trend
|
||||
|
||||
# Market volatility regime
|
||||
returns = np.diff(prices) / prices[:-1]
|
||||
state[71] = np.std(returns[-20:]) if len(returns) >= 20 else 0
|
||||
|
||||
# Support/resistance levels
|
||||
high_20 = np.max(prices[-20:])
|
||||
low_20 = np.min(prices[-20:])
|
||||
current_price = prices[-1]
|
||||
state[72] = (current_price - low_20) / (high_20 - low_20) if high_20 != low_20 else 0.5
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building state: {e}")
|
||||
|
||||
return state
|
||||
|
||||
|
||||
class RealTimeRLTrainer:
|
||||
"""Real-time RL trainer that learns from live trading decisions"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None):
|
||||
"""Initialize the real-time RL trainer"""
|
||||
self.config = config or {}
|
||||
|
||||
# RL Agent configuration
|
||||
state_size = self.config.get('state_size', 100)
|
||||
action_size = 3 # BUY, HOLD, SELL
|
||||
|
||||
# Initialize RL agent
|
||||
self.agent = DQNAgent(
|
||||
state_shape=(state_size,),
|
||||
n_actions=action_size,
|
||||
learning_rate=self.config.get('learning_rate', 0.0001),
|
||||
gamma=self.config.get('gamma', 0.95),
|
||||
epsilon=self.config.get('epsilon', 0.1), # Low epsilon for live trading
|
||||
epsilon_min=0.05,
|
||||
epsilon_decay=0.999,
|
||||
buffer_size=self.config.get('buffer_size', 10000),
|
||||
batch_size=self.config.get('batch_size', 32)
|
||||
)
|
||||
|
||||
# Market state builder
|
||||
self.state_builder = MarketStateBuilder(state_size)
|
||||
|
||||
# Training data storage
|
||||
self.pending_trades = {} # symbol -> trade info
|
||||
self.completed_experiences = deque(maxlen=1000)
|
||||
self.learning_history = []
|
||||
|
||||
# Training controls
|
||||
self.training_enabled = self.config.get('training_enabled', True)
|
||||
self.min_experiences_for_training = self.config.get('min_experiences', 10)
|
||||
self.training_frequency = self.config.get('training_frequency', 5) # Train every N experiences
|
||||
self.experience_count = 0
|
||||
|
||||
# Model saving
|
||||
self.model_save_path = self.config.get('model_save_path', 'models/realtime_rl')
|
||||
self.save_frequency = self.config.get('save_frequency', 100) # Save every N experiences
|
||||
|
||||
# Performance tracking
|
||||
self.performance_history = []
|
||||
self.recent_rewards = deque(maxlen=100)
|
||||
self.trade_count = 0
|
||||
self.win_count = 0
|
||||
|
||||
# Threading for async training
|
||||
self.training_thread = None
|
||||
self.training_queue = deque()
|
||||
self.training_lock = threading.Lock()
|
||||
|
||||
logger.info(f"Real-time RL trainer initialized")
|
||||
logger.info(f"State size: {state_size}, Action size: {action_size}")
|
||||
logger.info(f"Training enabled: {self.training_enabled}")
|
||||
|
||||
def update_market_data(self, symbol: str, price: float, volume: float,
|
||||
rsi: float = None, macd: float = None):
|
||||
"""Update market data for state building"""
|
||||
self.state_builder.update_market_data(price, volume, rsi, macd)
|
||||
|
||||
def record_trade_signal(self, symbol: str, action: str, confidence: float,
|
||||
current_price: float, position_info: Dict[str, Any] = None):
|
||||
"""Record a trade signal for future learning"""
|
||||
try:
|
||||
# Build current state
|
||||
current_position = 'NONE'
|
||||
position_pnl = 0.0
|
||||
account_balance = 1000.0
|
||||
|
||||
if position_info:
|
||||
current_position = position_info.get('side', 'NONE')
|
||||
position_pnl = position_info.get('unrealized_pnl', 0.0)
|
||||
account_balance = position_info.get('account_balance', 1000.0)
|
||||
|
||||
state = self.state_builder.build_state(current_position, position_pnl, account_balance)
|
||||
|
||||
# Convert action to numeric
|
||||
action_map = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
|
||||
action_num = action_map.get(action.upper(), 1)
|
||||
|
||||
# Store pending trade
|
||||
trade_info = {
|
||||
'pre_trade_state': state.copy(),
|
||||
'action': action_num,
|
||||
'entry_price': current_price,
|
||||
'confidence': confidence,
|
||||
'entry_time': datetime.now(),
|
||||
'market_conditions': {
|
||||
'volatility': np.std(list(self.state_builder.price_history)[-10:]) if len(self.state_builder.price_history) >= 10 else 0,
|
||||
'trend': state[70] if len(state) > 70 else 0,
|
||||
'volume_trend': state[40] if len(state) > 40 else 0
|
||||
}
|
||||
}
|
||||
|
||||
if action.upper() in ['BUY', 'SELL']:
|
||||
self.pending_trades[symbol] = trade_info
|
||||
logger.info(f"Recorded {action} signal for {symbol} at ${current_price:.2f} (confidence: {confidence:.2f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error recording trade signal: {e}")
|
||||
|
||||
def record_position_closure(self, symbol: str, exit_price: float,
|
||||
pnl: float, fees: float):
|
||||
"""Record position closure and create learning experience"""
|
||||
try:
|
||||
if symbol not in self.pending_trades:
|
||||
logger.warning(f"No pending trade found for {symbol}")
|
||||
return
|
||||
|
||||
trade_info = self.pending_trades.pop(symbol)
|
||||
|
||||
# Calculate holding time
|
||||
holding_time = (datetime.now() - trade_info['entry_time']).total_seconds()
|
||||
|
||||
# Create trading experience
|
||||
experience = TradingExperience(
|
||||
pre_trade_state=trade_info['pre_trade_state'],
|
||||
action=trade_info['action'],
|
||||
entry_price=trade_info['entry_price'],
|
||||
exit_price=exit_price,
|
||||
holding_time=holding_time,
|
||||
pnl=pnl,
|
||||
fees=fees,
|
||||
confidence=trade_info['confidence'],
|
||||
market_conditions=trade_info['market_conditions'],
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
# Add to completed experiences
|
||||
self.completed_experiences.append(experience)
|
||||
self.recent_rewards.append(experience.reward)
|
||||
self.experience_count += 1
|
||||
self.trade_count += 1
|
||||
|
||||
if experience.reward > 0:
|
||||
self.win_count += 1
|
||||
|
||||
# Log the experience
|
||||
logger.info(f"Recorded experience: {symbol} PnL=${pnl:.4f} Reward={experience.reward:.4f} "
|
||||
f"(Win rate: {self.win_count/self.trade_count*100:.1f}%)")
|
||||
|
||||
# Create next state (current market state after trade)
|
||||
current_state = self.state_builder.build_state('NONE', 0.0, 1000.0)
|
||||
|
||||
# Store in agent memory for learning
|
||||
self.agent.remember(
|
||||
state=trade_info['pre_trade_state'],
|
||||
action=trade_info['action'],
|
||||
reward=experience.reward,
|
||||
next_state=current_state,
|
||||
done=True # Each trade is a complete episode
|
||||
)
|
||||
|
||||
# Trigger training if conditions are met
|
||||
if self.training_enabled:
|
||||
self._maybe_train()
|
||||
|
||||
# Save model periodically
|
||||
if self.experience_count % self.save_frequency == 0:
|
||||
self._save_model()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error recording position closure: {e}")
|
||||
|
||||
def _maybe_train(self):
|
||||
"""Train the agent if conditions are met"""
|
||||
try:
|
||||
if (len(self.agent.memory) >= self.min_experiences_for_training and
|
||||
self.experience_count % self.training_frequency == 0):
|
||||
|
||||
# Perform training step
|
||||
loss = self.agent.replay()
|
||||
|
||||
if loss is not None:
|
||||
self.learning_history.append({
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'experience_count': self.experience_count,
|
||||
'loss': loss,
|
||||
'epsilon': self.agent.epsilon,
|
||||
'avg_reward': np.mean(list(self.recent_rewards)) if self.recent_rewards else 0,
|
||||
'win_rate': self.win_count / self.trade_count if self.trade_count > 0 else 0,
|
||||
'memory_size': len(self.agent.memory)
|
||||
})
|
||||
|
||||
logger.info(f"RL Training: Loss={loss:.4f}, Epsilon={self.agent.epsilon:.3f}, "
|
||||
f"Avg Reward={np.mean(list(self.recent_rewards)):.4f}, "
|
||||
f"Memory Size={len(self.agent.memory)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training: {e}")
|
||||
|
||||
def get_action_prediction(self, symbol: str, current_position: str = 'NONE',
|
||||
position_pnl: float = 0.0, account_balance: float = 1000.0) -> Tuple[str, float]:
|
||||
"""Get action prediction from trained RL agent"""
|
||||
try:
|
||||
# Build current state
|
||||
state = self.state_builder.build_state(current_position, position_pnl, account_balance)
|
||||
|
||||
# Get prediction from agent
|
||||
with torch.no_grad():
|
||||
q_values, _, _, _, _ = self.agent.policy_net(
|
||||
torch.FloatTensor(state).unsqueeze(0).to(self.agent.device)
|
||||
)
|
||||
|
||||
# Get action with highest Q-value
|
||||
action_idx = q_values.argmax().item()
|
||||
confidence = torch.softmax(q_values, dim=1).max().item()
|
||||
|
||||
# Convert to action string
|
||||
action_map = {0: 'SELL', 1: 'HOLD', 2: 'BUY'}
|
||||
action = action_map[action_idx]
|
||||
|
||||
return action, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting action prediction: {e}")
|
||||
return 'HOLD', 0.5
|
||||
|
||||
def get_training_stats(self) -> Dict[str, Any]:
|
||||
"""Get current training statistics"""
|
||||
try:
|
||||
return {
|
||||
'total_experiences': self.experience_count,
|
||||
'total_trades': self.trade_count,
|
||||
'win_count': self.win_count,
|
||||
'win_rate': self.win_count / self.trade_count if self.trade_count > 0 else 0,
|
||||
'avg_reward': np.mean(list(self.recent_rewards)) if self.recent_rewards else 0,
|
||||
'memory_size': len(self.agent.memory),
|
||||
'epsilon': self.agent.epsilon,
|
||||
'recent_loss': self.learning_history[-1]['loss'] if self.learning_history else 0,
|
||||
'training_enabled': self.training_enabled,
|
||||
'pending_trades': len(self.pending_trades)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting training stats: {e}")
|
||||
return {}
|
||||
|
||||
def _save_model(self):
|
||||
"""Save the trained model"""
|
||||
try:
|
||||
os.makedirs(self.model_save_path, exist_ok=True)
|
||||
|
||||
# Save RL agent
|
||||
self.agent.save(os.path.join(self.model_save_path, 'rl_agent'))
|
||||
|
||||
# Save training history
|
||||
history_path = os.path.join(self.model_save_path, 'training_history.json')
|
||||
with open(history_path, 'w') as f:
|
||||
json.dump(self.learning_history, f, indent=2)
|
||||
|
||||
# Save performance stats
|
||||
stats_path = os.path.join(self.model_save_path, 'performance_stats.json')
|
||||
with open(stats_path, 'w') as f:
|
||||
json.dump(self.get_training_stats(), f, indent=2)
|
||||
|
||||
logger.info(f"Saved RL model and training data to {self.model_save_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model: {e}")
|
||||
|
||||
def load_model(self):
|
||||
"""Load a previously saved model"""
|
||||
try:
|
||||
model_path = os.path.join(self.model_save_path, 'rl_agent')
|
||||
if os.path.exists(f"{model_path}_policy_model.pt"):
|
||||
self.agent.load(model_path)
|
||||
logger.info(f"Loaded RL model from {model_path}")
|
||||
|
||||
# Load training history if available
|
||||
history_path = os.path.join(self.model_save_path, 'training_history.json')
|
||||
if os.path.exists(history_path):
|
||||
with open(history_path, 'r') as f:
|
||||
self.learning_history = json.load(f)
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.info("No saved model found, starting with fresh model")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {e}")
|
||||
return False
|
||||
|
||||
def enable_training(self, enabled: bool = True):
|
||||
"""Enable or disable training"""
|
||||
self.training_enabled = enabled
|
||||
logger.info(f"RL training {'enabled' if enabled else 'disabled'}")
|
||||
|
||||
def reset_performance_stats(self):
|
||||
"""Reset performance tracking statistics"""
|
||||
self.trade_count = 0
|
||||
self.win_count = 0
|
||||
self.recent_rewards.clear()
|
||||
logger.info("Reset RL performance statistics")
|
@ -9,7 +9,7 @@ import logging
|
||||
import time
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
import sys
|
||||
@ -20,6 +20,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'NN'))
|
||||
from NN.exchanges import MEXCInterface
|
||||
from .config import get_config
|
||||
from .config_sync import ConfigSynchronizer
|
||||
from .realtime_rl_trainer import RealTimeRLTrainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -119,6 +120,29 @@ class TradingExecutor:
|
||||
mexc_interface=self.exchange if self.trading_enabled else None
|
||||
)
|
||||
|
||||
# Initialize real-time RL trainer for continuous learning
|
||||
rl_config = {
|
||||
'state_size': 100,
|
||||
'learning_rate': 0.0001,
|
||||
'gamma': 0.95,
|
||||
'epsilon': 0.1, # Low exploration for live trading
|
||||
'buffer_size': 10000,
|
||||
'batch_size': 32,
|
||||
'training_enabled': self.mexc_config.get('rl_learning_enabled', True),
|
||||
'min_experiences': 10,
|
||||
'training_frequency': 3, # Train every 3 trades
|
||||
'save_frequency': 50, # Save every 50 trades
|
||||
'model_save_path': 'models/realtime_rl'
|
||||
}
|
||||
|
||||
self.rl_trainer = RealTimeRLTrainer(rl_config)
|
||||
|
||||
# Try to load existing RL model
|
||||
if self.rl_trainer.load_model():
|
||||
logger.info("TRADING EXECUTOR: Loaded existing RL model for continuous learning")
|
||||
else:
|
||||
logger.info("TRADING EXECUTOR: Starting with fresh RL model")
|
||||
|
||||
# Perform initial fee sync on startup if trading is enabled
|
||||
if self.trading_enabled and self.exchange:
|
||||
try:
|
||||
@ -189,6 +213,29 @@ class TradingExecutor:
|
||||
return False
|
||||
current_price = ticker['last']
|
||||
|
||||
# Update RL trainer with market data (estimate volume from price movement)
|
||||
estimated_volume = abs(current_price) * 1000 # Simple volume estimate
|
||||
self.rl_trainer.update_market_data(symbol, current_price, estimated_volume)
|
||||
|
||||
# Get position info for RL trainer
|
||||
position_info = None
|
||||
if symbol in self.positions:
|
||||
position = self.positions[symbol]
|
||||
position_info = {
|
||||
'side': position.side,
|
||||
'unrealized_pnl': position.unrealized_pnl,
|
||||
'account_balance': 1000.0 # Could get from exchange
|
||||
}
|
||||
|
||||
# Record trade signal with RL trainer for learning
|
||||
self.rl_trainer.record_trade_signal(
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
current_price=current_price,
|
||||
position_info=position_info
|
||||
)
|
||||
|
||||
with self.lock:
|
||||
try:
|
||||
if action == 'BUY':
|
||||
@ -348,6 +395,14 @@ class TradingExecutor:
|
||||
self.trade_history.append(trade_record)
|
||||
self.daily_loss += max(0, -pnl) # Add to daily loss if negative
|
||||
|
||||
# Record position closure with RL trainer for learning
|
||||
self.rl_trainer.record_position_closure(
|
||||
symbol=symbol,
|
||||
exit_price=current_price,
|
||||
pnl=pnl,
|
||||
fees=0.0 # No fees in simulation
|
||||
)
|
||||
|
||||
# Remove position
|
||||
del self.positions[symbol]
|
||||
self.last_trade_time[symbol] = datetime.now()
|
||||
@ -397,6 +452,14 @@ class TradingExecutor:
|
||||
self.trade_history.append(trade_record)
|
||||
self.daily_loss += max(0, -(pnl - fees)) # Add to daily loss if negative
|
||||
|
||||
# Record position closure with RL trainer for learning
|
||||
self.rl_trainer.record_position_closure(
|
||||
symbol=symbol,
|
||||
exit_price=current_price,
|
||||
pnl=pnl,
|
||||
fees=fees
|
||||
)
|
||||
|
||||
# Remove position
|
||||
del self.positions[symbol]
|
||||
self.last_trade_time[symbol] = datetime.now()
|
||||
@ -464,6 +527,9 @@ class TradingExecutor:
|
||||
effective_fee_rate = (total_fees / max(0.01, total_volume)) if total_volume > 0 else 0
|
||||
fee_impact_on_pnl = (total_fees / max(0.01, abs(gross_pnl))) * 100 if gross_pnl != 0 else 0
|
||||
|
||||
# Get RL training statistics
|
||||
rl_stats = self.rl_trainer.get_training_stats() if hasattr(self, 'rl_trainer') else {}
|
||||
|
||||
return {
|
||||
'daily_trades': self.daily_trades,
|
||||
'daily_loss': self.daily_loss,
|
||||
@ -490,6 +556,15 @@ class TradingExecutor:
|
||||
'fee_impact_percent': fee_impact_on_pnl,
|
||||
'is_fee_efficient': fee_impact_on_pnl < 5.0, # Less than 5% impact is good
|
||||
'fee_savings_vs_market': (0.001 - effective_fee_rate) * total_volume if effective_fee_rate < 0.001 else 0
|
||||
},
|
||||
'rl_learning': {
|
||||
'enabled': rl_stats.get('training_enabled', False),
|
||||
'total_experiences': rl_stats.get('total_experiences', 0),
|
||||
'rl_win_rate': rl_stats.get('win_rate', 0),
|
||||
'avg_reward': rl_stats.get('avg_reward', 0),
|
||||
'memory_size': rl_stats.get('memory_size', 0),
|
||||
'epsilon': rl_stats.get('epsilon', 0),
|
||||
'pending_trades': rl_stats.get('pending_trades', 0)
|
||||
}
|
||||
}
|
||||
|
||||
@ -803,3 +878,71 @@ class TradingExecutor:
|
||||
'sync_available': False,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
def get_rl_prediction(self, symbol: str) -> Tuple[str, float]:
|
||||
"""Get RL agent prediction for the current market state
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
tuple: (action, confidence) where action is BUY/SELL/HOLD
|
||||
"""
|
||||
if not hasattr(self, 'rl_trainer'):
|
||||
return 'HOLD', 0.5
|
||||
|
||||
try:
|
||||
# Get current position info
|
||||
current_position = 'NONE'
|
||||
position_pnl = 0.0
|
||||
account_balance = 1000.0
|
||||
|
||||
if symbol in self.positions:
|
||||
position = self.positions[symbol]
|
||||
current_position = position.side
|
||||
position_pnl = position.unrealized_pnl
|
||||
|
||||
# Get RL prediction
|
||||
action, confidence = self.rl_trainer.get_action_prediction(
|
||||
symbol=symbol,
|
||||
current_position=current_position,
|
||||
position_pnl=position_pnl,
|
||||
account_balance=account_balance
|
||||
)
|
||||
|
||||
return action, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TRADING EXECUTOR: Error getting RL prediction: {e}")
|
||||
return 'HOLD', 0.5
|
||||
|
||||
def enable_rl_training(self, enabled: bool = True):
|
||||
"""Enable or disable real-time RL training
|
||||
|
||||
Args:
|
||||
enabled: Whether to enable RL training
|
||||
"""
|
||||
if hasattr(self, 'rl_trainer'):
|
||||
self.rl_trainer.enable_training(enabled)
|
||||
logger.info(f"TRADING EXECUTOR: RL training {'enabled' if enabled else 'disabled'}")
|
||||
else:
|
||||
logger.warning("TRADING EXECUTOR: RL trainer not initialized")
|
||||
|
||||
def get_rl_training_stats(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive RL training statistics
|
||||
|
||||
Returns:
|
||||
dict: RL training statistics and performance metrics
|
||||
"""
|
||||
if hasattr(self, 'rl_trainer'):
|
||||
return self.rl_trainer.get_training_stats()
|
||||
else:
|
||||
return {'error': 'RL trainer not initialized'}
|
||||
|
||||
def save_rl_model(self):
|
||||
"""Manually save the current RL model"""
|
||||
if hasattr(self, 'rl_trainer'):
|
||||
self.rl_trainer._save_model()
|
||||
logger.info("TRADING EXECUTOR: RL model saved manually")
|
||||
else:
|
||||
logger.warning("TRADING EXECUTOR: RL trainer not initialized")
|
||||
|
Reference in New Issue
Block a user