LR module possibly working
This commit is contained in:
@@ -10,14 +10,16 @@ This module implements sophisticated RL training with:
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from collections import deque, namedtuple
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
@@ -26,6 +28,9 @@ from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator, MarketState, TradingAction
|
||||
from models import RLAgentInterface
|
||||
import models
|
||||
from training.enhanced_rl_state_builder import EnhancedRLStateBuilder
|
||||
from training.williams_market_structure import WilliamsMarketStructure
|
||||
from training.cnn_rl_bridge import CNNRLBridge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -318,42 +323,66 @@ class EnhancedDQNAgent(nn.Module, RLAgentInterface):
|
||||
return (param_count * 4 + buffer_size) // (1024 * 1024)
|
||||
|
||||
class EnhancedRLTrainer:
|
||||
"""Enhanced RL trainer with continuous learning from market feedback"""
|
||||
"""Enhanced RL trainer with comprehensive state representation and real data integration"""
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
|
||||
"""Initialize the enhanced RL trainer"""
|
||||
"""Initialize enhanced RL trainer with comprehensive state building"""
|
||||
self.config = config or get_config()
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = DataProvider(self.config)
|
||||
|
||||
# Create RL agents for each symbol
|
||||
# Initialize comprehensive state builder (replaces mock code)
|
||||
self.state_builder = EnhancedRLStateBuilder(self.config)
|
||||
self.williams_structure = WilliamsMarketStructure()
|
||||
self.cnn_rl_bridge = CNNRLBridge(self.config) if hasattr(self.config, 'cnn_models') else None
|
||||
|
||||
# Enhanced RL agents with much larger state space
|
||||
self.agents = {}
|
||||
for symbol in self.config.symbols:
|
||||
agent_config = self.config.rl.copy()
|
||||
agent_config['name'] = f'RL_{symbol}'
|
||||
self.agents[symbol] = EnhancedDQNAgent(agent_config)
|
||||
self.initialize_agents()
|
||||
|
||||
# Training parameters
|
||||
self.training_interval = 3600 # Train every hour
|
||||
self.evaluation_window = 24 * 3600 # Evaluate actions after 24 hours
|
||||
self.min_experiences = 100 # Minimum experiences before training
|
||||
|
||||
# Performance tracking
|
||||
self.performance_history = {symbol: [] for symbol in self.config.symbols}
|
||||
self.training_metrics = {
|
||||
'total_episodes': 0,
|
||||
'total_rewards': {symbol: [] for symbol in self.config.symbols},
|
||||
'losses': {symbol: [] for symbol in self.config.symbols},
|
||||
'epsilon_values': {symbol: [] for symbol in self.config.symbols}
|
||||
}
|
||||
|
||||
# Create save directory
|
||||
models_path = self.config.rl.get('model_dir', "models/enhanced_rl")
|
||||
self.save_dir = Path(models_path)
|
||||
# Training configuration
|
||||
self.symbols = self.config.symbols
|
||||
self.save_dir = Path(self.config.rl.get('save_dir', 'models/rl/saved'))
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"Enhanced RL trainer initialized for symbols: {self.config.symbols}")
|
||||
|
||||
# Performance tracking
|
||||
self.training_metrics = {
|
||||
'total_episodes': 0,
|
||||
'total_rewards': {symbol: [] for symbol in self.symbols},
|
||||
'losses': {symbol: [] for symbol in self.symbols},
|
||||
'epsilon_values': {symbol: [] for symbol in self.symbols}
|
||||
}
|
||||
|
||||
self.performance_history = {symbol: [] for symbol in self.symbols}
|
||||
|
||||
# Real-time learning parameters
|
||||
self.learning_active = False
|
||||
self.experience_buffer_size = 1000
|
||||
self.min_experiences_for_training = 100
|
||||
|
||||
logger.info("Enhanced RL Trainer initialized with comprehensive state representation")
|
||||
logger.info(f"State builder total size: {self.state_builder.total_state_size} features")
|
||||
logger.info(f"Symbols: {self.symbols}")
|
||||
|
||||
def initialize_agents(self):
|
||||
"""Initialize RL agents with enhanced state size"""
|
||||
for symbol in self.symbols:
|
||||
agent_config = {
|
||||
'state_size': self.state_builder.total_state_size, # ~13,400 features
|
||||
'action_space': 3, # BUY, SELL, HOLD
|
||||
'hidden_size': 1024, # Larger hidden layers for complex state
|
||||
'learning_rate': 0.0001,
|
||||
'gamma': 0.99,
|
||||
'epsilon': 1.0,
|
||||
'epsilon_decay': 0.995,
|
||||
'epsilon_min': 0.01,
|
||||
'buffer_size': 50000, # Larger replay buffer
|
||||
'batch_size': 128,
|
||||
'target_update_freq': 1000
|
||||
}
|
||||
|
||||
self.agents[symbol] = EnhancedDQNAgent(agent_config)
|
||||
logger.info(f"Initialized {symbol} RL agent with state size: {agent_config['state_size']}")
|
||||
|
||||
async def continuous_learning_loop(self):
|
||||
"""Main continuous learning loop"""
|
||||
logger.info("Starting continuous RL learning loop")
|
||||
@@ -378,7 +407,7 @@ class EnhancedRLTrainer:
|
||||
self._save_all_models()
|
||||
|
||||
# Wait before next training cycle
|
||||
await asyncio.sleep(self.training_interval)
|
||||
await asyncio.sleep(3600) # Train every hour
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in continuous learning loop: {e}")
|
||||
@@ -388,7 +417,7 @@ class EnhancedRLTrainer:
|
||||
"""Train all RL agents with their experiences"""
|
||||
for symbol, agent in self.agents.items():
|
||||
try:
|
||||
if len(agent.replay_buffer) >= self.min_experiences:
|
||||
if len(agent.replay_buffer) >= self.min_experiences_for_training:
|
||||
# Train for multiple steps
|
||||
losses = []
|
||||
for _ in range(10): # Train 10 steps per cycle
|
||||
@@ -411,7 +440,7 @@ class EnhancedRLTrainer:
|
||||
if not self.orchestrator:
|
||||
return
|
||||
|
||||
for symbol in self.config.symbols:
|
||||
for symbol in self.symbols:
|
||||
try:
|
||||
# Get recent market states
|
||||
recent_states = list(self.orchestrator.market_states[symbol])[-10:] # Last 10 states
|
||||
@@ -471,11 +500,150 @@ class EnhancedRLTrainer:
|
||||
logger.error(f"Error adding experience for {symbol}: {e}")
|
||||
|
||||
def _market_state_to_rl_state(self, market_state: MarketState) -> np.ndarray:
|
||||
"""Convert market state to RL state vector"""
|
||||
if hasattr(self.orchestrator, '_market_state_to_rl_state'):
|
||||
return self.orchestrator._market_state_to_rl_state(market_state)
|
||||
"""Convert market state to comprehensive RL state vector using real data"""
|
||||
try:
|
||||
# Extract data from market state and orchestrator
|
||||
if not self.orchestrator:
|
||||
logger.warning("No orchestrator available for comprehensive state building")
|
||||
return self._fallback_state_conversion(market_state)
|
||||
|
||||
# Get real tick data from orchestrator's data provider
|
||||
symbol = market_state.symbol
|
||||
eth_ticks = self._get_recent_tick_data(symbol, seconds=300)
|
||||
|
||||
# Get multi-timeframe OHLCV data
|
||||
eth_ohlcv = self._get_multiframe_ohlcv_data(symbol)
|
||||
btc_ohlcv = self._get_multiframe_ohlcv_data('BTC/USDT')
|
||||
|
||||
# Get CNN features if available
|
||||
cnn_hidden_features = None
|
||||
cnn_predictions = None
|
||||
if self.cnn_rl_bridge:
|
||||
cnn_data = self.cnn_rl_bridge.get_latest_features_for_symbol(symbol)
|
||||
if cnn_data:
|
||||
cnn_hidden_features = cnn_data.get('hidden_features', {})
|
||||
cnn_predictions = cnn_data.get('predictions', {})
|
||||
|
||||
# Get pivot point data
|
||||
pivot_data = self._calculate_pivot_points(eth_ohlcv)
|
||||
|
||||
# Build comprehensive state using enhanced state builder
|
||||
comprehensive_state = self.state_builder.build_rl_state(
|
||||
eth_ticks=eth_ticks,
|
||||
eth_ohlcv=eth_ohlcv,
|
||||
btc_ohlcv=btc_ohlcv,
|
||||
cnn_hidden_features=cnn_hidden_features,
|
||||
cnn_predictions=cnn_predictions,
|
||||
pivot_data=pivot_data
|
||||
)
|
||||
|
||||
logger.debug(f"Built comprehensive RL state: {len(comprehensive_state)} features")
|
||||
return comprehensive_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building comprehensive RL state: {e}")
|
||||
return self._fallback_state_conversion(market_state)
|
||||
|
||||
def _get_recent_tick_data(self, symbol: str, seconds: int = 300) -> List:
|
||||
"""Get recent tick data from orchestrator's data provider"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'data_provider') and self.orchestrator.data_provider:
|
||||
# Get recent ticks from data provider
|
||||
recent_ticks = self.orchestrator.data_provider.get_recent_ticks(symbol, count=seconds*10)
|
||||
|
||||
# Convert to required format
|
||||
tick_data = []
|
||||
for tick in recent_ticks[-300:]: # Last 300 ticks max
|
||||
tick_data.append({
|
||||
'timestamp': tick.timestamp,
|
||||
'price': tick.price,
|
||||
'volume': tick.volume,
|
||||
'quantity': getattr(tick, 'quantity', tick.volume),
|
||||
'side': getattr(tick, 'side', 'unknown'),
|
||||
'trade_id': getattr(tick, 'trade_id', 'unknown')
|
||||
})
|
||||
|
||||
return tick_data
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting tick data for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def _get_multiframe_ohlcv_data(self, symbol: str) -> Dict[str, List]:
|
||||
"""Get multi-timeframe OHLCV data"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'data_provider') and self.orchestrator.data_provider:
|
||||
ohlcv_data = {}
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
for tf in timeframes:
|
||||
try:
|
||||
# Get historical data for timeframe
|
||||
df = self.orchestrator.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=tf,
|
||||
limit=300,
|
||||
refresh=True
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Convert to list of dictionaries
|
||||
bars = []
|
||||
for _, row in df.tail(300).iterrows():
|
||||
bar = {
|
||||
'timestamp': row.name if hasattr(row, 'name') else datetime.now(),
|
||||
'open': float(row.get('open', 0)),
|
||||
'high': float(row.get('high', 0)),
|
||||
'low': float(row.get('low', 0)),
|
||||
'close': float(row.get('close', 0)),
|
||||
'volume': float(row.get('volume', 0))
|
||||
}
|
||||
bars.append(bar)
|
||||
|
||||
ohlcv_data[tf] = bars
|
||||
else:
|
||||
ohlcv_data[tf] = []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting {tf} data for {symbol}: {e}")
|
||||
ohlcv_data[tf] = []
|
||||
|
||||
return ohlcv_data
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting OHLCV data for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def _calculate_pivot_points(self, eth_ohlcv: Dict[str, List]) -> Dict[str, Any]:
|
||||
"""Calculate Williams pivot points from OHLCV data"""
|
||||
try:
|
||||
if '1m' in eth_ohlcv and eth_ohlcv['1m']:
|
||||
# Convert to numpy array for Williams calculation
|
||||
bars = eth_ohlcv['1m']
|
||||
if len(bars) >= 50: # Need minimum data for pivot calculation
|
||||
ohlc_array = np.array([
|
||||
[bar['timestamp'].timestamp() if hasattr(bar['timestamp'], 'timestamp') else time.time(),
|
||||
bar['open'], bar['high'], bar['low'], bar['close'], bar['volume']]
|
||||
for bar in bars[-200:] # Last 200 bars
|
||||
])
|
||||
|
||||
pivot_data = self.williams_structure.calculate_recursive_pivot_points(ohlc_array)
|
||||
return pivot_data
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating pivot points: {e}")
|
||||
return {}
|
||||
|
||||
def _fallback_state_conversion(self, market_state: MarketState) -> np.ndarray:
|
||||
"""Fallback to basic state conversion if comprehensive state building fails"""
|
||||
logger.warning("Using fallback state conversion - limited features")
|
||||
|
||||
# Fallback implementation
|
||||
state_components = [
|
||||
market_state.volatility,
|
||||
market_state.volume,
|
||||
@@ -486,8 +654,8 @@ class EnhancedRLTrainer:
|
||||
for timeframe in sorted(market_state.prices.keys()):
|
||||
state_components.append(market_state.prices[timeframe])
|
||||
|
||||
# Pad or truncate to expected state size
|
||||
expected_size = self.config.rl.get('state_size', 100)
|
||||
# Pad to match expected state size
|
||||
expected_size = self.state_builder.total_state_size
|
||||
if len(state_components) < expected_size:
|
||||
state_components.extend([0.0] * (expected_size - len(state_components)))
|
||||
else:
|
||||
@@ -545,7 +713,7 @@ class EnhancedRLTrainer:
|
||||
timestamp = max(timestamps)
|
||||
|
||||
loaded_count = 0
|
||||
for symbol in self.config.symbols:
|
||||
for symbol in self.symbols:
|
||||
filename = f"rl_agent_{symbol}_{timestamp}.pt"
|
||||
filepath = self.save_dir / filename
|
||||
|
||||
|
||||
Reference in New Issue
Block a user