wip
This commit is contained in:
@@ -32,18 +32,20 @@ except ImportError:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def parse_timestamp_to_utc(timestamp_str: str) -> datetime:
|
def parse_timestamp_to_utc(timestamp_str) -> datetime:
|
||||||
"""
|
"""
|
||||||
Unified timestamp parser that handles all formats and ensures UTC timezone.
|
Unified timestamp parser that handles all formats and ensures UTC timezone.
|
||||||
|
|
||||||
Handles:
|
Handles:
|
||||||
|
- pandas Timestamp objects
|
||||||
|
- datetime objects
|
||||||
- ISO format with timezone: '2025-10-27T14:00:00+00:00'
|
- ISO format with timezone: '2025-10-27T14:00:00+00:00'
|
||||||
- ISO format with Z: '2025-10-27T14:00:00Z'
|
- ISO format with Z: '2025-10-27T14:00:00Z'
|
||||||
- Space-separated with seconds: '2025-10-27 14:00:00'
|
- Space-separated with seconds: '2025-10-27 14:00:00'
|
||||||
- Space-separated without seconds: '2025-10-27 14:00'
|
- Space-separated without seconds: '2025-10-27 14:00'
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
timestamp_str: Timestamp string in various formats
|
timestamp_str: Timestamp string, pandas Timestamp, or datetime object
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Timezone-aware datetime object in UTC
|
Timezone-aware datetime object in UTC
|
||||||
@@ -51,6 +53,23 @@ def parse_timestamp_to_utc(timestamp_str: str) -> datetime:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If timestamp cannot be parsed
|
ValueError: If timestamp cannot be parsed
|
||||||
"""
|
"""
|
||||||
|
# Handle pandas Timestamp objects
|
||||||
|
if hasattr(timestamp_str, 'to_pydatetime'):
|
||||||
|
dt = timestamp_str.to_pydatetime()
|
||||||
|
if dt.tzinfo is None:
|
||||||
|
dt = dt.replace(tzinfo=timezone.utc)
|
||||||
|
return dt
|
||||||
|
|
||||||
|
# Handle datetime objects directly
|
||||||
|
if isinstance(timestamp_str, datetime):
|
||||||
|
if timestamp_str.tzinfo is None:
|
||||||
|
return timestamp_str.replace(tzinfo=timezone.utc)
|
||||||
|
return timestamp_str
|
||||||
|
|
||||||
|
# Convert to string if not already
|
||||||
|
if not isinstance(timestamp_str, str):
|
||||||
|
timestamp_str = str(timestamp_str)
|
||||||
|
|
||||||
if not timestamp_str:
|
if not timestamp_str:
|
||||||
raise ValueError("Empty timestamp string")
|
raise ValueError("Empty timestamp string")
|
||||||
|
|
||||||
@@ -2445,7 +2464,8 @@ class RealTrainingAdapter:
|
|||||||
def start_realtime_inference(self, model_name: str, symbol: str, data_provider,
|
def start_realtime_inference(self, model_name: str, symbol: str, data_provider,
|
||||||
enable_live_training: bool = True,
|
enable_live_training: bool = True,
|
||||||
train_every_candle: bool = False,
|
train_every_candle: bool = False,
|
||||||
timeframe: str = '1m') -> str:
|
timeframe: str = '1m',
|
||||||
|
training_strategy = None) -> str:
|
||||||
"""
|
"""
|
||||||
Start real-time inference using orchestrator's REAL prediction methods
|
Start real-time inference using orchestrator's REAL prediction methods
|
||||||
|
|
||||||
@@ -2453,9 +2473,10 @@ class RealTrainingAdapter:
|
|||||||
model_name: Name of model to use for inference
|
model_name: Name of model to use for inference
|
||||||
symbol: Trading symbol
|
symbol: Trading symbol
|
||||||
data_provider: Data provider for market data
|
data_provider: Data provider for market data
|
||||||
enable_live_training: If True, automatically train on L2 pivots
|
enable_live_training: If True, automatically train (deprecated - use training_strategy)
|
||||||
train_every_candle: If True, train on every new candle (computationally expensive)
|
train_every_candle: If True, train on every candle (deprecated - use training_strategy)
|
||||||
timeframe: Timeframe for candle-based training (default: 1m)
|
timeframe: Timeframe for candle-based training (default: 1m)
|
||||||
|
training_strategy: TrainingStrategyManager for making training decisions
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
inference_id: Unique ID for this inference session
|
inference_id: Unique ID for this inference session
|
||||||
@@ -2482,6 +2503,8 @@ class RealTrainingAdapter:
|
|||||||
'train_every_candle': train_every_candle,
|
'train_every_candle': train_every_candle,
|
||||||
'timeframe': timeframe,
|
'timeframe': timeframe,
|
||||||
'data_provider': data_provider,
|
'data_provider': data_provider,
|
||||||
|
'training_strategy': training_strategy, # Strategy manager for training decisions
|
||||||
|
'pending_action': None, # Action to train on (set by strategy manager)
|
||||||
'metrics': {
|
'metrics': {
|
||||||
'accuracy': 0.0,
|
'accuracy': 0.0,
|
||||||
'loss': 0.0,
|
'loss': 0.0,
|
||||||
@@ -2585,10 +2608,18 @@ class RealTrainingAdapter:
|
|||||||
# Extract action
|
# Extract action
|
||||||
action_probs = outputs.get('action_probs')
|
action_probs = outputs.get('action_probs')
|
||||||
if action_probs is not None:
|
if action_probs is not None:
|
||||||
action_idx = torch.argmax(action_probs, dim=-1).item()
|
# Handle different tensor shapes: [batch, 3] or [3]
|
||||||
confidence = action_probs[0, action_idx].item()
|
if action_probs.dim() == 1:
|
||||||
|
# Shape [3] - single prediction
|
||||||
|
action_idx = torch.argmax(action_probs, dim=0).item()
|
||||||
|
confidence = action_probs[action_idx].item()
|
||||||
|
else:
|
||||||
|
# Shape [batch, 3] - take first batch item
|
||||||
|
action_idx = torch.argmax(action_probs[0], dim=0).item()
|
||||||
|
confidence = action_probs[0, action_idx].item()
|
||||||
|
|
||||||
actions = ['BUY', 'SELL', 'HOLD']
|
# Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL)
|
||||||
|
actions = ['HOLD', 'BUY', 'SELL']
|
||||||
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
|
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
|
||||||
|
|
||||||
# Handle predicted candles - DENORMALIZE them
|
# Handle predicted candles - DENORMALIZE them
|
||||||
@@ -2613,21 +2644,29 @@ class RealTrainingAdapter:
|
|||||||
# Note: raw_candle[0] is the list of 5 values
|
# Note: raw_candle[0] is the list of 5 values
|
||||||
candle_values = raw_candle[0]
|
candle_values = raw_candle[0]
|
||||||
|
|
||||||
|
# Ensure all values are Python floats (not numpy scalars or tensors)
|
||||||
|
def to_float(v):
|
||||||
|
if hasattr(v, 'item'):
|
||||||
|
return float(v.item())
|
||||||
|
return float(v)
|
||||||
|
|
||||||
denorm_candle = [
|
denorm_candle = [
|
||||||
candle_values[0] * (price_max - price_min) + price_min, # Open
|
to_float(candle_values[0] * (price_max - price_min) + price_min), # Open
|
||||||
candle_values[1] * (price_max - price_min) + price_min, # High
|
to_float(candle_values[1] * (price_max - price_min) + price_min), # High
|
||||||
candle_values[2] * (price_max - price_min) + price_min, # Low
|
to_float(candle_values[2] * (price_max - price_min) + price_min), # Low
|
||||||
candle_values[3] * (price_max - price_min) + price_min, # Close
|
to_float(candle_values[3] * (price_max - price_min) + price_min), # Close
|
||||||
candle_values[4] * (vol_max - vol_min) + vol_min # Volume
|
to_float(candle_values[4] * (vol_max - vol_min) + vol_min) # Volume
|
||||||
]
|
]
|
||||||
predicted_candles_denorm[tf] = denorm_candle
|
predicted_candles_denorm[tf] = denorm_candle
|
||||||
|
|
||||||
# Calculate predicted price from candle close
|
# Calculate predicted price from candle close (ensure Python float)
|
||||||
predicted_price = None
|
predicted_price = None
|
||||||
if '1m' in predicted_candles_denorm:
|
if '1m' in predicted_candles_denorm:
|
||||||
predicted_price = predicted_candles_denorm['1m'][3] # Close price
|
close_val = predicted_candles_denorm['1m'][3]
|
||||||
|
predicted_price = float(close_val.item() if hasattr(close_val, 'item') else close_val)
|
||||||
elif '1s' in predicted_candles_denorm:
|
elif '1s' in predicted_candles_denorm:
|
||||||
predicted_price = predicted_candles_denorm['1s'][3]
|
close_val = predicted_candles_denorm['1s'][3]
|
||||||
|
predicted_price = float(close_val.item() if hasattr(close_val, 'item') else close_val)
|
||||||
elif outputs.get('price_prediction') is not None:
|
elif outputs.get('price_prediction') is not None:
|
||||||
# Fallback to price_prediction head if available (normalized)
|
# Fallback to price_prediction head if available (normalized)
|
||||||
# This would need separate denormalization based on reference price
|
# This would need separate denormalization based on reference price
|
||||||
@@ -2755,42 +2794,61 @@ class RealTrainingAdapter:
|
|||||||
logger.debug(traceback.format_exc())
|
logger.debug(traceback.format_exc())
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
def _train_on_new_candle(self, session: Dict, symbol: str, timeframe: str, data_provider):
|
def _train_on_new_candle(self, session: Dict, symbol: str, timeframe: str, data_provider) -> Dict:
|
||||||
"""Train model on new candle when it closes"""
|
"""
|
||||||
|
Train model on new candle - Pure model interface with NO business logic
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Training session containing pending_action set by app
|
||||||
|
symbol: Trading symbol
|
||||||
|
timeframe: Timeframe for training
|
||||||
|
data_provider: Data provider for fetching candles
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with training metrics: {loss, accuracy, training_steps}
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# Get latest candle
|
# Get latest candles
|
||||||
df = data_provider.get_historical_data(symbol, timeframe, limit=2)
|
df = data_provider.get_historical_data(symbol, timeframe, limit=2)
|
||||||
if df is None or len(df) < 2:
|
if df is None or len(df) < 2:
|
||||||
return
|
return {'success': False, 'error': 'Insufficient data'}
|
||||||
|
|
||||||
# Check if we have a new candle
|
# Check if we have a new candle
|
||||||
latest_candle_time = df.index[-1]
|
latest_candle_time = df.index[-1]
|
||||||
if session['last_candle_time'] == latest_candle_time:
|
if session['last_candle_time'] == latest_candle_time:
|
||||||
return # Same candle, no training needed
|
return {'success': False, 'error': 'Same candle, no training needed'}
|
||||||
|
|
||||||
logger.debug(f"New candle detected: {latest_candle_time} (last: {session['last_candle_time']})")
|
logger.debug(f"New candle detected: {latest_candle_time} (last: {session['last_candle_time']})")
|
||||||
session['last_candle_time'] = latest_candle_time
|
session['last_candle_time'] = latest_candle_time
|
||||||
|
|
||||||
# Get the completed candle (second to last)
|
# Get the completed candle (second to last) and next candle
|
||||||
completed_candle = df.iloc[-2]
|
completed_candle = df.iloc[-2]
|
||||||
next_candle = df.iloc[-1]
|
next_candle = df.iloc[-1]
|
||||||
|
|
||||||
# Calculate if the prediction would have been correct
|
# Get action from session (set by app's training strategy)
|
||||||
|
action_label = session.get('pending_action')
|
||||||
|
if not action_label:
|
||||||
|
return {'success': False, 'error': 'No pending_action in session'}
|
||||||
|
|
||||||
|
# Fetch market state for training
|
||||||
|
market_state = self._fetch_market_state_for_candle(symbol, completed_candle.name, data_provider)
|
||||||
|
|
||||||
|
# Calculate price change
|
||||||
price_change = (next_candle['close'] - completed_candle['close']) / completed_candle['close']
|
price_change = (next_candle['close'] - completed_candle['close']) / completed_candle['close']
|
||||||
|
|
||||||
# Create training sample
|
# Create training sample
|
||||||
training_sample = {
|
training_sample = {
|
||||||
'symbol': symbol,
|
'symbol': symbol,
|
||||||
'timestamp': completed_candle.name,
|
'timestamp': completed_candle.name,
|
||||||
'market_state': self._fetch_market_state_for_candle(symbol, completed_candle.name, data_provider),
|
'market_state': market_state,
|
||||||
'action': 'BUY' if price_change > 0.001 else ('SELL' if price_change < -0.001 else 'HOLD'),
|
'action': action_label,
|
||||||
'entry_price': float(completed_candle['close']),
|
'entry_price': float(completed_candle['close']),
|
||||||
'exit_price': float(next_candle['close']),
|
'exit_price': float(next_candle['close']),
|
||||||
'profit_loss_pct': price_change * 100,
|
'profit_loss_pct': price_change * 100,
|
||||||
'direction': 'LONG' if price_change > 0 else 'SHORT'
|
'direction': 'LONG' if action_label == 'BUY' else ('SHORT' if action_label == 'SELL' else 'HOLD')
|
||||||
}
|
}
|
||||||
|
|
||||||
# Train on this sample
|
# Train based on model type
|
||||||
model_name = session['model_name']
|
model_name = session['model_name']
|
||||||
if model_name == 'Transformer':
|
if model_name == 'Transformer':
|
||||||
self._train_transformer_on_sample(training_sample)
|
self._train_transformer_on_sample(training_sample)
|
||||||
@@ -2801,15 +2859,25 @@ class RealTrainingAdapter:
|
|||||||
session['metrics']['accuracy'] = sum(self.realtime_training_metrics['accuracies']) / len(self.realtime_training_metrics['accuracies'])
|
session['metrics']['accuracy'] = sum(self.realtime_training_metrics['accuracies']) / len(self.realtime_training_metrics['accuracies'])
|
||||||
session['metrics']['steps'] = self.realtime_training_metrics['total_steps']
|
session['metrics']['steps'] = self.realtime_training_metrics['total_steps']
|
||||||
|
|
||||||
logger.info(f"Trained on candle: {symbol} {timeframe} @ {completed_candle.name} (change: {price_change:+.2%})")
|
logger.info(f"Trained on candle: {symbol} {timeframe} @ {completed_candle.name} action={action_label} (change: {price_change:+.2%})")
|
||||||
|
|
||||||
|
return {
|
||||||
|
'success': True,
|
||||||
|
'loss': session['metrics']['loss'],
|
||||||
|
'accuracy': session['metrics']['accuracy'],
|
||||||
|
'training_steps': session['metrics']['steps']
|
||||||
|
}
|
||||||
|
|
||||||
|
return {'success': False, 'error': f'Unsupported model: {model_name}'}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error training on new candle: {e}")
|
logger.warning(f"Error training on new candle: {e}")
|
||||||
|
return {'success': False, 'error': str(e)}
|
||||||
|
|
||||||
def _fetch_market_state_for_candle(self, symbol: str, timestamp, data_provider) -> Dict:
|
def _fetch_market_state_for_candle(self, symbol: str, timestamp, data_provider) -> Dict:
|
||||||
"""Fetch market state at a specific candle time"""
|
"""Fetch market state with OHLCV data for model training"""
|
||||||
try:
|
try:
|
||||||
# Simplified version - get recent data
|
# Get market state with OHLCV data only (NO business logic)
|
||||||
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
|
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
|
||||||
|
|
||||||
for tf in ['1s', '1m', '1h', '1d']:
|
for tf in ['1s', '1m', '1h', '1d']:
|
||||||
@@ -3192,8 +3260,15 @@ class RealTrainingAdapter:
|
|||||||
# Extract action prediction
|
# Extract action prediction
|
||||||
action_probs = outputs.get('action_probs')
|
action_probs = outputs.get('action_probs')
|
||||||
if action_probs is not None:
|
if action_probs is not None:
|
||||||
action_idx = torch.argmax(action_probs, dim=-1).item()
|
# Handle different tensor shapes: [batch, 3] or [3]
|
||||||
confidence = action_probs[0, action_idx].item()
|
if action_probs.dim() == 1:
|
||||||
|
# Shape [3] - single prediction
|
||||||
|
action_idx = torch.argmax(action_probs, dim=0).item()
|
||||||
|
confidence = action_probs[action_idx].item()
|
||||||
|
else:
|
||||||
|
# Shape [batch, 3] - take first batch item
|
||||||
|
action_idx = torch.argmax(action_probs[0], dim=0).item()
|
||||||
|
confidence = action_probs[0, action_idx].item()
|
||||||
|
|
||||||
# Map to BUY/SELL/HOLD
|
# Map to BUY/SELL/HOLD
|
||||||
actions = ['BUY', 'SELL', 'HOLD']
|
actions = ['BUY', 'SELL', 'HOLD']
|
||||||
@@ -3291,29 +3366,125 @@ class RealTrainingAdapter:
|
|||||||
else:
|
else:
|
||||||
logger.info(f"Live Signal (NOT executed): {signal['action']} @ {signal['price']:.2f} (conf: {signal['confidence']:.2f}) - {self._get_rejection_reason(session, signal)}")
|
logger.info(f"Live Signal (NOT executed): {signal['action']} @ {signal['price']:.2f} (conf: {signal['confidence']:.2f}) - {self._get_rejection_reason(session, signal)}")
|
||||||
|
|
||||||
# Store prediction for visualization WITH predicted_candle data for ghost candles
|
# Store prediction for visualization (INCLUDE predicted_candle for ghost candles!)
|
||||||
if self.orchestrator and hasattr(self.orchestrator, 'store_transformer_prediction'):
|
if self.orchestrator and hasattr(self.orchestrator, 'store_transformer_prediction'):
|
||||||
stored_prediction = {
|
prediction_data = {
|
||||||
'timestamp': datetime.now(timezone.utc).isoformat(),
|
'timestamp': datetime.now(timezone.utc).isoformat(),
|
||||||
'current_price': current_price,
|
'current_price': current_price,
|
||||||
'predicted_price': prediction.get('predicted_price', current_price * (1.01 if prediction['action'] == 'BUY' else 0.99)),
|
'predicted_price': prediction.get('predicted_price', current_price),
|
||||||
'price_change': 1.0 if prediction['action'] == 'BUY' else -1.0,
|
'price_change': 1.0 if prediction['action'] == 'BUY' else -1.0,
|
||||||
'confidence': prediction['confidence'],
|
'confidence': prediction['confidence'],
|
||||||
'action': prediction['action'],
|
'action': prediction['action'],
|
||||||
'horizon_minutes': 10,
|
'horizon_minutes': 10,
|
||||||
'source': 'live_inference'
|
'source': 'live_inference'
|
||||||
}
|
}
|
||||||
# Include predicted_candle for ghost candle visualization
|
|
||||||
|
# Include REAL predicted_candle from model (for ghost candles)
|
||||||
if 'predicted_candle' in prediction and prediction['predicted_candle']:
|
if 'predicted_candle' in prediction and prediction['predicted_candle']:
|
||||||
stored_prediction['predicted_candle'] = prediction['predicted_candle']
|
# Ensure predicted_candle values are Python native types (not tensors)
|
||||||
stored_prediction['next_candles'] = prediction['predicted_candle'] # Alias for compatibility
|
predicted_candle_clean = {}
|
||||||
logger.debug(f"Stored prediction with {len(prediction['predicted_candle'])} timeframe candles")
|
for tf, candle_data in prediction['predicted_candle'].items():
|
||||||
|
if isinstance(candle_data, (list, tuple)):
|
||||||
|
# Convert list/tuple elements to Python scalars
|
||||||
|
predicted_candle_clean[tf] = [
|
||||||
|
float(v.item() if hasattr(v, 'item') else v)
|
||||||
|
for v in candle_data
|
||||||
|
]
|
||||||
|
elif hasattr(candle_data, 'tolist'):
|
||||||
|
# Tensor array - convert to list
|
||||||
|
predicted_candle_clean[tf] = [float(v) for v in candle_data.tolist()]
|
||||||
|
else:
|
||||||
|
predicted_candle_clean[tf] = candle_data
|
||||||
|
|
||||||
self.orchestrator.store_transformer_prediction(symbol, stored_prediction)
|
prediction_data['predicted_candle'] = predicted_candle_clean
|
||||||
|
|
||||||
# Per-candle training mode
|
# Use actual predicted price from candle close (ensure it's a Python float)
|
||||||
if train_every_candle:
|
predicted_price_val = None
|
||||||
self._train_on_new_candle(session, symbol, timeframe, data_provider)
|
if '1m' in predicted_candle_clean:
|
||||||
|
close_val = predicted_candle_clean['1m'][3]
|
||||||
|
predicted_price_val = float(close_val.item() if hasattr(close_val, 'item') else close_val)
|
||||||
|
elif '1s' in predicted_candle_clean:
|
||||||
|
close_val = predicted_candle_clean['1s'][3]
|
||||||
|
predicted_price_val = float(close_val.item() if hasattr(close_val, 'item') else close_val)
|
||||||
|
|
||||||
|
if predicted_price_val is not None:
|
||||||
|
prediction_data['predicted_price'] = predicted_price_val
|
||||||
|
prediction_data['price_change'] = ((predicted_price_val - current_price) / current_price) * 100
|
||||||
|
else:
|
||||||
|
prediction_data['predicted_price'] = prediction.get('predicted_price', current_price)
|
||||||
|
prediction_data['price_change'] = 1.0 if prediction['action'] == 'BUY' else -1.0
|
||||||
|
else:
|
||||||
|
# Fallback to estimated price if no candle prediction
|
||||||
|
prediction_data['predicted_price'] = prediction.get('predicted_price', current_price * (1.01 if prediction['action'] == 'BUY' else 0.99))
|
||||||
|
prediction_data['price_change'] = 1.0 if prediction['action'] == 'BUY' else -1.0
|
||||||
|
|
||||||
|
# Include trend_vector if available (convert tensors to Python types)
|
||||||
|
if 'trend_vector' in prediction:
|
||||||
|
trend_vec = prediction['trend_vector']
|
||||||
|
# Convert any tensors to Python native types
|
||||||
|
if isinstance(trend_vec, dict):
|
||||||
|
serialized_trend = {}
|
||||||
|
for key, value in trend_vec.items():
|
||||||
|
if hasattr(value, 'numel'): # Tensor
|
||||||
|
if value.numel() == 1: # Scalar tensor
|
||||||
|
serialized_trend[key] = value.item()
|
||||||
|
else: # Multi-element tensor
|
||||||
|
serialized_trend[key] = value.detach().cpu().tolist()
|
||||||
|
elif hasattr(value, 'tolist'): # Other array-like
|
||||||
|
serialized_trend[key] = value.tolist()
|
||||||
|
elif isinstance(value, (list, tuple)):
|
||||||
|
# Recursively convert list/tuple of tensors
|
||||||
|
serialized_trend[key] = []
|
||||||
|
for v in value:
|
||||||
|
if hasattr(v, 'numel'):
|
||||||
|
if v.numel() == 1:
|
||||||
|
serialized_trend[key].append(v.item())
|
||||||
|
else:
|
||||||
|
serialized_trend[key].append(v.detach().cpu().tolist())
|
||||||
|
elif hasattr(v, 'tolist'):
|
||||||
|
serialized_trend[key].append(v.tolist())
|
||||||
|
else:
|
||||||
|
serialized_trend[key].append(v)
|
||||||
|
else:
|
||||||
|
serialized_trend[key] = value
|
||||||
|
prediction_data['trend_vector'] = serialized_trend
|
||||||
|
else:
|
||||||
|
prediction_data['trend_vector'] = trend_vec
|
||||||
|
|
||||||
|
self.orchestrator.store_transformer_prediction(symbol, prediction_data)
|
||||||
|
|
||||||
|
# Training decision using strategy manager
|
||||||
|
training_strategy = session.get('training_strategy')
|
||||||
|
if training_strategy and training_strategy.mode != 'none':
|
||||||
|
# Get pivot markers for training decision
|
||||||
|
pivot_markers = {}
|
||||||
|
if hasattr(training_strategy, 'dashboard') and training_strategy.dashboard:
|
||||||
|
try:
|
||||||
|
df = data_provider.get_historical_data(symbol, timeframe, limit=200)
|
||||||
|
if df is not None and len(df) >= 10:
|
||||||
|
pivot_markers = training_strategy.dashboard._get_pivot_markers_for_timeframe(symbol, timeframe, df)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not get pivot markers: {e}")
|
||||||
|
|
||||||
|
# Get current candle timestamp
|
||||||
|
df_current = data_provider.get_historical_data(symbol, timeframe, limit=1)
|
||||||
|
if df_current is not None and len(df_current) > 0:
|
||||||
|
current_timestamp = df_current.index[-1]
|
||||||
|
|
||||||
|
# Ask strategy manager if we should train
|
||||||
|
should_train, action_data = training_strategy.should_train_on_candle(
|
||||||
|
symbol, timeframe, current_timestamp, pivot_markers
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_train and action_data:
|
||||||
|
# Set action in session for training
|
||||||
|
session['pending_action'] = action_data['action']
|
||||||
|
|
||||||
|
# Call pure training method
|
||||||
|
train_result = self._train_on_new_candle(session, symbol, timeframe, data_provider)
|
||||||
|
|
||||||
|
if train_result.get('success'):
|
||||||
|
logger.info(f"Training completed: {action_data['action']} (reason: {action_data.get('reason', 'unknown')})")
|
||||||
|
|
||||||
# Sleep based on timeframe
|
# Sleep based on timeframe
|
||||||
sleep_time = self._get_sleep_time_for_timeframe(timeframe)
|
sleep_time = self._get_sleep_time_for_timeframe(timeframe)
|
||||||
@@ -3321,6 +3492,8 @@ class RealTrainingAdapter:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in inference loop: {e}")
|
logger.error(f"Error in inference loop: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
logger.info(f"Inference loop stopped: {inference_id}")
|
logger.info(f"Inference loop stopped: {inference_id}")
|
||||||
|
|||||||
@@ -1,28 +1,5 @@
|
|||||||
{
|
{
|
||||||
"annotations": [
|
"annotations": [
|
||||||
{
|
|
||||||
"annotation_id": "dc35c362-6174-4db4-b4db-8cc58a4ba8e5",
|
|
||||||
"symbol": "ETH/USDT",
|
|
||||||
"timeframe": "1h",
|
|
||||||
"entry": {
|
|
||||||
"timestamp": "2025-10-07 13:00",
|
|
||||||
"price": 4755,
|
|
||||||
"index": 28
|
|
||||||
},
|
|
||||||
"exit": {
|
|
||||||
"timestamp": "2025-10-11 21:00",
|
|
||||||
"price": 3643.33,
|
|
||||||
"index": 63
|
|
||||||
},
|
|
||||||
"direction": "SHORT",
|
|
||||||
"profit_loss_pct": 23.378969505783388,
|
|
||||||
"notes": "",
|
|
||||||
"created_at": "2025-10-24T22:33:26.187249",
|
|
||||||
"market_context": {
|
|
||||||
"entry_state": {},
|
|
||||||
"exit_state": {}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"annotation_id": "5d5c4354-12dd-4e0c-92a8-eff631a5dfab",
|
"annotation_id": "5d5c4354-12dd-4e0c-92a8-eff631a5dfab",
|
||||||
"symbol": "ETH/USDT",
|
"symbol": "ETH/USDT",
|
||||||
@@ -115,29 +92,6 @@
|
|||||||
"exit_state": {}
|
"exit_state": {}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"annotation_id": "46cc0e20-0bfb-498c-9358-71b52a003d0f",
|
|
||||||
"symbol": "ETH/USDT",
|
|
||||||
"timeframe": "1s",
|
|
||||||
"entry": {
|
|
||||||
"timestamp": "2025-11-22 12:50",
|
|
||||||
"price": 2712.11,
|
|
||||||
"index": 26
|
|
||||||
},
|
|
||||||
"exit": {
|
|
||||||
"timestamp": "2025-11-22 12:53:06",
|
|
||||||
"price": 2721.44,
|
|
||||||
"index": 45
|
|
||||||
},
|
|
||||||
"direction": "LONG",
|
|
||||||
"profit_loss_pct": 0.3440125953593301,
|
|
||||||
"notes": "",
|
|
||||||
"created_at": "2025-11-22T15:19:00.480166",
|
|
||||||
"market_context": {
|
|
||||||
"entry_state": {},
|
|
||||||
"exit_state": {}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"annotation_id": "b01fe6b2-7724-495e-ab01-3f3d3aa0da5d",
|
"annotation_id": "b01fe6b2-7724-495e-ab01-3f3d3aa0da5d",
|
||||||
"symbol": "ETH/USDT",
|
"symbol": "ETH/USDT",
|
||||||
@@ -160,10 +114,33 @@
|
|||||||
"entry_state": {},
|
"entry_state": {},
|
||||||
"exit_state": {}
|
"exit_state": {}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"annotation_id": "19a566fa-63bb-4ce3-9f30-116127c9fe95",
|
||||||
|
"symbol": "ETH/USDT",
|
||||||
|
"timeframe": "1s",
|
||||||
|
"entry": {
|
||||||
|
"timestamp": "2025-11-22 22:25:17",
|
||||||
|
"price": 2761.97,
|
||||||
|
"index": 35
|
||||||
|
},
|
||||||
|
"exit": {
|
||||||
|
"timestamp": "2025-11-22 22:30:40",
|
||||||
|
"price": 2760.15,
|
||||||
|
"index": 49
|
||||||
|
},
|
||||||
|
"direction": "SHORT",
|
||||||
|
"profit_loss_pct": 0.06589499523889503,
|
||||||
|
"notes": "",
|
||||||
|
"created_at": "2025-11-22T22:35:55.606071+00:00",
|
||||||
|
"market_context": {
|
||||||
|
"entry_state": {},
|
||||||
|
"exit_state": {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"total_annotations": 7,
|
"total_annotations": 6,
|
||||||
"last_updated": "2025-11-22T15:31:43.940190"
|
"last_updated": "2025-11-22T22:35:55.606373+00:00"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -17,7 +17,7 @@ from flask import Flask, render_template, request, jsonify, send_file
|
|||||||
from dash import Dash, html
|
from dash import Dash, html
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
from typing import Optional, Dict, List, Any
|
from typing import Optional, Dict, List, Any, Tuple
|
||||||
import json
|
import json
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -370,8 +370,8 @@ class BacktestRunner:
|
|||||||
action_idx = torch.argmax(action_probs, dim=-1).item()
|
action_idx = torch.argmax(action_probs, dim=-1).item()
|
||||||
confidence = action_probs[0, action_idx].item()
|
confidence = action_probs[0, action_idx].item()
|
||||||
|
|
||||||
# Map to BUY/SELL/HOLD
|
# Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL)
|
||||||
actions = ['BUY', 'SELL', 'HOLD']
|
actions = ['HOLD', 'BUY', 'SELL']
|
||||||
if action_idx < len(actions):
|
if action_idx < len(actions):
|
||||||
action = actions[action_idx]
|
action = actions[action_idx]
|
||||||
else:
|
else:
|
||||||
@@ -490,6 +490,194 @@ class BacktestRunner:
|
|||||||
state['stop_requested'] = True
|
state['stop_requested'] = True
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingStrategyManager:
|
||||||
|
"""
|
||||||
|
Manages training strategies and decisions - Separates business logic from model interface
|
||||||
|
|
||||||
|
Training Modes:
|
||||||
|
- 'none': No training (inference only)
|
||||||
|
- 'every_candle': Train on every completed candle
|
||||||
|
- 'pivots_only': Train only on pivot points (BUY at L pivots, SELL at H pivots)
|
||||||
|
- 'manual': Training triggered manually by user button
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_provider, training_adapter):
|
||||||
|
self.data_provider = data_provider
|
||||||
|
self.training_adapter = training_adapter
|
||||||
|
self.mode = 'none' # Default: no training
|
||||||
|
self.dashboard = None # Set by dashboard after initialization
|
||||||
|
|
||||||
|
# Statistics tracking
|
||||||
|
self.stats = {
|
||||||
|
'total_trained': 0,
|
||||||
|
'by_action': {'BUY': 0, 'SELL': 0, 'HOLD': 0},
|
||||||
|
'profitable': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
def should_train_on_candle(self, symbol: str, timeframe: str, candle_timestamp, pivot_markers: Dict = None) -> Tuple[bool, Optional[Dict]]:
|
||||||
|
"""
|
||||||
|
Decide if we should train on this candle based on current mode
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol
|
||||||
|
timeframe: Candle timeframe
|
||||||
|
candle_timestamp: Timestamp of the candle
|
||||||
|
pivot_markers: Dict of pivot markers (timestamp -> pivot data)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (should_train: bool, action_data: Optional[Dict])
|
||||||
|
action_data contains: {'action': 'BUY'/'SELL'/'HOLD', 'pivot_level': int, 'pivot_strength': float}
|
||||||
|
"""
|
||||||
|
if self.mode == 'none':
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
elif self.mode == 'every_candle':
|
||||||
|
# Train on every candle - determine action from price movement or pivots
|
||||||
|
action_data = self._get_action_for_candle(symbol, timeframe, candle_timestamp, pivot_markers)
|
||||||
|
return True, action_data
|
||||||
|
|
||||||
|
elif self.mode == 'pivots_only':
|
||||||
|
# Train only on pivot candles
|
||||||
|
return self._is_pivot_candle(candle_timestamp, pivot_markers)
|
||||||
|
|
||||||
|
elif self.mode == 'manual':
|
||||||
|
# Manual training - don't auto-train
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def _get_action_for_candle(self, symbol: str, timeframe: str, candle_timestamp, pivot_markers: Dict = None) -> Dict:
|
||||||
|
"""
|
||||||
|
Determine action for any candle (pivot or non-pivot)
|
||||||
|
For pivot candles: BUY at L, SELL at H
|
||||||
|
For non-pivot candles: Use price movement thresholds
|
||||||
|
"""
|
||||||
|
# First check if it's a pivot candle
|
||||||
|
is_pivot, pivot_action = self._is_pivot_candle(candle_timestamp, pivot_markers)
|
||||||
|
if is_pivot and pivot_action:
|
||||||
|
return pivot_action
|
||||||
|
|
||||||
|
# Not a pivot - use price movement based logic
|
||||||
|
# Get recent candles to determine trend
|
||||||
|
df = self.data_provider.get_historical_data(symbol, timeframe, limit=5)
|
||||||
|
if df is None or len(df) < 3:
|
||||||
|
return {'action': 'HOLD', 'reason': 'insufficient_data'}
|
||||||
|
|
||||||
|
# Simple momentum: if price going up, BUY, if going down, SELL
|
||||||
|
recent_change = (df.iloc[-1]['close'] - df.iloc[-3]['close']) / df.iloc[-3]['close']
|
||||||
|
|
||||||
|
if recent_change > 0.0005: # 0.05% up
|
||||||
|
action = 'BUY'
|
||||||
|
elif recent_change < -0.0005: # 0.05% down
|
||||||
|
action = 'SELL'
|
||||||
|
else:
|
||||||
|
action = 'HOLD'
|
||||||
|
|
||||||
|
return {
|
||||||
|
'action': action,
|
||||||
|
'reason': 'price_movement',
|
||||||
|
'change_pct': recent_change * 100
|
||||||
|
}
|
||||||
|
|
||||||
|
def _is_pivot_candle(self, timestamp, pivot_markers: Dict = None) -> Tuple[bool, Optional[Dict]]:
|
||||||
|
"""
|
||||||
|
Check if candle is a pivot point and return action
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_pivot: bool, action_data: Optional[Dict])
|
||||||
|
"""
|
||||||
|
if not pivot_markers:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
candle_timestamp = str(timestamp)
|
||||||
|
candle_pivots = pivot_markers.get(candle_timestamp, {})
|
||||||
|
|
||||||
|
if not candle_pivots:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
# BUY at L pivots (lows - support levels)
|
||||||
|
if 'lows' in candle_pivots and len(candle_pivots['lows']) > 0:
|
||||||
|
best_low = max(candle_pivots['lows'], key=lambda p: p.get('level', 0))
|
||||||
|
pivot_level = best_low.get('level', 1)
|
||||||
|
pivot_strength = best_low.get('strength', 0.5)
|
||||||
|
|
||||||
|
logger.info(f"L{pivot_level}L pivot detected @ {timestamp}, strength={pivot_strength:.2f} → BUY signal")
|
||||||
|
|
||||||
|
return True, {
|
||||||
|
'action': 'BUY',
|
||||||
|
'pivot_level': pivot_level,
|
||||||
|
'pivot_strength': pivot_strength,
|
||||||
|
'reason': 'low_pivot'
|
||||||
|
}
|
||||||
|
|
||||||
|
# SELL at H pivots (highs - resistance levels)
|
||||||
|
elif 'highs' in candle_pivots and len(candle_pivots['highs']) > 0:
|
||||||
|
best_high = max(candle_pivots['highs'], key=lambda p: p.get('level', 0))
|
||||||
|
pivot_level = best_high.get('level', 1)
|
||||||
|
pivot_strength = best_high.get('strength', 0.5)
|
||||||
|
|
||||||
|
logger.info(f"L{pivot_level}H pivot detected @ {timestamp}, strength={pivot_strength:.2f} → SELL signal")
|
||||||
|
|
||||||
|
return True, {
|
||||||
|
'action': 'SELL',
|
||||||
|
'pivot_level': pivot_level,
|
||||||
|
'pivot_strength': pivot_strength,
|
||||||
|
'reason': 'high_pivot'
|
||||||
|
}
|
||||||
|
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def train_manually(self, symbol: str, timeframe: str, action: str) -> Dict:
|
||||||
|
"""
|
||||||
|
Manually trigger training with specified action
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol
|
||||||
|
timeframe: Timeframe
|
||||||
|
action: Action to train ('BUY', 'SELL', or 'HOLD')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Training result dict with metrics
|
||||||
|
"""
|
||||||
|
logger.info(f"Manual training triggered: {action} on {symbol} {timeframe}")
|
||||||
|
|
||||||
|
# Create action data
|
||||||
|
action_data = {
|
||||||
|
'action': action,
|
||||||
|
'reason': 'manual_trigger'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update stats
|
||||||
|
self.stats['total_trained'] += 1
|
||||||
|
self.stats['by_action'][action] = self.stats['by_action'].get(action, 0) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
'success': True,
|
||||||
|
'action': action,
|
||||||
|
'triggered_by': 'manual'
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict:
|
||||||
|
"""Get training statistics"""
|
||||||
|
total = self.stats['total_trained']
|
||||||
|
if total == 0:
|
||||||
|
return {
|
||||||
|
'total_trained': 0,
|
||||||
|
'by_action': {'BUY': '0%', 'SELL': '0%', 'HOLD': '0%'},
|
||||||
|
'mode': self.mode
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_trained': total,
|
||||||
|
'by_action': {
|
||||||
|
'BUY': f"{(self.stats['by_action'].get('BUY', 0) / total * 100):.1f}%",
|
||||||
|
'SELL': f"{(self.stats['by_action'].get('SELL', 0) / total * 100):.1f}%",
|
||||||
|
'HOLD': f"{(self.stats['by_action'].get('HOLD', 0) / total * 100):.1f}%"
|
||||||
|
},
|
||||||
|
'mode': self.mode
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class AnnotationDashboard:
|
class AnnotationDashboard:
|
||||||
"""Main annotation dashboard application"""
|
"""Main annotation dashboard application"""
|
||||||
|
|
||||||
@@ -586,12 +774,19 @@ class AnnotationDashboard:
|
|||||||
self.annotation_manager = AnnotationManager()
|
self.annotation_manager = AnnotationManager()
|
||||||
# Use REAL training adapter - NO SIMULATION!
|
# Use REAL training adapter - NO SIMULATION!
|
||||||
self.training_adapter = RealTrainingAdapter(None, self.data_provider)
|
self.training_adapter = RealTrainingAdapter(None, self.data_provider)
|
||||||
|
# Initialize training strategy manager (controls training decisions)
|
||||||
|
self.training_strategy = TrainingStrategyManager(self.data_provider, self.training_adapter)
|
||||||
|
self.training_strategy.dashboard = self
|
||||||
# Pass socketio to training adapter for live trade updates
|
# Pass socketio to training adapter for live trade updates
|
||||||
if self.has_socketio and self.socketio:
|
if self.has_socketio and self.socketio:
|
||||||
self.training_adapter.socketio = self.socketio
|
self.training_adapter.socketio = self.socketio
|
||||||
# Backtest runner for replaying visible chart with predictions
|
# Backtest runner for replaying visible chart with predictions
|
||||||
self.backtest_runner = BacktestRunner()
|
self.backtest_runner = BacktestRunner()
|
||||||
|
|
||||||
|
# Prediction cache for training: stores inference inputs/outputs to compare with actual candles
|
||||||
|
# Format: {symbol: {timeframe: [{'timestamp': ts, 'inputs': {...}, 'outputs': {...}, 'norm_params': {...}}, ...]}}
|
||||||
|
self.prediction_cache = {}
|
||||||
|
|
||||||
# Check if we should auto-load a model at startup
|
# Check if we should auto-load a model at startup
|
||||||
auto_load_model = os.getenv('AUTO_LOAD_MODEL', 'Transformer') # Default: Transformer
|
auto_load_model = os.getenv('AUTO_LOAD_MODEL', 'Transformer') # Default: Transformer
|
||||||
|
|
||||||
@@ -2121,14 +2316,21 @@ class AnnotationDashboard:
|
|||||||
|
|
||||||
@self.server.route('/api/realtime-inference/start', methods=['POST'])
|
@self.server.route('/api/realtime-inference/start', methods=['POST'])
|
||||||
def start_realtime_inference():
|
def start_realtime_inference():
|
||||||
"""Start real-time inference mode with optional training modes"""
|
"""Start real-time inference mode with configurable training strategy"""
|
||||||
try:
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
model_name = data.get('model_name')
|
model_name = data.get('model_name')
|
||||||
symbol = data.get('symbol', 'ETH/USDT')
|
symbol = data.get('symbol', 'ETH/USDT')
|
||||||
timeframe = data.get('timeframe', '1m')
|
timeframe = data.get('timeframe', '1m')
|
||||||
enable_live_training = data.get('enable_live_training', False) # Pivot-based training
|
|
||||||
train_every_candle = data.get('train_every_candle', False) # Per-candle training
|
# New unified training_mode parameter
|
||||||
|
training_mode = data.get('training_mode', 'none') # 'none', 'every_candle', 'pivots_only', 'manual'
|
||||||
|
|
||||||
|
# Backward compatibility with old parameters
|
||||||
|
if 'enable_live_training' in data or 'train_every_candle' in data:
|
||||||
|
enable_live_training = data.get('enable_live_training', False)
|
||||||
|
train_every_candle = data.get('train_every_candle', False)
|
||||||
|
training_mode = 'every_candle' if train_every_candle else ('pivots_only' if enable_live_training else 'none')
|
||||||
|
|
||||||
if not self.training_adapter:
|
if not self.training_adapter:
|
||||||
return jsonify({
|
return jsonify({
|
||||||
@@ -2139,18 +2341,21 @@ class AnnotationDashboard:
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
# Start real-time inference with optional training modes
|
# Set training mode on strategy manager
|
||||||
|
self.training_strategy.mode = training_mode
|
||||||
|
logger.info(f"Training strategy mode set to: {training_mode}")
|
||||||
|
|
||||||
|
# Start real-time inference - pass strategy manager for training decisions
|
||||||
inference_id = self.training_adapter.start_realtime_inference(
|
inference_id = self.training_adapter.start_realtime_inference(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
symbol=symbol,
|
symbol=symbol,
|
||||||
data_provider=self.data_provider,
|
data_provider=self.data_provider,
|
||||||
enable_live_training=enable_live_training,
|
enable_live_training=(training_mode != 'none'),
|
||||||
train_every_candle=train_every_candle,
|
train_every_candle=(training_mode == 'every_candle'),
|
||||||
timeframe=timeframe
|
timeframe=timeframe,
|
||||||
|
training_strategy=self.training_strategy # Pass strategy manager
|
||||||
)
|
)
|
||||||
|
|
||||||
training_mode = "per-candle" if train_every_candle else ("pivot-based" if enable_live_training else "inference-only")
|
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'success': True,
|
'success': True,
|
||||||
'inference_id': inference_id,
|
'inference_id': inference_id,
|
||||||
@@ -2259,20 +2464,17 @@ class AnnotationDashboard:
|
|||||||
if hasattr(self.orchestrator, 'recent_transformer_predictions') and symbol in self.orchestrator.recent_transformer_predictions:
|
if hasattr(self.orchestrator, 'recent_transformer_predictions') and symbol in self.orchestrator.recent_transformer_predictions:
|
||||||
transformer_preds = list(self.orchestrator.recent_transformer_predictions[symbol])
|
transformer_preds = list(self.orchestrator.recent_transformer_predictions[symbol])
|
||||||
if transformer_preds:
|
if transformer_preds:
|
||||||
# Use the most recent stored prediction (from inference loop)
|
# Convert any remaining tensors to Python types before JSON serialization
|
||||||
predictions['transformer'] = transformer_preds[-1]
|
transformer_pred = transformer_preds[-1].copy()
|
||||||
logger.debug(f"Using stored prediction: {list(transformer_preds[-1].keys())}")
|
predictions['transformer'] = self._serialize_prediction(transformer_pred)
|
||||||
else:
|
|
||||||
# Fallback: generate new prediction if no stored predictions
|
|
||||||
transformer_pred = self._get_live_transformer_prediction(symbol)
|
|
||||||
if transformer_pred:
|
|
||||||
predictions['transformer'] = transformer_pred
|
|
||||||
|
|
||||||
if predictions:
|
if predictions:
|
||||||
response['prediction'] = predictions
|
response['prediction'] = predictions
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error getting predictions: {e}")
|
logger.debug(f"Error getting predictions: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
|
||||||
return jsonify(response)
|
return jsonify(response)
|
||||||
|
|
||||||
@@ -2322,10 +2524,101 @@ class AnnotationDashboard:
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@self.server.route('/api/realtime-inference/train-manual', methods=['POST'])
|
||||||
|
def train_manual():
|
||||||
|
"""Manually trigger training on current candle with specified action"""
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
inference_id = data.get('inference_id')
|
||||||
|
action = data.get('action', 'HOLD')
|
||||||
|
|
||||||
|
if not self.training_adapter:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Training adapter not available'
|
||||||
|
})
|
||||||
|
|
||||||
|
# Get active inference session
|
||||||
|
if not hasattr(self.training_adapter, 'inference_sessions'):
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'error': 'No active inference sessions'
|
||||||
|
})
|
||||||
|
|
||||||
|
session = self.training_adapter.inference_sessions.get(inference_id)
|
||||||
|
if not session:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Inference session not found'
|
||||||
|
})
|
||||||
|
|
||||||
|
# Set pending action for training
|
||||||
|
session['pending_action'] = action
|
||||||
|
|
||||||
|
# Get session parameters
|
||||||
|
symbol = session.get('symbol', 'ETH/USDT')
|
||||||
|
timeframe = session.get('timeframe', '1m')
|
||||||
|
data_provider = session.get('data_provider')
|
||||||
|
|
||||||
|
# Call training method
|
||||||
|
train_result = self.training_adapter._train_on_new_candle(
|
||||||
|
session, symbol, timeframe, data_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
if train_result.get('success'):
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'action': action,
|
||||||
|
'metrics': {
|
||||||
|
'loss': train_result.get('loss', 0.0),
|
||||||
|
'accuracy': train_result.get('accuracy', 0.0),
|
||||||
|
'training_steps': train_result.get('training_steps', 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'error': train_result.get('error', 'Training failed')
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in manual training: {e}")
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
})
|
||||||
|
|
||||||
# WebSocket event handlers (if SocketIO is available)
|
# WebSocket event handlers (if SocketIO is available)
|
||||||
if self.has_socketio:
|
if self.has_socketio:
|
||||||
self._setup_websocket_handlers()
|
self._setup_websocket_handlers()
|
||||||
|
|
||||||
|
def _serialize_prediction(self, prediction: Dict) -> Dict:
|
||||||
|
"""Convert PyTorch tensors in prediction dict to JSON-serializable Python types"""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
serialized = {}
|
||||||
|
for key, value in prediction.items():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
if value.numel() == 1: # Scalar tensor
|
||||||
|
serialized[key] = value.item()
|
||||||
|
else: # Multi-element tensor
|
||||||
|
serialized[key] = value.detach().cpu().tolist()
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
serialized[key] = self._serialize_prediction(value) # Recursive
|
||||||
|
elif isinstance(value, (list, tuple)):
|
||||||
|
serialized[key] = [
|
||||||
|
v.item() if isinstance(v, torch.Tensor) and v.numel() == 1 else
|
||||||
|
(v.detach().cpu().tolist() if isinstance(v, torch.Tensor) else v)
|
||||||
|
for v in value
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
serialized[key] = value
|
||||||
|
return serialized
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error serializing prediction: {e}")
|
||||||
|
# Fallback: return as-is (might fail JSON serialization but won't crash)
|
||||||
|
return prediction
|
||||||
|
|
||||||
def _setup_websocket_handlers(self):
|
def _setup_websocket_handlers(self):
|
||||||
"""Setup WebSocket event handlers for real-time updates"""
|
"""Setup WebSocket event handlers for real-time updates"""
|
||||||
if not self.has_socketio:
|
if not self.has_socketio:
|
||||||
@@ -2748,35 +3041,209 @@ class AnnotationDashboard:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1):
|
def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1):
|
||||||
"""Get live prediction from model"""
|
"""
|
||||||
|
Get live prediction from model using trainer inference
|
||||||
|
|
||||||
|
Caches inference data (inputs/outputs) for later training when actual candle arrives.
|
||||||
|
This allows us to:
|
||||||
|
1. Compare predicted vs actual candle values
|
||||||
|
2. Calculate loss
|
||||||
|
3. Do backpropagation with correct outputs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with prediction results including predicted_candle for ghost candle display
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
|
if not self.orchestrator:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get recent candles for prediction
|
# Get trainer from orchestrator
|
||||||
candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=200)
|
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
||||||
if not candles or len(candles) < 200:
|
if not trainer or not trainer.model:
|
||||||
|
logger.debug("No transformer trainer available for live prediction")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# TODO: Implement actual prediction logic
|
# Get market data using training adapter's method (reuses existing logic)
|
||||||
# For now, return placeholder
|
if not hasattr(self.training_adapter, '_get_realtime_market_data'):
|
||||||
import random
|
logger.warning("Training adapter missing _get_realtime_market_data method")
|
||||||
|
return None
|
||||||
|
|
||||||
|
market_data, norm_params = self.training_adapter._get_realtime_market_data(symbol, self.data_provider)
|
||||||
|
if not market_data:
|
||||||
|
logger.debug(f"No market data available for {symbol} {timeframe}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Make prediction with model
|
||||||
|
import torch
|
||||||
|
timestamp = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
trainer.model.eval()
|
||||||
|
outputs = trainer.model(**market_data)
|
||||||
|
|
||||||
|
# Extract action prediction
|
||||||
|
action_probs = outputs.get('action_probs')
|
||||||
|
if action_probs is None:
|
||||||
|
logger.debug("No action_probs in model output")
|
||||||
|
return None
|
||||||
|
|
||||||
|
action_idx = torch.argmax(action_probs, dim=-1).item()
|
||||||
|
confidence = action_probs[0, action_idx].item()
|
||||||
|
|
||||||
|
# Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL)
|
||||||
|
actions = ['HOLD', 'BUY', 'SELL']
|
||||||
|
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
|
||||||
|
|
||||||
|
# Extract predicted candles and denormalize
|
||||||
|
predicted_candles_raw = {}
|
||||||
|
if 'next_candles' in outputs:
|
||||||
|
for tf, tensor in outputs['next_candles'].items():
|
||||||
|
predicted_candles_raw[tf] = tensor.detach().cpu().numpy().tolist()
|
||||||
|
|
||||||
|
# Denormalize predicted candles
|
||||||
|
predicted_candles_denorm = {}
|
||||||
|
if predicted_candles_raw and norm_params:
|
||||||
|
for tf, raw_candle in predicted_candles_raw.items():
|
||||||
|
if tf in norm_params:
|
||||||
|
params = norm_params[tf]
|
||||||
|
price_min = params['price_min']
|
||||||
|
price_max = params['price_max']
|
||||||
|
vol_min = params['volume_min']
|
||||||
|
vol_max = params['volume_max']
|
||||||
|
|
||||||
|
# raw_candle is [1, 5] list
|
||||||
|
candle_values = raw_candle[0]
|
||||||
|
|
||||||
|
denorm_candle = [
|
||||||
|
candle_values[0] * (price_max - price_min) + price_min, # Open
|
||||||
|
candle_values[1] * (price_max - price_min) + price_min, # High
|
||||||
|
candle_values[2] * (price_max - price_min) + price_min, # Low
|
||||||
|
candle_values[3] * (price_max - price_min) + price_min, # Close
|
||||||
|
candle_values[4] * (vol_max - vol_min) + vol_min # Volume
|
||||||
|
]
|
||||||
|
predicted_candles_denorm[tf] = denorm_candle
|
||||||
|
|
||||||
|
# Get predicted price from candle close
|
||||||
|
predicted_price = None
|
||||||
|
if timeframe in predicted_candles_denorm:
|
||||||
|
predicted_price = predicted_candles_denorm[timeframe][3] # Close
|
||||||
|
elif '1m' in predicted_candles_denorm:
|
||||||
|
predicted_price = predicted_candles_denorm['1m'][3]
|
||||||
|
elif '1s' in predicted_candles_denorm:
|
||||||
|
predicted_price = predicted_candles_denorm['1s'][3]
|
||||||
|
|
||||||
|
# CACHE inference data for later training
|
||||||
|
# Store inputs, outputs, and normalization params so we can train when actual candle arrives
|
||||||
|
if symbol not in self.prediction_cache:
|
||||||
|
self.prediction_cache[symbol] = {}
|
||||||
|
if timeframe not in self.prediction_cache[symbol]:
|
||||||
|
self.prediction_cache[symbol][timeframe] = []
|
||||||
|
|
||||||
|
# Store cached inference data (convert tensors to CPU for storage)
|
||||||
|
cached_data = {
|
||||||
|
'timestamp': timestamp,
|
||||||
|
'symbol': symbol,
|
||||||
|
'timeframe': timeframe,
|
||||||
|
'model_inputs': {k: v.cpu().clone() if isinstance(v, torch.Tensor) else v
|
||||||
|
for k, v in market_data.items()},
|
||||||
|
'model_outputs': {k: v.cpu().clone() if isinstance(v, torch.Tensor) else v
|
||||||
|
for k, v in outputs.items()},
|
||||||
|
'normalization_params': norm_params,
|
||||||
|
'predicted_candle': predicted_candles_denorm.get(timeframe),
|
||||||
|
'prediction_steps': prediction_steps
|
||||||
|
}
|
||||||
|
|
||||||
|
self.prediction_cache[symbol][timeframe].append(cached_data)
|
||||||
|
|
||||||
|
# Keep only last 100 predictions per symbol/timeframe to prevent memory bloat
|
||||||
|
if len(self.prediction_cache[symbol][timeframe]) > 100:
|
||||||
|
self.prediction_cache[symbol][timeframe] = self.prediction_cache[symbol][timeframe][-100:]
|
||||||
|
|
||||||
|
logger.debug(f"Cached prediction for {symbol} {timeframe} @ {timestamp.isoformat()}")
|
||||||
|
|
||||||
|
# Return prediction result (same format as before for compatibility)
|
||||||
return {
|
return {
|
||||||
'symbol': symbol,
|
'symbol': symbol,
|
||||||
'timeframe': timeframe,
|
'timeframe': timeframe,
|
||||||
'timestamp': datetime.now(timezone.utc).isoformat(),
|
'timestamp': timestamp.isoformat(),
|
||||||
'action': random.choice(['BUY', 'SELL', 'HOLD']),
|
'action': action,
|
||||||
'confidence': random.uniform(0.6, 0.95),
|
'confidence': confidence,
|
||||||
'predicted_price': candles[-1].get('close', 0) * (1 + random.uniform(-0.01, 0.01)),
|
'predicted_price': predicted_price,
|
||||||
|
'predicted_candle': predicted_candles_denorm,
|
||||||
'prediction_steps': prediction_steps
|
'prediction_steps': prediction_steps
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting live prediction: {e}")
|
logger.error(f"Error getting live prediction: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def run(self, host='127.0.0.1', port=8052, debug=False):
|
def get_cached_predictions_for_training(self, symbol: str, timeframe: str, actual_candle_timestamp) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Retrieve cached predictions that match a specific candle timestamp for training
|
||||||
|
|
||||||
|
When an actual candle arrives, we can:
|
||||||
|
1. Find cached predictions made before this candle
|
||||||
|
2. Compare predicted vs actual candle values
|
||||||
|
3. Calculate loss and do backpropagation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol
|
||||||
|
timeframe: Timeframe
|
||||||
|
actual_candle_timestamp: Timestamp of the actual candle that just arrived
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of cached prediction dicts that should be trained on
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if symbol not in self.prediction_cache:
|
||||||
|
return []
|
||||||
|
if timeframe not in self.prediction_cache[symbol]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Find predictions made before this candle timestamp
|
||||||
|
# Predictions should be for candles that have now completed
|
||||||
|
matching_predictions = []
|
||||||
|
actual_time = actual_candle_timestamp if isinstance(actual_candle_timestamp, datetime) else datetime.fromisoformat(str(actual_candle_timestamp).replace('Z', '+00:00'))
|
||||||
|
|
||||||
|
for cached_pred in self.prediction_cache[symbol][timeframe]:
|
||||||
|
pred_time = cached_pred['timestamp']
|
||||||
|
if isinstance(pred_time, str):
|
||||||
|
pred_time = datetime.fromisoformat(pred_time.replace('Z', '+00:00'))
|
||||||
|
|
||||||
|
# Prediction should be for a candle that comes after the prediction time
|
||||||
|
# We match predictions that were made before the actual candle closed
|
||||||
|
if pred_time < actual_time:
|
||||||
|
matching_predictions.append(cached_pred)
|
||||||
|
|
||||||
|
return matching_predictions
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting cached predictions for training: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def clear_old_cached_predictions(self, symbol: str, timeframe: str, before_timestamp: datetime):
|
||||||
|
"""
|
||||||
|
Clear cached predictions older than a certain timestamp
|
||||||
|
|
||||||
|
Useful for cleaning up old predictions that are no longer needed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if symbol not in self.prediction_cache:
|
||||||
|
return
|
||||||
|
if timeframe not in self.prediction_cache[symbol]:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.prediction_cache[symbol][timeframe] = [
|
||||||
|
pred for pred in self.prediction_cache[symbol][timeframe]
|
||||||
|
if pred['timestamp'] >= before_timestamp
|
||||||
|
]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error clearing old cached predictions: {e}")
|
||||||
|
|
||||||
|
def run(self, host='127.0.0.1', port=8051, debug=False):
|
||||||
"""Run the application"""
|
"""Run the application"""
|
||||||
logger.info(f"Starting Annotation Dashboard on http://{host}:{port}")
|
logger.info(f"Starting Annotation Dashboard on http://{host}:{port}")
|
||||||
|
|
||||||
|
|||||||
@@ -100,6 +100,10 @@
|
|||||||
<i class="fas fa-stop"></i>
|
<i class="fas fa-stop"></i>
|
||||||
Stop Inference
|
Stop Inference
|
||||||
</button>
|
</button>
|
||||||
|
<button class="btn btn-warning btn-sm w-100 mt-1" id="manual-train-btn" style="display: none;">
|
||||||
|
<i class="fas fa-hand-pointer"></i>
|
||||||
|
Train on Current Candle (Manual)
|
||||||
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Backtest on Visible Chart -->
|
<!-- Backtest on Visible Chart -->
|
||||||
@@ -628,7 +632,7 @@
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Helper function to start inference with different modes
|
// Helper function to start inference with different modes
|
||||||
function startInference(enableLiveTraining, trainEveryCandle) {
|
function startInference(trainingMode) {
|
||||||
const modelName = document.getElementById('model-select').value;
|
const modelName = document.getElementById('model-select').value;
|
||||||
|
|
||||||
if (!modelName) {
|
if (!modelName) {
|
||||||
@@ -639,7 +643,7 @@
|
|||||||
// Get timeframe
|
// Get timeframe
|
||||||
const timeframe = document.getElementById('primary-timeframe-select').value;
|
const timeframe = document.getElementById('primary-timeframe-select').value;
|
||||||
|
|
||||||
// Start real-time inference
|
// Start real-time inference with unified training_mode parameter
|
||||||
fetch('/api/realtime-inference/start', {
|
fetch('/api/realtime-inference/start', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
@@ -647,8 +651,7 @@
|
|||||||
model_name: modelName,
|
model_name: modelName,
|
||||||
symbol: appState.currentSymbol,
|
symbol: appState.currentSymbol,
|
||||||
timeframe: timeframe,
|
timeframe: timeframe,
|
||||||
enable_live_training: enableLiveTraining,
|
training_mode: trainingMode // 'none', 'every_candle', 'pivots_only', 'manual'
|
||||||
train_every_candle: trainEveryCandle
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
@@ -664,6 +667,11 @@
|
|||||||
document.getElementById('inference-status').style.display = 'block';
|
document.getElementById('inference-status').style.display = 'block';
|
||||||
document.getElementById('inference-controls').style.display = 'block';
|
document.getElementById('inference-controls').style.display = 'block';
|
||||||
|
|
||||||
|
// Show manual training button if in manual mode
|
||||||
|
if (trainingMode === 'manual') {
|
||||||
|
document.getElementById('manual-train-btn').style.display = 'block';
|
||||||
|
}
|
||||||
|
|
||||||
// Display active timeframe
|
// Display active timeframe
|
||||||
document.getElementById('active-timeframe').textContent = timeframe.toUpperCase();
|
document.getElementById('active-timeframe').textContent = timeframe.toUpperCase();
|
||||||
|
|
||||||
@@ -708,15 +716,15 @@
|
|||||||
|
|
||||||
// Button handlers for different inference modes
|
// Button handlers for different inference modes
|
||||||
document.getElementById('start-inference-btn').addEventListener('click', function () {
|
document.getElementById('start-inference-btn').addEventListener('click', function () {
|
||||||
startInference(false, false); // No training
|
startInference('none'); // No training (inference only)
|
||||||
});
|
});
|
||||||
|
|
||||||
document.getElementById('start-inference-pivot-btn').addEventListener('click', function () {
|
document.getElementById('start-inference-pivot-btn').addEventListener('click', function () {
|
||||||
startInference(true, false); // Pivot-based training
|
startInference('pivots_only'); // Pivot-based training
|
||||||
});
|
});
|
||||||
|
|
||||||
document.getElementById('start-inference-candle-btn').addEventListener('click', function () {
|
document.getElementById('start-inference-candle-btn').addEventListener('click', function () {
|
||||||
startInference(false, true); // Per-candle training
|
startInference('every_candle'); // Per-candle training
|
||||||
});
|
});
|
||||||
|
|
||||||
document.getElementById('stop-inference-btn').addEventListener('click', function () {
|
document.getElementById('stop-inference-btn').addEventListener('click', function () {
|
||||||
@@ -736,6 +744,7 @@
|
|||||||
document.getElementById('start-inference-pivot-btn').style.display = 'block';
|
document.getElementById('start-inference-pivot-btn').style.display = 'block';
|
||||||
document.getElementById('start-inference-candle-btn').style.display = 'block';
|
document.getElementById('start-inference-candle-btn').style.display = 'block';
|
||||||
document.getElementById('stop-inference-btn').style.display = 'none';
|
document.getElementById('stop-inference-btn').style.display = 'none';
|
||||||
|
document.getElementById('manual-train-btn').style.display = 'none';
|
||||||
document.getElementById('inference-status').style.display = 'none';
|
document.getElementById('inference-status').style.display = 'none';
|
||||||
document.getElementById('inference-controls').style.display = 'none';
|
document.getElementById('inference-controls').style.display = 'none';
|
||||||
|
|
||||||
@@ -763,6 +772,42 @@
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Manual training button handler
|
||||||
|
document.getElementById('manual-train-btn').addEventListener('click', function () {
|
||||||
|
if (!currentInferenceId) {
|
||||||
|
showError('No active inference session');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get user's action choice (could add a dropdown, for now use BUY as example)
|
||||||
|
const action = prompt('Enter action (BUY, SELL, or HOLD):', 'BUY');
|
||||||
|
if (!action || !['BUY', 'SELL', 'HOLD'].includes(action.toUpperCase())) {
|
||||||
|
showError('Invalid action. Must be BUY, SELL, or HOLD');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trigger manual training
|
||||||
|
fetch('/api/realtime-inference/train-manual', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({
|
||||||
|
inference_id: currentInferenceId,
|
||||||
|
action: action.toUpperCase()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(data => {
|
||||||
|
if (data.success) {
|
||||||
|
showSuccess(`Manual training completed: ${data.action} (${data.metrics ? 'Loss: ' + data.metrics.loss.toFixed(4) : ''})`);
|
||||||
|
} else {
|
||||||
|
showError('Manual training failed: ' + (data.error || 'Unknown error'));
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch(error => {
|
||||||
|
showError('Network error: ' + error.message);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
// Backtest controls
|
// Backtest controls
|
||||||
let currentBacktestId = null;
|
let currentBacktestId = null;
|
||||||
let backtestPollInterval = null;
|
let backtestPollInterval = null;
|
||||||
|
|||||||
@@ -162,7 +162,8 @@ class DeepMultiScaleAttention(nn.Module):
|
|||||||
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores.masked_fill_(mask == 0, -1e9)
|
# Use non-inplace version to avoid gradient computation issues
|
||||||
|
scores = scores.masked_fill(mask == 0, -1e9)
|
||||||
|
|
||||||
attention = F.softmax(scores, dim=-1)
|
attention = F.softmax(scores, dim=-1)
|
||||||
attention = self.dropout(attention)
|
attention = self.dropout(attention)
|
||||||
@@ -1089,8 +1090,11 @@ class TradingTransformerTrainer:
|
|||||||
pct_start=0.1
|
pct_start=0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Loss functions
|
# Loss functions with class weights
|
||||||
self.action_criterion = nn.CrossEntropyLoss()
|
# Pivot-based training: BUY at L pivots, SELL at H pivots (naturally balanced)
|
||||||
|
# Weights: [HOLD=0, BUY=1, SELL=2] - equal weighting for pivot-based trades
|
||||||
|
class_weights = torch.tensor([0.5, 1.0, 1.0], dtype=torch.float32, device=self.device)
|
||||||
|
self.action_criterion = nn.CrossEntropyLoss(weight=class_weights)
|
||||||
self.price_criterion = nn.MSELoss()
|
self.price_criterion = nn.MSELoss()
|
||||||
self.confidence_criterion = nn.BCELoss()
|
self.confidence_criterion = nn.BCELoss()
|
||||||
|
|
||||||
@@ -1182,19 +1186,30 @@ class TradingTransformerTrainer:
|
|||||||
Returns:
|
Returns:
|
||||||
Denormalized OHLCV tensor
|
Denormalized OHLCV tensor
|
||||||
"""
|
"""
|
||||||
denorm = normalized_candle.clone()
|
# Avoid inplace operations by creating new tensors instead of slice assignment
|
||||||
|
|
||||||
# Denormalize OHLC (first 4 values)
|
|
||||||
price_min = norm_params.get('price_min', 0.0)
|
price_min = norm_params.get('price_min', 0.0)
|
||||||
price_max = norm_params.get('price_max', 1.0)
|
price_max = norm_params.get('price_max', 1.0)
|
||||||
if price_max > price_min:
|
|
||||||
denorm[..., :4] = denorm[..., :4] * (price_max - price_min) + price_min
|
|
||||||
|
|
||||||
# Denormalize volume (5th value)
|
|
||||||
volume_min = norm_params.get('volume_min', 0.0)
|
volume_min = norm_params.get('volume_min', 0.0)
|
||||||
volume_max = norm_params.get('volume_max', 1.0)
|
volume_max = norm_params.get('volume_max', 1.0)
|
||||||
|
|
||||||
|
# Denormalize OHLC (first 4 values) - create new tensor, no inplace operations
|
||||||
|
if price_max > price_min:
|
||||||
|
price_scale = (price_max - price_min)
|
||||||
|
price_offset = price_min
|
||||||
|
denorm_ohlc = normalized_candle[..., :4] * price_scale + price_offset
|
||||||
|
else:
|
||||||
|
denorm_ohlc = normalized_candle[..., :4]
|
||||||
|
|
||||||
|
# Denormalize volume (5th value) - create new tensor, no inplace operations
|
||||||
if volume_max > volume_min:
|
if volume_max > volume_min:
|
||||||
denorm[..., 4] = denorm[..., 4] * (volume_max - volume_min) + volume_min
|
volume_scale = (volume_max - volume_min)
|
||||||
|
volume_offset = volume_min
|
||||||
|
denorm_volume = (normalized_candle[..., 4:5] * volume_scale + volume_offset)
|
||||||
|
else:
|
||||||
|
denorm_volume = normalized_candle[..., 4:5]
|
||||||
|
|
||||||
|
# Concatenate OHLC and Volume to create final tensor (no inplace operations)
|
||||||
|
denorm = torch.cat([denorm_ohlc, denorm_volume], dim=-1)
|
||||||
|
|
||||||
return denorm
|
return denorm
|
||||||
|
|
||||||
@@ -1675,9 +1690,46 @@ class TradingTransformerTrainer:
|
|||||||
"""Load model and training state"""
|
"""Load model and training state"""
|
||||||
checkpoint = torch.load(path, map_location=self.device)
|
checkpoint = torch.load(path, map_location=self.device)
|
||||||
|
|
||||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
# Load model state (with strict=False to handle architecture changes)
|
||||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
try:
|
||||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
self.model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error loading model state dict: {e}, continuing with partial load")
|
||||||
|
|
||||||
|
# Load optimizer state (handle mismatched states gracefully)
|
||||||
|
try:
|
||||||
|
optimizer_state = checkpoint.get('optimizer_state_dict')
|
||||||
|
if optimizer_state:
|
||||||
|
try:
|
||||||
|
# Try to load optimizer state
|
||||||
|
self.optimizer.load_state_dict(optimizer_state)
|
||||||
|
except (KeyError, ValueError, RuntimeError) as e:
|
||||||
|
logger.warning(f"Error loading optimizer state: {e}. Resetting optimizer.")
|
||||||
|
# Recreate optimizer (same pattern as __init__)
|
||||||
|
self.optimizer = torch.optim.AdamW(
|
||||||
|
self.model.parameters(),
|
||||||
|
lr=self.config.learning_rate,
|
||||||
|
weight_decay=self.config.weight_decay
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("No optimizer state found in checkpoint. Using fresh optimizer.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error loading optimizer state: {e}. Resetting optimizer.")
|
||||||
|
# Recreate optimizer (same pattern as __init__)
|
||||||
|
self.optimizer = torch.optim.AdamW(
|
||||||
|
self.model.parameters(),
|
||||||
|
lr=self.config.learning_rate,
|
||||||
|
weight_decay=self.config.weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load scheduler state
|
||||||
|
try:
|
||||||
|
scheduler_state = checkpoint.get('scheduler_state_dict')
|
||||||
|
if scheduler_state:
|
||||||
|
self.scheduler.load_state_dict(scheduler_state)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error loading scheduler state: {e}, continuing without scheduler state")
|
||||||
|
|
||||||
self.training_history = checkpoint.get('training_history', self.training_history)
|
self.training_history = checkpoint.get('training_history', self.training_history)
|
||||||
|
|
||||||
logger.info(f"Model loaded from {path}")
|
logger.info(f"Model loaded from {path}")
|
||||||
|
|||||||
208
PLACEHOLDERS_AND_MISSING_IMPLEMENTATIONS.md
Normal file
208
PLACEHOLDERS_AND_MISSING_IMPLEMENTATIONS.md
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
# Placeholders and Missing Implementations Report
|
||||||
|
|
||||||
|
**Generated**: 2025-11-23
|
||||||
|
**Purpose**: Identify all TODO, placeholder, and missing implementations that violate "no synthetic data" policy
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🔴 **CRITICAL - Synthetic Data Violations**
|
||||||
|
|
||||||
|
### 1. **core/negative_case_trainer.py** (Lines 396-397)
|
||||||
|
**Issue**: Uses `np.random.uniform()` for synthetic training improvements
|
||||||
|
```python
|
||||||
|
session.loss_improvement = np.random.uniform(0.1, 0.5) # 10-50% improvement
|
||||||
|
session.accuracy_improvement = np.random.uniform(0.05, 0.2) # 5-20% improvement
|
||||||
|
```
|
||||||
|
**Fix Required**: Calculate actual improvements from real training metrics
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🟡 **HIGH PRIORITY - Missing Core Functionality**
|
||||||
|
|
||||||
|
### 2. **core/orchestrator.py** (Line 2020)
|
||||||
|
**Issue**: `_get_all_predictions()` returns empty list - not implemented
|
||||||
|
```python
|
||||||
|
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
|
||||||
|
predictions = []
|
||||||
|
# TODO: Implement proper prediction gathering from all registered models
|
||||||
|
logger.warning(f"_get_all_predictions not fully implemented for {symbol}")
|
||||||
|
return predictions
|
||||||
|
```
|
||||||
|
**Impact**: Orchestrator cannot gather predictions from all models for decision fusion
|
||||||
|
**Fix Required**: Implement actual prediction gathering from registered models
|
||||||
|
|
||||||
|
### 3. **ANNOTATE/core/real_training_adapter.py** (Line 2339-2341)
|
||||||
|
**Issue**: Extrema training uses placeholder loss
|
||||||
|
```python
|
||||||
|
# TODO: Implement actual extrema training
|
||||||
|
session.current_loss = 0.5 / (epoch + 1) # Placeholder
|
||||||
|
```
|
||||||
|
**Fix Required**: Implement real extrema training logic
|
||||||
|
|
||||||
|
### 4. **ANNOTATE/core/real_training_adapter.py** (Line 1577)
|
||||||
|
**Issue**: Placeholder `time_in_position_minutes`
|
||||||
|
```python
|
||||||
|
time_in_position_minutes = 1.0 # Placeholder, will be more accurate with actual timestamps
|
||||||
|
```
|
||||||
|
**Fix Required**: Calculate actual time from entry timestamp
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🟢 **MEDIUM PRIORITY - Missing Features**
|
||||||
|
|
||||||
|
### 5. **web/clean_dashboard.py** (Lines 8759, 8768)
|
||||||
|
**Issue**: TODO for technical indicators and pivot points
|
||||||
|
```python
|
||||||
|
def _get_technical_indicators(self, symbol: str) -> Dict[str, float]:
|
||||||
|
# TODO: Implement technical indicators calculation
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _get_pivot_points(self, symbol: str) -> List['PivotPoint']:
|
||||||
|
# TODO: Implement pivot points calculation
|
||||||
|
return []
|
||||||
|
```
|
||||||
|
**Note**: Pivot points ARE implemented elsewhere (Williams Market Structure), but not in this method
|
||||||
|
**Fix Required**: Implement or delegate to existing pivot calculation
|
||||||
|
|
||||||
|
### 6. **web/clean_dashboard.py** (Line 8665)
|
||||||
|
**Issue**: TODO for cross-model predictions
|
||||||
|
```python
|
||||||
|
last_predictions={} # TODO: Add cross-model predictions
|
||||||
|
```
|
||||||
|
**Fix Required**: Gather predictions from all models (similar to orchestrator)
|
||||||
|
|
||||||
|
### 7. **web/clean_dashboard.py** (Line 8696)
|
||||||
|
**Issue**: TODO for technical indicators in bar data
|
||||||
|
```python
|
||||||
|
indicators={} # TODO: Add technical indicators
|
||||||
|
```
|
||||||
|
**Fix Required**: Calculate and include technical indicators
|
||||||
|
|
||||||
|
### 8. **web/clean_dashboard.py** (Line 9542)
|
||||||
|
**Issue**: Placeholder features array
|
||||||
|
```python
|
||||||
|
'features': [current_price, 0, 0, 0, 0] # Placeholder features
|
||||||
|
```
|
||||||
|
**Fix Required**: Extract real features from market data
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🔵 **LOW PRIORITY - Acceptable Placeholders**
|
||||||
|
|
||||||
|
### 9. **ANNOTATE/core/real_training_adapter.py** (Line 1421)
|
||||||
|
**Issue**: Placeholder COB data (zeros)
|
||||||
|
```python
|
||||||
|
# Create placeholder COB data (zeros if not available)
|
||||||
|
cob_data = torch.zeros(1, target_seq_len, 100, dtype=torch.float32)
|
||||||
|
```
|
||||||
|
**Status**: ✅ **ACCEPTABLE** - Returns zeros when COB data unavailable (not synthetic, just missing)
|
||||||
|
|
||||||
|
### 10. **ANNOTATE/core/real_training_adapter.py** (Lines 2746-2748)
|
||||||
|
**Issue**: Placeholder tech/market/COB data (zeros)
|
||||||
|
```python
|
||||||
|
data['tech_data'] = torch.zeros(1, 40, dtype=torch.float32)
|
||||||
|
data['market_data'] = torch.zeros(1, 30, dtype=torch.float32)
|
||||||
|
data['cob_data'] = torch.zeros(1, 600, 100, dtype=torch.float32)
|
||||||
|
```
|
||||||
|
**Status**: ✅ **ACCEPTABLE** - Model requires these inputs, zeros are safe defaults when unavailable
|
||||||
|
|
||||||
|
### 11. **ANNOTATE/core/real_training_adapter.py** (Line 2887)
|
||||||
|
**Issue**: Placeholder action in annotation conversion
|
||||||
|
```python
|
||||||
|
'action': 'BUY', # Placeholder, not used for candle prediction training
|
||||||
|
```
|
||||||
|
**Status**: ✅ **ACCEPTABLE** - Comment indicates it's not used for this training path
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## ⚠️ **QUESTIONABLE - Needs Review**
|
||||||
|
|
||||||
|
### 12. **core/orchestrator.py** (Lines 2098-2101)
|
||||||
|
**Issue**: Uses `random.uniform()` for tie-breaking
|
||||||
|
```python
|
||||||
|
import random
|
||||||
|
for action in action_scores:
|
||||||
|
# Add tiny random noise (±0.001) to break exact ties
|
||||||
|
action_scores[action] += random.uniform(-0.001, 0.001)
|
||||||
|
```
|
||||||
|
**Status**: ⚠️ **QUESTIONABLE** - This is for tie-breaking, not synthetic data generation
|
||||||
|
**Recommendation**: Consider deterministic tie-breaking (e.g., alphabetical order) instead
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📋 **Other TODOs Found**
|
||||||
|
|
||||||
|
### 13. **ANNOTATE/web/app.py** (Lines 2745, 2850)
|
||||||
|
**Issue**: Hardcoded symbols
|
||||||
|
```python
|
||||||
|
for symbol in ['ETH/USDT', 'BTC/USDT']: # TODO: Get from active subscriptions
|
||||||
|
symbol = 'ETH/USDT' # TODO: Get from active trading pair
|
||||||
|
```
|
||||||
|
**Fix Required**: Get from active subscriptions/trading pairs
|
||||||
|
|
||||||
|
### 14. **ANNOTATE/web/static/js/chart_manager.js** (Line 1193)
|
||||||
|
**Issue**: TODO for visual markers
|
||||||
|
```python
|
||||||
|
# TODO: Add visual markers using Plotly annotations
|
||||||
|
```
|
||||||
|
**Fix Required**: Add visual markers if needed
|
||||||
|
|
||||||
|
### 15. **ANNOTATE/web/templates/components/inference_panel.html** (Line 259)
|
||||||
|
**Issue**: TODO for Plotly chart update
|
||||||
|
```python
|
||||||
|
// TODO: Update Plotly chart with prediction marker
|
||||||
|
```
|
||||||
|
**Fix Required**: Implement prediction marker update
|
||||||
|
|
||||||
|
### 16. **ANNOTATE/core/data_loader.py** (Line 434)
|
||||||
|
**Issue**: MEXC time range fetch not implemented
|
||||||
|
```python
|
||||||
|
logger.warning("MEXC time range fetch not implemented yet")
|
||||||
|
```
|
||||||
|
**Fix Required**: Implement MEXC time range fetch or remove if not needed
|
||||||
|
|
||||||
|
### 17. **core/multi_horizon_prediction_manager.py** (Lines 690-713)
|
||||||
|
**Issue**: Placeholder methods for CNN/RL feature preparation
|
||||||
|
```python
|
||||||
|
def _prepare_cnn_features_for_horizon(...) -> np.ndarray:
|
||||||
|
"""Prepare CNN features for specific horizon (placeholder - not yet implemented)"""
|
||||||
|
return np.array([]) # Return empty array instead of synthetic data
|
||||||
|
|
||||||
|
def _prepare_rl_state_for_horizon(...) -> np.ndarray:
|
||||||
|
"""Prepare RL state for specific horizon (placeholder - not yet implemented)"""
|
||||||
|
return np.array([]) # Return empty array instead of synthetic data
|
||||||
|
```
|
||||||
|
**Status**: ✅ **ACCEPTABLE** - Returns empty arrays (not synthetic data), methods not yet needed
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📊 **Summary**
|
||||||
|
|
||||||
|
### By Priority:
|
||||||
|
- **🔴 Critical (Synthetic Data)**: 1 issue
|
||||||
|
- **🟡 High Priority (Missing Core)**: 3 issues
|
||||||
|
- **🟢 Medium Priority (Missing Features)**: 4 issues
|
||||||
|
- **🔵 Low Priority (Acceptable)**: 3 issues
|
||||||
|
- **⚠️ Questionable**: 1 issue
|
||||||
|
- **📋 Other TODOs**: 5 issues
|
||||||
|
|
||||||
|
### Total Issues Found: 17
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🎯 **Recommended Fix Order**
|
||||||
|
|
||||||
|
1. **Fix synthetic data violation** (negative_case_trainer.py) - **URGENT**
|
||||||
|
2. **Implement `_get_all_predictions()`** (orchestrator.py) - **HIGH**
|
||||||
|
3. **Implement extrema training** (real_training_adapter.py) - **HIGH**
|
||||||
|
4. **Fix time_in_position calculation** (real_training_adapter.py) - **MEDIUM**
|
||||||
|
5. **Implement technical indicators** (clean_dashboard.py) - **MEDIUM**
|
||||||
|
6. **Review tie-breaking logic** (orchestrator.py) - **LOW**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## ✅ **Already Fixed**
|
||||||
|
|
||||||
|
- ✅ `_get_live_prediction()` - Now uses real model inference with caching
|
||||||
|
- ✅ Ghost candles - Now includes `predicted_candle` in predictions
|
||||||
|
- ✅ JSON serialization - Fixed tensor serialization errors
|
||||||
@@ -392,9 +392,13 @@ class NegativeCaseTrainer:
|
|||||||
case.retraining_count += 1
|
case.retraining_count += 1
|
||||||
case.last_retrained = datetime.now()
|
case.last_retrained = datetime.now()
|
||||||
|
|
||||||
# Calculate improvements (simulated)
|
# Calculate improvements from actual training metrics (NO SYNTHETIC DATA)
|
||||||
session.loss_improvement = np.random.uniform(0.1, 0.5) # 10-50% improvement
|
# If actual training metrics are not available, set to 0.0 instead of random values
|
||||||
session.accuracy_improvement = np.random.uniform(0.05, 0.2) # 5-20% improvement
|
# TODO: Replace with actual model training that returns real loss/accuracy improvements
|
||||||
|
session.loss_improvement = 0.0 # Set to 0 until real training metrics available
|
||||||
|
session.accuracy_improvement = 0.0 # Set to 0 until real training metrics available
|
||||||
|
|
||||||
|
logger.warning(f"Training session completed but improvements not calculated - intensive training not yet implemented")
|
||||||
|
|
||||||
# Store training session results
|
# Store training session results
|
||||||
self._store_training_session(session)
|
self._store_training_session(session)
|
||||||
|
|||||||
@@ -2097,13 +2097,20 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
# Choose best action - safe way to handle max with key function
|
# Choose best action - safe way to handle max with key function
|
||||||
if action_scores:
|
if action_scores:
|
||||||
# Add small random component to break ties and prevent pure bias
|
# Break exact ties deterministically (NO RANDOM DATA)
|
||||||
import random
|
# Use action order as tie-breaker: BUY > SELL > HOLD
|
||||||
for action in action_scores:
|
action_order = {'BUY': 3, 'SELL': 2, 'HOLD': 1}
|
||||||
# Add tiny random noise (±0.001) to break exact ties
|
|
||||||
action_scores[action] += random.uniform(-0.001, 0.001)
|
# Find max score
|
||||||
|
max_score = max(action_scores.values())
|
||||||
|
|
||||||
|
# If multiple actions have same score, prefer BUY > SELL > HOLD
|
||||||
|
tied_actions = [action for action, score in action_scores.items() if score == max_score]
|
||||||
|
if len(tied_actions) > 1:
|
||||||
|
best_action = max(tied_actions, key=lambda a: action_order.get(a, 0))
|
||||||
|
else:
|
||||||
|
best_action = tied_actions[0]
|
||||||
|
|
||||||
best_action = max(action_scores.keys(), key=lambda k: action_scores[k])
|
|
||||||
best_confidence = action_scores[best_action]
|
best_confidence = action_scores[best_action]
|
||||||
|
|
||||||
# DEBUG: Log action scores to understand bias
|
# DEBUG: Log action scores to understand bias
|
||||||
|
|||||||
Reference in New Issue
Block a user