4 Commits

Author SHA1 Message Date
6c91bf0b93 fix sim and wip fix live 2025-07-08 02:47:10 +03:00
64678bd8d3 more live trades fix 2025-07-08 02:03:32 +03:00
4ab7bc1846 tweaks, try live trading 2025-07-08 01:33:22 +03:00
9cd2d5d8a4 fixes 2025-07-07 23:39:12 +03:00
15 changed files with 640 additions and 356 deletions

Binary file not shown.

View File

@ -5,6 +5,7 @@ import requests
import hmac import hmac
import hashlib import hashlib
from urllib.parse import urlencode, quote_plus from urllib.parse import urlencode, quote_plus
import json # Added for json.dumps
from .exchange_interface import ExchangeInterface from .exchange_interface import ExchangeInterface
@ -85,37 +86,40 @@ class MEXCInterface(ExchangeInterface):
return symbol.replace('/', '_').upper() return symbol.replace('/', '_').upper()
def _generate_signature(self, timestamp: str, method: str, endpoint: str, params: Dict[str, Any]) -> str: def _generate_signature(self, timestamp: str, method: str, endpoint: str, params: Dict[str, Any]) -> str:
"""Generate signature for private API calls using MEXC's expected parameter order""" """Generate signature for private API calls using MEXC's official method"""
# MEXC requires specific parameter ordering, not alphabetical # MEXC signature format varies by method:
# Based on successful test: symbol, side, type, quantity, timestamp, then other params # For GET/DELETE: URL-encoded query string of alphabetically sorted parameters.
mexc_param_order = ['symbol', 'side', 'type', 'quantity', 'timestamp', 'recvWindow'] # For POST: JSON string of parameters (no sorting needed).
# The API-Secret is used as the HMAC SHA256 key.
# Build ordered parameter list # Remove signature from params to avoid circular inclusion
ordered_params = [] clean_params = {k: v for k, v in params.items() if k != 'signature'}
# Add parameters in MEXC's expected order parameter_string: str
for param_name in mexc_param_order:
if param_name in params and param_name != 'signature':
ordered_params.append(f"{param_name}={params[param_name]}")
# Add any remaining parameters not in the standard order (alphabetically) if method.upper() == "POST":
remaining_params = {k: v for k, v in params.items() if k not in mexc_param_order and k != 'signature'} # For POST requests, the signature parameter is a JSON string
for key in sorted(remaining_params.keys()): # Ensure sorting keys for consistent JSON string generation across runs
ordered_params.append(f"{key}={remaining_params[key]}") # even though MEXC says sorting is not required for POST params, it's good practice.
parameter_string = json.dumps(clean_params, sort_keys=True, separators=(',', ':'))
else:
# For GET/DELETE requests, parameters are spliced in dictionary order with & interval
sorted_params = sorted(clean_params.items())
parameter_string = '&'.join(f"{key}={str(value)}" for key, value in sorted_params)
# Create query string (MEXC doesn't use the api_key + timestamp prefix) # The string to be signed is: accessKey + timestamp + obtained parameter string.
query_string = '&'.join(ordered_params) string_to_sign = f"{self.api_key}{timestamp}{parameter_string}"
logger.debug(f"MEXC signature query string: {query_string}") logger.debug(f"MEXC string to sign (method {method}): {string_to_sign}")
# Generate HMAC SHA256 signature # Generate HMAC SHA256 signature
signature = hmac.new( signature = hmac.new(
self.api_secret.encode('utf-8'), self.api_secret.encode('utf-8'),
query_string.encode('utf-8'), string_to_sign.encode('utf-8'),
hashlib.sha256 hashlib.sha256
).hexdigest() ).hexdigest()
logger.debug(f"MEXC signature: {signature}") logger.debug(f"MEXC generated signature: {signature}")
return signature return signature
def _send_public_request(self, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: def _send_public_request(self, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
@ -145,7 +149,7 @@ class MEXCInterface(ExchangeInterface):
logger.error(f"Error in public request to {endpoint}: {e}") logger.error(f"Error in public request to {endpoint}: {e}")
return {} return {}
def _send_private_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Optional[Dict[str, Any]]: def _send_private_request(self, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
"""Send a private request to the exchange with proper signature""" """Send a private request to the exchange with proper signature"""
if params is None: if params is None:
params = {} params = {}
@ -170,8 +174,11 @@ class MEXCInterface(ExchangeInterface):
if method.upper() == "GET": if method.upper() == "GET":
response = self.session.get(url, headers=headers, params=params, timeout=10) response = self.session.get(url, headers=headers, params=params, timeout=10)
elif method.upper() == "POST": elif method.upper() == "POST":
# MEXC expects POST parameters as query string, not in body # MEXC expects POST parameters as JSON in the request body, not as query string
response = self.session.post(url, headers=headers, params=params, timeout=10) # The signature is generated from the JSON string of parameters.
# We need to exclude 'signature' from the JSON body sent, as it's for the header.
params_for_body = {k: v for k, v in params.items() if k != 'signature'}
response = self.session.post(url, headers=headers, json=params_for_body, timeout=10)
else: else:
logger.error(f"Unsupported method: {method}") logger.error(f"Unsupported method: {method}")
return None return None
@ -217,12 +224,9 @@ class MEXCInterface(ExchangeInterface):
response = self._send_public_request('GET', endpoint, params) response = self._send_public_request('GET', endpoint, params)
if response:
# MEXC ticker returns a dictionary if single symbol, list if all symbols
if isinstance(response, dict): if isinstance(response, dict):
ticker_data = response ticker_data: Dict[str, Any] = response
elif isinstance(response, list) and len(response) > 0: elif isinstance(response, list) and len(response) > 0:
# If the response is a list, try to find the specific symbol
found_ticker = next((item for item in response if item.get('symbol') == formatted_symbol), None) found_ticker = next((item for item in response if item.get('symbol') == formatted_symbol), None)
if found_ticker: if found_ticker:
ticker_data = found_ticker ticker_data = found_ticker
@ -233,6 +237,9 @@ class MEXCInterface(ExchangeInterface):
logger.error(f"Unexpected ticker response format: {response}") logger.error(f"Unexpected ticker response format: {response}")
return None return None
# At this point, ticker_data is guaranteed to be a Dict[str, Any] due to the above logic
# If it was None, we would have returned early.
# Extract relevant info and format for universal use # Extract relevant info and format for universal use
last_price = float(ticker_data.get('lastPrice', 0)) last_price = float(ticker_data.get('lastPrice', 0))
bid_price = float(ticker_data.get('bidPrice', 0)) bid_price = float(ticker_data.get('bidPrice', 0))
@ -257,8 +264,6 @@ class MEXCInterface(ExchangeInterface):
'exchange': 'MEXC', 'exchange': 'MEXC',
'raw_data': ticker_data 'raw_data': ticker_data
} }
logger.error(f"Failed to get ticker for {symbol}")
return None
def get_api_symbols(self) -> List[str]: def get_api_symbols(self) -> List[str]:
"""Get list of symbols supported for API trading""" """Get list of symbols supported for API trading"""
@ -293,40 +298,90 @@ class MEXCInterface(ExchangeInterface):
logger.info(f"Supported symbols include: {supported_symbols[:10]}...") # Show first 10 logger.info(f"Supported symbols include: {supported_symbols[:10]}...") # Show first 10
return {} return {}
# Format quantity according to symbol precision requirements
formatted_quantity = self._format_quantity_for_symbol(formatted_symbol, quantity)
if formatted_quantity is None:
logger.error(f"MEXC: Failed to format quantity {quantity} for {formatted_symbol}")
return {}
# Handle order type restrictions for specific symbols
final_order_type = self._adjust_order_type_for_symbol(formatted_symbol, order_type.upper())
# Get price for limit orders
final_price = price
if final_order_type == 'LIMIT' and price is None:
# Get current market price
ticker = self.get_ticker(symbol)
if ticker and 'last' in ticker:
final_price = ticker['last']
logger.info(f"MEXC: Using market price ${final_price:.2f} for LIMIT order")
else:
logger.error(f"MEXC: Could not get market price for LIMIT order on {formatted_symbol}")
return {}
endpoint = "order" endpoint = "order"
params: Dict[str, Any] = { params: Dict[str, Any] = {
'symbol': formatted_symbol, 'symbol': formatted_symbol,
'side': side.upper(), 'side': side.upper(),
'type': order_type.upper(), 'type': final_order_type,
'quantity': str(quantity) # Quantity must be a string 'quantity': str(formatted_quantity) # Quantity must be a string
} }
if price is not None: if final_price is not None:
params['price'] = str(price) # Price must be a string for limit orders params['price'] = str(final_price) # Price must be a string for limit orders
logger.info(f"MEXC: Placing {side.upper()} {order_type.upper()} order for {quantity} {formatted_symbol} at price {price}") logger.info(f"MEXC: Placing {side.upper()} {final_order_type} order for {formatted_quantity} {formatted_symbol} at price {final_price}")
# For market orders, some parameters might be optional or handled differently.
# Check MEXC API docs for market order specifics (e.g., quoteOrderQty for buy market orders)
if order_type.upper() == 'MARKET' and side.upper() == 'BUY':
# If it's a market buy order, MEXC often expects quoteOrderQty instead of quantity
# Assuming quantity here refers to the base asset, if quoteOrderQty is needed, adjust.
# For now, we will stick to quantity and let MEXC handle the conversion if possible
pass # No specific change needed based on the current params structure
try: try:
# MEXC API endpoint for placing orders is /api/v3/order (POST) # MEXC API endpoint for placing orders is /api/v3/order (POST)
order_result = self._send_private_request('POST', endpoint, params) order_result = self._send_private_request('POST', endpoint, params)
if order_result: if order_result is not None:
logger.info(f"MEXC: Order placed successfully: {order_result}") logger.info(f"MEXC: Order placed successfully: {order_result}")
return order_result return order_result
else: else:
logger.error(f"MEXC: Error placing order: {order_result}") logger.error(f"MEXC: Error placing order: request returned None")
return {} return {}
except Exception as e: except Exception as e:
logger.error(f"MEXC: Exception placing order: {e}") logger.error(f"MEXC: Exception placing order: {e}")
return {} return {}
def _format_quantity_for_symbol(self, formatted_symbol: str, quantity: float) -> Optional[float]:
"""Format quantity according to symbol precision requirements"""
try:
# Symbol-specific precision rules
if formatted_symbol == 'ETHUSDC':
# ETHUSDC requires max 5 decimal places, step size 0.000001
formatted_qty = round(quantity, 5)
# Ensure it meets minimum step size
step_size = 0.000001
formatted_qty = round(formatted_qty / step_size) * step_size
# Round again to remove floating point errors
formatted_qty = round(formatted_qty, 6)
logger.info(f"MEXC: Formatted ETHUSDC quantity {quantity} -> {formatted_qty}")
return formatted_qty
elif formatted_symbol == 'BTCUSDC':
# Assume similar precision for BTC
formatted_qty = round(quantity, 6)
step_size = 0.000001
formatted_qty = round(formatted_qty / step_size) * step_size
formatted_qty = round(formatted_qty, 6)
return formatted_qty
else:
# Default formatting - 6 decimal places
return round(quantity, 6)
except Exception as e:
logger.error(f"Error formatting quantity for {formatted_symbol}: {e}")
return None
def _adjust_order_type_for_symbol(self, formatted_symbol: str, order_type: str) -> str:
"""Adjust order type based on symbol restrictions"""
if formatted_symbol == 'ETHUSDC':
# ETHUSDC only supports LIMIT and LIMIT_MAKER orders
if order_type == 'MARKET':
logger.info(f"MEXC: Converting MARKET order to LIMIT for {formatted_symbol} (MARKET not supported)")
return 'LIMIT'
return order_type
def cancel_order(self, symbol: str, order_id: str) -> Dict[str, Any]: def cancel_order(self, symbol: str, order_id: str) -> Dict[str, Any]:
"""Cancel an existing order on MEXC.""" """Cancel an existing order on MEXC."""
formatted_symbol = self._format_spot_symbol(symbol) formatted_symbol = self._format_spot_symbol(symbol)

View File

@ -229,8 +229,8 @@ class COBRLModelInterface(ModelInterface):
Interface for the COB RL model that handles model management, training, and inference Interface for the COB RL model that handles model management, training, and inference
""" """
def __init__(self, model_checkpoint_dir: str = "models/realtime_rl_cob", device: str = None): def __init__(self, model_checkpoint_dir: str = "models/realtime_rl_cob", device: str = None, name=None, **kwargs):
super().__init__(name="cob_rl_model") # Initialize ModelInterface with a name super().__init__(name=name) # Initialize ModelInterface with a name
self.model_checkpoint_dir = model_checkpoint_dir self.model_checkpoint_dir = model_checkpoint_dir
self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu')) self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu'))

