Compare commits
4 Commits
kiro
...
small-prof
Author | SHA1 | Date | |
---|---|---|---|
6c91bf0b93 | |||
64678bd8d3 | |||
4ab7bc1846 | |||
9cd2d5d8a4 |
Binary file not shown.
@ -5,6 +5,7 @@ import requests
|
|||||||
import hmac
|
import hmac
|
||||||
import hashlib
|
import hashlib
|
||||||
from urllib.parse import urlencode, quote_plus
|
from urllib.parse import urlencode, quote_plus
|
||||||
|
import json # Added for json.dumps
|
||||||
|
|
||||||
from .exchange_interface import ExchangeInterface
|
from .exchange_interface import ExchangeInterface
|
||||||
|
|
||||||
@ -85,37 +86,40 @@ class MEXCInterface(ExchangeInterface):
|
|||||||
return symbol.replace('/', '_').upper()
|
return symbol.replace('/', '_').upper()
|
||||||
|
|
||||||
def _generate_signature(self, timestamp: str, method: str, endpoint: str, params: Dict[str, Any]) -> str:
|
def _generate_signature(self, timestamp: str, method: str, endpoint: str, params: Dict[str, Any]) -> str:
|
||||||
"""Generate signature for private API calls using MEXC's expected parameter order"""
|
"""Generate signature for private API calls using MEXC's official method"""
|
||||||
# MEXC requires specific parameter ordering, not alphabetical
|
# MEXC signature format varies by method:
|
||||||
# Based on successful test: symbol, side, type, quantity, timestamp, then other params
|
# For GET/DELETE: URL-encoded query string of alphabetically sorted parameters.
|
||||||
mexc_param_order = ['symbol', 'side', 'type', 'quantity', 'timestamp', 'recvWindow']
|
# For POST: JSON string of parameters (no sorting needed).
|
||||||
|
# The API-Secret is used as the HMAC SHA256 key.
|
||||||
|
|
||||||
# Build ordered parameter list
|
# Remove signature from params to avoid circular inclusion
|
||||||
ordered_params = []
|
clean_params = {k: v for k, v in params.items() if k != 'signature'}
|
||||||
|
|
||||||
# Add parameters in MEXC's expected order
|
parameter_string: str
|
||||||
for param_name in mexc_param_order:
|
|
||||||
if param_name in params and param_name != 'signature':
|
|
||||||
ordered_params.append(f"{param_name}={params[param_name]}")
|
|
||||||
|
|
||||||
# Add any remaining parameters not in the standard order (alphabetically)
|
if method.upper() == "POST":
|
||||||
remaining_params = {k: v for k, v in params.items() if k not in mexc_param_order and k != 'signature'}
|
# For POST requests, the signature parameter is a JSON string
|
||||||
for key in sorted(remaining_params.keys()):
|
# Ensure sorting keys for consistent JSON string generation across runs
|
||||||
ordered_params.append(f"{key}={remaining_params[key]}")
|
# even though MEXC says sorting is not required for POST params, it's good practice.
|
||||||
|
parameter_string = json.dumps(clean_params, sort_keys=True, separators=(',', ':'))
|
||||||
|
else:
|
||||||
|
# For GET/DELETE requests, parameters are spliced in dictionary order with & interval
|
||||||
|
sorted_params = sorted(clean_params.items())
|
||||||
|
parameter_string = '&'.join(f"{key}={str(value)}" for key, value in sorted_params)
|
||||||
|
|
||||||
# Create query string (MEXC doesn't use the api_key + timestamp prefix)
|
# The string to be signed is: accessKey + timestamp + obtained parameter string.
|
||||||
query_string = '&'.join(ordered_params)
|
string_to_sign = f"{self.api_key}{timestamp}{parameter_string}"
|
||||||
|
|
||||||
logger.debug(f"MEXC signature query string: {query_string}")
|
logger.debug(f"MEXC string to sign (method {method}): {string_to_sign}")
|
||||||
|
|
||||||
# Generate HMAC SHA256 signature
|
# Generate HMAC SHA256 signature
|
||||||
signature = hmac.new(
|
signature = hmac.new(
|
||||||
self.api_secret.encode('utf-8'),
|
self.api_secret.encode('utf-8'),
|
||||||
query_string.encode('utf-8'),
|
string_to_sign.encode('utf-8'),
|
||||||
hashlib.sha256
|
hashlib.sha256
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
logger.debug(f"MEXC signature: {signature}")
|
logger.debug(f"MEXC generated signature: {signature}")
|
||||||
return signature
|
return signature
|
||||||
|
|
||||||
def _send_public_request(self, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
def _send_public_request(self, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||||
@ -145,7 +149,7 @@ class MEXCInterface(ExchangeInterface):
|
|||||||
logger.error(f"Error in public request to {endpoint}: {e}")
|
logger.error(f"Error in public request to {endpoint}: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _send_private_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Optional[Dict[str, Any]]:
|
def _send_private_request(self, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
|
||||||
"""Send a private request to the exchange with proper signature"""
|
"""Send a private request to the exchange with proper signature"""
|
||||||
if params is None:
|
if params is None:
|
||||||
params = {}
|
params = {}
|
||||||
@ -170,8 +174,11 @@ class MEXCInterface(ExchangeInterface):
|
|||||||
if method.upper() == "GET":
|
if method.upper() == "GET":
|
||||||
response = self.session.get(url, headers=headers, params=params, timeout=10)
|
response = self.session.get(url, headers=headers, params=params, timeout=10)
|
||||||
elif method.upper() == "POST":
|
elif method.upper() == "POST":
|
||||||
# MEXC expects POST parameters as query string, not in body
|
# MEXC expects POST parameters as JSON in the request body, not as query string
|
||||||
response = self.session.post(url, headers=headers, params=params, timeout=10)
|
# The signature is generated from the JSON string of parameters.
|
||||||
|
# We need to exclude 'signature' from the JSON body sent, as it's for the header.
|
||||||
|
params_for_body = {k: v for k, v in params.items() if k != 'signature'}
|
||||||
|
response = self.session.post(url, headers=headers, json=params_for_body, timeout=10)
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unsupported method: {method}")
|
logger.error(f"Unsupported method: {method}")
|
||||||
return None
|
return None
|
||||||
@ -217,12 +224,9 @@ class MEXCInterface(ExchangeInterface):
|
|||||||
|
|
||||||
response = self._send_public_request('GET', endpoint, params)
|
response = self._send_public_request('GET', endpoint, params)
|
||||||
|
|
||||||
if response:
|
|
||||||
# MEXC ticker returns a dictionary if single symbol, list if all symbols
|
|
||||||
if isinstance(response, dict):
|
if isinstance(response, dict):
|
||||||
ticker_data = response
|
ticker_data: Dict[str, Any] = response
|
||||||
elif isinstance(response, list) and len(response) > 0:
|
elif isinstance(response, list) and len(response) > 0:
|
||||||
# If the response is a list, try to find the specific symbol
|
|
||||||
found_ticker = next((item for item in response if item.get('symbol') == formatted_symbol), None)
|
found_ticker = next((item for item in response if item.get('symbol') == formatted_symbol), None)
|
||||||
if found_ticker:
|
if found_ticker:
|
||||||
ticker_data = found_ticker
|
ticker_data = found_ticker
|
||||||
@ -233,6 +237,9 @@ class MEXCInterface(ExchangeInterface):
|
|||||||
logger.error(f"Unexpected ticker response format: {response}")
|
logger.error(f"Unexpected ticker response format: {response}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# At this point, ticker_data is guaranteed to be a Dict[str, Any] due to the above logic
|
||||||
|
# If it was None, we would have returned early.
|
||||||
|
|
||||||
# Extract relevant info and format for universal use
|
# Extract relevant info and format for universal use
|
||||||
last_price = float(ticker_data.get('lastPrice', 0))
|
last_price = float(ticker_data.get('lastPrice', 0))
|
||||||
bid_price = float(ticker_data.get('bidPrice', 0))
|
bid_price = float(ticker_data.get('bidPrice', 0))
|
||||||
@ -257,8 +264,6 @@ class MEXCInterface(ExchangeInterface):
|
|||||||
'exchange': 'MEXC',
|
'exchange': 'MEXC',
|
||||||
'raw_data': ticker_data
|
'raw_data': ticker_data
|
||||||
}
|
}
|
||||||
logger.error(f"Failed to get ticker for {symbol}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_api_symbols(self) -> List[str]:
|
def get_api_symbols(self) -> List[str]:
|
||||||
"""Get list of symbols supported for API trading"""
|
"""Get list of symbols supported for API trading"""
|
||||||
@ -293,40 +298,90 @@ class MEXCInterface(ExchangeInterface):
|
|||||||
logger.info(f"Supported symbols include: {supported_symbols[:10]}...") # Show first 10
|
logger.info(f"Supported symbols include: {supported_symbols[:10]}...") # Show first 10
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
# Format quantity according to symbol precision requirements
|
||||||
|
formatted_quantity = self._format_quantity_for_symbol(formatted_symbol, quantity)
|
||||||
|
if formatted_quantity is None:
|
||||||
|
logger.error(f"MEXC: Failed to format quantity {quantity} for {formatted_symbol}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Handle order type restrictions for specific symbols
|
||||||
|
final_order_type = self._adjust_order_type_for_symbol(formatted_symbol, order_type.upper())
|
||||||
|
|
||||||
|
# Get price for limit orders
|
||||||
|
final_price = price
|
||||||
|
if final_order_type == 'LIMIT' and price is None:
|
||||||
|
# Get current market price
|
||||||
|
ticker = self.get_ticker(symbol)
|
||||||
|
if ticker and 'last' in ticker:
|
||||||
|
final_price = ticker['last']
|
||||||
|
logger.info(f"MEXC: Using market price ${final_price:.2f} for LIMIT order")
|
||||||
|
else:
|
||||||
|
logger.error(f"MEXC: Could not get market price for LIMIT order on {formatted_symbol}")
|
||||||
|
return {}
|
||||||
|
|
||||||
endpoint = "order"
|
endpoint = "order"
|
||||||
|
|
||||||
params: Dict[str, Any] = {
|
params: Dict[str, Any] = {
|
||||||
'symbol': formatted_symbol,
|
'symbol': formatted_symbol,
|
||||||
'side': side.upper(),
|
'side': side.upper(),
|
||||||
'type': order_type.upper(),
|
'type': final_order_type,
|
||||||
'quantity': str(quantity) # Quantity must be a string
|
'quantity': str(formatted_quantity) # Quantity must be a string
|
||||||
}
|
}
|
||||||
if price is not None:
|
if final_price is not None:
|
||||||
params['price'] = str(price) # Price must be a string for limit orders
|
params['price'] = str(final_price) # Price must be a string for limit orders
|
||||||
|
|
||||||
logger.info(f"MEXC: Placing {side.upper()} {order_type.upper()} order for {quantity} {formatted_symbol} at price {price}")
|
logger.info(f"MEXC: Placing {side.upper()} {final_order_type} order for {formatted_quantity} {formatted_symbol} at price {final_price}")
|
||||||
|
|
||||||
# For market orders, some parameters might be optional or handled differently.
|
|
||||||
# Check MEXC API docs for market order specifics (e.g., quoteOrderQty for buy market orders)
|
|
||||||
if order_type.upper() == 'MARKET' and side.upper() == 'BUY':
|
|
||||||
# If it's a market buy order, MEXC often expects quoteOrderQty instead of quantity
|
|
||||||
# Assuming quantity here refers to the base asset, if quoteOrderQty is needed, adjust.
|
|
||||||
# For now, we will stick to quantity and let MEXC handle the conversion if possible
|
|
||||||
pass # No specific change needed based on the current params structure
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# MEXC API endpoint for placing orders is /api/v3/order (POST)
|
# MEXC API endpoint for placing orders is /api/v3/order (POST)
|
||||||
order_result = self._send_private_request('POST', endpoint, params)
|
order_result = self._send_private_request('POST', endpoint, params)
|
||||||
if order_result:
|
if order_result is not None:
|
||||||
logger.info(f"MEXC: Order placed successfully: {order_result}")
|
logger.info(f"MEXC: Order placed successfully: {order_result}")
|
||||||
return order_result
|
return order_result
|
||||||
else:
|
else:
|
||||||
logger.error(f"MEXC: Error placing order: {order_result}")
|
logger.error(f"MEXC: Error placing order: request returned None")
|
||||||
return {}
|
return {}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"MEXC: Exception placing order: {e}")
|
logger.error(f"MEXC: Exception placing order: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def _format_quantity_for_symbol(self, formatted_symbol: str, quantity: float) -> Optional[float]:
|
||||||
|
"""Format quantity according to symbol precision requirements"""
|
||||||
|
try:
|
||||||
|
# Symbol-specific precision rules
|
||||||
|
if formatted_symbol == 'ETHUSDC':
|
||||||
|
# ETHUSDC requires max 5 decimal places, step size 0.000001
|
||||||
|
formatted_qty = round(quantity, 5)
|
||||||
|
# Ensure it meets minimum step size
|
||||||
|
step_size = 0.000001
|
||||||
|
formatted_qty = round(formatted_qty / step_size) * step_size
|
||||||
|
# Round again to remove floating point errors
|
||||||
|
formatted_qty = round(formatted_qty, 6)
|
||||||
|
logger.info(f"MEXC: Formatted ETHUSDC quantity {quantity} -> {formatted_qty}")
|
||||||
|
return formatted_qty
|
||||||
|
elif formatted_symbol == 'BTCUSDC':
|
||||||
|
# Assume similar precision for BTC
|
||||||
|
formatted_qty = round(quantity, 6)
|
||||||
|
step_size = 0.000001
|
||||||
|
formatted_qty = round(formatted_qty / step_size) * step_size
|
||||||
|
formatted_qty = round(formatted_qty, 6)
|
||||||
|
return formatted_qty
|
||||||
|
else:
|
||||||
|
# Default formatting - 6 decimal places
|
||||||
|
return round(quantity, 6)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error formatting quantity for {formatted_symbol}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _adjust_order_type_for_symbol(self, formatted_symbol: str, order_type: str) -> str:
|
||||||
|
"""Adjust order type based on symbol restrictions"""
|
||||||
|
if formatted_symbol == 'ETHUSDC':
|
||||||
|
# ETHUSDC only supports LIMIT and LIMIT_MAKER orders
|
||||||
|
if order_type == 'MARKET':
|
||||||
|
logger.info(f"MEXC: Converting MARKET order to LIMIT for {formatted_symbol} (MARKET not supported)")
|
||||||
|
return 'LIMIT'
|
||||||
|
return order_type
|
||||||
|
|
||||||
def cancel_order(self, symbol: str, order_id: str) -> Dict[str, Any]:
|
def cancel_order(self, symbol: str, order_id: str) -> Dict[str, Any]:
|
||||||
"""Cancel an existing order on MEXC."""
|
"""Cancel an existing order on MEXC."""
|
||||||
formatted_symbol = self._format_spot_symbol(symbol)
|
formatted_symbol = self._format_spot_symbol(symbol)
|
||||||
|
@ -229,8 +229,8 @@ class COBRLModelInterface(ModelInterface):
|
|||||||
Interface for the COB RL model that handles model management, training, and inference
|
Interface for the COB RL model that handles model management, training, and inference
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_checkpoint_dir: str = "models/realtime_rl_cob", device: str = None):
|
def __init__(self, model_checkpoint_dir: str = "models/realtime_rl_cob", device: str = None, name=None, **kwargs):
|
||||||
super().__init__(name="cob_rl_model") # Initialize ModelInterface with a name
|
super().__init__(name=name) # Initialize ModelInterface with a name
|
||||||
self.model_checkpoint_dir = model_checkpoint_dir
|
self.model_checkpoint_dir = model_checkpoint_dir
|
||||||
self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu'))
|
self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu'))
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
import random
|
import random
|
||||||
from typing import Tuple, List
|
from typing import Tuple, List
|
||||||
import osvu
|
import os
|
||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
Binary file not shown.
Binary file not shown.
10
config.yaml
10
config.yaml
@ -162,11 +162,11 @@ mexc_trading:
|
|||||||
trading_mode: simulation # simulation, testnet, live
|
trading_mode: simulation # simulation, testnet, live
|
||||||
|
|
||||||
# Position sizing as percentage of account balance
|
# Position sizing as percentage of account balance
|
||||||
base_position_percent: 5.0 # 5% base position of account
|
base_position_percent: 1 # 0.5% base position of account (MUCH SAFER)
|
||||||
max_position_percent: 20.0 # 20% max position of account
|
max_position_percent: 5.0 # 2% max position of account (REDUCED)
|
||||||
min_position_percent: 2.0 # 2% min position of account
|
min_position_percent: 0.5 # 0.2% min position of account (REDUCED)
|
||||||
leverage: 50.0 # 50x leverage (adjustable in UI)
|
leverage: 1.0 # 1x leverage (NO LEVERAGE FOR TESTING)
|
||||||
simulation_account_usd: 100.0 # $100 simulation account balance
|
simulation_account_usd: 99.9 # $100 simulation account balance
|
||||||
|
|
||||||
# Risk management
|
# Risk management
|
||||||
max_daily_loss_usd: 200.0
|
max_daily_loss_usd: 200.0
|
||||||
|
@ -34,7 +34,7 @@ class COBIntegration:
|
|||||||
Integration layer for Multi-Exchange COB data with gogo2 trading system
|
Integration layer for Multi-Exchange COB data with gogo2 trading system
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data_provider: Optional[DataProvider] = None, symbols: Optional[List[str]] = None):
|
def __init__(self, data_provider: Optional[DataProvider] = None, symbols: Optional[List[str]] = None, initial_data_limit=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Initialize COB Integration
|
Initialize COB Integration
|
||||||
|
|
||||||
|
@ -1007,6 +1007,17 @@ class TradingOrchestrator:
|
|||||||
if enhanced_features is not None:
|
if enhanced_features is not None:
|
||||||
# Get CNN prediction - use the actual underlying model
|
# Get CNN prediction - use the actual underlying model
|
||||||
try:
|
try:
|
||||||
|
# Ensure features are properly shaped and limited
|
||||||
|
if isinstance(enhanced_features, np.ndarray):
|
||||||
|
# Flatten and limit features to prevent shape mismatches
|
||||||
|
enhanced_features = enhanced_features.flatten()
|
||||||
|
if len(enhanced_features) > 100: # Limit to 100 features
|
||||||
|
enhanced_features = enhanced_features[:100]
|
||||||
|
elif len(enhanced_features) < 100: # Pad with zeros
|
||||||
|
padded = np.zeros(100)
|
||||||
|
padded[:len(enhanced_features)] = enhanced_features
|
||||||
|
enhanced_features = padded
|
||||||
|
|
||||||
if hasattr(model.model, 'act'):
|
if hasattr(model.model, 'act'):
|
||||||
# Use the CNN's act method
|
# Use the CNN's act method
|
||||||
action_result = model.model.act(enhanced_features, explore=False)
|
action_result = model.model.act(enhanced_features, explore=False)
|
||||||
@ -1138,6 +1149,17 @@ class TradingOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if feature_matrix is not None:
|
if feature_matrix is not None:
|
||||||
|
# Ensure feature_matrix is properly shaped and limited
|
||||||
|
if isinstance(feature_matrix, np.ndarray):
|
||||||
|
# Flatten and limit features to prevent shape mismatches
|
||||||
|
feature_matrix = feature_matrix.flatten()
|
||||||
|
if len(feature_matrix) > 2000: # Limit to 2000 features for generic models
|
||||||
|
feature_matrix = feature_matrix[:2000]
|
||||||
|
elif len(feature_matrix) < 2000: # Pad with zeros
|
||||||
|
padded = np.zeros(2000)
|
||||||
|
padded[:len(feature_matrix)] = feature_matrix
|
||||||
|
feature_matrix = padded
|
||||||
|
|
||||||
prediction_result = model.predict(feature_matrix)
|
prediction_result = model.predict(feature_matrix)
|
||||||
|
|
||||||
# Handle different return formats from model.predict()
|
# Handle different return formats from model.predict()
|
||||||
@ -1834,3 +1856,100 @@ class TradingOrchestrator:
|
|||||||
"""Set the trading executor for position tracking"""
|
"""Set the trading executor for position tracking"""
|
||||||
self.trading_executor = trading_executor
|
self.trading_executor = trading_executor
|
||||||
logger.info("Trading executor set for position tracking and P&L feedback")
|
logger.info("Trading executor set for position tracking and P&L feedback")
|
||||||
|
|
||||||
|
def _get_current_price(self, symbol: str) -> float:
|
||||||
|
"""Get current price for symbol"""
|
||||||
|
try:
|
||||||
|
# Try to get from data provider
|
||||||
|
if self.data_provider:
|
||||||
|
try:
|
||||||
|
# Try different methods to get current price
|
||||||
|
if hasattr(self.data_provider, 'get_latest_data'):
|
||||||
|
latest_data = self.data_provider.get_latest_data(symbol)
|
||||||
|
if latest_data and 'price' in latest_data:
|
||||||
|
return float(latest_data['price'])
|
||||||
|
elif latest_data and 'close' in latest_data:
|
||||||
|
return float(latest_data['close'])
|
||||||
|
elif hasattr(self.data_provider, 'get_current_price'):
|
||||||
|
return float(self.data_provider.get_current_price(symbol))
|
||||||
|
elif hasattr(self.data_provider, 'get_latest_candle'):
|
||||||
|
latest_candle = self.data_provider.get_latest_candle(symbol, '1m')
|
||||||
|
if latest_candle and 'close' in latest_candle:
|
||||||
|
return float(latest_candle['close'])
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not get price from data provider: {e}")
|
||||||
|
# Try to get from universal adapter
|
||||||
|
if self.universal_adapter:
|
||||||
|
try:
|
||||||
|
data_stream = self.universal_adapter.get_latest_data(symbol)
|
||||||
|
if data_stream and hasattr(data_stream, 'current_price'):
|
||||||
|
return float(data_stream.current_price)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not get price from universal adapter: {e}")
|
||||||
|
# Fallback to default prices
|
||||||
|
default_prices = {
|
||||||
|
'ETH/USDT': 2500.0,
|
||||||
|
'BTC/USDT': 108000.0
|
||||||
|
}
|
||||||
|
return default_prices.get(symbol, 1000.0)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting current price for {symbol}: {e}")
|
||||||
|
# Return default price based on symbol
|
||||||
|
if 'ETH' in symbol:
|
||||||
|
return 2500.0
|
||||||
|
elif 'BTC' in symbol:
|
||||||
|
return 108000.0
|
||||||
|
else:
|
||||||
|
return 1000.0
|
||||||
|
|
||||||
|
def _generate_fallback_prediction(self, symbol: str) -> Dict[str, Any]:
|
||||||
|
"""Generate fallback prediction when models fail"""
|
||||||
|
try:
|
||||||
|
return {
|
||||||
|
'action': 'HOLD',
|
||||||
|
'confidence': 0.5,
|
||||||
|
'price': self._get_current_price(symbol) or 2500.0,
|
||||||
|
'timestamp': datetime.now(),
|
||||||
|
'model': 'fallback'
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error generating fallback prediction: {e}")
|
||||||
|
return {
|
||||||
|
'action': 'HOLD',
|
||||||
|
'confidence': 0.5,
|
||||||
|
'price': 2500.0,
|
||||||
|
'timestamp': datetime.now(),
|
||||||
|
'model': 'fallback'
|
||||||
|
}
|
||||||
|
|
||||||
|
def capture_dqn_prediction(self, symbol: str, action_idx: int, confidence: float, price: float, q_values: List[float] = None):
|
||||||
|
"""Capture DQN prediction for dashboard visualization"""
|
||||||
|
try:
|
||||||
|
if symbol not in self.recent_dqn_predictions:
|
||||||
|
self.recent_dqn_predictions[symbol] = deque(maxlen=100)
|
||||||
|
prediction_data = {
|
||||||
|
'timestamp': datetime.now(),
|
||||||
|
'action': ['SELL', 'HOLD', 'BUY'][action_idx],
|
||||||
|
'confidence': confidence,
|
||||||
|
'price': price,
|
||||||
|
'q_values': q_values or [0.33, 0.33, 0.34]
|
||||||
|
}
|
||||||
|
self.recent_dqn_predictions[symbol].append(prediction_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error capturing DQN prediction: {e}")
|
||||||
|
|
||||||
|
def capture_cnn_prediction(self, symbol: str, direction: int, confidence: float, current_price: float, predicted_price: float):
|
||||||
|
"""Capture CNN prediction for dashboard visualization"""
|
||||||
|
try:
|
||||||
|
if symbol not in self.recent_cnn_predictions:
|
||||||
|
self.recent_cnn_predictions[symbol] = deque(maxlen=50)
|
||||||
|
prediction_data = {
|
||||||
|
'timestamp': datetime.now(),
|
||||||
|
'direction': ['DOWN', 'SAME', 'UP'][direction],
|
||||||
|
'confidence': confidence,
|
||||||
|
'current_price': current_price,
|
||||||
|
'predicted_price': predicted_price
|
||||||
|
}
|
||||||
|
self.recent_cnn_predictions[symbol].append(prediction_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error capturing CNN prediction: {e}")
|
@ -114,8 +114,13 @@ class TradingExecutor:
|
|||||||
# Thread safety
|
# Thread safety
|
||||||
self.lock = Lock()
|
self.lock = Lock()
|
||||||
|
|
||||||
# Connect to exchange
|
# Connect to exchange - skip connection check in simulation mode
|
||||||
if self.trading_enabled:
|
if self.trading_enabled:
|
||||||
|
if self.simulation_mode:
|
||||||
|
logger.info("TRADING EXECUTOR: Simulation mode - skipping exchange connection check")
|
||||||
|
# In simulation mode, we don't need a real exchange connection
|
||||||
|
# Trading should remain enabled for simulation trades
|
||||||
|
else:
|
||||||
logger.info("TRADING EXECUTOR: Attempting to connect to exchange...")
|
logger.info("TRADING EXECUTOR: Attempting to connect to exchange...")
|
||||||
if not self._connect_exchange():
|
if not self._connect_exchange():
|
||||||
logger.error("TRADING EXECUTOR: Failed initial exchange connection. Trading will be disabled.")
|
logger.error("TRADING EXECUTOR: Failed initial exchange connection. Trading will be disabled.")
|
||||||
@ -230,15 +235,25 @@ class TradingExecutor:
|
|||||||
required_capital = self._calculate_position_size(confidence, current_price)
|
required_capital = self._calculate_position_size(confidence, current_price)
|
||||||
|
|
||||||
# Get available balance for the quote asset
|
# Get available balance for the quote asset
|
||||||
available_balance = self.exchange.get_balance(quote_asset)
|
# For MEXC, prioritize USDT over USDC since most accounts have USDT
|
||||||
|
if quote_asset == 'USDC':
|
||||||
# If USDC balance is insufficient, check USDT as fallback (for MEXC compatibility)
|
# Check USDT first (most common balance)
|
||||||
if available_balance < required_capital and quote_asset == 'USDC':
|
|
||||||
usdt_balance = self.exchange.get_balance('USDT')
|
usdt_balance = self.exchange.get_balance('USDT')
|
||||||
|
usdc_balance = self.exchange.get_balance('USDC')
|
||||||
|
|
||||||
if usdt_balance >= required_capital:
|
if usdt_balance >= required_capital:
|
||||||
available_balance = usdt_balance
|
available_balance = usdt_balance
|
||||||
quote_asset = 'USDT' # Use USDT instead
|
quote_asset = 'USDT' # Use USDT for trading
|
||||||
logger.info(f"BALANCE CHECK: Using USDT fallback balance for {symbol}")
|
logger.info(f"BALANCE CHECK: Using USDT balance for {symbol} (preferred)")
|
||||||
|
elif usdc_balance >= required_capital:
|
||||||
|
available_balance = usdc_balance
|
||||||
|
logger.info(f"BALANCE CHECK: Using USDC balance for {symbol}")
|
||||||
|
else:
|
||||||
|
# Use the larger balance for reporting
|
||||||
|
available_balance = max(usdt_balance, usdc_balance)
|
||||||
|
quote_asset = 'USDT' if usdt_balance > usdc_balance else 'USDC'
|
||||||
|
else:
|
||||||
|
available_balance = self.exchange.get_balance(quote_asset)
|
||||||
|
|
||||||
logger.info(f"BALANCE CHECK: Symbol: {symbol}, Action: {action}, Required: ${required_capital:.2f} {quote_asset}, Available: ${available_balance:.2f} {quote_asset}")
|
logger.info(f"BALANCE CHECK: Symbol: {symbol}, Action: {action}, Required: ${required_capital:.2f} {quote_asset}, Available: ${available_balance:.2f} {quote_asset}")
|
||||||
|
|
||||||
|
@ -229,9 +229,12 @@ class TrainingIntegration:
|
|||||||
# Truncate
|
# Truncate
|
||||||
features = features[:50]
|
features = features[:50]
|
||||||
|
|
||||||
|
# Get the model's device to ensure tensors are on the same device
|
||||||
|
model_device = next(cnn_model.parameters()).device
|
||||||
|
|
||||||
# Create tensors
|
# Create tensors
|
||||||
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(device)
|
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(model_device)
|
||||||
target_tensor = torch.LongTensor([target]).to(device)
|
target_tensor = torch.LongTensor([target]).to(model_device)
|
||||||
|
|
||||||
# Training step
|
# Training step
|
||||||
cnn_model.train()
|
cnn_model.train()
|
||||||
|
@ -1454,9 +1454,10 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
model.train()
|
model.train()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Convert numpy arrays to PyTorch tensors
|
# Convert numpy arrays to PyTorch tensors and move to device
|
||||||
features_tensor = torch.from_numpy(features).float()
|
device = next(model.parameters()).device
|
||||||
targets_tensor = torch.from_numpy(targets).long()
|
features_tensor = torch.from_numpy(features).float().to(device)
|
||||||
|
targets_tensor = torch.from_numpy(targets).long().to(device)
|
||||||
|
|
||||||
# Ensure features_tensor has the correct shape for CNN (batch_size, channels, height, width)
|
# Ensure features_tensor has the correct shape for CNN (batch_size, channels, height, width)
|
||||||
# Assuming features are flattened (batch_size, 15*20) and need to be reshaped to (batch_size, 1, 15, 20)
|
# Assuming features are flattened (batch_size, 15*20) and need to be reshaped to (batch_size, 1, 15, 20)
|
||||||
@ -1471,10 +1472,37 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
# If the CNN expects (batch_size, channels, sequence_length)
|
# If the CNN expects (batch_size, channels, sequence_length)
|
||||||
# features_tensor = features_tensor.view(features_tensor.shape[0], 1, 15 * 20) # Example for 1D CNN
|
# features_tensor = features_tensor.view(features_tensor.shape[0], 1, 15 * 20) # Example for 1D CNN
|
||||||
|
|
||||||
# Let's assume the CNN expects 2D input (batch_size, flattened_features)
|
# Ensure proper shape for CNN input
|
||||||
|
if len(features_tensor.shape) == 2:
|
||||||
|
# If it's (batch_size, features), keep as is for 1D CNN
|
||||||
|
pass
|
||||||
|
elif len(features_tensor.shape) == 1:
|
||||||
|
# If it's (features), add batch dimension
|
||||||
|
features_tensor = features_tensor.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
# Reshape to (batch_size, features) if needed
|
||||||
|
features_tensor = features_tensor.view(features_tensor.shape[0], -1)
|
||||||
|
|
||||||
|
# Limit input size to prevent shape mismatches
|
||||||
|
if features_tensor.shape[1] > 1000: # Limit to 1000 features
|
||||||
|
features_tensor = features_tensor[:, :1000]
|
||||||
|
|
||||||
outputs = model(features_tensor)
|
outputs = model(features_tensor)
|
||||||
|
|
||||||
loss = criterion(outputs, targets_tensor)
|
# Extract logits from model output (model returns a dictionary)
|
||||||
|
if isinstance(outputs, dict):
|
||||||
|
logits = outputs['logits']
|
||||||
|
elif isinstance(outputs, tuple):
|
||||||
|
logits = outputs[0] # First element is usually logits
|
||||||
|
else:
|
||||||
|
logits = outputs
|
||||||
|
|
||||||
|
# Ensure logits is a tensor
|
||||||
|
if not isinstance(logits, torch.Tensor):
|
||||||
|
logger.error(f"CNN output is not a tensor: {type(logits)}")
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
loss = criterion(logits, targets_tensor)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@ -1856,13 +1884,19 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
|
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
|
||||||
and self.orchestrator.rl_agent):
|
and self.orchestrator.rl_agent):
|
||||||
|
|
||||||
# Get Q-values from model
|
# Use RL agent to make prediction
|
||||||
q_values = self.orchestrator.rl_agent.act(current_state, return_q_values=True)
|
current_state = self._get_dqn_state(symbol)
|
||||||
if isinstance(q_values, tuple):
|
if current_state is None:
|
||||||
action, q_vals = q_values
|
return
|
||||||
q_values = q_vals.tolist() if hasattr(q_vals, 'tolist') else [0, 0, 0]
|
action = self.orchestrator.rl_agent.act(current_state, explore=False)
|
||||||
|
# Get Q-values separately if available
|
||||||
|
if hasattr(self.orchestrator.rl_agent, 'policy_net'):
|
||||||
|
with torch.no_grad():
|
||||||
|
state_tensor = torch.FloatTensor(current_state).unsqueeze(0).to(self.orchestrator.rl_agent.device)
|
||||||
|
q_values_tensor = self.orchestrator.rl_agent.policy_net(state_tensor)
|
||||||
|
if isinstance(q_values_tensor, tuple):
|
||||||
|
q_values = q_values_tensor[0].cpu().numpy()[0].tolist()
|
||||||
else:
|
else:
|
||||||
action = q_values
|
|
||||||
q_values = [0.33, 0.33, 0.34] # Default uniform distribution
|
q_values = [0.33, 0.33, 0.34] # Default uniform distribution
|
||||||
|
|
||||||
confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33
|
confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33
|
||||||
|
@ -1,201 +1,121 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
Run Clean Trading Dashboard with Full Training Pipeline
|
Clean Trading Dashboard Runner with Enhanced Stability and Error Handling
|
||||||
Integrated system with both training loop and clean web dashboard
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
# Fix OpenMP library conflicts before importing other modules
|
|
||||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
|
||||||
os.environ['OMP_NUM_THREADS'] = '4'
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import logging
|
||||||
|
import traceback
|
||||||
|
import gc
|
||||||
import time
|
import time
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Add project root to path
|
|
||||||
project_root = Path(__file__).parent
|
|
||||||
sys.path.insert(0, str(project_root))
|
|
||||||
|
|
||||||
from core.config import get_config, setup_logging
|
|
||||||
from core.data_provider import DataProvider
|
|
||||||
|
|
||||||
# Import checkpoint management
|
|
||||||
from utils.checkpoint_manager import get_checkpoint_manager
|
|
||||||
from utils.training_integration import get_training_integration
|
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
setup_logging()
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def start_training_pipeline(orchestrator, trading_executor):
|
def clear_gpu_memory():
|
||||||
"""Start the training pipeline in the background"""
|
"""Clear GPU memory cache"""
|
||||||
logger.info("=" * 70)
|
if torch.cuda.is_available():
|
||||||
logger.info("STARTING TRAINING PIPELINE WITH CLEAN DASHBOARD")
|
torch.cuda.empty_cache()
|
||||||
logger.info("=" * 70)
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
# Initialize checkpoint management
|
def check_system_resources():
|
||||||
checkpoint_manager = get_checkpoint_manager()
|
"""Check if system has enough resources"""
|
||||||
training_integration = get_training_integration()
|
available_ram = psutil.virtual_memory().available / 1024**3
|
||||||
|
if available_ram < 2.0: # Less than 2GB available
|
||||||
|
logger.warning(f"Low RAM: {available_ram:.1f} GB available")
|
||||||
|
gc.collect()
|
||||||
|
clear_gpu_memory()
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
# Training statistics
|
def run_dashboard_with_recovery():
|
||||||
training_stats = {
|
"""Run dashboard with automatic error recovery"""
|
||||||
'iteration_count': 0,
|
max_retries = 3
|
||||||
'total_decisions': 0,
|
retry_count = 0
|
||||||
'successful_trades': 0,
|
|
||||||
'best_performance': 0.0,
|
|
||||||
'last_checkpoint_iteration': 0
|
|
||||||
}
|
|
||||||
|
|
||||||
|
while retry_count < max_retries:
|
||||||
try:
|
try:
|
||||||
# Start real-time processing (available in Enhanced orchestrator)
|
logger.info(f"Starting Clean Trading Dashboard (attempt {retry_count + 1}/{max_retries})")
|
||||||
if hasattr(orchestrator, 'start_realtime_processing'):
|
|
||||||
await orchestrator.start_realtime_processing()
|
|
||||||
logger.info("Real-time processing started")
|
|
||||||
|
|
||||||
# Start COB integration (available in Enhanced orchestrator)
|
# Check system resources
|
||||||
if hasattr(orchestrator, 'start_cob_integration'):
|
if not check_system_resources():
|
||||||
await orchestrator.start_cob_integration()
|
logger.warning("System resources low, waiting 30 seconds...")
|
||||||
logger.info("COB integration started - 5-minute data matrix active")
|
time.sleep(30)
|
||||||
else:
|
continue
|
||||||
logger.info("COB integration not available")
|
|
||||||
|
|
||||||
# Main training loop
|
# Import here to avoid memory issues on restart
|
||||||
iteration = 0
|
|
||||||
last_checkpoint_time = time.time()
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
iteration += 1
|
|
||||||
training_stats['iteration_count'] = iteration
|
|
||||||
|
|
||||||
# Get symbols to process
|
|
||||||
symbols = orchestrator.symbols if hasattr(orchestrator, 'symbols') else ['ETH/USDT']
|
|
||||||
|
|
||||||
# Process each symbol
|
|
||||||
for symbol in symbols:
|
|
||||||
try:
|
|
||||||
# Make trading decision (this triggers model training)
|
|
||||||
decision = await orchestrator.make_trading_decision(symbol)
|
|
||||||
if decision:
|
|
||||||
training_stats['total_decisions'] += 1
|
|
||||||
logger.debug(f"[{symbol}] Decision: {decision.action} @ {decision.confidence:.1%}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error processing {symbol}: {e}")
|
|
||||||
|
|
||||||
# Status logging every 100 iterations
|
|
||||||
if iteration % 100 == 0:
|
|
||||||
current_time = time.time()
|
|
||||||
elapsed = current_time - last_checkpoint_time
|
|
||||||
|
|
||||||
logger.info(f"[TRAINING] Iteration {iteration}, Decisions: {training_stats['total_decisions']}, Time: {elapsed:.1f}s")
|
|
||||||
|
|
||||||
# Models will save their own checkpoints when performance improves
|
|
||||||
training_stats['last_checkpoint_iteration'] = iteration
|
|
||||||
last_checkpoint_time = current_time
|
|
||||||
|
|
||||||
# Brief pause to prevent overwhelming the system
|
|
||||||
await asyncio.sleep(0.1) # 100ms between iterations
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Training loop error: {e}")
|
|
||||||
await asyncio.sleep(5) # Wait longer on error
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Training pipeline error: {e}")
|
|
||||||
import traceback
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
def start_clean_dashboard_with_training():
|
|
||||||
"""Start clean dashboard with full training pipeline"""
|
|
||||||
try:
|
|
||||||
logger.info("=" * 80)
|
|
||||||
logger.info("CLEAN TRADING DASHBOARD + FULL TRAINING PIPELINE")
|
|
||||||
logger.info("=" * 80)
|
|
||||||
logger.info("Features: Real-time Training, COB Integration, Clean UI")
|
|
||||||
logger.info("Universal Data Stream: ENABLED")
|
|
||||||
logger.info("Neural Decision Fusion: ENABLED")
|
|
||||||
logger.info("COB Integration: ENABLED")
|
|
||||||
logger.info("GPU Training: ENABLED")
|
|
||||||
logger.info("Multi-symbol: ETH/USDT, BTC/USDT")
|
|
||||||
|
|
||||||
# Get port from environment or use default
|
|
||||||
dashboard_port = int(os.environ.get('DASHBOARD_PORT', '8051'))
|
|
||||||
logger.info(f"Dashboard: http://127.0.0.1:{dashboard_port}")
|
|
||||||
logger.info("=" * 80)
|
|
||||||
|
|
||||||
# Check environment variables
|
|
||||||
enable_universal_stream = os.environ.get('ENABLE_UNIVERSAL_DATA_STREAM', '1') == '1'
|
|
||||||
enable_nn_fusion = os.environ.get('ENABLE_NN_DECISION_FUSION', '1') == '1'
|
|
||||||
enable_cob = os.environ.get('ENABLE_COB_INTEGRATION', '1') == '1'
|
|
||||||
|
|
||||||
logger.info(f"Universal Data Stream: {'ENABLED' if enable_universal_stream else 'DISABLED'}")
|
|
||||||
logger.info(f"Neural Decision Fusion: {'ENABLED' if enable_nn_fusion else 'DISABLED'}")
|
|
||||||
logger.info(f"COB Integration: {'ENABLED' if enable_cob else 'DISABLED'}")
|
|
||||||
|
|
||||||
# Get configuration
|
|
||||||
config = get_config()
|
|
||||||
|
|
||||||
# Initialize core components
|
|
||||||
from core.data_provider import DataProvider
|
from core.data_provider import DataProvider
|
||||||
from core.orchestrator import TradingOrchestrator
|
from core.orchestrator import TradingOrchestrator
|
||||||
from core.trading_executor import TradingExecutor
|
from core.trading_executor import TradingExecutor
|
||||||
|
|
||||||
# Create data provider
|
|
||||||
data_provider = DataProvider()
|
|
||||||
|
|
||||||
# Create enhanced orchestrator with COB integration - stable and efficient
|
|
||||||
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
|
|
||||||
logger.info("Enhanced Trading Orchestrator created with COB integration")
|
|
||||||
|
|
||||||
# Create trading executor
|
|
||||||
trading_executor = TradingExecutor()
|
|
||||||
|
|
||||||
# Import clean dashboard
|
|
||||||
from web.clean_dashboard import create_clean_dashboard
|
from web.clean_dashboard import create_clean_dashboard
|
||||||
|
|
||||||
# Create clean dashboard
|
logger.info("Creating data provider...")
|
||||||
dashboard = create_clean_dashboard(
|
data_provider = DataProvider()
|
||||||
|
|
||||||
|
logger.info("Creating trading orchestrator...")
|
||||||
|
orchestrator = TradingOrchestrator(
|
||||||
data_provider=data_provider,
|
data_provider=data_provider,
|
||||||
orchestrator=orchestrator,
|
enhanced_rl_training=True
|
||||||
trading_executor=trading_executor
|
|
||||||
)
|
)
|
||||||
logger.info("Clean Trading Dashboard created")
|
|
||||||
|
|
||||||
# Start training pipeline in background thread
|
logger.info("Creating trading executor...")
|
||||||
def training_worker():
|
trading_executor = TradingExecutor()
|
||||||
"""Run training pipeline in background"""
|
|
||||||
|
logger.info("Creating clean dashboard...")
|
||||||
|
dashboard = create_clean_dashboard(data_provider, orchestrator, trading_executor)
|
||||||
|
|
||||||
|
logger.info("Dashboard created successfully")
|
||||||
|
logger.info("=== Clean Trading Dashboard Status ===")
|
||||||
|
logger.info("- Data Provider: Active")
|
||||||
|
logger.info("- Trading Orchestrator: Active")
|
||||||
|
logger.info("- Trading Executor: Active")
|
||||||
|
logger.info("- Enhanced Training: Active")
|
||||||
|
logger.info("- Dashboard: Ready")
|
||||||
|
logger.info("=======================================")
|
||||||
|
|
||||||
|
# Start the dashboard server with error handling
|
||||||
try:
|
try:
|
||||||
asyncio.run(start_training_pipeline(orchestrator, trading_executor))
|
logger.info("Starting dashboard server on http://127.0.0.1:8050")
|
||||||
except Exception as e:
|
dashboard.run_server(host='127.0.0.1', port=8050, debug=False)
|
||||||
logger.error(f"Training worker error: {e}")
|
|
||||||
|
|
||||||
training_thread = threading.Thread(target=training_worker, daemon=True)
|
|
||||||
training_thread.start()
|
|
||||||
logger.info("Training pipeline started in background")
|
|
||||||
|
|
||||||
# Wait a moment for training to initialize
|
|
||||||
time.sleep(3)
|
|
||||||
|
|
||||||
# Start dashboard server (this blocks)
|
|
||||||
logger.info(" Starting Clean Dashboard Server...")
|
|
||||||
dashboard.run_server(host='127.0.0.1', port=dashboard_port, debug=False)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("System stopped by user")
|
logger.info("Dashboard stopped by user")
|
||||||
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error running clean dashboard with training: {e}")
|
logger.error(f"Dashboard server error: {e}")
|
||||||
import traceback
|
logger.error(traceback.format_exc())
|
||||||
traceback.print_exc()
|
raise
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Critical error in dashboard: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count < max_retries:
|
||||||
|
logger.info(f"Attempting recovery... ({retry_count}/{max_retries})")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
gc.collect()
|
||||||
|
clear_gpu_memory()
|
||||||
|
|
||||||
|
# Wait before retry
|
||||||
|
wait_time = 30 * retry_count # Exponential backoff
|
||||||
|
logger.info(f"Waiting {wait_time} seconds before retry...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
logger.error("Max retries reached. Exiting.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main function"""
|
|
||||||
start_clean_dashboard_with_training()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
try:
|
||||||
|
run_dashboard_with_recovery()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Application stopped by user")
|
||||||
|
sys.exit(0)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Fatal error: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
sys.exit(1)
|
@ -205,6 +205,9 @@ class CleanTradingDashboard:
|
|||||||
# Start signal generation loop to ensure continuous trading signals
|
# Start signal generation loop to ensure continuous trading signals
|
||||||
self._start_signal_generation_loop()
|
self._start_signal_generation_loop()
|
||||||
|
|
||||||
|
# Start live balance sync for trading
|
||||||
|
self._start_live_balance_sync()
|
||||||
|
|
||||||
# Start training sessions if models are showing FRESH status
|
# Start training sessions if models are showing FRESH status
|
||||||
threading.Thread(target=self._delayed_training_check, daemon=True).start()
|
threading.Thread(target=self._delayed_training_check, daemon=True).start()
|
||||||
|
|
||||||
@ -319,6 +322,66 @@ class CleanTradingDashboard:
|
|||||||
logger.warning(f"Error getting balance: {e}")
|
logger.warning(f"Error getting balance: {e}")
|
||||||
return 100.0 # Default balance
|
return 100.0 # Default balance
|
||||||
|
|
||||||
|
def _get_live_balance(self) -> float:
|
||||||
|
"""Get real-time balance from exchange when in live trading mode"""
|
||||||
|
try:
|
||||||
|
if self.trading_executor:
|
||||||
|
# Check if we're in live trading mode
|
||||||
|
is_live = (hasattr(self.trading_executor, 'trading_enabled') and
|
||||||
|
self.trading_executor.trading_enabled and
|
||||||
|
hasattr(self.trading_executor, 'simulation_mode') and
|
||||||
|
not self.trading_executor.simulation_mode)
|
||||||
|
|
||||||
|
if is_live and hasattr(self.trading_executor, 'exchange'):
|
||||||
|
# Get real balance from exchange (throttled to avoid API spam)
|
||||||
|
import time
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Cache balance for 5 seconds for more frequent updates in live trading
|
||||||
|
if not hasattr(self, '_last_balance_check') or current_time - self._last_balance_check > 5:
|
||||||
|
exchange = self.trading_executor.exchange
|
||||||
|
if hasattr(exchange, 'get_balance'):
|
||||||
|
live_balance = exchange.get_balance('USDC')
|
||||||
|
if live_balance is not None and live_balance > 0:
|
||||||
|
self._cached_live_balance = live_balance
|
||||||
|
self._last_balance_check = current_time
|
||||||
|
logger.info(f"LIVE BALANCE: Retrieved ${live_balance:.2f} USDC from MEXC")
|
||||||
|
return live_balance
|
||||||
|
else:
|
||||||
|
logger.warning(f"LIVE BALANCE: Retrieved ${live_balance:.2f} USDC - checking USDT as fallback")
|
||||||
|
# Also try USDT as fallback since user might have USDT
|
||||||
|
usdt_balance = exchange.get_balance('USDT')
|
||||||
|
if usdt_balance is not None and usdt_balance > 0:
|
||||||
|
self._cached_live_balance = usdt_balance
|
||||||
|
self._last_balance_check = current_time
|
||||||
|
logger.info(f"LIVE BALANCE: Using USDT balance ${usdt_balance:.2f}")
|
||||||
|
return usdt_balance
|
||||||
|
else:
|
||||||
|
logger.warning("LIVE BALANCE: Exchange does not have get_balance method")
|
||||||
|
else:
|
||||||
|
# Return cached balance if within 10 second window
|
||||||
|
if hasattr(self, '_cached_live_balance'):
|
||||||
|
return self._cached_live_balance
|
||||||
|
elif hasattr(self.trading_executor, 'simulation_mode') and self.trading_executor.simulation_mode:
|
||||||
|
# In simulation mode, show dynamic balance based on P&L
|
||||||
|
initial_balance = self._get_initial_balance()
|
||||||
|
realized_pnl = sum(trade.get('pnl', 0) for trade in self.closed_trades)
|
||||||
|
simulation_balance = initial_balance + realized_pnl
|
||||||
|
logger.debug(f"SIMULATION BALANCE: ${simulation_balance:.2f} (Initial: ${initial_balance:.2f} + P&L: ${realized_pnl:.2f})")
|
||||||
|
return simulation_balance
|
||||||
|
else:
|
||||||
|
logger.debug("LIVE BALANCE: Not in live trading mode, using initial balance")
|
||||||
|
|
||||||
|
# Fallback to initial balance for simulation mode
|
||||||
|
return self._get_initial_balance()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting live balance: {e}")
|
||||||
|
# Return cached balance if available, otherwise fallback
|
||||||
|
if hasattr(self, '_cached_live_balance'):
|
||||||
|
return self._cached_live_balance
|
||||||
|
return self._get_initial_balance()
|
||||||
|
|
||||||
def _setup_layout(self):
|
def _setup_layout(self):
|
||||||
"""Setup the dashboard layout using layout manager"""
|
"""Setup the dashboard layout using layout manager"""
|
||||||
self.app.layout = self.layout_manager.create_main_layout()
|
self.app.layout = self.layout_manager.create_main_layout()
|
||||||
@ -411,17 +474,48 @@ class CleanTradingDashboard:
|
|||||||
trade_count = len(self.closed_trades)
|
trade_count = len(self.closed_trades)
|
||||||
trade_str = f"{trade_count} Trades"
|
trade_str = f"{trade_count} Trades"
|
||||||
|
|
||||||
# Portfolio value
|
# Portfolio value - use live balance for live trading
|
||||||
initial_balance = self._get_initial_balance()
|
current_balance = self._get_live_balance()
|
||||||
portfolio_value = initial_balance + total_session_pnl # Use total P&L including unrealized
|
portfolio_value = current_balance + total_session_pnl # Use total P&L including unrealized
|
||||||
portfolio_str = f"${portfolio_value:.2f}"
|
|
||||||
|
|
||||||
# MEXC status
|
# Show live balance indicator for live trading
|
||||||
|
balance_indicator = ""
|
||||||
|
if self.trading_executor:
|
||||||
|
is_live = (hasattr(self.trading_executor, 'trading_enabled') and
|
||||||
|
self.trading_executor.trading_enabled and
|
||||||
|
hasattr(self.trading_executor, 'simulation_mode') and
|
||||||
|
not self.trading_executor.simulation_mode)
|
||||||
|
if is_live:
|
||||||
|
balance_indicator = " (LIVE)"
|
||||||
|
|
||||||
|
portfolio_str = f"${portfolio_value:.2f}{balance_indicator}"
|
||||||
|
|
||||||
|
# MEXC status with balance info
|
||||||
mexc_status = "SIM"
|
mexc_status = "SIM"
|
||||||
if self.trading_executor:
|
if self.trading_executor:
|
||||||
if hasattr(self.trading_executor, 'trading_enabled') and self.trading_executor.trading_enabled:
|
if hasattr(self.trading_executor, 'trading_enabled') and self.trading_executor.trading_enabled:
|
||||||
if hasattr(self.trading_executor, 'simulation_mode') and not self.trading_executor.simulation_mode:
|
if hasattr(self.trading_executor, 'simulation_mode') and self.trading_executor.simulation_mode:
|
||||||
mexc_status = "LIVE"
|
# Show simulation mode status with simulated balance
|
||||||
|
mexc_status = f"SIM - ${current_balance:.2f}"
|
||||||
|
elif hasattr(self.trading_executor, 'simulation_mode') and not self.trading_executor.simulation_mode:
|
||||||
|
# Show live balance in MEXC status - detect currency
|
||||||
|
try:
|
||||||
|
exchange = self.trading_executor.exchange
|
||||||
|
usdc_balance = exchange.get_balance('USDC') if hasattr(exchange, 'get_balance') else 0
|
||||||
|
usdt_balance = exchange.get_balance('USDT') if hasattr(exchange, 'get_balance') else 0
|
||||||
|
|
||||||
|
if usdc_balance > 0:
|
||||||
|
mexc_status = f"LIVE - ${usdc_balance:.2f} USDC"
|
||||||
|
elif usdt_balance > 0:
|
||||||
|
mexc_status = f"LIVE - ${usdt_balance:.2f} USDT"
|
||||||
|
else:
|
||||||
|
mexc_status = f"LIVE - ${current_balance:.2f}"
|
||||||
|
except:
|
||||||
|
mexc_status = f"LIVE - ${current_balance:.2f}"
|
||||||
|
else:
|
||||||
|
mexc_status = "SIM"
|
||||||
|
else:
|
||||||
|
mexc_status = "DISABLED"
|
||||||
|
|
||||||
return price_str, session_pnl_str, position_str, trade_str, portfolio_str, mexc_status
|
return price_str, session_pnl_str, position_str, trade_str, portfolio_str, mexc_status
|
||||||
|
|
||||||
@ -2877,6 +2971,39 @@ class CleanTradingDashboard:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error starting signal generation loop: {e}")
|
logger.error(f"Error starting signal generation loop: {e}")
|
||||||
|
|
||||||
|
def _start_live_balance_sync(self):
|
||||||
|
"""Start continuous live balance synchronization for trading"""
|
||||||
|
def balance_sync_worker():
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if self.trading_executor:
|
||||||
|
is_live = (hasattr(self.trading_executor, 'trading_enabled') and
|
||||||
|
self.trading_executor.trading_enabled and
|
||||||
|
hasattr(self.trading_executor, 'simulation_mode') and
|
||||||
|
not self.trading_executor.simulation_mode)
|
||||||
|
|
||||||
|
if is_live and hasattr(self.trading_executor, 'exchange'):
|
||||||
|
# Force balance refresh every 15 seconds in live mode
|
||||||
|
if hasattr(self, '_last_balance_check'):
|
||||||
|
del self._last_balance_check # Force refresh
|
||||||
|
|
||||||
|
balance = self._get_live_balance()
|
||||||
|
if balance > 0:
|
||||||
|
logger.debug(f"BALANCE SYNC: Live balance: ${balance:.2f}")
|
||||||
|
else:
|
||||||
|
logger.warning("BALANCE SYNC: Could not retrieve live balance")
|
||||||
|
|
||||||
|
# Sync balance every 15 seconds for live trading
|
||||||
|
time.sleep(15)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error in balance sync loop: {e}")
|
||||||
|
time.sleep(30) # Wait longer on error
|
||||||
|
|
||||||
|
# Start balance sync thread only if we have trading enabled
|
||||||
|
if self.trading_executor:
|
||||||
|
threading.Thread(target=balance_sync_worker, daemon=True).start()
|
||||||
|
logger.info("BALANCE SYNC: Background balance synchronization started")
|
||||||
|
|
||||||
def _generate_dqn_signal(self, symbol: str, current_price: float) -> Optional[Dict]:
|
def _generate_dqn_signal(self, symbol: str, current_price: float) -> Optional[Dict]:
|
||||||
"""Generate trading signal using DQN agent - NOT AVAILABLE IN BASIC ORCHESTRATOR"""
|
"""Generate trading signal using DQN agent - NOT AVAILABLE IN BASIC ORCHESTRATOR"""
|
||||||
# Basic orchestrator doesn't have DQN features
|
# Basic orchestrator doesn't have DQN features
|
||||||
@ -4519,28 +4646,35 @@ class CleanTradingDashboard:
|
|||||||
imbalance = cob_snapshot['stats']['imbalance']
|
imbalance = cob_snapshot['stats']['imbalance']
|
||||||
abs_imbalance = abs(imbalance)
|
abs_imbalance = abs(imbalance)
|
||||||
|
|
||||||
# Dynamic threshold based on imbalance strength
|
# Dynamic threshold based on imbalance strength with realistic confidence
|
||||||
if abs_imbalance > 0.8: # Very strong imbalance (>80%)
|
if abs_imbalance > 0.8: # Very strong imbalance (>80%)
|
||||||
threshold = 0.05 # 5% threshold for very strong signals
|
threshold = 0.05 # 5% threshold for very strong signals
|
||||||
confidence_multiplier = 3.0
|
base_confidence = 0.85 # High but not perfect confidence
|
||||||
|
confidence_boost = (abs_imbalance - 0.8) * 0.75 # Scale remaining 15%
|
||||||
elif abs_imbalance > 0.5: # Strong imbalance (>50%)
|
elif abs_imbalance > 0.5: # Strong imbalance (>50%)
|
||||||
threshold = 0.1 # 10% threshold for strong signals
|
threshold = 0.1 # 10% threshold for strong signals
|
||||||
confidence_multiplier = 2.5
|
base_confidence = 0.70 # Good confidence
|
||||||
|
confidence_boost = (abs_imbalance - 0.5) * 0.50 # Scale up to 85%
|
||||||
elif abs_imbalance > 0.3: # Moderate imbalance (>30%)
|
elif abs_imbalance > 0.3: # Moderate imbalance (>30%)
|
||||||
threshold = 0.15 # 15% threshold for moderate signals
|
threshold = 0.15 # 15% threshold for moderate signals
|
||||||
confidence_multiplier = 2.0
|
base_confidence = 0.55 # Moderate confidence
|
||||||
|
confidence_boost = (abs_imbalance - 0.3) * 0.75 # Scale up to 70%
|
||||||
else: # Weak imbalance
|
else: # Weak imbalance
|
||||||
threshold = 0.2 # 20% threshold for weak signals
|
threshold = 0.2 # 20% threshold for weak signals
|
||||||
confidence_multiplier = 1.5
|
base_confidence = 0.35 # Low confidence
|
||||||
|
confidence_boost = abs_imbalance * 0.67 # Scale up to 55%
|
||||||
|
|
||||||
# Generate signal if imbalance exceeds threshold
|
# Generate signal if imbalance exceeds threshold
|
||||||
if abs_imbalance > threshold:
|
if abs_imbalance > threshold:
|
||||||
|
# Calculate more realistic confidence (never exactly 1.0)
|
||||||
|
final_confidence = min(0.95, base_confidence + confidence_boost)
|
||||||
|
|
||||||
signal = {
|
signal = {
|
||||||
'timestamp': datetime.now(),
|
'timestamp': datetime.now(),
|
||||||
'type': 'cob_liquidity_imbalance',
|
'type': 'cob_liquidity_imbalance',
|
||||||
'action': 'BUY' if imbalance > 0 else 'SELL',
|
'action': 'BUY' if imbalance > 0 else 'SELL',
|
||||||
'symbol': symbol,
|
'symbol': symbol,
|
||||||
'confidence': min(1.0, abs_imbalance * confidence_multiplier),
|
'confidence': final_confidence,
|
||||||
'strength': abs_imbalance,
|
'strength': abs_imbalance,
|
||||||
'threshold_used': threshold,
|
'threshold_used': threshold,
|
||||||
'signal_strength': 'very_strong' if abs_imbalance > 0.8 else 'strong' if abs_imbalance > 0.5 else 'moderate' if abs_imbalance > 0.3 else 'weak',
|
'signal_strength': 'very_strong' if abs_imbalance > 0.8 else 'strong' if abs_imbalance > 0.5 else 'moderate' if abs_imbalance > 0.3 else 'weak',
|
||||||
@ -5478,15 +5612,18 @@ class CleanTradingDashboard:
|
|||||||
import torch
|
import torch
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
# Get the model's device to ensure tensors are on the same device
|
||||||
|
model_device = next(model.parameters()).device
|
||||||
|
|
||||||
# Handle different input shapes for different CNN models
|
# Handle different input shapes for different CNN models
|
||||||
if hasattr(model, 'input_shape'):
|
if hasattr(model, 'input_shape'):
|
||||||
# EnhancedCNN model
|
# EnhancedCNN model
|
||||||
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(device)
|
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(model_device)
|
||||||
else:
|
else:
|
||||||
# Basic CNN model - reshape appropriately
|
# Basic CNN model - reshape appropriately
|
||||||
features_tensor = torch.FloatTensor(features).unsqueeze(0).unsqueeze(0).to(device)
|
features_tensor = torch.FloatTensor(features).unsqueeze(0).unsqueeze(0).to(model_device)
|
||||||
|
|
||||||
target_tensor = torch.LongTensor([target]).to(device)
|
target_tensor = torch.LongTensor([target]).to(model_device)
|
||||||
|
|
||||||
# Set model to training mode and zero gradients
|
# Set model to training mode and zero gradients
|
||||||
model.train()
|
model.train()
|
||||||
@ -5605,10 +5742,11 @@ class CleanTradingDashboard:
|
|||||||
if hasattr(network, 'forward'):
|
if hasattr(network, 'forward'):
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
# Get the model's device to ensure tensors are on the same device
|
||||||
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(device)
|
model_device = next(network.parameters()).device
|
||||||
action_target_tensor = torch.LongTensor([action_target]).to(device)
|
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(model_device)
|
||||||
confidence_target_tensor = torch.FloatTensor([confidence_target]).to(device)
|
action_target_tensor = torch.LongTensor([action_target]).to(model_device)
|
||||||
|
confidence_target_tensor = torch.FloatTensor([confidence_target]).to(model_device)
|
||||||
|
|
||||||
network.train()
|
network.train()
|
||||||
network_output = network(features_tensor)
|
network_output = network(features_tensor)
|
||||||
|
Reference in New Issue
Block a user