6 Commits

Author SHA1 Message Date
Dobromir Popov
6c91bf0b93 fix sim and wip fix live 2025-07-08 02:47:10 +03:00
Dobromir Popov
64678bd8d3 more live trades fix 2025-07-08 02:03:32 +03:00
Dobromir Popov
4ab7bc1846 tweaks, try live trading 2025-07-08 01:33:22 +03:00
Dobromir Popov
9cd2d5d8a4 fixes 2025-07-07 23:39:12 +03:00
Dobromir Popov
2d8f763eeb improve training and model data 2025-07-07 15:48:25 +03:00
Dobromir Popov
271e7d59b5 fixed cob 2025-07-07 01:44:16 +03:00
24 changed files with 2724 additions and 1985 deletions

Binary file not shown.

View File

@@ -5,6 +5,7 @@ import requests
import hmac
import hashlib
from urllib.parse import urlencode, quote_plus
import json # Added for json.dumps
from .exchange_interface import ExchangeInterface
@@ -85,37 +86,40 @@ class MEXCInterface(ExchangeInterface):
return symbol.replace('/', '_').upper()
def _generate_signature(self, timestamp: str, method: str, endpoint: str, params: Dict[str, Any]) -> str:
"""Generate signature for private API calls using MEXC's expected parameter order"""
# MEXC requires specific parameter ordering, not alphabetical
# Based on successful test: symbol, side, type, quantity, timestamp, then other params
mexc_param_order = ['symbol', 'side', 'type', 'quantity', 'timestamp', 'recvWindow']
# Build ordered parameter list
ordered_params = []
# Add parameters in MEXC's expected order
for param_name in mexc_param_order:
if param_name in params and param_name != 'signature':
ordered_params.append(f"{param_name}={params[param_name]}")
# Add any remaining parameters not in the standard order (alphabetically)
remaining_params = {k: v for k, v in params.items() if k not in mexc_param_order and k != 'signature'}
for key in sorted(remaining_params.keys()):
ordered_params.append(f"{key}={remaining_params[key]}")
# Create query string (MEXC doesn't use the api_key + timestamp prefix)
query_string = '&'.join(ordered_params)
logger.debug(f"MEXC signature query string: {query_string}")
"""Generate signature for private API calls using MEXC's official method"""
# MEXC signature format varies by method:
# For GET/DELETE: URL-encoded query string of alphabetically sorted parameters.
# For POST: JSON string of parameters (no sorting needed).
# The API-Secret is used as the HMAC SHA256 key.
# Remove signature from params to avoid circular inclusion
clean_params = {k: v for k, v in params.items() if k != 'signature'}
parameter_string: str
if method.upper() == "POST":
# For POST requests, the signature parameter is a JSON string
# Ensure sorting keys for consistent JSON string generation across runs
# even though MEXC says sorting is not required for POST params, it's good practice.
parameter_string = json.dumps(clean_params, sort_keys=True, separators=(',', ':'))
else:
# For GET/DELETE requests, parameters are spliced in dictionary order with & interval
sorted_params = sorted(clean_params.items())
parameter_string = '&'.join(f"{key}={str(value)}" for key, value in sorted_params)
# The string to be signed is: accessKey + timestamp + obtained parameter string.
string_to_sign = f"{self.api_key}{timestamp}{parameter_string}"
logger.debug(f"MEXC string to sign (method {method}): {string_to_sign}")
# Generate HMAC SHA256 signature
signature = hmac.new(
self.api_secret.encode('utf-8'),
query_string.encode('utf-8'),
string_to_sign.encode('utf-8'),
hashlib.sha256
).hexdigest()
logger.debug(f"MEXC signature: {signature}")
logger.debug(f"MEXC generated signature: {signature}")
return signature
def _send_public_request(self, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
@@ -145,7 +149,7 @@ class MEXCInterface(ExchangeInterface):
logger.error(f"Error in public request to {endpoint}: {e}")
return {}
def _send_private_request(self, method: str, endpoint: str, params: Dict[str, Any] = None) -> Optional[Dict[str, Any]]:
def _send_private_request(self, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
"""Send a private request to the exchange with proper signature"""
if params is None:
params = {}
@@ -170,8 +174,11 @@ class MEXCInterface(ExchangeInterface):
if method.upper() == "GET":
response = self.session.get(url, headers=headers, params=params, timeout=10)
elif method.upper() == "POST":
# MEXC expects POST parameters as query string, not in body
response = self.session.post(url, headers=headers, params=params, timeout=10)
# MEXC expects POST parameters as JSON in the request body, not as query string
# The signature is generated from the JSON string of parameters.
# We need to exclude 'signature' from the JSON body sent, as it's for the header.
params_for_body = {k: v for k, v in params.items() if k != 'signature'}
response = self.session.post(url, headers=headers, json=params_for_body, timeout=10)
else:
logger.error(f"Unsupported method: {method}")
return None
@@ -217,48 +224,46 @@ class MEXCInterface(ExchangeInterface):
response = self._send_public_request('GET', endpoint, params)
if response:
# MEXC ticker returns a dictionary if single symbol, list if all symbols
if isinstance(response, dict):
ticker_data = response
elif isinstance(response, list) and len(response) > 0:
# If the response is a list, try to find the specific symbol
found_ticker = next((item for item in response if item.get('symbol') == formatted_symbol), None)
if found_ticker:
ticker_data = found_ticker
else:
logger.error(f"Ticker data for {formatted_symbol} not found in response list.")
return None
if isinstance(response, dict):
ticker_data: Dict[str, Any] = response
elif isinstance(response, list) and len(response) > 0:
found_ticker = next((item for item in response if item.get('symbol') == formatted_symbol), None)
if found_ticker:
ticker_data = found_ticker
else:
logger.error(f"Unexpected ticker response format: {response}")
logger.error(f"Ticker data for {formatted_symbol} not found in response list.")
return None
else:
logger.error(f"Unexpected ticker response format: {response}")
return None
# Extract relevant info and format for universal use
last_price = float(ticker_data.get('lastPrice', 0))
bid_price = float(ticker_data.get('bidPrice', 0))
ask_price = float(ticker_data.get('askPrice', 0))
volume = float(ticker_data.get('volume', 0)) # Base asset volume
# At this point, ticker_data is guaranteed to be a Dict[str, Any] due to the above logic
# If it was None, we would have returned early.
# Determine price change and percent change
price_change = float(ticker_data.get('priceChange', 0))
price_change_percent = float(ticker_data.get('priceChangePercent', 0))
# Extract relevant info and format for universal use
last_price = float(ticker_data.get('lastPrice', 0))
bid_price = float(ticker_data.get('bidPrice', 0))
ask_price = float(ticker_data.get('askPrice', 0))
volume = float(ticker_data.get('volume', 0)) # Base asset volume
logger.info(f"MEXC: Got ticker from {endpoint} for {symbol}: ${last_price:.2f}")
return {
'symbol': formatted_symbol,
'last': last_price,
'bid': bid_price,
'ask': ask_price,
'volume': volume,
'high': float(ticker_data.get('highPrice', 0)),
'low': float(ticker_data.get('lowPrice', 0)),
'change': price_change_percent, # This is usually priceChangePercent
'exchange': 'MEXC',
'raw_data': ticker_data
}
logger.error(f"Failed to get ticker for {symbol}")
return None
# Determine price change and percent change
price_change = float(ticker_data.get('priceChange', 0))
price_change_percent = float(ticker_data.get('priceChangePercent', 0))
logger.info(f"MEXC: Got ticker from {endpoint} for {symbol}: ${last_price:.2f}")
return {
'symbol': formatted_symbol,
'last': last_price,
'bid': bid_price,
'ask': ask_price,
'volume': volume,
'high': float(ticker_data.get('highPrice', 0)),
'low': float(ticker_data.get('lowPrice', 0)),
'change': price_change_percent, # This is usually priceChangePercent
'exchange': 'MEXC',
'raw_data': ticker_data
}
def get_api_symbols(self) -> List[str]:
"""Get list of symbols supported for API trading"""
@@ -293,39 +298,89 @@ class MEXCInterface(ExchangeInterface):
logger.info(f"Supported symbols include: {supported_symbols[:10]}...") # Show first 10
return {}
# Format quantity according to symbol precision requirements
formatted_quantity = self._format_quantity_for_symbol(formatted_symbol, quantity)
if formatted_quantity is None:
logger.error(f"MEXC: Failed to format quantity {quantity} for {formatted_symbol}")
return {}
# Handle order type restrictions for specific symbols
final_order_type = self._adjust_order_type_for_symbol(formatted_symbol, order_type.upper())
# Get price for limit orders
final_price = price
if final_order_type == 'LIMIT' and price is None:
# Get current market price
ticker = self.get_ticker(symbol)
if ticker and 'last' in ticker:
final_price = ticker['last']
logger.info(f"MEXC: Using market price ${final_price:.2f} for LIMIT order")
else:
logger.error(f"MEXC: Could not get market price for LIMIT order on {formatted_symbol}")
return {}
endpoint = "order"
params: Dict[str, Any] = {
'symbol': formatted_symbol,
'side': side.upper(),
'type': order_type.upper(),
'quantity': str(quantity) # Quantity must be a string
'type': final_order_type,
'quantity': str(formatted_quantity) # Quantity must be a string
}
if price is not None:
params['price'] = str(price) # Price must be a string for limit orders
if final_price is not None:
params['price'] = str(final_price) # Price must be a string for limit orders
logger.info(f"MEXC: Placing {side.upper()} {order_type.upper()} order for {quantity} {formatted_symbol} at price {price}")
# For market orders, some parameters might be optional or handled differently.
# Check MEXC API docs for market order specifics (e.g., quoteOrderQty for buy market orders)
if order_type.upper() == 'MARKET' and side.upper() == 'BUY':
# If it's a market buy order, MEXC often expects quoteOrderQty instead of quantity
# Assuming quantity here refers to the base asset, if quoteOrderQty is needed, adjust.
# For now, we will stick to quantity and let MEXC handle the conversion if possible
pass # No specific change needed based on the current params structure
logger.info(f"MEXC: Placing {side.upper()} {final_order_type} order for {formatted_quantity} {formatted_symbol} at price {final_price}")
try:
# MEXC API endpoint for placing orders is /api/v3/order (POST)
order_result = self._send_private_request('POST', endpoint, params)
if order_result:
if order_result is not None:
logger.info(f"MEXC: Order placed successfully: {order_result}")
return order_result
else:
logger.error(f"MEXC: Error placing order: {order_result}")
logger.error(f"MEXC: Error placing order: request returned None")
return {}
except Exception as e:
logger.error(f"MEXC: Exception placing order: {e}")
return {}
def _format_quantity_for_symbol(self, formatted_symbol: str, quantity: float) -> Optional[float]:
"""Format quantity according to symbol precision requirements"""
try:
# Symbol-specific precision rules
if formatted_symbol == 'ETHUSDC':
# ETHUSDC requires max 5 decimal places, step size 0.000001
formatted_qty = round(quantity, 5)
# Ensure it meets minimum step size
step_size = 0.000001
formatted_qty = round(formatted_qty / step_size) * step_size
# Round again to remove floating point errors
formatted_qty = round(formatted_qty, 6)
logger.info(f"MEXC: Formatted ETHUSDC quantity {quantity} -> {formatted_qty}")
return formatted_qty
elif formatted_symbol == 'BTCUSDC':
# Assume similar precision for BTC
formatted_qty = round(quantity, 6)
step_size = 0.000001
formatted_qty = round(formatted_qty / step_size) * step_size
formatted_qty = round(formatted_qty, 6)
return formatted_qty
else:
# Default formatting - 6 decimal places
return round(quantity, 6)
except Exception as e:
logger.error(f"Error formatting quantity for {formatted_symbol}: {e}")
return None
def _adjust_order_type_for_symbol(self, formatted_symbol: str, order_type: str) -> str:
"""Adjust order type based on symbol restrictions"""
if formatted_symbol == 'ETHUSDC':
# ETHUSDC only supports LIMIT and LIMIT_MAKER orders
if order_type == 'MARKET':
logger.info(f"MEXC: Converting MARKET order to LIMIT for {formatted_symbol} (MARKET not supported)")
return 'LIMIT'
return order_type
def cancel_order(self, symbol: str, order_id: str) -> Dict[str, Any]:
"""Cancel an existing order on MEXC."""

View File

@@ -15,5 +15,7 @@ from NN.models.cnn_model import EnhancedCNNModel as CNNModel
from NN.models.dqn_agent import DQNAgent
from NN.models.cob_rl_model import MassiveRLNetwork, COBRLModelInterface
from NN.models.advanced_transformer_trading import AdvancedTradingTransformer, TradingTransformerConfig
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface
__all__ = ['CNNModel', 'DQNAgent', 'MassiveRLNetwork', 'COBRLModelInterface', 'AdvancedTradingTransformer', 'TradingTransformerConfig']
__all__ = ['CNNModel', 'DQNAgent', 'MassiveRLNetwork', 'COBRLModelInterface', 'AdvancedTradingTransformer', 'TradingTransformerConfig',
'ModelInterface', 'CNNModelInterface', 'RLAgentInterface', 'ExtremaTrainerInterface']

View File

@@ -772,8 +772,8 @@ class CNNModelTrainer:
# Comprehensive cleanup on any error
self.reset_computational_graph()
# Return safe dummy values to continue training
return {'main_loss': 0.0, 'total_loss': 0.0, 'accuracy': 0.5}
# Return realistic loss values based on random baseline performance
return {'main_loss': 0.693, 'total_loss': 0.693, 'accuracy': 0.5} # ln(2) for binary cross-entropy at random chance
def save_model(self, filepath: str, metadata: Optional[Dict] = None):
"""Save model with metadata"""
@@ -884,9 +884,8 @@ class CNNModel:
logger.error(f"Error in CNN prediction: {e}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
# Return dummy prediction
pred_class = np.array([0])
pred_proba = np.array([[0.1] * self.output_size])
# Return prediction based on simple statistical analysis of input
pred_class, pred_proba = self._fallback_prediction(X)
return pred_class, pred_proba
def fit(self, X, y, **kwargs):
@@ -944,6 +943,68 @@ class CNNModel:
except Exception as e:
logger.error(f"Error saving CNN model: {e}")
def _fallback_prediction(self, X):
"""Generate prediction based on statistical analysis of input data"""
try:
if isinstance(X, np.ndarray):
data = X
else:
data = X.cpu().numpy() if hasattr(X, 'cpu') else np.array(X)
# Analyze trends in the input data
if len(data.shape) >= 2:
# Calculate simple trend from the data
last_values = data[-10:] if len(data) >= 10 else data # Last 10 time steps
if len(last_values.shape) == 2:
# Multiple features - use first feature column as price
trend_data = last_values[:, 0]
else:
trend_data = last_values
# Calculate trend
if len(trend_data) > 1:
trend = (trend_data[-1] - trend_data[0]) / trend_data[0] if trend_data[0] != 0 else 0
# Map trend to action
if trend > 0.001: # Upward trend > 0.1%
action = 1 # BUY
confidence = min(0.9, 0.5 + abs(trend) * 10)
elif trend < -0.001: # Downward trend < -0.1%
action = 0 # SELL
confidence = min(0.9, 0.5 + abs(trend) * 10)
else:
action = 0 # Default to SELL for unclear trend
confidence = 0.3
else:
action = 0
confidence = 0.3
else:
action = 0
confidence = 0.3
# Create probabilities
proba = np.zeros(self.output_size)
proba[action] = confidence
# Distribute remaining probability among other classes
remaining = 1.0 - confidence
for i in range(self.output_size):
if i != action:
proba[i] = remaining / (self.output_size - 1)
pred_class = np.array([action])
pred_proba = np.array([proba])
logger.debug(f"Fallback prediction: action={action}, confidence={confidence:.2f}")
return pred_class, pred_proba
except Exception as e:
logger.error(f"Error in fallback prediction: {e}")
# Final fallback - conservative prediction
pred_class = np.array([0]) # SELL
proba = np.ones(self.output_size) / self.output_size # Equal probabilities
pred_proba = np.array([proba])
return pred_class, pred_proba
def load(self, filepath: str):
"""Load the model"""
try:

View File

@@ -18,6 +18,9 @@ import torch.nn.functional as F
import numpy as np
import logging
from typing import Dict, List, Optional, Tuple, Any
from abc import ABC, abstractmethod
from models import ModelInterface
logger = logging.getLogger(__name__)
@@ -221,12 +224,13 @@ class MassiveRLNetwork(nn.Module):
}
class COBRLModelInterface:
class COBRLModelInterface(ModelInterface):
"""
Interface for the COB RL model that handles model management, training, and inference
"""
def __init__(self, model_checkpoint_dir: str = "models/realtime_rl_cob", device: str = None):
def __init__(self, model_checkpoint_dir: str = "models/realtime_rl_cob", device: str = None, name=None, **kwargs):
super().__init__(name=name) # Initialize ModelInterface with a name
self.model_checkpoint_dir = model_checkpoint_dir
self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu'))
@@ -368,4 +372,23 @@ class COBRLModelInterface:
def get_model_stats(self) -> Dict[str, Any]:
"""Get model statistics"""
return self.model.get_model_info()
return self.model.get_model_info()
def get_memory_usage(self) -> float:
"""Estimate COBRLModel memory usage in MB"""
# This is an estimation. For a more precise value, you'd inspect tensors.
# A massive network might take hundreds of MBs or even GBs.
# Let's use a more realistic estimate for a 1B parameter model.
# Assuming float32 (4 bytes per parameter), 1B params = 4GB.
# For a 400M parameter network (as mentioned in comments), it's 1.6GB.
# Let's use a placeholder if it's too complex to calculate dynamically.
try:
# Calculate total parameters and convert to MB
total_params = sum(p.numel() for p in self.model.parameters())
# Assuming float32 (4 bytes per parameter) and converting to MB
memory_bytes = total_params * 4
memory_mb = memory_bytes / (1024 * 1024)
return memory_mb
except Exception as e:
logger.debug(f"Could not estimate COBRLModel memory usage: {e}")
return 1600.0 # Default to 1.6 GB as an estimate if calculation fails

View File

@@ -129,7 +129,128 @@ class DQNAgent:
logger.info(f"DQN Agent initialized with checkpoint management: {enable_checkpoints}")
if enable_checkpoints:
logger.info(f"Model name: {model_name}, Checkpoint frequency: {self.checkpoint_frequency}")
# Add this line to the __init__ method
self.recent_actions = deque(maxlen=10)
self.recent_prices = deque(maxlen=20)
self.recent_rewards = deque(maxlen=100)
# Price prediction tracking
self.last_price_pred = {
'immediate': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
},
'midterm': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
},
'longterm': {
'direction': 1, # Default to "sideways"
'confidence': 0.0,
'change': 0.0
}
}
# Store separate memory for price direction examples
self.price_movement_memory = [] # For storing examples of clear price movements
# Performance tracking
self.losses = []
self.no_improvement_count = 0
# Confidence tracking
self.confidence_history = []
self.avg_confidence = 0.0
self.max_confidence = 0.0
self.min_confidence = 1.0
# Enhanced features from EnhancedDQNAgent
# Market adaptation capabilities
self.market_regime_weights = {
'trending': 1.2, # Higher confidence in trending markets
'ranging': 0.8, # Lower confidence in ranging markets
'volatile': 0.6 # Much lower confidence in volatile markets
}
# Dueling network support (requires enhanced network architecture)
self.use_dueling = True
# Prioritized experience replay parameters
self.use_prioritized_replay = priority_memory
self.alpha = 0.6 # Priority exponent
self.beta = 0.4 # Importance sampling exponent
self.beta_increment = 0.001
# Double DQN support
self.use_double_dqn = True
# Enhanced training features from EnhancedDQNAgent
self.target_update_freq = target_update # More descriptive name
self.training_steps = 0
self.gradient_clip_norm = 1.0 # Gradient clipping
# Enhanced statistics tracking
self.epsilon_history = []
self.td_errors = [] # Track TD errors for analysis
# Trade action fee and confidence thresholds
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
# Violent move detection
self.price_history = []
self.volatility_window = 20 # Window size for volatility calculation
self.volatility_threshold = 0.0015 # Threshold for considering a move "violent"
self.post_violent_move = False # Flag for recent violent move
self.violent_move_cooldown = 0 # Cooldown after violent move
# Feature integration
self.last_hidden_features = None # Store last extracted features
self.feature_history = [] # Store history of features for analysis
# Real-time tick features integration
self.realtime_tick_features = None # Latest tick features from tick processor
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
# Check if mixed precision training should be used
self.use_mixed_precision = False
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
self.use_mixed_precision = True
self.scaler = torch.cuda.amp.GradScaler()
logger.info("Mixed precision training enabled")
else:
logger.info("Mixed precision training disabled")
# Track if we're in training mode
self.training = True
# For compatibility with old code
self.state_size = np.prod(state_shape)
self.action_size = n_actions
self.memory_size = buffer_size
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] # Default timeframes
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}")
# Log model parameters
total_params = sum(p.numel() for p in self.policy_net.parameters())
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
# Position management for 2-action system
self.current_position = 0.0 # -1 (short), 0 (neutral), 1 (long)
self.position_entry_price = 0.0
self.position_entry_time = None
# Different thresholds for entry vs exit decisions - AGGRESSIVE for more training data
self.entry_confidence_threshold = 0.35 # Lower threshold for new positions (was 0.7)
self.exit_confidence_threshold = 0.15 # Very low threshold for closing positions (was 0.3)
self.uncertainty_threshold = 0.1 # When to stay neutral
def load_best_checkpoint(self):
"""Load the best checkpoint for this DQN agent"""
try:
@@ -267,9 +388,6 @@ class DQNAgent:
# Trade action fee and confidence thresholds
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
self.recent_actions = deque(maxlen=10)
self.recent_prices = deque(maxlen=20)
self.recent_rewards = deque(maxlen=100)
# Violent move detection
self.price_history = []

