remove ws, fix predictions
This commit is contained in:
@@ -3589,8 +3589,7 @@ class RealTrainingAdapter:
|
||||
if model_name == 'Transformer' and self.orchestrator:
|
||||
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
||||
if trainer and trainer.model:
|
||||
# Get recent market data
|
||||
market_data, norm_params = self._get_realtime_market_data(symbol, data_provider)
|
||||
# Use provided market_data and norm_params (already fetched by caller)
|
||||
if not market_data:
|
||||
return None
|
||||
|
||||
@@ -4493,15 +4492,22 @@ class RealTrainingAdapter:
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
# Make prediction using the model
|
||||
prediction = self._make_realtime_prediction(model_name, symbol, data_provider)
|
||||
# Make prediction using the model - returns tuple (prediction_dict, market_data_dict)
|
||||
prediction_result = self._make_realtime_prediction(model_name, symbol, data_provider)
|
||||
|
||||
# Unpack tuple: prediction is the dict, market_data_info contains norm_params
|
||||
if prediction_result is None:
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
prediction, market_data_info = prediction_result
|
||||
|
||||
# Register inference frame reference for later training when actual candle arrives
|
||||
# This stores a reference (timestamp range) instead of copying 600 candles
|
||||
# The reference allows us to retrieve the exact data from DuckDB when training
|
||||
if prediction and self.training_coordinator:
|
||||
# Get norm_params for storage in reference
|
||||
_, norm_params = self._get_realtime_market_data(symbol, data_provider)
|
||||
if prediction and self.training_coordinator and market_data_info:
|
||||
# Get norm_params from market_data_info
|
||||
norm_params = market_data_info.get('norm_params', {})
|
||||
self._register_inference_frame(session, symbol, timeframe, prediction, data_provider, norm_params)
|
||||
|
||||
if prediction:
|
||||
@@ -4554,10 +4560,41 @@ class RealTrainingAdapter:
|
||||
|
||||
# Store prediction for visualization (INCLUDE predicted_candle for ghost candles!)
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'store_transformer_prediction'):
|
||||
# Get denormalized predicted_price (should already be denormalized from _make_realtime_prediction_internal)
|
||||
predicted_price = prediction.get('predicted_price')
|
||||
|
||||
# Always get actual current_price from latest candle to ensure it's denormalized
|
||||
# This is more reliable than trusting get_current_price which might return normalized values
|
||||
actual_current_price = current_price
|
||||
try:
|
||||
df_latest = data_provider.get_historical_data(symbol, timeframe, limit=1, refresh=False)
|
||||
if df_latest is not None and not df_latest.empty:
|
||||
actual_current_price = float(df_latest['close'].iloc[-1])
|
||||
else:
|
||||
# Try other timeframes
|
||||
for tf in ['1m', '1h', '1d']:
|
||||
if tf != timeframe:
|
||||
df_tf = data_provider.get_historical_data(symbol, tf, limit=1, refresh=False)
|
||||
if df_tf is not None and not df_tf.empty:
|
||||
actual_current_price = float(df_tf['close'].iloc[-1])
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting actual price from candle: {e}")
|
||||
# Fallback: if current_price looks normalized (< 1000 for ETH/USDT), try to denormalize
|
||||
if current_price < 1000 and symbol == 'ETH/USDT': # ETH should be > 1000, normalized would be < 1
|
||||
if market_data_info and 'norm_params' in market_data_info:
|
||||
norm_params = market_data_info['norm_params']
|
||||
if '1m' in norm_params:
|
||||
params = norm_params['1m']
|
||||
price_min = params['price_min']
|
||||
price_max = params['price_max']
|
||||
# Denormalize: price = normalized * (max - min) + min
|
||||
actual_current_price = float(current_price * (price_max - price_min) + price_min)
|
||||
|
||||
prediction_data = {
|
||||
'timestamp': datetime.now(timezone.utc).isoformat(),
|
||||
'current_price': current_price,
|
||||
'predicted_price': prediction.get('predicted_price', current_price),
|
||||
'current_price': actual_current_price, # Use denormalized price
|
||||
'predicted_price': predicted_price if predicted_price is not None else actual_current_price,
|
||||
'price_change': 1.0 if prediction['action'] == 'BUY' else -1.0,
|
||||
'confidence': prediction['confidence'],
|
||||
'action': prediction['action'],
|
||||
@@ -4596,45 +4633,101 @@ class RealTrainingAdapter:
|
||||
|
||||
if predicted_price_val is not None:
|
||||
prediction_data['predicted_price'] = predicted_price_val
|
||||
prediction_data['price_change'] = ((predicted_price_val - current_price) / current_price) * 100
|
||||
# Calculate price_change using denormalized prices
|
||||
prediction_data['price_change'] = ((predicted_price_val - actual_current_price) / actual_current_price) * 100
|
||||
else:
|
||||
prediction_data['predicted_price'] = prediction.get('predicted_price', current_price)
|
||||
prediction_data['price_change'] = 1.0 if prediction['action'] == 'BUY' else -1.0
|
||||
# Fallback: use predicted_price from prediction dict (should be denormalized)
|
||||
fallback_predicted = prediction.get('predicted_price')
|
||||
if fallback_predicted is not None:
|
||||
prediction_data['predicted_price'] = fallback_predicted
|
||||
prediction_data['price_change'] = ((fallback_predicted - actual_current_price) / actual_current_price) * 100
|
||||
else:
|
||||
prediction_data['predicted_price'] = actual_current_price
|
||||
prediction_data['price_change'] = 1.0 if prediction['action'] == 'BUY' else -1.0
|
||||
else:
|
||||
# Fallback to estimated price if no candle prediction
|
||||
logger.warning(f"!!! No predicted_candle in prediction object - ghost candles will not appear!")
|
||||
prediction_data['predicted_price'] = prediction.get('predicted_price', current_price * (1.01 if prediction['action'] == 'BUY' else 0.99))
|
||||
prediction_data['price_change'] = 1.0 if prediction['action'] == 'BUY' else -1.0
|
||||
|
||||
# Include trend_vector if available (convert tensors to Python types)
|
||||
# Include trend_vector if available (convert tensors to Python types and denormalize)
|
||||
if 'trend_vector' in prediction:
|
||||
trend_vec = prediction['trend_vector']
|
||||
# Convert any tensors to Python native types
|
||||
# Get normalization params for denormalization
|
||||
norm_params_for_denorm = {}
|
||||
if market_data_info and 'norm_params' in market_data_info:
|
||||
norm_params_for_denorm = market_data_info['norm_params']
|
||||
|
||||
# Convert any tensors to Python native types and denormalize price values
|
||||
if isinstance(trend_vec, dict):
|
||||
serialized_trend = {}
|
||||
for key, value in trend_vec.items():
|
||||
if hasattr(value, 'numel'): # Tensor
|
||||
if value.numel() == 1: # Scalar tensor
|
||||
serialized_trend[key] = value.item()
|
||||
val = value.item()
|
||||
# Denormalize price_delta if it's a price-related value
|
||||
if key == 'price_delta' and norm_params_for_denorm:
|
||||
val = self._denormalize_price_value(val, norm_params_for_denorm, '1m')
|
||||
serialized_trend[key] = val
|
||||
else: # Multi-element tensor
|
||||
serialized_trend[key] = value.detach().cpu().tolist()
|
||||
val_list = value.detach().cpu().tolist()
|
||||
# Denormalize pivot_prices if it's a price array (can be nested)
|
||||
if key == 'pivot_prices' and norm_params_for_denorm:
|
||||
val_list = self._denormalize_nested_price_array(val_list, norm_params_for_denorm, '1m')
|
||||
serialized_trend[key] = val_list
|
||||
elif hasattr(value, 'tolist'): # Other array-like
|
||||
serialized_trend[key] = value.tolist()
|
||||
val_list = value.tolist()
|
||||
if key == 'pivot_prices' and norm_params_for_denorm:
|
||||
val_list = self._denormalize_nested_price_array(val_list, norm_params_for_denorm, '1m')
|
||||
serialized_trend[key] = val_list
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Recursively convert list/tuple of tensors
|
||||
serialized_trend[key] = []
|
||||
serialized_list = []
|
||||
for v in value:
|
||||
if hasattr(v, 'numel'):
|
||||
if v.numel() == 1:
|
||||
serialized_trend[key].append(v.item())
|
||||
val = v.item()
|
||||
if key == 'pivot_prices' and norm_params_for_denorm:
|
||||
val = self._denormalize_price_value(val, norm_params_for_denorm, '1m')
|
||||
serialized_list.append(val)
|
||||
else:
|
||||
serialized_trend[key].append(v.detach().cpu().tolist())
|
||||
val_list = v.detach().cpu().tolist()
|
||||
if key == 'pivot_prices' and norm_params_for_denorm:
|
||||
# Handle nested arrays (pivot_prices is [[p1, p2, p3, ...]])
|
||||
val_list = self._denormalize_nested_price_array(val_list, norm_params_for_denorm, '1m')
|
||||
serialized_list.append(val_list)
|
||||
elif hasattr(v, 'tolist'):
|
||||
serialized_trend[key].append(v.tolist())
|
||||
val_list = v.tolist()
|
||||
if key == 'pivot_prices' and norm_params_for_denorm:
|
||||
# Handle nested arrays
|
||||
val_list = self._denormalize_nested_price_array(val_list, norm_params_for_denorm, '1m')
|
||||
serialized_list.append(val_list)
|
||||
elif isinstance(v, (list, tuple)):
|
||||
# Nested list - handle pivot_prices structure
|
||||
if key == 'pivot_prices' and norm_params_for_denorm:
|
||||
nested_denorm = self._denormalize_nested_price_array(list(v), norm_params_for_denorm, '1m')
|
||||
serialized_list.append(nested_denorm)
|
||||
else:
|
||||
serialized_list.append(list(v))
|
||||
else:
|
||||
serialized_trend[key].append(v)
|
||||
serialized_list.append(v)
|
||||
serialized_trend[key] = serialized_list
|
||||
else:
|
||||
serialized_trend[key] = value
|
||||
# Denormalize price_delta if it's a scalar
|
||||
if key == 'price_delta' and isinstance(value, (int, float)) and norm_params_for_denorm:
|
||||
serialized_trend[key] = self._denormalize_price_value(value, norm_params_for_denorm, '1m')
|
||||
else:
|
||||
serialized_trend[key] = value
|
||||
|
||||
# Denormalize vector array if it contains price deltas
|
||||
if 'vector' in serialized_trend and isinstance(serialized_trend['vector'], list) and norm_params_for_denorm:
|
||||
vector = serialized_trend['vector']
|
||||
if len(vector) > 0 and isinstance(vector[0], list) and len(vector[0]) > 0:
|
||||
# vector is [[price_delta, time_delta]]
|
||||
price_delta_norm = vector[0][0]
|
||||
price_delta_denorm = self._denormalize_price_value(price_delta_norm, norm_params_for_denorm, '1m')
|
||||
serialized_trend['vector'] = [[price_delta_denorm, vector[0][1]]]
|
||||
|
||||
prediction_data['trend_vector'] = serialized_trend
|
||||
else:
|
||||
prediction_data['trend_vector'] = trend_vec
|
||||
@@ -4870,3 +4963,82 @@ class RealTrainingAdapter:
|
||||
return ((current_price - entry_price) / entry_price) * 100 # Percentage
|
||||
else: # short
|
||||
return ((entry_price - current_price) / entry_price) * 100 # Percentage
|
||||
|
||||
def _denormalize_price_value(self, normalized_value: float, norm_params: Dict, timeframe: str = '1m') -> float:
|
||||
"""
|
||||
Denormalize a single price value using normalization parameters
|
||||
|
||||
Args:
|
||||
normalized_value: Normalized price value (0-1 range)
|
||||
norm_params: Dictionary of normalization parameters by timeframe
|
||||
timeframe: Timeframe to use for denormalization (default: '1m')
|
||||
|
||||
Returns:
|
||||
Denormalized price value
|
||||
"""
|
||||
try:
|
||||
if timeframe in norm_params:
|
||||
params = norm_params[timeframe]
|
||||
price_min = params.get('price_min', 0.0)
|
||||
price_max = params.get('price_max', 1.0)
|
||||
if price_max > price_min:
|
||||
# Denormalize: price = normalized * (max - min) + min
|
||||
return float(normalized_value * (price_max - price_min) + price_min)
|
||||
# Fallback: return as-is if no params available
|
||||
return float(normalized_value)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error denormalizing price value: {e}")
|
||||
return float(normalized_value)
|
||||
|
||||
def _denormalize_price_array(self, normalized_array: list, norm_params: Dict, timeframe: str = '1m') -> list:
|
||||
"""
|
||||
Denormalize an array of price values using normalization parameters
|
||||
|
||||
Args:
|
||||
normalized_array: List of normalized price values (0-1 range)
|
||||
norm_params: Dictionary of normalization parameters by timeframe
|
||||
timeframe: Timeframe to use for denormalization (default: '1m')
|
||||
|
||||
Returns:
|
||||
List of denormalized price values
|
||||
"""
|
||||
try:
|
||||
if timeframe in norm_params:
|
||||
params = norm_params[timeframe]
|
||||
price_min = params.get('price_min', 0.0)
|
||||
price_max = params.get('price_max', 1.0)
|
||||
if price_max > price_min:
|
||||
# Denormalize each value: price = normalized * (max - min) + min
|
||||
return [float(v * (price_max - price_min) + price_min) if isinstance(v, (int, float)) else v
|
||||
for v in normalized_array]
|
||||
# Fallback: return as-is if no params available
|
||||
return [float(v) if isinstance(v, (int, float)) else v for v in normalized_array]
|
||||
except Exception as e:
|
||||
logger.debug(f"Error denormalizing price array: {e}")
|
||||
return [float(v) if isinstance(v, (int, float)) else v for v in normalized_array]
|
||||
|
||||
def _denormalize_nested_price_array(self, normalized_array: list, norm_params: Dict, timeframe: str = '1m') -> list:
|
||||
"""
|
||||
Denormalize a nested array of price values (e.g., [[p1, p2, p3], [p4, p5, p6]])
|
||||
|
||||
Args:
|
||||
normalized_array: Nested list of normalized price values
|
||||
norm_params: Dictionary of normalization parameters by timeframe
|
||||
timeframe: Timeframe to use for denormalization (default: '1m')
|
||||
|
||||
Returns:
|
||||
Nested list of denormalized price values
|
||||
"""
|
||||
try:
|
||||
result = []
|
||||
for item in normalized_array:
|
||||
if isinstance(item, (list, tuple)):
|
||||
# Recursively denormalize nested arrays
|
||||
result.append(self._denormalize_price_array(list(item), norm_params, timeframe))
|
||||
else:
|
||||
# Single value - denormalize it
|
||||
result.append(self._denormalize_price_value(item, norm_params, timeframe) if isinstance(item, (int, float)) else item)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.debug(f"Error denormalizing nested price array: {e}")
|
||||
return normalized_array
|
||||
|
||||
Reference in New Issue
Block a user