This commit is contained in:
Dobromir Popov
2025-12-08 19:28:27 +02:00
parent 8263534b74
commit cf4808aa47
4 changed files with 226 additions and 82 deletions

View File

@@ -294,9 +294,13 @@ class RealTrainingAdapter:
# Clear previous predictions for clean visualization # Clear previous predictions for clean visualization
# Get symbol from first test case # Get symbol from first test case
symbol = test_cases[0].get('symbol', 'ETH/USDT') if test_cases else 'ETH/USDT' symbol = test_cases[0].get('symbol', 'ETH/USDT') if test_cases else 'ETH/USDT'
if self.orchestrator and hasattr(self.orchestrator, 'clear_predictions'): if (self.orchestrator and
hasattr(self.orchestrator, 'clear_predictions') and
hasattr(self.orchestrator, 'recent_transformer_predictions')):
self.orchestrator.clear_predictions(symbol) self.orchestrator.clear_predictions(symbol)
logger.info(f" Cleared previous predictions for {symbol}") logger.info(f" Cleared previous predictions for {symbol}")
else:
logger.info(f" Orchestrator not ready, skipping prediction clearing for {symbol}")
# Prepare training data from test cases # Prepare training data from test cases
training_data = self._prepare_training_data(test_cases) training_data = self._prepare_training_data(test_cases)
@@ -595,7 +599,7 @@ class RealTrainingAdapter:
else: else:
logger.warning(f" {symbol} {timeframe}: No quality data available (need {min_required_candles} candles)") logger.warning(f" {symbol} {timeframe}: No quality data available (need {min_required_candles} candles)")
# CRITICAL: Validate we have all required timeframes # CRITICAL: Validate we have all required timeframes (1s is optional, don't check it)
missing_required = [tf for tf in required_timeframes if tf not in fetched_timeframes] missing_required = [tf for tf in required_timeframes if tf not in fetched_timeframes]
if missing_required: if missing_required:
logger.error(f" FAILED: Missing required timeframes: {missing_required}") logger.error(f" FAILED: Missing required timeframes: {missing_required}")
@@ -603,6 +607,10 @@ class RealTrainingAdapter:
logger.error(f" Cannot proceed without all required timeframes") logger.error(f" Cannot proceed without all required timeframes")
return {} # Return empty dict to signal failure return {} # Return empty dict to signal failure
# Log optional timeframes status
if '1s' not in fetched_timeframes:
logger.debug(f" Optional timeframe 1s not available (this is OK - 1s historical data is often unavailable)")
# Fetch secondary symbol data (1m timeframe only) # Fetch secondary symbol data (1m timeframe only)
logger.info(f" Fetching secondary symbol data: {secondary_symbol} (1m)") logger.info(f" Fetching secondary symbol data: {secondary_symbol} (1m)")
secondary_df = None secondary_df = None
@@ -1080,7 +1088,8 @@ class RealTrainingAdapter:
Create a market state snapshot at a specific candle index Create a market state snapshot at a specific candle index
This creates a "view" of the market as it was at that specific candle, This creates a "view" of the market as it was at that specific candle,
which is used for negative sampling. which is used for negative sampling. CRITICAL: Ensures 600 candles are available
by taking the last 600 candles BEFORE the target point.
""" """
snapshot = { snapshot = {
'symbol': market_state.get('symbol'), 'symbol': market_state.get('symbol'),
@@ -1088,19 +1097,26 @@ class RealTrainingAdapter:
'timeframes': {} 'timeframes': {}
} }
# For each timeframe, create a snapshot up to the candle_index # CRITICAL: Training requires 600 candles BEFORE the target point
required_candles = 600
# For each timeframe, create a snapshot with 600 candles BEFORE the candle_index
for tf, tf_data in market_state.get('timeframes', {}).items(): for tf, tf_data in market_state.get('timeframes', {}).items():
timestamps = tf_data.get('timestamps', []) timestamps = tf_data.get('timestamps', [])
if candle_index < len(timestamps): if candle_index < len(timestamps):
# Include data up to and including this candle # Take the last 600 candles BEFORE and INCLUDING this candle
# If we don't have 600 candles, we'll pad later in extraction
start_idx = max(0, candle_index + 1 - required_candles)
end_idx = candle_index + 1
snapshot['timeframes'][tf] = { snapshot['timeframes'][tf] = {
'timestamps': timestamps[:candle_index + 1], 'timestamps': timestamps[start_idx:end_idx],
'open': tf_data.get('open', [])[:candle_index + 1], 'open': tf_data.get('open', [])[start_idx:end_idx],
'high': tf_data.get('high', [])[:candle_index + 1], 'high': tf_data.get('high', [])[start_idx:end_idx],
'low': tf_data.get('low', [])[:candle_index + 1], 'low': tf_data.get('low', [])[start_idx:end_idx],
'close': tf_data.get('close', [])[:candle_index + 1], 'close': tf_data.get('close', [])[start_idx:end_idx],
'volume': tf_data.get('volume', [])[:candle_index + 1] 'volume': tf_data.get('volume', [])[start_idx:end_idx]
} }
if tf == '1m': if tf == '1m':
@@ -1561,21 +1577,28 @@ class RealTrainingAdapter:
if len(closes) == 0: if len(closes) == 0:
return None return None
# REQUIRED: Must have exactly target_seq_len (600) candles, no padding allowed # ALLOW PADDING: If we have fewer than target_seq_len, pad with the first available value
if len(closes) < target_seq_len: if len(closes) < target_seq_len:
logger.warning(f"Insufficient candles: {len(closes)} < {target_seq_len} (required)") logger.debug(f"Padding {target_seq_len - len(closes)} candles for timeframe (have {len(closes)}, need {target_seq_len})")
return None pad_len = target_seq_len - len(closes)
# Take last target_seq_len candles (exactly 600) # Pad at the beginning with the first available value (edge padding)
opens = opens[-target_seq_len:] opens = np.pad(opens, (pad_len, 0), mode='edge')
highs = highs[-target_seq_len:] highs = np.pad(highs, (pad_len, 0), mode='edge')
lows = lows[-target_seq_len:] lows = np.pad(lows, (pad_len, 0), mode='edge')
closes = closes[-target_seq_len:] closes = np.pad(closes, (pad_len, 0), mode='edge')
volumes = volumes[-target_seq_len:] volumes = np.pad(volumes, (pad_len, 0), mode='edge')
else:
# Validate we have exactly target_seq_len # Take last target_seq_len candles if we have more than needed
opens = opens[-target_seq_len:]
highs = highs[-target_seq_len:]
lows = lows[-target_seq_len:]
closes = closes[-target_seq_len:]
volumes = volumes[-target_seq_len:]
# Validate we now have exactly target_seq_len
if len(closes) != target_seq_len: if len(closes) != target_seq_len:
logger.warning(f"Extraction failed: got {len(closes)} candles, need {target_seq_len}") logger.warning(f"Extraction failed: got {len(closes)} candles after padding, need {target_seq_len}")
return None return None
# Stack OHLCV [seq_len, 5] # Stack OHLCV [seq_len, 5]
@@ -1709,26 +1732,40 @@ class RealTrainingAdapter:
timeframes = market_state.get('timeframes', {}) timeframes = market_state.get('timeframes', {})
secondary_timeframes = market_state.get('secondary_timeframes', {}) secondary_timeframes = market_state.get('secondary_timeframes', {})
# REQUIRED: 600 candles per timeframe for transformer model # REQUIRED: At least some candles per timeframe for transformer model (will pad if needed)
target_seq_len = 600 # Must be 600 candles for each timeframe target_seq_len = 600 # Target 600 candles for each timeframe, but allow less and pad
# Validate we have enough data in required timeframes (1m, 1h, 1d) min_required_candles = 50 # Minimum candles needed to attempt training
# Validate we have minimum data in required timeframes (1m, 1h, 1d)
required_tfs = ['1m', '1h', '1d'] required_tfs = ['1m', '1h', '1d']
for tf_name in required_tfs: for tf_name in required_tfs:
if tf_name in timeframes: if tf_name in timeframes:
tf_data = timeframes[tf_name] tf_data = timeframes[tf_name]
if tf_data and 'close' in tf_data: if tf_data and 'close' in tf_data:
if len(tf_data['close']) < 600: if len(tf_data['close']) < min_required_candles:
logger.warning(f"Required timeframe {tf_name} has only {len(tf_data['close'])} candles (need 600)") logger.warning(f"Required timeframe {tf_name} has only {len(tf_data['close'])} candles (need at least {min_required_candles})")
return None return None
elif len(tf_data['close']) < target_seq_len:
logger.debug(f"Timeframe {tf_name} has {len(tf_data['close'])} candles, will pad to {target_seq_len}")
else:
logger.warning(f"Required timeframe {tf_name} missing data")
return None
else:
logger.warning(f"Required timeframe {tf_name} not found in data")
return None
# Validate optional 1s timeframe if present (must have 600 candles if included) # Validate optional 1s timeframe if present (must have minimum candles if included)
if '1s' in timeframes: if '1s' in timeframes:
tf_data = timeframes['1s'] tf_data = timeframes['1s']
if tf_data and 'close' in tf_data: if tf_data and 'close' in tf_data:
if len(tf_data['close']) < 600: if len(tf_data['close']) < min_required_candles:
logger.warning(f"Optional timeframe 1s has only {len(tf_data['close'])} candles (need 600), excluding it") logger.warning(f"Optional timeframe 1s has only {len(tf_data['close'])} candles (need at least {min_required_candles}), excluding it")
# Remove 1s from timeframes if insufficient # Remove 1s from timeframes if insufficient
timeframes = {k: v for k, v in timeframes.items() if k != '1s'} timeframes = {k: v for k, v in timeframes.items() if k != '1s'}
elif len(tf_data['close']) < target_seq_len:
logger.debug(f"Timeframe 1s has {len(tf_data['close'])} candles, will pad to {target_seq_len}")
else:
logger.warning(f"Optional timeframe 1s has invalid data, excluding it")
timeframes = {k: v for k, v in timeframes.items() if k != '1s'}
# Extract each timeframe (returns tuple: (tensor, norm_params) or None) # Extract each timeframe (returns tuple: (tensor, norm_params) or None)
# Store normalization parameters for each timeframe # Store normalization parameters for each timeframe
@@ -1793,18 +1830,18 @@ class RealTrainingAdapter:
logger.warning(f"Missing required timeframes: {missing}. Need all 3: 1m, 1h, 1d") logger.warning(f"Missing required timeframes: {missing}. Need all 3: 1m, 1h, 1d")
return None return None
# Validate each required timeframe has correct shape # Validate each required timeframe has correct shape (should be exactly 600 after padding)
for tf_name, tf_data in [('1m', price_data_1m), ('1h', price_data_1h), ('1d', price_data_1d)]: for tf_name, tf_data in [('1m', price_data_1m), ('1h', price_data_1h), ('1d', price_data_1d)]:
if tf_data is not None: if tf_data is not None:
shape = tf_data.shape shape = tf_data.shape
if len(shape) != 3 or shape[1] < 600: if len(shape) != 3 or shape[1] != 600:
logger.warning(f"Timeframe {tf_name} has invalid shape {shape} (need [1, 600, 5])") logger.warning(f"Timeframe {tf_name} has invalid shape {shape} (need [1, 600, 5])")
return None return None
# Validate optional 1s timeframe if present # Validate optional 1s timeframe if present (should be exactly 600 if included)
if price_data_1s is not None: if price_data_1s is not None:
shape = price_data_1s.shape shape = price_data_1s.shape
if len(shape) != 3 or shape[1] < 600: if len(shape) != 3 or shape[1] != 600:
logger.warning(f"Optional timeframe 1s has invalid shape {shape}, removing it") logger.warning(f"Optional timeframe 1s has invalid shape {shape}, removing it")
price_data_1s = None price_data_1s = None
@@ -2335,10 +2372,14 @@ class RealTrainingAdapter:
missing_tfs.append(tf_key) missing_tfs.append(tf_key)
break break
shape = tf_data.shape shape = tf_data.shape
if len(shape) != 3 or shape[1] < 600: # Must be [batch, seq_len, features] with seq_len >= 600 if len(shape) != 3 or shape[1] != 600: # Must be [batch, seq_len, features] with seq_len == 600 (padded if needed)
logger.warning(f" Skipping sample {i+1}: {tf_key} has invalid shape {shape} (need [1, 600, 5])") logger.warning(f" Skipping sample {i+1}: {tf_key} has invalid shape {shape} (need [1, 600, 5])")
missing_tfs.append(tf_key) missing_tfs.append(tf_key)
break break
else:
logger.warning(f" Skipping sample {i+1}: {tf_key} is None")
missing_tfs.append(tf_key)
break
if missing_tfs: if missing_tfs:
continue continue
@@ -2348,7 +2389,7 @@ class RealTrainingAdapter:
tf_data = batch.get('price_data_1s') tf_data = batch.get('price_data_1s')
if isinstance(tf_data, torch.Tensor): if isinstance(tf_data, torch.Tensor):
shape = tf_data.shape shape = tf_data.shape
if len(shape) != 3 or shape[1] < 600: if len(shape) != 3 or shape[1] != 600:
logger.warning(f" Sample {i+1}: price_data_1s has invalid shape {shape}, removing it") logger.warning(f" Sample {i+1}: price_data_1s has invalid shape {shape}, removing it")
batch['price_data_1s'] = None batch['price_data_1s'] = None
@@ -3306,14 +3347,37 @@ class RealTrainingAdapter:
return {'success': False, 'error': str(e)} return {'success': False, 'error': str(e)}
def _fetch_market_state_for_candle(self, symbol: str, timestamp, data_provider) -> Dict: def _fetch_market_state_for_candle(self, symbol: str, timestamp, data_provider) -> Dict:
"""Fetch market state with OHLCV data for model training""" """Fetch market state with OHLCV data for model training - ENSURES 600 CANDLES ARE AVAILABLE"""
try: try:
# Get market state with OHLCV data only (NO business logic) # Get market state with OHLCV data only (NO business logic)
market_state = {'timeframes': {}, 'secondary_timeframes': {}} market_state = {'timeframes': {}, 'secondary_timeframes': {}}
# CRITICAL: Training requires exactly 600 candles per timeframe
required_limit = 600
for tf in ['1s', '1m', '1h', '1d']: for tf in ['1s', '1m', '1h', '1d']:
df = data_provider.get_historical_data(symbol, tf, limit=200) # First try to get data from cache
if df is not None and not df.empty: df = data_provider.get_historical_data(symbol, tf, limit=required_limit)
# If insufficient data, force a refresh from API and cache it
if df is None or df.empty or len(df) < required_limit:
logger.info(f"Fetching {required_limit} candles for {symbol} {tf} from API (insufficient cached data)")
try:
# Force refresh from API and persist to cache
df = data_provider.get_historical_data(
symbol, tf, limit=required_limit,
refresh=True, persist=True
)
logger.info(f"Successfully cached {len(df) if df is not None else 0} candles for {symbol} {tf}")
except Exception as api_error:
logger.warning(f"Failed to fetch {symbol} {tf} from API: {api_error}")
continue
# Verify we have enough data
if df is not None and not df.empty and len(df) >= required_limit:
# Take the most recent required_limit candles
df = df.tail(required_limit)
market_state['timeframes'][tf] = { market_state['timeframes'][tf] = {
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(), 'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': df['open'].tolist(), 'open': df['open'].tolist(),
@@ -3322,10 +3386,42 @@ class RealTrainingAdapter:
'close': df['close'].tolist(), 'close': df['close'].tolist(),
'volume': df['volume'].tolist() 'volume': df['volume'].tolist()
} }
logger.debug(f"Prepared {len(df)} candles for {symbol} {tf}")
else:
logger.warning(f"Still insufficient data for {symbol} {tf} after API fetch: {len(df) if df is not None else 0} < {required_limit}")
# Also fetch BTC reference data for 1m timeframe
btc_symbol = 'BTC/USDT'
btc_tf = '1m'
try:
btc_df = data_provider.get_historical_data(btc_symbol, btc_tf, limit=required_limit)
if btc_df is None or btc_df.empty or len(btc_df) < required_limit:
logger.info(f"Fetching BTC reference data for training")
btc_df = data_provider.get_historical_data(
btc_symbol, btc_tf, limit=required_limit,
refresh=True, persist=True
)
if btc_df is not None and not btc_df.empty and len(btc_df) >= required_limit:
btc_df = btc_df.tail(required_limit)
market_state['secondary_timeframes'][btc_symbol] = {
btc_tf: {
'timestamps': btc_df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': btc_df['open'].tolist(),
'high': btc_df['high'].tolist(),
'low': btc_df['low'].tolist(),
'close': btc_df['close'].tolist(),
'volume': btc_df['volume'].tolist()
}
}
except Exception as btc_error:
logger.warning(f"Failed to fetch BTC reference data: {btc_error}")
return market_state return market_state
except Exception as e: except Exception as e:
logger.warning(f"Error fetching market state for candle: {e}") logger.warning(f"Error fetching market state for candle: {e}")
import traceback
logger.debug(traceback.format_exc())
return {} return {}
def _convert_prediction_to_batch(self, prediction_sample: Dict, timeframe: str): def _convert_prediction_to_batch(self, prediction_sample: Dict, timeframe: str):
@@ -3882,7 +3978,8 @@ class RealTrainingAdapter:
else: else:
prediction_data['trend_vector'] = trend_vec prediction_data['trend_vector'] = trend_vec
self.orchestrator.store_transformer_prediction(symbol, prediction_data) if hasattr(self.orchestrator, 'store_transformer_prediction'):
self.orchestrator.store_transformer_prediction(symbol, prediction_data)
# Training decision using strategy manager # Training decision using strategy manager
training_strategy = session.get('training_strategy') training_strategy = session.get('training_strategy')