View File

@ -5,7 +5,7 @@ import numpy as np
from collections import deque from collections import deque
import random import random
from typing import Tuple, List from typing import Tuple, List
import osvu import os
import sys import sys
import logging import logging
import torch.nn.functional as F import torch.nn.functional as F

View File

@ -162,11 +162,11 @@ mexc_trading:
trading_mode: simulation # simulation, testnet, live trading_mode: simulation # simulation, testnet, live
# Position sizing as percentage of account balance # Position sizing as percentage of account balance
base_position_percent: 5.0 # 5% base position of account base_position_percent: 1 # 0.5% base position of account (MUCH SAFER)
max_position_percent: 20.0 # 20% max position of account max_position_percent: 5.0 # 2% max position of account (REDUCED)
min_position_percent: 2.0 # 2% min position of account min_position_percent: 0.5 # 0.2% min position of account (REDUCED)
leverage: 50.0 # 50x leverage (adjustable in UI) leverage: 1.0 # 1x leverage (NO LEVERAGE FOR TESTING)
simulation_account_usd: 100.0 # $100 simulation account balance simulation_account_usd: 99.9 # $100 simulation account balance
# Risk management # Risk management
max_daily_loss_usd: 200.0 max_daily_loss_usd: 200.0

View File

@ -34,7 +34,7 @@ class COBIntegration:
Integration layer for Multi-Exchange COB data with gogo2 trading system Integration layer for Multi-Exchange COB data with gogo2 trading system
""" """
def __init__(self, data_provider: Optional[DataProvider] = None, symbols: Optional[List[str]] = None): def __init__(self, data_provider: Optional[DataProvider] = None, symbols: Optional[List[str]] = None, initial_data_limit=None, **kwargs):
""" """
Initialize COB Integration Initialize COB Integration