View File

@@ -0,0 +1,99 @@
"""
Model Interfaces Module
Defines abstract base classes and concrete implementations for various model types
to ensure consistent interaction within the trading system.
"""
import logging
from typing import Dict, Any, Optional, List
from abc import ABC, abstractmethod
import numpy as np
logger = logging.getLogger(__name__)
class ModelInterface(ABC):
"""Base interface for all models"""
def __init__(self, name: str):
self.name = name
@abstractmethod
def predict(self, data):
"""Make a prediction"""
pass
@abstractmethod
def get_memory_usage(self) -> float:
"""Get memory usage in MB"""
pass
class CNNModelInterface(ModelInterface):
"""Interface for CNN models"""
def __init__(self, model, name: str):
super().__init__(name)
self.model = model
def predict(self, data):
"""Make CNN prediction"""
try:
if hasattr(self.model, 'predict'):
return self.model.predict(data)
return None
except Exception as e:
logger.error(f"Error in CNN prediction: {e}")
return None
def get_memory_usage(self) -> float:
"""Estimate CNN memory usage"""
return 50.0 # MB
class RLAgentInterface(ModelInterface):
"""Interface for RL agents"""
def __init__(self, model, name: str):
super().__init__(name)
self.model = model
def predict(self, data):
"""Make RL prediction"""
try:
if hasattr(self.model, 'act'):
return self.model.act(data)
elif hasattr(self.model, 'predict'):
return self.model.predict(data)
return None
except Exception as e:
logger.error(f"Error in RL prediction: {e}")
return None
def get_memory_usage(self) -> float:
"""Estimate RL memory usage"""
return 25.0 # MB
class ExtremaTrainerInterface(ModelInterface):
"""Interface for ExtremaTrainer models, providing context features"""
def __init__(self, model, name: str):
super().__init__(name)
self.model = model
def predict(self, data=None):
"""ExtremaTrainer doesn't predict in the traditional sense, it provides features."""
logger.warning(f"Predict method called on ExtremaTrainerInterface ({self.name}). Use get_context_features_for_model instead.")
return None
def get_memory_usage(self) -> float:
"""Estimate ExtremaTrainer memory usage"""
return 30.0 # MB
def get_context_features_for_model(self, symbol: str) -> Optional[np.ndarray]:
"""Get context features from the ExtremaTrainer for model consumption."""
try:
if hasattr(self.model, 'get_context_features_for_model'):
return self.model.get_context_features_for_model(symbol)
return None
except Exception as e:
logger.error(f"Error getting extrema context features: {e}")
return None