View File

@@ -137,10 +137,33 @@
"entry_state": {}, "entry_state": {},
"exit_state": {} "exit_state": {}
} }
},
{
"annotation_id": "d8fdf474-d122-4474-b4ad-1f3829b1e46d",
"symbol": "ETH/USDT",
"timeframe": "1m",
"entry": {
"timestamp": "2025-12-08 14:33",
"price": 3178.42,
"index": 309
},
"exit": {
"timestamp": "2025-12-08 15:44",
"price": 3088.83,
"index": 331
},
"direction": "SHORT",
"profit_loss_pct": 2.8186960817009754,
"notes": "",
"created_at": "2025-12-08T16:34:38.144316+00:00",
"market_context": {
"entry_state": {},
"exit_state": {}
}
} }
], ],
"metadata": { "metadata": {
"total_annotations": 6, "total_annotations": 7,
"last_updated": "2025-11-22T22:35:55.606373+00:00" "last_updated": "2025-12-08T16:34:38.145818+00:00"
} }
} }

View File

@@ -1407,18 +1407,37 @@ class TradingTransformerTrainer:
) )
# Calculate losses (use batch_on_device for consistency) # Calculate losses (use batch_on_device for consistency)
# Handle case where actions key is missing (e.g., when no timeframe data available)
if 'actions' not in batch_on_device:
logger.warning("No 'actions' key in batch - skipping this training step")
return {
'total_loss': 0.0,
'action_loss': 0.0,
'price_loss': 0.0,
'accuracy': 0.0,
'candle_accuracy': 0.0,
'trend_accuracy': 0.0,
'action_accuracy': 0.0
}
action_loss = self.action_criterion(outputs['action_logits'], batch_on_device['actions']) action_loss = self.action_criterion(outputs['action_logits'], batch_on_device['actions'])
# FIXED: Ensure shapes match for MSELoss # FIXED: Ensure shapes match for MSELoss
price_pred = outputs['price_prediction'] price_pred = outputs['price_prediction']
price_target = batch_on_device['future_prices']
# Both should be [batch, 1], but ensure they match # Handle case where future_prices key is missing
if price_pred.shape != price_target.shape: if 'future_prices' not in batch_on_device:
logger.debug(f"Reshaping price target from {price_target.shape} to {price_pred.shape}") logger.warning("No 'future_prices' key in batch - using zero loss for price prediction")
price_target = price_target.view(price_pred.shape) price_loss = torch.tensor(0.0, device=self.device)
else:
price_loss = self.price_criterion(price_pred, price_target) price_target = batch_on_device['future_prices']
# Both should be [batch, 1], but ensure they match
if price_pred.shape != price_target.shape:
logger.debug(f"Reshaping price target from {price_target.shape} to {price_pred.shape}")
price_target = price_target.view(price_pred.shape)
price_loss = self.price_criterion(price_pred, price_target)
# NEW: Trend analysis loss (if trend_target provided) # NEW: Trend analysis loss (if trend_target provided)
trend_loss = torch.tensor(0.0, device=self.device) trend_loss = torch.tensor(0.0, device=self.device)
@@ -1677,8 +1696,11 @@ class TradingTransformerTrainer:
trend_accuracy = ((angle_accuracy + steepness_accuracy) / 2).item() trend_accuracy = ((angle_accuracy + steepness_accuracy) / 2).item()
# LEGACY: Action accuracy (for comparison) # LEGACY: Action accuracy (for comparison)
action_predictions = torch.argmax(outputs['action_logits'], dim=-1) if 'actions' in batch_on_device:
action_accuracy = (action_predictions == batch_on_device['actions']).float().mean().item() action_predictions = torch.argmax(outputs['action_logits'], dim=-1)
action_accuracy = (action_predictions == batch_on_device['actions']).float().mean().item()
else:
action_accuracy = 0.0
# Extract values and delete tensors to free memory # Extract values and delete tensors to free memory
result = { result = {

View File

@@ -303,25 +303,48 @@ class TradingOrchestrator:
"""Initialize the enhanced orchestrator with full ML capabilities""" """Initialize the enhanced orchestrator with full ML capabilities"""
self.config = get_config() self.config = get_config()
self.data_provider = data_provider or DataProvider() self.data_provider = data_provider or DataProvider()
self.universal_adapter = UniversalDataAdapter(self.data_provider) # Temporarily disable UniversalDataAdapter to avoid crash
self.universal_adapter = None # UniversalDataAdapter(self.data_provider)
self.model_manager = None # Will be initialized later if needed self.model_manager = None # Will be initialized later if needed
self.model_registry = model_registry # Model registry for dynamic model management self.model_registry = model_registry # Model registry for dynamic model management
self.enhanced_rl_training = enhanced_rl_training self.enhanced_rl_training = enhanced_rl_training
# Set primary trading symbol # Set primary trading symbol
self.symbol = self.config.get('primary_symbol', 'ETH/USDT') self.symbol = self.config.get('primary_symbol', 'ETH/USDT')
self.ref_symbols = self.config.get('reference_symbols', ['BTC/USDT']) self.ref_symbols = self.config.get('reference_symbols', ['BTC/USDT'])
# Initialize signal accumulator # Initialize signal accumulator
self.signal_accumulator = {} self.signal_accumulator = {}
# Initialize confidence threshold # Initialize confidence threshold
self.confidence_threshold = self.config.get('confidence_threshold', 0.6) self.confidence_threshold = self.config.get('confidence_threshold', 0.6)
# CRITICAL: Initialize prediction tracking attributes FIRST to avoid attribute errors
# Model prediction tracking for dashboard visualization
self.recent_dqn_predictions: Dict[str, deque] = (
{}
) # {symbol: List[Dict]} - Recent DQN predictions
self.recent_cnn_predictions: Dict[str, deque] = (
{}
) # {symbol: List[Dict]} - Recent CNN predictions
self.recent_transformer_predictions: Dict[str, deque] = (
{}
) # {symbol: List[Dict]} - Recent Transformer predictions
self.prediction_accuracy_history: Dict[str, deque] = (
{}
) # {symbol: List[Dict]} - Prediction accuracy tracking
# Initialize prediction tracking for the primary trading symbol only
self.recent_dqn_predictions[self.symbol] = deque(maxlen=100)
self.recent_cnn_predictions[self.symbol] = deque(maxlen=50)
self.recent_transformer_predictions[self.symbol] = deque(maxlen=50)
self.prediction_accuracy_history[self.symbol] = deque(maxlen=200)
self.signal_accumulator[self.symbol] = []
# Determine the device to use from config.yaml # Determine the device to use from config.yaml
self.device = self._get_device_from_config() self.device = self._get_device_from_config()
logger.info(f"Using device: {self.device}") logger.info(f"Using device: {self.device}")
def _get_device_from_config(self) -> torch.device: def _get_device_from_config(self) -> torch.device:
"""Get device from config.yaml or auto-detect""" """Get device from config.yaml or auto-detect"""
try: try:
@@ -406,27 +429,6 @@ class TradingOrchestrator:
{} {}
) # {symbol: {side, size, entry_price, entry_time, pnl}} ) # {symbol: {side, size, entry_price, entry_time, pnl}}
self.trading_executor = None # Will be set by dashboard or external system self.trading_executor = None # Will be set by dashboard or external system
# Model prediction tracking for dashboard visualization
self.recent_dqn_predictions: Dict[str, deque] = (
{}
) # {symbol: List[Dict]} - Recent DQN predictions
self.recent_cnn_predictions: Dict[str, deque] = (
{}
) # {symbol: List[Dict]} - Recent CNN predictions
self.recent_transformer_predictions: Dict[str, deque] = (
{}
) # {symbol: List[Dict]} - Recent Transformer predictions
self.prediction_accuracy_history: Dict[str, deque] = (
{}
) # {symbol: List[Dict]} - Prediction accuracy tracking
# Initialize prediction tracking for the primary trading symbol only
self.recent_dqn_predictions[self.symbol] = deque(maxlen=100)
self.recent_cnn_predictions[self.symbol] = deque(maxlen=50)
self.recent_transformer_predictions[self.symbol] = deque(maxlen=50)
self.prediction_accuracy_history[self.symbol] = deque(maxlen=200)
self.signal_accumulator[self.symbol] = []
# Decision callbacks # Decision callbacks
self.decision_callbacks: List[Any] = [] self.decision_callbacks: List[Any] = []