View File

@ -1007,6 +1007,17 @@ class TradingOrchestrator:
if enhanced_features is not None: if enhanced_features is not None:
# Get CNN prediction - use the actual underlying model # Get CNN prediction - use the actual underlying model
try: try:
# Ensure features are properly shaped and limited
if isinstance(enhanced_features, np.ndarray):
# Flatten and limit features to prevent shape mismatches
enhanced_features = enhanced_features.flatten()
if len(enhanced_features) > 100: # Limit to 100 features
enhanced_features = enhanced_features[:100]
elif len(enhanced_features) < 100: # Pad with zeros
padded = np.zeros(100)
padded[:len(enhanced_features)] = enhanced_features
enhanced_features = padded
if hasattr(model.model, 'act'): if hasattr(model.model, 'act'):
# Use the CNN's act method # Use the CNN's act method
action_result = model.model.act(enhanced_features, explore=False) action_result = model.model.act(enhanced_features, explore=False)
@ -1138,6 +1149,17 @@ class TradingOrchestrator:
) )
if feature_matrix is not None: if feature_matrix is not None:
# Ensure feature_matrix is properly shaped and limited
if isinstance(feature_matrix, np.ndarray):
# Flatten and limit features to prevent shape mismatches
feature_matrix = feature_matrix.flatten()
if len(feature_matrix) > 2000: # Limit to 2000 features for generic models
feature_matrix = feature_matrix[:2000]
elif len(feature_matrix) < 2000: # Pad with zeros
padded = np.zeros(2000)
padded[:len(feature_matrix)] = feature_matrix
feature_matrix = padded
prediction_result = model.predict(feature_matrix) prediction_result = model.predict(feature_matrix)
# Handle different return formats from model.predict() # Handle different return formats from model.predict()
@ -1834,3 +1856,100 @@ class TradingOrchestrator:
"""Set the trading executor for position tracking""" """Set the trading executor for position tracking"""
self.trading_executor = trading_executor self.trading_executor = trading_executor
logger.info("Trading executor set for position tracking and P&L feedback") logger.info("Trading executor set for position tracking and P&L feedback")
def _get_current_price(self, symbol: str) -> float:
"""Get current price for symbol"""
try:
# Try to get from data provider
if self.data_provider:
try:
# Try different methods to get current price
if hasattr(self.data_provider, 'get_latest_data'):
latest_data = self.data_provider.get_latest_data(symbol)
if latest_data and 'price' in latest_data:
return float(latest_data['price'])
elif latest_data and 'close' in latest_data:
return float(latest_data['close'])
elif hasattr(self.data_provider, 'get_current_price'):
return float(self.data_provider.get_current_price(symbol))
elif hasattr(self.data_provider, 'get_latest_candle'):
latest_candle = self.data_provider.get_latest_candle(symbol, '1m')
if latest_candle and 'close' in latest_candle:
return float(latest_candle['close'])
except Exception as e:
logger.debug(f"Could not get price from data provider: {e}")
# Try to get from universal adapter
if self.universal_adapter:
try:
data_stream = self.universal_adapter.get_latest_data(symbol)
if data_stream and hasattr(data_stream, 'current_price'):
return float(data_stream.current_price)
except Exception as e:
logger.debug(f"Could not get price from universal adapter: {e}")
# Fallback to default prices
default_prices = {
'ETH/USDT': 2500.0,
'BTC/USDT': 108000.0
}
return default_prices.get(symbol, 1000.0)
except Exception as e:
logger.error(f"Error getting current price for {symbol}: {e}")
# Return default price based on symbol
if 'ETH' in symbol:
return 2500.0
elif 'BTC' in symbol:
return 108000.0
else:
return 1000.0
def _generate_fallback_prediction(self, symbol: str) -> Dict[str, Any]:
"""Generate fallback prediction when models fail"""
try:
return {
'action': 'HOLD',
'confidence': 0.5,
'price': self._get_current_price(symbol) or 2500.0,
'timestamp': datetime.now(),
'model': 'fallback'
}
except Exception as e:
logger.debug(f"Error generating fallback prediction: {e}")
return {
'action': 'HOLD',
'confidence': 0.5,
'price': 2500.0,
'timestamp': datetime.now(),
'model': 'fallback'
}
def capture_dqn_prediction(self, symbol: str, action_idx: int, confidence: float, price: float, q_values: List[float] = None):
"""Capture DQN prediction for dashboard visualization"""
try:
if symbol not in self.recent_dqn_predictions:
self.recent_dqn_predictions[symbol] = deque(maxlen=100)
prediction_data = {
'timestamp': datetime.now(),
'action': ['SELL', 'HOLD', 'BUY'][action_idx],
'confidence': confidence,
'price': price,
'q_values': q_values or [0.33, 0.33, 0.34]
}
self.recent_dqn_predictions[symbol].append(prediction_data)
except Exception as e:
logger.debug(f"Error capturing DQN prediction: {e}")
def capture_cnn_prediction(self, symbol: str, direction: int, confidence: float, current_price: float, predicted_price: float):
"""Capture CNN prediction for dashboard visualization"""
try:
if symbol not in self.recent_cnn_predictions:
self.recent_cnn_predictions[symbol] = deque(maxlen=50)
prediction_data = {
'timestamp': datetime.now(),
'direction': ['DOWN', 'SAME', 'UP'][direction],
'confidence': confidence,
'current_price': current_price,
'predicted_price': predicted_price
}
self.recent_cnn_predictions[symbol].append(prediction_data)
except Exception as e:
logger.debug(f"Error capturing CNN prediction: {e}")

