remove ws, fix predictions

This commit is contained in:
Dobromir Popov
2025-12-10 00:26:57 +02:00
parent 992d6de25b
commit c21d8cbea1
6 changed files with 526 additions and 284 deletions

View File

@@ -3589,8 +3589,7 @@ class RealTrainingAdapter:
if model_name == 'Transformer' and self.orchestrator:
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if trainer and trainer.model:
# Get recent market data
market_data, norm_params = self._get_realtime_market_data(symbol, data_provider)
# Use provided market_data and norm_params (already fetched by caller)
if not market_data:
return None
@@ -4493,15 +4492,22 @@ class RealTrainingAdapter:
time.sleep(1)
continue
# Make prediction using the model
prediction = self._make_realtime_prediction(model_name, symbol, data_provider)
# Make prediction using the model - returns tuple (prediction_dict, market_data_dict)
prediction_result = self._make_realtime_prediction(model_name, symbol, data_provider)
# Unpack tuple: prediction is the dict, market_data_info contains norm_params
if prediction_result is None:
time.sleep(1)
continue
prediction, market_data_info = prediction_result
# Register inference frame reference for later training when actual candle arrives
# This stores a reference (timestamp range) instead of copying 600 candles
# The reference allows us to retrieve the exact data from DuckDB when training
if prediction and self.training_coordinator:
# Get norm_params for storage in reference
_, norm_params = self._get_realtime_market_data(symbol, data_provider)
if prediction and self.training_coordinator and market_data_info:
# Get norm_params from market_data_info
norm_params = market_data_info.get('norm_params', {})
self._register_inference_frame(session, symbol, timeframe, prediction, data_provider, norm_params)
if prediction:
@@ -4554,10 +4560,41 @@ class RealTrainingAdapter:
# Store prediction for visualization (INCLUDE predicted_candle for ghost candles!)
if self.orchestrator and hasattr(self.orchestrator, 'store_transformer_prediction'):
# Get denormalized predicted_price (should already be denormalized from _make_realtime_prediction_internal)
predicted_price = prediction.get('predicted_price')
# Always get actual current_price from latest candle to ensure it's denormalized
# This is more reliable than trusting get_current_price which might return normalized values
actual_current_price = current_price
try:
df_latest = data_provider.get_historical_data(symbol, timeframe, limit=1, refresh=False)
if df_latest is not None and not df_latest.empty:
actual_current_price = float(df_latest['close'].iloc[-1])
else:
# Try other timeframes
for tf in ['1m', '1h', '1d']:
if tf != timeframe:
df_tf = data_provider.get_historical_data(symbol, tf, limit=1, refresh=False)
if df_tf is not None and not df_tf.empty:
actual_current_price = float(df_tf['close'].iloc[-1])
break
except Exception as e:
logger.debug(f"Error getting actual price from candle: {e}")
# Fallback: if current_price looks normalized (< 1000 for ETH/USDT), try to denormalize
if current_price < 1000 and symbol == 'ETH/USDT': # ETH should be > 1000, normalized would be < 1
if market_data_info and 'norm_params' in market_data_info:
norm_params = market_data_info['norm_params']
if '1m' in norm_params:
params = norm_params['1m']
price_min = params['price_min']
price_max = params['price_max']
# Denormalize: price = normalized * (max - min) + min
actual_current_price = float(current_price * (price_max - price_min) + price_min)
prediction_data = {
'timestamp': datetime.now(timezone.utc).isoformat(),
'current_price': current_price,
'predicted_price': prediction.get('predicted_price', current_price),
'current_price': actual_current_price, # Use denormalized price
'predicted_price': predicted_price if predicted_price is not None else actual_current_price,
'price_change': 1.0 if prediction['action'] == 'BUY' else -1.0,
'confidence': prediction['confidence'],
'action': prediction['action'],
@@ -4596,45 +4633,101 @@ class RealTrainingAdapter:
if predicted_price_val is not None:
prediction_data['predicted_price'] = predicted_price_val
prediction_data['price_change'] = ((predicted_price_val - current_price) / current_price) * 100
# Calculate price_change using denormalized prices
prediction_data['price_change'] = ((predicted_price_val - actual_current_price) / actual_current_price) * 100
else:
prediction_data['predicted_price'] = prediction.get('predicted_price', current_price)
prediction_data['price_change'] = 1.0 if prediction['action'] == 'BUY' else -1.0
# Fallback: use predicted_price from prediction dict (should be denormalized)
fallback_predicted = prediction.get('predicted_price')
if fallback_predicted is not None:
prediction_data['predicted_price'] = fallback_predicted
prediction_data['price_change'] = ((fallback_predicted - actual_current_price) / actual_current_price) * 100
else:
prediction_data['predicted_price'] = actual_current_price
prediction_data['price_change'] = 1.0 if prediction['action'] == 'BUY' else -1.0
else:
# Fallback to estimated price if no candle prediction
logger.warning(f"!!! No predicted_candle in prediction object - ghost candles will not appear!")
prediction_data['predicted_price'] = prediction.get('predicted_price', current_price * (1.01 if prediction['action'] == 'BUY' else 0.99))
prediction_data['price_change'] = 1.0 if prediction['action'] == 'BUY' else -1.0
# Include trend_vector if available (convert tensors to Python types)
# Include trend_vector if available (convert tensors to Python types and denormalize)
if 'trend_vector' in prediction:
trend_vec = prediction['trend_vector']
# Convert any tensors to Python native types
# Get normalization params for denormalization
norm_params_for_denorm = {}
if market_data_info and 'norm_params' in market_data_info:
norm_params_for_denorm = market_data_info['norm_params']
# Convert any tensors to Python native types and denormalize price values
if isinstance(trend_vec, dict):
serialized_trend = {}
for key, value in trend_vec.items():
if hasattr(value, 'numel'): # Tensor
if value.numel() == 1: # Scalar tensor
serialized_trend[key] = value.item()
val = value.item()
# Denormalize price_delta if it's a price-related value
if key == 'price_delta' and norm_params_for_denorm:
val = self._denormalize_price_value(val, norm_params_for_denorm, '1m')
serialized_trend[key] = val
else: # Multi-element tensor
serialized_trend[key] = value.detach().cpu().tolist()
val_list = value.detach().cpu().tolist()
# Denormalize pivot_prices if it's a price array (can be nested)
if key == 'pivot_prices' and norm_params_for_denorm:
val_list = self._denormalize_nested_price_array(val_list, norm_params_for_denorm, '1m')
serialized_trend[key] = val_list
elif hasattr(value, 'tolist'): # Other array-like
serialized_trend[key] = value.tolist()
val_list = value.tolist()
if key == 'pivot_prices' and norm_params_for_denorm:
val_list = self._denormalize_nested_price_array(val_list, norm_params_for_denorm, '1m')
serialized_trend[key] = val_list
elif isinstance(value, (list, tuple)):
# Recursively convert list/tuple of tensors
serialized_trend[key] = []
serialized_list = []
for v in value:
if hasattr(v, 'numel'):
if v.numel() == 1:
serialized_trend[key].append(v.item())
val = v.item()
if key == 'pivot_prices' and norm_params_for_denorm:
val = self._denormalize_price_value(val, norm_params_for_denorm, '1m')
serialized_list.append(val)
else:
serialized_trend[key].append(v.detach().cpu().tolist())
val_list = v.detach().cpu().tolist()
if key == 'pivot_prices' and norm_params_for_denorm:
# Handle nested arrays (pivot_prices is [[p1, p2, p3, ...]])
val_list = self._denormalize_nested_price_array(val_list, norm_params_for_denorm, '1m')
serialized_list.append(val_list)
elif hasattr(v, 'tolist'):
serialized_trend[key].append(v.tolist())
val_list = v.tolist()
if key == 'pivot_prices' and norm_params_for_denorm:
# Handle nested arrays
val_list = self._denormalize_nested_price_array(val_list, norm_params_for_denorm, '1m')
serialized_list.append(val_list)
elif isinstance(v, (list, tuple)):
# Nested list - handle pivot_prices structure
if key == 'pivot_prices' and norm_params_for_denorm:
nested_denorm = self._denormalize_nested_price_array(list(v), norm_params_for_denorm, '1m')
serialized_list.append(nested_denorm)
else:
serialized_list.append(list(v))
else:
serialized_trend[key].append(v)
serialized_list.append(v)
serialized_trend[key] = serialized_list
else:
serialized_trend[key] = value
# Denormalize price_delta if it's a scalar
if key == 'price_delta' and isinstance(value, (int, float)) and norm_params_for_denorm:
serialized_trend[key] = self._denormalize_price_value(value, norm_params_for_denorm, '1m')
else:
serialized_trend[key] = value
# Denormalize vector array if it contains price deltas
if 'vector' in serialized_trend and isinstance(serialized_trend['vector'], list) and norm_params_for_denorm:
vector = serialized_trend['vector']
if len(vector) > 0 and isinstance(vector[0], list) and len(vector[0]) > 0:
# vector is [[price_delta, time_delta]]
price_delta_norm = vector[0][0]
price_delta_denorm = self._denormalize_price_value(price_delta_norm, norm_params_for_denorm, '1m')
serialized_trend['vector'] = [[price_delta_denorm, vector[0][1]]]
prediction_data['trend_vector'] = serialized_trend
else:
prediction_data['trend_vector'] = trend_vec
@@ -4870,3 +4963,82 @@ class RealTrainingAdapter:
return ((current_price - entry_price) / entry_price) * 100 # Percentage
else: # short
return ((entry_price - current_price) / entry_price) * 100 # Percentage
def _denormalize_price_value(self, normalized_value: float, norm_params: Dict, timeframe: str = '1m') -> float:
"""
Denormalize a single price value using normalization parameters
Args:
normalized_value: Normalized price value (0-1 range)
norm_params: Dictionary of normalization parameters by timeframe
timeframe: Timeframe to use for denormalization (default: '1m')
Returns:
Denormalized price value
"""
try:
if timeframe in norm_params:
params = norm_params[timeframe]
price_min = params.get('price_min', 0.0)
price_max = params.get('price_max', 1.0)
if price_max > price_min:
# Denormalize: price = normalized * (max - min) + min
return float(normalized_value * (price_max - price_min) + price_min)
# Fallback: return as-is if no params available
return float(normalized_value)
except Exception as e:
logger.debug(f"Error denormalizing price value: {e}")
return float(normalized_value)
def _denormalize_price_array(self, normalized_array: list, norm_params: Dict, timeframe: str = '1m') -> list:
"""
Denormalize an array of price values using normalization parameters
Args:
normalized_array: List of normalized price values (0-1 range)
norm_params: Dictionary of normalization parameters by timeframe
timeframe: Timeframe to use for denormalization (default: '1m')
Returns:
List of denormalized price values
"""
try:
if timeframe in norm_params:
params = norm_params[timeframe]
price_min = params.get('price_min', 0.0)
price_max = params.get('price_max', 1.0)
if price_max > price_min:
# Denormalize each value: price = normalized * (max - min) + min
return [float(v * (price_max - price_min) + price_min) if isinstance(v, (int, float)) else v
for v in normalized_array]
# Fallback: return as-is if no params available
return [float(v) if isinstance(v, (int, float)) else v for v in normalized_array]
except Exception as e:
logger.debug(f"Error denormalizing price array: {e}")
return [float(v) if isinstance(v, (int, float)) else v for v in normalized_array]
def _denormalize_nested_price_array(self, normalized_array: list, norm_params: Dict, timeframe: str = '1m') -> list:
"""
Denormalize a nested array of price values (e.g., [[p1, p2, p3], [p4, p5, p6]])
Args:
normalized_array: Nested list of normalized price values
norm_params: Dictionary of normalization parameters by timeframe
timeframe: Timeframe to use for denormalization (default: '1m')
Returns:
Nested list of denormalized price values
"""
try:
result = []
for item in normalized_array:
if isinstance(item, (list, tuple)):
# Recursively denormalize nested arrays
result.append(self._denormalize_price_array(list(item), norm_params, timeframe))
else:
# Single value - denormalize it
result.append(self._denormalize_price_value(item, norm_params, timeframe) if isinstance(item, (int, float)) else item)
return result
except Exception as e:
logger.debug(f"Error denormalizing nested price array: {e}")
return normalized_array