View File

@@ -339,12 +339,64 @@ class TransformerModel:
# Ensure X_features has the right shape
if X_features is None:
# Create dummy features with zeros
X_features = np.zeros((X_ts.shape[0], self.feature_input_shape))
# Extract features from time series data if no external features provided
X_features = self._extract_features_from_timeseries(X_ts)
elif len(X_features.shape) == 1:
# Single sample, add batch dimension
X_features = np.expand_dims(X_features, axis=0)
def _extract_features_from_timeseries(self, X_ts: np.ndarray) -> np.ndarray:
"""Extract meaningful features from time series data instead of using dummy zeros"""
try:
batch_size = X_ts.shape[0]
features = []
for i in range(batch_size):
sample = X_ts[i] # Shape: (timesteps, features)
# Extract statistical features from each feature dimension
sample_features = []
for feature_idx in range(sample.shape[1]):
feature_data = sample[:, feature_idx]
# Basic statistical features
sample_features.extend([
np.mean(feature_data), # Mean
np.std(feature_data), # Standard deviation
np.min(feature_data), # Minimum
np.max(feature_data), # Maximum
np.percentile(feature_data, 25), # 25th percentile
np.percentile(feature_data, 75), # 75th percentile
])
# Trend features
if len(feature_data) > 1:
# Linear trend (slope)
x = np.arange(len(feature_data))
slope = np.polyfit(x, feature_data, 1)[0]
sample_features.append(slope)
# Rate of change
rate_of_change = (feature_data[-1] - feature_data[0]) / feature_data[0] if feature_data[0] != 0 else 0
sample_features.append(rate_of_change)
else:
sample_features.extend([0.0, 0.0])
# Pad or truncate to expected feature size
while len(sample_features) < self.feature_input_shape:
sample_features.append(0.0)
sample_features = sample_features[:self.feature_input_shape]
features.append(sample_features)
return np.array(features, dtype=np.float32)
except Exception as e:
logger.error(f"Error extracting features from time series: {e}")
# Fallback to zeros if extraction fails
return np.zeros((X_ts.shape[0], self.feature_input_shape), dtype=np.float32)
# Get predictions
y_proba = self.model.predict([X_ts, X_features])

View File

@@ -77,3 +77,8 @@ use existing checkpoint manager if it;s not too bloated as well. otherwise re-im
we should load the models in a way that we do a back propagation and other model specificic training at realtime as training examples emerge from the realtime data we process. we will save only the best examples (the realtime data dumps we feed to the models) so we can cold start other models if we change the architecture. if it's not working, perform a cleanup of all traininn and trainer code to make it easer to work withm to streamline latest changes and to simplify and refactor it

View File

@@ -162,11 +162,11 @@ mexc_trading:
trading_mode: simulation # simulation, testnet, live
# Position sizing as percentage of account balance
base_position_percent: 5.0 # 5% base position of account
max_position_percent: 20.0 # 20% max position of account
min_position_percent: 2.0 # 2% min position of account
leverage: 50.0 # 50x leverage (adjustable in UI)
simulation_account_usd: 100.0 # $100 simulation account balance
base_position_percent: 1 # 0.5% base position of account (MUCH SAFER)
max_position_percent: 5.0 # 2% max position of account (REDUCED)
min_position_percent: 0.5 # 0.2% min position of account (REDUCED)
leverage: 1.0 # 1x leverage (NO LEVERAGE FOR TESTING)
simulation_account_usd: 99.9 # $100 simulation account balance
# Risk management
max_daily_loss_usd: 200.0
@@ -197,6 +197,7 @@ enhanced_training:
enabled: true # Enable enhanced real-time training
auto_start: true # Automatically start training when orchestrator starts
training_intervals:
cob_rl_training_interval: 1 # Train COB RL every 1 second (HIGHEST PRIORITY)
dqn_training_interval: 5 # Train DQN every 5 seconds
cnn_training_interval: 10 # Train CNN every 10 seconds
validation_interval: 60 # Validate every minute
@@ -206,6 +207,11 @@ enhanced_training:
adaptation_threshold: 0.1 # Performance threshold for adaptation
forward_looking_predictions: true # Enable forward-looking prediction validation
# COB RL Priority Settings (since order book imbalance predicts price moves)
cob_rl_priority: true # Enable COB RL as highest priority model
cob_rl_batch_size: 16 # Smaller batches for faster COB updates
cob_rl_min_samples: 5 # Lower threshold for COB training
# Real-time RL COB Trader Configuration
realtime_rl:
# Model parameters for 400M parameter network (faster startup)

View File

@@ -34,7 +34,7 @@ class COBIntegration:
Integration layer for Multi-Exchange COB data with gogo2 trading system
"""
def __init__(self, data_provider: Optional[DataProvider] = None, symbols: Optional[List[str]] = None):
def __init__(self, data_provider: Optional[DataProvider] = None, symbols: Optional[List[str]] = None, initial_data_limit=None, **kwargs):
"""
Initialize COB Integration
@@ -655,7 +655,7 @@ class COBIntegration:
except Exception as e:
logger.error(f"Error getting NN stats for {symbol}: {e}")
return {}
return {}
def get_realtime_stats(self):
# Added null check to ensure the COB provider is initialized

View File

