698 lines
30 KiB
Python
698 lines
30 KiB
Python
"""
|
|
Enhanced Trading Orchestrator - Advanced Multi-Modal Decision Making
|
|
|
|
This enhanced orchestrator implements:
|
|
1. Multi-timeframe CNN predictions with individual confidence scores
|
|
2. Advanced RL feedback loop for continuous learning
|
|
3. Multi-symbol (ETH, BTC) coordinated decision making
|
|
4. Perfect move marking for CNN backpropagation training
|
|
5. Market environment adaptation through RL evaluation
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
import numpy as np
|
|
import pandas as pd
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Tuple, Any, Union
|
|
from dataclasses import dataclass, field
|
|
from collections import deque
|
|
import torch
|
|
|
|
from .config import get_config
|
|
from .data_provider import DataProvider
|
|
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class TimeframePrediction:
|
|
"""CNN prediction for a specific timeframe with confidence"""
|
|
timeframe: str
|
|
action: str # 'BUY', 'SELL', 'HOLD'
|
|
confidence: float # 0.0 to 1.0
|
|
probabilities: Dict[str, float] # Action probabilities
|
|
timestamp: datetime
|
|
market_features: Dict[str, float] = field(default_factory=dict) # Additional context
|
|
|
|
@dataclass
|
|
class EnhancedPrediction:
|
|
"""Enhanced prediction structure with timeframe breakdown"""
|
|
symbol: str
|
|
timeframe_predictions: List[TimeframePrediction]
|
|
overall_action: str
|
|
overall_confidence: float
|
|
model_name: str
|
|
timestamp: datetime
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
@dataclass
|
|
class TradingAction:
|
|
"""Represents a trading action with full context"""
|
|
symbol: str
|
|
action: str # 'BUY', 'SELL', 'HOLD'
|
|
quantity: float
|
|
confidence: float
|
|
price: float
|
|
timestamp: datetime
|
|
reasoning: Dict[str, Any]
|
|
timeframe_analysis: List[TimeframePrediction]
|
|
|
|
@dataclass
|
|
class MarketState:
|
|
"""Complete market state for RL evaluation"""
|
|
symbol: str
|
|
timestamp: datetime
|
|
prices: Dict[str, float] # {timeframe: current_price}
|
|
features: Dict[str, np.ndarray] # {timeframe: feature_matrix}
|
|
volatility: float
|
|
volume: float
|
|
trend_strength: float
|
|
market_regime: str # 'trending', 'ranging', 'volatile'
|
|
|
|
@dataclass
|
|
class PerfectMove:
|
|
"""Marked perfect move for CNN training"""
|
|
symbol: str
|
|
timeframe: str
|
|
timestamp: datetime
|
|
optimal_action: str
|
|
actual_outcome: float # Price change percentage
|
|
market_state_before: MarketState
|
|
market_state_after: MarketState
|
|
confidence_should_have_been: float
|
|
|
|
class EnhancedTradingOrchestrator:
|
|
"""
|
|
Enhanced orchestrator with sophisticated multi-modal decision making
|
|
"""
|
|
|
|
def __init__(self, data_provider: DataProvider = None):
|
|
"""Initialize the enhanced orchestrator"""
|
|
self.config = get_config()
|
|
self.data_provider = data_provider or DataProvider()
|
|
self.model_registry = get_model_registry()
|
|
|
|
# Multi-symbol configuration
|
|
self.symbols = self.config.symbols
|
|
self.timeframes = self.config.timeframes
|
|
|
|
# Configuration
|
|
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.6)
|
|
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30)
|
|
|
|
# Enhanced weighting system
|
|
self.timeframe_weights = self._initialize_timeframe_weights()
|
|
self.symbol_correlation_matrix = self._initialize_correlation_matrix()
|
|
|
|
# State tracking for each symbol
|
|
self.symbol_states = {symbol: {} for symbol in self.symbols}
|
|
self.recent_actions = {symbol: deque(maxlen=100) for symbol in self.symbols}
|
|
self.market_states = {symbol: deque(maxlen=1000) for symbol in self.symbols}
|
|
|
|
# Perfect move tracking for CNN training
|
|
self.perfect_moves = deque(maxlen=10000)
|
|
self.performance_tracker = {}
|
|
|
|
# RL feedback system
|
|
self.rl_evaluation_queue = deque(maxlen=1000)
|
|
self.environment_adaptation_rate = 0.01
|
|
|
|
# Decision callbacks
|
|
self.decision_callbacks = []
|
|
self.learning_callbacks = []
|
|
|
|
logger.info("Enhanced TradingOrchestrator initialized")
|
|
logger.info(f"Symbols: {self.symbols}")
|
|
logger.info(f"Timeframes: {self.timeframes}")
|
|
logger.info(f"Enhanced confidence threshold: {self.confidence_threshold}")
|
|
|
|
def _initialize_timeframe_weights(self) -> Dict[str, float]:
|
|
"""Initialize weights for different timeframes"""
|
|
# Higher timeframes get more weight for trend direction
|
|
# Lower timeframes get more weight for entry/exit timing
|
|
base_weights = {
|
|
'1m': 0.05, # Noise filtering
|
|
'5m': 0.10, # Short-term momentum
|
|
'15m': 0.15, # Entry/exit timing
|
|
'1h': 0.25, # Medium-term trend
|
|
'4h': 0.25, # Stronger trend confirmation
|
|
'1d': 0.20 # Long-term direction
|
|
}
|
|
|
|
# Normalize weights for configured timeframes
|
|
configured_weights = {tf: base_weights.get(tf, 0.1) for tf in self.timeframes}
|
|
total = sum(configured_weights.values())
|
|
return {tf: w/total for tf, w in configured_weights.items()}
|
|
|
|
def _initialize_correlation_matrix(self) -> Dict[Tuple[str, str], float]:
|
|
"""Initialize correlation matrix between symbols"""
|
|
correlations = {}
|
|
for i, symbol1 in enumerate(self.symbols):
|
|
for j, symbol2 in enumerate(self.symbols):
|
|
if i != j:
|
|
# ETH and BTC are typically highly correlated
|
|
if 'ETH' in symbol1 and 'BTC' in symbol2:
|
|
correlations[(symbol1, symbol2)] = 0.85
|
|
elif 'BTC' in symbol1 and 'ETH' in symbol2:
|
|
correlations[(symbol1, symbol2)] = 0.85
|
|
else:
|
|
correlations[(symbol1, symbol2)] = 0.7 # Default correlation
|
|
return correlations
|
|
|
|
async def make_coordinated_decisions(self) -> Dict[str, Optional[TradingAction]]:
|
|
"""
|
|
Make coordinated trading decisions across all symbols
|
|
"""
|
|
decisions = {}
|
|
|
|
try:
|
|
# Get market states for all symbols
|
|
market_states = await self._get_all_market_states()
|
|
|
|
# Get enhanced predictions for all symbols
|
|
symbol_predictions = {}
|
|
for symbol in self.symbols:
|
|
if symbol in market_states:
|
|
predictions = await self._get_enhanced_predictions(symbol, market_states[symbol])
|
|
symbol_predictions[symbol] = predictions
|
|
|
|
# Coordinate decisions considering symbol correlations
|
|
for symbol in self.symbols:
|
|
if symbol in symbol_predictions:
|
|
decision = await self._make_coordinated_decision(
|
|
symbol,
|
|
symbol_predictions[symbol],
|
|
symbol_predictions,
|
|
market_states[symbol]
|
|
)
|
|
decisions[symbol] = decision
|
|
|
|
# Queue for RL evaluation
|
|
if decision and decision.action != 'HOLD':
|
|
self._queue_for_rl_evaluation(decision, market_states[symbol])
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in coordinated decision making: {e}")
|
|
|
|
return decisions
|
|
|
|
async def _get_all_market_states(self) -> Dict[str, MarketState]:
|
|
"""Get current market state for all symbols"""
|
|
market_states = {}
|
|
|
|
for symbol in self.symbols:
|
|
try:
|
|
# Get current market data for all timeframes
|
|
prices = {}
|
|
features = {}
|
|
|
|
for timeframe in self.timeframes:
|
|
# Get current price
|
|
current_price = self.data_provider.get_current_price(symbol)
|
|
if current_price:
|
|
prices[timeframe] = current_price
|
|
|
|
# Get feature matrix for this timeframe
|
|
feature_matrix = self.data_provider.get_feature_matrix(
|
|
symbol=symbol,
|
|
timeframes=[timeframe],
|
|
window_size=20 # Standard window
|
|
)
|
|
if feature_matrix is not None:
|
|
features[timeframe] = feature_matrix
|
|
|
|
if prices and features:
|
|
# Calculate market metrics
|
|
volatility = self._calculate_volatility(symbol)
|
|
volume = self._get_current_volume(symbol)
|
|
trend_strength = self._calculate_trend_strength(symbol)
|
|
market_regime = self._determine_market_regime(symbol)
|
|
|
|
market_state = MarketState(
|
|
symbol=symbol,
|
|
timestamp=datetime.now(),
|
|
prices=prices,
|
|
features=features,
|
|
volatility=volatility,
|
|
volume=volume,
|
|
trend_strength=trend_strength,
|
|
market_regime=market_regime
|
|
)
|
|
|
|
market_states[symbol] = market_state
|
|
|
|
# Store for historical tracking
|
|
self.market_states[symbol].append(market_state)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting market state for {symbol}: {e}")
|
|
|
|
return market_states
|
|
|
|
async def _get_enhanced_predictions(self, symbol: str, market_state: MarketState) -> List[EnhancedPrediction]:
|
|
"""Get enhanced predictions with timeframe breakdown"""
|
|
predictions = []
|
|
|
|
for model_name, model in self.model_registry.models.items():
|
|
try:
|
|
if isinstance(model, CNNModelInterface):
|
|
# Get CNN predictions for each timeframe
|
|
timeframe_predictions = []
|
|
|
|
for timeframe in self.timeframes:
|
|
if timeframe in market_state.features:
|
|
feature_matrix = market_state.features[timeframe]
|
|
|
|
# Get timeframe-specific prediction
|
|
action_probs, confidence = await self._get_timeframe_prediction(
|
|
model, feature_matrix, timeframe, market_state
|
|
)
|
|
|
|
if action_probs is not None:
|
|
action_names = ['SELL', 'HOLD', 'BUY']
|
|
best_action_idx = np.argmax(action_probs)
|
|
best_action = action_names[best_action_idx]
|
|
|
|
# Create timeframe prediction
|
|
tf_prediction = TimeframePrediction(
|
|
timeframe=timeframe,
|
|
action=best_action,
|
|
confidence=float(confidence),
|
|
probabilities={name: float(prob) for name, prob in zip(action_names, action_probs)},
|
|
timestamp=datetime.now(),
|
|
market_features={
|
|
'volatility': market_state.volatility,
|
|
'volume': market_state.volume,
|
|
'trend_strength': market_state.trend_strength
|
|
}
|
|
)
|
|
timeframe_predictions.append(tf_prediction)
|
|
|
|
if timeframe_predictions:
|
|
# Combine timeframe predictions into overall prediction
|
|
overall_action, overall_confidence = self._combine_timeframe_predictions(
|
|
timeframe_predictions, symbol
|
|
)
|
|
|
|
enhanced_pred = EnhancedPrediction(
|
|
symbol=symbol,
|
|
timeframe_predictions=timeframe_predictions,
|
|
overall_action=overall_action,
|
|
overall_confidence=overall_confidence,
|
|
model_name=model.name,
|
|
timestamp=datetime.now(),
|
|
metadata={
|
|
'market_regime': market_state.market_regime,
|
|
'symbol_correlation': self._get_symbol_correlation(symbol)
|
|
}
|
|
)
|
|
predictions.append(enhanced_pred)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting enhanced predictions from {model_name}: {e}")
|
|
|
|
return predictions
|
|
|
|
async def _get_timeframe_prediction(self, model: CNNModelInterface, feature_matrix: np.ndarray,
|
|
timeframe: str, market_state: MarketState) -> Tuple[Optional[np.ndarray], float]:
|
|
"""Get prediction for specific timeframe with enhanced context"""
|
|
try:
|
|
# Check if model supports timeframe-specific prediction
|
|
if hasattr(model, 'predict_timeframe'):
|
|
action_probs, confidence = model.predict_timeframe(feature_matrix, timeframe)
|
|
else:
|
|
action_probs, confidence = model.predict(feature_matrix)
|
|
|
|
if action_probs is not None and confidence is not None:
|
|
# Enhance confidence based on market conditions
|
|
enhanced_confidence = self._enhance_confidence_with_context(
|
|
confidence, timeframe, market_state
|
|
)
|
|
return action_probs, enhanced_confidence
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting timeframe prediction for {timeframe}: {e}")
|
|
|
|
return None, 0.0
|
|
|
|
def _enhance_confidence_with_context(self, base_confidence: float, timeframe: str,
|
|
market_state: MarketState) -> float:
|
|
"""Enhance confidence score based on market context"""
|
|
enhanced = base_confidence
|
|
|
|
# Adjust based on market regime
|
|
if market_state.market_regime == 'trending':
|
|
enhanced *= 1.1 # More confident in trending markets
|
|
elif market_state.market_regime == 'volatile':
|
|
enhanced *= 0.8 # Less confident in volatile markets
|
|
|
|
# Adjust based on timeframe reliability
|
|
timeframe_reliability = {
|
|
'1m': 0.7, '5m': 0.8, '15m': 0.9, '1h': 1.0, '4h': 1.1, '1d': 1.2
|
|
}
|
|
enhanced *= timeframe_reliability.get(timeframe, 1.0)
|
|
|
|
# Adjust based on volume
|
|
if market_state.volume > 1.5: # High volume
|
|
enhanced *= 1.05
|
|
elif market_state.volume < 0.5: # Low volume
|
|
enhanced *= 0.9
|
|
|
|
return min(enhanced, 1.0) # Cap at 1.0
|
|
|
|
def _combine_timeframe_predictions(self, timeframe_predictions: List[TimeframePrediction],
|
|
symbol: str) -> Tuple[str, float]:
|
|
"""Combine predictions from multiple timeframes"""
|
|
action_scores = {'BUY': 0.0, 'SELL': 0.0, 'HOLD': 0.0}
|
|
total_weight = 0.0
|
|
|
|
for tf_pred in timeframe_predictions:
|
|
# Get timeframe weight
|
|
tf_weight = self.timeframe_weights.get(tf_pred.timeframe, 0.1)
|
|
|
|
# Weight by confidence and timeframe importance
|
|
weighted_confidence = tf_pred.confidence * tf_weight
|
|
|
|
# Add to action scores
|
|
action_scores[tf_pred.action] += weighted_confidence
|
|
total_weight += weighted_confidence
|
|
|
|
# Normalize scores
|
|
if total_weight > 0:
|
|
for action in action_scores:
|
|
action_scores[action] /= total_weight
|
|
|
|
# Get best action and confidence
|
|
best_action = max(action_scores, key=action_scores.get)
|
|
best_confidence = action_scores[best_action]
|
|
|
|
return best_action, best_confidence
|
|
|
|
async def _make_coordinated_decision(self, symbol: str, predictions: List[EnhancedPrediction],
|
|
all_predictions: Dict[str, List[EnhancedPrediction]],
|
|
market_state: MarketState) -> Optional[TradingAction]:
|
|
"""Make decision considering symbol correlations"""
|
|
if not predictions:
|
|
return None
|
|
|
|
try:
|
|
# Get primary prediction (highest confidence)
|
|
primary_pred = max(predictions, key=lambda p: p.overall_confidence)
|
|
|
|
# Consider correlated symbols
|
|
correlated_sentiment = self._get_correlated_sentiment(symbol, all_predictions)
|
|
|
|
# Adjust decision based on correlation
|
|
final_action = primary_pred.overall_action
|
|
final_confidence = primary_pred.overall_confidence
|
|
|
|
# If correlated symbols strongly disagree, reduce confidence
|
|
if correlated_sentiment['agreement'] < 0.5:
|
|
final_confidence *= 0.8
|
|
logger.info(f"Reduced confidence for {symbol} due to correlation disagreement")
|
|
|
|
# Apply confidence threshold
|
|
if final_confidence < self.confidence_threshold:
|
|
final_action = 'HOLD'
|
|
logger.info(f"Action for {symbol} changed to HOLD due to low confidence: {final_confidence:.3f}")
|
|
|
|
# Create trading action
|
|
if final_action != 'HOLD':
|
|
current_price = market_state.prices.get(self.timeframes[0], 0)
|
|
quantity = self._calculate_position_size(symbol, final_action, final_confidence)
|
|
|
|
action = TradingAction(
|
|
symbol=symbol,
|
|
action=final_action,
|
|
quantity=quantity,
|
|
confidence=final_confidence,
|
|
price=current_price,
|
|
timestamp=datetime.now(),
|
|
reasoning={
|
|
'primary_model': primary_pred.model_name,
|
|
'timeframe_breakdown': [(tf.timeframe, tf.action, tf.confidence)
|
|
for tf in primary_pred.timeframe_predictions],
|
|
'correlated_sentiment': correlated_sentiment,
|
|
'market_regime': market_state.market_regime
|
|
},
|
|
timeframe_analysis=primary_pred.timeframe_predictions
|
|
)
|
|
|
|
# Store recent action
|
|
self.recent_actions[symbol].append(action)
|
|
|
|
return action
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error making coordinated decision for {symbol}: {e}")
|
|
|
|
return None
|
|
|
|
def _get_correlated_sentiment(self, symbol: str,
|
|
all_predictions: Dict[str, List[EnhancedPrediction]]) -> Dict[str, Any]:
|
|
"""Get sentiment from correlated symbols"""
|
|
correlated_actions = []
|
|
correlated_confidences = []
|
|
|
|
for other_symbol, predictions in all_predictions.items():
|
|
if other_symbol != symbol and predictions:
|
|
correlation = self.symbol_correlation_matrix.get((symbol, other_symbol), 0.0)
|
|
|
|
if correlation > 0.5: # Only consider significantly correlated symbols
|
|
best_pred = max(predictions, key=lambda p: p.overall_confidence)
|
|
correlated_actions.append(best_pred.overall_action)
|
|
correlated_confidences.append(best_pred.overall_confidence * correlation)
|
|
|
|
if not correlated_actions:
|
|
return {'agreement': 1.0, 'sentiment': 'NEUTRAL'}
|
|
|
|
# Calculate agreement
|
|
primary_pred = all_predictions[symbol][0] if all_predictions.get(symbol) else None
|
|
if primary_pred:
|
|
agreement_count = sum(1 for action in correlated_actions
|
|
if action == primary_pred.overall_action)
|
|
agreement = agreement_count / len(correlated_actions)
|
|
else:
|
|
agreement = 0.5
|
|
|
|
# Calculate overall sentiment
|
|
action_weights = {'BUY': 0.0, 'SELL': 0.0, 'HOLD': 0.0}
|
|
for action, confidence in zip(correlated_actions, correlated_confidences):
|
|
action_weights[action] += confidence
|
|
|
|
dominant_sentiment = max(action_weights, key=action_weights.get)
|
|
|
|
return {
|
|
'agreement': agreement,
|
|
'sentiment': dominant_sentiment,
|
|
'correlated_symbols': len(correlated_actions)
|
|
}
|
|
|
|
def _queue_for_rl_evaluation(self, action: TradingAction, market_state: MarketState):
|
|
"""Queue trading action for RL evaluation"""
|
|
evaluation_item = {
|
|
'action': action,
|
|
'market_state_before': market_state,
|
|
'timestamp': datetime.now(),
|
|
'evaluation_pending': True
|
|
}
|
|
self.rl_evaluation_queue.append(evaluation_item)
|
|
|
|
async def evaluate_actions_with_rl(self):
|
|
"""Evaluate recent actions using RL agents for continuous learning"""
|
|
if not self.rl_evaluation_queue:
|
|
return
|
|
|
|
current_time = datetime.now()
|
|
|
|
# Process actions that are ready for evaluation (e.g., 1 hour old)
|
|
for item in list(self.rl_evaluation_queue):
|
|
if item['evaluation_pending']:
|
|
time_since_action = (current_time - item['timestamp']).total_seconds()
|
|
|
|
# Evaluate after sufficient time has passed
|
|
if time_since_action >= 3600: # 1 hour
|
|
await self._evaluate_single_action(item)
|
|
item['evaluation_pending'] = False
|
|
|
|
async def _evaluate_single_action(self, evaluation_item: Dict[str, Any]):
|
|
"""Evaluate a single action using RL"""
|
|
try:
|
|
action = evaluation_item['action']
|
|
initial_state = evaluation_item['market_state_before']
|
|
|
|
# Get current market state for comparison
|
|
current_market_states = await self._get_all_market_states()
|
|
current_state = current_market_states.get(action.symbol)
|
|
|
|
if current_state:
|
|
# Calculate reward based on price movement
|
|
initial_price = initial_state.prices.get(self.timeframes[0], 0)
|
|
current_price = current_state.prices.get(self.timeframes[0], 0)
|
|
|
|
if initial_price > 0:
|
|
price_change = (current_price - initial_price) / initial_price
|
|
|
|
# Calculate reward based on action and price movement
|
|
reward = self._calculate_reward(action.action, price_change, action.confidence)
|
|
|
|
# Update RL agents
|
|
await self._update_rl_agents(action, initial_state, current_state, reward)
|
|
|
|
# Check if this was a perfect move for CNN training
|
|
if abs(reward) > 0.02: # Significant outcome
|
|
self._mark_perfect_move(action, initial_state, current_state, reward)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error evaluating action: {e}")
|
|
|
|
def _calculate_reward(self, action: str, price_change: float, confidence: float) -> float:
|
|
"""Calculate reward for RL training"""
|
|
base_reward = 0.0
|
|
|
|
if action == 'BUY' and price_change > 0:
|
|
base_reward = price_change * 10 # Reward proportional to gain
|
|
elif action == 'SELL' and price_change < 0:
|
|
base_reward = abs(price_change) * 10 # Reward for avoiding loss
|
|
elif action == 'HOLD':
|
|
base_reward = 0.01 if abs(price_change) < 0.005 else -0.01 # Small reward for correct holds
|
|
else:
|
|
base_reward = -abs(price_change) * 5 # Penalty for wrong actions
|
|
|
|
# Adjust reward based on confidence
|
|
confidence_multiplier = 0.5 + confidence # 0.5 to 1.5 range
|
|
|
|
return base_reward * confidence_multiplier
|
|
|
|
async def _update_rl_agents(self, action: TradingAction, initial_state: MarketState,
|
|
current_state: MarketState, reward: float):
|
|
"""Update RL agents with action evaluation"""
|
|
for model_name, model in self.model_registry.models.items():
|
|
if isinstance(model, RLAgentInterface):
|
|
try:
|
|
# Convert market states to RL state format
|
|
initial_rl_state = self._market_state_to_rl_state(initial_state)
|
|
current_rl_state = self._market_state_to_rl_state(current_state)
|
|
|
|
# Convert action to RL action index
|
|
action_idx = {'SELL': 0, 'HOLD': 1, 'BUY': 2}.get(action.action, 1)
|
|
|
|
# Store experience
|
|
model.remember(
|
|
state=initial_rl_state,
|
|
action=action_idx,
|
|
reward=reward,
|
|
next_state=current_rl_state,
|
|
done=False
|
|
)
|
|
|
|
# Trigger replay learning
|
|
loss = model.replay()
|
|
if loss is not None:
|
|
logger.info(f"RL agent {model_name} updated with loss: {loss:.4f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating RL agent {model_name}: {e}")
|
|
|
|
def _mark_perfect_move(self, action: TradingAction, initial_state: MarketState,
|
|
final_state: MarketState, reward: float):
|
|
"""Mark a perfect move for CNN training"""
|
|
try:
|
|
# Determine what the optimal action should have been
|
|
optimal_action = action.action if reward > 0 else ('HOLD' if action.action == 'HOLD' else
|
|
('SELL' if action.action == 'BUY' else 'BUY'))
|
|
|
|
# Calculate what confidence should have been
|
|
optimal_confidence = min(0.95, abs(reward) * 10) # Higher reward = higher confidence should have been
|
|
|
|
for tf_pred in action.timeframe_analysis:
|
|
perfect_move = PerfectMove(
|
|
symbol=action.symbol,
|
|
timeframe=tf_pred.timeframe,
|
|
timestamp=action.timestamp,
|
|
optimal_action=optimal_action,
|
|
actual_outcome=reward,
|
|
market_state_before=initial_state,
|
|
market_state_after=final_state,
|
|
confidence_should_have_been=optimal_confidence
|
|
)
|
|
self.perfect_moves.append(perfect_move)
|
|
|
|
logger.info(f"Marked perfect move for {action.symbol}: {optimal_action} with confidence {optimal_confidence:.3f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error marking perfect move: {e}")
|
|
|
|
def get_perfect_moves_for_training(self, symbol: str = None, timeframe: str = None,
|
|
limit: int = 1000) -> List[PerfectMove]:
|
|
"""Get perfect moves for CNN training"""
|
|
moves = list(self.perfect_moves)
|
|
|
|
if symbol:
|
|
moves = [m for m in moves if m.symbol == symbol]
|
|
|
|
if timeframe:
|
|
moves = [m for m in moves if m.timeframe == timeframe]
|
|
|
|
return moves[-limit:] if limit else moves
|
|
|
|
# Helper methods for market analysis
|
|
def _calculate_volatility(self, symbol: str) -> float:
|
|
"""Calculate current volatility for symbol"""
|
|
# Placeholder - implement based on your data provider
|
|
return 0.02 # 2% default volatility
|
|
|
|
def _get_current_volume(self, symbol: str) -> float:
|
|
"""Get current volume ratio compared to average"""
|
|
# Placeholder - implement based on your data provider
|
|
return 1.0 # Normal volume
|
|
|
|
def _calculate_trend_strength(self, symbol: str) -> float:
|
|
"""Calculate trend strength (0 = no trend, 1 = strong trend)"""
|
|
# Placeholder - implement based on your data provider
|
|
return 0.5 # Moderate trend
|
|
|
|
def _determine_market_regime(self, symbol: str) -> str:
|
|
"""Determine current market regime"""
|
|
# Placeholder - implement based on your analysis
|
|
return 'trending' # Default to trending
|
|
|
|
def _get_symbol_correlation(self, symbol: str) -> Dict[str, float]:
|
|
"""Get correlations with other symbols"""
|
|
correlations = {}
|
|
for other_symbol in self.symbols:
|
|
if other_symbol != symbol:
|
|
correlations[other_symbol] = self.symbol_correlation_matrix.get((symbol, other_symbol), 0.0)
|
|
return correlations
|
|
|
|
def _calculate_position_size(self, symbol: str, action: str, confidence: float) -> float:
|
|
"""Calculate position size based on confidence and risk management"""
|
|
base_size = 0.02 # 2% of portfolio
|
|
confidence_multiplier = confidence # Scale by confidence
|
|
max_size = 0.05 # 5% maximum
|
|
|
|
return min(base_size * confidence_multiplier, max_size)
|
|
|
|
def _market_state_to_rl_state(self, market_state: MarketState) -> np.ndarray:
|
|
"""Convert market state to RL state vector"""
|
|
# Combine features from all timeframes into a single state vector
|
|
state_components = []
|
|
|
|
# Add price features
|
|
state_components.extend([
|
|
market_state.volatility,
|
|
market_state.volume,
|
|
market_state.trend_strength
|
|
])
|
|
|
|
# Add flattened features from each timeframe
|
|
for timeframe in sorted(market_state.features.keys()):
|
|
features = market_state.features[timeframe]
|
|
if features is not None:
|
|
# Take the last row (most recent) and flatten
|
|
latest_features = features[-1] if len(features.shape) > 1 else features
|
|
state_components.extend(latest_features.flatten())
|
|
|
|
return np.array(state_components, dtype=np.float32) |