View File

@ -114,8 +114,13 @@ class TradingExecutor:
# Thread safety # Thread safety
self.lock = Lock() self.lock = Lock()
# Connect to exchange # Connect to exchange - skip connection check in simulation mode
if self.trading_enabled: if self.trading_enabled:
if self.simulation_mode:
logger.info("TRADING EXECUTOR: Simulation mode - skipping exchange connection check")
# In simulation mode, we don't need a real exchange connection
# Trading should remain enabled for simulation trades
else:
logger.info("TRADING EXECUTOR: Attempting to connect to exchange...") logger.info("TRADING EXECUTOR: Attempting to connect to exchange...")
if not self._connect_exchange(): if not self._connect_exchange():
logger.error("TRADING EXECUTOR: Failed initial exchange connection. Trading will be disabled.") logger.error("TRADING EXECUTOR: Failed initial exchange connection. Trading will be disabled.")
@ -230,15 +235,25 @@ class TradingExecutor:
required_capital = self._calculate_position_size(confidence, current_price) required_capital = self._calculate_position_size(confidence, current_price)
# Get available balance for the quote asset # Get available balance for the quote asset
available_balance = self.exchange.get_balance(quote_asset) # For MEXC, prioritize USDT over USDC since most accounts have USDT
if quote_asset == 'USDC':
# If USDC balance is insufficient, check USDT as fallback (for MEXC compatibility) # Check USDT first (most common balance)
if available_balance < required_capital and quote_asset == 'USDC':
usdt_balance = self.exchange.get_balance('USDT') usdt_balance = self.exchange.get_balance('USDT')
usdc_balance = self.exchange.get_balance('USDC')
if usdt_balance >= required_capital: if usdt_balance >= required_capital:
available_balance = usdt_balance available_balance = usdt_balance
quote_asset = 'USDT' # Use USDT instead quote_asset = 'USDT' # Use USDT for trading
logger.info(f"BALANCE CHECK: Using USDT fallback balance for {symbol}") logger.info(f"BALANCE CHECK: Using USDT balance for {symbol} (preferred)")
elif usdc_balance >= required_capital:
available_balance = usdc_balance
logger.info(f"BALANCE CHECK: Using USDC balance for {symbol}")
else:
# Use the larger balance for reporting
available_balance = max(usdt_balance, usdc_balance)
quote_asset = 'USDT' if usdt_balance > usdc_balance else 'USDC'
else:
available_balance = self.exchange.get_balance(quote_asset)
logger.info(f"BALANCE CHECK: Symbol: {symbol}, Action: {action}, Required: ${required_capital:.2f} {quote_asset}, Available: ${available_balance:.2f} {quote_asset}") logger.info(f"BALANCE CHECK: Symbol: {symbol}, Action: {action}, Required: ${required_capital:.2f} {quote_asset}, Available: ${available_balance:.2f} {quote_asset}")

View File

@ -229,9 +229,12 @@ class TrainingIntegration:
# Truncate # Truncate
features = features[:50] features = features[:50]
# Get the model's device to ensure tensors are on the same device
model_device = next(cnn_model.parameters()).device
# Create tensors # Create tensors
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(device) features_tensor = torch.FloatTensor(features).unsqueeze(0).to(model_device)
target_tensor = torch.LongTensor([target]).to(device) target_tensor = torch.LongTensor([target]).to(model_device)
# Training step # Training step
cnn_model.train() cnn_model.train()

View File