@@ -661,22 +661,315 @@ class MultiExchangeCOBProvider:
except Exception as e:
logger.error(f"Error processing Binance order book for {symbol}: {e}", exc_info=True)
async def _stream_coinbase_orderbook(self, symbol: str, config: ExchangeConfig):
"""Stream Coinbase order book data (placeholder implementation)"""
async def _process_coinbase_orderbook(self, symbol: str, data: Dict):
"""Process Coinbase order book data"""
try:
# For now, just log that Coinbase streaming is not implemented
logger.info(f"Coinbase streaming for {symbol} not yet implemented")
await asyncio.sleep(60) # Sleep to prevent spam
if data.get('type') == 'snapshot':
# Initial snapshot
bids = {}
asks = {}
for bid_data in data.get('bids', []):
price, size = float(bid_data[0]), float(bid_data[1])
if size > 0:
bids[price] = ExchangeOrderBookLevel(
exchange='coinbase',
price=price,
size=size,
volume_usd=price * size,
orders_count=1, # Coinbase doesn't provide order count
side='bid',
timestamp=datetime.now(),
raw_data=bid_data
)
for ask_data in data.get('asks', []):
price, size = float(ask_data[0]), float(ask_data[1])
if size > 0:
asks[price] = ExchangeOrderBookLevel(
exchange='coinbase',
price=price,
size=size,
volume_usd=price * size,
orders_count=1,
side='ask',
timestamp=datetime.now(),
raw_data=ask_data
)
# Update order book
async with self.data_lock:
if symbol not in self.exchange_order_books:
self.exchange_order_books[symbol] = {}
self.exchange_order_books[symbol]['coinbase'] = {
'bids': bids,
'asks': asks,
'last_update': datetime.now(),
'connected': True
}
logger.info(f"Coinbase snapshot for {symbol}: {len(bids)} bids, {len(asks)} asks")
elif data.get('type') == 'l2update':
# Level 2 update
async with self.data_lock:
if symbol in self.exchange_order_books and 'coinbase' in self.exchange_order_books[symbol]:
coinbase_data = self.exchange_order_books[symbol]['coinbase']
for change in data.get('changes', []):
side, price_str, size_str = change
price, size = float(price_str), float(size_str)
if side == 'buy':
if size == 0:
# Remove level
coinbase_data['bids'].pop(price, None)
else:
# Update level
coinbase_data['bids'][price] = ExchangeOrderBookLevel(
exchange='coinbase',
price=price,
size=size,
volume_usd=price * size,
orders_count=1,
side='bid',
timestamp=datetime.now(),
raw_data=change
)
elif side == 'sell':
if size == 0:
# Remove level
coinbase_data['asks'].pop(price, None)
else:
# Update level
coinbase_data['asks'][price] = ExchangeOrderBookLevel(
exchange='coinbase',
price=price,
size=size,
volume_usd=price * size,
orders_count=1,
side='ask',
timestamp=datetime.now(),
raw_data=change
)
coinbase_data['last_update'] = datetime.now()
# Update exchange count
exchange_name = 'coinbase'
if exchange_name not in self.exchange_update_counts:
self.exchange_update_counts[exchange_name] = 0
self.exchange_update_counts[exchange_name] += 1
# Log every 1000th update
if self.exchange_update_counts[exchange_name] % 1000 == 0:
logger.info(f"Processed {self.exchange_update_counts[exchange_name]} Coinbase updates for {symbol}")
except Exception as e:
logger.error(f"Error streaming Coinbase order book for {symbol}: {e}")
logger.error(f"Error processing Coinbase order book for {symbol}: {e}", exc_info=True)
async def _process_kraken_orderbook(self, symbol: str, data: Dict):
"""Process Kraken order book data"""
try:
# Kraken sends different message types
if isinstance(data, list) and len(data) > 1:
# Order book update format: [channel_id, data, channel_name, pair]
if len(data) >= 4 and data[2] == "book-25":
book_data = data[1]
# Check for snapshot vs update
if 'bs' in book_data and 'as' in book_data:
# Snapshot
bids = {}
asks = {}
for bid_data in book_data.get('bs', []):
price, volume, timestamp = float(bid_data[0]), float(bid_data[1]), float(bid_data[2])
if volume > 0:
bids[price] = ExchangeOrderBookLevel(
exchange='kraken',
price=price,
size=volume,
volume_usd=price * volume,
orders_count=1, # Kraken doesn't provide order count in book feed
side='bid',
timestamp=datetime.fromtimestamp(timestamp),
raw_data=bid_data
)
for ask_data in book_data.get('as', []):
price, volume, timestamp = float(ask_data[0]), float(ask_data[1]), float(ask_data[2])
if volume > 0:
asks[price] = ExchangeOrderBookLevel(
exchange='kraken',
price=price,
size=volume,
volume_usd=price * volume,
orders_count=1,
side='ask',
timestamp=datetime.fromtimestamp(timestamp),
raw_data=ask_data
)
# Update order book
async with self.data_lock:
if symbol not in self.exchange_order_books:
self.exchange_order_books[symbol] = {}
self.exchange_order_books[symbol]['kraken'] = {
'bids': bids,
'asks': asks,
'last_update': datetime.now(),
'connected': True
}
logger.info(f"Kraken snapshot for {symbol}: {len(bids)} bids, {len(asks)} asks")
else:
# Incremental update
async with self.data_lock:
if symbol in self.exchange_order_books and 'kraken' in self.exchange_order_books[symbol]:
kraken_data = self.exchange_order_books[symbol]['kraken']
# Process bid updates
for bid_update in book_data.get('b', []):
price, volume, timestamp = float(bid_update[0]), float(bid_update[1]), float(bid_update[2])
if volume == 0:
# Remove level
kraken_data['bids'].pop(price, None)
else:
# Update level
kraken_data['bids'][price] = ExchangeOrderBookLevel(
exchange='kraken',
price=price,
size=volume,
volume_usd=price * volume,
orders_count=1,
side='bid',
timestamp=datetime.fromtimestamp(timestamp),
raw_data=bid_update
)
# Process ask updates
for ask_update in book_data.get('a', []):
price, volume, timestamp = float(ask_update[0]), float(ask_update[1]), float(ask_update[2])
if volume == 0:
# Remove level
kraken_data['asks'].pop(price, None)
else:
# Update level
kraken_data['asks'][price] = ExchangeOrderBookLevel(
exchange='kraken',
price=price,
size=volume,
volume_usd=price * volume,
orders_count=1,
side='ask',
timestamp=datetime.fromtimestamp(timestamp),
raw_data=ask_update
)
kraken_data['last_update'] = datetime.now()
# Update exchange count
exchange_name = 'kraken'
if exchange_name not in self.exchange_update_counts:
self.exchange_update_counts[exchange_name] = 0
self.exchange_update_counts[exchange_name] += 1
# Log every 1000th update
if self.exchange_update_counts[exchange_name] % 1000 == 0:
logger.info(f"Processed {self.exchange_update_counts[exchange_name]} Kraken updates for {symbol}")
except Exception as e:
logger.error(f"Error processing Kraken order book for {symbol}: {e}", exc_info=True)
async def _stream_coinbase_orderbook(self, symbol: str, config: ExchangeConfig):
"""Stream Coinbase order book data via WebSocket"""
try:
import json
if websockets is None or websockets_connect is None:
raise ImportError("websockets module not available")
# Coinbase Pro WebSocket URL
ws_url = "wss://ws-feed.pro.coinbase.com"
coinbase_symbol = config.symbols_mapping.get(symbol, symbol.replace('/', '-'))
# Subscribe message for level2 order book updates
subscribe_message = {
"type": "subscribe",
"product_ids": [coinbase_symbol],
"channels": ["level2"]
}
logger.info(f"Connecting to Coinbase order book stream for {symbol}")
async with websockets_connect(ws_url) as websocket:
# Send subscription
await websocket.send(json.dumps(subscribe_message))
logger.info(f"Subscribed to Coinbase level2 for {coinbase_symbol}")
async for message in websocket:
if not self.is_streaming:
break
try:
data = json.loads(message)
await self._process_coinbase_orderbook(symbol, data)
except json.JSONDecodeError as e:
logger.error(f"Error parsing Coinbase message: {e}")
except Exception as e:
logger.error(f"Error processing Coinbase orderbook: {e}")
except Exception as e:
logger.error(f"Coinbase order book stream error for {symbol}: {e}")
finally:
logger.info(f"Disconnected from Coinbase order book stream for {symbol}")
async def _stream_kraken_orderbook(self, symbol: str, config: ExchangeConfig):
"""Stream Kraken order book data (placeholder implementation)"""
"""Stream Kraken order book data via WebSocket"""
try:
logger.info(f"Kraken streaming for {symbol} not yet implemented")
await asyncio.sleep(60) # Sleep to prevent spam
import json
if websockets is None or websockets_connect is None:
raise ImportError("websockets module not available")
# Kraken WebSocket URL
ws_url = "wss://ws.kraken.com"
kraken_symbol = config.symbols_mapping.get(symbol, symbol.replace('/', ''))
# Subscribe message for book updates
subscribe_message = {
"event": "subscribe",
"pair": [kraken_symbol],
"subscription": {"name": "book", "depth": 25}
}
logger.info(f"Connecting to Kraken order book stream for {symbol}")
async with websockets_connect(ws_url) as websocket:
# Send subscription
await websocket.send(json.dumps(subscribe_message))
logger.info(f"Subscribed to Kraken book for {kraken_symbol}")
async for message in websocket:
if not self.is_streaming:
break
try:
data = json.loads(message)
await self._process_kraken_orderbook(symbol, data)
except json.JSONDecodeError as e:
logger.error(f"Error parsing Kraken message: {e}")
except Exception as e:
logger.error(f"Error processing Kraken orderbook: {e}")
except Exception as e:
logger.error(f"Error streaming Kraken order book for {symbol}: {e}")
logger.error(f"Kraken order book stream error for {symbol}: {e}")
finally:
logger.info(f"Disconnected from Kraken order book stream for {symbol}")
async def _stream_huobi_orderbook(self, symbol: str, config: ExchangeConfig):
"""Stream Huobi order book data (placeholder implementation)"""

File diff suppressed because it is too large Load Diff

View File

