Compare commits
5 Commits
97ea27ea84
...
296e1be422
Author | SHA1 | Date | |
---|---|---|---|
296e1be422 | |||
4c53871014 | |||
fab25ffe6f | |||
601e44de25 | |||
d791ab8b14 |
@ -130,7 +130,7 @@ class DQNAgent:
|
||||
result = load_best_checkpoint(self.model_name)
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
checkpoint = torch.load(file_path, map_location=self.device)
|
||||
checkpoint = torch.load(file_path, map_location=self.device, weights_only=False)
|
||||
|
||||
# Load model states
|
||||
if 'policy_net_state_dict' in checkpoint:
|
||||
@ -1212,7 +1212,7 @@ class DQNAgent:
|
||||
|
||||
# Load agent state
|
||||
try:
|
||||
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device)
|
||||
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device, weights_only=False)
|
||||
self.epsilon = agent_state['epsilon']
|
||||
self.update_count = agent_state['update_count']
|
||||
self.losses = agent_state['losses']
|
||||
|
@ -224,5 +224,49 @@
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
}
|
||||
],
|
||||
"dqn_agent": [
|
||||
{
|
||||
"checkpoint_id": "dqn_agent_20250627_030115",
|
||||
"model_name": "dqn_agent",
|
||||
"model_type": "dqn",
|
||||
"file_path": "models\\saved\\dqn_agent\\dqn_agent_20250627_030115.pt",
|
||||
"created_at": "2025-06-27T03:01:15.021842",
|
||||
"file_size_mb": 57.57266807556152,
|
||||
"performance_score": 95.0,
|
||||
"accuracy": 0.85,
|
||||
"loss": 0.0145,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
}
|
||||
],
|
||||
"enhanced_cnn": [
|
||||
{
|
||||
"checkpoint_id": "enhanced_cnn_20250627_030115",
|
||||
"model_name": "enhanced_cnn",
|
||||
"model_type": "cnn",
|
||||
"file_path": "models\\saved\\enhanced_cnn\\enhanced_cnn_20250627_030115.pt",
|
||||
"created_at": "2025-06-27T03:01:15.024856",
|
||||
"file_size_mb": 0.7184391021728516,
|
||||
"performance_score": 92.0,
|
||||
"accuracy": 0.88,
|
||||
"loss": 0.0187,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
}
|
||||
]
|
||||
}
|
@ -2193,135 +2193,24 @@ class DataProvider:
|
||||
logger.error(f"Error getting BOM matrix for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def generate_synthetic_bom_features(self, symbol: str) -> List[float]:
|
||||
def get_real_bom_features(self, symbol: str) -> Optional[List[float]]:
|
||||
"""
|
||||
Generate synthetic BOM features when real COB data is not available
|
||||
Get REAL BOM features from actual market data ONLY
|
||||
|
||||
This creates realistic-looking order book features based on current market data
|
||||
NO SYNTHETIC DATA - Returns None if real data is not available
|
||||
"""
|
||||
try:
|
||||
features = []
|
||||
# Try to get real COB data from integration
|
||||
if hasattr(self, 'cob_integration') and self.cob_integration:
|
||||
return self._extract_real_bom_features(symbol, self.cob_integration)
|
||||
|
||||
# Get current price for context
|
||||
current_price = self.get_current_price(symbol)
|
||||
if current_price is None:
|
||||
current_price = 3000.0 # Fallback price
|
||||
|
||||
# === 1. CONSOLIDATED ORDER BOOK DATA (40 features) ===
|
||||
# Top 10 bid levels (price offset + volume)
|
||||
for i in range(10):
|
||||
price_offset = -0.001 * (i + 1) * (1 + np.random.normal(0, 0.1)) # Negative for bids
|
||||
volume_normalized = np.random.exponential(0.5) * (1.0 - i * 0.1) # Decreasing with depth
|
||||
features.extend([price_offset, volume_normalized])
|
||||
|
||||
# Top 10 ask levels (price offset + volume)
|
||||
for i in range(10):
|
||||
price_offset = 0.001 * (i + 1) * (1 + np.random.normal(0, 0.1)) # Positive for asks
|
||||
volume_normalized = np.random.exponential(0.5) * (1.0 - i * 0.1) # Decreasing with depth
|
||||
features.extend([price_offset, volume_normalized])
|
||||
|
||||
# === 2. VOLUME PROFILE FEATURES (30 features) ===
|
||||
# Top 10 volume levels (buy%, sell%, total volume)
|
||||
for i in range(10):
|
||||
buy_percent = 0.3 + np.random.normal(0, 0.2) # Around 30-70% buy
|
||||
buy_percent = max(0.0, min(1.0, buy_percent))
|
||||
sell_percent = 1.0 - buy_percent
|
||||
total_volume = np.random.exponential(1.0) * (1.0 - i * 0.05)
|
||||
features.extend([buy_percent, sell_percent, total_volume])
|
||||
|
||||
# === 3. ORDER FLOW INTENSITY (25 features) ===
|
||||
# Aggressive order flow
|
||||
features.extend([
|
||||
0.5 + np.random.normal(0, 0.1), # Aggressive buy ratio
|
||||
0.5 + np.random.normal(0, 0.1), # Aggressive sell ratio
|
||||
0.4 + np.random.normal(0, 0.1), # Buy volume ratio
|
||||
0.4 + np.random.normal(0, 0.1), # Sell volume ratio
|
||||
np.random.exponential(100), # Avg aggressive buy size
|
||||
np.random.exponential(100), # Avg aggressive sell size
|
||||
])
|
||||
|
||||
# Block trade detection
|
||||
features.extend([
|
||||
0.1 + np.random.exponential(0.05), # Large trade ratio
|
||||
0.2 + np.random.exponential(0.1), # Large trade volume ratio
|
||||
np.random.exponential(1000), # Avg large trade size
|
||||
])
|
||||
|
||||
# Flow velocity metrics
|
||||
features.extend([
|
||||
1.0 + np.random.normal(0, 0.2), # Avg time delta
|
||||
0.1 + np.random.exponential(0.05), # Time velocity variance
|
||||
0.5 + np.random.normal(0, 0.1), # Trade clustering
|
||||
])
|
||||
|
||||
# Institutional activity indicators
|
||||
features.extend([
|
||||
0.05 + np.random.exponential(0.02), # Iceberg detection
|
||||
0.3 + np.random.normal(0, 0.1), # Hidden order ratio
|
||||
0.2 + np.random.normal(0, 0.05), # Smart money flow
|
||||
0.1 + np.random.exponential(0.03), # Algorithmic activity
|
||||
])
|
||||
|
||||
# Market maker behavior
|
||||
features.extend([
|
||||
0.6 + np.random.normal(0, 0.1), # MM provision ratio
|
||||
0.4 + np.random.normal(0, 0.1), # MM take ratio
|
||||
0.02 + np.random.normal(0, 0.005), # Spread tightening
|
||||
1.0 + np.random.normal(0, 0.2), # Quote update frequency
|
||||
0.8 + np.random.normal(0, 0.1), # Quote stability
|
||||
])
|
||||
|
||||
# === 4. MARKET MICROSTRUCTURE SIGNALS (25 features) ===
|
||||
# Order book pressure
|
||||
features.extend([
|
||||
0.5 + np.random.normal(0, 0.1), # Bid pressure
|
||||
0.5 + np.random.normal(0, 0.1), # Ask pressure
|
||||
0.0 + np.random.normal(0, 0.05), # Pressure imbalance
|
||||
1.0 + np.random.normal(0, 0.2), # Pressure intensity
|
||||
0.5 + np.random.normal(0, 0.1), # Depth stability
|
||||
])
|
||||
|
||||
# Price level concentration
|
||||
features.extend([
|
||||
0.3 + np.random.normal(0, 0.1), # Bid concentration
|
||||
0.3 + np.random.normal(0, 0.1), # Ask concentration
|
||||
0.8 + np.random.normal(0, 0.1), # Top level dominance
|
||||
0.2 + np.random.normal(0, 0.05), # Fragmentation index
|
||||
0.6 + np.random.normal(0, 0.1), # Liquidity clustering
|
||||
])
|
||||
|
||||
# Temporal dynamics
|
||||
features.extend([
|
||||
0.1 + np.random.normal(0, 0.02), # Volatility factor
|
||||
1.0 + np.random.normal(0, 0.1), # Momentum factor
|
||||
0.0 + np.random.normal(0, 0.05), # Mean reversion
|
||||
0.5 + np.random.normal(0, 0.1), # Trend alignment
|
||||
0.8 + np.random.normal(0, 0.1), # Pattern consistency
|
||||
])
|
||||
|
||||
# Exchange-specific patterns
|
||||
features.extend([
|
||||
0.4 + np.random.normal(0, 0.1), # Cross-exchange correlation
|
||||
0.3 + np.random.normal(0, 0.1), # Exchange arbitrage
|
||||
0.2 + np.random.normal(0, 0.05), # Latency patterns
|
||||
0.8 + np.random.normal(0, 0.1), # Sync quality
|
||||
0.6 + np.random.normal(0, 0.1), # Data freshness
|
||||
])
|
||||
|
||||
# Ensure exactly 120 features
|
||||
if len(features) > 120:
|
||||
features = features[:120]
|
||||
elif len(features) < 120:
|
||||
features.extend([0.0] * (120 - len(features)))
|
||||
|
||||
# Clamp all values to reasonable ranges
|
||||
features = [max(-5.0, min(5.0, f)) for f in features]
|
||||
|
||||
return features
|
||||
# No real data available - return None instead of synthetic
|
||||
logger.warning(f"No real BOM data available for {symbol} - waiting for real market data")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating synthetic BOM features for {symbol}: {e}")
|
||||
return [0.0] * 120
|
||||
logger.error(f"Error getting real BOM features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def start_bom_cache_updates(self, cob_integration=None):
|
||||
"""
|
||||
@ -2342,17 +2231,14 @@ class DataProvider:
|
||||
if bom_features:
|
||||
self.update_bom_cache(symbol, bom_features, cob_integration)
|
||||
else:
|
||||
# Fallback to synthetic
|
||||
synthetic_features = self.generate_synthetic_bom_features(symbol)
|
||||
self.update_bom_cache(symbol, synthetic_features)
|
||||
# NO SYNTHETIC FALLBACK - Wait for real data
|
||||
logger.warning(f"No real BOM features available for {symbol} - waiting for real data")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting real BOM features for {symbol}: {e}")
|
||||
synthetic_features = self.generate_synthetic_bom_features(symbol)
|
||||
self.update_bom_cache(symbol, synthetic_features)
|
||||
logger.warning(f"Waiting for real data instead of using synthetic")
|
||||
else:
|
||||
# Generate synthetic BOM features
|
||||
synthetic_features = self.generate_synthetic_bom_features(symbol)
|
||||
self.update_bom_cache(symbol, synthetic_features)
|
||||
# NO SYNTHETIC FEATURES - Wait for real COB integration
|
||||
logger.warning(f"No COB integration available for {symbol} - waiting for real data")
|
||||
|
||||
time.sleep(1.0) # Update every second
|
||||
|
||||
@ -2470,7 +2356,9 @@ class DataProvider:
|
||||
"""Extract flow and microstructure features"""
|
||||
try:
|
||||
# For now, return synthetic features since full implementation would be complex
|
||||
return self.generate_synthetic_bom_features(symbol)[70:] # Last 50 features
|
||||
# NO SYNTHETIC DATA - Return None if no real microstructure data
|
||||
logger.warning(f"No real microstructure data available for {symbol}")
|
||||
return None
|
||||
except:
|
||||
return [0.0] * 50
|
||||
|
||||
|
1107
core/orchestrator.py
1107
core/orchestrator.py
File diff suppressed because it is too large
Load Diff
@ -1183,7 +1183,7 @@ class EnhancedRealtimeTrainingSystem:
|
||||
if symbol in self.recent_dqn_predictions:
|
||||
self.recent_dqn_predictions[symbol].append(display_prediction)
|
||||
|
||||
self.last_prediction_time[symbol] = current_time
|
||||
self.last_prediction_time[symbol] = int(current_time)
|
||||
|
||||
logger.info(f"Forward DQN prediction: {symbol} action={['BUY','SELL','HOLD'][action]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
|
||||
|
||||
|
@ -51,15 +51,17 @@ async def start_training_pipeline(orchestrator, trading_executor):
|
||||
}
|
||||
|
||||
try:
|
||||
# Start real-time processing (if available in Basic orchestrator)
|
||||
# Start real-time processing (available in Enhanced orchestrator)
|
||||
if hasattr(orchestrator, 'start_realtime_processing'):
|
||||
await orchestrator.start_realtime_processing()
|
||||
logger.info("Real-time processing started")
|
||||
else:
|
||||
logger.info("Real-time processing not available in Basic orchestrator")
|
||||
|
||||
# COB integration not available in Basic orchestrator
|
||||
logger.info("COB integration not available - using Basic orchestrator")
|
||||
# Start COB integration (available in Enhanced orchestrator)
|
||||
if hasattr(orchestrator, 'start_cob_integration'):
|
||||
await orchestrator.start_cob_integration()
|
||||
logger.info("COB integration started - 5-minute data matrix active")
|
||||
else:
|
||||
logger.info("COB integration not available")
|
||||
|
||||
# Main training loop
|
||||
iteration = 0
|
||||
@ -146,9 +148,9 @@ def start_clean_dashboard_with_training():
|
||||
# Create data provider
|
||||
data_provider = DataProvider()
|
||||
|
||||
# Create basic orchestrator - stable and efficient
|
||||
orchestrator = TradingOrchestrator(data_provider)
|
||||
logger.info("Basic Trading Orchestrator created for stability")
|
||||
# Create enhanced orchestrator with COB integration - stable and efficient
|
||||
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
|
||||
logger.info("Enhanced Trading Orchestrator created with COB integration")
|
||||
|
||||
# Create trading executor
|
||||
trading_executor = TradingExecutor()
|
||||
|
@ -1,309 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Model Predictions Visualization
|
||||
|
||||
This script demonstrates the enhanced model prediction visualization system
|
||||
that shows DQN actions, CNN price predictions, and accuracy feedback on the price chart.
|
||||
|
||||
Features tested:
|
||||
- DQN action predictions (BUY/SELL/HOLD) as directional arrows with confidence-based sizing
|
||||
- CNN price direction predictions as trend lines with target markers
|
||||
- Prediction accuracy feedback with color-coded results
|
||||
- Real-time prediction tracking and storage
|
||||
- Mock prediction generation for demonstration
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.trading_executor import TradingExecutor
|
||||
from web.clean_dashboard import create_clean_dashboard
|
||||
from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelPredictionTester:
|
||||
"""Test model prediction visualization capabilities"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = get_config()
|
||||
self.data_provider = DataProvider()
|
||||
self.trading_executor = TradingExecutor()
|
||||
self.orchestrator = TradingOrchestrator(
|
||||
data_provider=self.data_provider,
|
||||
enhanced_rl_training=True,
|
||||
model_registry={}
|
||||
)
|
||||
|
||||
# Initialize enhanced training system
|
||||
self.training_system = EnhancedRealtimeTrainingSystem(
|
||||
orchestrator=self.orchestrator,
|
||||
data_provider=self.data_provider,
|
||||
dashboard=None # Will be set after dashboard creation
|
||||
)
|
||||
|
||||
# Create dashboard with enhanced prediction visualization
|
||||
self.dashboard = create_clean_dashboard(
|
||||
data_provider=self.data_provider,
|
||||
orchestrator=self.orchestrator,
|
||||
trading_executor=self.trading_executor
|
||||
)
|
||||
|
||||
# Connect training system to dashboard
|
||||
self.training_system.dashboard = self.dashboard
|
||||
self.dashboard.training_system = self.training_system
|
||||
|
||||
# Test data
|
||||
self.test_symbols = ['ETH/USDT', 'BTC/USDT']
|
||||
self.prediction_count = 0
|
||||
|
||||
logger.info("Model Prediction Tester initialized")
|
||||
|
||||
def generate_mock_dqn_predictions(self, symbol: str, count: int = 10):
|
||||
"""Generate mock DQN predictions for testing"""
|
||||
try:
|
||||
current_price = self.data_provider.get_current_price(symbol) or 2400.0
|
||||
|
||||
for i in range(count):
|
||||
# Generate realistic state vector
|
||||
state = np.random.random(100) # 100-dimensional state
|
||||
|
||||
# Generate Q-values with some logic
|
||||
q_values = [np.random.random(), np.random.random(), np.random.random()]
|
||||
action = np.argmax(q_values) # Best action
|
||||
confidence = max(q_values) / sum(q_values) # Confidence based on Q-value distribution
|
||||
|
||||
# Add some price variation
|
||||
pred_price = current_price + np.random.normal(0, 20)
|
||||
|
||||
# Capture prediction
|
||||
self.training_system.capture_dqn_prediction(
|
||||
symbol=symbol,
|
||||
state=state,
|
||||
q_values=q_values,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
price=pred_price
|
||||
)
|
||||
|
||||
self.prediction_count += 1
|
||||
|
||||
logger.info(f"Generated DQN prediction {i+1}/{count}: {symbol} action={['BUY', 'SELL', 'HOLD'][action]} confidence={confidence:.2f}")
|
||||
|
||||
# Small delay between predictions
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating DQN predictions: {e}")
|
||||
|
||||
def generate_mock_cnn_predictions(self, symbol: str, count: int = 8):
|
||||
"""Generate mock CNN predictions for testing"""
|
||||
try:
|
||||
current_price = self.data_provider.get_current_price(symbol) or 2400.0
|
||||
|
||||
for i in range(count):
|
||||
# Generate direction with some logic
|
||||
direction = np.random.choice([0, 1, 2], p=[0.3, 0.2, 0.5]) # Slightly bullish
|
||||
confidence = 0.4 + np.random.random() * 0.5 # 0.4-0.9 confidence
|
||||
|
||||
# Calculate predicted price based on direction
|
||||
if direction == 2: # UP
|
||||
price_change = np.random.uniform(5, 50)
|
||||
elif direction == 0: # DOWN
|
||||
price_change = -np.random.uniform(5, 50)
|
||||
else: # SAME
|
||||
price_change = np.random.uniform(-5, 5)
|
||||
|
||||
predicted_price = current_price + price_change
|
||||
|
||||
# Generate features
|
||||
features = np.random.random((15, 20)).flatten() # Flattened CNN features
|
||||
|
||||
# Capture prediction
|
||||
self.training_system.capture_cnn_prediction(
|
||||
symbol=symbol,
|
||||
current_price=current_price,
|
||||
predicted_price=predicted_price,
|
||||
direction=direction,
|
||||
confidence=confidence,
|
||||
features=features
|
||||
)
|
||||
|
||||
self.prediction_count += 1
|
||||
|
||||
logger.info(f"Generated CNN prediction {i+1}/{count}: {symbol} direction={['DOWN', 'SAME', 'UP'][direction]} confidence={confidence:.2f}")
|
||||
|
||||
# Small delay between predictions
|
||||
time.sleep(0.2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating CNN predictions: {e}")
|
||||
|
||||
def generate_mock_accuracy_data(self, symbol: str, count: int = 15):
|
||||
"""Generate mock prediction accuracy data for testing"""
|
||||
try:
|
||||
current_price = self.data_provider.get_current_price(symbol) or 2400.0
|
||||
|
||||
for i in range(count):
|
||||
# Randomly choose prediction type
|
||||
prediction_type = np.random.choice(['DQN', 'CNN'])
|
||||
predicted_action = np.random.choice([0, 1, 2])
|
||||
confidence = 0.3 + np.random.random() * 0.6
|
||||
|
||||
# Generate realistic price change
|
||||
actual_price_change = np.random.normal(0, 0.01) # ±1% typical change
|
||||
|
||||
# Validate accuracy
|
||||
self.training_system.validate_prediction_accuracy(
|
||||
symbol=symbol,
|
||||
prediction_type=prediction_type,
|
||||
predicted_action=predicted_action,
|
||||
actual_price_change=actual_price_change,
|
||||
confidence=confidence
|
||||
)
|
||||
|
||||
logger.info(f"Generated accuracy data {i+1}/{count}: {symbol} {prediction_type} action={predicted_action}")
|
||||
|
||||
# Small delay
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating accuracy data: {e}")
|
||||
|
||||
def run_prediction_generation_test(self):
|
||||
"""Run comprehensive prediction generation test"""
|
||||
try:
|
||||
logger.info("Starting Model Prediction Visualization Test")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Test for each symbol
|
||||
for symbol in self.test_symbols:
|
||||
logger.info(f"\nGenerating predictions for {symbol}...")
|
||||
|
||||
# Generate DQN predictions
|
||||
logger.info(f"Generating DQN predictions for {symbol}...")
|
||||
self.generate_mock_dqn_predictions(symbol, count=12)
|
||||
|
||||
# Generate CNN predictions
|
||||
logger.info(f"Generating CNN predictions for {symbol}...")
|
||||
self.generate_mock_cnn_predictions(symbol, count=8)
|
||||
|
||||
# Generate accuracy data
|
||||
logger.info(f"Generating accuracy data for {symbol}...")
|
||||
self.generate_mock_accuracy_data(symbol, count=20)
|
||||
|
||||
# Get prediction summary
|
||||
summary = self.training_system.get_prediction_summary(symbol)
|
||||
logger.info(f"Prediction summary for {symbol}: {summary}")
|
||||
|
||||
# Log total statistics
|
||||
training_stats = self.training_system.get_training_statistics()
|
||||
logger.info("\nTraining System Statistics:")
|
||||
logger.info(f"Total predictions generated: {self.prediction_count}")
|
||||
logger.info(f"Prediction stats: {training_stats.get('prediction_stats', {})}")
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("Prediction generation test completed successfully!")
|
||||
logger.info("Dashboard should now show enhanced model predictions on the price chart:")
|
||||
logger.info("- Green/Red arrows for DQN BUY/SELL predictions")
|
||||
logger.info("- Gray circles for DQN HOLD predictions")
|
||||
logger.info("- Colored trend lines for CNN price direction predictions")
|
||||
logger.info("- Diamond markers for CNN prediction targets")
|
||||
logger.info("- Green/Red X marks for correct/incorrect prediction feedback")
|
||||
logger.info("- Hover tooltips showing confidence, Q-values, and accuracy scores")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prediction generation test: {e}")
|
||||
|
||||
def start_dashboard_with_predictions(self, host='127.0.0.1', port=8051):
|
||||
"""Start dashboard with enhanced prediction visualization"""
|
||||
try:
|
||||
logger.info(f"Starting dashboard with model predictions at http://{host}:{port}")
|
||||
|
||||
# Run prediction generation in background
|
||||
import threading
|
||||
pred_thread = threading.Thread(target=self.run_prediction_generation_test, daemon=True)
|
||||
pred_thread.start()
|
||||
|
||||
# Start training system
|
||||
self.training_system.start_training()
|
||||
|
||||
# Start dashboard
|
||||
self.dashboard.run_server(host=host, port=port, debug=False)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting dashboard with predictions: {e}")
|
||||
|
||||
def test_prediction_accuracy_validation(self):
|
||||
"""Test prediction accuracy validation logic"""
|
||||
try:
|
||||
logger.info("Testing prediction accuracy validation...")
|
||||
|
||||
# Test DQN accuracy validation
|
||||
test_cases = [
|
||||
('DQN', 0, 0.01, 0.8, True), # BUY + price up = correct
|
||||
('DQN', 1, -0.01, 0.7, True), # SELL + price down = correct
|
||||
('DQN', 2, 0.0005, 0.6, True), # HOLD + no change = correct
|
||||
('DQN', 0, -0.01, 0.8, False), # BUY + price down = incorrect
|
||||
('CNN', 2, 0.01, 0.9, True), # UP + price up = correct
|
||||
('CNN', 0, -0.01, 0.8, True), # DOWN + price down = correct
|
||||
('CNN', 1, 0.0005, 0.7, True), # SAME + no change = correct
|
||||
('CNN', 2, -0.01, 0.9, False), # UP + price down = incorrect
|
||||
]
|
||||
|
||||
for prediction_type, action, price_change, confidence, expected_correct in test_cases:
|
||||
self.training_system.validate_prediction_accuracy(
|
||||
symbol='ETH/USDT',
|
||||
prediction_type=prediction_type,
|
||||
predicted_action=action,
|
||||
actual_price_change=price_change,
|
||||
confidence=confidence
|
||||
)
|
||||
|
||||
# Check if validation worked correctly
|
||||
if self.training_system.prediction_accuracy_history['ETH/USDT']:
|
||||
latest = list(self.training_system.prediction_accuracy_history['ETH/USDT'])[-1]
|
||||
actual_correct = latest['correct']
|
||||
|
||||
status = "✓" if actual_correct == expected_correct else "✗"
|
||||
logger.info(f"{status} {prediction_type} action={action} change={price_change:.4f} -> correct={actual_correct}")
|
||||
|
||||
logger.info("Prediction accuracy validation test completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing prediction accuracy validation: {e}")
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
try:
|
||||
# Create tester
|
||||
tester = ModelPredictionTester()
|
||||
|
||||
# Run accuracy validation test first
|
||||
tester.test_prediction_accuracy_validation()
|
||||
|
||||
# Start dashboard with enhanced predictions
|
||||
logger.info("\nStarting dashboard with enhanced model prediction visualization...")
|
||||
logger.info("Visit http://127.0.0.1:8051 to see the enhanced price chart with model predictions")
|
||||
|
||||
tester.start_dashboard_with_predictions()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Test interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in main test: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
File diff suppressed because it is too large
Load Diff
@ -4,8 +4,10 @@ Manages the formatting and creation of dashboard components
|
||||
"""
|
||||
|
||||
from dash import html, dcc
|
||||
import dash_bootstrap_components as dbc
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -245,139 +247,108 @@ class DashboardComponentManager:
|
||||
logger.error(f"Error formatting system status: {e}")
|
||||
return [html.P(f"Error: {str(e)}", className="text-danger small")]
|
||||
|
||||
def _format_cnn_pivot_prediction(self, model_info):
|
||||
"""Format CNN pivot prediction for display"""
|
||||
try:
|
||||
pivot_prediction = model_info.get('pivot_prediction')
|
||||
if not pivot_prediction:
|
||||
return html.Div()
|
||||
|
||||
pivot_type = pivot_prediction.get('pivot_type', 'UNKNOWN')
|
||||
predicted_price = pivot_prediction.get('predicted_price', 0)
|
||||
confidence = pivot_prediction.get('confidence', 0)
|
||||
time_horizon = pivot_prediction.get('time_horizon_minutes', 0)
|
||||
|
||||
# Color coding for pivot types
|
||||
if 'RESISTANCE' in pivot_type:
|
||||
pivot_color = "text-danger"
|
||||
pivot_icon = "fas fa-arrow-up"
|
||||
elif 'SUPPORT' in pivot_type:
|
||||
pivot_color = "text-success"
|
||||
pivot_icon = "fas fa-arrow-down"
|
||||
else:
|
||||
pivot_color = "text-warning"
|
||||
pivot_icon = "fas fa-arrows-alt-h"
|
||||
|
||||
return html.Div([
|
||||
html.Div([
|
||||
html.I(className=f"{pivot_icon} me-1 {pivot_color}"),
|
||||
html.Span("Next Pivot: ", className="text-muted small"),
|
||||
html.Span(f"${predicted_price:.2f}", className=f"small fw-bold {pivot_color}")
|
||||
], className="mb-1"),
|
||||
html.Div([
|
||||
html.Span(f"{pivot_type.replace('_', ' ')}", className=f"small {pivot_color}"),
|
||||
html.Span(f" ({confidence:.0%}) in ~{time_horizon}m", className="text-muted small")
|
||||
])
|
||||
], className="mt-1 p-1", style={"backgroundColor": "rgba(255,255,255,0.02)", "borderRadius": "3px"})
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error formatting CNN pivot prediction: {e}")
|
||||
return html.Div()
|
||||
|
||||
def format_cob_data(self, cob_snapshot, symbol):
|
||||
"""Format COB data for display"""
|
||||
"""Format COB data into a ladder display with volume bars"""
|
||||
try:
|
||||
if not cob_snapshot:
|
||||
return [html.P("No COB data", className="text-muted small")]
|
||||
if not cob_snapshot or not hasattr(cob_snapshot, 'stats'):
|
||||
return html.Div([
|
||||
html.H6(f"{symbol} COB", className="mb-2"),
|
||||
html.P("No COB data available", className="text-muted small")
|
||||
])
|
||||
|
||||
# Real COB data display
|
||||
cob_info = []
|
||||
stats = cob_snapshot.stats if hasattr(cob_snapshot, 'stats') else {}
|
||||
mid_price = stats.get('mid_price', 0)
|
||||
spread_bps = stats.get('spread_bps', 0)
|
||||
total_bid_liquidity = stats.get('total_bid_liquidity', 0)
|
||||
total_ask_liquidity = stats.get('total_ask_liquidity', 0)
|
||||
imbalance = stats.get('imbalance', 0)
|
||||
bids = getattr(cob_snapshot, 'consolidated_bids', [])
|
||||
asks = getattr(cob_snapshot, 'consolidated_asks', [])
|
||||
|
||||
if mid_price == 0 or not bids or not asks:
|
||||
return html.Div([
|
||||
html.H6(f"{symbol} COB", className="mb-2"),
|
||||
html.P("Awaiting valid order book data...", className="text-muted small")
|
||||
])
|
||||
|
||||
# Header with summary stats
|
||||
imbalance_text = f"Bid Heavy ({imbalance:.3f})" if imbalance > 0 else f"Ask Heavy ({imbalance:.3f})"
|
||||
imbalance_color = "text-success" if imbalance > 0 else "text-danger"
|
||||
|
||||
# Symbol header
|
||||
cob_info.append(html.Div([
|
||||
html.Strong(f"{symbol}", className="text-info"),
|
||||
html.Span(" - COB Snapshot", className="small text-muted")
|
||||
], className="mb-2"))
|
||||
|
||||
# Check if we have a real COB snapshot object
|
||||
if hasattr(cob_snapshot, 'volume_weighted_mid'):
|
||||
# Real COB snapshot data
|
||||
mid_price = getattr(cob_snapshot, 'volume_weighted_mid', 0)
|
||||
spread_bps = getattr(cob_snapshot, 'spread_bps', 0)
|
||||
bid_liquidity = getattr(cob_snapshot, 'total_bid_liquidity', 0)
|
||||
ask_liquidity = getattr(cob_snapshot, 'total_ask_liquidity', 0)
|
||||
imbalance = getattr(cob_snapshot, 'liquidity_imbalance', 0)
|
||||
bid_levels = len(getattr(cob_snapshot, 'consolidated_bids', []))
|
||||
ask_levels = len(getattr(cob_snapshot, 'consolidated_asks', []))
|
||||
|
||||
# Price and spread
|
||||
cob_info.append(html.Div([
|
||||
html.Div([
|
||||
html.I(className="fas fa-dollar-sign text-success me-2"),
|
||||
html.Span(f"Mid: ${mid_price:.2f}", className="small fw-bold")
|
||||
], className="mb-1"),
|
||||
html.Div([
|
||||
html.I(className="fas fa-arrows-alt-h text-warning me-2"),
|
||||
html.Span(f"Spread: {spread_bps:.1f} bps", className="small")
|
||||
], className="mb-1")
|
||||
header = html.Div([
|
||||
html.H6(f"{symbol} - COB Ladder", className="mb-1"),
|
||||
html.Div([
|
||||
html.Span(f"Mid: ${mid_price:,.2f}", className="me-3"),
|
||||
html.Span(f"Spread: {spread_bps:.1f} bps", className="me-3"),
|
||||
html.Span(f"Imbalance: ", className="small"),
|
||||
html.Span(imbalance_text, className=f"fw-bold {imbalance_color}")
|
||||
], className="small text-muted")
|
||||
], className="mb-2")
|
||||
|
||||
# --- Ladder Creation ---
|
||||
bucket_size = 10 # $10 price buckets
|
||||
num_levels = 5 # 5 levels above and below
|
||||
|
||||
# Aggregate bids and asks into buckets
|
||||
def aggregate_buckets(orders, mid_price, bucket_size):
|
||||
buckets = {}
|
||||
for order in orders:
|
||||
price = order.get('price', 0)
|
||||
size = order.get('size', 0)
|
||||
if price > 0:
|
||||
bucket_key = round(price / bucket_size) * bucket_size
|
||||
if bucket_key not in buckets:
|
||||
buckets[bucket_key] = 0
|
||||
buckets[bucket_key] += size * price # Volume in quote currency (USD)
|
||||
return buckets
|
||||
|
||||
bid_buckets = aggregate_buckets(bids, mid_price, bucket_size)
|
||||
ask_buckets = aggregate_buckets(asks, mid_price, bucket_size)
|
||||
|
||||
all_volumes = list(bid_buckets.values()) + list(ask_buckets.values())
|
||||
max_volume = max(all_volumes) if all_volumes else 1
|
||||
|
||||
# Determine ladder price levels
|
||||
center_bucket = round(mid_price / bucket_size) * bucket_size
|
||||
ask_levels = [center_bucket + i * bucket_size for i in range(1, num_levels + 1)]
|
||||
bid_levels = [center_bucket - i * bucket_size for i in range(num_levels)]
|
||||
|
||||
# Create ladder rows
|
||||
ask_rows = []
|
||||
for price in sorted(ask_levels, reverse=True):
|
||||
volume = ask_buckets.get(price, 0)
|
||||
progress = (volume / max_volume) * 100
|
||||
ask_rows.append(html.Tr([
|
||||
html.Td(f"${price:,.2f}", className="text-danger price-level"),
|
||||
html.Td(f"${volume:,.0f}", className="volume-level"),
|
||||
html.Td(dbc.Progress(value=progress, color="danger", className="vh-25"), className="progress-cell")
|
||||
]))
|
||||
|
||||
# Liquidity info
|
||||
total_liquidity = bid_liquidity + ask_liquidity
|
||||
bid_pct = (bid_liquidity / total_liquidity * 100) if total_liquidity > 0 else 0
|
||||
ask_pct = (ask_liquidity / total_liquidity * 100) if total_liquidity > 0 else 0
|
||||
|
||||
cob_info.append(html.Div([
|
||||
html.Div([
|
||||
html.I(className="fas fa-layer-group text-info me-2"),
|
||||
html.Span(f"Liquidity: ${total_liquidity:,.0f}", className="small")
|
||||
], className="mb-1"),
|
||||
html.Div([
|
||||
html.Span(f"Bids: {bid_pct:.0f}% ", className="small text-success"),
|
||||
html.Span(f"Asks: {ask_pct:.0f}%", className="small text-danger")
|
||||
], className="mb-1")
|
||||
]))
|
||||
|
||||
# Order book depth
|
||||
cob_info.append(html.Div([
|
||||
html.Div([
|
||||
html.I(className="fas fa-list text-secondary me-2"),
|
||||
html.Span(f"Levels: {bid_levels} bids, {ask_levels} asks", className="small")
|
||||
], className="mb-1")
|
||||
]))
|
||||
|
||||
# Imbalance indicator
|
||||
imbalance_color = "text-success" if imbalance > 0.1 else "text-danger" if imbalance < -0.1 else "text-muted"
|
||||
imbalance_text = "Bid Heavy" if imbalance > 0.1 else "Ask Heavy" if imbalance < -0.1 else "Balanced"
|
||||
|
||||
cob_info.append(html.Div([
|
||||
html.I(className="fas fa-balance-scale me-2"),
|
||||
html.Span(f"Imbalance: ", className="small text-muted"),
|
||||
html.Span(f"{imbalance_text} ({imbalance:.3f})", className=f"small {imbalance_color}")
|
||||
], className="mb-1"))
|
||||
|
||||
else:
|
||||
# Fallback display for other data formats
|
||||
cob_info.append(html.Div([
|
||||
html.Div([
|
||||
html.I(className="fas fa-chart-bar text-success me-2"),
|
||||
html.Span("Order Book: Active", className="small")
|
||||
], className="mb-1"),
|
||||
html.Div([
|
||||
html.I(className="fas fa-coins text-warning me-2"),
|
||||
html.Span("Liquidity: Good", className="small")
|
||||
], className="mb-1"),
|
||||
html.Div([
|
||||
html.I(className="fas fa-balance-scale text-info me-2"),
|
||||
html.Span("Imbalance: Neutral", className="small")
|
||||
])
|
||||
|
||||
bid_rows = []
|
||||
for price in sorted(bid_levels, reverse=True):
|
||||
volume = bid_buckets.get(price, 0)
|
||||
progress = (volume / max_volume) * 100
|
||||
bid_rows.append(html.Tr([
|
||||
html.Td(f"${price:,.2f}", className="text-success price-level"),
|
||||
html.Td(f"${volume:,.0f}", className="volume-level"),
|
||||
html.Td(dbc.Progress(value=progress, color="success", className="vh-25"), className="progress-cell")
|
||||
]))
|
||||
|
||||
return cob_info
|
||||
|
||||
# Mid-price separator
|
||||
mid_row = html.Tr([
|
||||
html.Td(f"${mid_price:,.2f}", colSpan=3, className="text-center fw-bold text-white bg-secondary")
|
||||
])
|
||||
|
||||
ladder_table = html.Table([
|
||||
html.Thead(html.Tr([html.Th("Price (USD)"), html.Th("Volume (USD)"), html.Th("Total")])),
|
||||
html.Tbody(ask_rows + [mid_row] + bid_rows)
|
||||
], className="table table-sm table-dark cob-ladder-table")
|
||||
|
||||
return html.Div([header, ladder_table])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting COB data: {e}")
|
||||
return [html.P(f"Error: {str(e)}", className="text-danger small")]
|
||||
logger.error(f"Error formatting COB data ladder: {e}")
|
||||
return html.P(f"Error: {str(e)}", className="text-danger small")
|
||||
|
||||
def format_cob_data_with_buckets(self, cob_snapshot, symbol, price_buckets, memory_stats, bucket_size=1.0):
|
||||
"""Format COB data with price buckets for high-frequency display"""
|
||||
@ -607,6 +578,18 @@ class DashboardComponentManager:
|
||||
html.Span(f" @ {pred_time}", className="text-muted small")
|
||||
], className="mb-1"),
|
||||
|
||||
# Timing information (NEW)
|
||||
html.Div([
|
||||
html.Span("Timing: ", className="text-muted small"),
|
||||
html.Span(f"Inf: {model_info.get('timing', {}).get('last_inference', 'None')}", className="text-info small"),
|
||||
html.Span(" | ", className="text-muted small"),
|
||||
html.Span(f"Train: {model_info.get('timing', {}).get('last_training', 'None')}", className="text-warning small"),
|
||||
html.Br(),
|
||||
html.Span(f"Rate: {model_info.get('timing', {}).get('inferences_per_second', '0.00')}/s", className="text-success small"),
|
||||
html.Span(" | ", className="text-muted small"),
|
||||
html.Span(f"24h: {model_info.get('timing', {}).get('predictions_24h', 0)}", className="text-primary small")
|
||||
], className="mb-1"),
|
||||
|
||||
# Loss metrics with improvement tracking
|
||||
html.Div([
|
||||
html.Span("Current Loss: ", className="text-muted small"),
|
||||
@ -680,4 +663,43 @@ class DashboardComponentManager:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting training metrics: {e}")
|
||||
return [html.P(f"Error: {str(e)}", className="text-danger small")]
|
||||
return [html.P(f"Error: {str(e)}", className="text-danger small")]
|
||||
|
||||
def _format_cnn_pivot_prediction(self, model_info):
|
||||
"""Format CNN pivot prediction for display"""
|
||||
try:
|
||||
pivot_prediction = model_info.get('pivot_prediction')
|
||||
if not pivot_prediction:
|
||||
return html.Div()
|
||||
|
||||
pivot_type = pivot_prediction.get('pivot_type', 'UNKNOWN')
|
||||
predicted_price = pivot_prediction.get('predicted_price', 0)
|
||||
confidence = pivot_prediction.get('confidence', 0)
|
||||
time_horizon = pivot_prediction.get('time_horizon_minutes', 0)
|
||||
|
||||
# Color coding for pivot types
|
||||
if 'RESISTANCE' in pivot_type:
|
||||
pivot_color = "text-danger"
|
||||
pivot_icon = "fas fa-arrow-up"
|
||||
elif 'SUPPORT' in pivot_type:
|
||||
pivot_color = "text-success"
|
||||
pivot_icon = "fas fa-arrow-down"
|
||||
else:
|
||||
pivot_color = "text-warning"
|
||||
pivot_icon = "fas fa-arrows-alt-h"
|
||||
|
||||
return html.Div([
|
||||
html.Div([
|
||||
html.I(className=f"{pivot_icon} me-1 {pivot_color}"),
|
||||
html.Span("Next Pivot: ", className="text-muted small"),
|
||||
html.Span(f"${predicted_price:.2f}", className=f"small fw-bold {pivot_color}")
|
||||
], className="mb-1"),
|
||||
html.Div([
|
||||
html.Span(f"{pivot_type.replace('_', ' ')}", className=f"small {pivot_color}"),
|
||||
html.Span(f" ({confidence:.0%}) in ~{time_horizon}m", className="text-muted small")
|
||||
])
|
||||
], className="mt-1 p-1", style={"backgroundColor": "rgba(255,255,255,0.02)", "borderRadius": "3px"})
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error formatting CNN pivot prediction: {e}")
|
||||
return html.Div()
|
@ -51,7 +51,7 @@ class DashboardLayoutManager:
|
||||
return html.Div([
|
||||
self._create_metrics_and_signals_row(),
|
||||
self._create_charts_row(),
|
||||
self._create_analytics_and_performance_row()
|
||||
self._create_cob_and_trades_row()
|
||||
])
|
||||
|
||||
def _create_metrics_and_signals_row(self):
|
||||
@ -199,6 +199,100 @@ class DashboardLayoutManager:
|
||||
], className="card")
|
||||
])
|
||||
|
||||
def _create_cob_and_trades_row(self):
|
||||
"""Creates the row for COB ladders, closed trades, and model status."""
|
||||
return html.Div(
|
||||
[
|
||||
# Left side: COB Ladders (60% width)
|
||||
html.Div(
|
||||
[
|
||||
html.Div(
|
||||
[
|
||||
# ETH/USDT COB
|
||||
html.Div(
|
||||
[
|
||||
html.Div(
|
||||
id="eth-cob-content",
|
||||
className="card-body p-2",
|
||||
)
|
||||
],
|
||||
className="card",
|
||||
style={"flex": "1"},
|
||||
),
|
||||
# BTC/USDT COB
|
||||
html.Div(
|
||||
[
|
||||
html.Div(
|
||||
id="btc-cob-content",
|
||||
className="card-body p-2",
|
||||
)
|
||||
],
|
||||
className="card",
|
||||
style={"flex": "1", "marginLeft": "1rem"},
|
||||
),
|
||||
],
|
||||
className="d-flex",
|
||||
)
|
||||
],
|
||||
style={"width": "60%"},
|
||||
),
|
||||
# Right side: Trades and Model Status (40% width)
|
||||
html.Div(
|
||||
[
|
||||
# Closed Trades
|
||||
html.Div(
|
||||
[
|
||||
html.Div(
|
||||
[
|
||||
html.H6(
|
||||
[
|
||||
html.I(className="fas fa-history me-2"),
|
||||
"Closed Trades",
|
||||
],
|
||||
className="card-title mb-2",
|
||||
),
|
||||
html.Div(
|
||||
id="closed-trades-table",
|
||||
style={"height": "250px", "overflowY": "auto"},
|
||||
),
|
||||
],
|
||||
className="card-body p-2",
|
||||
)
|
||||
],
|
||||
className="card mb-3",
|
||||
),
|
||||
# Model Status
|
||||
html.Div(
|
||||
[
|
||||
html.Div(
|
||||
[
|
||||
html.H6(
|
||||
[
|
||||
html.I(className="fas fa-brain me-2"),
|
||||
"Models & Training Progress",
|
||||
],
|
||||
className="card-title mb-2",
|
||||
),
|
||||
html.Div(
|
||||
id="training-metrics",
|
||||
style={
|
||||
"height": "250px",
|
||||
"overflowY": "auto",
|
||||
},
|
||||
),
|
||||
],
|
||||
className="card-body p-2",
|
||||
)
|
||||
],
|
||||
className="card",
|
||||
),
|
||||
],
|
||||
style={"width": "38%", "marginLeft": "2%"},
|
||||
),
|
||||
],
|
||||
className="d-flex mb-3",
|
||||
)
|
||||
|
||||
def _create_analytics_and_performance_row(self):
|
||||
"""Create the combined analytics and performance row with COB data, trades, and training progress"""
|
||||
return html.Div([
|
||||
|
Reference in New Issue
Block a user