added leverage slider
This commit is contained in:
@ -1,477 +1,477 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Enhanced RL Training Launcher with Real Data Integration
|
||||
# #!/usr/bin/env python3
|
||||
# """
|
||||
# Enhanced RL Training Launcher with Real Data Integration
|
||||
|
||||
This script launches the comprehensive RL training system that uses:
|
||||
- Real-time tick data (300s window for momentum detection)
|
||||
- Multi-timeframe OHLCV data (1s, 1m, 1h, 1d)
|
||||
- BTC reference data for correlation
|
||||
- CNN hidden features and predictions
|
||||
- Williams Market Structure pivot points
|
||||
- Market microstructure analysis
|
||||
# This script launches the comprehensive RL training system that uses:
|
||||
# - Real-time tick data (300s window for momentum detection)
|
||||
# - Multi-timeframe OHLCV data (1s, 1m, 1h, 1d)
|
||||
# - BTC reference data for correlation
|
||||
# - CNN hidden features and predictions
|
||||
# - Williams Market Structure pivot points
|
||||
# - Market microstructure analysis
|
||||
|
||||
The RL model will receive ~13,400 features instead of the previous ~100 basic features.
|
||||
"""
|
||||
# The RL model will receive ~13,400 features instead of the previous ~100 basic features.
|
||||
# """
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import signal
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
# import asyncio
|
||||
# import logging
|
||||
# import time
|
||||
# import signal
|
||||
# import sys
|
||||
# from datetime import datetime, timedelta
|
||||
# from pathlib import Path
|
||||
# from typing import Dict, List, Optional
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('enhanced_rl_training.log'),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
# # Configure logging
|
||||
# logging.basicConfig(
|
||||
# level=logging.INFO,
|
||||
# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
# handlers=[
|
||||
# logging.FileHandler('enhanced_rl_training.log'),
|
||||
# logging.StreamHandler(sys.stdout)
|
||||
# ]
|
||||
# )
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
# Import our enhanced components
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from training.enhanced_rl_trainer import EnhancedRLTrainer
|
||||
from training.enhanced_rl_state_builder import EnhancedRLStateBuilder
|
||||
from training.williams_market_structure import WilliamsMarketStructure
|
||||
from training.cnn_rl_bridge import CNNRLBridge
|
||||
# # Import our enhanced components
|
||||
# from core.config import get_config
|
||||
# from core.data_provider import DataProvider
|
||||
# from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
# from training.enhanced_rl_trainer import EnhancedRLTrainer
|
||||
# from training.enhanced_rl_state_builder import EnhancedRLStateBuilder
|
||||
# from training.williams_market_structure import WilliamsMarketStructure
|
||||
# from training.cnn_rl_bridge import CNNRLBridge
|
||||
|
||||
class EnhancedRLTrainingSystem:
|
||||
"""Comprehensive RL training system with real data integration"""
|
||||
# class EnhancedRLTrainingSystem:
|
||||
# """Comprehensive RL training system with real data integration"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the enhanced RL training system"""
|
||||
self.config = get_config()
|
||||
self.running = False
|
||||
self.data_provider = None
|
||||
self.orchestrator = None
|
||||
self.rl_trainer = None
|
||||
# def __init__(self):
|
||||
# """Initialize the enhanced RL training system"""
|
||||
# self.config = get_config()
|
||||
# self.running = False
|
||||
# self.data_provider = None
|
||||
# self.orchestrator = None
|
||||
# self.rl_trainer = None
|
||||
|
||||
# Performance tracking
|
||||
self.training_stats = {
|
||||
'training_sessions': 0,
|
||||
'total_experiences': 0,
|
||||
'avg_state_size': 0,
|
||||
'data_quality_score': 0.0,
|
||||
'last_training_time': None
|
||||
}
|
||||
# # Performance tracking
|
||||
# self.training_stats = {
|
||||
# 'training_sessions': 0,
|
||||
# 'total_experiences': 0,
|
||||
# 'avg_state_size': 0,
|
||||
# 'data_quality_score': 0.0,
|
||||
# 'last_training_time': None
|
||||
# }
|
||||
|
||||
logger.info("Enhanced RL Training System initialized")
|
||||
logger.info("Features:")
|
||||
logger.info("- Real-time tick data processing (300s window)")
|
||||
logger.info("- Multi-timeframe OHLCV analysis (1s, 1m, 1h, 1d)")
|
||||
logger.info("- BTC correlation analysis")
|
||||
logger.info("- CNN feature integration")
|
||||
logger.info("- Williams Market Structure pivot points")
|
||||
logger.info("- ~13,400 feature state vector (vs previous ~100)")
|
||||
# logger.info("Enhanced RL Training System initialized")
|
||||
# logger.info("Features:")
|
||||
# logger.info("- Real-time tick data processing (300s window)")
|
||||
# logger.info("- Multi-timeframe OHLCV analysis (1s, 1m, 1h, 1d)")
|
||||
# logger.info("- BTC correlation analysis")
|
||||
# logger.info("- CNN feature integration")
|
||||
# logger.info("- Williams Market Structure pivot points")
|
||||
# logger.info("- ~13,400 feature state vector (vs previous ~100)")
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize all components"""
|
||||
try:
|
||||
logger.info("Initializing enhanced RL training components...")
|
||||
# async def initialize(self):
|
||||
# """Initialize all components"""
|
||||
# try:
|
||||
# logger.info("Initializing enhanced RL training components...")
|
||||
|
||||
# Initialize data provider with real-time streaming
|
||||
logger.info("Setting up data provider with real-time streaming...")
|
||||
self.data_provider = DataProvider(
|
||||
symbols=self.config.symbols,
|
||||
timeframes=self.config.timeframes
|
||||
)
|
||||
# # Initialize data provider with real-time streaming
|
||||
# logger.info("Setting up data provider with real-time streaming...")
|
||||
# self.data_provider = DataProvider(
|
||||
# symbols=self.config.symbols,
|
||||
# timeframes=self.config.timeframes
|
||||
# )
|
||||
|
||||
# Start real-time data streaming
|
||||
await self.data_provider.start_real_time_streaming()
|
||||
logger.info("Real-time data streaming started")
|
||||
# # Start real-time data streaming
|
||||
# await self.data_provider.start_real_time_streaming()
|
||||
# logger.info("Real-time data streaming started")
|
||||
|
||||
# Wait for initial data collection
|
||||
logger.info("Collecting initial market data...")
|
||||
await asyncio.sleep(30) # Allow 30 seconds for data collection
|
||||
# # Wait for initial data collection
|
||||
# logger.info("Collecting initial market data...")
|
||||
# await asyncio.sleep(30) # Allow 30 seconds for data collection
|
||||
|
||||
# Initialize enhanced orchestrator
|
||||
logger.info("Initializing enhanced orchestrator...")
|
||||
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
||||
# # Initialize enhanced orchestrator
|
||||
# logger.info("Initializing enhanced orchestrator...")
|
||||
# self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
||||
|
||||
# Initialize enhanced RL trainer with comprehensive state building
|
||||
logger.info("Initializing enhanced RL trainer...")
|
||||
self.rl_trainer = EnhancedRLTrainer(
|
||||
config=self.config,
|
||||
orchestrator=self.orchestrator
|
||||
)
|
||||
# # Initialize enhanced RL trainer with comprehensive state building
|
||||
# logger.info("Initializing enhanced RL trainer...")
|
||||
# self.rl_trainer = EnhancedRLTrainer(
|
||||
# config=self.config,
|
||||
# orchestrator=self.orchestrator
|
||||
# )
|
||||
|
||||
# Verify data availability
|
||||
data_status = await self._verify_data_availability()
|
||||
if not data_status['has_sufficient_data']:
|
||||
logger.warning("Insufficient data detected. Continuing with limited training.")
|
||||
logger.warning(f"Data status: {data_status}")
|
||||
else:
|
||||
logger.info("Sufficient data available for comprehensive RL training")
|
||||
logger.info(f"Tick data: {data_status['tick_count']} ticks")
|
||||
logger.info(f"OHLCV data: {data_status['ohlcv_bars']} bars")
|
||||
# # Verify data availability
|
||||
# data_status = await self._verify_data_availability()
|
||||
# if not data_status['has_sufficient_data']:
|
||||
# logger.warning("Insufficient data detected. Continuing with limited training.")
|
||||
# logger.warning(f"Data status: {data_status}")
|
||||
# else:
|
||||
# logger.info("Sufficient data available for comprehensive RL training")
|
||||
# logger.info(f"Tick data: {data_status['tick_count']} ticks")
|
||||
# logger.info(f"OHLCV data: {data_status['ohlcv_bars']} bars")
|
||||
|
||||
self.running = True
|
||||
logger.info("Enhanced RL training system initialized successfully")
|
||||
# self.running = True
|
||||
# logger.info("Enhanced RL training system initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during initialization: {e}")
|
||||
raise
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error during initialization: {e}")
|
||||
# raise
|
||||
|
||||
async def _verify_data_availability(self) -> Dict[str, any]:
|
||||
"""Verify that we have sufficient data for training"""
|
||||
try:
|
||||
data_status = {
|
||||
'has_sufficient_data': False,
|
||||
'tick_count': 0,
|
||||
'ohlcv_bars': 0,
|
||||
'symbols_with_data': [],
|
||||
'missing_data': []
|
||||
}
|
||||
# async def _verify_data_availability(self) -> Dict[str, any]:
|
||||
# """Verify that we have sufficient data for training"""
|
||||
# try:
|
||||
# data_status = {
|
||||
# 'has_sufficient_data': False,
|
||||
# 'tick_count': 0,
|
||||
# 'ohlcv_bars': 0,
|
||||
# 'symbols_with_data': [],
|
||||
# 'missing_data': []
|
||||
# }
|
||||
|
||||
for symbol in self.config.symbols:
|
||||
# Check tick data
|
||||
recent_ticks = self.data_provider.get_recent_ticks(symbol, count=100)
|
||||
tick_count = len(recent_ticks)
|
||||
# for symbol in self.config.symbols:
|
||||
# # Check tick data
|
||||
# recent_ticks = self.data_provider.get_recent_ticks(symbol, count=100)
|
||||
# tick_count = len(recent_ticks)
|
||||
|
||||
# Check OHLCV data
|
||||
ohlcv_bars = 0
|
||||
for timeframe in ['1s', '1m', '1h', '1d']:
|
||||
try:
|
||||
df = self.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=50,
|
||||
refresh=True
|
||||
)
|
||||
if df is not None and not df.empty:
|
||||
ohlcv_bars += len(df)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking {timeframe} data for {symbol}: {e}")
|
||||
# # Check OHLCV data
|
||||
# ohlcv_bars = 0
|
||||
# for timeframe in ['1s', '1m', '1h', '1d']:
|
||||
# try:
|
||||
# df = self.data_provider.get_historical_data(
|
||||
# symbol=symbol,
|
||||
# timeframe=timeframe,
|
||||
# limit=50,
|
||||
# refresh=True
|
||||
# )
|
||||
# if df is not None and not df.empty:
|
||||
# ohlcv_bars += len(df)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Error checking {timeframe} data for {symbol}: {e}")
|
||||
|
||||
data_status['tick_count'] += tick_count
|
||||
data_status['ohlcv_bars'] += ohlcv_bars
|
||||
# data_status['tick_count'] += tick_count
|
||||
# data_status['ohlcv_bars'] += ohlcv_bars
|
||||
|
||||
if tick_count >= 50 and ohlcv_bars >= 100:
|
||||
data_status['symbols_with_data'].append(symbol)
|
||||
else:
|
||||
data_status['missing_data'].append(f"{symbol}: {tick_count} ticks, {ohlcv_bars} bars")
|
||||
# if tick_count >= 50 and ohlcv_bars >= 100:
|
||||
# data_status['symbols_with_data'].append(symbol)
|
||||
# else:
|
||||
# data_status['missing_data'].append(f"{symbol}: {tick_count} ticks, {ohlcv_bars} bars")
|
||||
|
||||
# Consider data sufficient if we have at least one symbol with good data
|
||||
data_status['has_sufficient_data'] = len(data_status['symbols_with_data']) > 0
|
||||
# # Consider data sufficient if we have at least one symbol with good data
|
||||
# data_status['has_sufficient_data'] = len(data_status['symbols_with_data']) > 0
|
||||
|
||||
return data_status
|
||||
# return data_status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying data availability: {e}")
|
||||
return {'has_sufficient_data': False, 'error': str(e)}
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error verifying data availability: {e}")
|
||||
# return {'has_sufficient_data': False, 'error': str(e)}
|
||||
|
||||
async def run_training_loop(self):
|
||||
"""Run the main training loop with real data"""
|
||||
logger.info("Starting enhanced RL training loop...")
|
||||
# async def run_training_loop(self):
|
||||
# """Run the main training loop with real data"""
|
||||
# logger.info("Starting enhanced RL training loop...")
|
||||
|
||||
training_cycle = 0
|
||||
last_state_size_log = time.time()
|
||||
# training_cycle = 0
|
||||
# last_state_size_log = time.time()
|
||||
|
||||
try:
|
||||
while self.running:
|
||||
training_cycle += 1
|
||||
cycle_start_time = time.time()
|
||||
# try:
|
||||
# while self.running:
|
||||
# training_cycle += 1
|
||||
# cycle_start_time = time.time()
|
||||
|
||||
logger.info(f"Training cycle {training_cycle} started")
|
||||
# logger.info(f"Training cycle {training_cycle} started")
|
||||
|
||||
# Get comprehensive market states with real data
|
||||
market_states = await self._get_comprehensive_market_states()
|
||||
# # Get comprehensive market states with real data
|
||||
# market_states = await self._get_comprehensive_market_states()
|
||||
|
||||
if not market_states:
|
||||
logger.warning("No market states available. Waiting for data...")
|
||||
await asyncio.sleep(60)
|
||||
continue
|
||||
# if not market_states:
|
||||
# logger.warning("No market states available. Waiting for data...")
|
||||
# await asyncio.sleep(60)
|
||||
# continue
|
||||
|
||||
# Train RL agents with comprehensive states
|
||||
training_results = await self._train_rl_agents(market_states)
|
||||
# # Train RL agents with comprehensive states
|
||||
# training_results = await self._train_rl_agents(market_states)
|
||||
|
||||
# Update performance tracking
|
||||
self._update_training_stats(training_results, market_states)
|
||||
# # Update performance tracking
|
||||
# self._update_training_stats(training_results, market_states)
|
||||
|
||||
# Log training progress
|
||||
cycle_duration = time.time() - cycle_start_time
|
||||
logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s")
|
||||
# # Log training progress
|
||||
# cycle_duration = time.time() - cycle_start_time
|
||||
# logger.info(f"Training cycle {training_cycle} completed in {cycle_duration:.2f}s")
|
||||
|
||||
# Log state size periodically
|
||||
if time.time() - last_state_size_log > 300: # Every 5 minutes
|
||||
self._log_state_size_info(market_states)
|
||||
last_state_size_log = time.time()
|
||||
# # Log state size periodically
|
||||
# if time.time() - last_state_size_log > 300: # Every 5 minutes
|
||||
# self._log_state_size_info(market_states)
|
||||
# last_state_size_log = time.time()
|
||||
|
||||
# Save models periodically
|
||||
if training_cycle % 10 == 0:
|
||||
await self._save_training_progress()
|
||||
# # Save models periodically
|
||||
# if training_cycle % 10 == 0:
|
||||
# await self._save_training_progress()
|
||||
|
||||
# Wait before next training cycle
|
||||
await asyncio.sleep(300) # Train every 5 minutes
|
||||
# # Wait before next training cycle
|
||||
# await asyncio.sleep(300) # Train every 5 minutes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training loop: {e}")
|
||||
raise
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in training loop: {e}")
|
||||
# raise
|
||||
|
||||
async def _get_comprehensive_market_states(self) -> Dict[str, any]:
|
||||
"""Get comprehensive market states with all required data"""
|
||||
try:
|
||||
# Get market states from orchestrator
|
||||
universal_stream = self.orchestrator.universal_adapter.get_universal_stream()
|
||||
market_states = await self.orchestrator._get_all_market_states_universal(universal_stream)
|
||||
# async def _get_comprehensive_market_states(self) -> Dict[str, any]:
|
||||
# """Get comprehensive market states with all required data"""
|
||||
# try:
|
||||
# # Get market states from orchestrator
|
||||
# universal_stream = self.orchestrator.universal_adapter.get_universal_stream()
|
||||
# market_states = await self.orchestrator._get_all_market_states_universal(universal_stream)
|
||||
|
||||
# Verify data quality
|
||||
quality_score = self._calculate_data_quality(market_states)
|
||||
self.training_stats['data_quality_score'] = quality_score
|
||||
# # Verify data quality
|
||||
# quality_score = self._calculate_data_quality(market_states)
|
||||
# self.training_stats['data_quality_score'] = quality_score
|
||||
|
||||
if quality_score < 0.5:
|
||||
logger.warning(f"Low data quality detected: {quality_score:.2f}")
|
||||
# if quality_score < 0.5:
|
||||
# logger.warning(f"Low data quality detected: {quality_score:.2f}")
|
||||
|
||||
return market_states
|
||||
# return market_states
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting comprehensive market states: {e}")
|
||||
return {}
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error getting comprehensive market states: {e}")
|
||||
# return {}
|
||||
|
||||
def _calculate_data_quality(self, market_states: Dict[str, any]) -> float:
|
||||
"""Calculate data quality score based on available data"""
|
||||
try:
|
||||
if not market_states:
|
||||
return 0.0
|
||||
# def _calculate_data_quality(self, market_states: Dict[str, any]) -> float:
|
||||
# """Calculate data quality score based on available data"""
|
||||
# try:
|
||||
# if not market_states:
|
||||
# return 0.0
|
||||
|
||||
total_score = 0.0
|
||||
total_symbols = len(market_states)
|
||||
# total_score = 0.0
|
||||
# total_symbols = len(market_states)
|
||||
|
||||
for symbol, state in market_states.items():
|
||||
symbol_score = 0.0
|
||||
# for symbol, state in market_states.items():
|
||||
# symbol_score = 0.0
|
||||
|
||||
# Score based on tick data availability
|
||||
if hasattr(state, 'raw_ticks') and state.raw_ticks:
|
||||
tick_score = min(len(state.raw_ticks) / 100, 1.0) # Max score for 100+ ticks
|
||||
symbol_score += tick_score * 0.3
|
||||
# # Score based on tick data availability
|
||||
# if hasattr(state, 'raw_ticks') and state.raw_ticks:
|
||||
# tick_score = min(len(state.raw_ticks) / 100, 1.0) # Max score for 100+ ticks
|
||||
# symbol_score += tick_score * 0.3
|
||||
|
||||
# Score based on OHLCV data availability
|
||||
if hasattr(state, 'ohlcv_data') and state.ohlcv_data:
|
||||
ohlcv_score = len(state.ohlcv_data) / 4.0 # Max score for all 4 timeframes
|
||||
symbol_score += min(ohlcv_score, 1.0) * 0.4
|
||||
# # Score based on OHLCV data availability
|
||||
# if hasattr(state, 'ohlcv_data') and state.ohlcv_data:
|
||||
# ohlcv_score = len(state.ohlcv_data) / 4.0 # Max score for all 4 timeframes
|
||||
# symbol_score += min(ohlcv_score, 1.0) * 0.4
|
||||
|
||||
# Score based on CNN features
|
||||
if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features:
|
||||
symbol_score += 0.15
|
||||
# # Score based on CNN features
|
||||
# if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features:
|
||||
# symbol_score += 0.15
|
||||
|
||||
# Score based on pivot points
|
||||
if hasattr(state, 'pivot_points') and state.pivot_points:
|
||||
symbol_score += 0.15
|
||||
# # Score based on pivot points
|
||||
# if hasattr(state, 'pivot_points') and state.pivot_points:
|
||||
# symbol_score += 0.15
|
||||
|
||||
total_score += symbol_score
|
||||
# total_score += symbol_score
|
||||
|
||||
return total_score / total_symbols if total_symbols > 0 else 0.0
|
||||
# return total_score / total_symbols if total_symbols > 0 else 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating data quality: {e}")
|
||||
return 0.5 # Default to medium quality
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Error calculating data quality: {e}")
|
||||
# return 0.5 # Default to medium quality
|
||||
|
||||
async def _train_rl_agents(self, market_states: Dict[str, any]) -> Dict[str, any]:
|
||||
"""Train RL agents with comprehensive market states"""
|
||||
try:
|
||||
training_results = {
|
||||
'symbols_trained': [],
|
||||
'total_experiences': 0,
|
||||
'avg_state_size': 0,
|
||||
'training_errors': []
|
||||
}
|
||||
# async def _train_rl_agents(self, market_states: Dict[str, any]) -> Dict[str, any]:
|
||||
# """Train RL agents with comprehensive market states"""
|
||||
# try:
|
||||
# training_results = {
|
||||
# 'symbols_trained': [],
|
||||
# 'total_experiences': 0,
|
||||
# 'avg_state_size': 0,
|
||||
# 'training_errors': []
|
||||
# }
|
||||
|
||||
for symbol, market_state in market_states.items():
|
||||
try:
|
||||
# Convert market state to comprehensive RL state
|
||||
rl_state = self.rl_trainer._market_state_to_rl_state(market_state)
|
||||
# for symbol, market_state in market_states.items():
|
||||
# try:
|
||||
# # Convert market state to comprehensive RL state
|
||||
# rl_state = self.rl_trainer._market_state_to_rl_state(market_state)
|
||||
|
||||
if rl_state is not None and len(rl_state) > 0:
|
||||
# Record state size
|
||||
training_results['avg_state_size'] += len(rl_state)
|
||||
# if rl_state is not None and len(rl_state) > 0:
|
||||
# # Record state size
|
||||
# training_results['avg_state_size'] += len(rl_state)
|
||||
|
||||
# Simulate trading action for experience generation
|
||||
# In real implementation, this would be actual trading decisions
|
||||
action = self._simulate_trading_action(symbol, rl_state)
|
||||
# # Simulate trading action for experience generation
|
||||
# # In real implementation, this would be actual trading decisions
|
||||
# action = self._simulate_trading_action(symbol, rl_state)
|
||||
|
||||
# Generate reward based on market outcome
|
||||
reward = self._calculate_training_reward(symbol, market_state, action)
|
||||
# # Generate reward based on market outcome
|
||||
# reward = self._calculate_training_reward(symbol, market_state, action)
|
||||
|
||||
# Add experience to RL agent
|
||||
agent = self.rl_trainer.agents.get(symbol)
|
||||
if agent:
|
||||
# Create next state (would be actual next market state in real scenario)
|
||||
next_state = rl_state # Simplified for now
|
||||
# # Add experience to RL agent
|
||||
# agent = self.rl_trainer.agents.get(symbol)
|
||||
# if agent:
|
||||
# # Create next state (would be actual next market state in real scenario)
|
||||
# next_state = rl_state # Simplified for now
|
||||
|
||||
agent.remember(
|
||||
state=rl_state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=False
|
||||
)
|
||||
# agent.remember(
|
||||
# state=rl_state,
|
||||
# action=action,
|
||||
# reward=reward,
|
||||
# next_state=next_state,
|
||||
# done=False
|
||||
# )
|
||||
|
||||
# Train agent if enough experiences
|
||||
if len(agent.replay_buffer) >= agent.batch_size:
|
||||
loss = agent.replay()
|
||||
if loss is not None:
|
||||
logger.debug(f"Agent {symbol} training loss: {loss:.4f}")
|
||||
# # Train agent if enough experiences
|
||||
# if len(agent.replay_buffer) >= agent.batch_size:
|
||||
# loss = agent.replay()
|
||||
# if loss is not None:
|
||||
# logger.debug(f"Agent {symbol} training loss: {loss:.4f}")
|
||||
|
||||
training_results['symbols_trained'].append(symbol)
|
||||
training_results['total_experiences'] += 1
|
||||
# training_results['symbols_trained'].append(symbol)
|
||||
# training_results['total_experiences'] += 1
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error training {symbol}: {e}"
|
||||
logger.warning(error_msg)
|
||||
training_results['training_errors'].append(error_msg)
|
||||
# except Exception as e:
|
||||
# error_msg = f"Error training {symbol}: {e}"
|
||||
# logger.warning(error_msg)
|
||||
# training_results['training_errors'].append(error_msg)
|
||||
|
||||
# Calculate average state size
|
||||
if len(training_results['symbols_trained']) > 0:
|
||||
training_results['avg_state_size'] /= len(training_results['symbols_trained'])
|
||||
# # Calculate average state size
|
||||
# if len(training_results['symbols_trained']) > 0:
|
||||
# training_results['avg_state_size'] /= len(training_results['symbols_trained'])
|
||||
|
||||
return training_results
|
||||
# return training_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training RL agents: {e}")
|
||||
return {'error': str(e)}
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error training RL agents: {e}")
|
||||
# return {'error': str(e)}
|
||||
|
||||
def _simulate_trading_action(self, symbol: str, rl_state) -> int:
|
||||
"""Simulate trading action for training (would be real decision in production)"""
|
||||
# Simple simulation based on state features
|
||||
if len(rl_state) > 100:
|
||||
# Use momentum features to decide action
|
||||
momentum_features = rl_state[:100] # First 100 features assumed to be momentum
|
||||
avg_momentum = sum(momentum_features) / len(momentum_features)
|
||||
# def _simulate_trading_action(self, symbol: str, rl_state) -> int:
|
||||
# """Simulate trading action for training (would be real decision in production)"""
|
||||
# # Simple simulation based on state features
|
||||
# if len(rl_state) > 100:
|
||||
# # Use momentum features to decide action
|
||||
# momentum_features = rl_state[:100] # First 100 features assumed to be momentum
|
||||
# avg_momentum = sum(momentum_features) / len(momentum_features)
|
||||
|
||||
if avg_momentum > 0.6:
|
||||
return 1 # BUY
|
||||
elif avg_momentum < 0.4:
|
||||
return 2 # SELL
|
||||
else:
|
||||
return 0 # HOLD
|
||||
else:
|
||||
return 0 # HOLD as default
|
||||
# if avg_momentum > 0.6:
|
||||
# return 1 # BUY
|
||||
# elif avg_momentum < 0.4:
|
||||
# return 2 # SELL
|
||||
# else:
|
||||
# return 0 # HOLD
|
||||
# else:
|
||||
# return 0 # HOLD as default
|
||||
|
||||
def _calculate_training_reward(self, symbol: str, market_state, action: int) -> float:
|
||||
"""Calculate training reward based on market state and action"""
|
||||
try:
|
||||
# Simple reward calculation based on market conditions
|
||||
base_reward = 0.0
|
||||
# def _calculate_training_reward(self, symbol: str, market_state, action: int) -> float:
|
||||
# """Calculate training reward based on market state and action"""
|
||||
# try:
|
||||
# # Simple reward calculation based on market conditions
|
||||
# base_reward = 0.0
|
||||
|
||||
# Reward based on volatility alignment
|
||||
if hasattr(market_state, 'volatility'):
|
||||
if action == 0 and market_state.volatility > 0.02: # HOLD in high volatility
|
||||
base_reward += 0.1
|
||||
elif action != 0 and market_state.volatility < 0.01: # Trade in low volatility
|
||||
base_reward += 0.1
|
||||
# # Reward based on volatility alignment
|
||||
# if hasattr(market_state, 'volatility'):
|
||||
# if action == 0 and market_state.volatility > 0.02: # HOLD in high volatility
|
||||
# base_reward += 0.1
|
||||
# elif action != 0 and market_state.volatility < 0.01: # Trade in low volatility
|
||||
# base_reward += 0.1
|
||||
|
||||
# Reward based on trend alignment
|
||||
if hasattr(market_state, 'trend_strength'):
|
||||
if action == 1 and market_state.trend_strength > 0.6: # BUY in uptrend
|
||||
base_reward += 0.2
|
||||
elif action == 2 and market_state.trend_strength < 0.4: # SELL in downtrend
|
||||
base_reward += 0.2
|
||||
# # Reward based on trend alignment
|
||||
# if hasattr(market_state, 'trend_strength'):
|
||||
# if action == 1 and market_state.trend_strength > 0.6: # BUY in uptrend
|
||||
# base_reward += 0.2
|
||||
# elif action == 2 and market_state.trend_strength < 0.4: # SELL in downtrend
|
||||
# base_reward += 0.2
|
||||
|
||||
return base_reward
|
||||
# return base_reward
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating reward for {symbol}: {e}")
|
||||
return 0.0
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Error calculating reward for {symbol}: {e}")
|
||||
# return 0.0
|
||||
|
||||
def _update_training_stats(self, training_results: Dict[str, any], market_states: Dict[str, any]):
|
||||
"""Update training statistics"""
|
||||
self.training_stats['training_sessions'] += 1
|
||||
self.training_stats['total_experiences'] += training_results.get('total_experiences', 0)
|
||||
self.training_stats['avg_state_size'] = training_results.get('avg_state_size', 0)
|
||||
self.training_stats['last_training_time'] = datetime.now()
|
||||
# def _update_training_stats(self, training_results: Dict[str, any], market_states: Dict[str, any]):
|
||||
# """Update training statistics"""
|
||||
# self.training_stats['training_sessions'] += 1
|
||||
# self.training_stats['total_experiences'] += training_results.get('total_experiences', 0)
|
||||
# self.training_stats['avg_state_size'] = training_results.get('avg_state_size', 0)
|
||||
# self.training_stats['last_training_time'] = datetime.now()
|
||||
|
||||
# Log statistics periodically
|
||||
if self.training_stats['training_sessions'] % 10 == 0:
|
||||
logger.info("Training Statistics:")
|
||||
logger.info(f" Sessions: {self.training_stats['training_sessions']}")
|
||||
logger.info(f" Total Experiences: {self.training_stats['total_experiences']}")
|
||||
logger.info(f" Avg State Size: {self.training_stats['avg_state_size']:.0f}")
|
||||
logger.info(f" Data Quality: {self.training_stats['data_quality_score']:.2f}")
|
||||
# # Log statistics periodically
|
||||
# if self.training_stats['training_sessions'] % 10 == 0:
|
||||
# logger.info("Training Statistics:")
|
||||
# logger.info(f" Sessions: {self.training_stats['training_sessions']}")
|
||||
# logger.info(f" Total Experiences: {self.training_stats['total_experiences']}")
|
||||
# logger.info(f" Avg State Size: {self.training_stats['avg_state_size']:.0f}")
|
||||
# logger.info(f" Data Quality: {self.training_stats['data_quality_score']:.2f}")
|
||||
|
||||
def _log_state_size_info(self, market_states: Dict[str, any]):
|
||||
"""Log information about state sizes for debugging"""
|
||||
for symbol, state in market_states.items():
|
||||
info = []
|
||||
# def _log_state_size_info(self, market_states: Dict[str, any]):
|
||||
# """Log information about state sizes for debugging"""
|
||||
# for symbol, state in market_states.items():
|
||||
# info = []
|
||||
|
||||
if hasattr(state, 'raw_ticks'):
|
||||
info.append(f"ticks: {len(state.raw_ticks)}")
|
||||
# if hasattr(state, 'raw_ticks'):
|
||||
# info.append(f"ticks: {len(state.raw_ticks)}")
|
||||
|
||||
if hasattr(state, 'ohlcv_data'):
|
||||
total_bars = sum(len(bars) for bars in state.ohlcv_data.values())
|
||||
info.append(f"OHLCV bars: {total_bars}")
|
||||
# if hasattr(state, 'ohlcv_data'):
|
||||
# total_bars = sum(len(bars) for bars in state.ohlcv_data.values())
|
||||
# info.append(f"OHLCV bars: {total_bars}")
|
||||
|
||||
if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features:
|
||||
info.append("CNN features: available")
|
||||
# if hasattr(state, 'cnn_hidden_features') and state.cnn_hidden_features:
|
||||
# info.append("CNN features: available")
|
||||
|
||||
if hasattr(state, 'pivot_points') and state.pivot_points:
|
||||
info.append("pivot points: available")
|
||||
# if hasattr(state, 'pivot_points') and state.pivot_points:
|
||||
# info.append("pivot points: available")
|
||||
|
||||
logger.info(f"{symbol} state data: {', '.join(info)}")
|
||||
# logger.info(f"{symbol} state data: {', '.join(info)}")
|
||||
|
||||
async def _save_training_progress(self):
|
||||
"""Save training progress and models"""
|
||||
try:
|
||||
if self.rl_trainer:
|
||||
self.rl_trainer._save_all_models()
|
||||
logger.info("Training progress saved")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving training progress: {e}")
|
||||
# async def _save_training_progress(self):
|
||||
# """Save training progress and models"""
|
||||
# try:
|
||||
# if self.rl_trainer:
|
||||
# self.rl_trainer._save_all_models()
|
||||
# logger.info("Training progress saved")
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error saving training progress: {e}")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Graceful shutdown"""
|
||||
logger.info("Shutting down enhanced RL training system...")
|
||||
self.running = False
|
||||
# async def shutdown(self):
|
||||
# """Graceful shutdown"""
|
||||
# logger.info("Shutting down enhanced RL training system...")
|
||||
# self.running = False
|
||||
|
||||
# Save final state
|
||||
await self._save_training_progress()
|
||||
# # Save final state
|
||||
# await self._save_training_progress()
|
||||
|
||||
# Stop data provider
|
||||
if self.data_provider:
|
||||
await self.data_provider.stop_real_time_streaming()
|
||||
# # Stop data provider
|
||||
# if self.data_provider:
|
||||
# await self.data_provider.stop_real_time_streaming()
|
||||
|
||||
logger.info("Enhanced RL training system shutdown complete")
|
||||
# logger.info("Enhanced RL training system shutdown complete")
|
||||
|
||||
async def main():
|
||||
"""Main function to run enhanced RL training"""
|
||||
system = None
|
||||
# async def main():
|
||||
# """Main function to run enhanced RL training"""
|
||||
# system = None
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
logger.info("Received shutdown signal")
|
||||
if system:
|
||||
asyncio.create_task(system.shutdown())
|
||||
# def signal_handler(signum, frame):
|
||||
# logger.info("Received shutdown signal")
|
||||
# if system:
|
||||
# asyncio.create_task(system.shutdown())
|
||||
|
||||
# Set up signal handlers
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
# # Set up signal handlers
|
||||
# signal.signal(signal.SIGINT, signal_handler)
|
||||
# signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
# Create and initialize the training system
|
||||
system = EnhancedRLTrainingSystem()
|
||||
await system.initialize()
|
||||
# try:
|
||||
# # Create and initialize the training system
|
||||
# system = EnhancedRLTrainingSystem()
|
||||
# await system.initialize()
|
||||
|
||||
logger.info("Enhanced RL Training System is now running...")
|
||||
logger.info("The RL model now receives ~13,400 features instead of ~100!")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
# logger.info("Enhanced RL Training System is now running...")
|
||||
# logger.info("The RL model now receives ~13,400 features instead of ~100!")
|
||||
# logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# Run the training loop
|
||||
await system.run_training_loop()
|
||||
# # Run the training loop
|
||||
# await system.run_training_loop()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main training loop: {e}")
|
||||
raise
|
||||
finally:
|
||||
if system:
|
||||
await system.shutdown()
|
||||
# except KeyboardInterrupt:
|
||||
# logger.info("Training interrupted by user")
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in main training loop: {e}")
|
||||
# raise
|
||||
# finally:
|
||||
# if system:
|
||||
# await system.shutdown()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
# if __name__ == "__main__":
|
||||
# asyncio.run(main())
|
Reference in New Issue
Block a user