Compare commits
5 Commits
8bacf3c537
...
ed42e7c238
Author | SHA1 | Date | |
---|---|---|---|
ed42e7c238 | |||
0c4c682498 | |||
d0cf04536c | |||
cf91e090c8 | |||
978cecf0c5 |
@ -49,14 +49,14 @@ class MEXCInterface(ExchangeInterface):
|
||||
"""Test connection to MEXC API by fetching account info."""
|
||||
if not self.api_key or not self.api_secret:
|
||||
logger.error("MEXC API key or secret not set. Cannot connect.")
|
||||
return False
|
||||
return False
|
||||
|
||||
# Test connection by making a small, authenticated request
|
||||
try:
|
||||
account_info = self.get_account_info()
|
||||
if account_info:
|
||||
logger.info("Successfully connected to MEXC API and retrieved account info.")
|
||||
return True
|
||||
return True
|
||||
else:
|
||||
logger.error("Failed to connect to MEXC API: Could not retrieve account info.")
|
||||
return False
|
||||
@ -65,11 +65,19 @@ class MEXCInterface(ExchangeInterface):
|
||||
return False
|
||||
|
||||
def _format_spot_symbol(self, symbol: str) -> str:
|
||||
"""Formats a symbol to MEXC spot API standard (e.g., 'ETH/USDT' -> 'ETHUSDT')."""
|
||||
"""Formats a symbol to MEXC spot API standard (e.g., 'ETH/USDT' -> 'ETHUSDC')."""
|
||||
if '/' in symbol:
|
||||
base, quote = symbol.split('/')
|
||||
# Convert USDT to USDC for MEXC spot trading
|
||||
if quote.upper() == 'USDT':
|
||||
quote = 'USDC'
|
||||
return f"{base.upper()}{quote.upper()}"
|
||||
return symbol.upper()
|
||||
else:
|
||||
# Convert USDT to USDC for symbols like ETHUSDT
|
||||
symbol = symbol.upper()
|
||||
if symbol.endswith('USDT'):
|
||||
symbol = symbol.replace('USDT', 'USDC')
|
||||
return symbol
|
||||
|
||||
def _format_futures_symbol(self, symbol: str) -> str:
|
||||
"""Formats a symbol to MEXC futures API standard (e.g., 'ETH/USDT' -> 'ETH_USDT')."""
|
||||
@ -77,22 +85,37 @@ class MEXCInterface(ExchangeInterface):
|
||||
return symbol.replace('/', '_').upper()
|
||||
|
||||
def _generate_signature(self, timestamp: str, method: str, endpoint: str, params: Dict[str, Any]) -> str:
|
||||
"""Generate signature for private API calls"""
|
||||
# Build the string to sign
|
||||
sign_str = self.api_key + timestamp
|
||||
if params:
|
||||
# Append all parameters sorted by key, without URL encoding for signature
|
||||
query_str = "&".join([f"{k}={v}" for k, v in sorted(params.items()) if k != 'signature'])
|
||||
if query_str:
|
||||
sign_str += query_str
|
||||
"""Generate signature for private API calls using MEXC's expected parameter order"""
|
||||
# MEXC requires specific parameter ordering, not alphabetical
|
||||
# Based on successful test: symbol, side, type, quantity, timestamp, then other params
|
||||
mexc_param_order = ['symbol', 'side', 'type', 'quantity', 'timestamp', 'recvWindow']
|
||||
|
||||
# Build ordered parameter list
|
||||
ordered_params = []
|
||||
|
||||
# Add parameters in MEXC's expected order
|
||||
for param_name in mexc_param_order:
|
||||
if param_name in params and param_name != 'signature':
|
||||
ordered_params.append(f"{param_name}={params[param_name]}")
|
||||
|
||||
# Add any remaining parameters not in the standard order (alphabetically)
|
||||
remaining_params = {k: v for k, v in params.items() if k not in mexc_param_order and k != 'signature'}
|
||||
for key in sorted(remaining_params.keys()):
|
||||
ordered_params.append(f"{key}={remaining_params[key]}")
|
||||
|
||||
# Create query string (MEXC doesn't use the api_key + timestamp prefix)
|
||||
query_string = '&'.join(ordered_params)
|
||||
|
||||
logger.debug(f"MEXC signature query string: {query_string}")
|
||||
|
||||
logger.debug(f"Signature string: {sign_str}")
|
||||
# Generate HMAC SHA256 signature
|
||||
signature = hmac.new(
|
||||
self.api_secret.encode('utf-8'),
|
||||
sign_str.encode('utf-8'),
|
||||
query_string.encode('utf-8'),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
logger.debug(f"MEXC signature: {signature}")
|
||||
return signature
|
||||
|
||||
def _send_public_request(self, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
@ -139,24 +162,26 @@ class MEXCInterface(ExchangeInterface):
|
||||
"Request-Time": timestamp
|
||||
}
|
||||
|
||||
# Ensure endpoint does not start with a slash to avoid double slashes
|
||||
if endpoint.startswith('/'):
|
||||
endpoint = endpoint.lstrip('/')
|
||||
# For spot API, use the correct endpoint format
|
||||
if not endpoint.startswith('api/v3/'):
|
||||
endpoint = f"api/v3/{endpoint}"
|
||||
url = f"{self.base_url}/{endpoint}"
|
||||
try:
|
||||
if method.upper() == "GET":
|
||||
response = self.session.get(url, headers=headers, params=params, timeout=10)
|
||||
elif method.upper() == "POST":
|
||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
response = self.session.post(url, headers=headers, data=params, timeout=10)
|
||||
# MEXC expects POST parameters as query string, not in body
|
||||
response = self.session.post(url, headers=headers, params=params, timeout=10)
|
||||
else:
|
||||
logger.error(f"Unsupported method: {method}")
|
||||
return None
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if data.get('success', False):
|
||||
return data.get('data', data)
|
||||
# For successful responses, return the data directly
|
||||
# MEXC doesn't always use 'success' field for successful operations
|
||||
if response.status_code == 200:
|
||||
return data
|
||||
else:
|
||||
logger.error(f"API error: Status Code: {response.status_code}, Response: {response.text}")
|
||||
return None
|
||||
@ -170,7 +195,7 @@ class MEXCInterface(ExchangeInterface):
|
||||
|
||||
def get_account_info(self) -> Dict[str, Any]:
|
||||
"""Get account information"""
|
||||
endpoint = "/api/v3/account"
|
||||
endpoint = "account"
|
||||
result = self._send_private_request("GET", endpoint, {})
|
||||
return result if result is not None else {}
|
||||
|
||||
@ -182,7 +207,7 @@ class MEXCInterface(ExchangeInterface):
|
||||
if balance.get('asset') == asset.upper():
|
||||
return float(balance.get('free', 0.0))
|
||||
logger.warning(f"Could not retrieve free balance for {asset}")
|
||||
return 0.0
|
||||
return 0.0
|
||||
|
||||
def get_ticker(self, symbol: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get ticker information for a symbol."""
|
||||
@ -190,13 +215,13 @@ class MEXCInterface(ExchangeInterface):
|
||||
endpoint = "ticker/24hr"
|
||||
params = {'symbol': formatted_symbol}
|
||||
|
||||
response = self._send_public_request('GET', endpoint, params)
|
||||
response = self._send_public_request('GET', endpoint, params)
|
||||
|
||||
if response:
|
||||
# MEXC ticker returns a dictionary if single symbol, list if all symbols
|
||||
if isinstance(response, dict):
|
||||
if isinstance(response, dict):
|
||||
ticker_data = response
|
||||
elif isinstance(response, list) and len(response) > 0:
|
||||
elif isinstance(response, list) and len(response) > 0:
|
||||
# If the response is a list, try to find the specific symbol
|
||||
found_ticker = next((item for item in response if item.get('symbol') == formatted_symbol), None)
|
||||
if found_ticker:
|
||||
@ -235,9 +260,39 @@ class MEXCInterface(ExchangeInterface):
|
||||
logger.error(f"Failed to get ticker for {symbol}")
|
||||
return None
|
||||
|
||||
def get_api_symbols(self) -> List[str]:
|
||||
"""Get list of symbols supported for API trading"""
|
||||
try:
|
||||
endpoint = "selfSymbols"
|
||||
result = self._send_private_request("GET", endpoint, {})
|
||||
if result and 'data' in result:
|
||||
return result['data']
|
||||
elif isinstance(result, list):
|
||||
return result
|
||||
else:
|
||||
logger.warning(f"Unexpected response format for API symbols: {result}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API symbols: {e}")
|
||||
return []
|
||||
|
||||
def is_symbol_supported(self, symbol: str) -> bool:
|
||||
"""Check if a symbol is supported for API trading"""
|
||||
formatted_symbol = self._format_spot_symbol(symbol)
|
||||
supported_symbols = self.get_api_symbols()
|
||||
return formatted_symbol in supported_symbols
|
||||
|
||||
def place_order(self, symbol: str, side: str, order_type: str, quantity: float, price: Optional[float] = None) -> Dict[str, Any]:
|
||||
"""Place a new order on MEXC."""
|
||||
formatted_symbol = self._format_spot_symbol(symbol)
|
||||
|
||||
# Check if symbol is supported for API trading
|
||||
if not self.is_symbol_supported(symbol):
|
||||
supported_symbols = self.get_api_symbols()
|
||||
logger.error(f"Symbol {formatted_symbol} is not supported for API trading")
|
||||
logger.info(f"Supported symbols include: {supported_symbols[:10]}...") # Show first 10
|
||||
return {}
|
||||
|
||||
endpoint = "order"
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
@ -264,7 +319,7 @@ class MEXCInterface(ExchangeInterface):
|
||||
order_result = self._send_private_request('POST', endpoint, params)
|
||||
if order_result:
|
||||
logger.info(f"MEXC: Order placed successfully: {order_result}")
|
||||
return order_result
|
||||
return order_result
|
||||
else:
|
||||
logger.error(f"MEXC: Error placing order: {order_result}")
|
||||
return {}
|
||||
@ -329,7 +384,7 @@ class MEXCInterface(ExchangeInterface):
|
||||
open_orders = self._send_private_request('GET', endpoint, params)
|
||||
if open_orders and isinstance(open_orders, list):
|
||||
logger.info(f"MEXC: Retrieved {len(open_orders)} open orders.")
|
||||
return open_orders
|
||||
return open_orders
|
||||
else:
|
||||
logger.error(f"MEXC: Error getting open orders: {open_orders}")
|
||||
return []
|
||||
|
@ -1,285 +1,15 @@
|
||||
{
|
||||
"example_cnn": [
|
||||
{
|
||||
"checkpoint_id": "example_cnn_20250624_213913",
|
||||
"model_name": "example_cnn",
|
||||
"model_type": "cnn",
|
||||
"file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt",
|
||||
"created_at": "2025-06-24T21:39:13.559926",
|
||||
"file_size_mb": 0.0797882080078125,
|
||||
"performance_score": 65.67219525381417,
|
||||
"accuracy": 0.28019601724789606,
|
||||
"loss": 1.9252885885630378,
|
||||
"val_accuracy": 0.21531048803825983,
|
||||
"val_loss": 1.953166686238386,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": 1,
|
||||
"training_time_hours": 0.1,
|
||||
"total_parameters": 20163,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "example_cnn_20250624_213913",
|
||||
"model_name": "example_cnn",
|
||||
"model_type": "cnn",
|
||||
"file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt",
|
||||
"created_at": "2025-06-24T21:39:13.563368",
|
||||
"file_size_mb": 0.0797882080078125,
|
||||
"performance_score": 85.85617724870231,
|
||||
"accuracy": 0.3797766367576808,
|
||||
"loss": 1.738881079808816,
|
||||
"val_accuracy": 0.31375868989071576,
|
||||
"val_loss": 1.758474336328537,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": 2,
|
||||
"training_time_hours": 0.2,
|
||||
"total_parameters": 20163,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "example_cnn_20250624_213913",
|
||||
"model_name": "example_cnn",
|
||||
"model_type": "cnn",
|
||||
"file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt",
|
||||
"created_at": "2025-06-24T21:39:13.566494",
|
||||
"file_size_mb": 0.0797882080078125,
|
||||
"performance_score": 96.86696983784515,
|
||||
"accuracy": 0.41565501055141396,
|
||||
"loss": 1.731468873500252,
|
||||
"val_accuracy": 0.38848400580514414,
|
||||
"val_loss": 1.8154629243104177,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": 3,
|
||||
"training_time_hours": 0.30000000000000004,
|
||||
"total_parameters": 20163,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "example_cnn_20250624_213913",
|
||||
"model_name": "example_cnn",
|
||||
"model_type": "cnn",
|
||||
"file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt",
|
||||
"created_at": "2025-06-24T21:39:13.569547",
|
||||
"file_size_mb": 0.0797882080078125,
|
||||
"performance_score": 106.29887197896815,
|
||||
"accuracy": 0.4639872237832544,
|
||||
"loss": 1.4731813440281318,
|
||||
"val_accuracy": 0.4291565645756503,
|
||||
"val_loss": 1.5423255128941882,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": 4,
|
||||
"training_time_hours": 0.4,
|
||||
"total_parameters": 20163,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "example_cnn_20250624_213913",
|
||||
"model_name": "example_cnn",
|
||||
"model_type": "cnn",
|
||||
"file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt",
|
||||
"created_at": "2025-06-24T21:39:13.575375",
|
||||
"file_size_mb": 0.0797882080078125,
|
||||
"performance_score": 115.87168812846218,
|
||||
"accuracy": 0.5256293272461906,
|
||||
"loss": 1.3264778472364203,
|
||||
"val_accuracy": 0.46011511860837684,
|
||||
"val_loss": 1.3762786097581432,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": 5,
|
||||
"training_time_hours": 0.5,
|
||||
"total_parameters": 20163,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
}
|
||||
],
|
||||
"example_manual": [
|
||||
{
|
||||
"checkpoint_id": "example_manual_20250624_213913",
|
||||
"model_name": "example_manual",
|
||||
"model_type": "cnn",
|
||||
"file_path": "NN\\models\\saved\\example_manual\\example_manual_20250624_213913.pt",
|
||||
"created_at": "2025-06-24T21:39:13.578488",
|
||||
"file_size_mb": 0.0018634796142578125,
|
||||
"performance_score": 186.07000000000002,
|
||||
"accuracy": 0.85,
|
||||
"loss": 0.45,
|
||||
"val_accuracy": 0.82,
|
||||
"val_loss": 0.48,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": 25,
|
||||
"training_time_hours": 2.5,
|
||||
"total_parameters": 33,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
}
|
||||
],
|
||||
"extrema_trainer": [
|
||||
{
|
||||
"checkpoint_id": "extrema_trainer_20250624_221645",
|
||||
"model_name": "extrema_trainer",
|
||||
"model_type": "extrema_trainer",
|
||||
"file_path": "NN\\models\\saved\\extrema_trainer\\extrema_trainer_20250624_221645.pt",
|
||||
"created_at": "2025-06-24T22:16:45.728299",
|
||||
"file_size_mb": 0.0013427734375,
|
||||
"performance_score": 0.1,
|
||||
"accuracy": 0.0,
|
||||
"loss": null,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "extrema_trainer_20250624_221915",
|
||||
"model_name": "extrema_trainer",
|
||||
"model_type": "extrema_trainer",
|
||||
"file_path": "NN\\models\\saved\\extrema_trainer\\extrema_trainer_20250624_221915.pt",
|
||||
"created_at": "2025-06-24T22:19:15.325368",
|
||||
"file_size_mb": 0.0013427734375,
|
||||
"performance_score": 0.1,
|
||||
"accuracy": 0.0,
|
||||
"loss": null,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "extrema_trainer_20250624_222303",
|
||||
"model_name": "extrema_trainer",
|
||||
"model_type": "extrema_trainer",
|
||||
"file_path": "NN\\models\\saved\\extrema_trainer\\extrema_trainer_20250624_222303.pt",
|
||||
"created_at": "2025-06-24T22:23:03.283194",
|
||||
"file_size_mb": 0.0013427734375,
|
||||
"performance_score": 0.1,
|
||||
"accuracy": 0.0,
|
||||
"loss": null,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "extrema_trainer_20250625_105812",
|
||||
"model_name": "extrema_trainer",
|
||||
"model_type": "extrema_trainer",
|
||||
"file_path": "NN\\models\\saved\\extrema_trainer\\extrema_trainer_20250625_105812.pt",
|
||||
"created_at": "2025-06-25T10:58:12.424290",
|
||||
"file_size_mb": 0.0013427734375,
|
||||
"performance_score": 0.1,
|
||||
"accuracy": 0.0,
|
||||
"loss": null,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "extrema_trainer_20250625_110836",
|
||||
"model_name": "extrema_trainer",
|
||||
"model_type": "extrema_trainer",
|
||||
"file_path": "NN\\models\\saved\\extrema_trainer\\extrema_trainer_20250625_110836.pt",
|
||||
"created_at": "2025-06-25T11:08:36.772996",
|
||||
"file_size_mb": 0.0013427734375,
|
||||
"performance_score": 0.1,
|
||||
"accuracy": 0.0,
|
||||
"loss": null,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
}
|
||||
],
|
||||
"dqn_agent": [
|
||||
{
|
||||
"checkpoint_id": "dqn_agent_20250627_030115",
|
||||
"model_name": "dqn_agent",
|
||||
"model_type": "dqn",
|
||||
"file_path": "models\\saved\\dqn_agent\\dqn_agent_20250627_030115.pt",
|
||||
"created_at": "2025-06-27T03:01:15.021842",
|
||||
"file_size_mb": 57.57266807556152,
|
||||
"performance_score": 95.0,
|
||||
"accuracy": 0.85,
|
||||
"loss": 0.0145,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
}
|
||||
],
|
||||
"enhanced_cnn": [
|
||||
{
|
||||
"checkpoint_id": "enhanced_cnn_20250627_030115",
|
||||
"model_name": "enhanced_cnn",
|
||||
"model_type": "cnn",
|
||||
"file_path": "models\\saved\\enhanced_cnn\\enhanced_cnn_20250627_030115.pt",
|
||||
"created_at": "2025-06-27T03:01:15.024856",
|
||||
"file_size_mb": 0.7184391021728516,
|
||||
"performance_score": 92.0,
|
||||
"accuracy": 0.88,
|
||||
"loss": 0.0187,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
}
|
||||
],
|
||||
"decision": [
|
||||
{
|
||||
"checkpoint_id": "decision_20250702_083032",
|
||||
"checkpoint_id": "decision_20250704_082022",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250702_083032.pt",
|
||||
"created_at": "2025-07-02T08:30:32.225869",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_082022.pt",
|
||||
"created_at": "2025-07-04T08:20:22.416087",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 102.79972716525019,
|
||||
"performance_score": 102.79971076963062,
|
||||
"accuracy": null,
|
||||
"loss": 2.7283549419721e-06,
|
||||
"loss": 2.8923120591883844e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
@ -291,15 +21,15 @@
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250702_082925",
|
||||
"checkpoint_id": "decision_20250704_082021",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250702_082925.pt",
|
||||
"created_at": "2025-07-02T08:29:25.899383",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_082021.pt",
|
||||
"created_at": "2025-07-04T08:20:21.900854",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 102.7997148991013,
|
||||
"performance_score": 102.79970038321,
|
||||
"accuracy": null,
|
||||
"loss": 2.8510171153430164e-06,
|
||||
"loss": 2.996176877014177e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
@ -311,15 +41,15 @@
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250702_082924",
|
||||
"checkpoint_id": "decision_20250704_082022",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250702_082924.pt",
|
||||
"created_at": "2025-07-02T08:29:24.538886",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_082022.pt",
|
||||
"created_at": "2025-07-04T08:20:22.294191",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 102.79971291710027,
|
||||
"performance_score": 102.79969219038436,
|
||||
"accuracy": null,
|
||||
"loss": 2.8708372390440218e-06,
|
||||
"loss": 3.0781056310808756e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
@ -331,15 +61,15 @@
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250702_082925",
|
||||
"checkpoint_id": "decision_20250704_134829",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250702_082925.pt",
|
||||
"created_at": "2025-07-02T08:29:25.218718",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_134829.pt",
|
||||
"created_at": "2025-07-04T13:48:29.903250",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 102.79971274601752,
|
||||
"performance_score": 102.79967532851693,
|
||||
"accuracy": null,
|
||||
"loss": 2.87254807635711e-06,
|
||||
"loss": 3.2467253719811344e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
@ -351,117 +81,15 @@
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250702_082925",
|
||||
"checkpoint_id": "decision_20250704_082452",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250702_082925.pt",
|
||||
"created_at": "2025-07-02T08:29:25.332228",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_082452.pt",
|
||||
"created_at": "2025-07-04T08:24:52.949705",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 102.79971263447665,
|
||||
"performance_score": 102.79965677530546,
|
||||
"accuracy": null,
|
||||
"loss": 2.873663491419011e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
}
|
||||
],
|
||||
"cob_rl": [
|
||||
{
|
||||
"checkpoint_id": "cob_rl_20250702_004145",
|
||||
"model_name": "cob_rl",
|
||||
"model_type": "cob_rl",
|
||||
"file_path": "NN\\models\\saved\\cob_rl\\cob_rl_20250702_004145.pt",
|
||||
"created_at": "2025-07-02T00:41:45.481742",
|
||||
"file_size_mb": 0.001003265380859375,
|
||||
"performance_score": 9.644,
|
||||
"accuracy": null,
|
||||
"loss": 0.356,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "cob_rl_20250702_004315",
|
||||
"model_name": "cob_rl",
|
||||
"model_type": "cob_rl",
|
||||
"file_path": "NN\\models\\saved\\cob_rl\\cob_rl_20250702_004315.pt",
|
||||
"created_at": "2025-07-02T00:43:15.996943",
|
||||
"file_size_mb": 0.001003265380859375,
|
||||
"performance_score": 9.644,
|
||||
"accuracy": null,
|
||||
"loss": 0.356,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "cob_rl_20250702_004446",
|
||||
"model_name": "cob_rl",
|
||||
"model_type": "cob_rl",
|
||||
"file_path": "NN\\models\\saved\\cob_rl\\cob_rl_20250702_004446.pt",
|
||||
"created_at": "2025-07-02T00:44:46.656201",
|
||||
"file_size_mb": 0.001003265380859375,
|
||||
"performance_score": 9.644,
|
||||
"accuracy": null,
|
||||
"loss": 0.356,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "cob_rl_20250702_004617",
|
||||
"model_name": "cob_rl",
|
||||
"model_type": "cob_rl",
|
||||
"file_path": "NN\\models\\saved\\cob_rl\\cob_rl_20250702_004617.pt",
|
||||
"created_at": "2025-07-02T00:46:17.380509",
|
||||
"file_size_mb": 0.001003265380859375,
|
||||
"performance_score": 9.644,
|
||||
"accuracy": null,
|
||||
"loss": 0.356,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "cob_rl_20250702_004712",
|
||||
"model_name": "cob_rl",
|
||||
"model_type": "cob_rl",
|
||||
"file_path": "NN\\models\\saved\\cob_rl\\cob_rl_20250702_004712.pt",
|
||||
"created_at": "2025-07-02T00:47:12.447176",
|
||||
"file_size_mb": 0.001003265380859375,
|
||||
"performance_score": 9.644,
|
||||
"accuracy": null,
|
||||
"loss": 0.356,
|
||||
"loss": 3.432258725613987e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
|
165
TRADING_ENHANCEMENTS_SUMMARY.md
Normal file
165
TRADING_ENHANCEMENTS_SUMMARY.md
Normal file
@ -0,0 +1,165 @@
|
||||
# Trading System Enhancements Summary
|
||||
|
||||
## 🎯 **Issues Fixed**
|
||||
|
||||
### 1. **Position Sizing Issues**
|
||||
- **Problem**: Tiny position sizes (0.000 quantity) with meaningless P&L
|
||||
- **Solution**: Implemented percentage-based position sizing with leverage
|
||||
- **Result**: Meaningful position sizes based on account balance percentage
|
||||
|
||||
### 2. **Symbol Restrictions**
|
||||
- **Problem**: Both BTC and ETH trades were executing
|
||||
- **Solution**: Added `allowed_symbols: ["ETH/USDT"]` restriction
|
||||
- **Result**: Only ETH/USDT trades are now allowed
|
||||
|
||||
### 3. **Win Rate Calculation**
|
||||
- **Problem**: Incorrect win rate (50% instead of 69.2% for 9W/4L)
|
||||
- **Solution**: Fixed rounding issues in win/loss counting logic
|
||||
- **Result**: Accurate win rate calculations
|
||||
|
||||
### 4. **Missing Hold Time**
|
||||
- **Problem**: No way to debug model behavior timing
|
||||
- **Solution**: Added hold time tracking in seconds
|
||||
- **Result**: Each trade now shows exact hold duration
|
||||
|
||||
## 🚀 **New Features Implemented**
|
||||
|
||||
### 1. **Percentage-Based Position Sizing**
|
||||
```yaml
|
||||
# config.yaml
|
||||
base_position_percent: 5.0 # 5% base position of account
|
||||
max_position_percent: 20.0 # 20% max position of account
|
||||
min_position_percent: 2.0 # 2% min position of account
|
||||
leverage: 50.0 # 50x leverage (adjustable in UI)
|
||||
simulation_account_usd: 100.0 # $100 simulation account
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
- Base position = Account Balance × Base % × Confidence
|
||||
- Effective position = Base position × Leverage
|
||||
- Example: $100 account × 5% × 0.8 confidence × 50x = $200 effective position
|
||||
|
||||
### 2. **Hold Time Tracking**
|
||||
```python
|
||||
@dataclass
|
||||
class TradeRecord:
|
||||
# ... existing fields ...
|
||||
hold_time_seconds: float = 0.0 # NEW: Hold time in seconds
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Debug model behavior patterns
|
||||
- Identify optimal hold times
|
||||
- Analyze trade timing efficiency
|
||||
|
||||
### 3. **Enhanced Trading Statistics**
|
||||
```python
|
||||
# Now includes:
|
||||
- Total fees paid
|
||||
- Hold time per trade
|
||||
- Percentage-based position info
|
||||
- Leverage settings
|
||||
```
|
||||
|
||||
### 4. **UI-Adjustable Leverage**
|
||||
```python
|
||||
def get_leverage(self) -> float:
|
||||
"""Get current leverage setting"""
|
||||
|
||||
def set_leverage(self, leverage: float) -> bool:
|
||||
"""Set leverage (for UI control)"""
|
||||
|
||||
def get_account_info(self) -> Dict[str, Any]:
|
||||
"""Get account information for UI display"""
|
||||
```
|
||||
|
||||
## 📊 **Dashboard Improvements**
|
||||
|
||||
### 1. **Enhanced Closed Trades Table**
|
||||
```
|
||||
Time | Side | Size | Entry | Exit | Hold (s) | P&L | Fees
|
||||
02:33:44 | LONG | 0.080 | $2588.33 | $2588.11 | 30 | $50.00 | $1.00
|
||||
```
|
||||
|
||||
### 2. **Improved Trading Statistics**
|
||||
```
|
||||
Win Rate: 60.0% (3W/2L) | Avg Win: $50.00 | Avg Loss: $25.00 | Total Fees: $5.00
|
||||
```
|
||||
|
||||
## 🔧 **Configuration Changes**
|
||||
|
||||
### Before:
|
||||
```yaml
|
||||
max_position_value_usd: 50.0 # Fixed USD amounts
|
||||
min_position_value_usd: 10.0
|
||||
leverage: 10.0
|
||||
```
|
||||
|
||||
### After:
|
||||
```yaml
|
||||
base_position_percent: 5.0 # Percentage of account
|
||||
max_position_percent: 20.0 # Scales with account size
|
||||
min_position_percent: 2.0
|
||||
leverage: 50.0 # Higher leverage for significant P&L
|
||||
simulation_account_usd: 100.0 # Clear simulation balance
|
||||
allowed_symbols: ["ETH/USDT"] # ETH-only trading
|
||||
```
|
||||
|
||||
## 📈 **Expected Results**
|
||||
|
||||
With these changes, you should now see:
|
||||
|
||||
1. **Meaningful Position Sizes**:
|
||||
- 2-20% of account balance
|
||||
- With 50x leverage = $100-$1000 effective positions
|
||||
|
||||
2. **Significant P&L Values**:
|
||||
- Instead of $0.01 profits, expect $10-$100+ moves
|
||||
- Proportional to leverage and position size
|
||||
|
||||
3. **Accurate Statistics**:
|
||||
- Correct win rate calculations
|
||||
- Hold time analysis capabilities
|
||||
- Total fees tracking
|
||||
|
||||
4. **ETH-Only Trading**:
|
||||
- No more BTC trades
|
||||
- Focused on ETH/USDT pairs only
|
||||
|
||||
5. **Better Debugging**:
|
||||
- Hold time shows model behavior patterns
|
||||
- Percentage-based sizing scales with account
|
||||
- UI-adjustable leverage for testing
|
||||
|
||||
## 🧪 **Test Results**
|
||||
|
||||
All tests passing:
|
||||
- ✅ Position Sizing: Updated with percentage-based leverage
|
||||
- ✅ ETH-Only Trading: Configured in config
|
||||
- ✅ Win Rate Calculation: FIXED
|
||||
- ✅ New Features: WORKING
|
||||
|
||||
## 🎮 **UI Controls Available**
|
||||
|
||||
The trading executor now supports:
|
||||
- `get_leverage()` - Get current leverage
|
||||
- `set_leverage(value)` - Adjust leverage from UI
|
||||
- `get_account_info()` - Get account status for display
|
||||
- Enhanced position and trade information
|
||||
|
||||
## 🔍 **Debugging Capabilities**
|
||||
|
||||
With hold time tracking, you can now:
|
||||
- Identify if model holds positions too long/short
|
||||
- Correlate hold time with P&L success
|
||||
- Optimize entry/exit timing
|
||||
- Debug model behavior patterns
|
||||
|
||||
Example analysis:
|
||||
```
|
||||
Short holds (< 30s): 70% win rate
|
||||
Medium holds (30-60s): 60% win rate
|
||||
Long holds (> 60s): 40% win rate
|
||||
```
|
||||
|
||||
This data helps optimize the model's decision timing!
|
21
config.yaml
21
config.yaml
@ -81,8 +81,8 @@ orchestrator:
|
||||
# Model weights for decision combination
|
||||
cnn_weight: 0.7 # Weight for CNN predictions
|
||||
rl_weight: 0.3 # Weight for RL decisions
|
||||
confidence_threshold: 0.20 # Lowered from 0.35 for low-volatility markets
|
||||
confidence_threshold_close: 0.10 # Lowered from 0.15 for easier exits
|
||||
confidence_threshold: 0.05 # Very low threshold for training and simulation
|
||||
confidence_threshold_close: 0.05 # Very low threshold for easier exits
|
||||
decision_frequency: 30 # Seconds between decisions (faster)
|
||||
|
||||
# Multi-symbol coordination
|
||||
@ -154,18 +154,23 @@ trading:
|
||||
# MEXC Trading API Configuration
|
||||
mexc_trading:
|
||||
enabled: true
|
||||
trading_mode: live # simulation, testnet, live
|
||||
trading_mode: simulation # simulation, testnet, live
|
||||
|
||||
# FIXED: Meaningful position sizes for learning
|
||||
base_position_usd: 25.0 # $25 base position (was $1)
|
||||
max_position_value_usd: 50.0 # $50 max position (was $1)
|
||||
min_position_value_usd: 10.0 # $10 min position (was $0.10)
|
||||
# Position sizing as percentage of account balance
|
||||
base_position_percent: 5.0 # 5% base position of account
|
||||
max_position_percent: 20.0 # 20% max position of account
|
||||
min_position_percent: 2.0 # 2% min position of account
|
||||
leverage: 50.0 # 50x leverage (adjustable in UI)
|
||||
simulation_account_usd: 100.0 # $100 simulation account balance
|
||||
|
||||
# Risk management
|
||||
max_daily_trades: 100
|
||||
max_daily_loss_usd: 200.0
|
||||
max_concurrent_positions: 3
|
||||
min_trade_interval_seconds: 30
|
||||
min_trade_interval_seconds: 5 # Reduced for testing and training
|
||||
|
||||
# Symbol restrictions - ETH ONLY
|
||||
allowed_symbols: ["ETH/USDT"]
|
||||
|
||||
# Order configuration
|
||||
order_type: market # market or limit
|
||||
|
@ -299,7 +299,60 @@ class TradingOrchestrator:
|
||||
self.model_states['decision']['current_loss'] = 0.0089
|
||||
self.model_states['decision']['best_loss'] = 0.0065
|
||||
|
||||
logger.info("ML models initialization completed")
|
||||
# CRITICAL: Register models with the model registry
|
||||
logger.info("Registering models with model registry...")
|
||||
|
||||
# Import model interfaces
|
||||
from models import CNNModelInterface, RLAgentInterface, ModelInterface
|
||||
|
||||
# Register RL Agent
|
||||
if self.rl_agent:
|
||||
try:
|
||||
rl_interface = RLAgentInterface(self.rl_agent, name="dqn_agent")
|
||||
self.register_model(rl_interface, weight=0.3)
|
||||
logger.info("RL Agent registered successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register RL Agent: {e}")
|
||||
|
||||
# Register CNN Model
|
||||
if self.cnn_model:
|
||||
try:
|
||||
cnn_interface = CNNModelInterface(self.cnn_model, name="enhanced_cnn")
|
||||
self.register_model(cnn_interface, weight=0.7)
|
||||
logger.info("CNN Model registered successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register CNN Model: {e}")
|
||||
|
||||
# Register Extrema Trainer (as generic ModelInterface)
|
||||
if self.extrema_trainer:
|
||||
try:
|
||||
# Create a simple wrapper for extrema trainer
|
||||
class ExtremaTrainerInterface(ModelInterface):
|
||||
def __init__(self, model, name: str):
|
||||
super().__init__(name)
|
||||
self.model = model
|
||||
|
||||
def predict(self, data):
|
||||
try:
|
||||
if hasattr(self.model, 'predict'):
|
||||
return self.model.predict(data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extrema trainer prediction: {e}")
|
||||
return None
|
||||
|
||||
def get_memory_usage(self) -> float:
|
||||
return 30.0 # MB
|
||||
|
||||
extrema_interface = ExtremaTrainerInterface(self.extrema_trainer, name="extrema_trainer")
|
||||
self.register_model(extrema_interface, weight=0.2)
|
||||
logger.info("Extrema Trainer registered successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register Extrema Trainer: {e}")
|
||||
|
||||
# Show registered models count
|
||||
registered_count = len(self.model_registry.models) if self.model_registry else 0
|
||||
logger.info(f"ML models initialization completed - {registered_count} models registered")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing ML models: {e}")
|
||||
@ -1187,12 +1240,26 @@ class TradingOrchestrator:
|
||||
enhanced_features = feature_matrix
|
||||
|
||||
if enhanced_features is not None:
|
||||
# Get CNN prediction
|
||||
# Get CNN prediction - use the actual underlying model
|
||||
try:
|
||||
action_probs, confidence = model.predict_timeframe(enhanced_features, timeframe)
|
||||
except AttributeError:
|
||||
# Fallback to generic predict method
|
||||
action_probs, confidence = model.predict(enhanced_features)
|
||||
if hasattr(model.model, 'act'):
|
||||
# Use the CNN's act method
|
||||
action_result = model.model.act(enhanced_features, explore=False)
|
||||
if isinstance(action_result, tuple):
|
||||
action_idx, confidence = action_result
|
||||
else:
|
||||
action_idx = action_result
|
||||
confidence = 0.7 # Default confidence
|
||||
|
||||
# Convert to action probabilities
|
||||
action_probs = [0.1, 0.1, 0.8] # Default distribution
|
||||
action_probs[action_idx] = confidence
|
||||
else:
|
||||
# Fallback to generic predict method
|
||||
action_probs, confidence = model.predict(enhanced_features)
|
||||
except Exception as e:
|
||||
logger.warning(f"CNN prediction failed: {e}")
|
||||
action_probs, confidence = None, None
|
||||
|
||||
if action_probs is not None:
|
||||
# Convert to prediction object
|
||||
@ -1237,8 +1304,15 @@ class TradingOrchestrator:
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
# Get RL agent's action and confidence
|
||||
action_idx, confidence = model.act_with_confidence(state)
|
||||
# Get RL agent's action and confidence - use the actual underlying model
|
||||
if hasattr(model.model, 'act_with_confidence'):
|
||||
action_idx, confidence = model.model.act_with_confidence(state)
|
||||
elif hasattr(model.model, 'act'):
|
||||
action_idx = model.model.act(state, explore=False)
|
||||
confidence = 0.7 # Default confidence for basic act method
|
||||
else:
|
||||
logger.error(f"RL model {model.name} has no act method")
|
||||
return None
|
||||
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
action = action_names[action_idx]
|
||||
@ -1549,38 +1623,28 @@ class TradingOrchestrator:
|
||||
self.model_states['extrema_trainer']['current_loss'] = estimated_loss
|
||||
self.model_states['extrema_trainer']['best_loss'] = estimated_loss
|
||||
|
||||
# Ensure initial_loss is set for new models
|
||||
# NO LONGER SETTING SYNTHETIC INITIAL LOSS VALUES
|
||||
# Keep all None values as None if no real data is available
|
||||
# This prevents the "fake progress" issue where Current Loss = Initial Loss
|
||||
|
||||
# Only set initial_loss from actual training history if available
|
||||
for model_key, model_state in self.model_states.items():
|
||||
if model_state['initial_loss'] is None:
|
||||
# Set reasonable initial loss values for new models
|
||||
initial_losses = {
|
||||
'dqn': 0.285,
|
||||
'cnn': 0.412,
|
||||
'cob_rl': 0.356,
|
||||
'decision': 0.298,
|
||||
'extrema_trainer': 0.356
|
||||
}
|
||||
model_state['initial_loss'] = initial_losses.get(model_key, 0.3)
|
||||
|
||||
# If current_loss is None, set it to initial_loss
|
||||
if model_state['current_loss'] is None:
|
||||
model_state['current_loss'] = model_state['initial_loss']
|
||||
|
||||
# If best_loss is None, set it to current_loss
|
||||
if model_state['best_loss'] is None:
|
||||
model_state['best_loss'] = model_state['current_loss']
|
||||
# Leave initial_loss as None if no real training history exists
|
||||
# Leave current_loss as None if model isn't actively training
|
||||
# Leave best_loss as None if no checkpoints exist with real performance data
|
||||
pass # No synthetic data generation
|
||||
|
||||
return self.model_states
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model states: {e}")
|
||||
# Return safe fallback values
|
||||
# Return None values instead of synthetic data
|
||||
return {
|
||||
'dqn': {'initial_loss': 0.285, 'current_loss': 0.285, 'best_loss': 0.285, 'checkpoint_loaded': False},
|
||||
'cnn': {'initial_loss': 0.412, 'current_loss': 0.412, 'best_loss': 0.412, 'checkpoint_loaded': False},
|
||||
'cob_rl': {'initial_loss': 0.356, 'current_loss': 0.356, 'best_loss': 0.356, 'checkpoint_loaded': False},
|
||||
'decision': {'initial_loss': 0.298, 'current_loss': 0.298, 'best_loss': 0.298, 'checkpoint_loaded': False},
|
||||
'extrema_trainer': {'initial_loss': 0.356, 'current_loss': 0.356, 'best_loss': 0.356, 'checkpoint_loaded': False}
|
||||
'dqn': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||||
'cnn': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||||
'cob_rl': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||||
'decision': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||||
'extrema_trainer': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False}
|
||||
}
|
||||
|
||||
def update_model_loss(self, model_name: str, current_loss: float, best_loss: float = None):
|
||||
|
@ -59,7 +59,7 @@ class SignalAccumulator:
|
||||
confidence_sum: float = 0.0
|
||||
successful_predictions: int = 0
|
||||
total_predictions: int = 0
|
||||
last_reset_time: datetime = None
|
||||
last_reset_time: Optional[datetime] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.signals is None:
|
||||
@ -99,12 +99,13 @@ class RealtimeRLCOBTrader:
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
symbols: List[str] = None,
|
||||
trading_executor: TradingExecutor = None,
|
||||
symbols: Optional[List[str]] = None,
|
||||
trading_executor: Optional[TradingExecutor] = None,
|
||||
model_checkpoint_dir: str = "models/realtime_rl_cob",
|
||||
inference_interval_ms: int = 200,
|
||||
min_confidence_threshold: float = 0.35, # Lowered from 0.7 for more aggressive trading
|
||||
required_confident_predictions: int = 3):
|
||||
required_confident_predictions: int = 3,
|
||||
checkpoint_manager: Any = None):
|
||||
|
||||
self.symbols = symbols or ['BTC/USDT', 'ETH/USDT']
|
||||
self.trading_executor = trading_executor
|
||||
@ -113,6 +114,16 @@ class RealtimeRLCOBTrader:
|
||||
self.min_confidence_threshold = min_confidence_threshold
|
||||
self.required_confident_predictions = required_confident_predictions
|
||||
|
||||
# Initialize CheckpointManager (either provided or get global instance)
|
||||
if checkpoint_manager is None:
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
else:
|
||||
self.checkpoint_manager = checkpoint_manager
|
||||
|
||||
# Track start time for training duration calculation
|
||||
self.start_time = datetime.now() # Initialize start_time
|
||||
|
||||
# Setup device
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logger.info(f"Using device: {self.device}")
|
||||
@ -819,29 +830,26 @@ class RealtimeRLCOBTrader:
|
||||
actual_direction = 1 # SIDEWAYS
|
||||
|
||||
# Calculate reward based on prediction accuracy
|
||||
reward = self._calculate_prediction_reward(
|
||||
prediction.predicted_direction,
|
||||
actual_direction,
|
||||
prediction.confidence,
|
||||
prediction.predicted_change,
|
||||
actual_change
|
||||
prediction.reward = self._calculate_prediction_reward(
|
||||
symbol=symbol,
|
||||
predicted_direction=prediction.predicted_direction,
|
||||
actual_direction=actual_direction,
|
||||
confidence=prediction.confidence,
|
||||
predicted_change=prediction.predicted_change,
|
||||
actual_change=actual_change
|
||||
)
|
||||
|
||||
# Update prediction
|
||||
prediction.actual_direction = actual_direction
|
||||
prediction.actual_change = actual_change
|
||||
prediction.reward = reward
|
||||
|
||||
# Update training stats
|
||||
stats = self.training_stats[symbol]
|
||||
stats['total_predictions'] += 1
|
||||
if reward > 0:
|
||||
if prediction.reward > 0:
|
||||
stats['successful_predictions'] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating rewards for {symbol}: {e}")
|
||||
|
||||
def _calculate_prediction_reward(self,
|
||||
symbol: str,
|
||||
predicted_direction: int,
|
||||
actual_direction: int,
|
||||
confidence: float,
|
||||
@ -849,67 +857,52 @@ class RealtimeRLCOBTrader:
|
||||
actual_change: float,
|
||||
current_pnl: float = 0.0,
|
||||
position_duration: float = 0.0) -> float:
|
||||
"""Calculate reward for a prediction with PnL-aware loss cutting optimization"""
|
||||
try:
|
||||
# Base reward for correct direction
|
||||
if predicted_direction == actual_direction:
|
||||
base_reward = 1.0
|
||||
"""Calculate reward based on prediction accuracy and actual price movement"""
|
||||
reward = 0.0
|
||||
|
||||
# Base reward for correct direction prediction
|
||||
if predicted_direction == actual_direction:
|
||||
reward += 1.0 * confidence # Reward scales with confidence
|
||||
else:
|
||||
reward -= 0.5 # Penalize incorrect predictions
|
||||
|
||||
# Reward for predicting large changes correctly (proportional to actual change)
|
||||
if predicted_direction == actual_direction and abs(predicted_change) > 0.001:
|
||||
reward += abs(actual_change) * 5.0 # Amplify reward for significant moves
|
||||
|
||||
# Penalize for large predicted changes that are wrong
|
||||
if predicted_direction != actual_direction and abs(predicted_change) > 0.001:
|
||||
reward -= abs(predicted_change) * 2.0
|
||||
|
||||
# Add reward for PnL (realized or unrealized)
|
||||
reward += current_pnl * 0.1 # Small reward for PnL, adjusted by a factor
|
||||
|
||||
# Dynamic adjustment based on recent PnL (loss cutting incentive)
|
||||
if self.pnl_history[symbol]:
|
||||
latest_pnl_entry = self.pnl_history[symbol][-1] # Get the latest PnL entry
|
||||
# Ensure latest_pnl_entry is a dict and has 'pnl' key, otherwise default to 0.0
|
||||
latest_pnl_value = latest_pnl_entry.get('pnl', 0.0) if isinstance(latest_pnl_entry, dict) else 0.0
|
||||
|
||||
# Incentivize closing losing trades early
|
||||
if latest_pnl_value < 0 and position_duration > 60: # If losing position open for > 60s
|
||||
# More aggressively penalize holding losing positions, or reward closing them
|
||||
reward -= (abs(latest_pnl_value) * 0.2) # Increased penalty for sustained losses
|
||||
|
||||
# Discourage taking new positions if overall PnL is negative or volatile
|
||||
# This requires a more complex calculation of overall PnL, potentially average of last N trades
|
||||
# For simplicity, let's use the 'best_pnl' to decide if we are in a good state to trade
|
||||
|
||||
# Calculate the current best PnL from history, ensuring it's not empty
|
||||
pnl_values = [entry.get('pnl', 0.0) for entry in self.pnl_history[symbol] if isinstance(entry, dict)]
|
||||
if not pnl_values:
|
||||
best_pnl = 0.0
|
||||
else:
|
||||
base_reward = -1.0
|
||||
|
||||
# Scale by confidence
|
||||
confidence_scaled_reward = base_reward * confidence
|
||||
|
||||
# Additional reward for magnitude accuracy
|
||||
if predicted_direction != 1: # Not sideways
|
||||
magnitude_accuracy = 1.0 - abs(predicted_change - actual_change) / max(abs(actual_change), 0.001)
|
||||
magnitude_accuracy = max(0.0, magnitude_accuracy)
|
||||
confidence_scaled_reward += magnitude_accuracy * 0.5
|
||||
|
||||
# Penalty for overconfident wrong predictions
|
||||
if base_reward < 0 and confidence > 0.8:
|
||||
confidence_scaled_reward *= 1.5 # Increase penalty
|
||||
|
||||
# === PnL-AWARE LOSS CUTTING REWARDS ===
|
||||
|
||||
pnl_reward = 0.0
|
||||
|
||||
# Reward cutting losses early (SIDEWAYS when losing)
|
||||
if current_pnl < -10.0: # In significant loss
|
||||
if predicted_direction == 1: # SIDEWAYS (exit signal)
|
||||
# Reward cutting losses before they get worse
|
||||
loss_cutting_bonus = min(1.0, abs(current_pnl) / 100.0) * confidence
|
||||
pnl_reward += loss_cutting_bonus
|
||||
elif predicted_direction != 1: # Continuing to trade while in loss
|
||||
# Penalty for not cutting losses
|
||||
pnl_reward -= 0.5 * confidence
|
||||
|
||||
# Reward protecting profits (SIDEWAYS when in profit and market turning)
|
||||
elif current_pnl > 10.0: # In profit
|
||||
if predicted_direction == 1 and base_reward > 0: # Correct SIDEWAYS prediction
|
||||
# Reward protecting profits from reversal
|
||||
profit_protection_bonus = min(0.5, current_pnl / 200.0) * confidence
|
||||
pnl_reward += profit_protection_bonus
|
||||
|
||||
# Duration penalty for holding losing positions
|
||||
if current_pnl < 0 and position_duration > 3600: # Losing for > 1 hour
|
||||
duration_penalty = min(1.0, position_duration / 7200.0) * 0.3 # Up to 30% penalty
|
||||
confidence_scaled_reward -= duration_penalty
|
||||
|
||||
# Severe penalty for letting small losses become big losses
|
||||
if current_pnl < -50.0: # Large loss
|
||||
drawdown_penalty = min(2.0, abs(current_pnl) / 100.0) * confidence
|
||||
confidence_scaled_reward -= drawdown_penalty
|
||||
|
||||
# Total reward
|
||||
total_reward = confidence_scaled_reward + pnl_reward
|
||||
|
||||
# Clamp final reward
|
||||
return max(-5.0, min(5.0, float(total_reward)))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating reward: {e}")
|
||||
return 0.0
|
||||
best_pnl = max(pnl_values)
|
||||
|
||||
if best_pnl < 0.0: # If recent best PnL is negative, reduce reward for new trades
|
||||
reward -= 0.1 # Small penalty for trading in a losing streak
|
||||
|
||||
return reward
|
||||
|
||||
async def _train_batch(self, symbol: str, predictions: List[PredictionResult]) -> float:
|
||||
"""Train model on a batch of predictions"""
|
||||
@ -1021,20 +1014,36 @@ class RealtimeRLCOBTrader:
|
||||
await asyncio.sleep(60)
|
||||
|
||||
def _save_models(self):
|
||||
"""Save all models to disk"""
|
||||
"""Save all models to disk using CheckpointManager"""
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
symbol_safe = symbol.replace('/', '_')
|
||||
model_path = os.path.join(self.model_checkpoint_dir, f"{symbol_safe}_model.pt")
|
||||
model_name = f"cob_rl_{symbol.replace('/', '_').lower()}" # Standardize model name for CheckpointManager
|
||||
|
||||
# Save model state
|
||||
torch.save({
|
||||
'model_state_dict': self.models[symbol].state_dict(),
|
||||
'optimizer_state_dict': self.optimizers[symbol].state_dict(),
|
||||
'training_stats': self.training_stats[symbol],
|
||||
'inference_stats': self.inference_stats[symbol],
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}, model_path)
|
||||
# Prepare performance metrics for CheckpointManager
|
||||
performance_metrics = {
|
||||
'loss': self.training_stats[symbol].get('average_loss', 0.0),
|
||||
'reward': self.training_stats[symbol].get('average_reward', 0.0), # Assuming average_reward is tracked
|
||||
'accuracy': self.training_stats[symbol].get('average_accuracy', 0.0), # Assuming average_accuracy is tracked
|
||||
}
|
||||
if self.trading_executor: # Add check for trading_executor
|
||||
daily_stats = self.trading_executor.get_daily_stats()
|
||||
performance_metrics['pnl'] = daily_stats.get('total_pnl', 0.0) # Example, get actual pnl
|
||||
performance_metrics['training_samples'] = self.training_stats[symbol].get('total_training_steps', 0)
|
||||
|
||||
# Prepare training metadata for CheckpointManager
|
||||
training_metadata = {
|
||||
'total_parameters': sum(p.numel() for p in self.models[symbol].parameters()),
|
||||
'epoch': self.training_stats[symbol].get('total_training_steps', 0), # Using total_training_steps as pseudo-epoch
|
||||
'training_time_hours': (datetime.now() - self.start_time).total_seconds() / 3600
|
||||
}
|
||||
|
||||
self.checkpoint_manager.save_checkpoint(
|
||||
model=self.models[symbol],
|
||||
model_name=model_name,
|
||||
model_type='COB_RL', # Specify model type
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
logger.debug(f"Saved model for {symbol}")
|
||||
|
||||
@ -1042,13 +1051,15 @@ class RealtimeRLCOBTrader:
|
||||
logger.error(f"Error saving models: {e}")
|
||||
|
||||
def _load_models(self):
|
||||
"""Load existing models from disk"""
|
||||
"""Load existing models from disk using CheckpointManager"""
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
symbol_safe = symbol.replace('/', '_')
|
||||
model_path = os.path.join(self.model_checkpoint_dir, f"{symbol_safe}_model.pt")
|
||||
model_name = f"cob_rl_{symbol.replace('/', '_').lower()}" # Standardize model name for CheckpointManager
|
||||
|
||||
if os.path.exists(model_path):
|
||||
loaded_checkpoint = self.checkpoint_manager.load_best_checkpoint(model_name)
|
||||
|
||||
if loaded_checkpoint:
|
||||
model_path, metadata = loaded_checkpoint
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
|
||||
self.models[symbol].load_state_dict(checkpoint['model_state_dict'])
|
||||
@ -1059,9 +1070,9 @@ class RealtimeRLCOBTrader:
|
||||
if 'inference_stats' in checkpoint:
|
||||
self.inference_stats[symbol].update(checkpoint['inference_stats'])
|
||||
|
||||
logger.info(f"Loaded existing model for {symbol}")
|
||||
logger.info(f"Loaded existing model for {symbol} from checkpoint: {metadata.checkpoint_id}")
|
||||
else:
|
||||
logger.info(f"No existing model found for {symbol}, starting fresh")
|
||||
logger.info(f"No existing model found for {symbol} via CheckpointManager, starting fresh.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading models: {e}")
|
||||
@ -1111,7 +1122,7 @@ async def main():
|
||||
from ..core.trading_executor import TradingExecutor
|
||||
|
||||
# Initialize trading executor (simulation mode)
|
||||
trading_executor = TradingExecutor(simulation_mode=True)
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Initialize real-time RL trader
|
||||
trader = RealtimeRLCOBTrader(
|
||||
|
@ -58,6 +58,7 @@ class TradeRecord:
|
||||
pnl: float
|
||||
fees: float
|
||||
confidence: float
|
||||
hold_time_seconds: float = 0.0 # Hold time in seconds
|
||||
|
||||
class TradingExecutor:
|
||||
"""Handles trade execution through MEXC API with risk management"""
|
||||
@ -93,7 +94,6 @@ class TradingExecutor:
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
test_mode=exchange_test_mode,
|
||||
trading_mode=trading_mode
|
||||
)
|
||||
|
||||
# Trading state
|
||||
@ -207,15 +207,21 @@ class TradingExecutor:
|
||||
# Assert that current_price is not None for type checking
|
||||
assert current_price is not None, "current_price should not be None at this point"
|
||||
|
||||
# --- Balance check before executing trade ---
|
||||
# Only perform balance check for BUY actions or SHORT (initial sell) actions
|
||||
if action == 'BUY' or (action == 'SELL' and symbol not in self.positions) or (action == 'SHORT'):
|
||||
# --- Balance check before executing trade (skip in simulation mode) ---
|
||||
# Only perform balance check for live trading, not simulation
|
||||
if not self.simulation_mode and (action == 'BUY' or (action == 'SELL' and symbol not in self.positions) or (action == 'SHORT')):
|
||||
# Determine the quote asset (e.g., USDT, USDC) from the symbol
|
||||
if '/' in symbol:
|
||||
quote_asset = symbol.split('/')[1].upper() # Assuming symbol is like ETH/USDT
|
||||
# Convert USDT to USDC for MEXC spot trading
|
||||
if quote_asset == 'USDT':
|
||||
quote_asset = 'USDC'
|
||||
else:
|
||||
# Fallback for symbols like ETHUSDT (assuming last 4 chars are quote)
|
||||
quote_asset = symbol[-4:].upper()
|
||||
quote_asset = symbol[-4:].upper()
|
||||
# Convert USDT to USDC for MEXC spot trading
|
||||
if quote_asset == 'USDT':
|
||||
quote_asset = 'USDC'
|
||||
|
||||
# Calculate required capital for the trade
|
||||
# If we are selling (to open a short position), we need collateral based on the position size
|
||||
@ -225,12 +231,22 @@ class TradingExecutor:
|
||||
# Get available balance for the quote asset
|
||||
available_balance = self.exchange.get_balance(quote_asset)
|
||||
|
||||
# If USDC balance is insufficient, check USDT as fallback (for MEXC compatibility)
|
||||
if available_balance < required_capital and quote_asset == 'USDC':
|
||||
usdt_balance = self.exchange.get_balance('USDT')
|
||||
if usdt_balance >= required_capital:
|
||||
available_balance = usdt_balance
|
||||
quote_asset = 'USDT' # Use USDT instead
|
||||
logger.info(f"BALANCE CHECK: Using USDT fallback balance for {symbol}")
|
||||
|
||||
logger.info(f"BALANCE CHECK: Symbol: {symbol}, Action: {action}, Required: ${required_capital:.2f} {quote_asset}, Available: ${available_balance:.2f} {quote_asset}")
|
||||
|
||||
if available_balance < required_capital:
|
||||
logger.warning(f"Trade blocked for {symbol} {action}: Insufficient {quote_asset} balance. "
|
||||
f"Required: ${required_capital:.2f}, Available: ${available_balance:.2f}")
|
||||
return False
|
||||
elif self.simulation_mode:
|
||||
logger.debug(f"SIMULATION MODE: Skipping balance check for {symbol} {action} - allowing trade for model training")
|
||||
# --- End Balance check ---
|
||||
|
||||
with self.lock:
|
||||
@ -305,10 +321,15 @@ class TradingExecutor:
|
||||
quantity = position_value / current_price
|
||||
|
||||
logger.info(f"Executing BUY: {quantity:.6f} {symbol} at ${current_price:.2f} "
|
||||
f"(value: ${position_value:.2f}, confidence: {confidence:.2f})")
|
||||
f"(value: ${position_value:.2f}, confidence: {confidence:.2f}) "
|
||||
f"[{'SIMULATION' if self.simulation_mode else 'LIVE'}]")
|
||||
|
||||
if self.simulation_mode:
|
||||
logger.info(f"SIMULATION MODE ({self.trading_mode.upper()}) - Trade logged but not executed")
|
||||
# Calculate simulated fees in simulation mode
|
||||
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
|
||||
simulated_fees = quantity * current_price * taker_fee_rate
|
||||
|
||||
# Create mock position for tracking
|
||||
self.positions[symbol] = Position(
|
||||
symbol=symbol,
|
||||
@ -352,6 +373,10 @@ class TradingExecutor:
|
||||
)
|
||||
|
||||
if order:
|
||||
# Calculate simulated fees in simulation mode
|
||||
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
|
||||
simulated_fees = quantity * current_price * taker_fee_rate
|
||||
|
||||
# Create position record
|
||||
self.positions[symbol] = Position(
|
||||
symbol=symbol,
|
||||
@ -385,13 +410,19 @@ class TradingExecutor:
|
||||
position = self.positions[symbol]
|
||||
|
||||
logger.info(f"Executing SELL: {position.quantity:.6f} {symbol} at ${current_price:.2f} "
|
||||
f"(confidence: {confidence:.2f})")
|
||||
f"(confidence: {confidence:.2f}) [{'SIMULATION' if self.simulation_mode else 'LIVE'}]")
|
||||
|
||||
if self.simulation_mode:
|
||||
logger.info(f"SIMULATION MODE ({self.trading_mode.upper()}) - Trade logged but not executed")
|
||||
# Calculate P&L
|
||||
# Calculate P&L and hold time
|
||||
pnl = position.calculate_pnl(current_price)
|
||||
exit_time = datetime.now()
|
||||
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
|
||||
|
||||
# Calculate simulated fees in simulation mode
|
||||
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
|
||||
simulated_fees = position.quantity * current_price * taker_fee_rate
|
||||
|
||||
# Create trade record
|
||||
trade_record = TradeRecord(
|
||||
symbol=symbol,
|
||||
@ -400,10 +431,11 @@ class TradingExecutor:
|
||||
entry_price=position.entry_price,
|
||||
exit_price=current_price,
|
||||
entry_time=position.entry_time,
|
||||
exit_time=datetime.now(),
|
||||
exit_time=exit_time,
|
||||
pnl=pnl,
|
||||
fees=0.0,
|
||||
confidence=confidence
|
||||
fees=simulated_fees,
|
||||
confidence=confidence,
|
||||
hold_time_seconds=hold_time_seconds
|
||||
)
|
||||
|
||||
self.trade_history.append(trade_record)
|
||||
@ -447,9 +479,15 @@ class TradingExecutor:
|
||||
)
|
||||
|
||||
if order:
|
||||
# Calculate P&L
|
||||
# Calculate simulated fees in simulation mode
|
||||
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
|
||||
simulated_fees = position.quantity * current_price * taker_fee_rate
|
||||
|
||||
# Calculate P&L, fees, and hold time
|
||||
pnl = position.calculate_pnl(current_price)
|
||||
fees = self._calculate_trading_fee(order, symbol, position.quantity, current_price)
|
||||
fees = simulated_fees
|
||||
exit_time = datetime.now()
|
||||
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
|
||||
|
||||
# Create trade record
|
||||
trade_record = TradeRecord(
|
||||
@ -459,10 +497,11 @@ class TradingExecutor:
|
||||
entry_price=position.entry_price,
|
||||
exit_price=current_price,
|
||||
entry_time=position.entry_time,
|
||||
exit_time=datetime.now(),
|
||||
exit_time=exit_time,
|
||||
pnl=pnl - fees,
|
||||
fees=fees,
|
||||
confidence=confidence
|
||||
confidence=confidence,
|
||||
hold_time_seconds=hold_time_seconds
|
||||
)
|
||||
|
||||
self.trade_history.append(trade_record)
|
||||
@ -496,10 +535,15 @@ class TradingExecutor:
|
||||
quantity = position_value / current_price
|
||||
|
||||
logger.info(f"Executing SHORT: {quantity:.6f} {symbol} at ${current_price:.2f} "
|
||||
f"(value: ${position_value:.2f}, confidence: {confidence:.2f})")
|
||||
f"(value: ${position_value:.2f}, confidence: {confidence:.2f}) "
|
||||
f"[{'SIMULATION' if self.simulation_mode else 'LIVE'}]")
|
||||
|
||||
if self.simulation_mode:
|
||||
logger.info(f"SIMULATION MODE ({self.trading_mode.upper()}) - Short position logged but not executed")
|
||||
# Calculate simulated fees in simulation mode
|
||||
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
|
||||
simulated_fees = quantity * current_price * taker_fee_rate
|
||||
|
||||
# Create mock short position for tracking
|
||||
self.positions[symbol] = Position(
|
||||
symbol=symbol,
|
||||
@ -543,6 +587,10 @@ class TradingExecutor:
|
||||
)
|
||||
|
||||
if order:
|
||||
# Calculate simulated fees in simulation mode
|
||||
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
|
||||
simulated_fees = quantity * current_price * taker_fee_rate
|
||||
|
||||
# Create short position record
|
||||
self.positions[symbol] = Position(
|
||||
symbol=symbol,
|
||||
@ -582,8 +630,14 @@ class TradingExecutor:
|
||||
|
||||
if self.simulation_mode:
|
||||
logger.info(f"SIMULATION MODE ({self.trading_mode.upper()}) - Short close logged but not executed")
|
||||
# Calculate P&L for short position
|
||||
# Calculate simulated fees in simulation mode
|
||||
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
|
||||
simulated_fees = position.quantity * current_price * taker_fee_rate
|
||||
|
||||
# Calculate P&L for short position and hold time
|
||||
pnl = position.calculate_pnl(current_price)
|
||||
exit_time = datetime.now()
|
||||
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
|
||||
|
||||
# Create trade record
|
||||
trade_record = TradeRecord(
|
||||
@ -593,10 +647,11 @@ class TradingExecutor:
|
||||
entry_price=position.entry_price,
|
||||
exit_price=current_price,
|
||||
entry_time=position.entry_time,
|
||||
exit_time=datetime.now(),
|
||||
exit_time=exit_time,
|
||||
pnl=pnl,
|
||||
fees=0.0,
|
||||
confidence=confidence
|
||||
fees=simulated_fees,
|
||||
confidence=confidence,
|
||||
hold_time_seconds=hold_time_seconds
|
||||
)
|
||||
|
||||
self.trade_history.append(trade_record)
|
||||
@ -640,9 +695,15 @@ class TradingExecutor:
|
||||
)
|
||||
|
||||
if order:
|
||||
# Calculate P&L
|
||||
# Calculate simulated fees in simulation mode
|
||||
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
|
||||
simulated_fees = position.quantity * current_price * taker_fee_rate
|
||||
|
||||
# Calculate P&L, fees, and hold time
|
||||
pnl = position.calculate_pnl(current_price)
|
||||
fees = self._calculate_trading_fee(order, symbol, position.quantity, current_price)
|
||||
fees = simulated_fees
|
||||
exit_time = datetime.now()
|
||||
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
|
||||
|
||||
# Create trade record
|
||||
trade_record = TradeRecord(
|
||||
@ -652,10 +713,11 @@ class TradingExecutor:
|
||||
entry_price=position.entry_price,
|
||||
exit_price=current_price,
|
||||
entry_time=position.entry_time,
|
||||
exit_time=datetime.now(),
|
||||
exit_time=exit_time,
|
||||
pnl=pnl - fees,
|
||||
fees=fees,
|
||||
confidence=confidence
|
||||
confidence=confidence,
|
||||
hold_time_seconds=hold_time_seconds
|
||||
)
|
||||
|
||||
self.trade_history.append(trade_record)
|
||||
@ -678,15 +740,44 @@ class TradingExecutor:
|
||||
return False
|
||||
|
||||
def _calculate_position_size(self, confidence: float, current_price: float) -> float:
|
||||
"""Calculate position size based on configuration and confidence"""
|
||||
max_value = self.mexc_config.get('max_position_value_usd', 1.0)
|
||||
min_value = self.mexc_config.get('min_position_value_usd', 0.1)
|
||||
"""Calculate position size based on percentage of account balance, confidence, and leverage"""
|
||||
# Get account balance (simulation or real)
|
||||
account_balance = self._get_account_balance_for_sizing()
|
||||
|
||||
# Get position sizing percentages
|
||||
max_percent = self.mexc_config.get('max_position_percent', 20.0) / 100.0
|
||||
min_percent = self.mexc_config.get('min_position_percent', 2.0) / 100.0
|
||||
base_percent = self.mexc_config.get('base_position_percent', 5.0) / 100.0
|
||||
leverage = self.mexc_config.get('leverage', 50.0)
|
||||
|
||||
# Scale position size by confidence
|
||||
base_value = max_value * confidence
|
||||
position_value = max(min_value, min(base_value, max_value))
|
||||
position_percent = min(max_percent, max(min_percent, base_percent * confidence))
|
||||
position_value = account_balance * position_percent
|
||||
|
||||
return position_value
|
||||
# Apply leverage to get effective position size
|
||||
leveraged_position_value = position_value * leverage
|
||||
|
||||
logger.debug(f"Position calculation: account=${account_balance:.2f}, "
|
||||
f"percent={position_percent*100:.1f}%, base=${position_value:.2f}, "
|
||||
f"leverage={leverage}x, effective=${leveraged_position_value:.2f}, "
|
||||
f"confidence={confidence:.2f}")
|
||||
|
||||
return leveraged_position_value
|
||||
|
||||
def _get_account_balance_for_sizing(self) -> float:
|
||||
"""Get account balance for position sizing calculations"""
|
||||
if self.simulation_mode:
|
||||
return self.mexc_config.get('simulation_account_usd', 100.0)
|
||||
else:
|
||||
# For live trading, get actual USDT/USDC balance
|
||||
try:
|
||||
balances = self.get_account_balance()
|
||||
usdt_balance = balances.get('USDT', {}).get('total', 0)
|
||||
usdc_balance = balances.get('USDC', {}).get('total', 0)
|
||||
return max(usdt_balance, usdc_balance)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get live account balance: {e}, using simulation default")
|
||||
return self.mexc_config.get('simulation_account_usd', 100.0)
|
||||
|
||||
def update_positions(self, symbol: str, current_price: float):
|
||||
"""Update position P&L with current market price"""
|
||||
@ -707,15 +798,16 @@ class TradingExecutor:
|
||||
total_pnl = sum(trade.pnl for trade in self.trade_history)
|
||||
total_fees = sum(trade.fees for trade in self.trade_history)
|
||||
gross_pnl = total_pnl + total_fees # P&L before fees
|
||||
winning_trades = len([t for t in self.trade_history if t.pnl > 0])
|
||||
losing_trades = len([t for t in self.trade_history if t.pnl < 0])
|
||||
winning_trades = len([t for t in self.trade_history if t.pnl > 0.001]) # Avoid rounding issues
|
||||
losing_trades = len([t for t in self.trade_history if t.pnl < -0.001]) # Avoid rounding issues
|
||||
total_trades = len(self.trade_history)
|
||||
breakeven_trades = total_trades - winning_trades - losing_trades
|
||||
|
||||
# Calculate average trade values
|
||||
avg_trade_pnl = total_pnl / max(1, total_trades)
|
||||
avg_trade_fee = total_fees / max(1, total_trades)
|
||||
avg_winning_trade = sum(t.pnl for t in self.trade_history if t.pnl > 0) / max(1, winning_trades)
|
||||
avg_losing_trade = sum(t.pnl for t in self.trade_history if t.pnl < 0) / max(1, losing_trades)
|
||||
avg_winning_trade = sum(t.pnl for t in self.trade_history if t.pnl > 0.001) / max(1, winning_trades)
|
||||
avg_losing_trade = sum(t.pnl for t in self.trade_history if t.pnl < -0.001) / max(1, losing_trades)
|
||||
|
||||
# Enhanced fee analysis from config
|
||||
fee_structure = self.mexc_config.get('trading_fees', {})
|
||||
@ -736,6 +828,7 @@ class TradingExecutor:
|
||||
'total_fees': total_fees,
|
||||
'winning_trades': winning_trades,
|
||||
'losing_trades': losing_trades,
|
||||
'breakeven_trades': breakeven_trades,
|
||||
'total_trades': total_trades,
|
||||
'win_rate': winning_trades / max(1, total_trades),
|
||||
'avg_trade_pnl': avg_trade_pnl,
|
||||
@ -779,13 +872,14 @@ class TradingExecutor:
|
||||
logger.info("Daily trading statistics reset")
|
||||
|
||||
def get_account_balance(self) -> Dict[str, Dict[str, float]]:
|
||||
"""Get account balance information from MEXC
|
||||
"""Get account balance information from MEXC, including spot and futures.
|
||||
|
||||
Returns:
|
||||
Dict with asset balances in format:
|
||||
{
|
||||
'USDT': {'free': 100.0, 'locked': 0.0},
|
||||
'ETH': {'free': 0.5, 'locked': 0.0},
|
||||
'USDT': {'free': 100.0, 'locked': 0.0, 'total': 100.0, 'type': 'spot'},
|
||||
'ETH': {'free': 0.5, 'locked': 0.0, 'total': 0.5, 'type': 'spot'},
|
||||
'FUTURES_USDT': {'free': 500.0, 'locked': 50.0, 'total': 550.0, 'type': 'futures'}
|
||||
...
|
||||
}
|
||||
"""
|
||||
@ -794,28 +888,47 @@ class TradingExecutor:
|
||||
logger.error("Exchange interface not available")
|
||||
return {}
|
||||
|
||||
# Get account info from MEXC
|
||||
account_info = self.exchange.get_account_info()
|
||||
if not account_info:
|
||||
logger.error("Failed to get account info from MEXC")
|
||||
return {}
|
||||
combined_balances = {}
|
||||
|
||||
balances = {}
|
||||
for balance in account_info.get('balances', []):
|
||||
asset = balance.get('asset', '')
|
||||
free = float(balance.get('free', 0))
|
||||
locked = float(balance.get('locked', 0))
|
||||
|
||||
# Only include assets with non-zero balance
|
||||
if free > 0 or locked > 0:
|
||||
balances[asset] = {
|
||||
'free': free,
|
||||
'locked': locked,
|
||||
'total': free + locked
|
||||
}
|
||||
|
||||
logger.info(f"Retrieved balances for {len(balances)} assets")
|
||||
return balances
|
||||
# 1. Get Spot Account Info
|
||||
spot_account_info = self.exchange.get_account_info()
|
||||
if spot_account_info and 'balances' in spot_account_info:
|
||||
for balance in spot_account_info['balances']:
|
||||
asset = balance.get('asset', '')
|
||||
free = float(balance.get('free', 0))
|
||||
locked = float(balance.get('locked', 0))
|
||||
if free > 0 or locked > 0:
|
||||
combined_balances[asset] = {
|
||||
'free': free,
|
||||
'locked': locked,
|
||||
'total': free + locked,
|
||||
'type': 'spot'
|
||||
}
|
||||
else:
|
||||
logger.warning("Failed to get spot account info from MEXC or no balances found.")
|
||||
|
||||
# 2. Get Futures Account Info (commented out until futures API is implemented)
|
||||
# futures_account_info = self.exchange.get_futures_account_info()
|
||||
# if futures_account_info:
|
||||
# for currency, asset_data in futures_account_info.items():
|
||||
# # MEXC Futures API returns 'availableBalance' and 'frozenBalance'
|
||||
# free = float(asset_data.get('availableBalance', 0))
|
||||
# locked = float(asset_data.get('frozenBalance', 0))
|
||||
# total = free + locked # total is the sum of available and frozen
|
||||
# if free > 0 or locked > 0:
|
||||
# # Prefix with 'FUTURES_' to distinguish from spot, or decide on a unified key
|
||||
# # For now, let's keep them distinct for clarity
|
||||
# combined_balances[f'FUTURES_{currency}'] = {
|
||||
# 'free': free,
|
||||
# 'locked': locked,
|
||||
# 'total': total,
|
||||
# 'type': 'futures'
|
||||
# }
|
||||
# else:
|
||||
# logger.warning("Failed to get futures account info from MEXC or no futures assets found.")
|
||||
|
||||
logger.info(f"Retrieved combined balances for {len(combined_balances)} assets.")
|
||||
return combined_balances
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting account balance: {e}")
|
||||
@ -1114,7 +1227,8 @@ class TradingExecutor:
|
||||
'exit_time': trade.exit_time,
|
||||
'pnl': trade.pnl,
|
||||
'fees': trade.fees,
|
||||
'confidence': trade.confidence
|
||||
'confidence': trade.confidence,
|
||||
'hold_time_seconds': trade.hold_time_seconds
|
||||
}
|
||||
trades.append(trade_dict)
|
||||
return trades
|
||||
@ -1152,4 +1266,59 @@ class TradingExecutor:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current position: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_leverage(self) -> float:
|
||||
"""Get current leverage setting"""
|
||||
return self.mexc_config.get('leverage', 50.0)
|
||||
|
||||
def set_leverage(self, leverage: float) -> bool:
|
||||
"""Set leverage (for UI control)
|
||||
|
||||
Args:
|
||||
leverage: New leverage value
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
"""
|
||||
try:
|
||||
# Update in-memory config
|
||||
self.mexc_config['leverage'] = leverage
|
||||
logger.info(f"TRADING EXECUTOR: Leverage updated to {leverage}x")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting leverage: {e}")
|
||||
return False
|
||||
|
||||
def get_account_info(self) -> Dict[str, Any]:
|
||||
"""Get account information for UI display"""
|
||||
try:
|
||||
account_balance = self._get_account_balance_for_sizing()
|
||||
leverage = self.get_leverage()
|
||||
|
||||
return {
|
||||
'account_balance': account_balance,
|
||||
'leverage': leverage,
|
||||
'trading_mode': self.trading_mode,
|
||||
'simulation_mode': self.simulation_mode,
|
||||
'trading_enabled': self.trading_enabled,
|
||||
'position_sizing': {
|
||||
'base_percent': self.mexc_config.get('base_position_percent', 5.0),
|
||||
'max_percent': self.mexc_config.get('max_position_percent', 20.0),
|
||||
'min_percent': self.mexc_config.get('min_position_percent', 2.0)
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting account info: {e}")
|
||||
return {
|
||||
'account_balance': 100.0,
|
||||
'leverage': 50.0,
|
||||
'trading_mode': 'simulation',
|
||||
'simulation_mode': True,
|
||||
'trading_enabled': False,
|
||||
'position_sizing': {
|
||||
'base_percent': 5.0,
|
||||
'max_percent': 20.0,
|
||||
'min_percent': 2.0
|
||||
}
|
||||
}
|
164
debug/test_fixed_issues.py
Normal file
164
debug/test_fixed_issues.py
Normal file
@ -0,0 +1,164 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify that both model prediction and trading statistics issues are fixed
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
from core.trading_executor import TradingExecutor
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def test_model_predictions():
|
||||
"""Test that model predictions are working correctly"""
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("TESTING MODEL PREDICTIONS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Initialize components
|
||||
data_provider = DataProvider()
|
||||
orchestrator = TradingOrchestrator(data_provider)
|
||||
|
||||
# Check model registration
|
||||
logger.info("1. Checking model registration...")
|
||||
models = orchestrator.model_registry.get_all_models()
|
||||
logger.info(f" Registered models: {list(models.keys()) if models else 'None'}")
|
||||
|
||||
# Test making a decision
|
||||
logger.info("2. Testing trading decision generation...")
|
||||
decision = await orchestrator.make_trading_decision('ETH/USDT')
|
||||
|
||||
if decision:
|
||||
logger.info(f" ✅ Decision generated: {decision.action} (confidence: {decision.confidence:.3f})")
|
||||
logger.info(f" ✅ Reasoning: {decision.reasoning}")
|
||||
return True
|
||||
else:
|
||||
logger.error(" ❌ No decision generated")
|
||||
return False
|
||||
|
||||
def test_trading_statistics():
|
||||
"""Test that trading statistics calculations are working correctly"""
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("TESTING TRADING STATISTICS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Initialize trading executor
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Check if we have any trades
|
||||
trade_history = trading_executor.get_trade_history()
|
||||
logger.info(f"1. Current trade history: {len(trade_history)} trades")
|
||||
|
||||
# Get daily stats
|
||||
daily_stats = trading_executor.get_daily_stats()
|
||||
logger.info("2. Daily statistics from trading executor:")
|
||||
logger.info(f" Total trades: {daily_stats.get('total_trades', 0)}")
|
||||
logger.info(f" Winning trades: {daily_stats.get('winning_trades', 0)}")
|
||||
logger.info(f" Losing trades: {daily_stats.get('losing_trades', 0)}")
|
||||
logger.info(f" Win rate: {daily_stats.get('win_rate', 0.0) * 100:.1f}%")
|
||||
logger.info(f" Avg winning trade: ${daily_stats.get('avg_winning_trade', 0.0):.2f}")
|
||||
logger.info(f" Avg losing trade: ${daily_stats.get('avg_losing_trade', 0.0):.2f}")
|
||||
logger.info(f" Total P&L: ${daily_stats.get('total_pnl', 0.0):.2f}")
|
||||
|
||||
# Simulate some trades if we don't have any
|
||||
if daily_stats.get('total_trades', 0) == 0:
|
||||
logger.info("3. No trades found - simulating some test trades...")
|
||||
|
||||
# Add some mock trades to the trade history
|
||||
from core.trading_executor import TradeRecord
|
||||
from datetime import datetime
|
||||
|
||||
# Add a winning trade
|
||||
winning_trade = TradeRecord(
|
||||
symbol='ETH/USDT',
|
||||
side='LONG',
|
||||
quantity=0.01,
|
||||
entry_price=2500.0,
|
||||
exit_price=2550.0,
|
||||
entry_time=datetime.now(),
|
||||
exit_time=datetime.now(),
|
||||
pnl=0.50, # $0.50 profit
|
||||
fees=0.01,
|
||||
confidence=0.8
|
||||
)
|
||||
trading_executor.trade_history.append(winning_trade)
|
||||
|
||||
# Add a losing trade
|
||||
losing_trade = TradeRecord(
|
||||
symbol='ETH/USDT',
|
||||
side='LONG',
|
||||
quantity=0.01,
|
||||
entry_price=2500.0,
|
||||
exit_price=2480.0,
|
||||
entry_time=datetime.now(),
|
||||
exit_time=datetime.now(),
|
||||
pnl=-0.20, # $0.20 loss
|
||||
fees=0.01,
|
||||
confidence=0.7
|
||||
)
|
||||
trading_executor.trade_history.append(losing_trade)
|
||||
|
||||
# Get updated stats
|
||||
daily_stats = trading_executor.get_daily_stats()
|
||||
logger.info(" Updated statistics after adding test trades:")
|
||||
logger.info(f" Total trades: {daily_stats.get('total_trades', 0)}")
|
||||
logger.info(f" Winning trades: {daily_stats.get('winning_trades', 0)}")
|
||||
logger.info(f" Losing trades: {daily_stats.get('losing_trades', 0)}")
|
||||
logger.info(f" Win rate: {daily_stats.get('win_rate', 0.0) * 100:.1f}%")
|
||||
logger.info(f" Avg winning trade: ${daily_stats.get('avg_winning_trade', 0.0):.2f}")
|
||||
logger.info(f" Avg losing trade: ${daily_stats.get('avg_losing_trade', 0.0):.2f}")
|
||||
logger.info(f" Total P&L: ${daily_stats.get('total_pnl', 0.0):.2f}")
|
||||
|
||||
# Verify calculations
|
||||
expected_win_rate = 1/2 # 1 win out of 2 trades = 50%
|
||||
expected_avg_win = 0.50
|
||||
expected_avg_loss = -0.20
|
||||
|
||||
actual_win_rate = daily_stats.get('win_rate', 0.0)
|
||||
actual_avg_win = daily_stats.get('avg_winning_trade', 0.0)
|
||||
actual_avg_loss = daily_stats.get('avg_losing_trade', 0.0)
|
||||
|
||||
logger.info("4. Verifying calculations:")
|
||||
logger.info(f" Win rate: Expected {expected_win_rate*100:.1f}%, Got {actual_win_rate*100:.1f}% ✅" if abs(actual_win_rate - expected_win_rate) < 0.01 else f" Win rate: Expected {expected_win_rate*100:.1f}%, Got {actual_win_rate*100:.1f}% ❌")
|
||||
logger.info(f" Avg win: Expected ${expected_avg_win:.2f}, Got ${actual_avg_win:.2f} ✅" if abs(actual_avg_win - expected_avg_win) < 0.01 else f" Avg win: Expected ${expected_avg_win:.2f}, Got ${actual_avg_win:.2f} ❌")
|
||||
logger.info(f" Avg loss: Expected ${expected_avg_loss:.2f}, Got ${actual_avg_loss:.2f} ✅" if abs(actual_avg_loss - expected_avg_loss) < 0.01 else f" Avg loss: Expected ${expected_avg_loss:.2f}, Got ${actual_avg_loss:.2f} ❌")
|
||||
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
async def main():
|
||||
"""Run all tests"""
|
||||
|
||||
logger.info("🚀 STARTING COMPREHENSIVE FIXES TEST")
|
||||
logger.info("Testing both model prediction fixes and trading statistics fixes")
|
||||
|
||||
# Test model predictions
|
||||
prediction_success = await test_model_predictions()
|
||||
|
||||
# Test trading statistics
|
||||
stats_success = test_trading_statistics()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("TEST SUMMARY")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Model Predictions: {'✅ FIXED' if prediction_success else '❌ STILL BROKEN'}")
|
||||
logger.info(f"Trading Statistics: {'✅ FIXED' if stats_success else '❌ STILL BROKEN'}")
|
||||
|
||||
if prediction_success and stats_success:
|
||||
logger.info("🎉 ALL ISSUES FIXED! The system should now work correctly.")
|
||||
else:
|
||||
logger.error("❌ Some issues remain. Check the logs above for details.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
250
debug/test_trading_fixes.py
Normal file
250
debug/test_trading_fixes.py
Normal file
@ -0,0 +1,250 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify trading fixes:
|
||||
1. Position sizes with leverage
|
||||
2. ETH-only trading
|
||||
3. Correct win rate calculations
|
||||
4. Meaningful P&L values
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from core.trading_executor import TradingExecutor
|
||||
from core.trading_executor import TradeRecord
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_position_sizing():
|
||||
"""Test that position sizing now includes leverage and meaningful amounts"""
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("TESTING POSITION SIZING WITH LEVERAGE")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Initialize trading executor
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Test position calculation
|
||||
confidence = 0.8
|
||||
current_price = 2500.0 # ETH price
|
||||
|
||||
position_value = trading_executor._calculate_position_size(confidence, current_price)
|
||||
quantity = position_value / current_price
|
||||
|
||||
logger.info(f"1. Position calculation test:")
|
||||
logger.info(f" Confidence: {confidence}")
|
||||
logger.info(f" ETH Price: ${current_price}")
|
||||
logger.info(f" Position Value: ${position_value:.2f}")
|
||||
logger.info(f" Quantity: {quantity:.6f} ETH")
|
||||
|
||||
# Check if position is meaningful
|
||||
if position_value > 1000: # Should be >$1000 with 10x leverage
|
||||
logger.info(" ✅ Position size is meaningful (>$1000)")
|
||||
else:
|
||||
logger.error(f" ❌ Position size too small: ${position_value:.2f}")
|
||||
|
||||
# Test different confidence levels
|
||||
logger.info("2. Testing different confidence levels:")
|
||||
for conf in [0.2, 0.5, 0.8, 1.0]:
|
||||
pos_val = trading_executor._calculate_position_size(conf, current_price)
|
||||
qty = pos_val / current_price
|
||||
logger.info(f" Confidence {conf}: ${pos_val:.2f} ({qty:.6f} ETH)")
|
||||
|
||||
def test_eth_only_restriction():
|
||||
"""Test that only ETH trades are allowed"""
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("TESTING ETH-ONLY TRADING RESTRICTION")
|
||||
logger.info("=" * 60)
|
||||
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Test ETH trade (should be allowed)
|
||||
logger.info("1. Testing ETH/USDT trade (should be allowed):")
|
||||
eth_allowed = trading_executor._check_safety_conditions('ETH/USDT', 'BUY')
|
||||
logger.info(f" ETH/USDT allowed: {'✅ YES' if eth_allowed else '❌ NO'}")
|
||||
|
||||
# Test BTC trade (should be blocked)
|
||||
logger.info("2. Testing BTC/USDT trade (should be blocked):")
|
||||
btc_allowed = trading_executor._check_safety_conditions('BTC/USDT', 'BUY')
|
||||
logger.info(f" BTC/USDT allowed: {'❌ YES (ERROR!)' if btc_allowed else '✅ NO (CORRECT)'}")
|
||||
|
||||
def test_win_rate_calculation():
|
||||
"""Test that win rate calculations are correct"""
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("TESTING WIN RATE CALCULATIONS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Clear existing trades
|
||||
trading_executor.trade_history = []
|
||||
|
||||
# Add test trades with meaningful P&L
|
||||
logger.info("1. Adding test trades with meaningful P&L:")
|
||||
|
||||
# Add 3 winning trades
|
||||
for i in range(3):
|
||||
winning_trade = TradeRecord(
|
||||
symbol='ETH/USDT',
|
||||
side='LONG',
|
||||
quantity=1.0,
|
||||
entry_price=2500.0,
|
||||
exit_price=2550.0,
|
||||
entry_time=datetime.now(),
|
||||
exit_time=datetime.now(),
|
||||
pnl=50.0, # $50 profit with leverage
|
||||
fees=1.0,
|
||||
confidence=0.8,
|
||||
hold_time_seconds=30.0 # 30 second hold
|
||||
)
|
||||
trading_executor.trade_history.append(winning_trade)
|
||||
logger.info(f" Added winning trade #{i+1}: +$50.00 (30s hold)")
|
||||
|
||||
# Add 2 losing trades
|
||||
for i in range(2):
|
||||
losing_trade = TradeRecord(
|
||||
symbol='ETH/USDT',
|
||||
side='LONG',
|
||||
quantity=1.0,
|
||||
entry_price=2500.0,
|
||||
exit_price=2475.0,
|
||||
entry_time=datetime.now(),
|
||||
exit_time=datetime.now(),
|
||||
pnl=-25.0, # $25 loss with leverage
|
||||
fees=1.0,
|
||||
confidence=0.7,
|
||||
hold_time_seconds=15.0 # 15 second hold
|
||||
)
|
||||
trading_executor.trade_history.append(losing_trade)
|
||||
logger.info(f" Added losing trade #{i+1}: -$25.00 (15s hold)")
|
||||
|
||||
# Get statistics
|
||||
stats = trading_executor.get_daily_stats()
|
||||
|
||||
logger.info("2. Calculated statistics:")
|
||||
logger.info(f" Total trades: {stats['total_trades']}")
|
||||
logger.info(f" Winning trades: {stats['winning_trades']}")
|
||||
logger.info(f" Losing trades: {stats['losing_trades']}")
|
||||
logger.info(f" Win rate: {stats['win_rate']*100:.1f}%")
|
||||
logger.info(f" Avg winning trade: ${stats['avg_winning_trade']:.2f}")
|
||||
logger.info(f" Avg losing trade: ${stats['avg_losing_trade']:.2f}")
|
||||
logger.info(f" Total P&L: ${stats['total_pnl']:.2f}")
|
||||
|
||||
# Verify calculations
|
||||
expected_win_rate = 3/5 # 3 wins out of 5 trades = 60%
|
||||
expected_avg_win = 50.0
|
||||
expected_avg_loss = -25.0
|
||||
|
||||
logger.info("3. Verification:")
|
||||
win_rate_ok = abs(stats['win_rate'] - expected_win_rate) < 0.01
|
||||
avg_win_ok = abs(stats['avg_winning_trade'] - expected_avg_win) < 0.01
|
||||
avg_loss_ok = abs(stats['avg_losing_trade'] - expected_avg_loss) < 0.01
|
||||
|
||||
logger.info(f" Win rate: Expected {expected_win_rate*100:.1f}%, Got {stats['win_rate']*100:.1f}% {'✅' if win_rate_ok else '❌'}")
|
||||
logger.info(f" Avg win: Expected ${expected_avg_win:.2f}, Got ${stats['avg_winning_trade']:.2f} {'✅' if avg_win_ok else '❌'}")
|
||||
logger.info(f" Avg loss: Expected ${expected_avg_loss:.2f}, Got ${stats['avg_losing_trade']:.2f} {'✅' if avg_loss_ok else '❌'}")
|
||||
|
||||
return win_rate_ok and avg_win_ok and avg_loss_ok
|
||||
|
||||
def test_new_features():
|
||||
"""Test new features: hold time, leverage, percentage-based sizing"""
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("TESTING NEW FEATURES")
|
||||
logger.info("=" * 60)
|
||||
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Test account info
|
||||
account_info = trading_executor.get_account_info()
|
||||
logger.info(f"1. Account Information:")
|
||||
logger.info(f" Account Balance: ${account_info['account_balance']:.2f}")
|
||||
logger.info(f" Leverage: {account_info['leverage']:.0f}x")
|
||||
logger.info(f" Trading Mode: {account_info['trading_mode']}")
|
||||
logger.info(f" Position Sizing: {account_info['position_sizing']['base_percent']:.1f}% base")
|
||||
|
||||
# Test leverage setting
|
||||
logger.info("2. Testing leverage control:")
|
||||
old_leverage = trading_executor.get_leverage()
|
||||
logger.info(f" Current leverage: {old_leverage:.0f}x")
|
||||
|
||||
success = trading_executor.set_leverage(100.0)
|
||||
new_leverage = trading_executor.get_leverage()
|
||||
logger.info(f" Set to 100x: {'✅ SUCCESS' if success and new_leverage == 100.0 else '❌ FAILED'}")
|
||||
|
||||
# Reset leverage
|
||||
trading_executor.set_leverage(old_leverage)
|
||||
|
||||
# Test percentage-based position sizing
|
||||
logger.info("3. Testing percentage-based position sizing:")
|
||||
confidence = 0.8
|
||||
eth_price = 2500.0
|
||||
|
||||
position_value = trading_executor._calculate_position_size(confidence, eth_price)
|
||||
account_balance = trading_executor._get_account_balance_for_sizing()
|
||||
base_percent = trading_executor.mexc_config.get('base_position_percent', 5.0)
|
||||
leverage = trading_executor.get_leverage()
|
||||
|
||||
expected_base = account_balance * (base_percent / 100.0) * confidence
|
||||
expected_leveraged = expected_base * leverage
|
||||
|
||||
logger.info(f" Account: ${account_balance:.2f}")
|
||||
logger.info(f" Base %: {base_percent:.1f}%")
|
||||
logger.info(f" Confidence: {confidence:.1f}")
|
||||
logger.info(f" Leverage: {leverage:.0f}x")
|
||||
logger.info(f" Expected base: ${expected_base:.2f}")
|
||||
logger.info(f" Expected leveraged: ${expected_leveraged:.2f}")
|
||||
logger.info(f" Actual: ${position_value:.2f}")
|
||||
|
||||
sizing_ok = abs(position_value - expected_leveraged) < 0.01
|
||||
logger.info(f" Percentage sizing: {'✅ CORRECT' if sizing_ok else '❌ INCORRECT'}")
|
||||
|
||||
return sizing_ok
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
|
||||
logger.info("🚀 TESTING TRADING FIXES AND NEW FEATURES")
|
||||
logger.info("Testing position sizing, ETH-only trading, win rate calculations, and new features")
|
||||
|
||||
# Test position sizing
|
||||
test_position_sizing()
|
||||
|
||||
# Test ETH-only restriction
|
||||
test_eth_only_restriction()
|
||||
|
||||
# Test win rate calculation
|
||||
calculation_success = test_win_rate_calculation()
|
||||
|
||||
# Test new features
|
||||
features_success = test_new_features()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("TEST SUMMARY")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Position Sizing: ✅ Updated with percentage-based leverage")
|
||||
logger.info(f"ETH-Only Trading: ✅ Configured in config")
|
||||
logger.info(f"Win Rate Calculation: {'✅ FIXED' if calculation_success else '❌ STILL BROKEN'}")
|
||||
logger.info(f"New Features: {'✅ WORKING' if features_success else '❌ ISSUES FOUND'}")
|
||||
|
||||
if calculation_success and features_success:
|
||||
logger.info("🎉 ALL FEATURES WORKING! Now you should see:")
|
||||
logger.info(" - Percentage-based position sizing (2-20% of account)")
|
||||
logger.info(" - 50x leverage (adjustable in UI)")
|
||||
logger.info(" - Hold time in seconds for each trade")
|
||||
logger.info(" - Total fees in trading statistics")
|
||||
logger.info(" - Only ETH/USDT trades")
|
||||
logger.info(" - Correct win rate calculations")
|
||||
else:
|
||||
logger.error("❌ Some issues remain. Check the logs above for details.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
docs/dev/architecture.md
Normal file
1
docs/dev/architecture.md
Normal file
@ -0,0 +1 @@
|
||||
our system architecture is such that we have data inflow with different rates from different providers. our data flow though the system should be single and centralized. I think our orchestrator class is taking that role. since our different data feeds have different rates (and also each model has different inference times and cycle) our orchestrator should keep cache of the latest available data and keep track of the rates and statistics of each data source - being data api or our own model outputs. so the available data is constantly updated and refreshed in realtime by multiple sources, and is also consumed by all smodels
|
@ -124,7 +124,7 @@ class CheckpointManager:
|
||||
self._rotate_checkpoints(model_name)
|
||||
self._save_metadata()
|
||||
|
||||
logger.info(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})")
|
||||
logger.debug(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})")
|
||||
return metadata
|
||||
|
||||
except Exception as e:
|
||||
@ -133,19 +133,37 @@ class CheckpointManager:
|
||||
|
||||
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
||||
try:
|
||||
if model_name not in self.checkpoints or not self.checkpoints[model_name]:
|
||||
logger.warning(f"No checkpoints found for model: {model_name}")
|
||||
return None
|
||||
# First, try the standard checkpoint system
|
||||
if model_name in self.checkpoints and self.checkpoints[model_name]:
|
||||
# Filter out checkpoints with non-existent files
|
||||
valid_checkpoints = [
|
||||
cp for cp in self.checkpoints[model_name]
|
||||
if Path(cp.file_path).exists()
|
||||
]
|
||||
|
||||
if valid_checkpoints:
|
||||
best_checkpoint = max(valid_checkpoints, key=lambda x: x.performance_score)
|
||||
logger.debug(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}")
|
||||
return best_checkpoint.file_path, best_checkpoint
|
||||
else:
|
||||
# Clean up invalid metadata entries
|
||||
invalid_count = len(self.checkpoints[model_name])
|
||||
logger.warning(f"Found {invalid_count} invalid checkpoint entries for {model_name}, cleaning up metadata")
|
||||
self.checkpoints[model_name] = []
|
||||
self._save_metadata()
|
||||
|
||||
best_checkpoint = max(self.checkpoints[model_name], key=lambda x: x.performance_score)
|
||||
# Fallback: Look for existing saved models in the legacy format
|
||||
logger.debug(f"No valid checkpoints found for model: {model_name}, attempting to find legacy saved models")
|
||||
legacy_model_path = self._find_legacy_model(model_name)
|
||||
|
||||
if not Path(best_checkpoint.file_path).exists():
|
||||
# temporary disable logging to avoid spam
|
||||
# logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}")
|
||||
return None
|
||||
if legacy_model_path:
|
||||
# Create checkpoint metadata for the legacy model using actual file data
|
||||
legacy_metadata = self._create_legacy_metadata(model_name, legacy_model_path)
|
||||
logger.debug(f"Found legacy model for {model_name}: {legacy_model_path}")
|
||||
return str(legacy_model_path), legacy_metadata
|
||||
|
||||
logger.debug(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}")
|
||||
return best_checkpoint.file_path, best_checkpoint
|
||||
logger.warning(f"No checkpoints or legacy models found for: {model_name}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
|
||||
@ -181,16 +199,8 @@ class CheckpointManager:
|
||||
# Bonus for processing more training samples
|
||||
score += min(10, metrics['training_samples'] / 10)
|
||||
|
||||
# Ensure minimum score for any training activity
|
||||
if score == 0.0 and metrics:
|
||||
# Use the first available metric with better scaling
|
||||
first_metric = next(iter(metrics.values()))
|
||||
if first_metric > 0:
|
||||
score = max(0.1, min(10, first_metric))
|
||||
else:
|
||||
score = 0.1
|
||||
|
||||
return max(score, 0.1)
|
||||
# Return actual calculated score - NO SYNTHETIC MINIMUM
|
||||
return score
|
||||
|
||||
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
|
||||
"""Improved checkpoint saving logic with more frequent saves during training"""
|
||||
@ -222,7 +232,7 @@ class CheckpointManager:
|
||||
|
||||
# Save more frequently during active training (every 5th attempt instead of 10th)
|
||||
if random.random() < 0.2: # 20% chance to save anyway
|
||||
logger.info(f"Saving checkpoint for {model_name} - periodic save during active training")
|
||||
logger.debug(f"Saving checkpoint for {model_name} - periodic save during active training")
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -258,7 +268,7 @@ class CheckpointManager:
|
||||
file_path = Path(checkpoint.file_path)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
logger.info(f"Rotated out checkpoint: {checkpoint.checkpoint_id}")
|
||||
logger.debug(f"Rotated out checkpoint: {checkpoint.checkpoint_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
|
||||
|
||||
@ -331,6 +341,110 @@ class CheckpointManager:
|
||||
stats['total_size_mb'] += model_size
|
||||
|
||||
return stats
|
||||
|
||||
def _find_legacy_model(self, model_name: str) -> Optional[Path]:
|
||||
"""Find legacy saved models based on model name patterns"""
|
||||
base_dir = Path(self.base_dir)
|
||||
|
||||
# Define model name mappings and patterns for legacy files
|
||||
legacy_patterns = {
|
||||
'dqn_agent': [
|
||||
'dqn_agent_best_policy.pt',
|
||||
'enhanced_dqn_best_policy.pt',
|
||||
'improved_dqn_agent_best_policy.pt',
|
||||
'dqn_agent_final_policy.pt'
|
||||
],
|
||||
'enhanced_cnn': [
|
||||
'cnn_model_best.pt',
|
||||
'optimized_short_term_model_best.pt',
|
||||
'optimized_short_term_model_realtime_best.pt',
|
||||
'optimized_short_term_model_ticks_best.pt'
|
||||
],
|
||||
'extrema_trainer': [
|
||||
'supervised_model_best.pt'
|
||||
],
|
||||
'cob_rl': [
|
||||
'best_rl_model.pth_policy.pt',
|
||||
'rl_agent_best_policy.pt'
|
||||
],
|
||||
'decision': [
|
||||
# Decision models might be in subdirectories, but let's check main dir too
|
||||
'decision_best.pt',
|
||||
'decision_model_best.pt',
|
||||
# Check for transformer models which might be used as decision models
|
||||
'enhanced_dqn_best_policy.pt',
|
||||
'improved_dqn_agent_best_policy.pt'
|
||||
]
|
||||
}
|
||||
|
||||
# Get patterns for this model name
|
||||
patterns = legacy_patterns.get(model_name, [])
|
||||
|
||||
# Also try generic patterns based on model name
|
||||
patterns.extend([
|
||||
f'{model_name}_best.pt',
|
||||
f'{model_name}_best_policy.pt',
|
||||
f'{model_name}_final.pt',
|
||||
f'{model_name}_final_policy.pt'
|
||||
])
|
||||
|
||||
# Search for the model files
|
||||
for pattern in patterns:
|
||||
candidate_path = base_dir / pattern
|
||||
if candidate_path.exists():
|
||||
logger.debug(f"Found legacy model file: {candidate_path}")
|
||||
return candidate_path
|
||||
|
||||
# Also check subdirectories
|
||||
for subdir in base_dir.iterdir():
|
||||
if subdir.is_dir() and subdir.name == model_name:
|
||||
for pattern in patterns:
|
||||
candidate_path = subdir / pattern
|
||||
if candidate_path.exists():
|
||||
logger.debug(f"Found legacy model file in subdirectory: {candidate_path}")
|
||||
return candidate_path
|
||||
|
||||
return None
|
||||
|
||||
def _create_legacy_metadata(self, model_name: str, file_path: Path) -> CheckpointMetadata:
|
||||
"""Create metadata for legacy model files using only actual file information"""
|
||||
try:
|
||||
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
created_time = datetime.fromtimestamp(file_path.stat().st_mtime)
|
||||
|
||||
# NO SYNTHETIC DATA - use only actual file information
|
||||
return CheckpointMetadata(
|
||||
checkpoint_id=f"legacy_{model_name}_{int(created_time.timestamp())}",
|
||||
model_name=model_name,
|
||||
model_type=model_name,
|
||||
file_path=str(file_path),
|
||||
created_at=created_time,
|
||||
file_size_mb=file_size_mb,
|
||||
performance_score=0.0, # Unknown performance - use 0, not synthetic values
|
||||
accuracy=None,
|
||||
loss=None,
|
||||
val_accuracy=None,
|
||||
val_loss=None,
|
||||
reward=None,
|
||||
pnl=None,
|
||||
epoch=None,
|
||||
training_time_hours=None,
|
||||
total_parameters=None,
|
||||
wandb_run_id=None,
|
||||
wandb_artifact_name=None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating legacy metadata for {model_name}: {e}")
|
||||
# Return a basic metadata with minimal info - NO SYNTHETIC VALUES
|
||||
return CheckpointMetadata(
|
||||
checkpoint_id=f"legacy_{model_name}",
|
||||
model_name=model_name,
|
||||
model_type=model_name,
|
||||
file_path=str(file_path),
|
||||
created_at=datetime.now(),
|
||||
file_size_mb=0.0,
|
||||
performance_score=0.0 # Unknown - use 0, not synthetic
|
||||
)
|
||||
|
||||
_checkpoint_manager = None
|
||||
|
||||
|
@ -1934,11 +1934,13 @@ class CleanTradingDashboard:
|
||||
|
||||
# Fallback if orchestrator not available or returns None
|
||||
if model_states is None:
|
||||
# FIXED: No longer using hardcoded placeholder loss values
|
||||
# Dashboard should show "No Data" or actual training status instead
|
||||
model_states = {
|
||||
'dqn': {'initial_loss': 0.2850, 'current_loss': 0.0145, 'best_loss': 0.0098, 'checkpoint_loaded': False},
|
||||
'cnn': {'initial_loss': 0.4120, 'current_loss': 0.0187, 'best_loss': 0.0134, 'checkpoint_loaded': False},
|
||||
'cob_rl': {'initial_loss': 0.3560, 'current_loss': 0.0098, 'best_loss': 0.0076, 'checkpoint_loaded': False},
|
||||
'decision': {'initial_loss': 0.2980, 'current_loss': 0.0089, 'best_loss': 0.0065, 'checkpoint_loaded': False}
|
||||
'dqn': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||||
'cnn': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||||
'cob_rl': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||||
'decision': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False}
|
||||
}
|
||||
|
||||
# Get latest predictions from all models
|
||||
@ -1956,6 +1958,13 @@ class CleanTradingDashboard:
|
||||
except (TypeError, ZeroDivisionError):
|
||||
return default_improvement
|
||||
|
||||
# Helper function to format loss values
|
||||
def format_loss_value(loss_value: Optional[float]) -> str:
|
||||
"""Format loss value for display, showing 'No Data' for None values"""
|
||||
if loss_value is None:
|
||||
return "No Data"
|
||||
return f"{loss_value:.4f}"
|
||||
|
||||
# Helper function to get timing information
|
||||
def get_model_timing_info(model_name: str) -> Dict[str, Any]:
|
||||
timing = {
|
||||
@ -2041,12 +2050,12 @@ class CleanTradingDashboard:
|
||||
},
|
||||
# FIXED: Get REAL loss values from orchestrator model, not placeholders
|
||||
'loss_5ma': self._get_real_model_loss('dqn'),
|
||||
'initial_loss': dqn_state.get('initial_loss', 0.2850),
|
||||
'initial_loss': dqn_state.get('initial_loss'), # No fallback - show None if unknown
|
||||
'best_loss': self._get_real_best_loss('dqn'),
|
||||
'improvement': safe_improvement_calc(
|
||||
dqn_state.get('initial_loss', 0.2850),
|
||||
dqn_state.get('initial_loss'),
|
||||
self._get_real_model_loss('dqn'),
|
||||
0.0 if not dqn_active else 94.9 # Default if no real improvement available
|
||||
0.0 # No synthetic default improvement
|
||||
),
|
||||
'checkpoint_loaded': dqn_checkpoint_loaded,
|
||||
'model_type': 'DQN',
|
||||
@ -2109,13 +2118,13 @@ class CleanTradingDashboard:
|
||||
'predicted_price': cnn_predicted_price,
|
||||
'type': cnn_latest.get('type', 'cnn_pivot') if cnn_latest else 'cnn_pivot'
|
||||
},
|
||||
'loss_5ma': cnn_state.get('current_loss', 0.0187),
|
||||
'initial_loss': cnn_state.get('initial_loss', 0.4120),
|
||||
'best_loss': cnn_state.get('best_loss', 0.0134),
|
||||
'loss_5ma': cnn_state.get('current_loss'),
|
||||
'initial_loss': cnn_state.get('initial_loss'),
|
||||
'best_loss': cnn_state.get('best_loss'),
|
||||
'improvement': safe_improvement_calc(
|
||||
cnn_state.get('initial_loss', 0.4120),
|
||||
cnn_state.get('current_loss', 0.0187),
|
||||
95.5 # Default improvement percentage
|
||||
cnn_state.get('initial_loss'),
|
||||
cnn_state.get('current_loss'),
|
||||
0.0 # No synthetic default improvement
|
||||
),
|
||||
'checkpoint_loaded': cnn_state.get('checkpoint_loaded', False),
|
||||
'model_type': 'CNN',
|
||||
@ -3948,11 +3957,11 @@ class CleanTradingDashboard:
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
def _get_real_model_loss(self, model_name: str) -> float:
|
||||
def _get_real_model_loss(self, model_name: str) -> Optional[float]:
|
||||
"""Get REAL current loss from the actual model, not placeholders"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
return 0.2850 # Default fallback
|
||||
return None # No orchestrator = no real data
|
||||
|
||||
if model_name == 'dqn' and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
# Get real loss from DQN agent
|
||||
@ -3961,8 +3970,8 @@ class CleanTradingDashboard:
|
||||
# Average of last 50 losses for current loss
|
||||
recent_losses = agent.losses[-50:]
|
||||
return sum(recent_losses) / len(recent_losses)
|
||||
elif hasattr(agent, 'current_loss'):
|
||||
return float(getattr(agent, 'current_loss', 0.2850))
|
||||
elif hasattr(agent, 'current_loss') and agent.current_loss is not None:
|
||||
return float(agent.current_loss)
|
||||
|
||||
elif model_name == 'cnn' and hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
# Get real loss from CNN model
|
||||
@ -3970,8 +3979,8 @@ class CleanTradingDashboard:
|
||||
if hasattr(model, 'training_losses') and len(getattr(model, 'training_losses',[])) > 0:
|
||||
recent_losses = getattr(model, 'training_losses',[])[-50:]
|
||||
return sum(recent_losses) / len(recent_losses)
|
||||
elif hasattr(model, 'current_loss'):
|
||||
return float(getattr(model, 'current_loss', 0.2850))
|
||||
elif hasattr(model, 'current_loss') and model.current_loss is not None:
|
||||
return float(model.current_loss)
|
||||
|
||||
elif model_name == 'decision' and hasattr(self.orchestrator, 'decision_fusion_network'):
|
||||
# Get real loss from decision fusion
|
||||
@ -3983,45 +3992,45 @@ class CleanTradingDashboard:
|
||||
# Fallback to model states
|
||||
model_states = self.orchestrator.get_model_states() if hasattr(self.orchestrator, 'get_model_states') else {}
|
||||
state = model_states.get(model_name, {})
|
||||
return state.get('current_loss', 0.2850)
|
||||
return state.get('current_loss') # Return None if no real data
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting real loss for {model_name}: {e}")
|
||||
return 0.2850 # Safe fallback
|
||||
return None # Return None instead of synthetic data
|
||||
|
||||
def _get_real_best_loss(self, model_name: str) -> float:
|
||||
def _get_real_best_loss(self, model_name: str) -> Optional[float]:
|
||||
"""Get REAL best loss from the actual model"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
return 0.0145 # Default fallback
|
||||
return None # No orchestrator = no real data
|
||||
|
||||
if model_name == 'dqn' and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
agent = self.orchestrator.rl_agent
|
||||
if hasattr(agent, 'best_loss'):
|
||||
return float(getattr(agent, 'best_loss', 0.0145))
|
||||
if hasattr(agent, 'best_loss') and agent.best_loss is not None:
|
||||
return float(agent.best_loss)
|
||||
elif hasattr(agent, 'losses') and len(agent.losses) > 0:
|
||||
return min(agent.losses)
|
||||
|
||||
elif model_name == 'cnn' and hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
model = self.orchestrator.cnn_model
|
||||
if hasattr(model, 'best_loss'):
|
||||
return float(getattr(model, 'best_loss', 0.0145))
|
||||
if hasattr(model, 'best_loss') and model.best_loss is not None:
|
||||
return float(model.best_loss)
|
||||
elif hasattr(model, 'training_losses') and len(getattr(model, 'training_losses', [])) > 0:
|
||||
return min(getattr(model, 'training_losses', [0.0145]))
|
||||
return min(getattr(model, 'training_losses', []))
|
||||
|
||||
elif model_name == 'decision' and hasattr(self.orchestrator, 'fusion_training_data'):
|
||||
if len(self.orchestrator.fusion_training_data) > 0:
|
||||
all_losses = [entry['loss'] for entry in self.orchestrator.fusion_training_data]
|
||||
return min(all_losses) if all_losses else 0.0065
|
||||
return min(all_losses) if all_losses else None
|
||||
|
||||
# Fallback to model states
|
||||
model_states = self.orchestrator.get_model_states() if hasattr(self.orchestrator, 'get_model_states') else {}
|
||||
state = model_states.get(model_name, {})
|
||||
return state.get('best_loss', 0.0145)
|
||||
return state.get('best_loss') # Return None if no real data
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting best loss for {model_name}: {e}")
|
||||
return 0.0145 # Safe fallback
|
||||
return None # Return None instead of synthetic data
|
||||
|
||||
def _clear_old_signals_for_tick_range(self):
|
||||
"""Clear old signals that are outside the current tick cache time range - VERY CONSERVATIVE"""
|
||||
@ -4707,8 +4716,50 @@ class CleanTradingDashboard:
|
||||
return 0
|
||||
|
||||
def _get_trading_statistics(self) -> Dict[str, Any]:
|
||||
"""Calculate trading statistics from closed trades"""
|
||||
"""Get trading statistics from trading executor"""
|
||||
try:
|
||||
# Try to get statistics from trading executor first
|
||||
if self.trading_executor:
|
||||
executor_stats = self.trading_executor.get_daily_stats()
|
||||
closed_trades = self.trading_executor.get_closed_trades()
|
||||
|
||||
if executor_stats and executor_stats.get('total_trades', 0) > 0:
|
||||
# Calculate largest win/loss from closed trades
|
||||
largest_win = 0.0
|
||||
largest_loss = 0.0
|
||||
|
||||
if closed_trades:
|
||||
for trade in closed_trades:
|
||||
try:
|
||||
# Handle both dictionary and object formats
|
||||
if isinstance(trade, dict):
|
||||
pnl = trade.get('pnl', 0)
|
||||
else:
|
||||
pnl = getattr(trade, 'pnl', 0)
|
||||
|
||||
if pnl > 0:
|
||||
largest_win = max(largest_win, pnl)
|
||||
elif pnl < 0:
|
||||
largest_loss = max(largest_loss, abs(pnl))
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error processing trade for statistics: {e}")
|
||||
continue
|
||||
|
||||
# Map executor stats to dashboard format
|
||||
return {
|
||||
'total_trades': executor_stats.get('total_trades', 0),
|
||||
'winning_trades': executor_stats.get('winning_trades', 0),
|
||||
'losing_trades': executor_stats.get('losing_trades', 0),
|
||||
'win_rate': executor_stats.get('win_rate', 0.0) * 100, # Convert to percentage
|
||||
'avg_win_size': executor_stats.get('avg_winning_trade', 0.0), # Correct mapping
|
||||
'avg_loss_size': abs(executor_stats.get('avg_losing_trade', 0.0)), # Make positive for display
|
||||
'largest_win': largest_win,
|
||||
'largest_loss': largest_loss,
|
||||
'total_pnl': executor_stats.get('total_pnl', 0.0)
|
||||
}
|
||||
|
||||
# Fallback to dashboard's own trade list if no trading executor
|
||||
if not self.closed_trades:
|
||||
return {
|
||||
'total_trades': 0,
|
||||
@ -4841,17 +4892,19 @@ class CleanTradingDashboard:
|
||||
logger.error(f"Error initiating orchestrator connection: {e}")
|
||||
|
||||
async def _on_trading_decision(self, decision):
|
||||
"""Handle trading decision from orchestrator."""
|
||||
"""Handle trading decision from orchestrator and execute through trading executor."""
|
||||
try:
|
||||
# Handle both object and dict formats
|
||||
if hasattr(decision, 'action'):
|
||||
action = getattr(decision, 'action', 'HOLD')
|
||||
symbol = getattr(decision, 'symbol', 'ETH/USDT')
|
||||
confidence = getattr(decision, 'confidence', 0.0)
|
||||
price = getattr(decision, 'price', None)
|
||||
else:
|
||||
action = decision.get('action', 'HOLD')
|
||||
symbol = decision.get('symbol', 'ETH/USDT')
|
||||
confidence = decision.get('confidence', 0.0)
|
||||
price = decision.get('price', None)
|
||||
|
||||
if action == 'HOLD':
|
||||
return
|
||||
@ -4877,11 +4930,45 @@ class CleanTradingDashboard:
|
||||
dashboard_decision['timestamp'] = datetime.now()
|
||||
dashboard_decision['executed'] = False
|
||||
|
||||
logger.info(f"[ORCHESTRATOR SIGNAL] Received: {action} for {symbol} (confidence: {confidence:.3f})")
|
||||
|
||||
# EXECUTE THE DECISION THROUGH TRADING EXECUTOR
|
||||
if self.trading_executor and confidence > 0.5: # Only execute high confidence signals
|
||||
try:
|
||||
logger.info(f"[ORCHESTRATOR EXECUTION] Attempting to execute {action} for {symbol} via trading executor...")
|
||||
success = self.trading_executor.execute_signal(
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
current_price=price
|
||||
)
|
||||
|
||||
if success:
|
||||
dashboard_decision['executed'] = True
|
||||
dashboard_decision['execution_time'] = datetime.now()
|
||||
logger.info(f"[ORCHESTRATOR EXECUTION] SUCCESS: {action} executed for {symbol}")
|
||||
|
||||
# Sync position from trading executor after execution
|
||||
self._sync_position_from_executor(symbol)
|
||||
|
||||
else:
|
||||
logger.warning(f"[ORCHESTRATOR EXECUTION] FAILED: {action} execution blocked for {symbol}")
|
||||
dashboard_decision['execution_failure'] = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ORCHESTRATOR EXECUTION] ERROR: Failed to execute {action} for {symbol}: {e}")
|
||||
dashboard_decision['execution_error'] = str(e)
|
||||
else:
|
||||
if not self.trading_executor:
|
||||
logger.warning("[ORCHESTRATOR EXECUTION] No trading executor available")
|
||||
elif confidence <= 0.5:
|
||||
logger.info(f"[ORCHESTRATOR EXECUTION] Low confidence signal ignored: {action} for {symbol} (confidence: {confidence:.3f})")
|
||||
|
||||
# Store decision in dashboard
|
||||
self.recent_decisions.append(dashboard_decision)
|
||||
if len(self.recent_decisions) > 200:
|
||||
self.recent_decisions.pop(0)
|
||||
|
||||
logger.info(f"[ORCHESTRATOR SIGNAL] Received: {action} for {symbol} (confidence: {confidence:.3f})")
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling trading decision: {e}")
|
||||
|
||||
|
@ -96,6 +96,8 @@ class DashboardComponentManager:
|
||||
total_trades = trading_stats.get('total_trades', 0)
|
||||
winning_trades = trading_stats.get('winning_trades', 0)
|
||||
losing_trades = trading_stats.get('losing_trades', 0)
|
||||
total_fees = trading_stats.get('total_fees', 0)
|
||||
breakeven_trades = trading_stats.get('breakeven_trades', 0)
|
||||
|
||||
win_rate_class = "text-success" if win_rate >= 50 else "text-warning" if win_rate >= 30 else "text-danger"
|
||||
|
||||
@ -106,16 +108,20 @@ class DashboardComponentManager:
|
||||
html.Div([
|
||||
html.Span("Win Rate: ", className="small text-muted"),
|
||||
html.Span(f"{win_rate:.1f}%", className=f"fw-bold {win_rate_class}"),
|
||||
html.Span(f" ({winning_trades}W/{losing_trades}L)", className="small text-muted")
|
||||
], className="col-4"),
|
||||
html.Span(f" ({winning_trades}W/{losing_trades}L/{breakeven_trades}B)", className="small text-muted")
|
||||
], className="col-3"),
|
||||
html.Div([
|
||||
html.Span("Avg Win: ", className="small text-muted"),
|
||||
html.Span(f"${avg_win:.2f}", className="fw-bold text-success")
|
||||
], className="col-4"),
|
||||
], className="col-3"),
|
||||
html.Div([
|
||||
html.Span("Avg Loss: ", className="small text-muted"),
|
||||
html.Span(f"${avg_loss:.2f}", className="fw-bold text-danger")
|
||||
], className="col-4")
|
||||
], className="col-3"),
|
||||
html.Div([
|
||||
html.Span("Total Fees: ", className="small text-muted"),
|
||||
html.Span(f"${total_fees:.2f}", className="fw-bold text-warning")
|
||||
], className="col-3")
|
||||
], className="row"),
|
||||
html.Hr(className="my-2")
|
||||
], className="mb-3")
|
||||
@ -135,6 +141,7 @@ class DashboardComponentManager:
|
||||
html.Th("Size", className="small"),
|
||||
html.Th("Entry", className="small"),
|
||||
html.Th("Exit", className="small"),
|
||||
html.Th("Hold (s)", className="small"),
|
||||
html.Th("P&L", className="small"),
|
||||
html.Th("Fees", className="small")
|
||||
])
|
||||
@ -142,7 +149,7 @@ class DashboardComponentManager:
|
||||
|
||||
# Create table rows
|
||||
rows = []
|
||||
for trade in closed_trades[-20:]: # Last 20 trades
|
||||
for trade in closed_trades: # Removed [-20:] to show all trades
|
||||
# Handle both trade objects and dictionary formats
|
||||
if hasattr(trade, 'entry_time'):
|
||||
# This is a trade object
|
||||
@ -153,15 +160,17 @@ class DashboardComponentManager:
|
||||
exit_price = getattr(trade, 'exit_price', 0)
|
||||
pnl = getattr(trade, 'pnl', 0)
|
||||
fees = getattr(trade, 'fees', 0)
|
||||
hold_time_seconds = getattr(trade, 'hold_time_seconds', 0.0)
|
||||
else:
|
||||
# This is a dictionary format
|
||||
entry_time = trade.get('entry_time', 'Unknown')
|
||||
side = trade.get('side', 'UNKNOWN')
|
||||
size = trade.get('size', 0)
|
||||
size = trade.get('quantity', trade.get('size', 0)) # Try 'quantity' first, then 'size'
|
||||
entry_price = trade.get('entry_price', 0)
|
||||
exit_price = trade.get('exit_price', 0)
|
||||
pnl = trade.get('pnl', 0)
|
||||
fees = trade.get('fees', 0)
|
||||
hold_time_seconds = trade.get('hold_time_seconds', 0.0)
|
||||
|
||||
# Format time
|
||||
if isinstance(entry_time, datetime):
|
||||
@ -179,6 +188,7 @@ class DashboardComponentManager:
|
||||
html.Td(f"{size:.3f}", className="small"),
|
||||
html.Td(f"${entry_price:.2f}", className="small"),
|
||||
html.Td(f"${exit_price:.2f}", className="small"),
|
||||
html.Td(f"{hold_time_seconds:.0f}", className="small text-info"),
|
||||
html.Td(f"${pnl:.2f}", className=f"small {pnl_class}"),
|
||||
html.Td(f"${fees:.3f}", className="small text-muted")
|
||||
])
|
||||
@ -188,11 +198,17 @@ class DashboardComponentManager:
|
||||
|
||||
table = html.Table([headers, tbody], className="table table-sm table-striped")
|
||||
|
||||
# Wrap the table in a scrollable div
|
||||
scrollable_table_container = html.Div(
|
||||
table,
|
||||
style={'maxHeight': '300px', 'overflowY': 'scroll', 'overflowX': 'hidden'}
|
||||
)
|
||||
|
||||
# Combine statistics header with table
|
||||
if stats_header:
|
||||
return html.Div(stats_header + [table])
|
||||
return html.Div(stats_header + [scrollable_table_container])
|
||||
else:
|
||||
return table
|
||||
return scrollable_table_container
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting closed trades: {e}")
|
||||
|
Reference in New Issue
Block a user