LR module possibly working

This commit is contained in:
Dobromir Popov
2025-05-28 23:42:06 +03:00
parent de01d3665c
commit 6b7d7aec81
16 changed files with 5118 additions and 580 deletions

View File

@@ -10,14 +10,16 @@ This module implements sophisticated RL training with:
import asyncio
import logging
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque, namedtuple
import random
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
from typing import Dict, List, Optional, Tuple, Any, Union
import matplotlib.pyplot as plt
from pathlib import Path
@@ -26,6 +28,9 @@ from core.data_provider import DataProvider
from core.enhanced_orchestrator import EnhancedTradingOrchestrator, MarketState, TradingAction
from models import RLAgentInterface
import models
from training.enhanced_rl_state_builder import EnhancedRLStateBuilder
from training.williams_market_structure import WilliamsMarketStructure
from training.cnn_rl_bridge import CNNRLBridge
logger = logging.getLogger(__name__)
@@ -318,42 +323,66 @@ class EnhancedDQNAgent(nn.Module, RLAgentInterface):
return (param_count * 4 + buffer_size) // (1024 * 1024)
class EnhancedRLTrainer:
"""Enhanced RL trainer with continuous learning from market feedback"""
"""Enhanced RL trainer with comprehensive state representation and real data integration"""
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
"""Initialize the enhanced RL trainer"""
"""Initialize enhanced RL trainer with comprehensive state building"""
self.config = config or get_config()
self.orchestrator = orchestrator
self.data_provider = DataProvider(self.config)
# Create RL agents for each symbol
# Initialize comprehensive state builder (replaces mock code)
self.state_builder = EnhancedRLStateBuilder(self.config)
self.williams_structure = WilliamsMarketStructure()
self.cnn_rl_bridge = CNNRLBridge(self.config) if hasattr(self.config, 'cnn_models') else None
# Enhanced RL agents with much larger state space
self.agents = {}
for symbol in self.config.symbols:
agent_config = self.config.rl.copy()
agent_config['name'] = f'RL_{symbol}'
self.agents[symbol] = EnhancedDQNAgent(agent_config)
self.initialize_agents()
# Training parameters
self.training_interval = 3600 # Train every hour
self.evaluation_window = 24 * 3600 # Evaluate actions after 24 hours
self.min_experiences = 100 # Minimum experiences before training
# Performance tracking
self.performance_history = {symbol: [] for symbol in self.config.symbols}
self.training_metrics = {
'total_episodes': 0,
'total_rewards': {symbol: [] for symbol in self.config.symbols},
'losses': {symbol: [] for symbol in self.config.symbols},
'epsilon_values': {symbol: [] for symbol in self.config.symbols}
}
# Create save directory
models_path = self.config.rl.get('model_dir', "models/enhanced_rl")
self.save_dir = Path(models_path)
# Training configuration
self.symbols = self.config.symbols
self.save_dir = Path(self.config.rl.get('save_dir', 'models/rl/saved'))
self.save_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Enhanced RL trainer initialized for symbols: {self.config.symbols}")
# Performance tracking
self.training_metrics = {
'total_episodes': 0,
'total_rewards': {symbol: [] for symbol in self.symbols},
'losses': {symbol: [] for symbol in self.symbols},
'epsilon_values': {symbol: [] for symbol in self.symbols}
}
self.performance_history = {symbol: [] for symbol in self.symbols}
# Real-time learning parameters
self.learning_active = False
self.experience_buffer_size = 1000
self.min_experiences_for_training = 100
logger.info("Enhanced RL Trainer initialized with comprehensive state representation")
logger.info(f"State builder total size: {self.state_builder.total_state_size} features")
logger.info(f"Symbols: {self.symbols}")
def initialize_agents(self):
"""Initialize RL agents with enhanced state size"""
for symbol in self.symbols:
agent_config = {
'state_size': self.state_builder.total_state_size, # ~13,400 features
'action_space': 3, # BUY, SELL, HOLD
'hidden_size': 1024, # Larger hidden layers for complex state
'learning_rate': 0.0001,
'gamma': 0.99,
'epsilon': 1.0,
'epsilon_decay': 0.995,
'epsilon_min': 0.01,
'buffer_size': 50000, # Larger replay buffer
'batch_size': 128,
'target_update_freq': 1000
}
self.agents[symbol] = EnhancedDQNAgent(agent_config)
logger.info(f"Initialized {symbol} RL agent with state size: {agent_config['state_size']}")
async def continuous_learning_loop(self):
"""Main continuous learning loop"""
logger.info("Starting continuous RL learning loop")
@@ -378,7 +407,7 @@ class EnhancedRLTrainer:
self._save_all_models()
# Wait before next training cycle
await asyncio.sleep(self.training_interval)
await asyncio.sleep(3600) # Train every hour
except Exception as e:
logger.error(f"Error in continuous learning loop: {e}")
@@ -388,7 +417,7 @@ class EnhancedRLTrainer:
"""Train all RL agents with their experiences"""
for symbol, agent in self.agents.items():
try:
if len(agent.replay_buffer) >= self.min_experiences:
if len(agent.replay_buffer) >= self.min_experiences_for_training:
# Train for multiple steps
losses = []
for _ in range(10): # Train 10 steps per cycle
@@ -411,7 +440,7 @@ class EnhancedRLTrainer:
if not self.orchestrator:
return
for symbol in self.config.symbols:
for symbol in self.symbols:
try:
# Get recent market states
recent_states = list(self.orchestrator.market_states[symbol])[-10:] # Last 10 states
@@ -471,11 +500,150 @@ class EnhancedRLTrainer:
logger.error(f"Error adding experience for {symbol}: {e}")
def _market_state_to_rl_state(self, market_state: MarketState) -> np.ndarray:
"""Convert market state to RL state vector"""
if hasattr(self.orchestrator, '_market_state_to_rl_state'):
return self.orchestrator._market_state_to_rl_state(market_state)
"""Convert market state to comprehensive RL state vector using real data"""
try:
# Extract data from market state and orchestrator
if not self.orchestrator:
logger.warning("No orchestrator available for comprehensive state building")
return self._fallback_state_conversion(market_state)
# Get real tick data from orchestrator's data provider
symbol = market_state.symbol
eth_ticks = self._get_recent_tick_data(symbol, seconds=300)
# Get multi-timeframe OHLCV data
eth_ohlcv = self._get_multiframe_ohlcv_data(symbol)
btc_ohlcv = self._get_multiframe_ohlcv_data('BTC/USDT')
# Get CNN features if available
cnn_hidden_features = None
cnn_predictions = None
if self.cnn_rl_bridge:
cnn_data = self.cnn_rl_bridge.get_latest_features_for_symbol(symbol)
if cnn_data:
cnn_hidden_features = cnn_data.get('hidden_features', {})
cnn_predictions = cnn_data.get('predictions', {})
# Get pivot point data
pivot_data = self._calculate_pivot_points(eth_ohlcv)
# Build comprehensive state using enhanced state builder
comprehensive_state = self.state_builder.build_rl_state(
eth_ticks=eth_ticks,
eth_ohlcv=eth_ohlcv,
btc_ohlcv=btc_ohlcv,
cnn_hidden_features=cnn_hidden_features,
cnn_predictions=cnn_predictions,
pivot_data=pivot_data
)
logger.debug(f"Built comprehensive RL state: {len(comprehensive_state)} features")
return comprehensive_state
except Exception as e:
logger.error(f"Error building comprehensive RL state: {e}")
return self._fallback_state_conversion(market_state)
def _get_recent_tick_data(self, symbol: str, seconds: int = 300) -> List:
"""Get recent tick data from orchestrator's data provider"""
try:
if hasattr(self.orchestrator, 'data_provider') and self.orchestrator.data_provider:
# Get recent ticks from data provider
recent_ticks = self.orchestrator.data_provider.get_recent_ticks(symbol, count=seconds*10)
# Convert to required format
tick_data = []
for tick in recent_ticks[-300:]: # Last 300 ticks max
tick_data.append({
'timestamp': tick.timestamp,
'price': tick.price,
'volume': tick.volume,
'quantity': getattr(tick, 'quantity', tick.volume),
'side': getattr(tick, 'side', 'unknown'),
'trade_id': getattr(tick, 'trade_id', 'unknown')
})
return tick_data
return []
except Exception as e:
logger.warning(f"Error getting tick data for {symbol}: {e}")
return []
def _get_multiframe_ohlcv_data(self, symbol: str) -> Dict[str, List]:
"""Get multi-timeframe OHLCV data"""
try:
if hasattr(self.orchestrator, 'data_provider') and self.orchestrator.data_provider:
ohlcv_data = {}
timeframes = ['1s', '1m', '1h', '1d']
for tf in timeframes:
try:
# Get historical data for timeframe
df = self.orchestrator.data_provider.get_historical_data(
symbol=symbol,
timeframe=tf,
limit=300,
refresh=True
)
if df is not None and not df.empty:
# Convert to list of dictionaries
bars = []
for _, row in df.tail(300).iterrows():
bar = {
'timestamp': row.name if hasattr(row, 'name') else datetime.now(),
'open': float(row.get('open', 0)),
'high': float(row.get('high', 0)),
'low': float(row.get('low', 0)),
'close': float(row.get('close', 0)),
'volume': float(row.get('volume', 0))
}
bars.append(bar)
ohlcv_data[tf] = bars
else:
ohlcv_data[tf] = []
except Exception as e:
logger.warning(f"Error getting {tf} data for {symbol}: {e}")
ohlcv_data[tf] = []
return ohlcv_data
return {}
except Exception as e:
logger.warning(f"Error getting OHLCV data for {symbol}: {e}")
return {}
def _calculate_pivot_points(self, eth_ohlcv: Dict[str, List]) -> Dict[str, Any]:
"""Calculate Williams pivot points from OHLCV data"""
try:
if '1m' in eth_ohlcv and eth_ohlcv['1m']:
# Convert to numpy array for Williams calculation
bars = eth_ohlcv['1m']
if len(bars) >= 50: # Need minimum data for pivot calculation
ohlc_array = np.array([
[bar['timestamp'].timestamp() if hasattr(bar['timestamp'], 'timestamp') else time.time(),
bar['open'], bar['high'], bar['low'], bar['close'], bar['volume']]
for bar in bars[-200:] # Last 200 bars
])
pivot_data = self.williams_structure.calculate_recursive_pivot_points(ohlc_array)
return pivot_data
return {}
except Exception as e:
logger.warning(f"Error calculating pivot points: {e}")
return {}
def _fallback_state_conversion(self, market_state: MarketState) -> np.ndarray:
"""Fallback to basic state conversion if comprehensive state building fails"""
logger.warning("Using fallback state conversion - limited features")
# Fallback implementation
state_components = [
market_state.volatility,
market_state.volume,
@@ -486,8 +654,8 @@ class EnhancedRLTrainer:
for timeframe in sorted(market_state.prices.keys()):
state_components.append(market_state.prices[timeframe])
# Pad or truncate to expected state size
expected_size = self.config.rl.get('state_size', 100)
# Pad to match expected state size
expected_size = self.state_builder.total_state_size
if len(state_components) < expected_size:
state_components.extend([0.0] * (expected_size - len(state_components)))
else:
@@ -545,7 +713,7 @@ class EnhancedRLTrainer:
timestamp = max(timestamps)
loaded_count = 0
for symbol in self.config.symbols:
for symbol in self.symbols:
filename = f"rl_agent_{symbol}_{timestamp}.pt"
filepath = self.save_dir / filename