wip on the RL training pipeline and data collection
This commit is contained in:
602
web/dashboard.py
602
web/dashboard.py
@ -41,6 +41,33 @@ from core.data_provider import DataProvider
|
||||
from core.orchestrator import TradingOrchestrator, TradingDecision
|
||||
from core.trading_executor import TradingExecutor
|
||||
|
||||
# Enhanced RL Training Integration
|
||||
try:
|
||||
from core.unified_data_stream import UnifiedDataStream, TrainingDataPacket, UIDataPacket
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator, MarketState, TradingAction
|
||||
from training.enhanced_rl_trainer import EnhancedRLTrainer
|
||||
ENHANCED_RL_AVAILABLE = True
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Enhanced RL training components available")
|
||||
except ImportError as e:
|
||||
ENHANCED_RL_AVAILABLE = False
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"Enhanced RL training not available: {e}")
|
||||
# Fallback classes
|
||||
class UnifiedDataStream:
|
||||
def __init__(self, *args, **kwargs): pass
|
||||
def register_consumer(self, *args, **kwargs): return "fallback_consumer"
|
||||
def start_streaming(self): pass
|
||||
def stop_streaming(self): pass
|
||||
def get_latest_training_data(self): return None
|
||||
def get_latest_ui_data(self): return None
|
||||
|
||||
class TrainingDataPacket:
|
||||
def __init__(self, *args, **kwargs): pass
|
||||
|
||||
class UIDataPacket:
|
||||
def __init__(self, *args, **kwargs): pass
|
||||
|
||||
# Try to import model registry, fallback if not available
|
||||
try:
|
||||
from models import get_model_registry
|
||||
@ -73,16 +100,40 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TradingDashboard:
|
||||
"""Modern trading dashboard with real-time updates"""
|
||||
"""Modern trading dashboard with real-time updates and enhanced RL training integration"""
|
||||
|
||||
def __init__(self, data_provider: DataProvider = None, orchestrator: TradingOrchestrator = None, trading_executor: TradingExecutor = None):
|
||||
"""Initialize the dashboard"""
|
||||
"""Initialize the dashboard with unified data stream and enhanced RL training"""
|
||||
self.config = get_config()
|
||||
self.data_provider = data_provider or DataProvider()
|
||||
self.orchestrator = orchestrator or TradingOrchestrator(self.data_provider)
|
||||
|
||||
# Enhanced orchestrator support
|
||||
if ENHANCED_RL_AVAILABLE and isinstance(orchestrator, EnhancedTradingOrchestrator):
|
||||
self.orchestrator = orchestrator
|
||||
self.enhanced_rl_enabled = True
|
||||
logger.info("Enhanced RL training orchestrator detected")
|
||||
else:
|
||||
self.orchestrator = orchestrator or TradingOrchestrator(self.data_provider)
|
||||
self.enhanced_rl_enabled = False
|
||||
logger.info("Using standard orchestrator")
|
||||
|
||||
self.trading_executor = trading_executor or TradingExecutor()
|
||||
self.model_registry = get_model_registry()
|
||||
|
||||
# Initialize unified data stream for comprehensive training data
|
||||
if ENHANCED_RL_AVAILABLE:
|
||||
self.unified_stream = UnifiedDataStream(self.data_provider, self.orchestrator)
|
||||
self.stream_consumer_id = self.unified_stream.register_consumer(
|
||||
consumer_name="TradingDashboard",
|
||||
callback=self._handle_unified_stream_data,
|
||||
data_types=['ticks', 'ohlcv', 'training_data', 'ui_data']
|
||||
)
|
||||
logger.info(f"Unified data stream initialized with consumer ID: {self.stream_consumer_id}")
|
||||
else:
|
||||
self.unified_stream = UnifiedDataStream() # Fallback
|
||||
self.stream_consumer_id = "fallback"
|
||||
logger.warning("Using fallback unified data stream")
|
||||
|
||||
# Dashboard state
|
||||
self.recent_decisions = []
|
||||
self.recent_signals = [] # Track all signals (not just executed trades)
|
||||
@ -126,21 +177,29 @@ class TradingDashboard:
|
||||
self.ws_thread = None
|
||||
self.is_streaming = False
|
||||
|
||||
# Load available models for real trading
|
||||
self._load_available_models()
|
||||
|
||||
# RL Training System - Train on closed trades
|
||||
# Enhanced RL Training System - Train on closed trades with comprehensive data
|
||||
self.rl_training_enabled = True
|
||||
self.enhanced_rl_training_enabled = ENHANCED_RL_AVAILABLE and self.enhanced_rl_enabled
|
||||
self.rl_training_stats = {
|
||||
'total_training_episodes': 0,
|
||||
'profitable_trades_trained': 0,
|
||||
'unprofitable_trades_trained': 0,
|
||||
'last_training_time': None,
|
||||
'training_rewards': deque(maxlen=100), # Last 100 training rewards
|
||||
'model_accuracy_trend': deque(maxlen=50) # Track accuracy over time
|
||||
'model_accuracy_trend': deque(maxlen=50), # Track accuracy over time
|
||||
'enhanced_rl_episodes': 0,
|
||||
'comprehensive_data_packets': 0
|
||||
}
|
||||
self.rl_training_queue = deque(maxlen=1000) # Queue of trades to train on
|
||||
|
||||
# Enhanced training data tracking
|
||||
self.latest_training_data = None
|
||||
self.latest_ui_data = None
|
||||
self.training_data_available = False
|
||||
|
||||
# Load available models for real trading
|
||||
self._load_available_models()
|
||||
|
||||
# Create Dash app
|
||||
self.app = dash.Dash(__name__, external_stylesheets=[
|
||||
'https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css',
|
||||
@ -151,13 +210,244 @@ class TradingDashboard:
|
||||
self._setup_layout()
|
||||
self._setup_callbacks()
|
||||
|
||||
# Start WebSocket tick streaming
|
||||
self._start_websocket_stream()
|
||||
# Start unified data streaming
|
||||
self._initialize_streaming()
|
||||
|
||||
# Start continuous training
|
||||
# Start continuous training with enhanced RL support
|
||||
self.start_continuous_training()
|
||||
|
||||
logger.info("Trading Dashboard initialized with continuous training")
|
||||
logger.info("Trading Dashboard initialized with enhanced RL training integration")
|
||||
logger.info(f"Enhanced RL enabled: {self.enhanced_rl_training_enabled}")
|
||||
logger.info(f"Stream consumer ID: {self.stream_consumer_id}")
|
||||
|
||||
def _initialize_streaming(self):
|
||||
"""Initialize unified data streaming and WebSocket fallback"""
|
||||
try:
|
||||
if ENHANCED_RL_AVAILABLE:
|
||||
# Start unified data stream
|
||||
asyncio.run(self.unified_stream.start_streaming())
|
||||
logger.info("Unified data stream started")
|
||||
|
||||
# Start WebSocket as backup/additional data source
|
||||
self._start_websocket_stream()
|
||||
|
||||
# Start background data collection
|
||||
self._start_enhanced_training_data_collection()
|
||||
|
||||
logger.info("All data streaming initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing streaming: {e}")
|
||||
# Fallback to WebSocket only
|
||||
self._start_websocket_stream()
|
||||
|
||||
def _start_enhanced_training_data_collection(self):
|
||||
"""Start enhanced training data collection using unified stream"""
|
||||
def enhanced_training_loop():
|
||||
try:
|
||||
logger.info("Enhanced training data collection started with unified stream")
|
||||
|
||||
while True:
|
||||
try:
|
||||
if ENHANCED_RL_AVAILABLE and self.enhanced_rl_training_enabled:
|
||||
# Get latest comprehensive training data from unified stream
|
||||
training_data = self.unified_stream.get_latest_training_data()
|
||||
|
||||
if training_data:
|
||||
# Send comprehensive training data to enhanced RL pipeline
|
||||
self._send_comprehensive_training_data_to_enhanced_rl(training_data)
|
||||
|
||||
# Update training statistics
|
||||
self.rl_training_stats['comprehensive_data_packets'] += 1
|
||||
self.training_data_available = True
|
||||
|
||||
# Update context data in orchestrator
|
||||
if hasattr(self.orchestrator, 'update_context_data'):
|
||||
self.orchestrator.update_context_data()
|
||||
|
||||
# Initialize extrema trainer if not done
|
||||
if hasattr(self.orchestrator, 'extrema_trainer'):
|
||||
if not hasattr(self.orchestrator.extrema_trainer, '_initialized'):
|
||||
self.orchestrator.extrema_trainer.initialize_context_data()
|
||||
self.orchestrator.extrema_trainer._initialized = True
|
||||
logger.info("Extrema trainer context data initialized")
|
||||
|
||||
# Run extrema detection with real data
|
||||
if hasattr(self.orchestrator, 'extrema_trainer'):
|
||||
for symbol in self.orchestrator.symbols:
|
||||
detected = self.orchestrator.extrema_trainer.detect_local_extrema(symbol)
|
||||
if detected:
|
||||
logger.debug(f"Detected {len(detected)} extrema for {symbol}")
|
||||
else:
|
||||
# Fallback to basic training data collection
|
||||
self._collect_basic_training_data()
|
||||
|
||||
time.sleep(10) # Update every 10 seconds for enhanced training
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in enhanced training loop: {e}")
|
||||
time.sleep(30) # Wait before retrying
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Enhanced training loop failed: {e}")
|
||||
|
||||
# Start enhanced training thread
|
||||
training_thread = Thread(target=enhanced_training_loop, daemon=True)
|
||||
training_thread.start()
|
||||
logger.info("Enhanced training data collection thread started")
|
||||
|
||||
def _handle_unified_stream_data(self, data_packet: Dict[str, Any]):
|
||||
"""Handle data from unified stream for dashboard and training"""
|
||||
try:
|
||||
# Extract UI data for dashboard display
|
||||
if 'ui_data' in data_packet:
|
||||
self.latest_ui_data = data_packet['ui_data']
|
||||
if hasattr(self.latest_ui_data, 'current_prices'):
|
||||
self.current_prices.update(self.latest_ui_data.current_prices)
|
||||
if hasattr(self.latest_ui_data, 'streaming_status'):
|
||||
self.is_streaming = self.latest_ui_data.streaming_status == 'LIVE'
|
||||
if hasattr(self.latest_ui_data, 'training_data_available'):
|
||||
self.training_data_available = self.latest_ui_data.training_data_available
|
||||
|
||||
# Extract training data for enhanced RL
|
||||
if 'training_data' in data_packet:
|
||||
self.latest_training_data = data_packet['training_data']
|
||||
logger.debug("Received comprehensive training data from unified stream")
|
||||
|
||||
# Extract tick data for dashboard charts
|
||||
if 'ticks' in data_packet:
|
||||
ticks = data_packet['ticks']
|
||||
for tick in ticks[-100:]: # Keep last 100 ticks
|
||||
self.tick_cache.append(tick)
|
||||
|
||||
# Extract OHLCV data for dashboard charts
|
||||
if 'one_second_bars' in data_packet:
|
||||
bars = data_packet['one_second_bars']
|
||||
for bar in bars[-100:]: # Keep last 100 bars
|
||||
self.one_second_bars.append(bar)
|
||||
|
||||
logger.debug(f"Processed unified stream data packet with keys: {list(data_packet.keys())}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling unified stream data: {e}")
|
||||
|
||||
def _send_comprehensive_training_data_to_enhanced_rl(self, training_data: TrainingDataPacket):
|
||||
"""Send comprehensive training data to enhanced RL training pipeline"""
|
||||
try:
|
||||
if not self.enhanced_rl_training_enabled:
|
||||
logger.debug("Enhanced RL training not enabled, skipping comprehensive data send")
|
||||
return
|
||||
|
||||
# Extract comprehensive training data components
|
||||
market_state = training_data.market_state if hasattr(training_data, 'market_state') else None
|
||||
universal_stream = training_data.universal_stream if hasattr(training_data, 'universal_stream') else None
|
||||
cnn_features = training_data.cnn_features if hasattr(training_data, 'cnn_features') else None
|
||||
cnn_predictions = training_data.cnn_predictions if hasattr(training_data, 'cnn_predictions') else None
|
||||
|
||||
if market_state and universal_stream:
|
||||
# Send to enhanced RL trainer if available
|
||||
if hasattr(self.orchestrator, 'enhanced_rl_trainer'):
|
||||
try:
|
||||
# Create comprehensive training step with ~13,400 features
|
||||
asyncio.run(self.orchestrator.enhanced_rl_trainer.training_step(universal_stream))
|
||||
self.rl_training_stats['enhanced_rl_episodes'] += 1
|
||||
logger.debug("Sent comprehensive data to enhanced RL trainer")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in enhanced RL training step: {e}")
|
||||
|
||||
# Send to extrema trainer for CNN training with perfect moves
|
||||
if hasattr(self.orchestrator, 'extrema_trainer'):
|
||||
try:
|
||||
extrema_data = self.orchestrator.extrema_trainer.get_extrema_training_data(count=50)
|
||||
perfect_moves = self.orchestrator.extrema_trainer.get_perfect_moves_for_cnn(count=100)
|
||||
|
||||
if extrema_data:
|
||||
logger.debug(f"Enhanced RL: {len(extrema_data)} extrema training samples available")
|
||||
|
||||
if perfect_moves:
|
||||
logger.debug(f"Enhanced RL: {len(perfect_moves)} perfect moves for CNN training")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting extrema training data: {e}")
|
||||
|
||||
# Send to sensitivity learning DQN for outcome-based learning
|
||||
if hasattr(self.orchestrator, 'sensitivity_learning_queue'):
|
||||
try:
|
||||
if len(self.orchestrator.sensitivity_learning_queue) > 0:
|
||||
logger.debug("Enhanced RL: Sensitivity learning data available for DQN training")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error accessing sensitivity learning queue: {e}")
|
||||
|
||||
# Get context features for models with real market data
|
||||
if hasattr(self.orchestrator, 'extrema_trainer'):
|
||||
try:
|
||||
for symbol in self.orchestrator.symbols:
|
||||
context_features = self.orchestrator.extrema_trainer.get_context_features_for_model(symbol)
|
||||
if context_features is not None:
|
||||
logger.debug(f"Enhanced RL: Context features available for {symbol}: {context_features.shape}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting context features: {e}")
|
||||
|
||||
# Log comprehensive training data statistics
|
||||
tick_count = len(training_data.tick_cache) if hasattr(training_data, 'tick_cache') else 0
|
||||
bars_count = len(training_data.one_second_bars) if hasattr(training_data, 'one_second_bars') else 0
|
||||
timeframe_count = len(training_data.multi_timeframe_data) if hasattr(training_data, 'multi_timeframe_data') else 0
|
||||
|
||||
logger.info(f"Enhanced RL Comprehensive Training Data:")
|
||||
logger.info(f" Tick cache: {tick_count} ticks")
|
||||
logger.info(f" 1s bars: {bars_count} bars")
|
||||
logger.info(f" Multi-timeframe data: {timeframe_count} symbols")
|
||||
logger.info(f" CNN features: {'Available' if cnn_features else 'Not available'}")
|
||||
logger.info(f" CNN predictions: {'Available' if cnn_predictions else 'Not available'}")
|
||||
logger.info(f" Market state: {'Available (~13,400 features)' if market_state else 'Not available'}")
|
||||
logger.info(f" Universal stream: {'Available' if universal_stream else 'Not available'}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending comprehensive training data to enhanced RL: {e}")
|
||||
|
||||
def _collect_basic_training_data(self):
|
||||
"""Fallback method to collect basic training data when enhanced RL is not available"""
|
||||
try:
|
||||
# Get real tick data from data provider subscribers
|
||||
for symbol in ['ETH/USDT', 'BTC/USDT']:
|
||||
try:
|
||||
# Get recent ticks from data provider
|
||||
if hasattr(self.data_provider, 'get_recent_ticks'):
|
||||
recent_ticks = self.data_provider.get_recent_ticks(symbol, count=10)
|
||||
|
||||
for tick in recent_ticks:
|
||||
# Create tick data from real market data
|
||||
tick_data = {
|
||||
'symbol': tick.symbol,
|
||||
'price': tick.price,
|
||||
'timestamp': tick.timestamp,
|
||||
'volume': tick.volume
|
||||
}
|
||||
|
||||
# Add to tick cache
|
||||
self.tick_cache.append(tick_data)
|
||||
|
||||
# Create 1s bar data from real tick
|
||||
bar_data = {
|
||||
'symbol': tick.symbol,
|
||||
'open': tick.price,
|
||||
'high': tick.price,
|
||||
'low': tick.price,
|
||||
'close': tick.price,
|
||||
'volume': tick.volume,
|
||||
'timestamp': tick.timestamp
|
||||
}
|
||||
|
||||
# Add to 1s bars cache
|
||||
self.one_second_bars.append(bar_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"No recent tick data available for {symbol}: {e}")
|
||||
|
||||
# Set streaming status based on real data availability
|
||||
self.is_streaming = len(self.tick_cache) > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in basic training data collection: {e}")
|
||||
|
||||
def _get_initial_balance(self) -> float:
|
||||
"""Get initial USDT balance from MEXC or return default"""
|
||||
@ -2240,12 +2530,12 @@ class TradingDashboard:
|
||||
logger.warning(f"RL prediction error: {e}")
|
||||
return np.array([0.33, 0.34, 0.33]), 0.5
|
||||
|
||||
def get_memory_usage(self):
|
||||
return 80 # MB estimate
|
||||
|
||||
def to_device(self, device):
|
||||
self.device = device
|
||||
return self
|
||||
def get_memory_usage(self):
|
||||
return 80 # MB estimate
|
||||
|
||||
def to_device(self, device):
|
||||
self.device = device
|
||||
return self
|
||||
|
||||
rl_wrapper = RLWrapper(rl_path)
|
||||
|
||||
@ -2511,19 +2801,20 @@ class TradingDashboard:
|
||||
return pd.DataFrame()
|
||||
|
||||
def _create_training_metrics(self) -> List:
|
||||
"""Create comprehensive model training metrics display"""
|
||||
"""Create comprehensive model training metrics display with enhanced RL integration"""
|
||||
try:
|
||||
training_items = []
|
||||
|
||||
# Training Data Streaming Status
|
||||
# Enhanced Training Data Streaming Status
|
||||
tick_cache_size = len(self.tick_cache)
|
||||
bars_cache_size = len(self.one_second_bars)
|
||||
enhanced_data_available = self.training_data_available and self.enhanced_rl_training_enabled
|
||||
|
||||
training_items.append(
|
||||
html.Div([
|
||||
html.H6([
|
||||
html.I(className="fas fa-database me-2 text-info"),
|
||||
"Training Data Stream"
|
||||
"Enhanced Training Data Stream"
|
||||
], className="mb-2"),
|
||||
html.Div([
|
||||
html.Small([
|
||||
@ -2538,11 +2829,58 @@ class TradingDashboard:
|
||||
html.Strong("Stream: "),
|
||||
html.Span("LIVE" if self.is_streaming else "OFFLINE",
|
||||
className="text-success" if self.is_streaming else "text-danger")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Enhanced RL: "),
|
||||
html.Span("ENABLED" if self.enhanced_rl_training_enabled else "DISABLED",
|
||||
className="text-success" if self.enhanced_rl_training_enabled else "text-warning")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Comprehensive Data: "),
|
||||
html.Span("AVAILABLE" if enhanced_data_available else "WAITING",
|
||||
className="text-success" if enhanced_data_available else "text-warning")
|
||||
], className="d-block")
|
||||
])
|
||||
], className="mb-3 p-2 border border-info rounded")
|
||||
)
|
||||
|
||||
# Enhanced RL Training Statistics
|
||||
if self.enhanced_rl_training_enabled:
|
||||
enhanced_episodes = self.rl_training_stats.get('enhanced_rl_episodes', 0)
|
||||
comprehensive_packets = self.rl_training_stats.get('comprehensive_data_packets', 0)
|
||||
|
||||
training_items.append(
|
||||
html.Div([
|
||||
html.H6([
|
||||
html.I(className="fas fa-brain me-2 text-success"),
|
||||
"Enhanced RL Training"
|
||||
], className="mb-2"),
|
||||
html.Div([
|
||||
html.Small([
|
||||
html.Strong("Status: "),
|
||||
html.Span("ACTIVE" if enhanced_episodes > 0 else "WAITING",
|
||||
className="text-success" if enhanced_episodes > 0 else "text-warning")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Episodes: "),
|
||||
html.Span(f"{enhanced_episodes}", className="text-info")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Data Packets: "),
|
||||
html.Span(f"{comprehensive_packets}", className="text-info")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Features: "),
|
||||
html.Span("~13,400 (Market State)", className="text-success")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Training Mode: "),
|
||||
html.Span("Comprehensive", className="text-success")
|
||||
], className="d-block")
|
||||
])
|
||||
], className="mb-3 p-2 border border-success rounded")
|
||||
)
|
||||
|
||||
# Model Training Status
|
||||
try:
|
||||
# Try to get real training metrics from orchestrator
|
||||
@ -2553,7 +2891,7 @@ class TradingDashboard:
|
||||
html.Div([
|
||||
html.H6([
|
||||
html.I(className="fas fa-brain me-2 text-warning"),
|
||||
"CNN Model"
|
||||
"CNN Model (Extrema Detection)"
|
||||
], className="mb-2"),
|
||||
html.Div([
|
||||
html.Small([
|
||||
@ -2570,59 +2908,58 @@ class TradingDashboard:
|
||||
html.Span(f"{training_status['cnn']['loss']:.4f}", className="text-muted")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Epochs: "),
|
||||
html.Span(f"{training_status['cnn']['epochs']}", className="text-muted")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Learning Rate: "),
|
||||
html.Span(f"{training_status['cnn']['learning_rate']:.6f}", className="text-muted")
|
||||
html.Strong("Perfect Moves: "),
|
||||
html.Span("Available" if hasattr(self.orchestrator, 'extrema_trainer') else "N/A",
|
||||
className="text-success" if hasattr(self.orchestrator, 'extrema_trainer') else "text-muted")
|
||||
], className="d-block")
|
||||
])
|
||||
], className="mb-3 p-2 border border-warning rounded")
|
||||
)
|
||||
|
||||
# RL Training Metrics
|
||||
# RL Training Metrics (Enhanced)
|
||||
total_episodes = self.rl_training_stats.get('total_training_episodes', 0)
|
||||
profitable_trades = self.rl_training_stats.get('profitable_trades_trained', 0)
|
||||
win_rate = (profitable_trades / total_episodes * 100) if total_episodes > 0 else 0
|
||||
|
||||
training_items.append(
|
||||
html.Div([
|
||||
html.H6([
|
||||
html.I(className="fas fa-robot me-2 text-success"),
|
||||
"RL Agent (DQN)"
|
||||
html.I(className="fas fa-robot me-2 text-primary"),
|
||||
"RL Agent (DQN + Sensitivity Learning)"
|
||||
], className="mb-2"),
|
||||
html.Div([
|
||||
html.Small([
|
||||
html.Strong("Status: "),
|
||||
html.Span(training_status['rl']['status'],
|
||||
className=f"text-{training_status['rl']['status_color']}")
|
||||
html.Span("ENHANCED" if self.enhanced_rl_training_enabled else "BASIC",
|
||||
className="text-success" if self.enhanced_rl_training_enabled else "text-warning")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Win Rate: "),
|
||||
html.Span(f"{training_status['rl']['win_rate']:.1%}", className="text-info")
|
||||
html.Span(f"{win_rate:.1f}%", className="text-info")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Avg Reward: "),
|
||||
html.Span(f"{training_status['rl']['avg_reward']:.2f}", className="text-muted")
|
||||
html.Strong("Total Episodes: "),
|
||||
html.Span(f"{total_episodes}", className="text-muted")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Episodes: "),
|
||||
html.Span(f"{training_status['rl']['episodes']}", className="text-muted")
|
||||
html.Strong("Enhanced Episodes: "),
|
||||
html.Span(f"{enhanced_episodes}" if self.enhanced_rl_training_enabled else "N/A",
|
||||
className="text-success" if self.enhanced_rl_training_enabled else "text-muted")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Epsilon: "),
|
||||
html.Span(f"{training_status['rl']['epsilon']:.3f}", className="text-muted")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Memory: "),
|
||||
html.Span(f"{training_status['rl']['memory_size']:,}", className="text-muted")
|
||||
html.Strong("Sensitivity Learning: "),
|
||||
html.Span("ACTIVE" if hasattr(self.orchestrator, 'sensitivity_learning_queue') else "N/A",
|
||||
className="text-success" if hasattr(self.orchestrator, 'sensitivity_learning_queue') else "text-muted")
|
||||
], className="d-block")
|
||||
])
|
||||
], className="mb-3 p-2 border border-success rounded")
|
||||
], className="mb-3 p-2 border border-primary rounded")
|
||||
)
|
||||
|
||||
# Training Progress Chart (Mini)
|
||||
training_items.append(
|
||||
html.Div([
|
||||
html.H6([
|
||||
html.I(className="fas fa-chart-line me-2 text-primary"),
|
||||
html.I(className="fas fa-chart-line me-2 text-secondary"),
|
||||
"Training Progress"
|
||||
], className="mb-2"),
|
||||
dcc.Graph(
|
||||
@ -2630,7 +2967,7 @@ class TradingDashboard:
|
||||
style={"height": "150px"},
|
||||
config={'displayModeBar': False}
|
||||
)
|
||||
], className="mb-3 p-2 border border-primary rounded")
|
||||
], className="mb-3 p-2 border border-secondary rounded")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@ -3365,7 +3702,7 @@ class TradingDashboard:
|
||||
logger.error(f"Error stopping continuous training: {e}")
|
||||
|
||||
def _trigger_rl_training_on_closed_trade(self, closed_trade):
|
||||
"""Trigger RL training based on a closed trade's profitability"""
|
||||
"""Trigger enhanced RL training based on a closed trade's profitability with comprehensive data"""
|
||||
try:
|
||||
if not self.rl_training_enabled:
|
||||
return
|
||||
@ -3375,7 +3712,7 @@ class TradingDashboard:
|
||||
is_profitable = net_pnl > 0
|
||||
trade_duration = closed_trade.get('duration', timedelta(0))
|
||||
|
||||
# Create training episode data
|
||||
# Create enhanced training episode data
|
||||
training_episode = {
|
||||
'trade_id': closed_trade.get('trade_id'),
|
||||
'side': closed_trade.get('side'),
|
||||
@ -3386,7 +3723,8 @@ class TradingDashboard:
|
||||
'duration_seconds': trade_duration.total_seconds(),
|
||||
'symbol': closed_trade.get('symbol', 'ETH/USDT'),
|
||||
'timestamp': closed_trade.get('exit_time', datetime.now()),
|
||||
'reward': self._calculate_rl_reward(closed_trade)
|
||||
'reward': self._calculate_rl_reward(closed_trade),
|
||||
'enhanced_data_available': self.enhanced_rl_training_enabled
|
||||
}
|
||||
|
||||
# Add to training queue
|
||||
@ -3402,16 +3740,126 @@ class TradingDashboard:
|
||||
self.rl_training_stats['last_training_time'] = datetime.now()
|
||||
self.rl_training_stats['training_rewards'].append(training_episode['reward'])
|
||||
|
||||
# Trigger actual RL model training
|
||||
self._execute_rl_training_step(training_episode)
|
||||
# Enhanced RL training with comprehensive data
|
||||
if self.enhanced_rl_training_enabled:
|
||||
self._execute_enhanced_rl_training_step(training_episode)
|
||||
else:
|
||||
# Fallback to basic RL training
|
||||
self._execute_rl_training_step(training_episode)
|
||||
|
||||
logger.info(f"[RL_TRAINING] Trade #{training_episode['trade_id']} added to training: "
|
||||
logger.info(f"[RL_TRAINING] Trade #{training_episode['trade_id']} added to {'ENHANCED' if self.enhanced_rl_training_enabled else 'BASIC'} training: "
|
||||
f"{'PROFITABLE' if is_profitable else 'LOSS'} "
|
||||
f"PnL: ${net_pnl:.2f}, Reward: {training_episode['reward']:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in RL training trigger: {e}")
|
||||
|
||||
def _execute_enhanced_rl_training_step(self, training_episode):
|
||||
"""Execute enhanced RL training step with comprehensive market data"""
|
||||
try:
|
||||
# Get comprehensive training data from unified stream
|
||||
training_data = self.unified_stream.get_latest_training_data() if ENHANCED_RL_AVAILABLE else None
|
||||
|
||||
if training_data and hasattr(training_data, 'market_state') and training_data.market_state:
|
||||
# Enhanced RL training with ~13,400 features
|
||||
market_state = training_data.market_state
|
||||
universal_stream = training_data.universal_stream
|
||||
|
||||
# Create comprehensive training context
|
||||
enhanced_context = {
|
||||
'trade_outcome': training_episode,
|
||||
'market_state': market_state,
|
||||
'universal_stream': universal_stream,
|
||||
'tick_cache': training_data.tick_cache if hasattr(training_data, 'tick_cache') else [],
|
||||
'multi_timeframe_data': training_data.multi_timeframe_data if hasattr(training_data, 'multi_timeframe_data') else {},
|
||||
'cnn_features': training_data.cnn_features if hasattr(training_data, 'cnn_features') else None,
|
||||
'cnn_predictions': training_data.cnn_predictions if hasattr(training_data, 'cnn_predictions') else None
|
||||
}
|
||||
|
||||
# Send to enhanced RL trainer
|
||||
if hasattr(self.orchestrator, 'enhanced_rl_trainer'):
|
||||
try:
|
||||
# Add trading experience with comprehensive context
|
||||
symbol = training_episode['symbol']
|
||||
action = TradingAction(
|
||||
action=training_episode['side'],
|
||||
symbol=symbol,
|
||||
confidence=0.8, # Inferred from executed trade
|
||||
price=training_episode['exit_price'],
|
||||
size=0.1, # Default size
|
||||
timestamp=training_episode['timestamp']
|
||||
)
|
||||
|
||||
# Create initial and final market states for RL learning
|
||||
initial_state = market_state # State at trade entry
|
||||
final_state = market_state # State at trade exit (simplified)
|
||||
reward = training_episode['reward']
|
||||
|
||||
# Add comprehensive trading experience
|
||||
self.orchestrator.enhanced_rl_trainer.add_trading_experience(
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
initial_state=initial_state,
|
||||
final_state=final_state,
|
||||
reward=reward
|
||||
)
|
||||
|
||||
logger.info(f"[ENHANCED_RL] Added comprehensive trading experience for trade #{training_episode['trade_id']}")
|
||||
logger.info(f"[ENHANCED_RL] Market state features: ~13,400, Reward: {reward:.3f}")
|
||||
|
||||
# Update enhanced RL statistics
|
||||
self.rl_training_stats['enhanced_rl_episodes'] += 1
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in enhanced RL trainer: {e}")
|
||||
return False
|
||||
|
||||
# Send to extrema trainer for CNN learning
|
||||
if hasattr(self.orchestrator, 'extrema_trainer'):
|
||||
try:
|
||||
# Mark this trade outcome for CNN training
|
||||
trade_context = {
|
||||
'symbol': training_episode['symbol'],
|
||||
'entry_price': training_episode['entry_price'],
|
||||
'exit_price': training_episode['exit_price'],
|
||||
'is_profitable': training_episode['is_profitable'],
|
||||
'timestamp': training_episode['timestamp']
|
||||
}
|
||||
|
||||
# Add to extrema training if this was a good/bad move
|
||||
if abs(training_episode['net_pnl']) > 0.5: # Significant move
|
||||
self.orchestrator.extrema_trainer.add_trade_outcome_for_learning(trade_context)
|
||||
logger.debug(f"[EXTREMA_CNN] Added trade outcome for CNN learning")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error adding to extrema trainer: {e}")
|
||||
|
||||
# Send to sensitivity learning DQN
|
||||
if hasattr(self.orchestrator, 'sensitivity_learning_queue'):
|
||||
try:
|
||||
sensitivity_data = {
|
||||
'trade_outcome': training_episode,
|
||||
'market_context': enhanced_context,
|
||||
'learning_priority': 'high' if abs(training_episode['net_pnl']) > 1.0 else 'normal'
|
||||
}
|
||||
|
||||
self.orchestrator.sensitivity_learning_queue.append(sensitivity_data)
|
||||
logger.debug(f"[SENSITIVITY_DQN] Added trade outcome for sensitivity learning")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error adding to sensitivity learning: {e}")
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"[ENHANCED_RL] No comprehensive training data available, falling back to basic training")
|
||||
return self._execute_rl_training_step(training_episode)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing enhanced RL training step: {e}")
|
||||
return False
|
||||
|
||||
def _calculate_rl_reward(self, closed_trade):
|
||||
"""Calculate reward for RL training based on trade performance"""
|
||||
try:
|
||||
@ -3658,6 +4106,54 @@ class TradingDashboard:
|
||||
"""Get current RL training statistics"""
|
||||
return self.rl_training_stats.copy()
|
||||
|
||||
def stop_streaming(self):
|
||||
"""Stop all streaming and training components"""
|
||||
try:
|
||||
logger.info("Stopping dashboard streaming and training components...")
|
||||
|
||||
# Stop unified data stream
|
||||
if ENHANCED_RL_AVAILABLE and hasattr(self, 'unified_stream'):
|
||||
try:
|
||||
asyncio.run(self.unified_stream.stop_streaming())
|
||||
if hasattr(self, 'stream_consumer_id'):
|
||||
self.unified_stream.unregister_consumer(self.stream_consumer_id)
|
||||
logger.info("Unified data stream stopped")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping unified stream: {e}")
|
||||
|
||||
# Stop WebSocket streaming
|
||||
self.is_streaming = False
|
||||
if self.ws_connection:
|
||||
try:
|
||||
self.ws_connection.close()
|
||||
logger.info("WebSocket connection closed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing WebSocket: {e}")
|
||||
|
||||
if self.ws_thread and self.ws_thread.is_alive():
|
||||
try:
|
||||
self.ws_thread.join(timeout=5)
|
||||
logger.info("WebSocket thread stopped")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping WebSocket thread: {e}")
|
||||
|
||||
# Stop continuous training
|
||||
self.stop_continuous_training()
|
||||
|
||||
# Stop enhanced RL training if available
|
||||
if self.enhanced_rl_training_enabled and hasattr(self.orchestrator, 'enhanced_rl_trainer'):
|
||||
try:
|
||||
if hasattr(self.orchestrator.enhanced_rl_trainer, 'stop_training'):
|
||||
asyncio.run(self.orchestrator.enhanced_rl_trainer.stop_training())
|
||||
logger.info("Enhanced RL training stopped")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping enhanced RL training: {e}")
|
||||
|
||||
logger.info("All streaming and training components stopped")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping streaming: {e}")
|
||||
|
||||
|
||||
def create_dashboard(data_provider: DataProvider = None, orchestrator: TradingOrchestrator = None, trading_executor: TradingExecutor = None) -> TradingDashboard:
|
||||
"""Factory function to create a trading dashboard"""
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user