i think we fixed mexc interface at the end!!!
This commit is contained in:
@ -1549,38 +1549,28 @@ class TradingOrchestrator:
|
||||
self.model_states['extrema_trainer']['current_loss'] = estimated_loss
|
||||
self.model_states['extrema_trainer']['best_loss'] = estimated_loss
|
||||
|
||||
# Ensure initial_loss is set for new models
|
||||
# NO LONGER SETTING SYNTHETIC INITIAL LOSS VALUES
|
||||
# Keep all None values as None if no real data is available
|
||||
# This prevents the "fake progress" issue where Current Loss = Initial Loss
|
||||
|
||||
# Only set initial_loss from actual training history if available
|
||||
for model_key, model_state in self.model_states.items():
|
||||
if model_state['initial_loss'] is None:
|
||||
# Set reasonable initial loss values for new models
|
||||
initial_losses = {
|
||||
'dqn': 0.285,
|
||||
'cnn': 0.412,
|
||||
'cob_rl': 0.356,
|
||||
'decision': 0.298,
|
||||
'extrema_trainer': 0.356
|
||||
}
|
||||
model_state['initial_loss'] = initial_losses.get(model_key, 0.3)
|
||||
|
||||
# If current_loss is None, set it to initial_loss
|
||||
if model_state['current_loss'] is None:
|
||||
model_state['current_loss'] = model_state['initial_loss']
|
||||
|
||||
# If best_loss is None, set it to current_loss
|
||||
if model_state['best_loss'] is None:
|
||||
model_state['best_loss'] = model_state['current_loss']
|
||||
# Leave initial_loss as None if no real training history exists
|
||||
# Leave current_loss as None if model isn't actively training
|
||||
# Leave best_loss as None if no checkpoints exist with real performance data
|
||||
pass # No synthetic data generation
|
||||
|
||||
return self.model_states
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model states: {e}")
|
||||
# Return safe fallback values
|
||||
# Return None values instead of synthetic data
|
||||
return {
|
||||
'dqn': {'initial_loss': 0.285, 'current_loss': 0.285, 'best_loss': 0.285, 'checkpoint_loaded': False},
|
||||
'cnn': {'initial_loss': 0.412, 'current_loss': 0.412, 'best_loss': 0.412, 'checkpoint_loaded': False},
|
||||
'cob_rl': {'initial_loss': 0.356, 'current_loss': 0.356, 'best_loss': 0.356, 'checkpoint_loaded': False},
|
||||
'decision': {'initial_loss': 0.298, 'current_loss': 0.298, 'best_loss': 0.298, 'checkpoint_loaded': False},
|
||||
'extrema_trainer': {'initial_loss': 0.356, 'current_loss': 0.356, 'best_loss': 0.356, 'checkpoint_loaded': False}
|
||||
'dqn': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||||
'cnn': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||||
'cob_rl': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||||
'decision': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||||
'extrema_trainer': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False}
|
||||
}
|
||||
|
||||
def update_model_loss(self, model_name: str, current_loss: float, best_loss: float = None):
|
||||
|
@ -59,7 +59,7 @@ class SignalAccumulator:
|
||||
confidence_sum: float = 0.0
|
||||
successful_predictions: int = 0
|
||||
total_predictions: int = 0
|
||||
last_reset_time: datetime = None
|
||||
last_reset_time: Optional[datetime] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.signals is None:
|
||||
@ -99,12 +99,13 @@ class RealtimeRLCOBTrader:
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
symbols: List[str] = None,
|
||||
trading_executor: TradingExecutor = None,
|
||||
symbols: Optional[List[str]] = None,
|
||||
trading_executor: Optional[TradingExecutor] = None,
|
||||
model_checkpoint_dir: str = "models/realtime_rl_cob",
|
||||
inference_interval_ms: int = 200,
|
||||
min_confidence_threshold: float = 0.35, # Lowered from 0.7 for more aggressive trading
|
||||
required_confident_predictions: int = 3):
|
||||
required_confident_predictions: int = 3,
|
||||
checkpoint_manager: Any = None):
|
||||
|
||||
self.symbols = symbols or ['BTC/USDT', 'ETH/USDT']
|
||||
self.trading_executor = trading_executor
|
||||
@ -113,6 +114,16 @@ class RealtimeRLCOBTrader:
|
||||
self.min_confidence_threshold = min_confidence_threshold
|
||||
self.required_confident_predictions = required_confident_predictions
|
||||
|
||||
# Initialize CheckpointManager (either provided or get global instance)
|
||||
if checkpoint_manager is None:
|
||||
from utils.checkpoint_manager import get_checkpoint_manager
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
else:
|
||||
self.checkpoint_manager = checkpoint_manager
|
||||
|
||||
# Track start time for training duration calculation
|
||||
self.start_time = datetime.now() # Initialize start_time
|
||||
|
||||
# Setup device
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logger.info(f"Using device: {self.device}")
|
||||
@ -819,29 +830,26 @@ class RealtimeRLCOBTrader:
|
||||
actual_direction = 1 # SIDEWAYS
|
||||
|
||||
# Calculate reward based on prediction accuracy
|
||||
reward = self._calculate_prediction_reward(
|
||||
prediction.predicted_direction,
|
||||
actual_direction,
|
||||
prediction.confidence,
|
||||
prediction.predicted_change,
|
||||
actual_change
|
||||
prediction.reward = self._calculate_prediction_reward(
|
||||
symbol=symbol,
|
||||
predicted_direction=prediction.predicted_direction,
|
||||
actual_direction=actual_direction,
|
||||
confidence=prediction.confidence,
|
||||
predicted_change=prediction.predicted_change,
|
||||
actual_change=actual_change
|
||||
)
|
||||
|
||||
# Update prediction
|
||||
prediction.actual_direction = actual_direction
|
||||
prediction.actual_change = actual_change
|
||||
prediction.reward = reward
|
||||
|
||||
# Update training stats
|
||||
stats = self.training_stats[symbol]
|
||||
stats['total_predictions'] += 1
|
||||
if reward > 0:
|
||||
if prediction.reward > 0:
|
||||
stats['successful_predictions'] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating rewards for {symbol}: {e}")
|
||||
|
||||
def _calculate_prediction_reward(self,
|
||||
symbol: str,
|
||||
predicted_direction: int,
|
||||
actual_direction: int,
|
||||
confidence: float,
|
||||
@ -849,67 +857,52 @@ class RealtimeRLCOBTrader:
|
||||
actual_change: float,
|
||||
current_pnl: float = 0.0,
|
||||
position_duration: float = 0.0) -> float:
|
||||
"""Calculate reward for a prediction with PnL-aware loss cutting optimization"""
|
||||
try:
|
||||
# Base reward for correct direction
|
||||
if predicted_direction == actual_direction:
|
||||
base_reward = 1.0
|
||||
"""Calculate reward based on prediction accuracy and actual price movement"""
|
||||
reward = 0.0
|
||||
|
||||
# Base reward for correct direction prediction
|
||||
if predicted_direction == actual_direction:
|
||||
reward += 1.0 * confidence # Reward scales with confidence
|
||||
else:
|
||||
reward -= 0.5 # Penalize incorrect predictions
|
||||
|
||||
# Reward for predicting large changes correctly (proportional to actual change)
|
||||
if predicted_direction == actual_direction and abs(predicted_change) > 0.001:
|
||||
reward += abs(actual_change) * 5.0 # Amplify reward for significant moves
|
||||
|
||||
# Penalize for large predicted changes that are wrong
|
||||
if predicted_direction != actual_direction and abs(predicted_change) > 0.001:
|
||||
reward -= abs(predicted_change) * 2.0
|
||||
|
||||
# Add reward for PnL (realized or unrealized)
|
||||
reward += current_pnl * 0.1 # Small reward for PnL, adjusted by a factor
|
||||
|
||||
# Dynamic adjustment based on recent PnL (loss cutting incentive)
|
||||
if self.pnl_history[symbol]:
|
||||
latest_pnl_entry = self.pnl_history[symbol][-1] # Get the latest PnL entry
|
||||
# Ensure latest_pnl_entry is a dict and has 'pnl' key, otherwise default to 0.0
|
||||
latest_pnl_value = latest_pnl_entry.get('pnl', 0.0) if isinstance(latest_pnl_entry, dict) else 0.0
|
||||
|
||||
# Incentivize closing losing trades early
|
||||
if latest_pnl_value < 0 and position_duration > 60: # If losing position open for > 60s
|
||||
# More aggressively penalize holding losing positions, or reward closing them
|
||||
reward -= (abs(latest_pnl_value) * 0.2) # Increased penalty for sustained losses
|
||||
|
||||
# Discourage taking new positions if overall PnL is negative or volatile
|
||||
# This requires a more complex calculation of overall PnL, potentially average of last N trades
|
||||
# For simplicity, let's use the 'best_pnl' to decide if we are in a good state to trade
|
||||
|
||||
# Calculate the current best PnL from history, ensuring it's not empty
|
||||
pnl_values = [entry.get('pnl', 0.0) for entry in self.pnl_history[symbol] if isinstance(entry, dict)]
|
||||
if not pnl_values:
|
||||
best_pnl = 0.0
|
||||
else:
|
||||
base_reward = -1.0
|
||||
|
||||
# Scale by confidence
|
||||
confidence_scaled_reward = base_reward * confidence
|
||||
|
||||
# Additional reward for magnitude accuracy
|
||||
if predicted_direction != 1: # Not sideways
|
||||
magnitude_accuracy = 1.0 - abs(predicted_change - actual_change) / max(abs(actual_change), 0.001)
|
||||
magnitude_accuracy = max(0.0, magnitude_accuracy)
|
||||
confidence_scaled_reward += magnitude_accuracy * 0.5
|
||||
|
||||
# Penalty for overconfident wrong predictions
|
||||
if base_reward < 0 and confidence > 0.8:
|
||||
confidence_scaled_reward *= 1.5 # Increase penalty
|
||||
|
||||
# === PnL-AWARE LOSS CUTTING REWARDS ===
|
||||
|
||||
pnl_reward = 0.0
|
||||
|
||||
# Reward cutting losses early (SIDEWAYS when losing)
|
||||
if current_pnl < -10.0: # In significant loss
|
||||
if predicted_direction == 1: # SIDEWAYS (exit signal)
|
||||
# Reward cutting losses before they get worse
|
||||
loss_cutting_bonus = min(1.0, abs(current_pnl) / 100.0) * confidence
|
||||
pnl_reward += loss_cutting_bonus
|
||||
elif predicted_direction != 1: # Continuing to trade while in loss
|
||||
# Penalty for not cutting losses
|
||||
pnl_reward -= 0.5 * confidence
|
||||
|
||||
# Reward protecting profits (SIDEWAYS when in profit and market turning)
|
||||
elif current_pnl > 10.0: # In profit
|
||||
if predicted_direction == 1 and base_reward > 0: # Correct SIDEWAYS prediction
|
||||
# Reward protecting profits from reversal
|
||||
profit_protection_bonus = min(0.5, current_pnl / 200.0) * confidence
|
||||
pnl_reward += profit_protection_bonus
|
||||
|
||||
# Duration penalty for holding losing positions
|
||||
if current_pnl < 0 and position_duration > 3600: # Losing for > 1 hour
|
||||
duration_penalty = min(1.0, position_duration / 7200.0) * 0.3 # Up to 30% penalty
|
||||
confidence_scaled_reward -= duration_penalty
|
||||
|
||||
# Severe penalty for letting small losses become big losses
|
||||
if current_pnl < -50.0: # Large loss
|
||||
drawdown_penalty = min(2.0, abs(current_pnl) / 100.0) * confidence
|
||||
confidence_scaled_reward -= drawdown_penalty
|
||||
|
||||
# Total reward
|
||||
total_reward = confidence_scaled_reward + pnl_reward
|
||||
|
||||
# Clamp final reward
|
||||
return max(-5.0, min(5.0, float(total_reward)))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating reward: {e}")
|
||||
return 0.0
|
||||
best_pnl = max(pnl_values)
|
||||
|
||||
if best_pnl < 0.0: # If recent best PnL is negative, reduce reward for new trades
|
||||
reward -= 0.1 # Small penalty for trading in a losing streak
|
||||
|
||||
return reward
|
||||
|
||||
async def _train_batch(self, symbol: str, predictions: List[PredictionResult]) -> float:
|
||||
"""Train model on a batch of predictions"""
|
||||
@ -1021,20 +1014,36 @@ class RealtimeRLCOBTrader:
|
||||
await asyncio.sleep(60)
|
||||
|
||||
def _save_models(self):
|
||||
"""Save all models to disk"""
|
||||
"""Save all models to disk using CheckpointManager"""
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
symbol_safe = symbol.replace('/', '_')
|
||||
model_path = os.path.join(self.model_checkpoint_dir, f"{symbol_safe}_model.pt")
|
||||
model_name = f"cob_rl_{symbol.replace('/', '_').lower()}" # Standardize model name for CheckpointManager
|
||||
|
||||
# Save model state
|
||||
torch.save({
|
||||
'model_state_dict': self.models[symbol].state_dict(),
|
||||
'optimizer_state_dict': self.optimizers[symbol].state_dict(),
|
||||
'training_stats': self.training_stats[symbol],
|
||||
'inference_stats': self.inference_stats[symbol],
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}, model_path)
|
||||
# Prepare performance metrics for CheckpointManager
|
||||
performance_metrics = {
|
||||
'loss': self.training_stats[symbol].get('average_loss', 0.0),
|
||||
'reward': self.training_stats[symbol].get('average_reward', 0.0), # Assuming average_reward is tracked
|
||||
'accuracy': self.training_stats[symbol].get('average_accuracy', 0.0), # Assuming average_accuracy is tracked
|
||||
}
|
||||
if self.trading_executor: # Add check for trading_executor
|
||||
daily_stats = self.trading_executor.get_daily_stats()
|
||||
performance_metrics['pnl'] = daily_stats.get('total_pnl', 0.0) # Example, get actual pnl
|
||||
performance_metrics['training_samples'] = self.training_stats[symbol].get('total_training_steps', 0)
|
||||
|
||||
# Prepare training metadata for CheckpointManager
|
||||
training_metadata = {
|
||||
'total_parameters': sum(p.numel() for p in self.models[symbol].parameters()),
|
||||
'epoch': self.training_stats[symbol].get('total_training_steps', 0), # Using total_training_steps as pseudo-epoch
|
||||
'training_time_hours': (datetime.now() - self.start_time).total_seconds() / 3600
|
||||
}
|
||||
|
||||
self.checkpoint_manager.save_checkpoint(
|
||||
model=self.models[symbol],
|
||||
model_name=model_name,
|
||||
model_type='COB_RL', # Specify model type
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
logger.debug(f"Saved model for {symbol}")
|
||||
|
||||
@ -1042,13 +1051,15 @@ class RealtimeRLCOBTrader:
|
||||
logger.error(f"Error saving models: {e}")
|
||||
|
||||
def _load_models(self):
|
||||
"""Load existing models from disk"""
|
||||
"""Load existing models from disk using CheckpointManager"""
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
symbol_safe = symbol.replace('/', '_')
|
||||
model_path = os.path.join(self.model_checkpoint_dir, f"{symbol_safe}_model.pt")
|
||||
model_name = f"cob_rl_{symbol.replace('/', '_').lower()}" # Standardize model name for CheckpointManager
|
||||
|
||||
if os.path.exists(model_path):
|
||||
loaded_checkpoint = self.checkpoint_manager.load_best_checkpoint(model_name)
|
||||
|
||||
if loaded_checkpoint:
|
||||
model_path, metadata = loaded_checkpoint
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
|
||||
self.models[symbol].load_state_dict(checkpoint['model_state_dict'])
|
||||
@ -1059,9 +1070,9 @@ class RealtimeRLCOBTrader:
|
||||
if 'inference_stats' in checkpoint:
|
||||
self.inference_stats[symbol].update(checkpoint['inference_stats'])
|
||||
|
||||
logger.info(f"Loaded existing model for {symbol}")
|
||||
logger.info(f"Loaded existing model for {symbol} from checkpoint: {metadata.checkpoint_id}")
|
||||
else:
|
||||
logger.info(f"No existing model found for {symbol}, starting fresh")
|
||||
logger.info(f"No existing model found for {symbol} via CheckpointManager, starting fresh.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading models: {e}")
|
||||
@ -1111,7 +1122,7 @@ async def main():
|
||||
from ..core.trading_executor import TradingExecutor
|
||||
|
||||
# Initialize trading executor (simulation mode)
|
||||
trading_executor = TradingExecutor(simulation_mode=True)
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Initialize real-time RL trader
|
||||
trader = RealtimeRLCOBTrader(
|
||||
|
@ -93,7 +93,6 @@ class TradingExecutor:
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
test_mode=exchange_test_mode,
|
||||
trading_mode=trading_mode
|
||||
)
|
||||
|
||||
# Trading state
|
||||
@ -213,9 +212,15 @@ class TradingExecutor:
|
||||
# Determine the quote asset (e.g., USDT, USDC) from the symbol
|
||||
if '/' in symbol:
|
||||
quote_asset = symbol.split('/')[1].upper() # Assuming symbol is like ETH/USDT
|
||||
# Convert USDT to USDC for MEXC spot trading
|
||||
if quote_asset == 'USDT':
|
||||
quote_asset = 'USDC'
|
||||
else:
|
||||
# Fallback for symbols like ETHUSDT (assuming last 4 chars are quote)
|
||||
quote_asset = symbol[-4:].upper()
|
||||
quote_asset = symbol[-4:].upper()
|
||||
# Convert USDT to USDC for MEXC spot trading
|
||||
if quote_asset == 'USDT':
|
||||
quote_asset = 'USDC'
|
||||
|
||||
# Calculate required capital for the trade
|
||||
# If we are selling (to open a short position), we need collateral based on the position size
|
||||
@ -779,13 +784,14 @@ class TradingExecutor:
|
||||
logger.info("Daily trading statistics reset")
|
||||
|
||||
def get_account_balance(self) -> Dict[str, Dict[str, float]]:
|
||||
"""Get account balance information from MEXC
|
||||
"""Get account balance information from MEXC, including spot and futures.
|
||||
|
||||
Returns:
|
||||
Dict with asset balances in format:
|
||||
{
|
||||
'USDT': {'free': 100.0, 'locked': 0.0},
|
||||
'ETH': {'free': 0.5, 'locked': 0.0},
|
||||
'USDT': {'free': 100.0, 'locked': 0.0, 'total': 100.0, 'type': 'spot'},
|
||||
'ETH': {'free': 0.5, 'locked': 0.0, 'total': 0.5, 'type': 'spot'},
|
||||
'FUTURES_USDT': {'free': 500.0, 'locked': 50.0, 'total': 550.0, 'type': 'futures'}
|
||||
...
|
||||
}
|
||||
"""
|
||||
@ -794,28 +800,47 @@ class TradingExecutor:
|
||||
logger.error("Exchange interface not available")
|
||||
return {}
|
||||
|
||||
# Get account info from MEXC
|
||||
account_info = self.exchange.get_account_info()
|
||||
if not account_info:
|
||||
logger.error("Failed to get account info from MEXC")
|
||||
return {}
|
||||
combined_balances = {}
|
||||
|
||||
balances = {}
|
||||
for balance in account_info.get('balances', []):
|
||||
asset = balance.get('asset', '')
|
||||
free = float(balance.get('free', 0))
|
||||
locked = float(balance.get('locked', 0))
|
||||
|
||||
# Only include assets with non-zero balance
|
||||
if free > 0 or locked > 0:
|
||||
balances[asset] = {
|
||||
'free': free,
|
||||
'locked': locked,
|
||||
'total': free + locked
|
||||
}
|
||||
|
||||
logger.info(f"Retrieved balances for {len(balances)} assets")
|
||||
return balances
|
||||
# 1. Get Spot Account Info
|
||||
spot_account_info = self.exchange.get_account_info()
|
||||
if spot_account_info and 'balances' in spot_account_info:
|
||||
for balance in spot_account_info['balances']:
|
||||
asset = balance.get('asset', '')
|
||||
free = float(balance.get('free', 0))
|
||||
locked = float(balance.get('locked', 0))
|
||||
if free > 0 or locked > 0:
|
||||
combined_balances[asset] = {
|
||||
'free': free,
|
||||
'locked': locked,
|
||||
'total': free + locked,
|
||||
'type': 'spot'
|
||||
}
|
||||
else:
|
||||
logger.warning("Failed to get spot account info from MEXC or no balances found.")
|
||||
|
||||
# 2. Get Futures Account Info (commented out until futures API is implemented)
|
||||
# futures_account_info = self.exchange.get_futures_account_info()
|
||||
# if futures_account_info:
|
||||
# for currency, asset_data in futures_account_info.items():
|
||||
# # MEXC Futures API returns 'availableBalance' and 'frozenBalance'
|
||||
# free = float(asset_data.get('availableBalance', 0))
|
||||
# locked = float(asset_data.get('frozenBalance', 0))
|
||||
# total = free + locked # total is the sum of available and frozen
|
||||
# if free > 0 or locked > 0:
|
||||
# # Prefix with 'FUTURES_' to distinguish from spot, or decide on a unified key
|
||||
# # For now, let's keep them distinct for clarity
|
||||
# combined_balances[f'FUTURES_{currency}'] = {
|
||||
# 'free': free,
|
||||
# 'locked': locked,
|
||||
# 'total': total,
|
||||
# 'type': 'futures'
|
||||
# }
|
||||
# else:
|
||||
# logger.warning("Failed to get futures account info from MEXC or no futures assets found.")
|
||||
|
||||
logger.info(f"Retrieved combined balances for {len(combined_balances)} assets.")
|
||||
return combined_balances
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting account balance: {e}")
|
||||
|
Reference in New Issue
Block a user