@ -1454,9 +1454,10 @@ class EnhancedRealtimeTrainingSystem:
model.train() model.train()
optimizer.zero_grad() optimizer.zero_grad()
# Convert numpy arrays to PyTorch tensors # Convert numpy arrays to PyTorch tensors and move to device
features_tensor = torch.from_numpy(features).float() device = next(model.parameters()).device
targets_tensor = torch.from_numpy(targets).long() features_tensor = torch.from_numpy(features).float().to(device)
targets_tensor = torch.from_numpy(targets).long().to(device)
# Ensure features_tensor has the correct shape for CNN (batch_size, channels, height, width) # Ensure features_tensor has the correct shape for CNN (batch_size, channels, height, width)
# Assuming features are flattened (batch_size, 15*20) and need to be reshaped to (batch_size, 1, 15, 20) # Assuming features are flattened (batch_size, 15*20) and need to be reshaped to (batch_size, 1, 15, 20)
@ -1471,10 +1472,37 @@ class EnhancedRealtimeTrainingSystem:
# If the CNN expects (batch_size, channels, sequence_length) # If the CNN expects (batch_size, channels, sequence_length)
# features_tensor = features_tensor.view(features_tensor.shape[0], 1, 15 * 20) # Example for 1D CNN # features_tensor = features_tensor.view(features_tensor.shape[0], 1, 15 * 20) # Example for 1D CNN
# Let's assume the CNN expects 2D input (batch_size, flattened_features) # Ensure proper shape for CNN input
if len(features_tensor.shape) == 2:
# If it's (batch_size, features), keep as is for 1D CNN
pass
elif len(features_tensor.shape) == 1:
# If it's (features), add batch dimension
features_tensor = features_tensor.unsqueeze(0)
else:
# Reshape to (batch_size, features) if needed
features_tensor = features_tensor.view(features_tensor.shape[0], -1)
# Limit input size to prevent shape mismatches
if features_tensor.shape[1] > 1000: # Limit to 1000 features
features_tensor = features_tensor[:, :1000]
outputs = model(features_tensor) outputs = model(features_tensor)
loss = criterion(outputs, targets_tensor) # Extract logits from model output (model returns a dictionary)
if isinstance(outputs, dict):
logits = outputs['logits']
elif isinstance(outputs, tuple):
logits = outputs[0] # First element is usually logits
else:
logits = outputs
# Ensure logits is a tensor
if not isinstance(logits, torch.Tensor):
logger.error(f"CNN output is not a tensor: {type(logits)}")
return 0.0
loss = criterion(logits, targets_tensor)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -1856,13 +1884,19 @@ class EnhancedRealtimeTrainingSystem:
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent') if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
and self.orchestrator.rl_agent): and self.orchestrator.rl_agent):
# Get Q-values from model # Use RL agent to make prediction
q_values = self.orchestrator.rl_agent.act(current_state, return_q_values=True) current_state = self._get_dqn_state(symbol)
if isinstance(q_values, tuple): if current_state is None:
action, q_vals = q_values return
q_values = q_vals.tolist() if hasattr(q_vals, 'tolist') else [0, 0, 0] action = self.orchestrator.rl_agent.act(current_state, explore=False)
# Get Q-values separately if available
if hasattr(self.orchestrator.rl_agent, 'policy_net'):
with torch.no_grad():
state_tensor = torch.FloatTensor(current_state).unsqueeze(0).to(self.orchestrator.rl_agent.device)
q_values_tensor = self.orchestrator.rl_agent.policy_net(state_tensor)
if isinstance(q_values_tensor, tuple):
q_values = q_values_tensor[0].cpu().numpy()[0].tolist()
else: else:
action = q_values
q_values = [0.33, 0.33, 0.34] # Default uniform distribution q_values = [0.33, 0.33, 0.34] # Default uniform distribution
confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33 confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33

View File

