wip
This commit is contained in:
@@ -32,18 +32,20 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_timestamp_to_utc(timestamp_str: str) -> datetime:
|
||||
def parse_timestamp_to_utc(timestamp_str) -> datetime:
|
||||
"""
|
||||
Unified timestamp parser that handles all formats and ensures UTC timezone.
|
||||
|
||||
Handles:
|
||||
- pandas Timestamp objects
|
||||
- datetime objects
|
||||
- ISO format with timezone: '2025-10-27T14:00:00+00:00'
|
||||
- ISO format with Z: '2025-10-27T14:00:00Z'
|
||||
- Space-separated with seconds: '2025-10-27 14:00:00'
|
||||
- Space-separated without seconds: '2025-10-27 14:00'
|
||||
|
||||
Args:
|
||||
timestamp_str: Timestamp string in various formats
|
||||
timestamp_str: Timestamp string, pandas Timestamp, or datetime object
|
||||
|
||||
Returns:
|
||||
Timezone-aware datetime object in UTC
|
||||
@@ -51,6 +53,23 @@ def parse_timestamp_to_utc(timestamp_str: str) -> datetime:
|
||||
Raises:
|
||||
ValueError: If timestamp cannot be parsed
|
||||
"""
|
||||
# Handle pandas Timestamp objects
|
||||
if hasattr(timestamp_str, 'to_pydatetime'):
|
||||
dt = timestamp_str.to_pydatetime()
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
# Handle datetime objects directly
|
||||
if isinstance(timestamp_str, datetime):
|
||||
if timestamp_str.tzinfo is None:
|
||||
return timestamp_str.replace(tzinfo=timezone.utc)
|
||||
return timestamp_str
|
||||
|
||||
# Convert to string if not already
|
||||
if not isinstance(timestamp_str, str):
|
||||
timestamp_str = str(timestamp_str)
|
||||
|
||||
if not timestamp_str:
|
||||
raise ValueError("Empty timestamp string")
|
||||
|
||||
@@ -2445,7 +2464,8 @@ class RealTrainingAdapter:
|
||||
def start_realtime_inference(self, model_name: str, symbol: str, data_provider,
|
||||
enable_live_training: bool = True,
|
||||
train_every_candle: bool = False,
|
||||
timeframe: str = '1m') -> str:
|
||||
timeframe: str = '1m',
|
||||
training_strategy = None) -> str:
|
||||
"""
|
||||
Start real-time inference using orchestrator's REAL prediction methods
|
||||
|
||||
@@ -2453,9 +2473,10 @@ class RealTrainingAdapter:
|
||||
model_name: Name of model to use for inference
|
||||
symbol: Trading symbol
|
||||
data_provider: Data provider for market data
|
||||
enable_live_training: If True, automatically train on L2 pivots
|
||||
train_every_candle: If True, train on every new candle (computationally expensive)
|
||||
enable_live_training: If True, automatically train (deprecated - use training_strategy)
|
||||
train_every_candle: If True, train on every candle (deprecated - use training_strategy)
|
||||
timeframe: Timeframe for candle-based training (default: 1m)
|
||||
training_strategy: TrainingStrategyManager for making training decisions
|
||||
|
||||
Returns:
|
||||
inference_id: Unique ID for this inference session
|
||||
@@ -2482,6 +2503,8 @@ class RealTrainingAdapter:
|
||||
'train_every_candle': train_every_candle,
|
||||
'timeframe': timeframe,
|
||||
'data_provider': data_provider,
|
||||
'training_strategy': training_strategy, # Strategy manager for training decisions
|
||||
'pending_action': None, # Action to train on (set by strategy manager)
|
||||
'metrics': {
|
||||
'accuracy': 0.0,
|
||||
'loss': 0.0,
|
||||
@@ -2585,10 +2608,18 @@ class RealTrainingAdapter:
|
||||
# Extract action
|
||||
action_probs = outputs.get('action_probs')
|
||||
if action_probs is not None:
|
||||
action_idx = torch.argmax(action_probs, dim=-1).item()
|
||||
confidence = action_probs[0, action_idx].item()
|
||||
# Handle different tensor shapes: [batch, 3] or [3]
|
||||
if action_probs.dim() == 1:
|
||||
# Shape [3] - single prediction
|
||||
action_idx = torch.argmax(action_probs, dim=0).item()
|
||||
confidence = action_probs[action_idx].item()
|
||||
else:
|
||||
# Shape [batch, 3] - take first batch item
|
||||
action_idx = torch.argmax(action_probs[0], dim=0).item()
|
||||
confidence = action_probs[0, action_idx].item()
|
||||
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
# Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL)
|
||||
actions = ['HOLD', 'BUY', 'SELL']
|
||||
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
|
||||
|
||||
# Handle predicted candles - DENORMALIZE them
|
||||
@@ -2613,21 +2644,29 @@ class RealTrainingAdapter:
|
||||
# Note: raw_candle[0] is the list of 5 values
|
||||
candle_values = raw_candle[0]
|
||||
|
||||
# Ensure all values are Python floats (not numpy scalars or tensors)
|
||||
def to_float(v):
|
||||
if hasattr(v, 'item'):
|
||||
return float(v.item())
|
||||
return float(v)
|
||||
|
||||
denorm_candle = [
|
||||
candle_values[0] * (price_max - price_min) + price_min, # Open
|
||||
candle_values[1] * (price_max - price_min) + price_min, # High
|
||||
candle_values[2] * (price_max - price_min) + price_min, # Low
|
||||
candle_values[3] * (price_max - price_min) + price_min, # Close
|
||||
candle_values[4] * (vol_max - vol_min) + vol_min # Volume
|
||||
to_float(candle_values[0] * (price_max - price_min) + price_min), # Open
|
||||
to_float(candle_values[1] * (price_max - price_min) + price_min), # High
|
||||
to_float(candle_values[2] * (price_max - price_min) + price_min), # Low
|
||||
to_float(candle_values[3] * (price_max - price_min) + price_min), # Close
|
||||
to_float(candle_values[4] * (vol_max - vol_min) + vol_min) # Volume
|
||||
]
|
||||
predicted_candles_denorm[tf] = denorm_candle
|
||||
|
||||
# Calculate predicted price from candle close
|
||||
# Calculate predicted price from candle close (ensure Python float)
|
||||
predicted_price = None
|
||||
if '1m' in predicted_candles_denorm:
|
||||
predicted_price = predicted_candles_denorm['1m'][3] # Close price
|
||||
close_val = predicted_candles_denorm['1m'][3]
|
||||
predicted_price = float(close_val.item() if hasattr(close_val, 'item') else close_val)
|
||||
elif '1s' in predicted_candles_denorm:
|
||||
predicted_price = predicted_candles_denorm['1s'][3]
|
||||
close_val = predicted_candles_denorm['1s'][3]
|
||||
predicted_price = float(close_val.item() if hasattr(close_val, 'item') else close_val)
|
||||
elif outputs.get('price_prediction') is not None:
|
||||
# Fallback to price_prediction head if available (normalized)
|
||||
# This would need separate denormalization based on reference price
|
||||
@@ -2755,42 +2794,61 @@ class RealTrainingAdapter:
|
||||
logger.debug(traceback.format_exc())
|
||||
return None, None
|
||||
|
||||
def _train_on_new_candle(self, session: Dict, symbol: str, timeframe: str, data_provider):
|
||||
"""Train model on new candle when it closes"""
|
||||
def _train_on_new_candle(self, session: Dict, symbol: str, timeframe: str, data_provider) -> Dict:
|
||||
"""
|
||||
Train model on new candle - Pure model interface with NO business logic
|
||||
|
||||
Args:
|
||||
session: Training session containing pending_action set by app
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe for training
|
||||
data_provider: Data provider for fetching candles
|
||||
|
||||
Returns:
|
||||
Dict with training metrics: {loss, accuracy, training_steps}
|
||||
"""
|
||||
try:
|
||||
# Get latest candle
|
||||
# Get latest candles
|
||||
df = data_provider.get_historical_data(symbol, timeframe, limit=2)
|
||||
if df is None or len(df) < 2:
|
||||
return
|
||||
return {'success': False, 'error': 'Insufficient data'}
|
||||
|
||||
# Check if we have a new candle
|
||||
latest_candle_time = df.index[-1]
|
||||
if session['last_candle_time'] == latest_candle_time:
|
||||
return # Same candle, no training needed
|
||||
return {'success': False, 'error': 'Same candle, no training needed'}
|
||||
|
||||
logger.debug(f"New candle detected: {latest_candle_time} (last: {session['last_candle_time']})")
|
||||
session['last_candle_time'] = latest_candle_time
|
||||
|
||||
# Get the completed candle (second to last)
|
||||
# Get the completed candle (second to last) and next candle
|
||||
completed_candle = df.iloc[-2]
|
||||
next_candle = df.iloc[-1]
|
||||
|
||||
# Calculate if the prediction would have been correct
|
||||
# Get action from session (set by app's training strategy)
|
||||
action_label = session.get('pending_action')
|
||||
if not action_label:
|
||||
return {'success': False, 'error': 'No pending_action in session'}
|
||||
|
||||
# Fetch market state for training
|
||||
market_state = self._fetch_market_state_for_candle(symbol, completed_candle.name, data_provider)
|
||||
|
||||
# Calculate price change
|
||||
price_change = (next_candle['close'] - completed_candle['close']) / completed_candle['close']
|
||||
|
||||
# Create training sample
|
||||
training_sample = {
|
||||
'symbol': symbol,
|
||||
'timestamp': completed_candle.name,
|
||||
'market_state': self._fetch_market_state_for_candle(symbol, completed_candle.name, data_provider),
|
||||
'action': 'BUY' if price_change > 0.001 else ('SELL' if price_change < -0.001 else 'HOLD'),
|
||||
'market_state': market_state,
|
||||
'action': action_label,
|
||||
'entry_price': float(completed_candle['close']),
|
||||
'exit_price': float(next_candle['close']),
|
||||
'profit_loss_pct': price_change * 100,
|
||||
'direction': 'LONG' if price_change > 0 else 'SHORT'
|
||||
'direction': 'LONG' if action_label == 'BUY' else ('SHORT' if action_label == 'SELL' else 'HOLD')
|
||||
}
|
||||
|
||||
# Train on this sample
|
||||
# Train based on model type
|
||||
model_name = session['model_name']
|
||||
if model_name == 'Transformer':
|
||||
self._train_transformer_on_sample(training_sample)
|
||||
@@ -2801,15 +2859,25 @@ class RealTrainingAdapter:
|
||||
session['metrics']['accuracy'] = sum(self.realtime_training_metrics['accuracies']) / len(self.realtime_training_metrics['accuracies'])
|
||||
session['metrics']['steps'] = self.realtime_training_metrics['total_steps']
|
||||
|
||||
logger.info(f"Trained on candle: {symbol} {timeframe} @ {completed_candle.name} (change: {price_change:+.2%})")
|
||||
logger.info(f"Trained on candle: {symbol} {timeframe} @ {completed_candle.name} action={action_label} (change: {price_change:+.2%})")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'loss': session['metrics']['loss'],
|
||||
'accuracy': session['metrics']['accuracy'],
|
||||
'training_steps': session['metrics']['steps']
|
||||
}
|
||||
|
||||
return {'success': False, 'error': f'Unsupported model: {model_name}'}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error training on new candle: {e}")
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
def _fetch_market_state_for_candle(self, symbol: str, timestamp, data_provider) -> Dict:
|
||||
"""Fetch market state at a specific candle time"""
|
||||
"""Fetch market state with OHLCV data for model training"""
|
||||
try:
|
||||
# Simplified version - get recent data
|
||||
# Get market state with OHLCV data only (NO business logic)
|
||||
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
|
||||
|
||||
for tf in ['1s', '1m', '1h', '1d']:
|
||||
@@ -3192,8 +3260,15 @@ class RealTrainingAdapter:
|
||||
# Extract action prediction
|
||||
action_probs = outputs.get('action_probs')
|
||||
if action_probs is not None:
|
||||
action_idx = torch.argmax(action_probs, dim=-1).item()
|
||||
confidence = action_probs[0, action_idx].item()
|
||||
# Handle different tensor shapes: [batch, 3] or [3]
|
||||
if action_probs.dim() == 1:
|
||||
# Shape [3] - single prediction
|
||||
action_idx = torch.argmax(action_probs, dim=0).item()
|
||||
confidence = action_probs[action_idx].item()
|
||||
else:
|
||||
# Shape [batch, 3] - take first batch item
|
||||
action_idx = torch.argmax(action_probs[0], dim=0).item()
|
||||
confidence = action_probs[0, action_idx].item()
|
||||
|
||||
# Map to BUY/SELL/HOLD
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
@@ -3291,29 +3366,125 @@ class RealTrainingAdapter:
|
||||
else:
|
||||
logger.info(f"Live Signal (NOT executed): {signal['action']} @ {signal['price']:.2f} (conf: {signal['confidence']:.2f}) - {self._get_rejection_reason(session, signal)}")
|
||||
|
||||
# Store prediction for visualization WITH predicted_candle data for ghost candles
|
||||
# Store prediction for visualization (INCLUDE predicted_candle for ghost candles!)
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'store_transformer_prediction'):
|
||||
stored_prediction = {
|
||||
prediction_data = {
|
||||
'timestamp': datetime.now(timezone.utc).isoformat(),
|
||||
'current_price': current_price,
|
||||
'predicted_price': prediction.get('predicted_price', current_price * (1.01 if prediction['action'] == 'BUY' else 0.99)),
|
||||
'predicted_price': prediction.get('predicted_price', current_price),
|
||||
'price_change': 1.0 if prediction['action'] == 'BUY' else -1.0,
|
||||
'confidence': prediction['confidence'],
|
||||
'action': prediction['action'],
|
||||
'horizon_minutes': 10,
|
||||
'source': 'live_inference'
|
||||
}
|
||||
# Include predicted_candle for ghost candle visualization
|
||||
|
||||
# Include REAL predicted_candle from model (for ghost candles)
|
||||
if 'predicted_candle' in prediction and prediction['predicted_candle']:
|
||||
stored_prediction['predicted_candle'] = prediction['predicted_candle']
|
||||
stored_prediction['next_candles'] = prediction['predicted_candle'] # Alias for compatibility
|
||||
logger.debug(f"Stored prediction with {len(prediction['predicted_candle'])} timeframe candles")
|
||||
|
||||
self.orchestrator.store_transformer_prediction(symbol, stored_prediction)
|
||||
# Ensure predicted_candle values are Python native types (not tensors)
|
||||
predicted_candle_clean = {}
|
||||
for tf, candle_data in prediction['predicted_candle'].items():
|
||||
if isinstance(candle_data, (list, tuple)):
|
||||
# Convert list/tuple elements to Python scalars
|
||||
predicted_candle_clean[tf] = [
|
||||
float(v.item() if hasattr(v, 'item') else v)
|
||||
for v in candle_data
|
||||
]
|
||||
elif hasattr(candle_data, 'tolist'):
|
||||
# Tensor array - convert to list
|
||||
predicted_candle_clean[tf] = [float(v) for v in candle_data.tolist()]
|
||||
else:
|
||||
predicted_candle_clean[tf] = candle_data
|
||||
|
||||
prediction_data['predicted_candle'] = predicted_candle_clean
|
||||
|
||||
# Use actual predicted price from candle close (ensure it's a Python float)
|
||||
predicted_price_val = None
|
||||
if '1m' in predicted_candle_clean:
|
||||
close_val = predicted_candle_clean['1m'][3]
|
||||
predicted_price_val = float(close_val.item() if hasattr(close_val, 'item') else close_val)
|
||||
elif '1s' in predicted_candle_clean:
|
||||
close_val = predicted_candle_clean['1s'][3]
|
||||
predicted_price_val = float(close_val.item() if hasattr(close_val, 'item') else close_val)
|
||||
|
||||
if predicted_price_val is not None:
|
||||
prediction_data['predicted_price'] = predicted_price_val
|
||||
prediction_data['price_change'] = ((predicted_price_val - current_price) / current_price) * 100
|
||||
else:
|
||||
prediction_data['predicted_price'] = prediction.get('predicted_price', current_price)
|
||||
prediction_data['price_change'] = 1.0 if prediction['action'] == 'BUY' else -1.0
|
||||
else:
|
||||
# Fallback to estimated price if no candle prediction
|
||||
prediction_data['predicted_price'] = prediction.get('predicted_price', current_price * (1.01 if prediction['action'] == 'BUY' else 0.99))
|
||||
prediction_data['price_change'] = 1.0 if prediction['action'] == 'BUY' else -1.0
|
||||
|
||||
# Include trend_vector if available (convert tensors to Python types)
|
||||
if 'trend_vector' in prediction:
|
||||
trend_vec = prediction['trend_vector']
|
||||
# Convert any tensors to Python native types
|
||||
if isinstance(trend_vec, dict):
|
||||
serialized_trend = {}
|
||||
for key, value in trend_vec.items():
|
||||
if hasattr(value, 'numel'): # Tensor
|
||||
if value.numel() == 1: # Scalar tensor
|
||||
serialized_trend[key] = value.item()
|
||||
else: # Multi-element tensor
|
||||
serialized_trend[key] = value.detach().cpu().tolist()
|
||||
elif hasattr(value, 'tolist'): # Other array-like
|
||||
serialized_trend[key] = value.tolist()
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Recursively convert list/tuple of tensors
|
||||
serialized_trend[key] = []
|
||||
for v in value:
|
||||
if hasattr(v, 'numel'):
|
||||
if v.numel() == 1:
|
||||
serialized_trend[key].append(v.item())
|
||||
else:
|
||||
serialized_trend[key].append(v.detach().cpu().tolist())
|
||||
elif hasattr(v, 'tolist'):
|
||||
serialized_trend[key].append(v.tolist())
|
||||
else:
|
||||
serialized_trend[key].append(v)
|
||||
else:
|
||||
serialized_trend[key] = value
|
||||
prediction_data['trend_vector'] = serialized_trend
|
||||
else:
|
||||
prediction_data['trend_vector'] = trend_vec
|
||||
|
||||
self.orchestrator.store_transformer_prediction(symbol, prediction_data)
|
||||
|
||||
# Per-candle training mode
|
||||
if train_every_candle:
|
||||
self._train_on_new_candle(session, symbol, timeframe, data_provider)
|
||||
# Training decision using strategy manager
|
||||
training_strategy = session.get('training_strategy')
|
||||
if training_strategy and training_strategy.mode != 'none':
|
||||
# Get pivot markers for training decision
|
||||
pivot_markers = {}
|
||||
if hasattr(training_strategy, 'dashboard') and training_strategy.dashboard:
|
||||
try:
|
||||
df = data_provider.get_historical_data(symbol, timeframe, limit=200)
|
||||
if df is not None and len(df) >= 10:
|
||||
pivot_markers = training_strategy.dashboard._get_pivot_markers_for_timeframe(symbol, timeframe, df)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get pivot markers: {e}")
|
||||
|
||||
# Get current candle timestamp
|
||||
df_current = data_provider.get_historical_data(symbol, timeframe, limit=1)
|
||||
if df_current is not None and len(df_current) > 0:
|
||||
current_timestamp = df_current.index[-1]
|
||||
|
||||
# Ask strategy manager if we should train
|
||||
should_train, action_data = training_strategy.should_train_on_candle(
|
||||
symbol, timeframe, current_timestamp, pivot_markers
|
||||
)
|
||||
|
||||
if should_train and action_data:
|
||||
# Set action in session for training
|
||||
session['pending_action'] = action_data['action']
|
||||
|
||||
# Call pure training method
|
||||
train_result = self._train_on_new_candle(session, symbol, timeframe, data_provider)
|
||||
|
||||
if train_result.get('success'):
|
||||
logger.info(f"Training completed: {action_data['action']} (reason: {action_data.get('reason', 'unknown')})")
|
||||
|
||||
# Sleep based on timeframe
|
||||
sleep_time = self._get_sleep_time_for_timeframe(timeframe)
|
||||
@@ -3321,6 +3492,8 @@ class RealTrainingAdapter:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in inference loop: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
time.sleep(5)
|
||||
|
||||
logger.info(f"Inference loop stopped: {inference_id}")
|
||||
|
||||
Reference in New Issue
Block a user