more training
This commit is contained in:
@ -137,9 +137,16 @@ class DataProvider:
|
||||
logger.info(f"Using cached data for {symbol} {timeframe}")
|
||||
return cached_data.tail(limit)
|
||||
|
||||
# Fetch from API
|
||||
logger.info(f"Fetching historical data for {symbol} {timeframe}")
|
||||
df = self._fetch_from_binance(symbol, timeframe, limit)
|
||||
# Check if we need to preload 300s of data for first load
|
||||
should_preload = self._should_preload_data(symbol, timeframe, limit)
|
||||
|
||||
if should_preload:
|
||||
logger.info(f"Preloading 300s of data for {symbol} {timeframe}")
|
||||
df = self._preload_300s_data(symbol, timeframe)
|
||||
else:
|
||||
# Fetch from API with requested limit
|
||||
logger.info(f"Fetching historical data for {symbol} {timeframe}")
|
||||
df = self._fetch_from_binance(symbol, timeframe, limit)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Add technical indicators
|
||||
@ -154,7 +161,8 @@ class DataProvider:
|
||||
self.historical_data[symbol] = {}
|
||||
self.historical_data[symbol][timeframe] = df
|
||||
|
||||
return df
|
||||
# Return requested amount
|
||||
return df.tail(limit)
|
||||
|
||||
logger.warning(f"No data received for {symbol} {timeframe}")
|
||||
return None
|
||||
@ -163,6 +171,124 @@ class DataProvider:
|
||||
logger.error(f"Error fetching historical data for {symbol} {timeframe}: {e}")
|
||||
return None
|
||||
|
||||
def _should_preload_data(self, symbol: str, timeframe: str, limit: int) -> bool:
|
||||
"""Determine if we should preload 300s of data"""
|
||||
try:
|
||||
# Check if we have any cached data
|
||||
if self.cache_enabled:
|
||||
cached_data = self._load_from_cache(symbol, timeframe)
|
||||
if cached_data is not None and len(cached_data) > 0:
|
||||
return False # Already have some data
|
||||
|
||||
# Check if we have data in memory
|
||||
if (symbol in self.historical_data and
|
||||
timeframe in self.historical_data[symbol] and
|
||||
len(self.historical_data[symbol][timeframe]) > 0):
|
||||
return False # Already have data in memory
|
||||
|
||||
# Calculate if 300s worth of data would be more than requested limit
|
||||
timeframe_seconds = self.timeframe_seconds.get(timeframe, 60)
|
||||
candles_in_300s = 300 // timeframe_seconds
|
||||
|
||||
# Preload if we need more than the requested limit or if it's a short timeframe
|
||||
if candles_in_300s > limit or timeframe in ['1s', '1m']:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error determining if should preload data: {e}")
|
||||
return False
|
||||
|
||||
def _preload_300s_data(self, symbol: str, timeframe: str) -> Optional[pd.DataFrame]:
|
||||
"""Preload 300 seconds worth of data for better initial performance"""
|
||||
try:
|
||||
# Calculate how many candles we need for 300 seconds
|
||||
timeframe_seconds = self.timeframe_seconds.get(timeframe, 60)
|
||||
candles_needed = max(300 // timeframe_seconds, 100) # At least 100 candles
|
||||
|
||||
# For very short timeframes, limit to reasonable amount
|
||||
if timeframe == '1s':
|
||||
candles_needed = min(candles_needed, 300) # Max 300 1s candles
|
||||
elif timeframe == '1m':
|
||||
candles_needed = min(candles_needed, 60) # Max 60 1m candles (1 hour)
|
||||
else:
|
||||
candles_needed = min(candles_needed, 500) # Max 500 candles for other timeframes
|
||||
|
||||
logger.info(f"Preloading {candles_needed} candles for {symbol} {timeframe} (300s worth)")
|
||||
|
||||
# Fetch the data
|
||||
df = self._fetch_from_binance(symbol, timeframe, candles_needed)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
logger.info(f"Successfully preloaded {len(df)} candles for {symbol} {timeframe}")
|
||||
return df
|
||||
else:
|
||||
logger.warning(f"Failed to preload data for {symbol} {timeframe}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preloading 300s data for {symbol} {timeframe}: {e}")
|
||||
return None
|
||||
|
||||
def preload_all_symbols_data(self, timeframes: List[str] = None) -> Dict[str, Dict[str, bool]]:
|
||||
"""Preload 300s of data for all symbols and timeframes"""
|
||||
try:
|
||||
if timeframes is None:
|
||||
timeframes = self.timeframes
|
||||
|
||||
preload_results = {}
|
||||
|
||||
for symbol in self.symbols:
|
||||
preload_results[symbol] = {}
|
||||
|
||||
for timeframe in timeframes:
|
||||
try:
|
||||
logger.info(f"Preloading data for {symbol} {timeframe}")
|
||||
|
||||
# Check if we should preload
|
||||
if self._should_preload_data(symbol, timeframe, 100):
|
||||
df = self._preload_300s_data(symbol, timeframe)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Add technical indicators
|
||||
df = self._add_technical_indicators(df)
|
||||
|
||||
# Cache the data
|
||||
if self.cache_enabled:
|
||||
self._save_to_cache(df, symbol, timeframe)
|
||||
|
||||
# Store in memory
|
||||
if symbol not in self.historical_data:
|
||||
self.historical_data[symbol] = {}
|
||||
self.historical_data[symbol][timeframe] = df
|
||||
|
||||
preload_results[symbol][timeframe] = True
|
||||
logger.info(f"✅ Preloaded {len(df)} candles for {symbol} {timeframe}")
|
||||
else:
|
||||
preload_results[symbol][timeframe] = False
|
||||
logger.warning(f"❌ Failed to preload {symbol} {timeframe}")
|
||||
else:
|
||||
preload_results[symbol][timeframe] = True # Already have data
|
||||
logger.info(f"⏭️ Skipped preloading {symbol} {timeframe} (already have data)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preloading {symbol} {timeframe}: {e}")
|
||||
preload_results[symbol][timeframe] = False
|
||||
|
||||
# Log summary
|
||||
total_pairs = len(self.symbols) * len(timeframes)
|
||||
successful_pairs = sum(1 for symbol_results in preload_results.values()
|
||||
for success in symbol_results.values() if success)
|
||||
|
||||
logger.info(f"Preloading completed: {successful_pairs}/{total_pairs} symbol-timeframe pairs loaded")
|
||||
|
||||
return preload_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in preload_all_symbols_data: {e}")
|
||||
return {}
|
||||
|
||||
def _fetch_from_binance(self, symbol: str, timeframe: str, limit: int) -> Optional[pd.DataFrame]:
|
||||
"""Fetch data from Binance API"""
|
||||
try:
|
||||
|
@ -117,6 +117,25 @@ class EnhancedTradingOrchestrator:
|
||||
self.confidence_threshold_close = self.config.orchestrator.get('confidence_threshold_close', 0.25) # Much lower for closing
|
||||
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30)
|
||||
|
||||
# DQN RL-based sensitivity learning parameters
|
||||
self.sensitivity_learning_enabled = True
|
||||
self.sensitivity_dqn_agent = None # Will be initialized when first DQN model is available
|
||||
self.sensitivity_state_size = 20 # Features for sensitivity learning
|
||||
self.sensitivity_action_space = 5 # 5 sensitivity levels: very_low, low, medium, high, very_high
|
||||
self.current_sensitivity_level = 2 # Start with medium (index 2)
|
||||
self.sensitivity_levels = {
|
||||
0: {'name': 'very_low', 'close_threshold_multiplier': 0.5, 'open_threshold_multiplier': 1.2},
|
||||
1: {'name': 'low', 'close_threshold_multiplier': 0.7, 'open_threshold_multiplier': 1.1},
|
||||
2: {'name': 'medium', 'close_threshold_multiplier': 1.0, 'open_threshold_multiplier': 1.0},
|
||||
3: {'name': 'high', 'close_threshold_multiplier': 1.3, 'open_threshold_multiplier': 0.9},
|
||||
4: {'name': 'very_high', 'close_threshold_multiplier': 1.5, 'open_threshold_multiplier': 0.8}
|
||||
}
|
||||
|
||||
# Trade tracking for sensitivity learning
|
||||
self.active_trades = {} # symbol -> trade_info with entry details
|
||||
self.completed_trades = deque(maxlen=1000) # Store last 1000 completed trades for learning
|
||||
self.sensitivity_learning_queue = deque(maxlen=500) # Queue for DQN training
|
||||
|
||||
# Enhanced weighting system
|
||||
self.timeframe_weights = self._initialize_timeframe_weights()
|
||||
self.symbol_correlation_matrix = self._initialize_correlation_matrix()
|
||||
@ -172,6 +191,7 @@ class EnhancedTradingOrchestrator:
|
||||
logger.info("Real-time tick processor integrated for ultra-low latency processing")
|
||||
logger.info("Raw tick and OHLCV bar processing enabled for pattern detection")
|
||||
logger.info("Enhanced retrospective learning enabled for perfect opportunity detection")
|
||||
logger.info("DQN RL-based sensitivity learning enabled for adaptive thresholds")
|
||||
|
||||
def _initialize_timeframe_weights(self) -> Dict[str, float]:
|
||||
"""Initialize weights for different timeframes"""
|
||||
@ -640,8 +660,10 @@ class EnhancedTradingOrchestrator:
|
||||
if action.action == 'BUY':
|
||||
# Close any short position, open long position
|
||||
if symbol in self.open_positions and self.open_positions[symbol]['side'] == 'SHORT':
|
||||
self._close_trade_for_sensitivity_learning(symbol, action)
|
||||
del self.open_positions[symbol]
|
||||
else:
|
||||
self._open_trade_for_sensitivity_learning(symbol, action)
|
||||
self.open_positions[symbol] = {
|
||||
'side': 'LONG',
|
||||
'entry_price': action.price,
|
||||
@ -650,474 +672,464 @@ class EnhancedTradingOrchestrator:
|
||||
elif action.action == 'SELL':
|
||||
# Close any long position, open short position
|
||||
if symbol in self.open_positions and self.open_positions[symbol]['side'] == 'LONG':
|
||||
self._close_trade_for_sensitivity_learning(symbol, action)
|
||||
del self.open_positions[symbol]
|
||||
else:
|
||||
self._open_trade_for_sensitivity_learning(symbol, action)
|
||||
self.open_positions[symbol] = {
|
||||
'side': 'SHORT',
|
||||
'entry_price': action.price,
|
||||
'timestamp': action.timestamp
|
||||
}
|
||||
|
||||
async def trigger_retrospective_learning(self):
|
||||
"""Trigger retrospective learning analysis on recent perfect opportunities"""
|
||||
def _open_trade_for_sensitivity_learning(self, symbol: str, action: TradingAction):
|
||||
"""Track trade opening for sensitivity learning"""
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
# Get current market state for learning context
|
||||
market_state = self._get_current_market_state_for_sensitivity(symbol)
|
||||
|
||||
# Only run retrospective analysis every 5 minutes to avoid overload
|
||||
if (current_time - self.last_retrospective_analysis).total_seconds() < 300:
|
||||
return
|
||||
|
||||
self.last_retrospective_analysis = current_time
|
||||
|
||||
# Analyze recent market moves for missed opportunities
|
||||
await self._analyze_missed_opportunities()
|
||||
|
||||
# Update model confidence thresholds based on recent performance
|
||||
self._adjust_confidence_thresholds()
|
||||
|
||||
# Mark retrospective learning as active
|
||||
self.retrospective_learning_active = True
|
||||
|
||||
logger.info("Retrospective learning analysis completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in retrospective learning: {e}")
|
||||
|
||||
async def _analyze_missed_opportunities(self):
|
||||
"""Analyze recent price movements to identify missed perfect opportunities"""
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
# Get recent price data
|
||||
recent_data = self.data_provider.get_latest_candles(symbol, '1m', limit=60)
|
||||
|
||||
if recent_data is None or len(recent_data) < 10:
|
||||
continue
|
||||
|
||||
# Look for significant price movements (>1% in 5 minutes)
|
||||
for i in range(5, len(recent_data)):
|
||||
price_change = (recent_data.iloc[i]['close'] - recent_data.iloc[i-5]['close']) / recent_data.iloc[i-5]['close']
|
||||
|
||||
if abs(price_change) > 0.01: # 1% move
|
||||
# This was a perfect opportunity
|
||||
optimal_action = 'BUY' if price_change > 0 else 'SELL'
|
||||
|
||||
# Create perfect move for retrospective learning
|
||||
perfect_move = PerfectMove(
|
||||
symbol=symbol,
|
||||
timeframe='1m',
|
||||
timestamp=recent_data.iloc[i-5]['timestamp'],
|
||||
optimal_action=optimal_action,
|
||||
actual_outcome=price_change,
|
||||
market_state_before=None, # Would need to reconstruct
|
||||
market_state_after=None, # Would need to reconstruct
|
||||
confidence_should_have_been=min(0.95, abs(price_change) * 20) # Higher confidence for bigger moves
|
||||
)
|
||||
|
||||
self.perfect_moves.append(perfect_move)
|
||||
|
||||
logger.info(f"Retrospective perfect opportunity identified: {optimal_action} {symbol} ({price_change*100:+.2f}%)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing missed opportunities: {e}")
|
||||
|
||||
def _adjust_confidence_thresholds(self):
|
||||
"""Dynamically adjust confidence thresholds based on recent performance"""
|
||||
try:
|
||||
if len(self.perfect_moves) < 10:
|
||||
return
|
||||
|
||||
# Analyze recent perfect moves
|
||||
recent_moves = list(self.perfect_moves)[-20:]
|
||||
avg_confidence_needed = np.mean([move.confidence_should_have_been for move in recent_moves])
|
||||
avg_outcome = np.mean([abs(move.actual_outcome) for move in recent_moves])
|
||||
|
||||
# Adjust opening threshold based on missed opportunities
|
||||
if avg_confidence_needed > self.confidence_threshold_open:
|
||||
adjustment = min(0.1, (avg_confidence_needed - self.confidence_threshold_open) * 0.1)
|
||||
self.confidence_threshold_open = max(0.3, self.confidence_threshold_open - adjustment)
|
||||
logger.info(f"Adjusted opening confidence threshold to {self.confidence_threshold_open:.3f}")
|
||||
|
||||
# Keep closing threshold very low for sensitivity
|
||||
if avg_outcome > 0.02: # If we're seeing big moves
|
||||
self.confidence_threshold_close = max(0.15, self.confidence_threshold_close * 0.9)
|
||||
logger.info(f"Lowered closing confidence threshold to {self.confidence_threshold_close:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adjusting confidence thresholds: {e}")
|
||||
|
||||
def _get_correlated_sentiment(self, symbol: str,
|
||||
all_predictions: Dict[str, List[EnhancedPrediction]]) -> Dict[str, Any]:
|
||||
"""Get sentiment from correlated symbols"""
|
||||
correlated_actions = []
|
||||
correlated_confidences = []
|
||||
|
||||
for other_symbol, predictions in all_predictions.items():
|
||||
if other_symbol != symbol and predictions:
|
||||
correlation = self.symbol_correlation_matrix.get((symbol, other_symbol), 0.0)
|
||||
|
||||
if correlation > 0.5: # Only consider significantly correlated symbols
|
||||
best_pred = max(predictions, key=lambda p: p.overall_confidence)
|
||||
correlated_actions.append(best_pred.overall_action)
|
||||
correlated_confidences.append(best_pred.overall_confidence * correlation)
|
||||
|
||||
if not correlated_actions:
|
||||
return {'agreement': 1.0, 'sentiment': 'NEUTRAL'}
|
||||
|
||||
# Calculate agreement
|
||||
primary_pred = all_predictions[symbol][0] if all_predictions.get(symbol) else None
|
||||
if primary_pred:
|
||||
agreement_count = sum(1 for action in correlated_actions
|
||||
if action == primary_pred.overall_action)
|
||||
agreement = agreement_count / len(correlated_actions)
|
||||
else:
|
||||
agreement = 0.5
|
||||
|
||||
# Calculate overall sentiment
|
||||
action_weights = {'BUY': 0.0, 'SELL': 0.0, 'HOLD': 0.0}
|
||||
for action, confidence in zip(correlated_actions, correlated_confidences):
|
||||
action_weights[action] += confidence
|
||||
|
||||
dominant_sentiment = max(action_weights, key=action_weights.get)
|
||||
|
||||
return {
|
||||
'agreement': agreement,
|
||||
'sentiment': dominant_sentiment,
|
||||
'correlated_symbols': len(correlated_actions)
|
||||
}
|
||||
|
||||
def _queue_for_rl_evaluation(self, action: TradingAction, market_state: MarketState):
|
||||
"""Queue trading action for RL evaluation"""
|
||||
evaluation_item = {
|
||||
'action': action,
|
||||
'market_state_before': market_state,
|
||||
'timestamp': datetime.now(),
|
||||
'evaluation_pending': True
|
||||
}
|
||||
self.rl_evaluation_queue.append(evaluation_item)
|
||||
|
||||
async def evaluate_actions_with_rl(self):
|
||||
"""Evaluate recent actions using RL agents for continuous learning"""
|
||||
if not self.rl_evaluation_queue:
|
||||
return
|
||||
|
||||
current_time = datetime.now()
|
||||
|
||||
# Process actions that are ready for evaluation (e.g., 1 hour old)
|
||||
for item in list(self.rl_evaluation_queue):
|
||||
if item['evaluation_pending']:
|
||||
time_since_action = (current_time - item['timestamp']).total_seconds()
|
||||
|
||||
# Evaluate after sufficient time has passed
|
||||
if time_since_action >= 3600: # 1 hour
|
||||
await self._evaluate_single_action(item)
|
||||
item['evaluation_pending'] = False
|
||||
|
||||
async def _evaluate_single_action(self, evaluation_item: Dict[str, Any]):
|
||||
"""Evaluate a single action using RL"""
|
||||
try:
|
||||
action = evaluation_item['action']
|
||||
initial_state = evaluation_item['market_state_before']
|
||||
|
||||
# Get current market state for comparison
|
||||
current_market_states = await self._get_all_market_states_universal(self.universal_adapter.get_universal_data_stream())
|
||||
current_state = current_market_states.get(action.symbol)
|
||||
|
||||
if current_state:
|
||||
# Calculate reward based on price movement
|
||||
initial_price = initial_state.prices.get(self.timeframes[0], 0)
|
||||
current_price = current_state.prices.get(self.timeframes[0], 0)
|
||||
|
||||
if initial_price > 0:
|
||||
price_change = (current_price - initial_price) / initial_price
|
||||
|
||||
# Calculate reward based on action and price movement
|
||||
reward = self._calculate_reward(action.action, price_change, action.confidence)
|
||||
|
||||
# Update RL agents
|
||||
await self._update_rl_agents(action, initial_state, current_state, reward)
|
||||
|
||||
# Check if this was a perfect move for CNN training
|
||||
if abs(reward) > 0.02: # Significant outcome
|
||||
self._mark_perfect_move(action, initial_state, current_state, reward)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating action: {e}")
|
||||
|
||||
def _calculate_reward(self, action: str, price_change: float, confidence: float) -> float:
|
||||
"""Calculate reward for RL training"""
|
||||
base_reward = 0.0
|
||||
|
||||
if action == 'BUY' and price_change > 0:
|
||||
base_reward = price_change * 10 # Reward proportional to gain
|
||||
elif action == 'SELL' and price_change < 0:
|
||||
base_reward = abs(price_change) * 10 # Reward for avoiding loss
|
||||
elif action == 'HOLD':
|
||||
base_reward = 0.01 if abs(price_change) < 0.005 else -0.01 # Small reward for correct holds
|
||||
else:
|
||||
base_reward = -abs(price_change) * 5 # Penalty for wrong actions
|
||||
|
||||
# Adjust reward based on confidence
|
||||
confidence_multiplier = 0.5 + confidence # 0.5 to 1.5 range
|
||||
|
||||
return base_reward * confidence_multiplier
|
||||
|
||||
async def _update_rl_agents(self, action: TradingAction, initial_state: MarketState,
|
||||
current_state: MarketState, reward: float):
|
||||
"""Update RL agents with action evaluation"""
|
||||
for model_name, model in self.model_registry.models.items():
|
||||
if isinstance(model, RLAgentInterface):
|
||||
try:
|
||||
# Convert market states to RL state format
|
||||
initial_rl_state = self._market_state_to_rl_state(initial_state)
|
||||
current_rl_state = self._market_state_to_rl_state(current_state)
|
||||
|
||||
# Convert action to RL action index
|
||||
action_idx = {'SELL': 0, 'HOLD': 1, 'BUY': 2}.get(action.action, 1)
|
||||
|
||||
# Store experience
|
||||
model.remember(
|
||||
state=initial_rl_state,
|
||||
action=action_idx,
|
||||
reward=reward,
|
||||
next_state=current_rl_state,
|
||||
done=False
|
||||
)
|
||||
|
||||
# Trigger replay learning
|
||||
loss = model.replay()
|
||||
if loss is not None:
|
||||
logger.info(f"RL agent {model_name} updated with loss: {loss:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating RL agent {model_name}: {e}")
|
||||
|
||||
def _mark_perfect_move(self, action: TradingAction, initial_state: MarketState,
|
||||
final_state: MarketState, reward: float):
|
||||
"""Mark a perfect move for CNN training"""
|
||||
try:
|
||||
# Determine what the optimal action should have been
|
||||
optimal_action = action.action if reward > 0 else ('HOLD' if action.action == 'HOLD' else
|
||||
('SELL' if action.action == 'BUY' else 'BUY'))
|
||||
|
||||
# Calculate what confidence should have been
|
||||
optimal_confidence = min(0.95, abs(reward) * 10) # Higher reward = higher confidence should have been
|
||||
|
||||
for tf_pred in action.timeframe_analysis:
|
||||
perfect_move = PerfectMove(
|
||||
symbol=action.symbol,
|
||||
timeframe=tf_pred.timeframe,
|
||||
timestamp=action.timestamp,
|
||||
optimal_action=optimal_action,
|
||||
actual_outcome=reward,
|
||||
market_state_before=initial_state,
|
||||
market_state_after=final_state,
|
||||
confidence_should_have_been=optimal_confidence
|
||||
)
|
||||
self.perfect_moves.append(perfect_move)
|
||||
|
||||
logger.info(f"Marked perfect move for {action.symbol}: {optimal_action} with confidence {optimal_confidence:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error marking perfect move: {e}")
|
||||
|
||||
def get_recent_perfect_moves(self, limit: int = 10) -> List[PerfectMove]:
|
||||
"""Get recent perfect moves for display/monitoring"""
|
||||
return list(self.perfect_moves)[-limit:]
|
||||
|
||||
async def queue_action_for_evaluation(self, action: TradingAction):
|
||||
"""Queue a trading action for future RL evaluation"""
|
||||
try:
|
||||
# Get current market state
|
||||
market_states = await self._get_all_market_states_universal(self.universal_adapter.get_universal_data_stream())
|
||||
if action.symbol in market_states:
|
||||
evaluation_item = {
|
||||
'action': action,
|
||||
'market_state_before': market_states[action.symbol],
|
||||
'timestamp': datetime.now()
|
||||
trade_info = {
|
||||
'symbol': symbol,
|
||||
'side': 'LONG' if action.action == 'BUY' else 'SHORT',
|
||||
'entry_price': action.price,
|
||||
'entry_time': action.timestamp,
|
||||
'entry_confidence': action.confidence,
|
||||
'entry_market_state': market_state,
|
||||
'sensitivity_level_at_entry': self.current_sensitivity_level,
|
||||
'thresholds_used': {
|
||||
'open': self._get_current_open_threshold(),
|
||||
'close': self._get_current_close_threshold()
|
||||
}
|
||||
self.rl_evaluation_queue.append(evaluation_item)
|
||||
logger.debug(f"Queued action for RL evaluation: {action.action} {action.symbol}")
|
||||
}
|
||||
|
||||
self.active_trades[symbol] = trade_info
|
||||
logger.info(f"Opened trade for sensitivity learning: {symbol} {trade_info['side']} @ ${action.price:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error queuing action for evaluation: {e}")
|
||||
logger.error(f"Error tracking trade opening for sensitivity learning: {e}")
|
||||
|
||||
def get_perfect_moves_for_training(self, symbol: str = None, timeframe: str = None,
|
||||
limit: int = 1000) -> List[PerfectMove]:
|
||||
"""Get perfect moves for CNN training"""
|
||||
moves = list(self.perfect_moves)
|
||||
|
||||
# Filter by symbol if specified
|
||||
if symbol:
|
||||
moves = [move for move in moves if move.symbol == symbol]
|
||||
|
||||
# Filter by timeframe if specified
|
||||
if timeframe:
|
||||
moves = [move for move in moves if move.timeframe == timeframe]
|
||||
|
||||
return moves[-limit:] # Return most recent moves
|
||||
|
||||
# Helper methods for market analysis using universal data
|
||||
def _calculate_volatility_from_universal(self, symbol: str, universal_stream: UniversalDataStream) -> float:
|
||||
"""Calculate current volatility for symbol using universal data"""
|
||||
def _close_trade_for_sensitivity_learning(self, symbol: str, action: TradingAction):
|
||||
"""Track trade closing and create learning case for DQN"""
|
||||
try:
|
||||
if symbol == 'ETH/USDT' and len(universal_stream.eth_ticks) > 10:
|
||||
# Calculate volatility from tick data
|
||||
prices = universal_stream.eth_ticks[-10:, 4] # Last 10 close prices
|
||||
returns = np.diff(prices) / prices[:-1]
|
||||
volatility = np.std(returns) * np.sqrt(86400) # Annualized volatility
|
||||
return float(volatility)
|
||||
elif symbol == 'BTC/USDT' and len(universal_stream.btc_ticks) > 10:
|
||||
# Calculate volatility from BTC tick data
|
||||
prices = universal_stream.btc_ticks[-10:, 4] # Last 10 close prices
|
||||
returns = np.diff(prices) / prices[:-1]
|
||||
volatility = np.std(returns) * np.sqrt(86400) # Annualized volatility
|
||||
return float(volatility)
|
||||
if symbol not in self.active_trades:
|
||||
return
|
||||
|
||||
trade_info = self.active_trades[symbol]
|
||||
|
||||
# Calculate trade outcome
|
||||
entry_price = trade_info['entry_price']
|
||||
exit_price = action.price
|
||||
side = trade_info['side']
|
||||
|
||||
if side == 'LONG':
|
||||
pnl_pct = (exit_price - entry_price) / entry_price
|
||||
else: # SHORT
|
||||
pnl_pct = (entry_price - exit_price) / entry_price
|
||||
|
||||
# Calculate trade duration
|
||||
duration = (action.timestamp - trade_info['entry_time']).total_seconds()
|
||||
|
||||
# Get current market state for exit context
|
||||
exit_market_state = self._get_current_market_state_for_sensitivity(symbol)
|
||||
|
||||
# Create completed trade record
|
||||
completed_trade = {
|
||||
'symbol': symbol,
|
||||
'side': side,
|
||||
'entry_price': entry_price,
|
||||
'exit_price': exit_price,
|
||||
'entry_time': trade_info['entry_time'],
|
||||
'exit_time': action.timestamp,
|
||||
'duration': duration,
|
||||
'pnl_pct': pnl_pct,
|
||||
'entry_confidence': trade_info['entry_confidence'],
|
||||
'exit_confidence': action.confidence,
|
||||
'entry_market_state': trade_info['entry_market_state'],
|
||||
'exit_market_state': exit_market_state,
|
||||
'sensitivity_level_used': trade_info['sensitivity_level_at_entry'],
|
||||
'thresholds_used': trade_info['thresholds_used']
|
||||
}
|
||||
|
||||
self.completed_trades.append(completed_trade)
|
||||
|
||||
# Create sensitivity learning case for DQN
|
||||
self._create_sensitivity_learning_case(completed_trade)
|
||||
|
||||
# Remove from active trades
|
||||
del self.active_trades[symbol]
|
||||
|
||||
logger.info(f"Closed trade for sensitivity learning: {symbol} {side} P&L: {pnl_pct*100:+.2f}% Duration: {duration:.0f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating volatility from universal data: {e}")
|
||||
|
||||
return 0.02 # Default 2% volatility
|
||||
logger.error(f"Error tracking trade closing for sensitivity learning: {e}")
|
||||
|
||||
def _get_current_volume_from_universal(self, symbol: str, universal_stream: UniversalDataStream) -> float:
|
||||
"""Get current volume ratio compared to average using universal data"""
|
||||
def _get_current_market_state_for_sensitivity(self, symbol: str) -> Dict[str, float]:
|
||||
"""Get current market state features for sensitivity learning"""
|
||||
try:
|
||||
if symbol == 'ETH/USDT':
|
||||
# Use 1m data for volume analysis
|
||||
if len(universal_stream.eth_1m) > 10:
|
||||
volumes = universal_stream.eth_1m[-10:, 5] # Last 10 volume values
|
||||
current_volume = universal_stream.eth_1m[-1, 5]
|
||||
avg_volume = np.mean(volumes[:-1])
|
||||
if avg_volume > 0:
|
||||
return float(current_volume / avg_volume)
|
||||
elif symbol == 'BTC/USDT':
|
||||
# Use BTC tick data for volume analysis
|
||||
if len(universal_stream.btc_ticks) > 10:
|
||||
volumes = universal_stream.btc_ticks[-10:, 5] # Last 10 volume values
|
||||
current_volume = universal_stream.btc_ticks[-1, 5]
|
||||
avg_volume = np.mean(volumes[:-1])
|
||||
if avg_volume > 0:
|
||||
return float(current_volume / avg_volume)
|
||||
# Get recent price data
|
||||
recent_data = self.data_provider.get_historical_data(symbol, '1m', limit=20)
|
||||
|
||||
if recent_data is None or len(recent_data) < 10:
|
||||
return self._get_default_market_state()
|
||||
|
||||
# Calculate market features
|
||||
current_price = recent_data['close'].iloc[-1]
|
||||
|
||||
# Volatility (20-period)
|
||||
volatility = recent_data['close'].pct_change().std() * 100
|
||||
|
||||
# Price momentum (5-period)
|
||||
momentum_5 = (current_price - recent_data['close'].iloc[-6]) / recent_data['close'].iloc[-6] * 100
|
||||
|
||||
# Volume ratio
|
||||
avg_volume = recent_data['volume'].mean()
|
||||
current_volume = recent_data['volume'].iloc[-1]
|
||||
volume_ratio = current_volume / avg_volume if avg_volume > 0 else 1.0
|
||||
|
||||
# RSI
|
||||
rsi = recent_data['rsi'].iloc[-1] if 'rsi' in recent_data.columns else 50.0
|
||||
|
||||
# MACD signal
|
||||
macd_signal = 0.0
|
||||
if 'macd' in recent_data.columns and 'macd_signal' in recent_data.columns:
|
||||
macd_signal = recent_data['macd'].iloc[-1] - recent_data['macd_signal'].iloc[-1]
|
||||
|
||||
# Bollinger Band position
|
||||
bb_position = 0.5 # Default middle
|
||||
if 'bb_upper' in recent_data.columns and 'bb_lower' in recent_data.columns:
|
||||
bb_upper = recent_data['bb_upper'].iloc[-1]
|
||||
bb_lower = recent_data['bb_lower'].iloc[-1]
|
||||
if bb_upper > bb_lower:
|
||||
bb_position = (current_price - bb_lower) / (bb_upper - bb_lower)
|
||||
|
||||
# Recent price change patterns
|
||||
price_changes = recent_data['close'].pct_change().tail(5).tolist()
|
||||
|
||||
return {
|
||||
'volatility': volatility,
|
||||
'momentum_5': momentum_5,
|
||||
'volume_ratio': volume_ratio,
|
||||
'rsi': rsi,
|
||||
'macd_signal': macd_signal,
|
||||
'bb_position': bb_position,
|
||||
'price_change_1': price_changes[-1] if len(price_changes) > 0 else 0.0,
|
||||
'price_change_2': price_changes[-2] if len(price_changes) > 1 else 0.0,
|
||||
'price_change_3': price_changes[-3] if len(price_changes) > 2 else 0.0,
|
||||
'price_change_4': price_changes[-4] if len(price_changes) > 3 else 0.0,
|
||||
'price_change_5': price_changes[-5] if len(price_changes) > 4 else 0.0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating volume from universal data: {e}")
|
||||
|
||||
return 1.0 # Normal volume
|
||||
logger.error(f"Error getting market state for sensitivity learning: {e}")
|
||||
return self._get_default_market_state()
|
||||
|
||||
def _calculate_trend_strength_from_universal(self, symbol: str, universal_stream: UniversalDataStream) -> float:
|
||||
"""Calculate trend strength using universal data"""
|
||||
def _get_default_market_state(self) -> Dict[str, float]:
|
||||
"""Get default market state when data is unavailable"""
|
||||
return {
|
||||
'volatility': 2.0,
|
||||
'momentum_5': 0.0,
|
||||
'volume_ratio': 1.0,
|
||||
'rsi': 50.0,
|
||||
'macd_signal': 0.0,
|
||||
'bb_position': 0.5,
|
||||
'price_change_1': 0.0,
|
||||
'price_change_2': 0.0,
|
||||
'price_change_3': 0.0,
|
||||
'price_change_4': 0.0,
|
||||
'price_change_5': 0.0
|
||||
}
|
||||
|
||||
def _create_sensitivity_learning_case(self, completed_trade: Dict[str, Any]):
|
||||
"""Create a learning case for the DQN sensitivity agent"""
|
||||
try:
|
||||
if symbol == 'ETH/USDT':
|
||||
# Use multiple timeframes to determine trend strength
|
||||
trend_scores = []
|
||||
|
||||
# Check 1m trend
|
||||
if len(universal_stream.eth_1m) > 20:
|
||||
prices = universal_stream.eth_1m[-20:, 4] # Last 20 close prices
|
||||
slope = np.polyfit(range(len(prices)), prices, 1)[0]
|
||||
trend_scores.append(abs(slope) / np.mean(prices))
|
||||
|
||||
# Check 1h trend
|
||||
if len(universal_stream.eth_1h) > 10:
|
||||
prices = universal_stream.eth_1h[-10:, 4] # Last 10 close prices
|
||||
slope = np.polyfit(range(len(prices)), prices, 1)[0]
|
||||
trend_scores.append(abs(slope) / np.mean(prices))
|
||||
|
||||
if trend_scores:
|
||||
return float(np.mean(trend_scores))
|
||||
|
||||
elif symbol == 'BTC/USDT':
|
||||
# Use BTC tick data for trend analysis
|
||||
if len(universal_stream.btc_ticks) > 20:
|
||||
prices = universal_stream.btc_ticks[-20:, 4] # Last 20 close prices
|
||||
slope = np.polyfit(range(len(prices)), prices, 1)[0]
|
||||
return float(abs(slope) / np.mean(prices))
|
||||
|
||||
# Create state vector from market conditions at entry
|
||||
entry_state = self._market_state_to_sensitivity_state(
|
||||
completed_trade['entry_market_state'],
|
||||
completed_trade['sensitivity_level_used']
|
||||
)
|
||||
|
||||
# Create state vector from market conditions at exit
|
||||
exit_state = self._market_state_to_sensitivity_state(
|
||||
completed_trade['exit_market_state'],
|
||||
completed_trade['sensitivity_level_used']
|
||||
)
|
||||
|
||||
# Calculate reward based on trade outcome
|
||||
reward = self._calculate_sensitivity_reward(completed_trade)
|
||||
|
||||
# Determine optimal sensitivity action based on outcome
|
||||
optimal_sensitivity = self._determine_optimal_sensitivity(completed_trade)
|
||||
|
||||
# Create learning experience
|
||||
learning_case = {
|
||||
'state': entry_state,
|
||||
'action': completed_trade['sensitivity_level_used'],
|
||||
'reward': reward,
|
||||
'next_state': exit_state,
|
||||
'done': True, # Trade is completed
|
||||
'optimal_action': optimal_sensitivity,
|
||||
'trade_outcome': completed_trade['pnl_pct'],
|
||||
'trade_duration': completed_trade['duration'],
|
||||
'symbol': completed_trade['symbol']
|
||||
}
|
||||
|
||||
self.sensitivity_learning_queue.append(learning_case)
|
||||
|
||||
# Train DQN if we have enough cases
|
||||
if len(self.sensitivity_learning_queue) >= 32: # Batch size
|
||||
self._train_sensitivity_dqn()
|
||||
|
||||
logger.info(f"Created sensitivity learning case: reward={reward:.3f}, optimal_sensitivity={optimal_sensitivity}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating trend strength from universal data: {e}")
|
||||
|
||||
return 0.5 # Moderate trend
|
||||
logger.error(f"Error creating sensitivity learning case: {e}")
|
||||
|
||||
def _determine_market_regime_from_universal(self, symbol: str, universal_stream: UniversalDataStream) -> str:
|
||||
"""Determine current market regime using universal data"""
|
||||
def _market_state_to_sensitivity_state(self, market_state: Dict[str, float], current_sensitivity: int) -> np.ndarray:
|
||||
"""Convert market state to DQN state vector for sensitivity learning"""
|
||||
try:
|
||||
if symbol == 'ETH/USDT':
|
||||
# Analyze volatility and trend from multiple timeframes
|
||||
volatility = self._calculate_volatility_from_universal(symbol, universal_stream)
|
||||
trend_strength = self._calculate_trend_strength_from_universal(symbol, universal_stream)
|
||||
|
||||
# Determine regime based on volatility and trend
|
||||
if volatility > 0.05: # High volatility
|
||||
return 'volatile'
|
||||
elif trend_strength > 0.002: # Strong trend
|
||||
return 'trending'
|
||||
else:
|
||||
return 'ranging'
|
||||
|
||||
elif symbol == 'BTC/USDT':
|
||||
# Analyze BTC regime
|
||||
volatility = self._calculate_volatility_from_universal(symbol, universal_stream)
|
||||
|
||||
if volatility > 0.04: # High volatility for BTC
|
||||
return 'volatile'
|
||||
else:
|
||||
return 'trending' # Default for BTC
|
||||
|
||||
# Create state vector with market features + current sensitivity
|
||||
state_features = [
|
||||
market_state.get('volatility', 2.0) / 10.0, # Normalize volatility
|
||||
market_state.get('momentum_5', 0.0) / 5.0, # Normalize momentum
|
||||
market_state.get('volume_ratio', 1.0), # Volume ratio
|
||||
market_state.get('rsi', 50.0) / 100.0, # Normalize RSI
|
||||
market_state.get('macd_signal', 0.0) / 2.0, # Normalize MACD
|
||||
market_state.get('bb_position', 0.5), # BB position (already 0-1)
|
||||
market_state.get('price_change_1', 0.0) * 100, # Recent price changes
|
||||
market_state.get('price_change_2', 0.0) * 100,
|
||||
market_state.get('price_change_3', 0.0) * 100,
|
||||
market_state.get('price_change_4', 0.0) * 100,
|
||||
market_state.get('price_change_5', 0.0) * 100,
|
||||
current_sensitivity / 4.0, # Normalize current sensitivity (0-4 -> 0-1)
|
||||
]
|
||||
|
||||
# Add recent performance metrics
|
||||
if len(self.completed_trades) > 0:
|
||||
recent_trades = list(self.completed_trades)[-10:] # Last 10 trades
|
||||
avg_pnl = np.mean([t['pnl_pct'] for t in recent_trades])
|
||||
win_rate = len([t for t in recent_trades if t['pnl_pct'] > 0]) / len(recent_trades)
|
||||
avg_duration = np.mean([t['duration'] for t in recent_trades]) / 3600 # Normalize to hours
|
||||
else:
|
||||
avg_pnl = 0.0
|
||||
win_rate = 0.5
|
||||
avg_duration = 0.5
|
||||
|
||||
state_features.extend([
|
||||
avg_pnl * 10, # Recent average P&L
|
||||
win_rate, # Recent win rate
|
||||
avg_duration, # Recent average duration
|
||||
])
|
||||
|
||||
# Pad or truncate to exact state size
|
||||
while len(state_features) < self.sensitivity_state_size:
|
||||
state_features.append(0.0)
|
||||
|
||||
state_features = state_features[:self.sensitivity_state_size]
|
||||
|
||||
return np.array(state_features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error determining market regime from universal data: {e}")
|
||||
|
||||
return 'trending' # Default regime
|
||||
logger.error(f"Error converting market state to sensitivity state: {e}")
|
||||
return np.zeros(self.sensitivity_state_size, dtype=np.float32)
|
||||
|
||||
# Legacy helper methods (kept for compatibility)
|
||||
def _calculate_volatility(self, symbol: str) -> float:
|
||||
"""Calculate current volatility for symbol (legacy method)"""
|
||||
return 0.02 # 2% default volatility
|
||||
def _calculate_sensitivity_reward(self, completed_trade: Dict[str, Any]) -> float:
|
||||
"""Calculate reward for sensitivity learning based on trade outcome"""
|
||||
try:
|
||||
pnl_pct = completed_trade['pnl_pct']
|
||||
duration = completed_trade['duration']
|
||||
|
||||
# Base reward from P&L
|
||||
base_reward = pnl_pct * 10 # Scale P&L percentage
|
||||
|
||||
# Duration penalty/bonus
|
||||
if duration < 300: # Less than 5 minutes - too quick
|
||||
duration_factor = 0.8
|
||||
elif duration < 1800: # Less than 30 minutes - good for scalping
|
||||
duration_factor = 1.2
|
||||
elif duration < 3600: # Less than 1 hour - acceptable
|
||||
duration_factor = 1.0
|
||||
else: # More than 1 hour - too slow for scalping
|
||||
duration_factor = 0.7
|
||||
|
||||
# Confidence factor - reward appropriate confidence levels
|
||||
entry_conf = completed_trade['entry_confidence']
|
||||
exit_conf = completed_trade['exit_confidence']
|
||||
|
||||
if pnl_pct > 0: # Winning trade
|
||||
# Reward high entry confidence and appropriate exit confidence
|
||||
conf_factor = (entry_conf + exit_conf) / 2
|
||||
else: # Losing trade
|
||||
# Reward quick exit (high exit confidence for losses)
|
||||
conf_factor = exit_conf
|
||||
|
||||
# Calculate final reward
|
||||
final_reward = base_reward * duration_factor * conf_factor
|
||||
|
||||
# Clip reward to reasonable range
|
||||
final_reward = np.clip(final_reward, -2.0, 2.0)
|
||||
|
||||
return float(final_reward)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating sensitivity reward: {e}")
|
||||
return 0.0
|
||||
|
||||
def _get_current_volume(self, symbol: str) -> float:
|
||||
"""Get current volume ratio compared to average (legacy method)"""
|
||||
return 1.0 # Normal volume
|
||||
def _determine_optimal_sensitivity(self, completed_trade: Dict[str, Any]) -> int:
|
||||
"""Determine optimal sensitivity level based on trade outcome"""
|
||||
try:
|
||||
pnl_pct = completed_trade['pnl_pct']
|
||||
duration = completed_trade['duration']
|
||||
current_sensitivity = completed_trade['sensitivity_level_used']
|
||||
|
||||
# If trade was profitable and quick, current sensitivity was good
|
||||
if pnl_pct > 0.01 and duration < 1800: # >1% profit in <30 min
|
||||
return current_sensitivity
|
||||
|
||||
# If trade was very profitable, could have been more aggressive
|
||||
if pnl_pct > 0.02: # >2% profit
|
||||
return min(4, current_sensitivity + 1) # Increase sensitivity
|
||||
|
||||
# If trade was a small loss, might need more sensitivity
|
||||
if -0.01 < pnl_pct < 0: # Small loss
|
||||
return min(4, current_sensitivity + 1) # Increase sensitivity
|
||||
|
||||
# If trade was a big loss, need less sensitivity
|
||||
if pnl_pct < -0.02: # >2% loss
|
||||
return max(0, current_sensitivity - 1) # Decrease sensitivity
|
||||
|
||||
# If trade took too long, need more sensitivity
|
||||
if duration > 3600: # >1 hour
|
||||
return min(4, current_sensitivity + 1) # Increase sensitivity
|
||||
|
||||
# Default: keep current sensitivity
|
||||
return current_sensitivity
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error determining optimal sensitivity: {e}")
|
||||
return 2 # Default to medium
|
||||
|
||||
def _calculate_trend_strength(self, symbol: str) -> float:
|
||||
"""Calculate trend strength (legacy method)"""
|
||||
return 0.5 # Moderate trend
|
||||
def _train_sensitivity_dqn(self):
|
||||
"""Train the DQN agent for sensitivity learning"""
|
||||
try:
|
||||
# Initialize DQN agent if not already done
|
||||
if self.sensitivity_dqn_agent is None:
|
||||
self._initialize_sensitivity_dqn()
|
||||
|
||||
if self.sensitivity_dqn_agent is None:
|
||||
return
|
||||
|
||||
# Get batch of learning cases
|
||||
batch_size = min(32, len(self.sensitivity_learning_queue))
|
||||
if batch_size < 8: # Need minimum batch size
|
||||
return
|
||||
|
||||
# Sample random batch
|
||||
batch_indices = np.random.choice(len(self.sensitivity_learning_queue), batch_size, replace=False)
|
||||
batch = [self.sensitivity_learning_queue[i] for i in batch_indices]
|
||||
|
||||
# Train the DQN agent
|
||||
for case in batch:
|
||||
self.sensitivity_dqn_agent.remember(
|
||||
state=case['state'],
|
||||
action=case['action'],
|
||||
reward=case['reward'],
|
||||
next_state=case['next_state'],
|
||||
done=case['done']
|
||||
)
|
||||
|
||||
# Perform replay training
|
||||
loss = self.sensitivity_dqn_agent.replay()
|
||||
|
||||
if loss is not None:
|
||||
logger.info(f"Sensitivity DQN training completed. Loss: {loss:.4f}")
|
||||
|
||||
# Update current sensitivity level based on recent performance
|
||||
self._update_current_sensitivity_level()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training sensitivity DQN: {e}")
|
||||
|
||||
def _determine_market_regime(self, symbol: str) -> str:
|
||||
"""Determine current market regime (legacy method)"""
|
||||
return 'trending' # Default to trending
|
||||
def _initialize_sensitivity_dqn(self):
|
||||
"""Initialize the DQN agent for sensitivity learning"""
|
||||
try:
|
||||
# Try to import DQN agent
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
|
||||
# Create DQN agent for sensitivity learning
|
||||
self.sensitivity_dqn_agent = DQNAgent(
|
||||
state_shape=(self.sensitivity_state_size,),
|
||||
n_actions=self.sensitivity_action_space,
|
||||
learning_rate=0.001,
|
||||
gamma=0.95,
|
||||
epsilon=0.3, # Lower epsilon for more exploitation
|
||||
epsilon_min=0.05,
|
||||
epsilon_decay=0.995,
|
||||
buffer_size=1000,
|
||||
batch_size=32,
|
||||
target_update=10
|
||||
)
|
||||
|
||||
logger.info("Sensitivity DQN agent initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing sensitivity DQN agent: {e}")
|
||||
self.sensitivity_dqn_agent = None
|
||||
|
||||
def _get_symbol_correlation(self, symbol: str) -> Dict[str, float]:
|
||||
"""Get correlations with other symbols"""
|
||||
correlations = {}
|
||||
for other_symbol in self.symbols:
|
||||
if other_symbol != symbol:
|
||||
correlations[other_symbol] = self.symbol_correlation_matrix.get((symbol, other_symbol), 0.0)
|
||||
return correlations
|
||||
def _update_current_sensitivity_level(self):
|
||||
"""Update current sensitivity level using trained DQN"""
|
||||
try:
|
||||
if self.sensitivity_dqn_agent is None:
|
||||
return
|
||||
|
||||
# Get current market state
|
||||
current_market_state = self._get_current_market_state_for_sensitivity('ETH/USDT') # Use ETH as primary
|
||||
current_state = self._market_state_to_sensitivity_state(current_market_state, self.current_sensitivity_level)
|
||||
|
||||
# Get action from DQN (without exploration for production use)
|
||||
action = self.sensitivity_dqn_agent.act(current_state, explore=False)
|
||||
|
||||
# Update sensitivity level if it changed
|
||||
if action != self.current_sensitivity_level:
|
||||
old_level = self.current_sensitivity_level
|
||||
self.current_sensitivity_level = action
|
||||
|
||||
# Update thresholds based on new sensitivity level
|
||||
self._update_thresholds_from_sensitivity()
|
||||
|
||||
logger.info(f"Sensitivity level updated: {self.sensitivity_levels[old_level]['name']} -> {self.sensitivity_levels[action]['name']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating current sensitivity level: {e}")
|
||||
|
||||
def _calculate_position_size(self, symbol: str, action: str, confidence: float) -> float:
|
||||
"""Calculate position size based on confidence and risk management"""
|
||||
base_size = 0.02 # 2% of portfolio
|
||||
confidence_multiplier = confidence # Scale by confidence
|
||||
max_size = 0.05 # 5% maximum
|
||||
|
||||
return min(base_size * confidence_multiplier, max_size)
|
||||
def _update_thresholds_from_sensitivity(self):
|
||||
"""Update confidence thresholds based on current sensitivity level"""
|
||||
try:
|
||||
sensitivity_config = self.sensitivity_levels[self.current_sensitivity_level]
|
||||
|
||||
# Get base thresholds from config
|
||||
base_open_threshold = self.config.orchestrator.get('confidence_threshold', 0.6)
|
||||
base_close_threshold = self.config.orchestrator.get('confidence_threshold_close', 0.25)
|
||||
|
||||
# Apply sensitivity multipliers
|
||||
self.confidence_threshold_open = base_open_threshold * sensitivity_config['open_threshold_multiplier']
|
||||
self.confidence_threshold_close = base_close_threshold * sensitivity_config['close_threshold_multiplier']
|
||||
|
||||
# Ensure thresholds stay within reasonable bounds
|
||||
self.confidence_threshold_open = np.clip(self.confidence_threshold_open, 0.3, 0.9)
|
||||
self.confidence_threshold_close = np.clip(self.confidence_threshold_close, 0.1, 0.6)
|
||||
|
||||
logger.info(f"Updated thresholds - Open: {self.confidence_threshold_open:.3f}, Close: {self.confidence_threshold_close:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating thresholds from sensitivity: {e}")
|
||||
|
||||
def _market_state_to_rl_state(self, market_state: MarketState) -> np.ndarray:
|
||||
"""Convert market state to RL state vector"""
|
||||
# Combine features from all timeframes into a single state vector
|
||||
state_components = []
|
||||
|
||||
# Add price features
|
||||
state_components.extend([
|
||||
market_state.volatility,
|
||||
market_state.volume,
|
||||
market_state.trend_strength
|
||||
])
|
||||
|
||||
# Add flattened features from each timeframe
|
||||
for timeframe in sorted(market_state.features.keys()):
|
||||
features = market_state.features[timeframe]
|
||||
if features is not None:
|
||||
# Take the last row (most recent) and flatten
|
||||
latest_features = features[-1] if len(features.shape) > 1 else features
|
||||
state_components.extend(latest_features.flatten())
|
||||
|
||||
return np.array(state_components, dtype=np.float32)
|
||||
def _get_current_open_threshold(self) -> float:
|
||||
"""Get current opening threshold"""
|
||||
return self.confidence_threshold_open
|
||||
|
||||
def _get_current_close_threshold(self) -> float:
|
||||
"""Get current closing threshold"""
|
||||
return self.confidence_threshold_close
|
||||
|
||||
def process_realtime_features(self, feature_dict: Dict[str, Any]):
|
||||
"""Process real-time tick features from the tick processor"""
|
||||
|
Reference in New Issue
Block a user