@ -1,201 +1,121 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Run Clean Trading Dashboard with Full Training Pipeline Clean Trading Dashboard Runner with Enhanced Stability and Error Handling
Integrated system with both training loop and clean web dashboard
""" """
import os
# Fix OpenMP library conflicts before importing other modules
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
os.environ['OMP_NUM_THREADS'] = '4'
import asyncio
import logging
import sys import sys
import threading import logging
import traceback
import gc
import time import time
import psutil
import torch
from pathlib import Path from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import get_config, setup_logging
from core.data_provider import DataProvider
# Import checkpoint management
from utils.checkpoint_manager import get_checkpoint_manager
from utils.training_integration import get_training_integration
# Setup logging # Setup logging
setup_logging() logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def start_training_pipeline(orchestrator, trading_executor): def clear_gpu_memory():
"""Start the training pipeline in the background""" """Clear GPU memory cache"""
logger.info("=" * 70) if torch.cuda.is_available():
logger.info("STARTING TRAINING PIPELINE WITH CLEAN DASHBOARD") torch.cuda.empty_cache()
logger.info("=" * 70) torch.cuda.synchronize()
# Initialize checkpoint management def check_system_resources():
checkpoint_manager = get_checkpoint_manager() """Check if system has enough resources"""
training_integration = get_training_integration() available_ram = psutil.virtual_memory().available / 1024**3
if available_ram < 2.0: # Less than 2GB available
logger.warning(f"Low RAM: {available_ram:.1f} GB available")
gc.collect()
clear_gpu_memory()
return False
return True
# Training statistics def run_dashboard_with_recovery():
training_stats = { """Run dashboard with automatic error recovery"""
'iteration_count': 0, max_retries = 3
'total_decisions': 0, retry_count = 0
'successful_trades': 0,
'best_performance': 0.0,
'last_checkpoint_iteration': 0
}
while retry_count < max_retries:
try: try:
# Start real-time processing (available in Enhanced orchestrator) logger.info(f"Starting Clean Trading Dashboard (attempt {retry_count + 1}/{max_retries})")
if hasattr(orchestrator, 'start_realtime_processing'):
await orchestrator.start_realtime_processing()
logger.info("Real-time processing started")
# Start COB integration (available in Enhanced orchestrator) # Check system resources
if hasattr(orchestrator, 'start_cob_integration'): if not check_system_resources():
await orchestrator.start_cob_integration() logger.warning("System resources low, waiting 30 seconds...")
logger.info("COB integration started - 5-minute data matrix active") time.sleep(30)
else: continue
logger.info("COB integration not available")
# Main training loop # Import here to avoid memory issues on restart
iteration = 0
last_checkpoint_time = time.time()
while True:
try:
iteration += 1
training_stats['iteration_count'] = iteration
# Get symbols to process
symbols = orchestrator.symbols if hasattr(orchestrator, 'symbols') else ['ETH/USDT']
# Process each symbol
for symbol in symbols:
try:
# Make trading decision (this triggers model training)
decision = await orchestrator.make_trading_decision(symbol)
if decision:
training_stats['total_decisions'] += 1
logger.debug(f"[{symbol}] Decision: {decision.action} @ {decision.confidence:.1%}")
except Exception as e:
logger.warning(f"Error processing {symbol}: {e}")
# Status logging every 100 iterations
if iteration % 100 == 0:
current_time = time.time()
elapsed = current_time - last_checkpoint_time
logger.info(f"[TRAINING] Iteration {iteration}, Decisions: {training_stats['total_decisions']}, Time: {elapsed:.1f}s")
# Models will save their own checkpoints when performance improves
training_stats['last_checkpoint_iteration'] = iteration
last_checkpoint_time = current_time
# Brief pause to prevent overwhelming the system
await asyncio.sleep(0.1) # 100ms between iterations
except Exception as e:
logger.error(f"Training loop error: {e}")
await asyncio.sleep(5) # Wait longer on error
except Exception as e:
logger.error(f"Training pipeline error: {e}")
import traceback
logger.error(traceback.format_exc())
def start_clean_dashboard_with_training():
"""Start clean dashboard with full training pipeline"""
try:
logger.info("=" * 80)
logger.info("CLEAN TRADING DASHBOARD + FULL TRAINING PIPELINE")
logger.info("=" * 80)
logger.info("Features: Real-time Training, COB Integration, Clean UI")
logger.info("Universal Data Stream: ENABLED")
logger.info("Neural Decision Fusion: ENABLED")
logger.info("COB Integration: ENABLED")
logger.info("GPU Training: ENABLED")
logger.info("Multi-symbol: ETH/USDT, BTC/USDT")
# Get port from environment or use default
dashboard_port = int(os.environ.get('DASHBOARD_PORT', '8051'))
logger.info(f"Dashboard: http://127.0.0.1:{dashboard_port}")
logger.info("=" * 80)
# Check environment variables
enable_universal_stream = os.environ.get('ENABLE_UNIVERSAL_DATA_STREAM', '1') == '1'
enable_nn_fusion = os.environ.get('ENABLE_NN_DECISION_FUSION', '1') == '1'
enable_cob = os.environ.get('ENABLE_COB_INTEGRATION', '1') == '1'
logger.info(f"Universal Data Stream: {'ENABLED' if enable_universal_stream else 'DISABLED'}")
logger.info(f"Neural Decision Fusion: {'ENABLED' if enable_nn_fusion else 'DISABLED'}")
logger.info(f"COB Integration: {'ENABLED' if enable_cob else 'DISABLED'}")
# Get configuration
config = get_config()
# Initialize core components
from core.data_provider import DataProvider from core.data_provider import DataProvider
from core.orchestrator import TradingOrchestrator from core.orchestrator import TradingOrchestrator
from core.trading_executor import TradingExecutor from core.trading_executor import TradingExecutor
# Create data provider
data_provider = DataProvider()
# Create enhanced orchestrator with COB integration - stable and efficient
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
logger.info("Enhanced Trading Orchestrator created with COB integration")
# Create trading executor
trading_executor = TradingExecutor()
# Import clean dashboard
from web.clean_dashboard import create_clean_dashboard from web.clean_dashboard import create_clean_dashboard
# Create clean dashboard logger.info("Creating data provider...")
dashboard = create_clean_dashboard( data_provider = DataProvider()
logger.info("Creating trading orchestrator...")
orchestrator = TradingOrchestrator(
data_provider=data_provider, data_provider=data_provider,
orchestrator=orchestrator, enhanced_rl_training=True
trading_executor=trading_executor
) )
logger.info("Clean Trading Dashboard created")
# Start training pipeline in background thread logger.info("Creating trading executor...")
def training_worker(): trading_executor = TradingExecutor()
"""Run training pipeline in background"""
logger.info("Creating clean dashboard...")
dashboard = create_clean_dashboard(data_provider, orchestrator, trading_executor)
logger.info("Dashboard created successfully")
logger.info("=== Clean Trading Dashboard Status ===")
logger.info("- Data Provider: Active")
logger.info("- Trading Orchestrator: Active")
logger.info("- Trading Executor: Active")
logger.info("- Enhanced Training: Active")
logger.info("- Dashboard: Ready")
logger.info("=======================================")
# Start the dashboard server with error handling
try: try:
asyncio.run(start_training_pipeline(orchestrator, trading_executor)) logger.info("Starting dashboard server on http://127.0.0.1:8050")
except Exception as e: dashboard.run_server(host='127.0.0.1', port=8050, debug=False)
logger.error(f"Training worker error: {e}")
training_thread = threading.Thread(target=training_worker, daemon=True)
training_thread.start()
logger.info("Training pipeline started in background")
# Wait a moment for training to initialize
time.sleep(3)
# Start dashboard server (this blocks)
logger.info(" Starting Clean Dashboard Server...")
dashboard.run_server(host='127.0.0.1', port=dashboard_port, debug=False)
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("System stopped by user") logger.info("Dashboard stopped by user")
break
except Exception as e: except Exception as e:
logger.error(f"Error running clean dashboard with training: {e}") logger.error(f"Dashboard server error: {e}")
import traceback logger.error(traceback.format_exc())
traceback.print_exc() raise
except Exception as e:
logger.error(f"Critical error in dashboard: {e}")
logger.error(traceback.format_exc())
retry_count += 1
if retry_count < max_retries:
logger.info(f"Attempting recovery... ({retry_count}/{max_retries})")
# Cleanup
gc.collect()
clear_gpu_memory()
# Wait before retry
wait_time = 30 * retry_count # Exponential backoff
logger.info(f"Waiting {wait_time} seconds before retry...")
time.sleep(wait_time)
else:
logger.error("Max retries reached. Exiting.")
sys.exit(1) sys.exit(1)
def main():
"""Main function"""
start_clean_dashboard_with_training()
if __name__ == "__main__": if __name__ == "__main__":
main() try:
run_dashboard_with_recovery()
except KeyboardInterrupt:
logger.info("Application stopped by user")
sys.exit(0)
except Exception as e:
logger.error(f"Fatal error: {e}")
logger.error(traceback.format_exc())
sys.exit(1)

View File

