new overhaul
This commit is contained in:
698
core/enhanced_orchestrator.py
Normal file
698
core/enhanced_orchestrator.py
Normal file
@ -0,0 +1,698 @@
|
||||
"""
|
||||
Enhanced Trading Orchestrator - Advanced Multi-Modal Decision Making
|
||||
|
||||
This enhanced orchestrator implements:
|
||||
1. Multi-timeframe CNN predictions with individual confidence scores
|
||||
2. Advanced RL feedback loop for continuous learning
|
||||
3. Multi-symbol (ETH, BTC) coordinated decision making
|
||||
4. Perfect move marking for CNN backpropagation training
|
||||
5. Market environment adaptation through RL evaluation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque
|
||||
import torch
|
||||
|
||||
from .config import get_config
|
||||
from .data_provider import DataProvider
|
||||
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TimeframePrediction:
|
||||
"""CNN prediction for a specific timeframe with confidence"""
|
||||
timeframe: str
|
||||
action: str # 'BUY', 'SELL', 'HOLD'
|
||||
confidence: float # 0.0 to 1.0
|
||||
probabilities: Dict[str, float] # Action probabilities
|
||||
timestamp: datetime
|
||||
market_features: Dict[str, float] = field(default_factory=dict) # Additional context
|
||||
|
||||
@dataclass
|
||||
class EnhancedPrediction:
|
||||
"""Enhanced prediction structure with timeframe breakdown"""
|
||||
symbol: str
|
||||
timeframe_predictions: List[TimeframePrediction]
|
||||
overall_action: str
|
||||
overall_confidence: float
|
||||
model_name: str
|
||||
timestamp: datetime
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class TradingAction:
|
||||
"""Represents a trading action with full context"""
|
||||
symbol: str
|
||||
action: str # 'BUY', 'SELL', 'HOLD'
|
||||
quantity: float
|
||||
confidence: float
|
||||
price: float
|
||||
timestamp: datetime
|
||||
reasoning: Dict[str, Any]
|
||||
timeframe_analysis: List[TimeframePrediction]
|
||||
|
||||
@dataclass
|
||||
class MarketState:
|
||||
"""Complete market state for RL evaluation"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
prices: Dict[str, float] # {timeframe: current_price}
|
||||
features: Dict[str, np.ndarray] # {timeframe: feature_matrix}
|
||||
volatility: float
|
||||
volume: float
|
||||
trend_strength: float
|
||||
market_regime: str # 'trending', 'ranging', 'volatile'
|
||||
|
||||
@dataclass
|
||||
class PerfectMove:
|
||||
"""Marked perfect move for CNN training"""
|
||||
symbol: str
|
||||
timeframe: str
|
||||
timestamp: datetime
|
||||
optimal_action: str
|
||||
actual_outcome: float # Price change percentage
|
||||
market_state_before: MarketState
|
||||
market_state_after: MarketState
|
||||
confidence_should_have_been: float
|
||||
|
||||
class EnhancedTradingOrchestrator:
|
||||
"""
|
||||
Enhanced orchestrator with sophisticated multi-modal decision making
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider: DataProvider = None):
|
||||
"""Initialize the enhanced orchestrator"""
|
||||
self.config = get_config()
|
||||
self.data_provider = data_provider or DataProvider()
|
||||
self.model_registry = get_model_registry()
|
||||
|
||||
# Multi-symbol configuration
|
||||
self.symbols = self.config.symbols
|
||||
self.timeframes = self.config.timeframes
|
||||
|
||||
# Configuration
|
||||
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.6)
|
||||
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30)
|
||||
|
||||
# Enhanced weighting system
|
||||
self.timeframe_weights = self._initialize_timeframe_weights()
|
||||
self.symbol_correlation_matrix = self._initialize_correlation_matrix()
|
||||
|
||||
# State tracking for each symbol
|
||||
self.symbol_states = {symbol: {} for symbol in self.symbols}
|
||||
self.recent_actions = {symbol: deque(maxlen=100) for symbol in self.symbols}
|
||||
self.market_states = {symbol: deque(maxlen=1000) for symbol in self.symbols}
|
||||
|
||||
# Perfect move tracking for CNN training
|
||||
self.perfect_moves = deque(maxlen=10000)
|
||||
self.performance_tracker = {}
|
||||
|
||||
# RL feedback system
|
||||
self.rl_evaluation_queue = deque(maxlen=1000)
|
||||
self.environment_adaptation_rate = 0.01
|
||||
|
||||
# Decision callbacks
|
||||
self.decision_callbacks = []
|
||||
self.learning_callbacks = []
|
||||
|
||||
logger.info("Enhanced TradingOrchestrator initialized")
|
||||
logger.info(f"Symbols: {self.symbols}")
|
||||
logger.info(f"Timeframes: {self.timeframes}")
|
||||
logger.info(f"Enhanced confidence threshold: {self.confidence_threshold}")
|
||||
|
||||
def _initialize_timeframe_weights(self) -> Dict[str, float]:
|
||||
"""Initialize weights for different timeframes"""
|
||||
# Higher timeframes get more weight for trend direction
|
||||
# Lower timeframes get more weight for entry/exit timing
|
||||
base_weights = {
|
||||
'1m': 0.05, # Noise filtering
|
||||
'5m': 0.10, # Short-term momentum
|
||||
'15m': 0.15, # Entry/exit timing
|
||||
'1h': 0.25, # Medium-term trend
|
||||
'4h': 0.25, # Stronger trend confirmation
|
||||
'1d': 0.20 # Long-term direction
|
||||
}
|
||||
|
||||
# Normalize weights for configured timeframes
|
||||
configured_weights = {tf: base_weights.get(tf, 0.1) for tf in self.timeframes}
|
||||
total = sum(configured_weights.values())
|
||||
return {tf: w/total for tf, w in configured_weights.items()}
|
||||
|
||||
def _initialize_correlation_matrix(self) -> Dict[Tuple[str, str], float]:
|
||||
"""Initialize correlation matrix between symbols"""
|
||||
correlations = {}
|
||||
for i, symbol1 in enumerate(self.symbols):
|
||||
for j, symbol2 in enumerate(self.symbols):
|
||||
if i != j:
|
||||
# ETH and BTC are typically highly correlated
|
||||
if 'ETH' in symbol1 and 'BTC' in symbol2:
|
||||
correlations[(symbol1, symbol2)] = 0.85
|
||||
elif 'BTC' in symbol1 and 'ETH' in symbol2:
|
||||
correlations[(symbol1, symbol2)] = 0.85
|
||||
else:
|
||||
correlations[(symbol1, symbol2)] = 0.7 # Default correlation
|
||||
return correlations
|
||||
|
||||
async def make_coordinated_decisions(self) -> Dict[str, Optional[TradingAction]]:
|
||||
"""
|
||||
Make coordinated trading decisions across all symbols
|
||||
"""
|
||||
decisions = {}
|
||||
|
||||
try:
|
||||
# Get market states for all symbols
|
||||
market_states = await self._get_all_market_states()
|
||||
|
||||
# Get enhanced predictions for all symbols
|
||||
symbol_predictions = {}
|
||||
for symbol in self.symbols:
|
||||
if symbol in market_states:
|
||||
predictions = await self._get_enhanced_predictions(symbol, market_states[symbol])
|
||||
symbol_predictions[symbol] = predictions
|
||||
|
||||
# Coordinate decisions considering symbol correlations
|
||||
for symbol in self.symbols:
|
||||
if symbol in symbol_predictions:
|
||||
decision = await self._make_coordinated_decision(
|
||||
symbol,
|
||||
symbol_predictions[symbol],
|
||||
symbol_predictions,
|
||||
market_states[symbol]
|
||||
)
|
||||
decisions[symbol] = decision
|
||||
|
||||
# Queue for RL evaluation
|
||||
if decision and decision.action != 'HOLD':
|
||||
self._queue_for_rl_evaluation(decision, market_states[symbol])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in coordinated decision making: {e}")
|
||||
|
||||
return decisions
|
||||
|
||||
async def _get_all_market_states(self) -> Dict[str, MarketState]:
|
||||
"""Get current market state for all symbols"""
|
||||
market_states = {}
|
||||
|
||||
for symbol in self.symbols:
|
||||
try:
|
||||
# Get current market data for all timeframes
|
||||
prices = {}
|
||||
features = {}
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
# Get current price
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
if current_price:
|
||||
prices[timeframe] = current_price
|
||||
|
||||
# Get feature matrix for this timeframe
|
||||
feature_matrix = self.data_provider.get_feature_matrix(
|
||||
symbol=symbol,
|
||||
timeframes=[timeframe],
|
||||
window_size=20 # Standard window
|
||||
)
|
||||
if feature_matrix is not None:
|
||||
features[timeframe] = feature_matrix
|
||||
|
||||
if prices and features:
|
||||
# Calculate market metrics
|
||||
volatility = self._calculate_volatility(symbol)
|
||||
volume = self._get_current_volume(symbol)
|
||||
trend_strength = self._calculate_trend_strength(symbol)
|
||||
market_regime = self._determine_market_regime(symbol)
|
||||
|
||||
market_state = MarketState(
|
||||
symbol=symbol,
|
||||
timestamp=datetime.now(),
|
||||
prices=prices,
|
||||
features=features,
|
||||
volatility=volatility,
|
||||
volume=volume,
|
||||
trend_strength=trend_strength,
|
||||
market_regime=market_regime
|
||||
)
|
||||
|
||||
market_states[symbol] = market_state
|
||||
|
||||
# Store for historical tracking
|
||||
self.market_states[symbol].append(market_state)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting market state for {symbol}: {e}")
|
||||
|
||||
return market_states
|
||||
|
||||
async def _get_enhanced_predictions(self, symbol: str, market_state: MarketState) -> List[EnhancedPrediction]:
|
||||
"""Get enhanced predictions with timeframe breakdown"""
|
||||
predictions = []
|
||||
|
||||
for model_name, model in self.model_registry.models.items():
|
||||
try:
|
||||
if isinstance(model, CNNModelInterface):
|
||||
# Get CNN predictions for each timeframe
|
||||
timeframe_predictions = []
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
if timeframe in market_state.features:
|
||||
feature_matrix = market_state.features[timeframe]
|
||||
|
||||
# Get timeframe-specific prediction
|
||||
action_probs, confidence = await self._get_timeframe_prediction(
|
||||
model, feature_matrix, timeframe, market_state
|
||||
)
|
||||
|
||||
if action_probs is not None:
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
best_action_idx = np.argmax(action_probs)
|
||||
best_action = action_names[best_action_idx]
|
||||
|
||||
# Create timeframe prediction
|
||||
tf_prediction = TimeframePrediction(
|
||||
timeframe=timeframe,
|
||||
action=best_action,
|
||||
confidence=float(confidence),
|
||||
probabilities={name: float(prob) for name, prob in zip(action_names, action_probs)},
|
||||
timestamp=datetime.now(),
|
||||
market_features={
|
||||
'volatility': market_state.volatility,
|
||||
'volume': market_state.volume,
|
||||
'trend_strength': market_state.trend_strength
|
||||
}
|
||||
)
|
||||
timeframe_predictions.append(tf_prediction)
|
||||
|
||||
if timeframe_predictions:
|
||||
# Combine timeframe predictions into overall prediction
|
||||
overall_action, overall_confidence = self._combine_timeframe_predictions(
|
||||
timeframe_predictions, symbol
|
||||
)
|
||||
|
||||
enhanced_pred = EnhancedPrediction(
|
||||
symbol=symbol,
|
||||
timeframe_predictions=timeframe_predictions,
|
||||
overall_action=overall_action,
|
||||
overall_confidence=overall_confidence,
|
||||
model_name=model.name,
|
||||
timestamp=datetime.now(),
|
||||
metadata={
|
||||
'market_regime': market_state.market_regime,
|
||||
'symbol_correlation': self._get_symbol_correlation(symbol)
|
||||
}
|
||||
)
|
||||
predictions.append(enhanced_pred)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting enhanced predictions from {model_name}: {e}")
|
||||
|
||||
return predictions
|
||||
|
||||
async def _get_timeframe_prediction(self, model: CNNModelInterface, feature_matrix: np.ndarray,
|
||||
timeframe: str, market_state: MarketState) -> Tuple[Optional[np.ndarray], float]:
|
||||
"""Get prediction for specific timeframe with enhanced context"""
|
||||
try:
|
||||
# Check if model supports timeframe-specific prediction
|
||||
if hasattr(model, 'predict_timeframe'):
|
||||
action_probs, confidence = model.predict_timeframe(feature_matrix, timeframe)
|
||||
else:
|
||||
action_probs, confidence = model.predict(feature_matrix)
|
||||
|
||||
if action_probs is not None and confidence is not None:
|
||||
# Enhance confidence based on market conditions
|
||||
enhanced_confidence = self._enhance_confidence_with_context(
|
||||
confidence, timeframe, market_state
|
||||
)
|
||||
return action_probs, enhanced_confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting timeframe prediction for {timeframe}: {e}")
|
||||
|
||||
return None, 0.0
|
||||
|
||||
def _enhance_confidence_with_context(self, base_confidence: float, timeframe: str,
|
||||
market_state: MarketState) -> float:
|
||||
"""Enhance confidence score based on market context"""
|
||||
enhanced = base_confidence
|
||||
|
||||
# Adjust based on market regime
|
||||
if market_state.market_regime == 'trending':
|
||||
enhanced *= 1.1 # More confident in trending markets
|
||||
elif market_state.market_regime == 'volatile':
|
||||
enhanced *= 0.8 # Less confident in volatile markets
|
||||
|
||||
# Adjust based on timeframe reliability
|
||||
timeframe_reliability = {
|
||||
'1m': 0.7, '5m': 0.8, '15m': 0.9, '1h': 1.0, '4h': 1.1, '1d': 1.2
|
||||
}
|
||||
enhanced *= timeframe_reliability.get(timeframe, 1.0)
|
||||
|
||||
# Adjust based on volume
|
||||
if market_state.volume > 1.5: # High volume
|
||||
enhanced *= 1.05
|
||||
elif market_state.volume < 0.5: # Low volume
|
||||
enhanced *= 0.9
|
||||
|
||||
return min(enhanced, 1.0) # Cap at 1.0
|
||||
|
||||
def _combine_timeframe_predictions(self, timeframe_predictions: List[TimeframePrediction],
|
||||
symbol: str) -> Tuple[str, float]:
|
||||
"""Combine predictions from multiple timeframes"""
|
||||
action_scores = {'BUY': 0.0, 'SELL': 0.0, 'HOLD': 0.0}
|
||||
total_weight = 0.0
|
||||
|
||||
for tf_pred in timeframe_predictions:
|
||||
# Get timeframe weight
|
||||
tf_weight = self.timeframe_weights.get(tf_pred.timeframe, 0.1)
|
||||
|
||||
# Weight by confidence and timeframe importance
|
||||
weighted_confidence = tf_pred.confidence * tf_weight
|
||||
|
||||
# Add to action scores
|
||||
action_scores[tf_pred.action] += weighted_confidence
|
||||
total_weight += weighted_confidence
|
||||
|
||||
# Normalize scores
|
||||
if total_weight > 0:
|
||||
for action in action_scores:
|
||||
action_scores[action] /= total_weight
|
||||
|
||||
# Get best action and confidence
|
||||
best_action = max(action_scores, key=action_scores.get)
|
||||
best_confidence = action_scores[best_action]
|
||||
|
||||
return best_action, best_confidence
|
||||
|
||||
async def _make_coordinated_decision(self, symbol: str, predictions: List[EnhancedPrediction],
|
||||
all_predictions: Dict[str, List[EnhancedPrediction]],
|
||||
market_state: MarketState) -> Optional[TradingAction]:
|
||||
"""Make decision considering symbol correlations"""
|
||||
if not predictions:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Get primary prediction (highest confidence)
|
||||
primary_pred = max(predictions, key=lambda p: p.overall_confidence)
|
||||
|
||||
# Consider correlated symbols
|
||||
correlated_sentiment = self._get_correlated_sentiment(symbol, all_predictions)
|
||||
|
||||
# Adjust decision based on correlation
|
||||
final_action = primary_pred.overall_action
|
||||
final_confidence = primary_pred.overall_confidence
|
||||
|
||||
# If correlated symbols strongly disagree, reduce confidence
|
||||
if correlated_sentiment['agreement'] < 0.5:
|
||||
final_confidence *= 0.8
|
||||
logger.info(f"Reduced confidence for {symbol} due to correlation disagreement")
|
||||
|
||||
# Apply confidence threshold
|
||||
if final_confidence < self.confidence_threshold:
|
||||
final_action = 'HOLD'
|
||||
logger.info(f"Action for {symbol} changed to HOLD due to low confidence: {final_confidence:.3f}")
|
||||
|
||||
# Create trading action
|
||||
if final_action != 'HOLD':
|
||||
current_price = market_state.prices.get(self.timeframes[0], 0)
|
||||
quantity = self._calculate_position_size(symbol, final_action, final_confidence)
|
||||
|
||||
action = TradingAction(
|
||||
symbol=symbol,
|
||||
action=final_action,
|
||||
quantity=quantity,
|
||||
confidence=final_confidence,
|
||||
price=current_price,
|
||||
timestamp=datetime.now(),
|
||||
reasoning={
|
||||
'primary_model': primary_pred.model_name,
|
||||
'timeframe_breakdown': [(tf.timeframe, tf.action, tf.confidence)
|
||||
for tf in primary_pred.timeframe_predictions],
|
||||
'correlated_sentiment': correlated_sentiment,
|
||||
'market_regime': market_state.market_regime
|
||||
},
|
||||
timeframe_analysis=primary_pred.timeframe_predictions
|
||||
)
|
||||
|
||||
# Store recent action
|
||||
self.recent_actions[symbol].append(action)
|
||||
|
||||
return action
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error making coordinated decision for {symbol}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _get_correlated_sentiment(self, symbol: str,
|
||||
all_predictions: Dict[str, List[EnhancedPrediction]]) -> Dict[str, Any]:
|
||||
"""Get sentiment from correlated symbols"""
|
||||
correlated_actions = []
|
||||
correlated_confidences = []
|
||||
|
||||
for other_symbol, predictions in all_predictions.items():
|
||||
if other_symbol != symbol and predictions:
|
||||
correlation = self.symbol_correlation_matrix.get((symbol, other_symbol), 0.0)
|
||||
|
||||
if correlation > 0.5: # Only consider significantly correlated symbols
|
||||
best_pred = max(predictions, key=lambda p: p.overall_confidence)
|
||||
correlated_actions.append(best_pred.overall_action)
|
||||
correlated_confidences.append(best_pred.overall_confidence * correlation)
|
||||
|
||||
if not correlated_actions:
|
||||
return {'agreement': 1.0, 'sentiment': 'NEUTRAL'}
|
||||
|
||||
# Calculate agreement
|
||||
primary_pred = all_predictions[symbol][0] if all_predictions.get(symbol) else None
|
||||
if primary_pred:
|
||||
agreement_count = sum(1 for action in correlated_actions
|
||||
if action == primary_pred.overall_action)
|
||||
agreement = agreement_count / len(correlated_actions)
|
||||
else:
|
||||
agreement = 0.5
|
||||
|
||||
# Calculate overall sentiment
|
||||
action_weights = {'BUY': 0.0, 'SELL': 0.0, 'HOLD': 0.0}
|
||||
for action, confidence in zip(correlated_actions, correlated_confidences):
|
||||
action_weights[action] += confidence
|
||||
|
||||
dominant_sentiment = max(action_weights, key=action_weights.get)
|
||||
|
||||
return {
|
||||
'agreement': agreement,
|
||||
'sentiment': dominant_sentiment,
|
||||
'correlated_symbols': len(correlated_actions)
|
||||
}
|
||||
|
||||
def _queue_for_rl_evaluation(self, action: TradingAction, market_state: MarketState):
|
||||
"""Queue trading action for RL evaluation"""
|
||||
evaluation_item = {
|
||||
'action': action,
|
||||
'market_state_before': market_state,
|
||||
'timestamp': datetime.now(),
|
||||
'evaluation_pending': True
|
||||
}
|
||||
self.rl_evaluation_queue.append(evaluation_item)
|
||||
|
||||
async def evaluate_actions_with_rl(self):
|
||||
"""Evaluate recent actions using RL agents for continuous learning"""
|
||||
if not self.rl_evaluation_queue:
|
||||
return
|
||||
|
||||
current_time = datetime.now()
|
||||
|
||||
# Process actions that are ready for evaluation (e.g., 1 hour old)
|
||||
for item in list(self.rl_evaluation_queue):
|
||||
if item['evaluation_pending']:
|
||||
time_since_action = (current_time - item['timestamp']).total_seconds()
|
||||
|
||||
# Evaluate after sufficient time has passed
|
||||
if time_since_action >= 3600: # 1 hour
|
||||
await self._evaluate_single_action(item)
|
||||
item['evaluation_pending'] = False
|
||||
|
||||
async def _evaluate_single_action(self, evaluation_item: Dict[str, Any]):
|
||||
"""Evaluate a single action using RL"""
|
||||
try:
|
||||
action = evaluation_item['action']
|
||||
initial_state = evaluation_item['market_state_before']
|
||||
|
||||
# Get current market state for comparison
|
||||
current_market_states = await self._get_all_market_states()
|
||||
current_state = current_market_states.get(action.symbol)
|
||||
|
||||
if current_state:
|
||||
# Calculate reward based on price movement
|
||||
initial_price = initial_state.prices.get(self.timeframes[0], 0)
|
||||
current_price = current_state.prices.get(self.timeframes[0], 0)
|
||||
|
||||
if initial_price > 0:
|
||||
price_change = (current_price - initial_price) / initial_price
|
||||
|
||||
# Calculate reward based on action and price movement
|
||||
reward = self._calculate_reward(action.action, price_change, action.confidence)
|
||||
|
||||
# Update RL agents
|
||||
await self._update_rl_agents(action, initial_state, current_state, reward)
|
||||
|
||||
# Check if this was a perfect move for CNN training
|
||||
if abs(reward) > 0.02: # Significant outcome
|
||||
self._mark_perfect_move(action, initial_state, current_state, reward)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating action: {e}")
|
||||
|
||||
def _calculate_reward(self, action: str, price_change: float, confidence: float) -> float:
|
||||
"""Calculate reward for RL training"""
|
||||
base_reward = 0.0
|
||||
|
||||
if action == 'BUY' and price_change > 0:
|
||||
base_reward = price_change * 10 # Reward proportional to gain
|
||||
elif action == 'SELL' and price_change < 0:
|
||||
base_reward = abs(price_change) * 10 # Reward for avoiding loss
|
||||
elif action == 'HOLD':
|
||||
base_reward = 0.01 if abs(price_change) < 0.005 else -0.01 # Small reward for correct holds
|
||||
else:
|
||||
base_reward = -abs(price_change) * 5 # Penalty for wrong actions
|
||||
|
||||
# Adjust reward based on confidence
|
||||
confidence_multiplier = 0.5 + confidence # 0.5 to 1.5 range
|
||||
|
||||
return base_reward * confidence_multiplier
|
||||
|
||||
async def _update_rl_agents(self, action: TradingAction, initial_state: MarketState,
|
||||
current_state: MarketState, reward: float):
|
||||
"""Update RL agents with action evaluation"""
|
||||
for model_name, model in self.model_registry.models.items():
|
||||
if isinstance(model, RLAgentInterface):
|
||||
try:
|
||||
# Convert market states to RL state format
|
||||
initial_rl_state = self._market_state_to_rl_state(initial_state)
|
||||
current_rl_state = self._market_state_to_rl_state(current_state)
|
||||
|
||||
# Convert action to RL action index
|
||||
action_idx = {'SELL': 0, 'HOLD': 1, 'BUY': 2}.get(action.action, 1)
|
||||
|
||||
# Store experience
|
||||
model.remember(
|
||||
state=initial_rl_state,
|
||||
action=action_idx,
|
||||
reward=reward,
|
||||
next_state=current_rl_state,
|
||||
done=False
|
||||
)
|
||||
|
||||
# Trigger replay learning
|
||||
loss = model.replay()
|
||||
if loss is not None:
|
||||
logger.info(f"RL agent {model_name} updated with loss: {loss:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating RL agent {model_name}: {e}")
|
||||
|
||||
def _mark_perfect_move(self, action: TradingAction, initial_state: MarketState,
|
||||
final_state: MarketState, reward: float):
|
||||
"""Mark a perfect move for CNN training"""
|
||||
try:
|
||||
# Determine what the optimal action should have been
|
||||
optimal_action = action.action if reward > 0 else ('HOLD' if action.action == 'HOLD' else
|
||||
('SELL' if action.action == 'BUY' else 'BUY'))
|
||||
|
||||
# Calculate what confidence should have been
|
||||
optimal_confidence = min(0.95, abs(reward) * 10) # Higher reward = higher confidence should have been
|
||||
|
||||
for tf_pred in action.timeframe_analysis:
|
||||
perfect_move = PerfectMove(
|
||||
symbol=action.symbol,
|
||||
timeframe=tf_pred.timeframe,
|
||||
timestamp=action.timestamp,
|
||||
optimal_action=optimal_action,
|
||||
actual_outcome=reward,
|
||||
market_state_before=initial_state,
|
||||
market_state_after=final_state,
|
||||
confidence_should_have_been=optimal_confidence
|
||||
)
|
||||
self.perfect_moves.append(perfect_move)
|
||||
|
||||
logger.info(f"Marked perfect move for {action.symbol}: {optimal_action} with confidence {optimal_confidence:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error marking perfect move: {e}")
|
||||
|
||||
def get_perfect_moves_for_training(self, symbol: str = None, timeframe: str = None,
|
||||
limit: int = 1000) -> List[PerfectMove]:
|
||||
"""Get perfect moves for CNN training"""
|
||||
moves = list(self.perfect_moves)
|
||||
|
||||
if symbol:
|
||||
moves = [m for m in moves if m.symbol == symbol]
|
||||
|
||||
if timeframe:
|
||||
moves = [m for m in moves if m.timeframe == timeframe]
|
||||
|
||||
return moves[-limit:] if limit else moves
|
||||
|
||||
# Helper methods for market analysis
|
||||
def _calculate_volatility(self, symbol: str) -> float:
|
||||
"""Calculate current volatility for symbol"""
|
||||
# Placeholder - implement based on your data provider
|
||||
return 0.02 # 2% default volatility
|
||||
|
||||
def _get_current_volume(self, symbol: str) -> float:
|
||||
"""Get current volume ratio compared to average"""
|
||||
# Placeholder - implement based on your data provider
|
||||
return 1.0 # Normal volume
|
||||
|
||||
def _calculate_trend_strength(self, symbol: str) -> float:
|
||||
"""Calculate trend strength (0 = no trend, 1 = strong trend)"""
|
||||
# Placeholder - implement based on your data provider
|
||||
return 0.5 # Moderate trend
|
||||
|
||||
def _determine_market_regime(self, symbol: str) -> str:
|
||||
"""Determine current market regime"""
|
||||
# Placeholder - implement based on your analysis
|
||||
return 'trending' # Default to trending
|
||||
|
||||
def _get_symbol_correlation(self, symbol: str) -> Dict[str, float]:
|
||||
"""Get correlations with other symbols"""
|
||||
correlations = {}
|
||||
for other_symbol in self.symbols:
|
||||
if other_symbol != symbol:
|
||||
correlations[other_symbol] = self.symbol_correlation_matrix.get((symbol, other_symbol), 0.0)
|
||||
return correlations
|
||||
|
||||
def _calculate_position_size(self, symbol: str, action: str, confidence: float) -> float:
|
||||
"""Calculate position size based on confidence and risk management"""
|
||||
base_size = 0.02 # 2% of portfolio
|
||||
confidence_multiplier = confidence # Scale by confidence
|
||||
max_size = 0.05 # 5% maximum
|
||||
|
||||
return min(base_size * confidence_multiplier, max_size)
|
||||
|
||||
def _market_state_to_rl_state(self, market_state: MarketState) -> np.ndarray:
|
||||
"""Convert market state to RL state vector"""
|
||||
# Combine features from all timeframes into a single state vector
|
||||
state_components = []
|
||||
|
||||
# Add price features
|
||||
state_components.extend([
|
||||
market_state.volatility,
|
||||
market_state.volume,
|
||||
market_state.trend_strength
|
||||
])
|
||||
|
||||
# Add flattened features from each timeframe
|
||||
for timeframe in sorted(market_state.features.keys()):
|
||||
features = market_state.features[timeframe]
|
||||
if features is not None:
|
||||
# Take the last row (most recent) and flatten
|
||||
latest_features = features[-1] if len(features.shape) > 1 else features
|
||||
state_components.extend(latest_features.flatten())
|
||||
|
||||
return np.array(state_components, dtype=np.float32)
|
Reference in New Issue
Block a user