@@ -114,12 +114,17 @@ class TradingExecutor:
# Thread safety
self.lock = Lock()
# Connect to exchange
# Connect to exchange - skip connection check in simulation mode
if self.trading_enabled:
logger.info("TRADING EXECUTOR: Attempting to connect to exchange...")
if not self._connect_exchange():
logger.error("TRADING EXECUTOR: Failed initial exchange connection. Trading will be disabled.")
self.trading_enabled = False
if self.simulation_mode:
logger.info("TRADING EXECUTOR: Simulation mode - skipping exchange connection check")
# In simulation mode, we don't need a real exchange connection
# Trading should remain enabled for simulation trades
else:
logger.info("TRADING EXECUTOR: Attempting to connect to exchange...")
if not self._connect_exchange():
logger.error("TRADING EXECUTOR: Failed initial exchange connection. Trading will be disabled.")
self.trading_enabled = False
else:
logger.info("TRADING EXECUTOR: Trading is explicitly disabled in config.")
@@ -230,15 +235,25 @@ class TradingExecutor:
required_capital = self._calculate_position_size(confidence, current_price)
# Get available balance for the quote asset
available_balance = self.exchange.get_balance(quote_asset)
# If USDC balance is insufficient, check USDT as fallback (for MEXC compatibility)
if available_balance < required_capital and quote_asset == 'USDC':
# For MEXC, prioritize USDT over USDC since most accounts have USDT
if quote_asset == 'USDC':
# Check USDT first (most common balance)
usdt_balance = self.exchange.get_balance('USDT')
usdc_balance = self.exchange.get_balance('USDC')
if usdt_balance >= required_capital:
available_balance = usdt_balance
quote_asset = 'USDT' # Use USDT instead
logger.info(f"BALANCE CHECK: Using USDT fallback balance for {symbol}")
quote_asset = 'USDT' # Use USDT for trading
logger.info(f"BALANCE CHECK: Using USDT balance for {symbol} (preferred)")
elif usdc_balance >= required_capital:
available_balance = usdc_balance
logger.info(f"BALANCE CHECK: Using USDC balance for {symbol}")
else:
# Use the larger balance for reporting
available_balance = max(usdt_balance, usdc_balance)
quote_asset = 'USDT' if usdt_balance > usdc_balance else 'USDC'
else:
available_balance = self.exchange.get_balance(quote_asset)
logger.info(f"BALANCE CHECK: Symbol: {symbol}, Action: {action}, Required: ${required_capital:.2f} {quote_asset}, Available: ${available_balance:.2f} {quote_asset}")

View File

@@ -229,9 +229,12 @@ class TrainingIntegration:
# Truncate
features = features[:50]
# Get the model's device to ensure tensors are on the same device
model_device = next(cnn_model.parameters()).device
# Create tensors
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(device)
target_tensor = torch.LongTensor([target]).to(device)
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(model_device)
target_tensor = torch.LongTensor([target]).to(model_device)
# Training step
cnn_model.train()

View File