@ -205,6 +205,9 @@ class CleanTradingDashboard:
# Start signal generation loop to ensure continuous trading signals # Start signal generation loop to ensure continuous trading signals
self._start_signal_generation_loop() self._start_signal_generation_loop()
# Start live balance sync for trading
self._start_live_balance_sync()
# Start training sessions if models are showing FRESH status # Start training sessions if models are showing FRESH status
threading.Thread(target=self._delayed_training_check, daemon=True).start() threading.Thread(target=self._delayed_training_check, daemon=True).start()
@ -319,6 +322,66 @@ class CleanTradingDashboard:
logger.warning(f"Error getting balance: {e}") logger.warning(f"Error getting balance: {e}")
return 100.0 # Default balance return 100.0 # Default balance
def _get_live_balance(self) -> float:
"""Get real-time balance from exchange when in live trading mode"""
try:
if self.trading_executor:
# Check if we're in live trading mode
is_live = (hasattr(self.trading_executor, 'trading_enabled') and
self.trading_executor.trading_enabled and
hasattr(self.trading_executor, 'simulation_mode') and
not self.trading_executor.simulation_mode)
if is_live and hasattr(self.trading_executor, 'exchange'):
# Get real balance from exchange (throttled to avoid API spam)
import time
current_time = time.time()
# Cache balance for 5 seconds for more frequent updates in live trading
if not hasattr(self, '_last_balance_check') or current_time - self._last_balance_check > 5:
exchange = self.trading_executor.exchange
if hasattr(exchange, 'get_balance'):
live_balance = exchange.get_balance('USDC')
if live_balance is not None and live_balance > 0:
self._cached_live_balance = live_balance
self._last_balance_check = current_time
logger.info(f"LIVE BALANCE: Retrieved ${live_balance:.2f} USDC from MEXC")
return live_balance
else:
logger.warning(f"LIVE BALANCE: Retrieved ${live_balance:.2f} USDC - checking USDT as fallback")
# Also try USDT as fallback since user might have USDT
usdt_balance = exchange.get_balance('USDT')
if usdt_balance is not None and usdt_balance > 0:
self._cached_live_balance = usdt_balance
self._last_balance_check = current_time
logger.info(f"LIVE BALANCE: Using USDT balance ${usdt_balance:.2f}")
return usdt_balance
else:
logger.warning("LIVE BALANCE: Exchange does not have get_balance method")
else:
# Return cached balance if within 10 second window
if hasattr(self, '_cached_live_balance'):
return self._cached_live_balance
elif hasattr(self.trading_executor, 'simulation_mode') and self.trading_executor.simulation_mode:
# In simulation mode, show dynamic balance based on P&L
initial_balance = self._get_initial_balance()
realized_pnl = sum(trade.get('pnl', 0) for trade in self.closed_trades)
simulation_balance = initial_balance + realized_pnl
logger.debug(f"SIMULATION BALANCE: ${simulation_balance:.2f} (Initial: ${initial_balance:.2f} + P&L: ${realized_pnl:.2f})")
return simulation_balance
else:
logger.debug("LIVE BALANCE: Not in live trading mode, using initial balance")
# Fallback to initial balance for simulation mode
return self._get_initial_balance()
except Exception as e:
logger.error(f"Error getting live balance: {e}")
# Return cached balance if available, otherwise fallback
if hasattr(self, '_cached_live_balance'):
return self._cached_live_balance
return self._get_initial_balance()
def _setup_layout(self): def _setup_layout(self):
"""Setup the dashboard layout using layout manager""" """Setup the dashboard layout using layout manager"""
self.app.layout = self.layout_manager.create_main_layout() self.app.layout = self.layout_manager.create_main_layout()
@ -411,17 +474,48 @@ class CleanTradingDashboard:
trade_count = len(self.closed_trades) trade_count = len(self.closed_trades)
trade_str = f"{trade_count} Trades" trade_str = f"{trade_count} Trades"
# Portfolio value # Portfolio value - use live balance for live trading
initial_balance = self._get_initial_balance() current_balance = self._get_live_balance()
portfolio_value = initial_balance + total_session_pnl # Use total P&L including unrealized portfolio_value = current_balance + total_session_pnl # Use total P&L including unrealized
portfolio_str = f"${portfolio_value:.2f}"
# MEXC status # Show live balance indicator for live trading
balance_indicator = ""
if self.trading_executor:
is_live = (hasattr(self.trading_executor, 'trading_enabled') and
self.trading_executor.trading_enabled and
hasattr(self.trading_executor, 'simulation_mode') and
not self.trading_executor.simulation_mode)
if is_live:
balance_indicator = " (LIVE)"
portfolio_str = f"${portfolio_value:.2f}{balance_indicator}"
# MEXC status with balance info
mexc_status = "SIM" mexc_status = "SIM"
if self.trading_executor: if self.trading_executor:
if hasattr(self.trading_executor, 'trading_enabled') and self.trading_executor.trading_enabled: if hasattr(self.trading_executor, 'trading_enabled') and self.trading_executor.trading_enabled:
if hasattr(self.trading_executor, 'simulation_mode') and not self.trading_executor.simulation_mode: if hasattr(self.trading_executor, 'simulation_mode') and self.trading_executor.simulation_mode:
mexc_status = "LIVE" # Show simulation mode status with simulated balance
mexc_status = f"SIM - ${current_balance:.2f}"
elif hasattr(self.trading_executor, 'simulation_mode') and not self.trading_executor.simulation_mode:
# Show live balance in MEXC status - detect currency
try:
exchange = self.trading_executor.exchange
usdc_balance = exchange.get_balance('USDC') if hasattr(exchange, 'get_balance') else 0
usdt_balance = exchange.get_balance('USDT') if hasattr(exchange, 'get_balance') else 0
if usdc_balance > 0:
mexc_status = f"LIVE - ${usdc_balance:.2f} USDC"
elif usdt_balance > 0:
mexc_status = f"LIVE - ${usdt_balance:.2f} USDT"
else:
mexc_status = f"LIVE - ${current_balance:.2f}"
except:
mexc_status = f"LIVE - ${current_balance:.2f}"
else:
mexc_status = "SIM"
else:
mexc_status = "DISABLED"
return price_str, session_pnl_str, position_str, trade_str, portfolio_str, mexc_status return price_str, session_pnl_str, position_str, trade_str, portfolio_str, mexc_status
@ -2877,6 +2971,39 @@ class CleanTradingDashboard:
except Exception as e: except Exception as e:
logger.error(f"Error starting signal generation loop: {e}") logger.error(f"Error starting signal generation loop: {e}")
def _start_live_balance_sync(self):
"""Start continuous live balance synchronization for trading"""
def balance_sync_worker():
while True:
try:
if self.trading_executor:
is_live = (hasattr(self.trading_executor, 'trading_enabled') and
self.trading_executor.trading_enabled and
hasattr(self.trading_executor, 'simulation_mode') and
not self.trading_executor.simulation_mode)
if is_live and hasattr(self.trading_executor, 'exchange'):
# Force balance refresh every 15 seconds in live mode
if hasattr(self, '_last_balance_check'):
del self._last_balance_check # Force refresh
balance = self._get_live_balance()
if balance > 0:
logger.debug(f"BALANCE SYNC: Live balance: ${balance:.2f}")
else:
logger.warning("BALANCE SYNC: Could not retrieve live balance")
# Sync balance every 15 seconds for live trading
time.sleep(15)
except Exception as e:
logger.debug(f"Error in balance sync loop: {e}")
time.sleep(30) # Wait longer on error
# Start balance sync thread only if we have trading enabled
if self.trading_executor:
threading.Thread(target=balance_sync_worker, daemon=True).start()
logger.info("BALANCE SYNC: Background balance synchronization started")
def _generate_dqn_signal(self, symbol: str, current_price: float) -> Optional[Dict]: def _generate_dqn_signal(self, symbol: str, current_price: float) -> Optional[Dict]:
"""Generate trading signal using DQN agent - NOT AVAILABLE IN BASIC ORCHESTRATOR""" """Generate trading signal using DQN agent - NOT AVAILABLE IN BASIC ORCHESTRATOR"""
# Basic orchestrator doesn't have DQN features # Basic orchestrator doesn't have DQN features
@ -4519,28 +4646,35 @@ class CleanTradingDashboard:
imbalance = cob_snapshot['stats']['imbalance'] imbalance = cob_snapshot['stats']['imbalance']
abs_imbalance = abs(imbalance) abs_imbalance = abs(imbalance)
# Dynamic threshold based on imbalance strength # Dynamic threshold based on imbalance strength with realistic confidence
if abs_imbalance > 0.8: # Very strong imbalance (>80%) if abs_imbalance > 0.8: # Very strong imbalance (>80%)
threshold = 0.05 # 5% threshold for very strong signals threshold = 0.05 # 5% threshold for very strong signals
confidence_multiplier = 3.0 base_confidence = 0.85 # High but not perfect confidence
confidence_boost = (abs_imbalance - 0.8) * 0.75 # Scale remaining 15%
elif abs_imbalance > 0.5: # Strong imbalance (>50%) elif abs_imbalance > 0.5: # Strong imbalance (>50%)
threshold = 0.1 # 10% threshold for strong signals threshold = 0.1 # 10% threshold for strong signals
confidence_multiplier = 2.5 base_confidence = 0.70 # Good confidence
confidence_boost = (abs_imbalance - 0.5) * 0.50 # Scale up to 85%
elif abs_imbalance > 0.3: # Moderate imbalance (>30%) elif abs_imbalance > 0.3: # Moderate imbalance (>30%)
threshold = 0.15 # 15% threshold for moderate signals threshold = 0.15 # 15% threshold for moderate signals
confidence_multiplier = 2.0 base_confidence = 0.55 # Moderate confidence
confidence_boost = (abs_imbalance - 0.3) * 0.75 # Scale up to 70%
else: # Weak imbalance else: # Weak imbalance
threshold = 0.2 # 20% threshold for weak signals threshold = 0.2 # 20% threshold for weak signals
confidence_multiplier = 1.5 base_confidence = 0.35 # Low confidence
confidence_boost = abs_imbalance * 0.67 # Scale up to 55%
# Generate signal if imbalance exceeds threshold # Generate signal if imbalance exceeds threshold
if abs_imbalance > threshold: if abs_imbalance > threshold:
# Calculate more realistic confidence (never exactly 1.0)
final_confidence = min(0.95, base_confidence + confidence_boost)
signal = { signal = {
'timestamp': datetime.now(), 'timestamp': datetime.now(),
'type': 'cob_liquidity_imbalance', 'type': 'cob_liquidity_imbalance',
'action': 'BUY' if imbalance > 0 else 'SELL', 'action': 'BUY' if imbalance > 0 else 'SELL',
'symbol': symbol, 'symbol': symbol,
'confidence': min(1.0, abs_imbalance * confidence_multiplier), 'confidence': final_confidence,
'strength': abs_imbalance, 'strength': abs_imbalance,
'threshold_used': threshold, 'threshold_used': threshold,
'signal_strength': 'very_strong' if abs_imbalance > 0.8 else 'strong' if abs_imbalance > 0.5 else 'moderate' if abs_imbalance > 0.3 else 'weak', 'signal_strength': 'very_strong' if abs_imbalance > 0.8 else 'strong' if abs_imbalance > 0.5 else 'moderate' if abs_imbalance > 0.3 else 'weak',
@ -5478,15 +5612,18 @@ class CleanTradingDashboard:
import torch import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Get the model's device to ensure tensors are on the same device
model_device = next(model.parameters()).device
# Handle different input shapes for different CNN models # Handle different input shapes for different CNN models
if hasattr(model, 'input_shape'): if hasattr(model, 'input_shape'):
# EnhancedCNN model # EnhancedCNN model
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(device) features_tensor = torch.FloatTensor(features).unsqueeze(0).to(model_device)
else: else:
# Basic CNN model - reshape appropriately # Basic CNN model - reshape appropriately
features_tensor = torch.FloatTensor(features).unsqueeze(0).unsqueeze(0).to(device) features_tensor = torch.FloatTensor(features).unsqueeze(0).unsqueeze(0).to(model_device)
target_tensor = torch.LongTensor([target]).to(device) target_tensor = torch.LongTensor([target]).to(model_device)
# Set model to training mode and zero gradients # Set model to training mode and zero gradients
model.train() model.train()
@ -5605,10 +5742,11 @@ class CleanTradingDashboard:
if hasattr(network, 'forward'): if hasattr(network, 'forward'):
import torch import torch
import torch.nn as nn import torch.nn as nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get the model's device to ensure tensors are on the same device
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(device) model_device = next(network.parameters()).device
action_target_tensor = torch.LongTensor([action_target]).to(device) features_tensor = torch.FloatTensor(features).unsqueeze(0).to(model_device)
confidence_target_tensor = torch.FloatTensor([confidence_target]).to(device) action_target_tensor = torch.LongTensor([action_target]).to(model_device)
confidence_target_tensor = torch.FloatTensor([confidence_target]).to(model_device)
network.train() network.train()
network_output = network(features_tensor) network_output = network(features_tensor)