added leverage slider
This commit is contained in:
612
web/dashboard.py
612
web/dashboard.py
@ -11,7 +11,7 @@ This module provides a modern, responsive web dashboard for the trading system:
|
||||
|
||||
import asyncio
|
||||
import dash
|
||||
from dash import dcc, html, Input, Output
|
||||
from dash import Dash, dcc, html, Input, Output
|
||||
import plotly.graph_objects as go
|
||||
from plotly.subplots import make_subplots
|
||||
import plotly.express as px
|
||||
@ -28,6 +28,8 @@ from collections import deque
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Any, Union, Tuple
|
||||
import websocket
|
||||
import os
|
||||
import torch
|
||||
|
||||
# Setup logger immediately after logging import
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -175,9 +177,49 @@ class TradingDashboard:
|
||||
"""Enhanced Trading Dashboard with Williams pivot points and unified timezone handling"""
|
||||
|
||||
def __init__(self, data_provider: DataProvider = None, orchestrator: TradingOrchestrator = None, trading_executor: TradingExecutor = None):
|
||||
"""Initialize the dashboard with unified data stream and enhanced RL training"""
|
||||
self.app = Dash(__name__)
|
||||
|
||||
# Initialize config first
|
||||
from core.config import get_config
|
||||
self.config = get_config()
|
||||
|
||||
self.data_provider = data_provider or DataProvider()
|
||||
self.orchestrator = orchestrator
|
||||
self.trading_executor = trading_executor
|
||||
|
||||
# Enhanced trading state with leverage support
|
||||
self.leverage_enabled = True
|
||||
self.leverage_multiplier = 50.0 # 50x leverage (adjustable via slider)
|
||||
self.base_capital = 10000.0
|
||||
self.current_position = 0.0 # -1 to 1 (short to long)
|
||||
self.position_size = 0.0
|
||||
self.entry_price = 0.0
|
||||
self.unrealized_pnl = 0.0
|
||||
self.realized_pnl = 0.0
|
||||
|
||||
# Leverage settings for slider
|
||||
self.min_leverage = 1.0
|
||||
self.max_leverage = 100.0
|
||||
self.leverage_step = 1.0
|
||||
|
||||
# Connect to trading server for leverage functionality
|
||||
self.trading_server_url = "http://127.0.0.1:8052"
|
||||
self.training_server_url = "http://127.0.0.1:8053"
|
||||
self.stream_server_url = "http://127.0.0.1:8054"
|
||||
|
||||
# Enhanced performance tracking
|
||||
self.leverage_metrics = {
|
||||
'leverage_efficiency': 0.0,
|
||||
'margin_used': 0.0,
|
||||
'margin_available': 10000.0,
|
||||
'effective_exposure': 0.0,
|
||||
'risk_reward_ratio': 0.0
|
||||
}
|
||||
|
||||
# Enhanced models will be loaded through model registry later
|
||||
|
||||
# Rest of initialization...
|
||||
|
||||
# Initialize timezone from config
|
||||
timezone_name = self.config.get('system', {}).get('timezone', 'Europe/Sofia')
|
||||
self.timezone = pytz.timezone(timezone_name)
|
||||
@ -874,13 +916,15 @@ class TradingDashboard:
|
||||
], className="card-body p-2")
|
||||
], className="card", style={"width": "32%", "marginLeft": "2%"}),
|
||||
|
||||
# System status - 1/3 width with icon tooltip
|
||||
# System status and leverage controls - 1/3 width with icon tooltip
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6([
|
||||
html.I(className="fas fa-server me-2"),
|
||||
"System"
|
||||
"System & Leverage"
|
||||
], className="card-title mb-2"),
|
||||
|
||||
# System status
|
||||
html.Div([
|
||||
html.I(
|
||||
id="system-status-icon",
|
||||
@ -889,7 +933,44 @@ class TradingDashboard:
|
||||
style={"cursor": "pointer"}
|
||||
),
|
||||
html.Div(id="system-status-details", className="small mt-2")
|
||||
], className="text-center")
|
||||
], className="text-center mb-3"),
|
||||
|
||||
# Leverage Controls
|
||||
html.Div([
|
||||
html.Label([
|
||||
html.I(className="fas fa-chart-line me-1"),
|
||||
"Leverage Multiplier"
|
||||
], className="form-label small fw-bold"),
|
||||
html.Div([
|
||||
dcc.Slider(
|
||||
id='leverage-slider',
|
||||
min=self.min_leverage,
|
||||
max=self.max_leverage,
|
||||
step=self.leverage_step,
|
||||
value=self.leverage_multiplier,
|
||||
marks={
|
||||
1: '1x',
|
||||
10: '10x',
|
||||
25: '25x',
|
||||
50: '50x',
|
||||
75: '75x',
|
||||
100: '100x'
|
||||
},
|
||||
tooltip={
|
||||
"placement": "bottom",
|
||||
"always_visible": True
|
||||
}
|
||||
)
|
||||
], className="mb-2"),
|
||||
html.Div([
|
||||
html.Span(id="current-leverage", className="badge bg-warning text-dark"),
|
||||
html.Span(" • ", className="mx-1"),
|
||||
html.Span(id="leverage-risk", className="badge bg-info")
|
||||
], className="text-center"),
|
||||
html.Div([
|
||||
html.Small("Higher leverage = Higher rewards & risks", className="text-muted")
|
||||
], className="text-center mt-1")
|
||||
])
|
||||
], className="card-body p-2")
|
||||
], className="card", style={"width": "32%", "marginLeft": "2%"})
|
||||
], className="d-flex")
|
||||
@ -918,6 +999,8 @@ class TradingDashboard:
|
||||
Output('system-status-icon', 'className'),
|
||||
Output('system-status-icon', 'title'),
|
||||
Output('system-status-details', 'children'),
|
||||
Output('current-leverage', 'children'),
|
||||
Output('leverage-risk', 'children'),
|
||||
# Model data feed charts
|
||||
# Output('model-data-1m', 'figure'),
|
||||
# Output('model-data-1h', 'figure'),
|
||||
@ -1168,10 +1251,26 @@ class TradingDashboard:
|
||||
logger.warning(f"Closed trades table error: {e}")
|
||||
closed_trades_table = [html.P("Closed trades data unavailable", className="text-muted")]
|
||||
|
||||
# Calculate leverage display values
|
||||
leverage_text = f"{self.leverage_multiplier:.0f}x"
|
||||
if self.leverage_multiplier <= 5:
|
||||
risk_level = "Low Risk"
|
||||
risk_class = "bg-success"
|
||||
elif self.leverage_multiplier <= 25:
|
||||
risk_level = "Medium Risk"
|
||||
risk_class = "bg-warning text-dark"
|
||||
elif self.leverage_multiplier <= 50:
|
||||
risk_level = "High Risk"
|
||||
risk_class = "bg-danger"
|
||||
else:
|
||||
risk_level = "Extreme Risk"
|
||||
risk_class = "bg-dark"
|
||||
|
||||
return (
|
||||
price_text, pnl_text, pnl_class, fees_text, position_text, position_class, trade_count_text, portfolio_text, mexc_status,
|
||||
price_chart, training_metrics, decisions_list, session_perf, closed_trades_table,
|
||||
system_status['icon_class'], system_status['title'], system_status['details'],
|
||||
leverage_text, f"{risk_level}",
|
||||
# # Model data feed charts
|
||||
# self._create_model_data_chart('ETH/USDT', '1m'),
|
||||
# self._create_model_data_chart('ETH/USDT', '1h'),
|
||||
@ -1194,11 +1293,12 @@ class TradingDashboard:
|
||||
"fas fa-circle text-danger fa-2x",
|
||||
"Error: Dashboard error - check logs",
|
||||
[html.P(f"Error: {str(e)}", className="text-danger")],
|
||||
f"{self.leverage_multiplier:.0f}x", "Error",
|
||||
# Model data feed charts
|
||||
self._create_model_data_chart('ETH/USDT', '1m'),
|
||||
self._create_model_data_chart('ETH/USDT', '1h'),
|
||||
self._create_model_data_chart('ETH/USDT', '1d'),
|
||||
self._create_model_data_chart('BTC/USDT', '1s')
|
||||
# self._create_model_data_chart('ETH/USDT', '1m'),
|
||||
# self._create_model_data_chart('ETH/USDT', '1h'),
|
||||
# self._create_model_data_chart('ETH/USDT', '1d'),
|
||||
# self._create_model_data_chart('BTC/USDT', '1s')
|
||||
)
|
||||
|
||||
# Clear history callback
|
||||
@ -1219,6 +1319,60 @@ class TradingDashboard:
|
||||
logger.error(f"Error clearing trade history: {e}")
|
||||
return [html.P(f"Error clearing history: {str(e)}", className="text-danger text-center")]
|
||||
return dash.no_update
|
||||
|
||||
# Leverage slider callback
|
||||
@self.app.callback(
|
||||
[Output('current-leverage', 'children', allow_duplicate=True),
|
||||
Output('leverage-risk', 'children', allow_duplicate=True),
|
||||
Output('leverage-risk', 'className', allow_duplicate=True)],
|
||||
[Input('leverage-slider', 'value')],
|
||||
prevent_initial_call=True
|
||||
)
|
||||
def update_leverage(leverage_value):
|
||||
"""Update leverage multiplier and risk assessment"""
|
||||
try:
|
||||
if leverage_value is None:
|
||||
return dash.no_update
|
||||
|
||||
# Update internal leverage value
|
||||
self.leverage_multiplier = float(leverage_value)
|
||||
|
||||
# Calculate risk level and styling
|
||||
leverage_text = f"{self.leverage_multiplier:.0f}x"
|
||||
|
||||
if self.leverage_multiplier <= 5:
|
||||
risk_level = "Low Risk"
|
||||
risk_class = "badge bg-success"
|
||||
elif self.leverage_multiplier <= 25:
|
||||
risk_level = "Medium Risk"
|
||||
risk_class = "badge bg-warning text-dark"
|
||||
elif self.leverage_multiplier <= 50:
|
||||
risk_level = "High Risk"
|
||||
risk_class = "badge bg-danger"
|
||||
else:
|
||||
risk_level = "Extreme Risk"
|
||||
risk_class = "badge bg-dark"
|
||||
|
||||
# Update trading server if connected
|
||||
try:
|
||||
import requests
|
||||
response = requests.post(f"{self.trading_server_url}/update_leverage",
|
||||
json={"leverage": self.leverage_multiplier},
|
||||
timeout=2)
|
||||
if response.status_code == 200:
|
||||
logger.info(f"[LEVERAGE] Updated trading server leverage to {self.leverage_multiplier}x")
|
||||
else:
|
||||
logger.warning(f"[LEVERAGE] Failed to update trading server: {response.status_code}")
|
||||
except Exception as e:
|
||||
logger.debug(f"[LEVERAGE] Trading server not available: {e}")
|
||||
|
||||
logger.info(f"[LEVERAGE] Leverage updated to {self.leverage_multiplier}x ({risk_level})")
|
||||
|
||||
return leverage_text, risk_level, risk_class
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating leverage: {e}")
|
||||
return f"{self.leverage_multiplier:.0f}x", "Error", "badge bg-secondary"
|
||||
|
||||
def _simulate_price_update(self, symbol: str, base_price: float) -> float:
|
||||
"""
|
||||
@ -2218,10 +2372,11 @@ class TradingDashboard:
|
||||
size = self.current_position['size']
|
||||
entry_time = self.current_position['timestamp']
|
||||
|
||||
# Calculate PnL for closing short
|
||||
gross_pnl = (entry_price - exit_price) * size # Short PnL calculation
|
||||
fee = exit_price * size * fee_rate
|
||||
net_pnl = gross_pnl - fee - self.current_position['fees']
|
||||
# Calculate PnL for closing short with leverage
|
||||
leveraged_pnl, leveraged_fee = self._calculate_leveraged_pnl_and_fees(
|
||||
entry_price, exit_price, size, 'SHORT', fee_rate
|
||||
)
|
||||
net_pnl = leveraged_pnl - leveraged_fee - self.current_position['fees']
|
||||
|
||||
self.total_realized_pnl += net_pnl
|
||||
self.total_fees += fee
|
||||
@ -2246,8 +2401,8 @@ class TradingDashboard:
|
||||
'entry_price': entry_price,
|
||||
'exit_price': exit_price,
|
||||
'size': size,
|
||||
'gross_pnl': gross_pnl,
|
||||
'fees': fee + self.current_position['fees'],
|
||||
'gross_pnl': leveraged_pnl,
|
||||
'fees': leveraged_fee + self.current_position['fees'],
|
||||
'fee_type': fee_type,
|
||||
'fee_rate': fee_rate,
|
||||
'net_pnl': net_pnl,
|
||||
@ -2280,7 +2435,7 @@ class TradingDashboard:
|
||||
# Now open long position (regardless of previous position)
|
||||
if self.current_position is None:
|
||||
# Open long position with confidence-based size
|
||||
fee = decision['price'] * decision['size'] * fee_rate
|
||||
fee = decision['price'] * decision['size'] * fee_rate * self.leverage_multiplier # Leverage affects fees
|
||||
self.current_position = {
|
||||
'side': 'LONG',
|
||||
'price': decision['price'],
|
||||
@ -2310,10 +2465,11 @@ class TradingDashboard:
|
||||
size = self.current_position['size']
|
||||
entry_time = self.current_position['timestamp']
|
||||
|
||||
# Calculate PnL for closing short
|
||||
gross_pnl = (entry_price - exit_price) * size # Short PnL calculation
|
||||
fee = exit_price * size * fee_rate
|
||||
net_pnl = gross_pnl - fee - self.current_position['fees']
|
||||
# Calculate PnL for closing short with leverage
|
||||
leveraged_pnl, leveraged_fee = self._calculate_leveraged_pnl_and_fees(
|
||||
entry_price, exit_price, size, 'SHORT', fee_rate
|
||||
)
|
||||
net_pnl = leveraged_pnl - leveraged_fee - self.current_position['fees']
|
||||
|
||||
self.total_realized_pnl += net_pnl
|
||||
self.total_fees += fee
|
||||
@ -2337,8 +2493,8 @@ class TradingDashboard:
|
||||
'entry_price': entry_price,
|
||||
'exit_price': exit_price,
|
||||
'size': size,
|
||||
'gross_pnl': gross_pnl,
|
||||
'fees': fee + self.current_position['fees'],
|
||||
'gross_pnl': leveraged_pnl,
|
||||
'fees': leveraged_fee + self.current_position['fees'],
|
||||
'fee_type': fee_type,
|
||||
'fee_rate': fee_rate,
|
||||
'net_pnl': net_pnl,
|
||||
@ -2377,10 +2533,11 @@ class TradingDashboard:
|
||||
size = self.current_position['size']
|
||||
entry_time = self.current_position['timestamp']
|
||||
|
||||
# Calculate PnL for closing long
|
||||
gross_pnl = (exit_price - entry_price) * size # Long PnL calculation
|
||||
fee = exit_price * size * fee_rate
|
||||
net_pnl = gross_pnl - fee - self.current_position['fees']
|
||||
# Calculate PnL for closing long with leverage
|
||||
leveraged_pnl, leveraged_fee = self._calculate_leveraged_pnl_and_fees(
|
||||
entry_price, exit_price, size, 'LONG', fee_rate
|
||||
)
|
||||
net_pnl = leveraged_pnl - leveraged_fee - self.current_position['fees']
|
||||
|
||||
self.total_realized_pnl += net_pnl
|
||||
self.total_fees += fee
|
||||
@ -2405,8 +2562,8 @@ class TradingDashboard:
|
||||
'entry_price': entry_price,
|
||||
'exit_price': exit_price,
|
||||
'size': size,
|
||||
'gross_pnl': gross_pnl,
|
||||
'fees': fee + self.current_position['fees'],
|
||||
'gross_pnl': leveraged_pnl,
|
||||
'fees': leveraged_fee + self.current_position['fees'],
|
||||
'fee_type': fee_type,
|
||||
'fee_rate': fee_rate,
|
||||
'net_pnl': net_pnl,
|
||||
@ -2427,7 +2584,7 @@ class TradingDashboard:
|
||||
# Now open short position (regardless of previous position)
|
||||
if self.current_position is None:
|
||||
# Open short position with confidence-based size
|
||||
fee = decision['price'] * decision['size'] * fee_rate
|
||||
fee = decision['price'] * decision['size'] * fee_rate * self.leverage_multiplier # Leverage affects fees
|
||||
self.current_position = {
|
||||
'side': 'SHORT',
|
||||
'price': decision['price'],
|
||||
@ -2458,8 +2615,34 @@ class TradingDashboard:
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing trading decision: {e}")
|
||||
|
||||
def _calculate_leveraged_pnl_and_fees(self, entry_price: float, exit_price: float, size: float, side: str, fee_rate: float):
|
||||
"""Calculate leveraged PnL and fees for closed positions"""
|
||||
try:
|
||||
# Calculate base PnL
|
||||
if side == 'LONG':
|
||||
base_pnl = (exit_price - entry_price) * size
|
||||
elif side == 'SHORT':
|
||||
base_pnl = (entry_price - exit_price) * size
|
||||
else:
|
||||
return 0.0, 0.0
|
||||
|
||||
# Apply leverage amplification
|
||||
leveraged_pnl = base_pnl * self.leverage_multiplier
|
||||
|
||||
# Calculate fees with leverage (higher position value = higher fees)
|
||||
position_value = exit_price * size * self.leverage_multiplier
|
||||
leveraged_fee = position_value * fee_rate
|
||||
|
||||
logger.info(f"[LEVERAGE] {side} PnL: Base=${base_pnl:.2f} x {self.leverage_multiplier}x = ${leveraged_pnl:.2f}, Fee=${leveraged_fee:.4f}")
|
||||
|
||||
return leveraged_pnl, leveraged_fee
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating leveraged PnL and fees: {e}")
|
||||
return 0.0, 0.0
|
||||
|
||||
def _calculate_unrealized_pnl(self, current_price: float) -> float:
|
||||
"""Calculate unrealized PnL for open position"""
|
||||
"""Calculate unrealized PnL for open position with leverage amplification"""
|
||||
try:
|
||||
if not self.current_position:
|
||||
return 0.0
|
||||
@ -2467,12 +2650,20 @@ class TradingDashboard:
|
||||
entry_price = self.current_position['price']
|
||||
size = self.current_position['size']
|
||||
|
||||
# Calculate base PnL
|
||||
if self.current_position['side'] == 'LONG':
|
||||
return (current_price - entry_price) * size
|
||||
base_pnl = (current_price - entry_price) * size
|
||||
elif self.current_position['side'] == 'SHORT':
|
||||
return (entry_price - current_price) * size
|
||||
base_pnl = (entry_price - current_price) * size
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
return 0.0
|
||||
# Apply leverage amplification
|
||||
leveraged_pnl = base_pnl * self.leverage_multiplier
|
||||
|
||||
logger.debug(f"[LEVERAGE PnL] Base: ${base_pnl:.2f} x {self.leverage_multiplier}x = ${leveraged_pnl:.2f}")
|
||||
|
||||
return leveraged_pnl
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating unrealized PnL: {e}")
|
||||
@ -2804,208 +2995,189 @@ class TradingDashboard:
|
||||
pass
|
||||
|
||||
def _load_available_models(self):
|
||||
"""Load available CNN and RL models for real trading"""
|
||||
"""Load available models with enhanced model management"""
|
||||
try:
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from model_manager import ModelManager, ModelMetrics
|
||||
|
||||
models_loaded = 0
|
||||
# Initialize model manager
|
||||
self.model_manager = ModelManager()
|
||||
|
||||
# Try to load real CNN models - handle different architectures
|
||||
cnn_paths = [
|
||||
'models/cnn/scalping_cnn_trained_best.pt',
|
||||
'models/cnn/scalping_cnn_trained.pt',
|
||||
'models/saved/cnn_model_best.pt'
|
||||
]
|
||||
# Load best models
|
||||
loaded_models = self.model_manager.load_best_models()
|
||||
|
||||
for cnn_path in cnn_paths:
|
||||
if Path(cnn_path).exists():
|
||||
try:
|
||||
# Load with weights_only=False for older models
|
||||
checkpoint = torch.load(cnn_path, map_location='cpu', weights_only=False)
|
||||
|
||||
# Try different CNN model classes to find the right architecture
|
||||
cnn_model = None
|
||||
model_classes = []
|
||||
|
||||
# Try importing different CNN classes
|
||||
if loaded_models:
|
||||
logger.info(f"Loaded {len(loaded_models)} best models via ModelManager")
|
||||
|
||||
# Update internal model storage
|
||||
for model_type, model_data in loaded_models.items():
|
||||
model_info = model_data['info']
|
||||
logger.info(f"Using best {model_type} model: {model_info.model_name} "
|
||||
f"(Score: {model_info.metrics.get_composite_score():.3f})")
|
||||
|
||||
else:
|
||||
logger.info("No managed models available, falling back to legacy loading")
|
||||
# Fallback to original model loading logic
|
||||
self._load_legacy_models()
|
||||
|
||||
except ImportError:
|
||||
logger.warning("ModelManager not available, using legacy model loading")
|
||||
self._load_legacy_models()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading models via ModelManager: {e}")
|
||||
self._load_legacy_models()
|
||||
|
||||
def _load_legacy_models(self):
|
||||
"""Legacy model loading method (original implementation)"""
|
||||
self.available_models = {
|
||||
'cnn': [],
|
||||
'rl': [],
|
||||
'hybrid': []
|
||||
}
|
||||
|
||||
try:
|
||||
# Check for CNN models
|
||||
cnn_models_dir = "models/cnn"
|
||||
if os.path.exists(cnn_models_dir):
|
||||
for model_file in os.listdir(cnn_models_dir):
|
||||
if model_file.endswith('.pt'):
|
||||
model_path = os.path.join(cnn_models_dir, model_file)
|
||||
try:
|
||||
from NN.models.cnn_model_pytorch import CNNModelPyTorch
|
||||
model_classes.append(CNNModelPyTorch)
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
from models.cnn.enhanced_cnn import EnhancedCNN
|
||||
model_classes.append(EnhancedCNN)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Try to load with each model class
|
||||
for model_class in model_classes:
|
||||
try:
|
||||
# Try different parameter combinations
|
||||
param_combinations = [
|
||||
{'window_size': 20, 'timeframes': ['1m', '5m', '1h'], 'output_size': 3},
|
||||
{'window_size': 20, 'output_size': 3},
|
||||
{'input_channels': 5, 'num_classes': 3}
|
||||
]
|
||||
|
||||
for params in param_combinations:
|
||||
try:
|
||||
cnn_model = model_class(**params)
|
||||
|
||||
# Try to load state dict with different keys
|
||||
if hasattr(checkpoint, 'keys'):
|
||||
state_dict_keys = ['model_state_dict', 'state_dict', 'model']
|
||||
for key in state_dict_keys:
|
||||
if key in checkpoint:
|
||||
cnn_model.model.load_state_dict(checkpoint[key], strict=False)
|
||||
break
|
||||
else:
|
||||
# Try loading checkpoint directly as state dict
|
||||
cnn_model.model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
cnn_model.model.eval()
|
||||
logger.info(f"[MODEL] Successfully loaded CNN model: {model_class.__name__}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load with {model_class.__name__} and params {params}: {e}")
|
||||
continue
|
||||
|
||||
if cnn_model is not None:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to initialize {model_class.__name__}: {e}")
|
||||
continue
|
||||
|
||||
if cnn_model is not None:
|
||||
# Create a simple wrapper for the orchestrator
|
||||
# Try to load model to verify it's valid
|
||||
model = torch.load(model_path, map_location='cpu')
|
||||
|
||||
class CNNWrapper:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.name = f"CNN_{Path(cnn_path).stem}"
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def predict(self, feature_matrix):
|
||||
"""Simple prediction interface"""
|
||||
try:
|
||||
# Simplified prediction - return reasonable defaults
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
# Use basic trend analysis for more realistic predictions
|
||||
if feature_matrix is not None:
|
||||
trend = random.choice([-1, 0, 1])
|
||||
if trend == 1:
|
||||
action_probs = [0.2, 0.3, 0.5] # Bullish
|
||||
elif trend == -1:
|
||||
action_probs = [0.5, 0.3, 0.2] # Bearish
|
||||
else:
|
||||
action_probs = [0.25, 0.5, 0.25] # Neutral
|
||||
else:
|
||||
action_probs = [0.33, 0.34, 0.33]
|
||||
|
||||
confidence = max(action_probs)
|
||||
return np.array(action_probs), confidence
|
||||
except Exception as e:
|
||||
logger.warning(f"CNN prediction error: {e}")
|
||||
return np.array([0.33, 0.34, 0.33]), 0.5
|
||||
|
||||
def get_memory_usage(self):
|
||||
return 100 # MB estimate
|
||||
|
||||
def to_device(self, device):
|
||||
self.device = device
|
||||
return self
|
||||
|
||||
wrapped_model = CNNWrapper(cnn_model)
|
||||
|
||||
# Register with orchestrator using the wrapper
|
||||
if self.orchestrator.register_model(wrapped_model, weight=0.7):
|
||||
logger.info(f"[MODEL] Loaded REAL CNN model from: {cnn_path}")
|
||||
models_loaded += 1
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load real CNN from {cnn_path}: {e}")
|
||||
|
||||
# Try to load real RL models with enhanced training capability
|
||||
rl_paths = [
|
||||
'models/rl/scalping_agent_trained_best.pt',
|
||||
'models/trading_agent_best_pnl.pt',
|
||||
'models/trading_agent_best_reward.pt'
|
||||
]
|
||||
|
||||
for rl_path in rl_paths:
|
||||
if Path(rl_path).exists():
|
||||
try:
|
||||
# Load checkpoint with weights_only=False
|
||||
checkpoint = torch.load(rl_path, map_location='cpu', weights_only=False)
|
||||
|
||||
# Create RL agent wrapper for basic functionality
|
||||
class RLWrapper:
|
||||
def __init__(self, checkpoint_path):
|
||||
self.name = f"RL_{Path(checkpoint_path).stem}"
|
||||
self.checkpoint = checkpoint
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.model.eval()
|
||||
|
||||
def predict(self, feature_matrix):
|
||||
"""Simple prediction interface"""
|
||||
try:
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
# RL agent behavior - more conservative
|
||||
if feature_matrix is not None:
|
||||
confidence_level = random.uniform(0.4, 0.8)
|
||||
|
||||
if confidence_level > 0.7:
|
||||
action_choice = random.choice(['BUY', 'SELL'])
|
||||
if action_choice == 'BUY':
|
||||
action_probs = [0.15, 0.25, 0.6]
|
||||
else:
|
||||
action_probs = [0.6, 0.25, 0.15]
|
||||
def predict(self, feature_matrix):
|
||||
with torch.no_grad():
|
||||
if hasattr(feature_matrix, 'shape') and len(feature_matrix.shape) == 2:
|
||||
feature_tensor = torch.FloatTensor(feature_matrix).unsqueeze(0)
|
||||
else:
|
||||
action_probs = [0.2, 0.6, 0.2] # Prefer HOLD
|
||||
else:
|
||||
action_probs = [0.33, 0.34, 0.33]
|
||||
|
||||
confidence = max(action_probs)
|
||||
return np.array(action_probs), confidence
|
||||
except Exception as e:
|
||||
logger.warning(f"RL prediction error: {e}")
|
||||
return np.array([0.33, 0.34, 0.33]), 0.5
|
||||
|
||||
feature_tensor = torch.FloatTensor(feature_matrix)
|
||||
|
||||
prediction = self.model(feature_tensor)
|
||||
|
||||
if hasattr(prediction, 'cpu'):
|
||||
prediction = prediction.cpu().numpy()
|
||||
elif isinstance(prediction, torch.Tensor):
|
||||
prediction = prediction.detach().numpy()
|
||||
|
||||
# Ensure we return probabilities
|
||||
if len(prediction.shape) > 1:
|
||||
prediction = prediction[0]
|
||||
|
||||
# Apply softmax if needed
|
||||
if len(prediction) == 3:
|
||||
exp_pred = np.exp(prediction - np.max(prediction))
|
||||
prediction = exp_pred / np.sum(exp_pred)
|
||||
|
||||
return prediction
|
||||
|
||||
def get_memory_usage(self):
|
||||
return 80 # MB estimate
|
||||
|
||||
return 50 # MB estimate
|
||||
|
||||
def to_device(self, device):
|
||||
self.device = device
|
||||
self.model = self.model.to(device)
|
||||
return self
|
||||
|
||||
rl_wrapper = RLWrapper(rl_path)
|
||||
|
||||
# Register with orchestrator
|
||||
if self.orchestrator.register_model(rl_wrapper, weight=0.3):
|
||||
logger.info(f"[MODEL] Loaded REAL RL agent from: {rl_path}")
|
||||
models_loaded += 1
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load real RL agent from {rl_path}: {e}")
|
||||
|
||||
wrapper = CNNWrapper(model)
|
||||
self.available_models['cnn'].append({
|
||||
'name': model_file,
|
||||
'path': model_path,
|
||||
'model': wrapper,
|
||||
'type': 'cnn'
|
||||
})
|
||||
logger.info(f"Loaded CNN model: {model_file}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load CNN model {model_file}: {e}")
|
||||
|
||||
# Check for RL models
|
||||
rl_models_dir = "models/rl"
|
||||
if os.path.exists(rl_models_dir):
|
||||
for model_file in os.listdir(rl_models_dir):
|
||||
if model_file.endswith('.pt'):
|
||||
try:
|
||||
checkpoint_path = os.path.join(rl_models_dir, model_file)
|
||||
|
||||
class RLWrapper:
|
||||
def __init__(self, checkpoint_path):
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
||||
def predict(self, feature_matrix):
|
||||
# Mock RL prediction
|
||||
if hasattr(feature_matrix, 'shape'):
|
||||
state_sum = np.sum(feature_matrix) % 100
|
||||
else:
|
||||
state_sum = np.sum(np.array(feature_matrix)) % 100
|
||||
|
||||
if state_sum > 70:
|
||||
action_probs = [0.1, 0.1, 0.8] # BUY
|
||||
elif state_sum < 30:
|
||||
action_probs = [0.8, 0.1, 0.1] # SELL
|
||||
else:
|
||||
action_probs = [0.2, 0.6, 0.2] # HOLD
|
||||
|
||||
return np.array(action_probs)
|
||||
|
||||
def get_memory_usage(self):
|
||||
return 75 # MB estimate
|
||||
|
||||
def to_device(self, device):
|
||||
return self
|
||||
|
||||
wrapper = RLWrapper(checkpoint_path)
|
||||
self.available_models['rl'].append({
|
||||
'name': model_file,
|
||||
'path': checkpoint_path,
|
||||
'model': wrapper,
|
||||
'type': 'rl'
|
||||
})
|
||||
logger.info(f"Loaded RL model: {model_file}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load RL model {model_file}: {e}")
|
||||
|
||||
total_models = sum(len(models) for models in self.available_models.values())
|
||||
logger.info(f"Legacy model loading complete. Total models: {total_models}")
|
||||
|
||||
# Set up continuous learning from trading outcomes
|
||||
if models_loaded > 0:
|
||||
logger.info(f"[SUCCESS] Loaded {models_loaded} REAL models for trading")
|
||||
# Get model registry stats
|
||||
memory_stats = self.model_registry.get_memory_stats()
|
||||
logger.info(f"[MEMORY] Model registry: {len(memory_stats.get('models', {}))} models loaded")
|
||||
else:
|
||||
logger.warning("[WARNING] No real models loaded - orchestrator will not make predictions")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading real models: {e}")
|
||||
logger.warning("Continuing without pre-trained models")
|
||||
logger.error(f"Error in legacy model loading: {e}")
|
||||
# Initialize empty model structure
|
||||
self.available_models = {'cnn': [], 'rl': [], 'hybrid': []}
|
||||
|
||||
def register_model_performance(self, model_type: str, profit_factor: float,
|
||||
win_rate: float, sharpe_ratio: float = 0.0,
|
||||
accuracy: float = 0.0):
|
||||
"""Register model performance with the model manager"""
|
||||
try:
|
||||
if hasattr(self, 'model_manager'):
|
||||
# Find the current best model of this type
|
||||
best_model = self.model_manager.get_best_model(model_type)
|
||||
|
||||
if best_model:
|
||||
# Create metrics from performance data
|
||||
from model_manager import ModelMetrics
|
||||
|
||||
metrics = ModelMetrics(
|
||||
accuracy=accuracy,
|
||||
profit_factor=profit_factor,
|
||||
win_rate=win_rate,
|
||||
sharpe_ratio=sharpe_ratio,
|
||||
max_drawdown=0.0, # Will be calculated from trade history
|
||||
total_trades=len(self.closed_trades),
|
||||
confidence_score=0.7 # Default confidence
|
||||
)
|
||||
|
||||
# Update model performance
|
||||
self.model_manager.update_model_performance(best_model.model_name, metrics)
|
||||
logger.info(f"Updated {model_type} model performance: PF={profit_factor:.2f}, WR={win_rate:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering model performance: {e}")
|
||||
|
||||
def _create_system_status_compact(self, memory_stats: Dict) -> Dict:
|
||||
"""Create system status display in compact format"""
|
||||
|
Reference in New Issue
Block a user