@@ -6,17 +6,18 @@ III. models we currently use (architecture is expandable with easy adaption to n
- cnn price prediction model - uses calculated multilevel pivot points and historical price data to predict the next pivot point for each level.
- DQN RL model outputs trade signals
- transformer model outputs price prediction
- COB RL model outputs trade signals - it is trained on cob (cached all COB data for period of time not just current order book. it should be a 2d matrix 1s aggregated ) and some indicators cummulative cob imbalance for different timeframes.
- COB RL model outputs trade signals - it is trained on cob (cached all COB data for period of time not just current order book. it should be a 2d matrix 1s aggregated ) and some indicators cummulative cob imbalance for different timeframes. we get COB snapshots every couple hundred miliseconds and we cache and aggregate them to have a COB history. 1d matrix from the API to 2d amtrix as model inputs. as both raw ticks and 1s averaged.
- decision model - it is trained on price prediction and trade signals to learn the effectiveness of the other models in contribute to succeessful prediction. outputs the final trade signal.
IV. by default all models take full current data frames available in the orchestrator on inference as base data - different aspects of the data are updated at different rates. main data frame includes 5 price charts
class UniversalDataAdapter:
- 1s 1m 1h ETH charts and ETH and BTC ticks. orchestrator can use and extend the UniversalDataAdapter class to add new data sources and data types.
- - cob models are different and they get fast realtime raw dob data ticks and should be agile to inference and procude outputs but yet able to learn.
V. hardware. we use GPU if available for training and inference for optimised performance.
V. Training and hardware.
- we should load the models in a way that we do a back propagation and other model specificic training at realtime as training examples emerge from the realtime data we process. we will save only the best examples (the realtime data dumps we feed to the models) so we can cold start other models if we change the architecture. i
- we use GPU if available for training and inference for optimised performance.

File diff suppressed because it is too large Load Diff

View File

@@ -1,201 +1,121 @@
#!/usr/bin/env python3
"""
Run Clean Trading Dashboard with Full Training Pipeline
Integrated system with both training loop and clean web dashboard
Clean Trading Dashboard Runner with Enhanced Stability and Error Handling
"""
import os
# Fix OpenMP library conflicts before importing other modules
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
os.environ['OMP_NUM_THREADS'] = '4'
import asyncio
import logging
import sys
import threading
import logging
import traceback
import gc
import time
import psutil
import torch
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import get_config, setup_logging
from core.data_provider import DataProvider
# Import checkpoint management
from utils.checkpoint_manager import get_checkpoint_manager
from utils.training_integration import get_training_integration
# Setup logging
setup_logging()
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
async def start_training_pipeline(orchestrator, trading_executor):
"""Start the training pipeline in the background"""
logger.info("=" * 70)
logger.info("STARTING TRAINING PIPELINE WITH CLEAN DASHBOARD")
logger.info("=" * 70)
# Initialize checkpoint management
checkpoint_manager = get_checkpoint_manager()
training_integration = get_training_integration()
# Training statistics
training_stats = {
'iteration_count': 0,
'total_decisions': 0,
'successful_trades': 0,
'best_performance': 0.0,
'last_checkpoint_iteration': 0
}
try:
# Start real-time processing (available in Enhanced orchestrator)
if hasattr(orchestrator, 'start_realtime_processing'):
await orchestrator.start_realtime_processing()
logger.info("Real-time processing started")
# Start COB integration (available in Enhanced orchestrator)
if hasattr(orchestrator, 'start_cob_integration'):
await orchestrator.start_cob_integration()
logger.info("COB integration started - 5-minute data matrix active")
else:
logger.info("COB integration not available")
# Main training loop
iteration = 0
last_checkpoint_time = time.time()
while True:
try:
iteration += 1
training_stats['iteration_count'] = iteration
# Get symbols to process
symbols = orchestrator.symbols if hasattr(orchestrator, 'symbols') else ['ETH/USDT']
# Process each symbol
for symbol in symbols:
try:
# Make trading decision (this triggers model training)
decision = await orchestrator.make_trading_decision(symbol)
if decision:
training_stats['total_decisions'] += 1
logger.debug(f"[{symbol}] Decision: {decision.action} @ {decision.confidence:.1%}")
except Exception as e:
logger.warning(f"Error processing {symbol}: {e}")
# Status logging every 100 iterations
if iteration % 100 == 0:
current_time = time.time()
elapsed = current_time - last_checkpoint_time
logger.info(f"[TRAINING] Iteration {iteration}, Decisions: {training_stats['total_decisions']}, Time: {elapsed:.1f}s")
# Models will save their own checkpoints when performance improves
training_stats['last_checkpoint_iteration'] = iteration
last_checkpoint_time = current_time
# Brief pause to prevent overwhelming the system
await asyncio.sleep(0.1) # 100ms between iterations
except Exception as e:
logger.error(f"Training loop error: {e}")
await asyncio.sleep(5) # Wait longer on error
except Exception as e:
logger.error(f"Training pipeline error: {e}")
import traceback
logger.error(traceback.format_exc())
def clear_gpu_memory():
"""Clear GPU memory cache"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def start_clean_dashboard_with_training():
"""Start clean dashboard with full training pipeline"""
try:
logger.info("=" * 80)
logger.info("CLEAN TRADING DASHBOARD + FULL TRAINING PIPELINE")
logger.info("=" * 80)
logger.info("Features: Real-time Training, COB Integration, Clean UI")
logger.info("Universal Data Stream: ENABLED")
logger.info("Neural Decision Fusion: ENABLED")
logger.info("COB Integration: ENABLED")
logger.info("GPU Training: ENABLED")
logger.info("Multi-symbol: ETH/USDT, BTC/USDT")
# Get port from environment or use default
dashboard_port = int(os.environ.get('DASHBOARD_PORT', '8051'))
logger.info(f"Dashboard: http://127.0.0.1:{dashboard_port}")
logger.info("=" * 80)
# Check environment variables
enable_universal_stream = os.environ.get('ENABLE_UNIVERSAL_DATA_STREAM', '1') == '1'
enable_nn_fusion = os.environ.get('ENABLE_NN_DECISION_FUSION', '1') == '1'
enable_cob = os.environ.get('ENABLE_COB_INTEGRATION', '1') == '1'
logger.info(f"Universal Data Stream: {'ENABLED' if enable_universal_stream else 'DISABLED'}")
logger.info(f"Neural Decision Fusion: {'ENABLED' if enable_nn_fusion else 'DISABLED'}")
logger.info(f"COB Integration: {'ENABLED' if enable_cob else 'DISABLED'}")
# Get configuration
config = get_config()
# Initialize core components
from core.data_provider import DataProvider
from core.orchestrator import TradingOrchestrator
from core.trading_executor import TradingExecutor
# Create data provider
data_provider = DataProvider()
# Create enhanced orchestrator with COB integration - stable and efficient
orchestrator = TradingOrchestrator(data_provider, enhanced_rl_training=True)
logger.info("Enhanced Trading Orchestrator created with COB integration")
# Create trading executor
trading_executor = TradingExecutor()
# Import clean dashboard
from web.clean_dashboard import create_clean_dashboard
# Create clean dashboard
dashboard = create_clean_dashboard(
data_provider=data_provider,
orchestrator=orchestrator,
trading_executor=trading_executor
)
logger.info("Clean Trading Dashboard created")
# Start training pipeline in background thread
def training_worker():
"""Run training pipeline in background"""
try:
asyncio.run(start_training_pipeline(orchestrator, trading_executor))
except Exception as e:
logger.error(f"Training worker error: {e}")
training_thread = threading.Thread(target=training_worker, daemon=True)
training_thread.start()
logger.info("Training pipeline started in background")
# Wait a moment for training to initialize
time.sleep(3)
# Start dashboard server (this blocks)
logger.info(" Starting Clean Dashboard Server...")
dashboard.run_server(host='127.0.0.1', port=dashboard_port, debug=False)
except KeyboardInterrupt:
logger.info("System stopped by user")
except Exception as e:
logger.error(f"Error running clean dashboard with training: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
def check_system_resources():
"""Check if system has enough resources"""
available_ram = psutil.virtual_memory().available / 1024**3
if available_ram < 2.0: # Less than 2GB available
logger.warning(f"Low RAM: {available_ram:.1f} GB available")
gc.collect()
clear_gpu_memory()
return False
return True
def main():
"""Main function"""
start_clean_dashboard_with_training()
def run_dashboard_with_recovery():
"""Run dashboard with automatic error recovery"""
max_retries = 3
retry_count = 0
while retry_count < max_retries:
try:
logger.info(f"Starting Clean Trading Dashboard (attempt {retry_count + 1}/{max_retries})")
# Check system resources
if not check_system_resources():
logger.warning("System resources low, waiting 30 seconds...")
time.sleep(30)
continue
# Import here to avoid memory issues on restart
from core.data_provider import DataProvider
from core.orchestrator import TradingOrchestrator
from core.trading_executor import TradingExecutor
from web.clean_dashboard import create_clean_dashboard
logger.info("Creating data provider...")
data_provider = DataProvider()
logger.info("Creating trading orchestrator...")
orchestrator = TradingOrchestrator(
data_provider=data_provider,
enhanced_rl_training=True
)
logger.info("Creating trading executor...")
trading_executor = TradingExecutor()
logger.info("Creating clean dashboard...")
dashboard = create_clean_dashboard(data_provider, orchestrator, trading_executor)
logger.info("Dashboard created successfully")
logger.info("=== Clean Trading Dashboard Status ===")
logger.info("- Data Provider: Active")
logger.info("- Trading Orchestrator: Active")
logger.info("- Trading Executor: Active")
logger.info("- Enhanced Training: Active")
logger.info("- Dashboard: Ready")
logger.info("=======================================")
# Start the dashboard server with error handling
try:
logger.info("Starting dashboard server on http://127.0.0.1:8050")
dashboard.run_server(host='127.0.0.1', port=8050, debug=False)
except KeyboardInterrupt:
logger.info("Dashboard stopped by user")
break
except Exception as e:
logger.error(f"Dashboard server error: {e}")
logger.error(traceback.format_exc())
raise
except Exception as e:
logger.error(f"Critical error in dashboard: {e}")
logger.error(traceback.format_exc())
retry_count += 1
if retry_count < max_retries:
logger.info(f"Attempting recovery... ({retry_count}/{max_retries})")
# Cleanup
gc.collect()
clear_gpu_memory()
# Wait before retry
wait_time = 30 * retry_count # Exponential backoff
logger.info(f"Waiting {wait_time} seconds before retry...")
time.sleep(wait_time)
else:
logger.error("Max retries reached. Exiting.")
sys.exit(1)
if __name__ == "__main__":
main()
try:
run_dashboard_with_recovery()
except KeyboardInterrupt:
logger.info("Application stopped by user")
sys.exit(0)
except Exception as e:
logger.error(f"Fatal error: {e}")
logger.error(traceback.format_exc())
sys.exit(1)

View File

@@ -23,7 +23,7 @@ class RewardCalculator:
self.trade_timestamps = []
self.frequency_threshold = 10 # Trades per minute threshold for penalty
self.max_frequency_penalty = 0.05
def record_pnl(self, pnl):
"""Record P&L for risk adjustment calculations"""
self.trade_pnls.append(pnl)
@@ -36,7 +36,7 @@ class RewardCalculator:
self.trade_timestamps.append(time())
if len(self.trade_timestamps) > 100:
self.trade_timestamps.pop(0)
def _calculate_frequency_penalty(self):
"""Calculate penalty for high-frequency trading"""
if len(self.trade_timestamps) < 2:
@@ -47,9 +47,9 @@ class RewardCalculator:
trades_per_minute = (len(self.trade_timestamps) / time_span) * 60
if trades_per_minute > self.frequency_threshold:
penalty = min(self.max_frequency_penalty, (trades_per_minute - self.frequency_threshold) * 0.001)
return penalty
return penalty
return 0.0
def _calculate_risk_adjustment(self, reward):
"""Adjust rewards based on risk (simple Sharpe ratio implementation)"""
if len(self.trade_pnls) < 5:
@@ -62,7 +62,7 @@ class RewardCalculator:
sharpe = mean_return / std_return
adjustment_factor = np.clip(1.0 + 0.5 * sharpe, 0.5, 2.0)
return reward * adjustment_factor
def _calculate_holding_reward(self, position_held_time, price_change):
"""Calculate reward for holding a position"""
base_holding_reward = 0.0005 * (position_held_time / 60.0)

View File

@@ -205,6 +205,9 @@ class CleanTradingDashboard:
# Start signal generation loop to ensure continuous trading signals
self._start_signal_generation_loop()
# Start live balance sync for trading
self._start_live_balance_sync()
# Start training sessions if models are showing FRESH status
threading.Thread(target=self._delayed_training_check, daemon=True).start()
@@ -318,6 +321,66 @@ class CleanTradingDashboard:
except Exception as e:
logger.warning(f"Error getting balance: {e}")
return 100.0 # Default balance
def _get_live_balance(self) -> float:
"""Get real-time balance from exchange when in live trading mode"""
try:
if self.trading_executor:
# Check if we're in live trading mode
is_live = (hasattr(self.trading_executor, 'trading_enabled') and
self.trading_executor.trading_enabled and
hasattr(self.trading_executor, 'simulation_mode') and
not self.trading_executor.simulation_mode)
if is_live and hasattr(self.trading_executor, 'exchange'):
# Get real balance from exchange (throttled to avoid API spam)
import time
current_time = time.time()
# Cache balance for 5 seconds for more frequent updates in live trading
if not hasattr(self, '_last_balance_check') or current_time - self._last_balance_check > 5:
exchange = self.trading_executor.exchange
if hasattr(exchange, 'get_balance'):
live_balance = exchange.get_balance('USDC')
if live_balance is not None and live_balance > 0:
self._cached_live_balance = live_balance
self._last_balance_check = current_time
logger.info(f"LIVE BALANCE: Retrieved ${live_balance:.2f} USDC from MEXC")
return live_balance
else:
logger.warning(f"LIVE BALANCE: Retrieved ${live_balance:.2f} USDC - checking USDT as fallback")
# Also try USDT as fallback since user might have USDT
usdt_balance = exchange.get_balance('USDT')
if usdt_balance is not None and usdt_balance > 0:
self._cached_live_balance = usdt_balance
self._last_balance_check = current_time
logger.info(f"LIVE BALANCE: Using USDT balance ${usdt_balance:.2f}")
return usdt_balance
else:
logger.warning("LIVE BALANCE: Exchange does not have get_balance method")
else:
# Return cached balance if within 10 second window
if hasattr(self, '_cached_live_balance'):
return self._cached_live_balance
elif hasattr(self.trading_executor, 'simulation_mode') and self.trading_executor.simulation_mode:
# In simulation mode, show dynamic balance based on P&L
initial_balance = self._get_initial_balance()
realized_pnl = sum(trade.get('pnl', 0) for trade in self.closed_trades)
simulation_balance = initial_balance + realized_pnl
logger.debug(f"SIMULATION BALANCE: ${simulation_balance:.2f} (Initial: ${initial_balance:.2f} + P&L: ${realized_pnl:.2f})")
return simulation_balance
else:
logger.debug("LIVE BALANCE: Not in live trading mode, using initial balance")
# Fallback to initial balance for simulation mode
return self._get_initial_balance()
except Exception as e:
logger.error(f"Error getting live balance: {e}")
# Return cached balance if available, otherwise fallback
if hasattr(self, '_cached_live_balance'):
return self._cached_live_balance
return self._get_initial_balance()
def _setup_layout(self):
"""Setup the dashboard layout using layout manager"""
@@ -411,17 +474,48 @@ class CleanTradingDashboard:
trade_count = len(self.closed_trades)
trade_str = f"{trade_count} Trades"
# Portfolio value
initial_balance = self._get_initial_balance()
portfolio_value = initial_balance + total_session_pnl # Use total P&L including unrealized
portfolio_str = f"${portfolio_value:.2f}"
# Portfolio value - use live balance for live trading
current_balance = self._get_live_balance()
portfolio_value = current_balance + total_session_pnl # Use total P&L including unrealized
# MEXC status
# Show live balance indicator for live trading
balance_indicator = ""
if self.trading_executor:
is_live = (hasattr(self.trading_executor, 'trading_enabled') and
self.trading_executor.trading_enabled and
hasattr(self.trading_executor, 'simulation_mode') and
not self.trading_executor.simulation_mode)
if is_live:
balance_indicator = " (LIVE)"
portfolio_str = f"${portfolio_value:.2f}{balance_indicator}"
# MEXC status with balance info
mexc_status = "SIM"
if self.trading_executor:
if hasattr(self.trading_executor, 'trading_enabled') and self.trading_executor.trading_enabled:
if hasattr(self.trading_executor, 'simulation_mode') and not self.trading_executor.simulation_mode:
mexc_status = "LIVE"
if hasattr(self.trading_executor, 'simulation_mode') and self.trading_executor.simulation_mode:
# Show simulation mode status with simulated balance
mexc_status = f"SIM - ${current_balance:.2f}"
elif hasattr(self.trading_executor, 'simulation_mode') and not self.trading_executor.simulation_mode:
# Show live balance in MEXC status - detect currency
try:
exchange = self.trading_executor.exchange
usdc_balance = exchange.get_balance('USDC') if hasattr(exchange, 'get_balance') else 0
usdt_balance = exchange.get_balance('USDT') if hasattr(exchange, 'get_balance') else 0
if usdc_balance > 0:
mexc_status = f"LIVE - ${usdc_balance:.2f} USDC"
elif usdt_balance > 0:
mexc_status = f"LIVE - ${usdt_balance:.2f} USDT"
else:
mexc_status = f"LIVE - ${current_balance:.2f}"
except:
mexc_status = f"LIVE - ${current_balance:.2f}"
else:
mexc_status = "SIM"
else:
mexc_status = "DISABLED"
return price_str, session_pnl_str, position_str, trade_str, portfolio_str, mexc_status
@@ -504,18 +598,31 @@ class CleanTradingDashboard:
def update_cob_data(n):
"""Update COB data displays with real order book ladders and cumulative stats"""
try:
# Update less frequently to reduce flickering
if n % self.update_batch_interval != 0:
raise PreventUpdate
# COB data is critical - update every second (no batching)
# if n % self.update_batch_interval != 0:
# raise PreventUpdate
eth_snapshot = self._get_cob_snapshot('ETH/USDT')
btc_snapshot = self._get_cob_snapshot('BTC/USDT')
# Debug: Log COB data availability
if n % 5 == 0: # Log every 5 seconds to avoid spam
logger.info(f"COB Update #{n}: ETH snapshot: {eth_snapshot is not None}, BTC snapshot: {btc_snapshot is not None}")
if hasattr(self, 'latest_cob_data'):
eth_data_time = self.cob_last_update.get('ETH/USDT', 0) if hasattr(self, 'cob_last_update') else 0
btc_data_time = self.cob_last_update.get('BTC/USDT', 0) if hasattr(self, 'cob_last_update') else 0
import time
current_time = time.time()
logger.info(f"COB Data Age: ETH: {current_time - eth_data_time:.1f}s, BTC: {current_time - btc_data_time:.1f}s")
eth_imbalance_stats = self._calculate_cumulative_imbalance('ETH/USDT')
btc_imbalance_stats = self._calculate_cumulative_imbalance('BTC/USDT')
eth_components = self.component_manager.format_cob_data(eth_snapshot, 'ETH/USDT', eth_imbalance_stats)
btc_components = self.component_manager.format_cob_data(btc_snapshot, 'BTC/USDT', btc_imbalance_stats)
# Determine COB data source mode
cob_mode = self._get_cob_mode()
eth_components = self.component_manager.format_cob_data(eth_snapshot, 'ETH/USDT', eth_imbalance_stats, cob_mode)
btc_components = self.component_manager.format_cob_data(btc_snapshot, 'BTC/USDT', btc_imbalance_stats, cob_mode)
return eth_components, btc_components
@@ -580,6 +687,34 @@ class CleanTradingDashboard:
return f"x{leverage_value}"
return "x50"
# Entry Aggressiveness slider callback
@self.app.callback(
Output('entry-agg-display', 'children'),
[Input('entry-aggressiveness-slider', 'value')]
)
def update_entry_aggressiveness_display(agg_value):
"""Update entry aggressiveness display and orchestrator setting"""
if agg_value is not None:
# Update orchestrator's entry aggressiveness
if self.orchestrator:
self.orchestrator.entry_aggressiveness = agg_value
return f"{agg_value:.1f}"
return "0.5"
# Exit Aggressiveness slider callback
@self.app.callback(
Output('exit-agg-display', 'children'),
[Input('exit-aggressiveness-slider', 'value')]
)
def update_exit_aggressiveness_display(agg_value):
"""Update exit aggressiveness display and orchestrator setting"""
if agg_value is not None:
# Update orchestrator's exit aggressiveness
if self.orchestrator:
self.orchestrator.exit_aggressiveness = agg_value
return f"{agg_value:.1f}"
return "0.5"
# Clear session button
@self.app.callback(
Output('clear-session-btn', 'children'),
@@ -1124,27 +1259,11 @@ class CleanTradingDashboard:
def _add_cob_rl_predictions_to_chart(self, fig: go.Figure, symbol: str, df_main: pd.DataFrame, row: int = 1):
"""Add COB_RL microstructure predictions as diamond markers"""
try:
# Get recent COB_RL predictions (simulated for now since model is FRESH)
current_time = datetime.now()
current_price = self._get_current_price(symbol) or 3500.0
# Get real COB_RL predictions from orchestrator or enhanced training system
cob_predictions = self._get_real_cob_rl_predictions(symbol)
# Generate sample COB_RL predictions for visualization
cob_predictions = []
for i in range(10): # Generate 10 sample predictions over last 5 minutes
pred_time = current_time - timedelta(minutes=i * 0.5)
price_variation = (i % 3 - 1) * 2.0 # Small price variations
# Simulate COB_RL predictions based on microstructure analysis
direction = (i % 3) # 0=DOWN, 1=SIDEWAYS, 2=UP
confidence = 0.65 + (i % 4) * 0.08 # Varying confidence
cob_predictions.append({
'timestamp': pred_time,
'direction': direction,
'confidence': confidence,
'price': current_price + price_variation,
'microstructure_signal': ['SELL_PRESSURE', 'BALANCED', 'BUY_PRESSURE'][direction]
})
if not cob_predictions:
return # No real predictions to display
# Separate predictions by direction
up_predictions = [p for p in cob_predictions if p['direction'] == 2]
@@ -1315,6 +1434,61 @@ class CleanTradingDashboard:
except Exception as e:
logger.debug(f"Error adding prediction accuracy feedback to chart: {e}")
def _get_real_cob_rl_predictions(self, symbol: str) -> List[Dict]:
"""Get real COB RL predictions from the model"""
try:
cob_predictions = []
# Get predictions from enhanced training system
if hasattr(self, 'enhanced_training_system') and self.enhanced_training_system:
if hasattr(self.enhanced_training_system, 'get_prediction_summary'):
summary = self.enhanced_training_system.get_prediction_summary(symbol)
if summary and 'cob_rl_predictions' in summary:
raw_predictions = summary['cob_rl_predictions'][-10:] # Last 10 predictions
for pred in raw_predictions:
if 'timestamp' in pred and 'direction' in pred:
cob_predictions.append({
'timestamp': pred['timestamp'],
'direction': pred['direction'],
'confidence': pred.get('confidence', 0.5),
'price': pred.get('price', self._get_current_price(symbol) or 3500.0),
'microstructure_signal': pred.get('signal', ['SELL_PRESSURE', 'BALANCED', 'BUY_PRESSURE'][pred['direction']])
})
# Fallback to orchestrator COB RL agent predictions
if not cob_predictions and self.orchestrator:
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
agent = self.orchestrator.cob_rl_agent
# Check if agent has recent predictions stored
if hasattr(agent, 'recent_predictions'):
for pred in agent.recent_predictions[-10:]:
cob_predictions.append({
'timestamp': pred.get('timestamp', datetime.now()),
'direction': pred.get('action', 1), # 0=SELL, 1=HOLD, 2=BUY
'confidence': pred.get('confidence', 0.5),
'price': pred.get('price', self._get_current_price(symbol) or 3500.0),
'microstructure_signal': ['SELL_PRESSURE', 'BALANCED', 'BUY_PRESSURE'][pred.get('action', 1)]
})
# Alternative: Try getting predictions from RL agent (DQN can handle COB features)
elif hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
agent = self.orchestrator.rl_agent
if hasattr(agent, 'recent_predictions'):
for pred in agent.recent_predictions[-10:]:
cob_predictions.append({
'timestamp': pred.get('timestamp', datetime.now()),
'direction': pred.get('action', 1),
'confidence': pred.get('confidence', 0.5),
'price': pred.get('price', self._get_current_price(symbol) or 3500.0),
'microstructure_signal': ['SELL_PRESSURE', 'BALANCED', 'BUY_PRESSURE'][pred.get('action', 1)]
})
return cob_predictions
except Exception as e:
logger.debug(f"Error getting real COB RL predictions: {e}")
return []
def _get_recent_dqn_predictions(self, symbol: str) -> List[Dict]:
"""Get recent DQN predictions from orchestrator with sample generation"""
try:
@@ -1953,6 +2127,27 @@ class CleanTradingDashboard:
except Exception as e:
logger.warning(f"Error getting COB snapshot for {symbol}: {e}")
return None
def _get_cob_mode(self) -> str:
"""Get current COB data collection mode"""
try:
# Check if orchestrator COB integration is working
if hasattr(self.orchestrator, 'cob_integration') and self.orchestrator.cob_integration:
# Try to get a snapshot from orchestrator
snapshot = self.orchestrator.cob_integration.get_cob_snapshot('ETH/USDT')
if snapshot and hasattr(snapshot, 'consolidated_bids') and snapshot.consolidated_bids:
return "WS" # WebSocket/Advanced mode
# Check if fallback data is available
if hasattr(self, 'latest_cob_data') and 'ETH/USDT' in self.latest_cob_data:
if self.latest_cob_data['ETH/USDT']:
return "REST" # REST API fallback mode
return "None" # No data available
except Exception as e:
logger.debug(f"Error determining COB mode: {e}")
return "Error"
def _get_enhanced_training_stats(self) -> Dict[str, Any]:
"""Get enhanced training statistics from the training system and orchestrator"""
@@ -2775,6 +2970,39 @@ class CleanTradingDashboard:
except Exception as e:
logger.error(f"Error starting signal generation loop: {e}")
def _start_live_balance_sync(self):
"""Start continuous live balance synchronization for trading"""
def balance_sync_worker():
while True:
try:
if self.trading_executor:
is_live = (hasattr(self.trading_executor, 'trading_enabled') and
self.trading_executor.trading_enabled and
hasattr(self.trading_executor, 'simulation_mode') and
not self.trading_executor.simulation_mode)
if is_live and hasattr(self.trading_executor, 'exchange'):
# Force balance refresh every 15 seconds in live mode
if hasattr(self, '_last_balance_check'):
del self._last_balance_check # Force refresh
balance = self._get_live_balance()
if balance > 0:
logger.debug(f"BALANCE SYNC: Live balance: ${balance:.2f}")
else:
logger.warning("BALANCE SYNC: Could not retrieve live balance")
# Sync balance every 15 seconds for live trading
time.sleep(15)
except Exception as e:
logger.debug(f"Error in balance sync loop: {e}")
time.sleep(30) # Wait longer on error
# Start balance sync thread only if we have trading enabled
if self.trading_executor:
threading.Thread(target=balance_sync_worker, daemon=True).start()
logger.info("BALANCE SYNC: Background balance synchronization started")
def _generate_dqn_signal(self, symbol: str, current_price: float) -> Optional[Dict]:
"""Generate trading signal using DQN agent - NOT AVAILABLE IN BASIC ORCHESTRATOR"""
@@ -4418,28 +4646,35 @@ class CleanTradingDashboard:
imbalance = cob_snapshot['stats']['imbalance']
abs_imbalance = abs(imbalance)
# Dynamic threshold based on imbalance strength
# Dynamic threshold based on imbalance strength with realistic confidence
if abs_imbalance > 0.8: # Very strong imbalance (>80%)
threshold = 0.05 # 5% threshold for very strong signals
confidence_multiplier = 3.0
base_confidence = 0.85 # High but not perfect confidence
confidence_boost = (abs_imbalance - 0.8) * 0.75 # Scale remaining 15%
elif abs_imbalance > 0.5: # Strong imbalance (>50%)
threshold = 0.1 # 10% threshold for strong signals
confidence_multiplier = 2.5
base_confidence = 0.70 # Good confidence
confidence_boost = (abs_imbalance - 0.5) * 0.50 # Scale up to 85%
elif abs_imbalance > 0.3: # Moderate imbalance (>30%)
threshold = 0.15 # 15% threshold for moderate signals
confidence_multiplier = 2.0
base_confidence = 0.55 # Moderate confidence
confidence_boost = (abs_imbalance - 0.3) * 0.75 # Scale up to 70%
else: # Weak imbalance
threshold = 0.2 # 20% threshold for weak signals
confidence_multiplier = 1.5
base_confidence = 0.35 # Low confidence
confidence_boost = abs_imbalance * 0.67 # Scale up to 55%
# Generate signal if imbalance exceeds threshold
if abs_imbalance > threshold:
# Calculate more realistic confidence (never exactly 1.0)
final_confidence = min(0.95, base_confidence + confidence_boost)
signal = {
'timestamp': datetime.now(),
'type': 'cob_liquidity_imbalance',
'action': 'BUY' if imbalance > 0 else 'SELL',
'symbol': symbol,
'confidence': min(1.0, abs_imbalance * confidence_multiplier),
'confidence': final_confidence,
'strength': abs_imbalance,
'threshold_used': threshold,
'signal_strength': 'very_strong' if abs_imbalance > 0.8 else 'strong' if abs_imbalance > 0.5 else 'moderate' if abs_imbalance > 0.3 else 'weak',
@@ -5150,12 +5385,24 @@ class CleanTradingDashboard:
logger.error(f"Error updating session metrics: {e}")
def _start_actual_training_if_needed(self):
"""Connect to centralized training system in orchestrator (following architecture)"""
"""Connect to centralized training system in orchestrator and start training"""
try:
if not self.orchestrator:
logger.warning("No orchestrator available for training connection")
return
logger.info("DASHBOARD: Connected to orchestrator's centralized training system")
# Actually start the orchestrator's enhanced training system
if hasattr(self.orchestrator, 'start_enhanced_training'):
training_started = self.orchestrator.start_enhanced_training()
if training_started:
logger.info("TRAINING: Orchestrator enhanced training system started successfully")
else:
logger.warning("TRAINING: Failed to start orchestrator enhanced training system")
else:
logger.warning("TRAINING: Orchestrator does not have enhanced training system")
# Dashboard only displays training status - actual training happens in orchestrator
# Training is centralized in the orchestrator as per architecture design
except Exception as e:
@@ -5365,15 +5612,18 @@ class CleanTradingDashboard:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Get the model's device to ensure tensors are on the same device
model_device = next(model.parameters()).device
# Handle different input shapes for different CNN models
if hasattr(model, 'input_shape'):
# EnhancedCNN model
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(device)
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(model_device)
else:
# Basic CNN model - reshape appropriately
features_tensor = torch.FloatTensor(features).unsqueeze(0).unsqueeze(0).to(device)
features_tensor = torch.FloatTensor(features).unsqueeze(0).unsqueeze(0).to(model_device)
target_tensor = torch.LongTensor([target]).to(device)
target_tensor = torch.LongTensor([target]).to(model_device)
# Set model to training mode and zero gradients
model.train()
@@ -5492,10 +5742,11 @@ class CleanTradingDashboard:
if hasattr(network, 'forward'):
import torch
import torch.nn as nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(device)
action_target_tensor = torch.LongTensor([action_target]).to(device)
confidence_target_tensor = torch.FloatTensor([confidence_target]).to(device)
# Get the model's device to ensure tensors are on the same device
model_device = next(network.parameters()).device
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(model_device)
action_target_tensor = torch.LongTensor([action_target]).to(model_device)
confidence_target_tensor = torch.FloatTensor([confidence_target]).to(model_device)
network.train()
network_output = network(features_tensor)
@@ -5919,30 +6170,7 @@ class CleanTradingDashboard:
cob_rl_agent = self.orchestrator.cob_rl_agent
if not cob_rl_agent:
# Create a simple checkpoint to prevent recreation if no agent available
try:
from utils.checkpoint_manager import save_checkpoint
checkpoint_data = {
'model_state_dict': {},
'training_samples': len(market_data),
'cob_features_processed': True
}
performance_metrics = {
'loss': 0.356,
'training_samples': len(market_data),
'model_parameters': 0
}
metadata = save_checkpoint(
model=checkpoint_data,
model_name="cob_rl",
model_type="cob_rl",
performance_metrics=performance_metrics,
training_metadata={'cob_data_processed': True}
)
if metadata:
logger.info(f"COB RL placeholder checkpoint saved: {metadata.checkpoint_id}")
except Exception as e:
logger.error(f"Error saving COB RL placeholder checkpoint: {e}")
logger.debug("No COB RL agent available for training")
return
# Perform actual COB RL training

View File

@@ -272,13 +272,14 @@ class DashboardComponentManager:
logger.error(f"Error formatting system status: {e}")
return [html.P(f"Error: {str(e)}", className="text-danger small")]
def format_cob_data(self, cob_snapshot, symbol, cumulative_imbalance_stats=None):
def format_cob_data(self, cob_snapshot, symbol, cumulative_imbalance_stats=None, cob_mode="Unknown"):
"""Format COB data into a split view with summary, imbalance stats, and a compact ladder."""
try:
if not cob_snapshot:
return html.Div([
html.H6(f"{symbol} COB", className="mb-2"),
html.P("No COB data available", className="text-muted small")
html.P("No COB data available", className="text-muted small"),
html.P(f"Mode: {cob_mode}", className="text-muted small")
])
# Handle both old format (with stats attribute) and new format (direct attributes)
@@ -316,7 +317,7 @@ class DashboardComponentManager:
}
# --- Left Panel: Overview and Stats ---
overview_panel = self._create_cob_overview_panel(symbol, stats, cumulative_imbalance_stats)
overview_panel = self._create_cob_overview_panel(symbol, stats, cumulative_imbalance_stats, cob_mode)
# --- Right Panel: Compact Ladder ---
ladder_panel = self._create_cob_ladder_panel(bids, asks, mid_price, symbol)
@@ -330,7 +331,7 @@ class DashboardComponentManager:
logger.error(f"Error formatting split COB data: {e}")
return html.P(f"Error: {str(e)}", className="text-danger small")
def _create_cob_overview_panel(self, symbol, stats, cumulative_imbalance_stats):
def _create_cob_overview_panel(self, symbol, stats, cumulative_imbalance_stats, cob_mode="Unknown"):
"""Creates the left panel with summary and imbalance stats."""
mid_price = stats.get('mid_price', 0)
spread_bps = stats.get('spread_bps', 0)
@@ -342,6 +343,10 @@ class DashboardComponentManager:
imbalance_text = f"Bid Heavy ({imbalance:.3f})" if imbalance > 0 else f"Ask Heavy ({imbalance:.3f})"
imbalance_color = "text-success" if imbalance > 0 else "text-danger"
# COB mode indicator
mode_color = "text-success" if cob_mode == "WS" else "text-warning" if cob_mode == "REST" else "text-muted"
mode_icon = "fas fa-wifi" if cob_mode == "WS" else "fas fa-globe" if cob_mode == "REST" else "fas fa-question"
imbalance_stats_display = []
if cumulative_imbalance_stats:
imbalance_stats_display.append(html.H6("Cumulative Imbalance", className="mt-3 mb-2 small text-muted text-uppercase"))
@@ -350,6 +355,12 @@ class DashboardComponentManager:
return html.Div([
html.H6(f"{symbol} - COB Overview", className="mb-2"),
html.Div([
html.Span([
html.I(className=f"{mode_icon} me-1 {mode_color}"),
html.Span(f"Mode: {cob_mode}", className=f"small {mode_color}")
], className="mb-2")
]),
html.Div([
self._create_stat_card("Mid Price", f"${mid_price:,.2f}", "fas fa-dollar-sign"),
self._create_stat_card("Spread", f"{spread_bps:.1f} bps", "fas fa-arrows-alt-h")

View File

@@ -145,6 +145,50 @@ class DashboardLayoutManager:
)
], className="mb-2"),
# Entry Aggressiveness Control
html.Div([
html.Label([
html.I(className="fas fa-bullseye me-1"),
"Entry Aggressiveness: ",
html.Span(id="entry-agg-display", children="0.5", className="fw-bold text-success")
], className="form-label small mb-1"),
dcc.Slider(
id='entry-aggressiveness-slider',
min=0.0,
max=1.0,
step=0.1,
value=0.5,
marks={
0.0: {'label': 'Conservative', 'style': {'fontSize': '7px'}},
0.5: {'label': 'Balanced', 'style': {'fontSize': '7px'}},
1.0: {'label': 'Aggressive', 'style': {'fontSize': '7px'}}
},
tooltip={"placement": "bottom", "always_visible": False}
)
], className="mb-2"),
# Exit Aggressiveness Control
html.Div([
html.Label([
html.I(className="fas fa-sign-out-alt me-1"),
"Exit Aggressiveness: ",
html.Span(id="exit-agg-display", children="0.5", className="fw-bold text-danger")
], className="form-label small mb-1"),
dcc.Slider(
id='exit-aggressiveness-slider',
min=0.0,
max=1.0,
step=0.1,
value=0.5,
marks={
0.0: {'label': 'Conservative', 'style': {'fontSize': '7px'}},
0.5: {'label': 'Balanced', 'style': {'fontSize': '7px'}},
1.0: {'label': 'Aggressive', 'style': {'fontSize': '7px'}}
},
tooltip={"placement": "bottom", "always_visible": False}
)
], className="mb-2"),
html.Button([
html.I(className="fas fa-trash me-1"),
"Clear Session"