remove ws, fix predictions
This commit is contained in:
@@ -89,14 +89,32 @@ class HistoricalDataLoader:
|
||||
try:
|
||||
# FORCE refresh for 1s/1m if requesting latest data OR incremental update
|
||||
force_refresh = (timeframe in ['1s', '1m'] and (bypass_cache or (not start_time and not end_time)))
|
||||
|
||||
# Try to get data from DataProvider's cached data first (most efficient)
|
||||
if hasattr(self.data_provider, 'cached_data'):
|
||||
with self.data_provider.data_lock:
|
||||
cached_df = self.data_provider.cached_data.get(symbol, {}).get(timeframe)
|
||||
|
||||
if cached_df is not None and not cached_df.empty:
|
||||
# Use cached data if we have enough candles
|
||||
if len(cached_df) >= min(limit, 100): # Use cached if we have at least 100 candles
|
||||
# If time range is specified, check if cached data covers it
|
||||
use_cached_data = True
|
||||
if start_time or end_time:
|
||||
if isinstance(cached_df.index, pd.DatetimeIndex):
|
||||
cache_start = cached_df.index.min()
|
||||
cache_end = cached_df.index.max()
|
||||
|
||||
# Check if requested range is within cached range
|
||||
if start_time and start_time < cache_start:
|
||||
use_cached_data = False
|
||||
elif end_time and end_time > cache_end:
|
||||
use_cached_data = False
|
||||
elif start_time and end_time:
|
||||
# Both specified - check if range overlaps
|
||||
if end_time < cache_start or start_time > cache_end:
|
||||
use_cached_data = False
|
||||
|
||||
# Use cached data if we have enough candles and it covers the range
|
||||
if use_cached_data and len(cached_df) >= min(limit, 100): # Use cached if we have at least 100 candles
|
||||
elapsed_ms = (time.time() - start_time_ms) * 1000
|
||||
logger.debug(f" DataProvider cache hit for {symbol} {timeframe} ({len(cached_df)} candles, {elapsed_ms:.1f}ms)")
|
||||
|
||||
@@ -109,9 +127,12 @@ class HistoricalDataLoader:
|
||||
limit
|
||||
)
|
||||
|
||||
# Cache in memory
|
||||
self.memory_cache[cache_key] = (filtered_df, datetime.now())
|
||||
return filtered_df
|
||||
# Only return cached data if filter produced results
|
||||
if filtered_df is not None and not filtered_df.empty:
|
||||
# Cache in memory
|
||||
self.memory_cache[cache_key] = (filtered_df, datetime.now())
|
||||
return filtered_df
|
||||
# If filter returned empty, fall through to fetch from DuckDB/API
|
||||
|
||||
# Try unified storage first if available
|
||||
if hasattr(self.data_provider, 'is_unified_storage_enabled') and \
|
||||
@@ -156,28 +177,47 @@ class HistoricalDataLoader:
|
||||
except Exception as e:
|
||||
logger.debug(f"Unified storage not available, falling back to cached data: {e}")
|
||||
|
||||
# Fallback to existing cached data method
|
||||
# Use DataProvider's cached data if available
|
||||
# Fallback to existing cached data method (duplicate check - should not reach here if first check worked)
|
||||
# This is kept for backward compatibility but should rarely execute
|
||||
if hasattr(self.data_provider, 'cached_data'):
|
||||
if symbol in self.data_provider.cached_data:
|
||||
if timeframe in self.data_provider.cached_data[symbol]:
|
||||
df = self.data_provider.cached_data[symbol][timeframe]
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Filter by time range with direction support
|
||||
df = self._filter_by_time_range(
|
||||
df.copy(),
|
||||
start_time,
|
||||
end_time,
|
||||
direction,
|
||||
limit
|
||||
)
|
||||
# Check if cached data covers the requested time range
|
||||
use_cached_data = True
|
||||
if start_time or end_time:
|
||||
if isinstance(df.index, pd.DatetimeIndex):
|
||||
cache_start = df.index.min()
|
||||
cache_end = df.index.max()
|
||||
|
||||
if start_time and start_time < cache_start:
|
||||
use_cached_data = False
|
||||
elif end_time and end_time > cache_end:
|
||||
use_cached_data = False
|
||||
elif start_time and end_time:
|
||||
if end_time < cache_start or start_time > cache_end:
|
||||
use_cached_data = False
|
||||
|
||||
# Cache in memory
|
||||
self.memory_cache[cache_key] = (df.copy(), datetime.now())
|
||||
|
||||
logger.info(f"Loaded {len(df)} candles for {symbol} {timeframe}")
|
||||
return df
|
||||
if use_cached_data:
|
||||
# Filter by time range with direction support
|
||||
df = self._filter_by_time_range(
|
||||
df.copy(),
|
||||
start_time,
|
||||
end_time,
|
||||
direction,
|
||||
limit
|
||||
)
|
||||
|
||||
# Only return if filter produced results
|
||||
if df is not None and not df.empty:
|
||||
# Cache in memory
|
||||
self.memory_cache[cache_key] = (df.copy(), datetime.now())
|
||||
|
||||
logger.info(f"Loaded {len(df)} candles for {symbol} {timeframe}")
|
||||
return df
|
||||
# If filter returned empty or range not covered, fall through to fetch from DuckDB/API
|
||||
|
||||
# Check DuckDB first for historical data (always check for infinite scroll)
|
||||
if self.data_provider.duckdb_storage and (start_time or end_time):
|
||||
@@ -198,7 +238,7 @@ class HistoricalDataLoader:
|
||||
self.memory_cache[cache_key] = (df.copy(), datetime.now())
|
||||
return df
|
||||
else:
|
||||
logger.info(f"📡 No data in DuckDB, fetching from exchange API for {symbol} {timeframe}")
|
||||
logger.info(f"No data in DuckDB, fetching from exchange API for {symbol} {timeframe}")
|
||||
|
||||
# Fetch from exchange API with time range
|
||||
df = self._fetch_from_exchange_api(
|
||||
@@ -212,7 +252,7 @@ class HistoricalDataLoader:
|
||||
|
||||
if df is not None and not df.empty:
|
||||
elapsed_ms = (time.time() - start_time_ms) * 1000
|
||||
logger.info(f"🌐 Exchange API hit for {symbol} {timeframe} ({len(df)} candles, {elapsed_ms:.1f}ms)")
|
||||
logger.info(f"Exchange API hit for {symbol} {timeframe} ({len(df)} candles, {elapsed_ms:.1f}ms)")
|
||||
|
||||
# Store in DuckDB for future use
|
||||
if self.data_provider.duckdb_storage:
|
||||
|
||||
@@ -3589,8 +3589,7 @@ class RealTrainingAdapter:
|
||||
if model_name == 'Transformer' and self.orchestrator:
|
||||
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
||||
if trainer and trainer.model:
|
||||
# Get recent market data
|
||||
market_data, norm_params = self._get_realtime_market_data(symbol, data_provider)
|
||||
# Use provided market_data and norm_params (already fetched by caller)
|
||||
if not market_data:
|
||||
return None
|
||||
|
||||
@@ -4493,15 +4492,22 @@ class RealTrainingAdapter:
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
# Make prediction using the model
|
||||
prediction = self._make_realtime_prediction(model_name, symbol, data_provider)
|
||||
# Make prediction using the model - returns tuple (prediction_dict, market_data_dict)
|
||||
prediction_result = self._make_realtime_prediction(model_name, symbol, data_provider)
|
||||
|
||||
# Unpack tuple: prediction is the dict, market_data_info contains norm_params
|
||||
if prediction_result is None:
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
prediction, market_data_info = prediction_result
|
||||
|
||||
# Register inference frame reference for later training when actual candle arrives
|
||||
# This stores a reference (timestamp range) instead of copying 600 candles
|
||||
# The reference allows us to retrieve the exact data from DuckDB when training
|
||||
if prediction and self.training_coordinator:
|
||||
# Get norm_params for storage in reference
|
||||
_, norm_params = self._get_realtime_market_data(symbol, data_provider)
|
||||
if prediction and self.training_coordinator and market_data_info:
|
||||
# Get norm_params from market_data_info
|
||||
norm_params = market_data_info.get('norm_params', {})
|
||||
self._register_inference_frame(session, symbol, timeframe, prediction, data_provider, norm_params)
|
||||
|
||||
if prediction:
|
||||
@@ -4554,10 +4560,41 @@ class RealTrainingAdapter:
|
||||
|
||||
# Store prediction for visualization (INCLUDE predicted_candle for ghost candles!)
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'store_transformer_prediction'):
|
||||
# Get denormalized predicted_price (should already be denormalized from _make_realtime_prediction_internal)
|
||||
predicted_price = prediction.get('predicted_price')
|
||||
|
||||
# Always get actual current_price from latest candle to ensure it's denormalized
|
||||
# This is more reliable than trusting get_current_price which might return normalized values
|
||||
actual_current_price = current_price
|
||||
try:
|
||||
df_latest = data_provider.get_historical_data(symbol, timeframe, limit=1, refresh=False)
|
||||
if df_latest is not None and not df_latest.empty:
|
||||
actual_current_price = float(df_latest['close'].iloc[-1])
|
||||
else:
|
||||
# Try other timeframes
|
||||
for tf in ['1m', '1h', '1d']:
|
||||
if tf != timeframe:
|
||||
df_tf = data_provider.get_historical_data(symbol, tf, limit=1, refresh=False)
|
||||
if df_tf is not None and not df_tf.empty:
|
||||
actual_current_price = float(df_tf['close'].iloc[-1])
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting actual price from candle: {e}")
|
||||
# Fallback: if current_price looks normalized (< 1000 for ETH/USDT), try to denormalize
|
||||
if current_price < 1000 and symbol == 'ETH/USDT': # ETH should be > 1000, normalized would be < 1
|
||||
if market_data_info and 'norm_params' in market_data_info:
|
||||
norm_params = market_data_info['norm_params']
|
||||
if '1m' in norm_params:
|
||||
params = norm_params['1m']
|
||||
price_min = params['price_min']
|
||||
price_max = params['price_max']
|
||||
# Denormalize: price = normalized * (max - min) + min
|
||||
actual_current_price = float(current_price * (price_max - price_min) + price_min)
|
||||
|
||||
prediction_data = {
|
||||
'timestamp': datetime.now(timezone.utc).isoformat(),
|
||||
'current_price': current_price,
|
||||
'predicted_price': prediction.get('predicted_price', current_price),
|
||||
'current_price': actual_current_price, # Use denormalized price
|
||||
'predicted_price': predicted_price if predicted_price is not None else actual_current_price,
|
||||
'price_change': 1.0 if prediction['action'] == 'BUY' else -1.0,
|
||||
'confidence': prediction['confidence'],
|
||||
'action': prediction['action'],
|
||||
@@ -4596,45 +4633,101 @@ class RealTrainingAdapter:
|
||||
|
||||
if predicted_price_val is not None:
|
||||
prediction_data['predicted_price'] = predicted_price_val
|
||||
prediction_data['price_change'] = ((predicted_price_val - current_price) / current_price) * 100
|
||||
# Calculate price_change using denormalized prices
|
||||
prediction_data['price_change'] = ((predicted_price_val - actual_current_price) / actual_current_price) * 100
|
||||
else:
|
||||
prediction_data['predicted_price'] = prediction.get('predicted_price', current_price)
|
||||
prediction_data['price_change'] = 1.0 if prediction['action'] == 'BUY' else -1.0
|
||||
# Fallback: use predicted_price from prediction dict (should be denormalized)
|
||||
fallback_predicted = prediction.get('predicted_price')
|
||||
if fallback_predicted is not None:
|
||||
prediction_data['predicted_price'] = fallback_predicted
|
||||
prediction_data['price_change'] = ((fallback_predicted - actual_current_price) / actual_current_price) * 100
|
||||
else:
|
||||
prediction_data['predicted_price'] = actual_current_price
|
||||
prediction_data['price_change'] = 1.0 if prediction['action'] == 'BUY' else -1.0
|
||||
else:
|
||||
# Fallback to estimated price if no candle prediction
|
||||
logger.warning(f"!!! No predicted_candle in prediction object - ghost candles will not appear!")
|
||||
prediction_data['predicted_price'] = prediction.get('predicted_price', current_price * (1.01 if prediction['action'] == 'BUY' else 0.99))
|
||||
prediction_data['price_change'] = 1.0 if prediction['action'] == 'BUY' else -1.0
|
||||
|
||||
# Include trend_vector if available (convert tensors to Python types)
|
||||
# Include trend_vector if available (convert tensors to Python types and denormalize)
|
||||
if 'trend_vector' in prediction:
|
||||
trend_vec = prediction['trend_vector']
|
||||
# Convert any tensors to Python native types
|
||||
# Get normalization params for denormalization
|
||||
norm_params_for_denorm = {}
|
||||
if market_data_info and 'norm_params' in market_data_info:
|
||||
norm_params_for_denorm = market_data_info['norm_params']
|
||||
|
||||
# Convert any tensors to Python native types and denormalize price values
|
||||
if isinstance(trend_vec, dict):
|
||||
serialized_trend = {}
|
||||
for key, value in trend_vec.items():
|
||||
if hasattr(value, 'numel'): # Tensor
|
||||
if value.numel() == 1: # Scalar tensor
|
||||
serialized_trend[key] = value.item()
|
||||
val = value.item()
|
||||
# Denormalize price_delta if it's a price-related value
|
||||
if key == 'price_delta' and norm_params_for_denorm:
|
||||
val = self._denormalize_price_value(val, norm_params_for_denorm, '1m')
|
||||
serialized_trend[key] = val
|
||||
else: # Multi-element tensor
|
||||
serialized_trend[key] = value.detach().cpu().tolist()
|
||||
val_list = value.detach().cpu().tolist()
|
||||
# Denormalize pivot_prices if it's a price array (can be nested)
|
||||
if key == 'pivot_prices' and norm_params_for_denorm:
|
||||
val_list = self._denormalize_nested_price_array(val_list, norm_params_for_denorm, '1m')
|
||||
serialized_trend[key] = val_list
|
||||
elif hasattr(value, 'tolist'): # Other array-like
|
||||
serialized_trend[key] = value.tolist()
|
||||
val_list = value.tolist()
|
||||
if key == 'pivot_prices' and norm_params_for_denorm:
|
||||
val_list = self._denormalize_nested_price_array(val_list, norm_params_for_denorm, '1m')
|
||||
serialized_trend[key] = val_list
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Recursively convert list/tuple of tensors
|
||||
serialized_trend[key] = []
|
||||
serialized_list = []
|
||||
for v in value:
|
||||
if hasattr(v, 'numel'):
|
||||
if v.numel() == 1:
|
||||
serialized_trend[key].append(v.item())
|
||||
val = v.item()
|
||||
if key == 'pivot_prices' and norm_params_for_denorm:
|
||||
val = self._denormalize_price_value(val, norm_params_for_denorm, '1m')
|
||||
serialized_list.append(val)
|
||||
else:
|
||||
serialized_trend[key].append(v.detach().cpu().tolist())
|
||||
val_list = v.detach().cpu().tolist()
|
||||
if key == 'pivot_prices' and norm_params_for_denorm:
|
||||
# Handle nested arrays (pivot_prices is [[p1, p2, p3, ...]])
|
||||
val_list = self._denormalize_nested_price_array(val_list, norm_params_for_denorm, '1m')
|
||||
serialized_list.append(val_list)
|
||||
elif hasattr(v, 'tolist'):
|
||||
serialized_trend[key].append(v.tolist())
|
||||
val_list = v.tolist()
|
||||
if key == 'pivot_prices' and norm_params_for_denorm:
|
||||
# Handle nested arrays
|
||||
val_list = self._denormalize_nested_price_array(val_list, norm_params_for_denorm, '1m')
|
||||
serialized_list.append(val_list)
|
||||
elif isinstance(v, (list, tuple)):
|
||||
# Nested list - handle pivot_prices structure
|
||||
if key == 'pivot_prices' and norm_params_for_denorm:
|
||||
nested_denorm = self._denormalize_nested_price_array(list(v), norm_params_for_denorm, '1m')
|
||||
serialized_list.append(nested_denorm)
|
||||
else:
|
||||
serialized_list.append(list(v))
|
||||
else:
|
||||
serialized_trend[key].append(v)
|
||||
serialized_list.append(v)
|
||||
serialized_trend[key] = serialized_list
|
||||
else:
|
||||
serialized_trend[key] = value
|
||||
# Denormalize price_delta if it's a scalar
|
||||
if key == 'price_delta' and isinstance(value, (int, float)) and norm_params_for_denorm:
|
||||
serialized_trend[key] = self._denormalize_price_value(value, norm_params_for_denorm, '1m')
|
||||
else:
|
||||
serialized_trend[key] = value
|
||||
|
||||
# Denormalize vector array if it contains price deltas
|
||||
if 'vector' in serialized_trend and isinstance(serialized_trend['vector'], list) and norm_params_for_denorm:
|
||||
vector = serialized_trend['vector']
|
||||
if len(vector) > 0 and isinstance(vector[0], list) and len(vector[0]) > 0:
|
||||
# vector is [[price_delta, time_delta]]
|
||||
price_delta_norm = vector[0][0]
|
||||
price_delta_denorm = self._denormalize_price_value(price_delta_norm, norm_params_for_denorm, '1m')
|
||||
serialized_trend['vector'] = [[price_delta_denorm, vector[0][1]]]
|
||||
|
||||
prediction_data['trend_vector'] = serialized_trend
|
||||
else:
|
||||
prediction_data['trend_vector'] = trend_vec
|
||||
@@ -4870,3 +4963,82 @@ class RealTrainingAdapter:
|
||||
return ((current_price - entry_price) / entry_price) * 100 # Percentage
|
||||
else: # short
|
||||
return ((entry_price - current_price) / entry_price) * 100 # Percentage
|
||||
|
||||
def _denormalize_price_value(self, normalized_value: float, norm_params: Dict, timeframe: str = '1m') -> float:
|
||||
"""
|
||||
Denormalize a single price value using normalization parameters
|
||||
|
||||
Args:
|
||||
normalized_value: Normalized price value (0-1 range)
|
||||
norm_params: Dictionary of normalization parameters by timeframe
|
||||
timeframe: Timeframe to use for denormalization (default: '1m')
|
||||
|
||||
Returns:
|
||||
Denormalized price value
|
||||
"""
|
||||
try:
|
||||
if timeframe in norm_params:
|
||||
params = norm_params[timeframe]
|
||||
price_min = params.get('price_min', 0.0)
|
||||
price_max = params.get('price_max', 1.0)
|
||||
if price_max > price_min:
|
||||
# Denormalize: price = normalized * (max - min) + min
|
||||
return float(normalized_value * (price_max - price_min) + price_min)
|
||||
# Fallback: return as-is if no params available
|
||||
return float(normalized_value)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error denormalizing price value: {e}")
|
||||
return float(normalized_value)
|
||||
|
||||
def _denormalize_price_array(self, normalized_array: list, norm_params: Dict, timeframe: str = '1m') -> list:
|
||||
"""
|
||||
Denormalize an array of price values using normalization parameters
|
||||
|
||||
Args:
|
||||
normalized_array: List of normalized price values (0-1 range)
|
||||
norm_params: Dictionary of normalization parameters by timeframe
|
||||
timeframe: Timeframe to use for denormalization (default: '1m')
|
||||
|
||||
Returns:
|
||||
List of denormalized price values
|
||||
"""
|
||||
try:
|
||||
if timeframe in norm_params:
|
||||
params = norm_params[timeframe]
|
||||
price_min = params.get('price_min', 0.0)
|
||||
price_max = params.get('price_max', 1.0)
|
||||
if price_max > price_min:
|
||||
# Denormalize each value: price = normalized * (max - min) + min
|
||||
return [float(v * (price_max - price_min) + price_min) if isinstance(v, (int, float)) else v
|
||||
for v in normalized_array]
|
||||
# Fallback: return as-is if no params available
|
||||
return [float(v) if isinstance(v, (int, float)) else v for v in normalized_array]
|
||||
except Exception as e:
|
||||
logger.debug(f"Error denormalizing price array: {e}")
|
||||
return [float(v) if isinstance(v, (int, float)) else v for v in normalized_array]
|
||||
|
||||
def _denormalize_nested_price_array(self, normalized_array: list, norm_params: Dict, timeframe: str = '1m') -> list:
|
||||
"""
|
||||
Denormalize a nested array of price values (e.g., [[p1, p2, p3], [p4, p5, p6]])
|
||||
|
||||
Args:
|
||||
normalized_array: Nested list of normalized price values
|
||||
norm_params: Dictionary of normalization parameters by timeframe
|
||||
timeframe: Timeframe to use for denormalization (default: '1m')
|
||||
|
||||
Returns:
|
||||
Nested list of denormalized price values
|
||||
"""
|
||||
try:
|
||||
result = []
|
||||
for item in normalized_array:
|
||||
if isinstance(item, (list, tuple)):
|
||||
# Recursively denormalize nested arrays
|
||||
result.append(self._denormalize_price_array(list(item), norm_params, timeframe))
|
||||
else:
|
||||
# Single value - denormalize it
|
||||
result.append(self._denormalize_price_value(item, norm_params, timeframe) if isinstance(item, (int, float)) else item)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.debug(f"Error denormalizing nested price array: {e}")
|
||||
return normalized_array
|
||||
|
||||
5
ANNOTATE/core/we need to fully move the Inference Trai
Normal file
5
ANNOTATE/core/we need to fully move the Inference Trai
Normal file
@@ -0,0 +1,5 @@
|
||||
we need to fully move the Inference Training Coordinator functions in Orchestrator - both classes have overlaping responsibilities and only one should exist.
|
||||
|
||||
InferenceFrameReference also should be in core/data_models.py.
|
||||
|
||||
we do not need a core folder in ANNOTATE app. we should refactor and move the classes in the main /core folder. this is a design flaw. we should have only one "core" naturally. the purpose of ANNOTATE app is to provide UI for creating test cases and anotating data and also running inference and training. all implementations should be in the main system and only referenced and used in the ANNOTATE app
|
||||
Reference in New Issue
Block a user