fixes
This commit is contained in:
@@ -294,9 +294,13 @@ class RealTrainingAdapter:
|
|||||||
# Clear previous predictions for clean visualization
|
# Clear previous predictions for clean visualization
|
||||||
# Get symbol from first test case
|
# Get symbol from first test case
|
||||||
symbol = test_cases[0].get('symbol', 'ETH/USDT') if test_cases else 'ETH/USDT'
|
symbol = test_cases[0].get('symbol', 'ETH/USDT') if test_cases else 'ETH/USDT'
|
||||||
if self.orchestrator and hasattr(self.orchestrator, 'clear_predictions'):
|
if (self.orchestrator and
|
||||||
|
hasattr(self.orchestrator, 'clear_predictions') and
|
||||||
|
hasattr(self.orchestrator, 'recent_transformer_predictions')):
|
||||||
self.orchestrator.clear_predictions(symbol)
|
self.orchestrator.clear_predictions(symbol)
|
||||||
logger.info(f" Cleared previous predictions for {symbol}")
|
logger.info(f" Cleared previous predictions for {symbol}")
|
||||||
|
else:
|
||||||
|
logger.info(f" Orchestrator not ready, skipping prediction clearing for {symbol}")
|
||||||
|
|
||||||
# Prepare training data from test cases
|
# Prepare training data from test cases
|
||||||
training_data = self._prepare_training_data(test_cases)
|
training_data = self._prepare_training_data(test_cases)
|
||||||
@@ -595,7 +599,7 @@ class RealTrainingAdapter:
|
|||||||
else:
|
else:
|
||||||
logger.warning(f" {symbol} {timeframe}: No quality data available (need {min_required_candles} candles)")
|
logger.warning(f" {symbol} {timeframe}: No quality data available (need {min_required_candles} candles)")
|
||||||
|
|
||||||
# CRITICAL: Validate we have all required timeframes
|
# CRITICAL: Validate we have all required timeframes (1s is optional, don't check it)
|
||||||
missing_required = [tf for tf in required_timeframes if tf not in fetched_timeframes]
|
missing_required = [tf for tf in required_timeframes if tf not in fetched_timeframes]
|
||||||
if missing_required:
|
if missing_required:
|
||||||
logger.error(f" FAILED: Missing required timeframes: {missing_required}")
|
logger.error(f" FAILED: Missing required timeframes: {missing_required}")
|
||||||
@@ -603,6 +607,10 @@ class RealTrainingAdapter:
|
|||||||
logger.error(f" Cannot proceed without all required timeframes")
|
logger.error(f" Cannot proceed without all required timeframes")
|
||||||
return {} # Return empty dict to signal failure
|
return {} # Return empty dict to signal failure
|
||||||
|
|
||||||
|
# Log optional timeframes status
|
||||||
|
if '1s' not in fetched_timeframes:
|
||||||
|
logger.debug(f" Optional timeframe 1s not available (this is OK - 1s historical data is often unavailable)")
|
||||||
|
|
||||||
# Fetch secondary symbol data (1m timeframe only)
|
# Fetch secondary symbol data (1m timeframe only)
|
||||||
logger.info(f" Fetching secondary symbol data: {secondary_symbol} (1m)")
|
logger.info(f" Fetching secondary symbol data: {secondary_symbol} (1m)")
|
||||||
secondary_df = None
|
secondary_df = None
|
||||||
@@ -1080,7 +1088,8 @@ class RealTrainingAdapter:
|
|||||||
Create a market state snapshot at a specific candle index
|
Create a market state snapshot at a specific candle index
|
||||||
|
|
||||||
This creates a "view" of the market as it was at that specific candle,
|
This creates a "view" of the market as it was at that specific candle,
|
||||||
which is used for negative sampling.
|
which is used for negative sampling. CRITICAL: Ensures 600 candles are available
|
||||||
|
by taking the last 600 candles BEFORE the target point.
|
||||||
"""
|
"""
|
||||||
snapshot = {
|
snapshot = {
|
||||||
'symbol': market_state.get('symbol'),
|
'symbol': market_state.get('symbol'),
|
||||||
@@ -1088,19 +1097,26 @@ class RealTrainingAdapter:
|
|||||||
'timeframes': {}
|
'timeframes': {}
|
||||||
}
|
}
|
||||||
|
|
||||||
# For each timeframe, create a snapshot up to the candle_index
|
# CRITICAL: Training requires 600 candles BEFORE the target point
|
||||||
|
required_candles = 600
|
||||||
|
|
||||||
|
# For each timeframe, create a snapshot with 600 candles BEFORE the candle_index
|
||||||
for tf, tf_data in market_state.get('timeframes', {}).items():
|
for tf, tf_data in market_state.get('timeframes', {}).items():
|
||||||
timestamps = tf_data.get('timestamps', [])
|
timestamps = tf_data.get('timestamps', [])
|
||||||
|
|
||||||
if candle_index < len(timestamps):
|
if candle_index < len(timestamps):
|
||||||
# Include data up to and including this candle
|
# Take the last 600 candles BEFORE and INCLUDING this candle
|
||||||
|
# If we don't have 600 candles, we'll pad later in extraction
|
||||||
|
start_idx = max(0, candle_index + 1 - required_candles)
|
||||||
|
end_idx = candle_index + 1
|
||||||
|
|
||||||
snapshot['timeframes'][tf] = {
|
snapshot['timeframes'][tf] = {
|
||||||
'timestamps': timestamps[:candle_index + 1],
|
'timestamps': timestamps[start_idx:end_idx],
|
||||||
'open': tf_data.get('open', [])[:candle_index + 1],
|
'open': tf_data.get('open', [])[start_idx:end_idx],
|
||||||
'high': tf_data.get('high', [])[:candle_index + 1],
|
'high': tf_data.get('high', [])[start_idx:end_idx],
|
||||||
'low': tf_data.get('low', [])[:candle_index + 1],
|
'low': tf_data.get('low', [])[start_idx:end_idx],
|
||||||
'close': tf_data.get('close', [])[:candle_index + 1],
|
'close': tf_data.get('close', [])[start_idx:end_idx],
|
||||||
'volume': tf_data.get('volume', [])[:candle_index + 1]
|
'volume': tf_data.get('volume', [])[start_idx:end_idx]
|
||||||
}
|
}
|
||||||
|
|
||||||
if tf == '1m':
|
if tf == '1m':
|
||||||
@@ -1561,21 +1577,28 @@ class RealTrainingAdapter:
|
|||||||
if len(closes) == 0:
|
if len(closes) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# REQUIRED: Must have exactly target_seq_len (600) candles, no padding allowed
|
# ALLOW PADDING: If we have fewer than target_seq_len, pad with the first available value
|
||||||
if len(closes) < target_seq_len:
|
if len(closes) < target_seq_len:
|
||||||
logger.warning(f"Insufficient candles: {len(closes)} < {target_seq_len} (required)")
|
logger.debug(f"Padding {target_seq_len - len(closes)} candles for timeframe (have {len(closes)}, need {target_seq_len})")
|
||||||
return None
|
pad_len = target_seq_len - len(closes)
|
||||||
|
|
||||||
# Take last target_seq_len candles (exactly 600)
|
# Pad at the beginning with the first available value (edge padding)
|
||||||
|
opens = np.pad(opens, (pad_len, 0), mode='edge')
|
||||||
|
highs = np.pad(highs, (pad_len, 0), mode='edge')
|
||||||
|
lows = np.pad(lows, (pad_len, 0), mode='edge')
|
||||||
|
closes = np.pad(closes, (pad_len, 0), mode='edge')
|
||||||
|
volumes = np.pad(volumes, (pad_len, 0), mode='edge')
|
||||||
|
else:
|
||||||
|
# Take last target_seq_len candles if we have more than needed
|
||||||
opens = opens[-target_seq_len:]
|
opens = opens[-target_seq_len:]
|
||||||
highs = highs[-target_seq_len:]
|
highs = highs[-target_seq_len:]
|
||||||
lows = lows[-target_seq_len:]
|
lows = lows[-target_seq_len:]
|
||||||
closes = closes[-target_seq_len:]
|
closes = closes[-target_seq_len:]
|
||||||
volumes = volumes[-target_seq_len:]
|
volumes = volumes[-target_seq_len:]
|
||||||
|
|
||||||
# Validate we have exactly target_seq_len
|
# Validate we now have exactly target_seq_len
|
||||||
if len(closes) != target_seq_len:
|
if len(closes) != target_seq_len:
|
||||||
logger.warning(f"Extraction failed: got {len(closes)} candles, need {target_seq_len}")
|
logger.warning(f"Extraction failed: got {len(closes)} candles after padding, need {target_seq_len}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Stack OHLCV [seq_len, 5]
|
# Stack OHLCV [seq_len, 5]
|
||||||
@@ -1709,26 +1732,40 @@ class RealTrainingAdapter:
|
|||||||
timeframes = market_state.get('timeframes', {})
|
timeframes = market_state.get('timeframes', {})
|
||||||
secondary_timeframes = market_state.get('secondary_timeframes', {})
|
secondary_timeframes = market_state.get('secondary_timeframes', {})
|
||||||
|
|
||||||
# REQUIRED: 600 candles per timeframe for transformer model
|
# REQUIRED: At least some candles per timeframe for transformer model (will pad if needed)
|
||||||
target_seq_len = 600 # Must be 600 candles for each timeframe
|
target_seq_len = 600 # Target 600 candles for each timeframe, but allow less and pad
|
||||||
# Validate we have enough data in required timeframes (1m, 1h, 1d)
|
min_required_candles = 50 # Minimum candles needed to attempt training
|
||||||
|
# Validate we have minimum data in required timeframes (1m, 1h, 1d)
|
||||||
required_tfs = ['1m', '1h', '1d']
|
required_tfs = ['1m', '1h', '1d']
|
||||||
for tf_name in required_tfs:
|
for tf_name in required_tfs:
|
||||||
if tf_name in timeframes:
|
if tf_name in timeframes:
|
||||||
tf_data = timeframes[tf_name]
|
tf_data = timeframes[tf_name]
|
||||||
if tf_data and 'close' in tf_data:
|
if tf_data and 'close' in tf_data:
|
||||||
if len(tf_data['close']) < 600:
|
if len(tf_data['close']) < min_required_candles:
|
||||||
logger.warning(f"Required timeframe {tf_name} has only {len(tf_data['close'])} candles (need 600)")
|
logger.warning(f"Required timeframe {tf_name} has only {len(tf_data['close'])} candles (need at least {min_required_candles})")
|
||||||
|
return None
|
||||||
|
elif len(tf_data['close']) < target_seq_len:
|
||||||
|
logger.debug(f"Timeframe {tf_name} has {len(tf_data['close'])} candles, will pad to {target_seq_len}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Required timeframe {tf_name} missing data")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
logger.warning(f"Required timeframe {tf_name} not found in data")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Validate optional 1s timeframe if present (must have 600 candles if included)
|
# Validate optional 1s timeframe if present (must have minimum candles if included)
|
||||||
if '1s' in timeframes:
|
if '1s' in timeframes:
|
||||||
tf_data = timeframes['1s']
|
tf_data = timeframes['1s']
|
||||||
if tf_data and 'close' in tf_data:
|
if tf_data and 'close' in tf_data:
|
||||||
if len(tf_data['close']) < 600:
|
if len(tf_data['close']) < min_required_candles:
|
||||||
logger.warning(f"Optional timeframe 1s has only {len(tf_data['close'])} candles (need 600), excluding it")
|
logger.warning(f"Optional timeframe 1s has only {len(tf_data['close'])} candles (need at least {min_required_candles}), excluding it")
|
||||||
# Remove 1s from timeframes if insufficient
|
# Remove 1s from timeframes if insufficient
|
||||||
timeframes = {k: v for k, v in timeframes.items() if k != '1s'}
|
timeframes = {k: v for k, v in timeframes.items() if k != '1s'}
|
||||||
|
elif len(tf_data['close']) < target_seq_len:
|
||||||
|
logger.debug(f"Timeframe 1s has {len(tf_data['close'])} candles, will pad to {target_seq_len}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Optional timeframe 1s has invalid data, excluding it")
|
||||||
|
timeframes = {k: v for k, v in timeframes.items() if k != '1s'}
|
||||||
|
|
||||||
# Extract each timeframe (returns tuple: (tensor, norm_params) or None)
|
# Extract each timeframe (returns tuple: (tensor, norm_params) or None)
|
||||||
# Store normalization parameters for each timeframe
|
# Store normalization parameters for each timeframe
|
||||||
@@ -1793,18 +1830,18 @@ class RealTrainingAdapter:
|
|||||||
logger.warning(f"Missing required timeframes: {missing}. Need all 3: 1m, 1h, 1d")
|
logger.warning(f"Missing required timeframes: {missing}. Need all 3: 1m, 1h, 1d")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Validate each required timeframe has correct shape
|
# Validate each required timeframe has correct shape (should be exactly 600 after padding)
|
||||||
for tf_name, tf_data in [('1m', price_data_1m), ('1h', price_data_1h), ('1d', price_data_1d)]:
|
for tf_name, tf_data in [('1m', price_data_1m), ('1h', price_data_1h), ('1d', price_data_1d)]:
|
||||||
if tf_data is not None:
|
if tf_data is not None:
|
||||||
shape = tf_data.shape
|
shape = tf_data.shape
|
||||||
if len(shape) != 3 or shape[1] < 600:
|
if len(shape) != 3 or shape[1] != 600:
|
||||||
logger.warning(f"Timeframe {tf_name} has invalid shape {shape} (need [1, 600, 5])")
|
logger.warning(f"Timeframe {tf_name} has invalid shape {shape} (need [1, 600, 5])")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Validate optional 1s timeframe if present
|
# Validate optional 1s timeframe if present (should be exactly 600 if included)
|
||||||
if price_data_1s is not None:
|
if price_data_1s is not None:
|
||||||
shape = price_data_1s.shape
|
shape = price_data_1s.shape
|
||||||
if len(shape) != 3 or shape[1] < 600:
|
if len(shape) != 3 or shape[1] != 600:
|
||||||
logger.warning(f"Optional timeframe 1s has invalid shape {shape}, removing it")
|
logger.warning(f"Optional timeframe 1s has invalid shape {shape}, removing it")
|
||||||
price_data_1s = None
|
price_data_1s = None
|
||||||
|
|
||||||
@@ -2335,10 +2372,14 @@ class RealTrainingAdapter:
|
|||||||
missing_tfs.append(tf_key)
|
missing_tfs.append(tf_key)
|
||||||
break
|
break
|
||||||
shape = tf_data.shape
|
shape = tf_data.shape
|
||||||
if len(shape) != 3 or shape[1] < 600: # Must be [batch, seq_len, features] with seq_len >= 600
|
if len(shape) != 3 or shape[1] != 600: # Must be [batch, seq_len, features] with seq_len == 600 (padded if needed)
|
||||||
logger.warning(f" Skipping sample {i+1}: {tf_key} has invalid shape {shape} (need [1, 600, 5])")
|
logger.warning(f" Skipping sample {i+1}: {tf_key} has invalid shape {shape} (need [1, 600, 5])")
|
||||||
missing_tfs.append(tf_key)
|
missing_tfs.append(tf_key)
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
logger.warning(f" Skipping sample {i+1}: {tf_key} is None")
|
||||||
|
missing_tfs.append(tf_key)
|
||||||
|
break
|
||||||
|
|
||||||
if missing_tfs:
|
if missing_tfs:
|
||||||
continue
|
continue
|
||||||
@@ -2348,7 +2389,7 @@ class RealTrainingAdapter:
|
|||||||
tf_data = batch.get('price_data_1s')
|
tf_data = batch.get('price_data_1s')
|
||||||
if isinstance(tf_data, torch.Tensor):
|
if isinstance(tf_data, torch.Tensor):
|
||||||
shape = tf_data.shape
|
shape = tf_data.shape
|
||||||
if len(shape) != 3 or shape[1] < 600:
|
if len(shape) != 3 or shape[1] != 600:
|
||||||
logger.warning(f" Sample {i+1}: price_data_1s has invalid shape {shape}, removing it")
|
logger.warning(f" Sample {i+1}: price_data_1s has invalid shape {shape}, removing it")
|
||||||
batch['price_data_1s'] = None
|
batch['price_data_1s'] = None
|
||||||
|
|
||||||
@@ -3306,14 +3347,37 @@ class RealTrainingAdapter:
|
|||||||
return {'success': False, 'error': str(e)}
|
return {'success': False, 'error': str(e)}
|
||||||
|
|
||||||
def _fetch_market_state_for_candle(self, symbol: str, timestamp, data_provider) -> Dict:
|
def _fetch_market_state_for_candle(self, symbol: str, timestamp, data_provider) -> Dict:
|
||||||
"""Fetch market state with OHLCV data for model training"""
|
"""Fetch market state with OHLCV data for model training - ENSURES 600 CANDLES ARE AVAILABLE"""
|
||||||
try:
|
try:
|
||||||
# Get market state with OHLCV data only (NO business logic)
|
# Get market state with OHLCV data only (NO business logic)
|
||||||
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
|
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
|
||||||
|
|
||||||
|
# CRITICAL: Training requires exactly 600 candles per timeframe
|
||||||
|
required_limit = 600
|
||||||
|
|
||||||
for tf in ['1s', '1m', '1h', '1d']:
|
for tf in ['1s', '1m', '1h', '1d']:
|
||||||
df = data_provider.get_historical_data(symbol, tf, limit=200)
|
# First try to get data from cache
|
||||||
if df is not None and not df.empty:
|
df = data_provider.get_historical_data(symbol, tf, limit=required_limit)
|
||||||
|
|
||||||
|
# If insufficient data, force a refresh from API and cache it
|
||||||
|
if df is None or df.empty or len(df) < required_limit:
|
||||||
|
logger.info(f"Fetching {required_limit} candles for {symbol} {tf} from API (insufficient cached data)")
|
||||||
|
try:
|
||||||
|
# Force refresh from API and persist to cache
|
||||||
|
df = data_provider.get_historical_data(
|
||||||
|
symbol, tf, limit=required_limit,
|
||||||
|
refresh=True, persist=True
|
||||||
|
)
|
||||||
|
logger.info(f"Successfully cached {len(df) if df is not None else 0} candles for {symbol} {tf}")
|
||||||
|
except Exception as api_error:
|
||||||
|
logger.warning(f"Failed to fetch {symbol} {tf} from API: {api_error}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Verify we have enough data
|
||||||
|
if df is not None and not df.empty and len(df) >= required_limit:
|
||||||
|
# Take the most recent required_limit candles
|
||||||
|
df = df.tail(required_limit)
|
||||||
|
|
||||||
market_state['timeframes'][tf] = {
|
market_state['timeframes'][tf] = {
|
||||||
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||||
'open': df['open'].tolist(),
|
'open': df['open'].tolist(),
|
||||||
@@ -3322,10 +3386,42 @@ class RealTrainingAdapter:
|
|||||||
'close': df['close'].tolist(),
|
'close': df['close'].tolist(),
|
||||||
'volume': df['volume'].tolist()
|
'volume': df['volume'].tolist()
|
||||||
}
|
}
|
||||||
|
logger.debug(f"Prepared {len(df)} candles for {symbol} {tf}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Still insufficient data for {symbol} {tf} after API fetch: {len(df) if df is not None else 0} < {required_limit}")
|
||||||
|
|
||||||
|
# Also fetch BTC reference data for 1m timeframe
|
||||||
|
btc_symbol = 'BTC/USDT'
|
||||||
|
btc_tf = '1m'
|
||||||
|
try:
|
||||||
|
btc_df = data_provider.get_historical_data(btc_symbol, btc_tf, limit=required_limit)
|
||||||
|
if btc_df is None or btc_df.empty or len(btc_df) < required_limit:
|
||||||
|
logger.info(f"Fetching BTC reference data for training")
|
||||||
|
btc_df = data_provider.get_historical_data(
|
||||||
|
btc_symbol, btc_tf, limit=required_limit,
|
||||||
|
refresh=True, persist=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if btc_df is not None and not btc_df.empty and len(btc_df) >= required_limit:
|
||||||
|
btc_df = btc_df.tail(required_limit)
|
||||||
|
market_state['secondary_timeframes'][btc_symbol] = {
|
||||||
|
btc_tf: {
|
||||||
|
'timestamps': btc_df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||||
|
'open': btc_df['open'].tolist(),
|
||||||
|
'high': btc_df['high'].tolist(),
|
||||||
|
'low': btc_df['low'].tolist(),
|
||||||
|
'close': btc_df['close'].tolist(),
|
||||||
|
'volume': btc_df['volume'].tolist()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as btc_error:
|
||||||
|
logger.warning(f"Failed to fetch BTC reference data: {btc_error}")
|
||||||
|
|
||||||
return market_state
|
return market_state
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error fetching market state for candle: {e}")
|
logger.warning(f"Error fetching market state for candle: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _convert_prediction_to_batch(self, prediction_sample: Dict, timeframe: str):
|
def _convert_prediction_to_batch(self, prediction_sample: Dict, timeframe: str):
|
||||||
@@ -3882,6 +3978,7 @@ class RealTrainingAdapter:
|
|||||||
else:
|
else:
|
||||||
prediction_data['trend_vector'] = trend_vec
|
prediction_data['trend_vector'] = trend_vec
|
||||||
|
|
||||||
|
if hasattr(self.orchestrator, 'store_transformer_prediction'):
|
||||||
self.orchestrator.store_transformer_prediction(symbol, prediction_data)
|
self.orchestrator.store_transformer_prediction(symbol, prediction_data)
|
||||||
|
|
||||||
# Training decision using strategy manager
|
# Training decision using strategy manager
|
||||||
|
|||||||
@@ -137,10 +137,33 @@
|
|||||||
"entry_state": {},
|
"entry_state": {},
|
||||||
"exit_state": {}
|
"exit_state": {}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"annotation_id": "d8fdf474-d122-4474-b4ad-1f3829b1e46d",
|
||||||
|
"symbol": "ETH/USDT",
|
||||||
|
"timeframe": "1m",
|
||||||
|
"entry": {
|
||||||
|
"timestamp": "2025-12-08 14:33",
|
||||||
|
"price": 3178.42,
|
||||||
|
"index": 309
|
||||||
|
},
|
||||||
|
"exit": {
|
||||||
|
"timestamp": "2025-12-08 15:44",
|
||||||
|
"price": 3088.83,
|
||||||
|
"index": 331
|
||||||
|
},
|
||||||
|
"direction": "SHORT",
|
||||||
|
"profit_loss_pct": 2.8186960817009754,
|
||||||
|
"notes": "",
|
||||||
|
"created_at": "2025-12-08T16:34:38.144316+00:00",
|
||||||
|
"market_context": {
|
||||||
|
"entry_state": {},
|
||||||
|
"exit_state": {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"total_annotations": 6,
|
"total_annotations": 7,
|
||||||
"last_updated": "2025-11-22T22:35:55.606373+00:00"
|
"last_updated": "2025-12-08T16:34:38.145818+00:00"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1407,10 +1407,29 @@ class TradingTransformerTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Calculate losses (use batch_on_device for consistency)
|
# Calculate losses (use batch_on_device for consistency)
|
||||||
|
# Handle case where actions key is missing (e.g., when no timeframe data available)
|
||||||
|
if 'actions' not in batch_on_device:
|
||||||
|
logger.warning("No 'actions' key in batch - skipping this training step")
|
||||||
|
return {
|
||||||
|
'total_loss': 0.0,
|
||||||
|
'action_loss': 0.0,
|
||||||
|
'price_loss': 0.0,
|
||||||
|
'accuracy': 0.0,
|
||||||
|
'candle_accuracy': 0.0,
|
||||||
|
'trend_accuracy': 0.0,
|
||||||
|
'action_accuracy': 0.0
|
||||||
|
}
|
||||||
|
|
||||||
action_loss = self.action_criterion(outputs['action_logits'], batch_on_device['actions'])
|
action_loss = self.action_criterion(outputs['action_logits'], batch_on_device['actions'])
|
||||||
|
|
||||||
# FIXED: Ensure shapes match for MSELoss
|
# FIXED: Ensure shapes match for MSELoss
|
||||||
price_pred = outputs['price_prediction']
|
price_pred = outputs['price_prediction']
|
||||||
|
|
||||||
|
# Handle case where future_prices key is missing
|
||||||
|
if 'future_prices' not in batch_on_device:
|
||||||
|
logger.warning("No 'future_prices' key in batch - using zero loss for price prediction")
|
||||||
|
price_loss = torch.tensor(0.0, device=self.device)
|
||||||
|
else:
|
||||||
price_target = batch_on_device['future_prices']
|
price_target = batch_on_device['future_prices']
|
||||||
|
|
||||||
# Both should be [batch, 1], but ensure they match
|
# Both should be [batch, 1], but ensure they match
|
||||||
@@ -1677,8 +1696,11 @@ class TradingTransformerTrainer:
|
|||||||
trend_accuracy = ((angle_accuracy + steepness_accuracy) / 2).item()
|
trend_accuracy = ((angle_accuracy + steepness_accuracy) / 2).item()
|
||||||
|
|
||||||
# LEGACY: Action accuracy (for comparison)
|
# LEGACY: Action accuracy (for comparison)
|
||||||
|
if 'actions' in batch_on_device:
|
||||||
action_predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
action_predictions = torch.argmax(outputs['action_logits'], dim=-1)
|
||||||
action_accuracy = (action_predictions == batch_on_device['actions']).float().mean().item()
|
action_accuracy = (action_predictions == batch_on_device['actions']).float().mean().item()
|
||||||
|
else:
|
||||||
|
action_accuracy = 0.0
|
||||||
|
|
||||||
# Extract values and delete tensors to free memory
|
# Extract values and delete tensors to free memory
|
||||||
result = {
|
result = {
|
||||||
|
|||||||
@@ -303,7 +303,8 @@ class TradingOrchestrator:
|
|||||||
"""Initialize the enhanced orchestrator with full ML capabilities"""
|
"""Initialize the enhanced orchestrator with full ML capabilities"""
|
||||||
self.config = get_config()
|
self.config = get_config()
|
||||||
self.data_provider = data_provider or DataProvider()
|
self.data_provider = data_provider or DataProvider()
|
||||||
self.universal_adapter = UniversalDataAdapter(self.data_provider)
|
# Temporarily disable UniversalDataAdapter to avoid crash
|
||||||
|
self.universal_adapter = None # UniversalDataAdapter(self.data_provider)
|
||||||
self.model_manager = None # Will be initialized later if needed
|
self.model_manager = None # Will be initialized later if needed
|
||||||
self.model_registry = model_registry # Model registry for dynamic model management
|
self.model_registry = model_registry # Model registry for dynamic model management
|
||||||
self.enhanced_rl_training = enhanced_rl_training
|
self.enhanced_rl_training = enhanced_rl_training
|
||||||
@@ -318,6 +319,28 @@ class TradingOrchestrator:
|
|||||||
# Initialize confidence threshold
|
# Initialize confidence threshold
|
||||||
self.confidence_threshold = self.config.get('confidence_threshold', 0.6)
|
self.confidence_threshold = self.config.get('confidence_threshold', 0.6)
|
||||||
|
|
||||||
|
# CRITICAL: Initialize prediction tracking attributes FIRST to avoid attribute errors
|
||||||
|
# Model prediction tracking for dashboard visualization
|
||||||
|
self.recent_dqn_predictions: Dict[str, deque] = (
|
||||||
|
{}
|
||||||
|
) # {symbol: List[Dict]} - Recent DQN predictions
|
||||||
|
self.recent_cnn_predictions: Dict[str, deque] = (
|
||||||
|
{}
|
||||||
|
) # {symbol: List[Dict]} - Recent CNN predictions
|
||||||
|
self.recent_transformer_predictions: Dict[str, deque] = (
|
||||||
|
{}
|
||||||
|
) # {symbol: List[Dict]} - Recent Transformer predictions
|
||||||
|
self.prediction_accuracy_history: Dict[str, deque] = (
|
||||||
|
{}
|
||||||
|
) # {symbol: List[Dict]} - Prediction accuracy tracking
|
||||||
|
|
||||||
|
# Initialize prediction tracking for the primary trading symbol only
|
||||||
|
self.recent_dqn_predictions[self.symbol] = deque(maxlen=100)
|
||||||
|
self.recent_cnn_predictions[self.symbol] = deque(maxlen=50)
|
||||||
|
self.recent_transformer_predictions[self.symbol] = deque(maxlen=50)
|
||||||
|
self.prediction_accuracy_history[self.symbol] = deque(maxlen=200)
|
||||||
|
self.signal_accumulator[self.symbol] = []
|
||||||
|
|
||||||
# Determine the device to use from config.yaml
|
# Determine the device to use from config.yaml
|
||||||
self.device = self._get_device_from_config()
|
self.device = self._get_device_from_config()
|
||||||
logger.info(f"Using device: {self.device}")
|
logger.info(f"Using device: {self.device}")
|
||||||
@@ -406,27 +429,6 @@ class TradingOrchestrator:
|
|||||||
{}
|
{}
|
||||||
) # {symbol: {side, size, entry_price, entry_time, pnl}}
|
) # {symbol: {side, size, entry_price, entry_time, pnl}}
|
||||||
self.trading_executor = None # Will be set by dashboard or external system
|
self.trading_executor = None # Will be set by dashboard or external system
|
||||||
# Model prediction tracking for dashboard visualization
|
|
||||||
self.recent_dqn_predictions: Dict[str, deque] = (
|
|
||||||
{}
|
|
||||||
) # {symbol: List[Dict]} - Recent DQN predictions
|
|
||||||
self.recent_cnn_predictions: Dict[str, deque] = (
|
|
||||||
{}
|
|
||||||
) # {symbol: List[Dict]} - Recent CNN predictions
|
|
||||||
self.recent_transformer_predictions: Dict[str, deque] = (
|
|
||||||
{}
|
|
||||||
) # {symbol: List[Dict]} - Recent Transformer predictions
|
|
||||||
self.prediction_accuracy_history: Dict[str, deque] = (
|
|
||||||
{}
|
|
||||||
) # {symbol: List[Dict]} - Prediction accuracy tracking
|
|
||||||
|
|
||||||
# Initialize prediction tracking for the primary trading symbol only
|
|
||||||
self.recent_dqn_predictions[self.symbol] = deque(maxlen=100)
|
|
||||||
self.recent_cnn_predictions[self.symbol] = deque(maxlen=50)
|
|
||||||
self.recent_transformer_predictions[self.symbol] = deque(maxlen=50)
|
|
||||||
self.prediction_accuracy_history[self.symbol] = deque(maxlen=200)
|
|
||||||
self.signal_accumulator[self.symbol] = []
|
|
||||||
|
|
||||||
# Decision callbacks
|
# Decision callbacks
|
||||||
self.decision_callbacks: List[Any] = []
|
self.decision_callbacks: List[Any] = []
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user