"""
@self.server.route('/api/recalculate-pivots', methods=['POST'])
def recalculate_pivots():
"""Recalculate pivot points for merged data using cached data from data_loader"""
try:
data = request.get_json()
symbol = data.get('symbol', 'ETH/USDT')
timeframe = data.get('timeframe')
# We don't use timestamps/ohlcv from frontend anymore, we use our own consistent data source
if not timeframe:
return jsonify({
'success': False,
'error': {'code': 'INVALID_REQUEST', 'message': 'Missing timeframe'}
})
pivot_logger.info(f"Recalculating pivots for {symbol} {timeframe} using backend data")
if not self.data_loader:
return jsonify({
'success': False,
'error': {'code': 'DATA_LOADER_UNAVAILABLE', 'message': 'Data loader not available'}
})
# Fetch latest data from data_loader (which should have the updated cache/DB from previous calls)
# We get enough history for proper pivot calculation
df = self.data_loader.get_data(
symbol=symbol,
timeframe=timeframe,
limit=2500, # Enough for context
direction='latest'
)
if df is None or df.empty:
logger.warning(f"No data found for {symbol} {timeframe} to recalculate pivots")
return jsonify({
'success': True,
'pivot_markers': {}
})
# Recalculate pivot markers
pivot_markers = self._get_pivot_markers_for_timeframe(symbol, timeframe, df)
pivot_logger.info(f"Recalculated {len(pivot_markers)} pivot candles")
return jsonify({
'success': True,
'pivot_markers': pivot_markers
})
except Exception as e:
logger.error(f"Error recalculating pivots: {e}")
return jsonify({
'success': False,
'error': {'code': 'RECALC_ERROR', 'message': str(e)}
})
@self.server.route('/api/chart-data', methods=['POST'])
def get_chart_data():
"""Get chart data for specified symbol and timeframes with infinite scroll support"""
try:
data = request.get_json()
symbol = data.get('symbol', 'ETH/USDT')
timeframes = data.get('timeframes', ['1s', '1m', '1h', '1d'])
start_time_str = data.get('start_time')
end_time_str = data.get('end_time')
limit = data.get('limit', 2500) # Default 2500 candles for training
direction = data.get('direction', 'latest') # 'latest', 'before', or 'after'
webui_logger.info(f"Chart data request: {symbol} {timeframes} direction={direction} limit={limit}")
if start_time_str:
webui_logger.info(f" start_time: {start_time_str}")
if end_time_str:
webui_logger.info(f" end_time: {end_time_str}")
if not self.data_loader:
return jsonify({
'success': False,
'error': {
'code': 'DATA_LOADER_UNAVAILABLE',
'message': 'Data loader not available'
}
})
# Parse time strings if provided
start_time = datetime.fromisoformat(start_time_str.replace('Z', '+00:00')) if start_time_str else None
end_time = datetime.fromisoformat(end_time_str.replace('Z', '+00:00')) if end_time_str else None
# Fetch data for each timeframe using data loader
# This will automatically:
# 1. Check DuckDB first
# 2. Fetch from API if not in cache
# 3. Store in DuckDB for future use
chart_data = {}
for timeframe in timeframes:
df = self.data_loader.get_data(
symbol=symbol,
timeframe=timeframe,
start_time=start_time,
end_time=end_time,
limit=limit,
direction=direction
)
if df is not None and not df.empty:
webui_logger.info(f" {timeframe}: {len(df)} candles ({df.index[0]} to {df.index[-1]})")
# Get pivot points for this timeframe (only if we have enough context)
pivot_markers = {}
if len(df) >= 50:
pivot_markers = self._get_pivot_markers_for_timeframe(symbol, timeframe, df)
# Convert to format suitable for Plotly
chart_data[timeframe] = {
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': df['open'].tolist(),
'high': df['high'].tolist(),
'low': df['low'].tolist(),
'close': df['close'].tolist(),
'volume': df['volume'].tolist(),
'pivot_markers': pivot_markers # Optional: only present if pivots exist
}
else:
logger.warning(f" {timeframe}: No data returned")
# Get pivot bounds for the symbol
pivot_bounds = None
if self.data_provider:
try:
pivot_bounds = self.data_provider.get_pivot_bounds(symbol)
if pivot_bounds:
logger.info(f"Found pivot bounds for {symbol}: {len(pivot_bounds.pivot_support_levels)} support, {len(pivot_bounds.pivot_resistance_levels)} resistance")
except Exception as e:
logger.error(f"Error getting pivot bounds: {e}")
return jsonify({
'success': True,
'chart_data': chart_data,
'pivot_bounds': {
'support_levels': pivot_bounds.pivot_support_levels if pivot_bounds else [],
'resistance_levels': pivot_bounds.pivot_resistance_levels if pivot_bounds else [],
'price_range': {
'min': pivot_bounds.price_min if pivot_bounds else None,
'max': pivot_bounds.price_max if pivot_bounds else None
},
'volume_range': {
'min': pivot_bounds.volume_min if pivot_bounds else None,
'max': pivot_bounds.volume_max if pivot_bounds else None
},
'timeframe': '1m', # Pivot bounds are calculated from 1m data
'period': '30 days', # Monthly data
'total_levels': len(pivot_bounds.pivot_support_levels) + len(pivot_bounds.pivot_resistance_levels) if pivot_bounds else 0
} if pivot_bounds else None
})
except Exception as e:
logger.error(f"Error fetching chart data: {e}")
return jsonify({
'success': False,
'error': {
'code': 'CHART_DATA_ERROR',
'message': str(e)
}
})
@self.server.route('/api/save-annotation', methods=['POST'])
def save_annotation():
"""Save a new annotation with full market context"""
try:
data = request.get_json()
# Capture market state at entry and exit times using data provider
entry_market_state = {}
exit_market_state = {}
if self.data_provider:
try:
# Parse timestamps
entry_time = datetime.fromisoformat(data['entry']['timestamp'].replace('Z', '+00:00'))
exit_time = datetime.fromisoformat(data['exit']['timestamp'].replace('Z', '+00:00'))
# Use the new data provider method to get market state at entry time
entry_market_state = self.data_provider.get_market_state_at_time(
symbol=data['symbol'],
timestamp=entry_time,
context_window_minutes=5
)
# Use the new data provider method to get market state at exit time
exit_market_state = self.data_provider.get_market_state_at_time(
symbol=data['symbol'],
timestamp=exit_time,
context_window_minutes=5
)
logger.info(f"Captured market state: {len(entry_market_state)} timeframes at entry, {len(exit_market_state)} at exit")
except Exception as e:
logger.error(f"Error capturing market state: {e}")
import traceback
traceback.print_exc()
# Create annotation with market context
annotation = self.annotation_manager.create_annotation(
entry_point=data['entry'],
exit_point=data['exit'],
symbol=data['symbol'],
timeframe=data['timeframe'],
entry_market_state=entry_market_state,
exit_market_state=exit_market_state
)
# Collect market snapshots for SQLite storage
market_snapshots = {}
if self.data_loader:
try:
# Get OHLCV data for all timeframes around the annotation time
entry_time = datetime.fromisoformat(data['entry']['timestamp'].replace('Z', '+00:00'))
exit_time = datetime.fromisoformat(data['exit']['timestamp'].replace('Z', '+00:00'))
# Get data from 5 minutes before entry to 5 minutes after exit
start_time = entry_time - timedelta(minutes=5)
end_time = exit_time + timedelta(minutes=5)
for timeframe in ['1s', '1m', '1h', '1d']:
df = self.data_loader.get_data(
symbol=data['symbol'],
timeframe=timeframe,
start_time=start_time,
end_time=end_time,
limit=1500
)
if df is not None and not df.empty:
market_snapshots[timeframe] = df
logger.info(f"Collected {len(market_snapshots)} timeframes for annotation storage")
except Exception as e:
logger.error(f"Error collecting market snapshots: {e}")
# Save annotation with market snapshots
self.annotation_manager.save_annotation(
annotation=annotation,
market_snapshots=market_snapshots
)
# Automatically generate test case with ±5min data
try:
test_case = self.annotation_manager.generate_test_case(
annotation,
data_provider=self.data_provider,
auto_save=True
)
# Log test case details
market_state = test_case.get('market_state', {})
timeframes_with_data = [k for k in market_state.keys() if k.startswith('ohlcv_')]
logger.info(f"Auto-generated test case: {test_case['test_case_id']}")
logger.info(f" Timeframes: {timeframes_with_data}")
for tf_key in timeframes_with_data:
candle_count = len(market_state[tf_key].get('timestamps', []))
logger.info(f" {tf_key}: {candle_count} candles")
if 'training_labels' in market_state:
logger.info(f" Training labels: {len(market_state['training_labels'].get('labels_1m', []))} labels")
except Exception as e:
logger.error(f"Failed to auto-generate test case: {e}")
import traceback
traceback.print_exc()
return jsonify({
'success': True,
'annotation': annotation.__dict__ if hasattr(annotation, '__dict__') else annotation
})
except Exception as e:
logger.error(f"Error saving annotation: {e}")
return jsonify({
'success': False,
'error': {
'code': 'SAVE_ANNOTATION_ERROR',
'message': str(e)
}
})
@self.server.route('/api/delete-annotation', methods=['POST'])
def delete_annotation():
"""Delete an annotation"""
try:
data = request.get_json()
annotation_id = data['annotation_id']
# Delete annotation and check if it was found
deleted = self.annotation_manager.delete_annotation(annotation_id)
if deleted:
return jsonify({
'success': True,
'message': 'Annotation deleted successfully'
})
else:
return jsonify({
'success': False,
'error': {
'code': 'ANNOTATION_NOT_FOUND',
'message': f'Annotation {annotation_id} not found'
}
})
except Exception as e:
logger.error(f"Error deleting annotation: {e}", exc_info=True)
return jsonify({
'success': False,
'error': {
'code': 'DELETE_ANNOTATION_ERROR',
'message': str(e)
}
})
@self.server.route('/api/clear-all-annotations', methods=['POST'])
def clear_all_annotations():
"""Clear all annotations"""
try:
data = request.get_json() or {}
symbol = data.get('symbol', None)
# Use the efficient clear_all_annotations method
deleted_count = self.annotation_manager.clear_all_annotations(symbol=symbol)
if deleted_count == 0:
return jsonify({
'success': True,
'deleted_count': 0,
'message': 'No annotations to clear'
})
logger.info(f"Cleared {deleted_count} annotations" + (f" for symbol {symbol}" if symbol else ""))
return jsonify({
'success': True,
'deleted_count': deleted_count,
'message': f'Cleared {deleted_count} annotations'
})
except Exception as e:
logger.error(f"Error clearing all annotations: {e}")
import traceback
logger.error(traceback.format_exc())
return jsonify({
'success': False,
'error': {
'code': 'CLEAR_ALL_ANNOTATIONS_ERROR',
'message': str(e)
}
})
@self.server.route('/api/refresh-data', methods=['POST'])
def refresh_data():
"""Refresh chart data from data provider"""
try:
data = request.get_json()
symbol = data.get('symbol', 'ETH/USDT')
timeframes = data.get('timeframes', ['1s', '1m', '1h', '1d'])
logger.info(f"Refreshing data for {symbol} with timeframes: {timeframes}")
# Force refresh data from data provider
chart_data = {}
if self.data_provider:
for timeframe in timeframes:
try:
# Force refresh by setting refresh=True
df = self.data_provider.get_historical_data(
symbol=symbol,
timeframe=timeframe,
limit=1000,
refresh=True
)
if df is not None and not df.empty:
# Get pivot markers for this timeframe
pivot_markers = self._get_pivot_markers_for_timeframe(symbol, timeframe, df)
chart_data[timeframe] = {
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': df['open'].tolist(),
'high': df['high'].tolist(),
'low': df['low'].tolist(),
'close': df['close'].tolist(),
'volume': df['volume'].tolist(),
'pivot_markers': pivot_markers # Optional: only present if pivots exist
}
logger.info(f"Refreshed {timeframe}: {len(df)} candles")
else:
logger.warning(f"No data available for {timeframe}")
except Exception as e:
logger.error(f"Error refreshing {timeframe} data: {e}")
# Get pivot bounds for the symbol
pivot_bounds = None
if self.data_provider:
try:
pivot_bounds = self.data_provider.get_pivot_bounds(symbol)
if pivot_bounds:
logger.info(f"Found pivot bounds for {symbol}: {len(pivot_bounds.pivot_support_levels)} support, {len(pivot_bounds.pivot_resistance_levels)} resistance")
except Exception as e:
logger.error(f"Error getting pivot bounds: {e}")
return jsonify({
'success': True,
'chart_data': chart_data,
'pivot_bounds': {
'support_levels': pivot_bounds.pivot_support_levels if pivot_bounds else [],
'resistance_levels': pivot_bounds.pivot_resistance_levels if pivot_bounds else [],
'price_range': {
'min': pivot_bounds.price_min if pivot_bounds else None,
'max': pivot_bounds.price_max if pivot_bounds else None
},
'volume_range': {
'min': pivot_bounds.volume_min if pivot_bounds else None,
'max': pivot_bounds.volume_max if pivot_bounds else None
},
'timeframe': '1m', # Pivot bounds are calculated from 1m data
'period': '30 days', # Monthly data
'total_levels': len(pivot_bounds.pivot_support_levels) + len(pivot_bounds.pivot_resistance_levels) if pivot_bounds else 0
} if pivot_bounds else None,
'message': f'Refreshed data for {symbol}'
})
except Exception as e:
logger.error(f"Error refreshing data: {e}")
return jsonify({
'success': False,
'error': {
'code': 'REFRESH_DATA_ERROR',
'message': str(e)
}
})
@self.server.route('/api/generate-test-case', methods=['POST'])
def generate_test_case():
"""Generate test case from annotation"""
try:
data = request.get_json()
annotation_id = data['annotation_id']
# Get annotation
annotations = self.annotation_manager.get_annotations()
annotation = next((a for a in annotations
if (a.annotation_id if hasattr(a, 'annotation_id')
else a.get('annotation_id')) == annotation_id), None)
if not annotation:
return jsonify({
'success': False,
'error': {
'code': 'ANNOTATION_NOT_FOUND',
'message': 'Annotation not found'
}
})
# Generate test case with market context
test_case = self.annotation_manager.generate_test_case(
annotation,
data_provider=self.data_provider
)
return jsonify({
'success': True,
'test_case': test_case
})
except Exception as e:
logger.error(f"Error generating test case: {e}")
return jsonify({
'success': False,
'error': {
'code': 'GENERATE_TESTCASE_ERROR',
'message': str(e)
}
})
@self.server.route('/api/get-annotations', methods=['POST'])
def get_annotations_api():
"""Get annotations filtered by symbol"""
try:
data = request.get_json()
symbol = data.get('symbol', 'ETH/USDT')
# Get annotations for this symbol
annotations = self.annotation_manager.get_annotations(symbol=symbol)
# Convert to serializable format
annotations_data = []
for ann in annotations:
if hasattr(ann, '__dict__'):
ann_dict = ann.__dict__
else:
ann_dict = ann
annotations_data.append({
'annotation_id': ann_dict.get('annotation_id'),
'symbol': ann_dict.get('symbol'),
'timeframe': ann_dict.get('timeframe'),
'entry': ann_dict.get('entry'),
'exit': ann_dict.get('exit'),
'direction': ann_dict.get('direction'),
'profit_loss_pct': ann_dict.get('profit_loss_pct'),
'notes': ann_dict.get('notes', ''),
'created_at': ann_dict.get('created_at')
})
logger.info(f"Returning {len(annotations_data)} annotations for {symbol}")
return jsonify({
'success': True,
'annotations': annotations_data,
'symbol': symbol,
'count': len(annotations_data)
})
except Exception as e:
logger.error(f"Error getting annotations: {e}")
return jsonify({
'success': False,
'error': str(e)
})
@self.server.route('/api/export-annotations', methods=['POST'])
def export_annotations():
"""Export annotations to file"""
try:
data = request.get_json()
symbol = data.get('symbol')
format_type = data.get('format', 'json')
# Get annotations
annotations = self.annotation_manager.get_annotations(symbol=symbol)
# Export to file
output_path = self.annotation_manager.export_annotations(
annotations=annotations,
format_type=format_type
)
return send_file(output_path, as_attachment=True)
except Exception as e:
logger.error(f"Error exporting annotations: {e}")
return jsonify({
'success': False,
'error': {
'code': 'EXPORT_ERROR',
'message': str(e)
}
})
@self.server.route('/api/train-model', methods=['POST'])
def train_model():
"""Start model training with annotations"""
try:
if not self.training_adapter:
return jsonify({
'success': False,
'error': {
'code': 'TRAINING_UNAVAILABLE',
'message': 'Real training adapter not available'
}
})
data = request.get_json()
model_name = data['model_name']
annotation_ids = data.get('annotation_ids', [])
# CRITICAL: Get current symbol to filter annotations
current_symbol = data.get('symbol', 'ETH/USDT')
# Get primary timeframe for display (optional)
timeframe = data.get('timeframe', '1m')
# If no specific annotations provided, use all for current symbol
if not annotation_ids:
annotations = self.annotation_manager.get_annotations(symbol=current_symbol)
annotation_ids = [
a.annotation_id if hasattr(a, 'annotation_id') else a.get('annotation_id')
for a in annotations
]
logger.info(f"Using all {len(annotation_ids)} annotations for {current_symbol}")
# Load test cases from disk (they were auto-generated when annotations were saved)
# Filter by current symbol to avoid cross-symbol training
all_test_cases = self.annotation_manager.get_all_test_cases(symbol=current_symbol)
# Filter to selected annotations
test_cases = [
tc for tc in all_test_cases
if tc['test_case_id'].replace('annotation_', '') in annotation_ids
]
if not test_cases:
return jsonify({
'success': False,
'error': {
'code': 'NO_TEST_CASES',
'message': f'No test cases found for {len(annotation_ids)} annotations'
}
})
logger.info(f"Starting REAL training with {len(test_cases)} test cases ({len(annotation_ids)} annotations) for model {model_name} on {timeframe}")
# Start REAL training (NO SIMULATION!)
training_id = self.training_adapter.start_training(
model_name=model_name,
test_cases=test_cases,
annotation_count=len(annotation_ids),
timeframe=timeframe
)
return jsonify({
'success': True,
'training_id': training_id,
'test_cases_count': len(test_cases)
})
except Exception as e:
logger.error(f"Error starting training: {e}")
return jsonify({
'success': False,
'error': {
'code': 'TRAINING_ERROR',
'message': str(e)
}
})
@self.server.route('/api/training-progress', methods=['POST'])
def get_training_progress():
"""Get training progress"""
try:
if not self.training_adapter:
return jsonify({
'success': False,
'error': {
'code': 'TRAINING_UNAVAILABLE',
'message': 'Real training adapter not available'
}
})
data = request.get_json()
training_id = data['training_id']
progress = self.training_adapter.get_training_progress(training_id)
return jsonify({
'success': True,
'progress': progress
})
except Exception as e:
logger.error(f"Error getting training progress: {e}")
return jsonify({
'success': False,
'error': {
'code': 'PROGRESS_ERROR',
'message': str(e)
}
})
# Backtest API Endpoints
@self.server.route('/api/backtest', methods=['POST'])
def start_backtest():
"""Start backtest on visible chart data"""
try:
data = request.get_json()
model_name = data['model_name']
symbol = data['symbol']
timeframe = data['timeframe']
start_time = data.get('start_time')
end_time = data.get('end_time')
# Get the loaded model
if model_name not in self.loaded_models:
return jsonify({
'success': False,
'error': f'Model {model_name} not loaded. Please load it first.'
})
model = self.loaded_models[model_name]
# Generate backtest ID
backtest_id = str(uuid.uuid4())
# Start backtest in background
self.backtest_runner.start_backtest(
backtest_id=backtest_id,
model=model,
data_provider=self.data_provider,
symbol=symbol,
timeframe=timeframe,
orchestrator=self.orchestrator,
start_time=start_time,
end_time=end_time
)
# Get initial state
progress = self.backtest_runner.get_progress(backtest_id)
return jsonify({
'success': True,
'backtest_id': backtest_id,
'total_candles': progress.get('total_candles', 0)
})
except Exception as e:
logger.error(f"Error starting backtest: {e}", exc_info=True)
return jsonify({
'success': False,
'error': str(e)
})
@self.server.route('/api/backtest/progress/', methods=['GET'])
def get_backtest_progress(backtest_id):
"""Get backtest progress"""
try:
progress = self.backtest_runner.get_progress(backtest_id)
return jsonify(progress)
except Exception as e:
logger.error(f"Error getting backtest progress: {e}")
return jsonify({
'success': False,
'error': str(e)
})
@self.server.route('/api/backtest/stop', methods=['POST'])
def stop_backtest():
"""Stop running backtest"""
try:
data = request.get_json()
backtest_id = data['backtest_id']
self.backtest_runner.stop_backtest(backtest_id)
return jsonify({
'success': True
})
except Exception as e:
logger.error(f"Error stopping backtest: {e}")
return jsonify({
'success': False,
'error': str(e)
})
@self.server.route('/api/active-training', methods=['GET'])
def get_active_training():
"""
Get currently active training session (if any)
Allows UI to resume tracking after page reload or across multiple clients
"""
try:
if not self.training_adapter:
return jsonify({
'success': False,
'active': False,
'error': {
'code': 'TRAINING_UNAVAILABLE',
'message': 'Real training adapter not available'
}
})
active_session = self.training_adapter.get_active_training_session()
if active_session:
return jsonify({
'success': True,
'active': True,
'session': active_session
})
else:
return jsonify({
'success': True,
'active': False
})
except Exception as e:
logger.error(f"Error getting active training: {e}")
return jsonify({
'success': False,
'active': False,
'error': {
'code': 'ACTIVE_TRAINING_ERROR',
'message': str(e)
}
})
# Live Training API Endpoints
@self.server.route('/api/live-training/start', methods=['POST'])
def start_live_training():
"""Start live inference and training mode"""
try:
if not self.orchestrator:
return jsonify({
'success': False,
'error': 'Orchestrator not available'
}), 500
if self.orchestrator.start_live_training():
return jsonify({
'success': True,
'status': 'started',
'message': 'Live training mode started'
})
else:
return jsonify({
'success': False,
'error': 'Failed to start live training'
}), 500
except Exception as e:
logger.error(f"Error starting live training: {e}")
return jsonify({
'success': False,
'error': str(e)
}), 500
@self.server.route('/api/live-training/stop', methods=['POST'])
def stop_live_training():
"""Stop live inference and training mode"""
try:
if not self.orchestrator:
return jsonify({
'success': False,
'error': 'Orchestrator not available'
}), 500
if self.orchestrator.stop_live_training():
return jsonify({
'success': True,
'status': 'stopped',
'message': 'Live training mode stopped'
})
else:
return jsonify({
'success': False,
'error': 'Failed to stop live training'
}), 500
except Exception as e:
logger.error(f"Error stopping live training: {e}")
return jsonify({
'success': False,
'error': str(e)
}), 500
@self.server.route('/api/live-training/status', methods=['GET'])
def get_live_training_status():
"""Get live training status and statistics"""
try:
if not self.orchestrator:
return jsonify({
'success': False,
'active': False,
'error': 'Orchestrator not available'
})
is_active = self.orchestrator.is_live_training_active()
stats = self.orchestrator.get_live_training_stats() if is_active else {}
return jsonify({
'success': True,
'active': is_active,
'stats': stats
})
except Exception as e:
logger.error(f"Error getting live training status: {e}")
return jsonify({
'success': False,
'active': False,
'error': str(e)
})
@self.server.route('/api/available-models', methods=['GET'])
def get_available_models():
"""Get list of available models with their load status"""
try:
# Ensure self.available_models is a list
if not isinstance(self.available_models, list):
logger.warning(f"self.available_models is not a list: {type(self.available_models)}. Resetting to default.")
self.available_models = ['Transformer', 'COB_RL', 'CNN', 'DQN']
# Ensure self.loaded_models exists (it's a dict)
if not hasattr(self, 'loaded_models'):
self.loaded_models = {}
# Build model state dict with checkpoint info
logger.info(f"Building model states for {len(self.available_models)} models: {self.available_models}")
logger.info(f"Currently loaded models: {list(self.loaded_models.keys())}")
model_states = []
for model_name in self.available_models:
# Check if model is in loaded_models dict
is_loaded = model_name in self.loaded_models and self.loaded_models[model_name] is not None
# Get checkpoint info (even for unloaded models)
checkpoint_info = None
# If loaded, get from orchestrator
if is_loaded and self.orchestrator:
checkpoint_attr = f"{model_name.lower()}_checkpoint_info"
if hasattr(self.orchestrator, checkpoint_attr):
cp_info = getattr(self.orchestrator, checkpoint_attr)
if cp_info and cp_info.get('status') == 'loaded':
checkpoint_info = {
'filename': cp_info.get('filename', 'unknown'),
'epoch': cp_info.get('epoch', 0),
'loss': cp_info.get('loss', 0.0),
'accuracy': cp_info.get('accuracy', 0.0),
'loaded_at': cp_info.get('loaded_at', ''),
'source': 'loaded'
}
# If not loaded, try to read best checkpoint from disk (filename parsing only)
if not checkpoint_info:
try:
cp_info = self._get_best_checkpoint_info(model_name)
if cp_info:
checkpoint_info = cp_info
checkpoint_info['source'] = 'disk'
except Exception as e:
logger.warning(f"Could not read checkpoint for {model_name}: {e}")
# Continue without checkpoint info - not critical
model_states.append({
'name': model_name,
'loaded': is_loaded,
'can_train': is_loaded,
'can_infer': is_loaded,
'checkpoint': checkpoint_info # Checkpoint metadata (loaded or from disk)
})
logger.info(f"Returning {len(model_states)} model states")
return jsonify({
'success': True,
'models': model_states,
'loaded_count': len(self.loaded_models),
'available_count': len(self.available_models)
})
except Exception as e:
logger.error(f"Error getting available models: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
# Return a fallback list so the UI doesn't hang
return jsonify({
'success': True,
'models': [
{'name': 'Transformer', 'loaded': False, 'can_train': False, 'can_infer': False},
{'name': 'COB_RL', 'loaded': False, 'can_train': False, 'can_infer': False}
],
'loaded_count': 0,
'available_count': 2,
'error': str(e)
})
@self.server.route('/api/load-model', methods=['POST'])
def load_model():
"""Load a specific model on demand"""
try:
data = request.get_json()
model_name = data.get('model_name')
if not model_name:
return jsonify({
'success': False,
'error': 'model_name is required'
})
# Load the model
result = self._load_model_lazy(model_name)
return jsonify(result)
except Exception as e:
logger.error(f"Error in load_model endpoint: {e}")
return jsonify({
'success': False,
'error': str(e)
})
@self.server.route('/api/realtime-inference/start', methods=['POST'])
def start_realtime_inference():
"""Start real-time inference mode with configurable training strategy"""
try:
data = request.get_json()
model_name = data.get('model_name')
symbol = data.get('symbol', 'ETH/USDT')
timeframe = data.get('timeframe', '1m')
# New unified training_mode parameter
training_mode = data.get('training_mode', 'none') # 'none', 'every_candle', 'pivots_only', 'manual'
# Backward compatibility with old parameters
if 'enable_live_training' in data or 'train_every_candle' in data:
enable_live_training = data.get('enable_live_training', False)
train_every_candle = data.get('train_every_candle', False)
training_mode = 'every_candle' if train_every_candle else ('pivots_only' if enable_live_training else 'none')
if not self.training_adapter:
return jsonify({
'success': False,
'error': {
'code': 'TRAINING_UNAVAILABLE',
'message': 'Real training adapter not available'
}
})
# Set training mode on strategy manager
self.training_strategy.mode = training_mode
logger.info(f"Training strategy mode set to: {training_mode}")
# Start real-time inference - pass strategy manager for training decisions
inference_id = self.training_adapter.start_realtime_inference(
model_name=model_name,
symbol=symbol,
data_provider=self.data_provider,
enable_live_training=(training_mode != 'none'),
train_every_candle=(training_mode == 'every_candle'),
timeframe=timeframe,
training_strategy=self.training_strategy # Pass strategy manager
)
return jsonify({
'success': True,
'inference_id': inference_id,
'training_mode': training_mode,
'timeframe': timeframe
})
except Exception as e:
logger.error(f"Error starting real-time inference: {e}")
return jsonify({
'success': False,
'error': {
'code': 'INFERENCE_START_ERROR',
'message': str(e)
}
})
@self.server.route('/api/realtime-inference/stop', methods=['POST'])
def stop_realtime_inference():
"""Stop real-time inference mode"""
try:
data = request.get_json()
inference_id = data.get('inference_id')
if not self.training_adapter:
return jsonify({
'success': False,
'error': {
'code': 'TRAINING_UNAVAILABLE',
'message': 'Real training adapter not available'
}
})
self.training_adapter.stop_realtime_inference(inference_id)
return jsonify({
'success': True
})
except Exception as e:
logger.error(f"Error stopping real-time inference: {e}")
return jsonify({
'success': False,
'error': {
'code': 'INFERENCE_STOP_ERROR',
'message': str(e)
}
})
@self.server.route('/api/live-updates', methods=['POST'])
def get_live_updates():
"""Get live chart and prediction updates (polling endpoint)"""
try:
data = request.get_json()
symbol = data.get('symbol', 'ETH/USDT')
timeframe = data.get('timeframe', '1m')
response = {
'success': True,
'chart_update': None,
'prediction': None
}
# Get latest candle for the requested timeframe
if self.orchestrator and self.orchestrator.data_provider:
try:
# Get latest candle
ohlcv_data = self.orchestrator.data_provider.get_ohlcv_data(symbol, timeframe, limit=1)
if ohlcv_data and len(ohlcv_data) > 0:
latest_candle = ohlcv_data[-1]
response['chart_update'] = {
'symbol': symbol,
'timeframe': timeframe,
'candle': {
'timestamp': latest_candle[0],
'open': float(latest_candle[1]),
'high': float(latest_candle[2]),
'low': float(latest_candle[3]),
'close': float(latest_candle[4]),
'volume': float(latest_candle[5])
}
}
except Exception as e:
logger.debug(f"Error getting latest candle: {e}")
# Get latest model predictions
if self.orchestrator:
try:
# Get latest predictions from orchestrator
predictions = {}
# DQN predictions
if hasattr(self.orchestrator, 'recent_dqn_predictions') and symbol in self.orchestrator.recent_dqn_predictions:
dqn_preds = list(self.orchestrator.recent_dqn_predictions[symbol])
if dqn_preds:
predictions['dqn'] = dqn_preds[-1]
# CNN predictions
if hasattr(self.orchestrator, 'recent_cnn_predictions') and symbol in self.orchestrator.recent_cnn_predictions:
cnn_preds = list(self.orchestrator.recent_cnn_predictions[symbol])
if cnn_preds:
predictions['cnn'] = cnn_preds[-1]
# Transformer predictions with next_candles for ghost candles
# First check if there are stored predictions from the inference loop
if hasattr(self.orchestrator, 'recent_transformer_predictions') and symbol in self.orchestrator.recent_transformer_predictions:
transformer_preds = list(self.orchestrator.recent_transformer_predictions[symbol])
if transformer_preds:
# Convert any remaining tensors to Python types before JSON serialization
transformer_pred = transformer_preds[-1].copy()
predictions['transformer'] = self._serialize_prediction(transformer_pred)
if predictions:
response['prediction'] = predictions
except Exception as e:
logger.debug(f"Error getting predictions: {e}")
import traceback
logger.debug(traceback.format_exc())
return jsonify(response)
except Exception as e:
logger.error(f"Error in live updates: {e}")
return jsonify({
'success': False,
'error': str(e)
})
@self.server.route('/api/realtime-inference/signals', methods=['GET'])
def get_realtime_signals():
"""Get latest real-time inference signals"""
try:
if not self.training_adapter:
return jsonify({
'success': False,
'error': {
'code': 'TRAINING_UNAVAILABLE',
'message': 'Real training adapter not available'
}
})
signals = self.training_adapter.get_latest_signals()
# Get metrics from active inference sessions
metrics = {'accuracy': 0.0, 'loss': 0.0}
if hasattr(self.training_adapter, 'inference_sessions'):
for session in self.training_adapter.inference_sessions.values():
if 'metrics' in session:
metrics = session['metrics']
break
return jsonify({
'success': True,
'signals': signals,
'metrics': metrics
})
except Exception as e:
logger.error(f"Error getting signals: {e}")
return jsonify({
'success': False,
'error': {
'code': 'SIGNALS_ERROR',
'message': str(e)
}
})
@self.server.route('/api/realtime-inference/train-manual', methods=['POST'])
def train_manual():
"""Manually trigger training on current candle with specified action"""
try:
data = request.get_json()
inference_id = data.get('inference_id')
action = data.get('action', 'HOLD')
if not self.training_adapter:
return jsonify({
'success': False,
'error': 'Training adapter not available'
})
# Get active inference session
if not hasattr(self.training_adapter, 'inference_sessions'):
return jsonify({
'success': False,
'error': 'No active inference sessions'
})
session = self.training_adapter.inference_sessions.get(inference_id)
if not session:
return jsonify({
'success': False,
'error': 'Inference session not found'
})
# Set pending action for training
session['pending_action'] = action
# Get session parameters
symbol = session.get('symbol', 'ETH/USDT')
timeframe = session.get('timeframe', '1m')
data_provider = session.get('data_provider')
# Call training method
train_result = self.training_adapter._train_on_new_candle(
session, symbol, timeframe, data_provider
)
if train_result.get('success'):
return jsonify({
'success': True,
'action': action,
'metrics': {
'loss': train_result.get('loss', 0.0),
'accuracy': train_result.get('accuracy', 0.0),
'training_steps': train_result.get('training_steps', 0)
}
})
else:
return jsonify({
'success': False,
'error': train_result.get('error', 'Training failed')
})
except Exception as e:
logger.error(f"Error in manual training: {e}")
return jsonify({
'success': False,
'error': str(e)
})
# WebSocket event handlers (if SocketIO is available)
if self.has_socketio:
self._setup_websocket_handlers()
def _serialize_prediction(self, prediction: Dict) -> Dict:
"""Convert PyTorch tensors in prediction dict to JSON-serializable Python types"""
try:
import torch
serialized = {}
for key, value in prediction.items():
if isinstance(value, torch.Tensor):
if value.numel() == 1: # Scalar tensor
serialized[key] = value.item()
else: # Multi-element tensor
serialized[key] = value.detach().cpu().tolist()
elif isinstance(value, dict):
serialized[key] = self._serialize_prediction(value) # Recursive
elif isinstance(value, (list, tuple)):
serialized[key] = [
v.item() if isinstance(v, torch.Tensor) and v.numel() == 1 else
(v.detach().cpu().tolist() if isinstance(v, torch.Tensor) else v)
for v in value
]
else:
serialized[key] = value
return serialized
except Exception as e:
logger.warning(f"Error serializing prediction: {e}")
# Fallback: return as-is (might fail JSON serialization but won't crash)
return prediction
def _setup_websocket_handlers(self):
"""Setup WebSocket event handlers for real-time updates"""
if not self.has_socketio:
return
@self.socketio.on('connect')
def handle_connect():
"""Handle client connection"""
logger.info(f"WebSocket client connected")
from flask_socketio import emit
emit('connection_response', {'status': 'connected', 'message': 'Connected to ANNOTATE live updates'})
@self.socketio.on('disconnect')
def handle_disconnect():
"""Handle client disconnection"""
logger.info(f"WebSocket client disconnected")
@self.socketio.on('subscribe_live_updates')
def handle_subscribe(data):
"""Subscribe to live chart and prediction updates"""
from flask_socketio import emit, join_room
symbol = data.get('symbol', 'ETH/USDT')
timeframe = data.get('timeframe', '1s')
room = f"{symbol}_{timeframe}"
join_room(room)
logger.info(f"Client subscribed to live updates: {room}")
emit('subscription_confirmed', {'room': room, 'symbol': symbol, 'timeframe': timeframe})
# Start live update thread if not already running
if not hasattr(self, '_live_update_thread') or not self._live_update_thread.is_alive():
self._start_live_update_thread()
@self.socketio.on('request_prediction')
def handle_prediction_request(data):
"""Handle manual prediction request"""
from flask_socketio import emit
try:
symbol = data.get('symbol', 'ETH/USDT')
timeframe = data.get('timeframe', '1s')
prediction_steps = data.get('prediction_steps', 1)
# Get prediction from model
prediction = self._get_live_prediction(symbol, timeframe, prediction_steps)
emit('prediction_update', prediction)
except Exception as e:
logger.error(f"Error handling prediction request: {e}")
emit('prediction_error', {'error': str(e)})
@self.socketio.on('prediction_accuracy')
def handle_prediction_accuracy(data):
"""
Handle validated prediction accuracy - trigger incremental training
This is called when frontend validates a prediction against actual candle.
We use this data to incrementally train the model for continuous improvement.
"""
from flask_socketio import emit
try:
timeframe = data.get('timeframe')
timestamp = data.get('timestamp')
predicted = data.get('predicted') # [O, H, L, C, V]
actual = data.get('actual') # [O, H, L, C]
errors = data.get('errors') # {open, high, low, close}
pct_errors = data.get('pctErrors')
direction_correct = data.get('directionCorrect')
accuracy = data.get('accuracy')
if not all([timeframe, timestamp, predicted, actual]):
logger.warning("Incomplete prediction accuracy data received")
return
logger.info(f"[{timeframe}] Prediction validated: {accuracy:.1f}% accuracy, direction: {direction_correct}")
logger.debug(f" Errors: O={pct_errors['open']:.2f}% H={pct_errors['high']:.2f}% L={pct_errors['low']:.2f}% C={pct_errors['close']:.2f}%")
# Trigger incremental training on this validated prediction
self._train_on_validated_prediction(
timeframe=timeframe,
timestamp=timestamp,
predicted=predicted,
actual=actual,
errors=errors,
direction_correct=direction_correct,
accuracy=accuracy
)
# Send confirmation back to frontend
emit('training_update', {
'status': 'training_triggered',
'timestamp': timestamp,
'accuracy': accuracy,
'message': f'Incremental training triggered on validated prediction'
})
except Exception as e:
logger.error(f"Error handling prediction accuracy: {e}", exc_info=True)
emit('training_error', {'error': str(e)})
def _start_live_update_thread(self):
"""Start background thread for live updates"""
import threading
def live_update_worker():
"""Background worker for live updates"""
import time
from flask_socketio import emit
logger.info("Live update thread started")
while True:
try:
# Get active rooms (symbol_timeframe combinations)
# For now, update all subscribed clients every second
# Get latest chart data
if self.data_provider:
for symbol in ['ETH/USDT', 'BTC/USDT']: # TODO: Get from active subscriptions
for timeframe in ['1s', '1m']:
room = f"{symbol}_{timeframe}"
# Get latest candles (need last 2 to determine confirmation status)
try:
candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=2)
if candles and len(candles) > 0:
latest_candle = candles[-1]
# Determine if candle is confirmed (closed)
# For 1s: candle is confirmed when next candle starts (2s delay)
# For others: candle is confirmed when next candle starts
is_confirmed = len(candles) >= 2 # If we have 2 candles, the first is confirmed
# Format timestamp consistently
timestamp = latest_candle.get('timestamp')
if isinstance(timestamp, str):
# Already formatted
formatted_timestamp = timestamp
else:
# Convert to ISO string then format
from datetime import datetime
if isinstance(timestamp, datetime):
formatted_timestamp = timestamp.strftime('%Y-%m-%d %H:%M:%S')
else:
formatted_timestamp = str(timestamp)
# Emit chart update with full candle data
self.socketio.emit('chart_update', {
'symbol': symbol,
'timeframe': timeframe,
'candle': {
'timestamp': formatted_timestamp,
'open': float(latest_candle.get('open', 0)),
'high': float(latest_candle.get('high', 0)),
'low': float(latest_candle.get('low', 0)),
'close': float(latest_candle.get('close', 0)),
'volume': float(latest_candle.get('volume', 0))
},
'is_confirmed': is_confirmed, # True if this candle is closed/confirmed
'has_previous': len(candles) >= 2 # True if we have previous candle for validation
}, room=room)
# Get prediction if model is loaded
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer'):
prediction = self._get_live_prediction(symbol, timeframe, 1)
if prediction:
self.socketio.emit('prediction_update', prediction, room=room)
except Exception as e:
logger.debug(f"Error getting data for {symbol} {timeframe}: {e}")
time.sleep(1) # Update every second
except Exception as e:
logger.error(f"Error in live update thread: {e}")
time.sleep(5) # Wait longer on error
self._live_update_thread = threading.Thread(target=live_update_worker, daemon=True)
self._live_update_thread.start()
def _get_live_transformer_prediction(self, symbol: str = 'ETH/USDT'):
"""
Generate live transformer prediction with next_candles for ghost candle display
"""
try:
if not self.orchestrator:
logger.debug("No orchestrator - cannot generate predictions")
return None
if not hasattr(self.orchestrator, 'primary_transformer'):
logger.debug("Orchestrator has no primary_transformer - enable training first")
return None
transformer = self.orchestrator.primary_transformer
if not transformer:
logger.debug("primary_transformer is None - model not loaded yet")
return None
transformer.eval()
# Get recent market data
price_data_1s = self.data_provider.get_ohlcv(symbol, '1s', limit=200) if self.data_provider else None
price_data_1m = self.data_provider.get_ohlcv(symbol, '1m', limit=150) if self.data_provider else None
price_data_1h = self.data_provider.get_ohlcv(symbol, '1h', limit=24) if self.data_provider else None
price_data_1d = self.data_provider.get_ohlcv(symbol, '1d', limit=14) if self.data_provider else None
btc_data_1m = self.data_provider.get_ohlcv('BTC/USDT', '1m', limit=150) if self.data_provider else None
if not price_data_1m or len(price_data_1m) < 10:
return None
import torch
import numpy as np
device = next(transformer.parameters()).device
def ohlcv_to_tensor(data, limit=None):
if not data:
return None
data = data[-limit:] if limit and len(data) > limit else data
arr = np.array([[d['open'], d['high'], d['low'], d['close'], d['volume']] for d in data], dtype=np.float32)
return torch.from_numpy(arr).unsqueeze(0).to(device)
inputs = {
'price_data_1s': ohlcv_to_tensor(price_data_1s, 200),
'price_data_1m': ohlcv_to_tensor(price_data_1m, 150),
'price_data_1h': ohlcv_to_tensor(price_data_1h, 24),
'price_data_1d': ohlcv_to_tensor(price_data_1d, 14),
'btc_data_1m': ohlcv_to_tensor(btc_data_1m, 150)
}
# Forward pass
with torch.no_grad():
outputs = transformer(**inputs)
# Extract next_candles
next_candles = outputs.get('next_candles', {})
if not next_candles:
return None
# Convert to JSON-serializable format
predicted_candle = {}
for tf, candle_tensor in next_candles.items():
if candle_tensor is not None:
candle_values = candle_tensor.squeeze(0).cpu().numpy().tolist()
predicted_candle[tf] = candle_values
current_price = price_data_1m[-1]['close']
predicted_1m_close = predicted_candle.get('1m', [0,0,0,current_price,0])[3]
price_change = (predicted_1m_close - current_price) / current_price
if price_change > 0.001:
action = 'BUY'
elif price_change < -0.001:
action = 'SELL'
else:
action = 'HOLD'
confidence = 0.7
if 'confidence' in outputs:
conf_tensor = outputs['confidence']
confidence = float(conf_tensor.squeeze(0).cpu().numpy()[0])
prediction = {
'timestamp': datetime.now().isoformat(),
'symbol': symbol,
'action': action,
'confidence': confidence,
'predicted_price': predicted_1m_close,
'current_price': current_price,
'price_change': price_change,
'predicted_candle': predicted_candle, # This is what frontend needs!
'type': 'transformer_prediction'
}
# Store for tracking
self.orchestrator.store_transformer_prediction(symbol, prediction)
logger.debug(f"Generated transformer prediction with {len(predicted_candle)} timeframes for ghost candles")
return prediction
except Exception as e:
logger.error(f"Error generating live transformer prediction: {e}", exc_info=True)
return None
def _train_on_validated_prediction(self, timeframe: str, timestamp: str, predicted: list,
actual: list, errors: dict, direction_correct: bool, accuracy: float):
"""
Incrementally train model on validated prediction
This implements online learning where each validated prediction becomes
a training sample, with loss weighting based on prediction accuracy.
"""
try:
if not self.training_adapter:
logger.warning("Training adapter not available for incremental training")
return
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
logger.warning("Transformer model not available for incremental training")
return
# Get the transformer trainer
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if not trainer:
logger.warning("Transformer trainer not available")
return
# Calculate sample weight based on accuracy
# Low accuracy predictions get higher weight (we need to learn from mistakes)
# High accuracy predictions get lower weight (model already knows this)
if accuracy < 50:
sample_weight = 3.0 # Learn hard from bad predictions
elif accuracy < 70:
sample_weight = 2.0 # Moderate learning
elif accuracy < 85:
sample_weight = 1.0 # Normal learning
else:
sample_weight = 0.5 # Light touch-up for good predictions
# Also weight by direction correctness
if not direction_correct:
sample_weight *= 1.5 # Wrong direction is critical - learn more
logger.info(f"[{timeframe}] Incremental training: accuracy={accuracy:.1f}%, weight={sample_weight:.1f}x")
# Create training sample from validated prediction
# We need to fetch the market state at that timestamp
symbol = 'ETH/USDT' # TODO: Get from active trading pair
training_sample = {
'symbol': symbol,
'timestamp': timestamp,
'predicted_candle': predicted, # [O, H, L, C, V]
'actual_candle': actual, # [O, H, L, C]
'errors': errors,
'accuracy': accuracy,
'direction_correct': direction_correct,
'sample_weight': sample_weight
}
# Get market state at that timestamp
try:
market_state = self._fetch_market_state_at_timestamp(symbol, timestamp, timeframe)
training_sample['market_state'] = market_state
except Exception as e:
logger.warning(f"Could not fetch market state: {e}")
return
# Convert to transformer batch format
batch = self.training_adapter._convert_prediction_to_batch(training_sample, timeframe)
if not batch:
logger.warning("Could not convert validated prediction to training batch")
return
# Train on this batch with sample weighting
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
if result:
loss = result.get('total_loss', 0)
candle_accuracy = result.get('candle_accuracy', 0)
logger.info(f"[{timeframe}] Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}")
# Save checkpoint periodically (every 10 incremental steps)
if not hasattr(self, '_incremental_training_steps'):
self._incremental_training_steps = 0
self._incremental_training_steps += 1
if self._incremental_training_steps % 10 == 0:
logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps")
trainer.save_checkpoint(
filepath=None, # Auto-generate path
metadata={
'training_type': 'incremental_online',
'steps': self._incremental_training_steps,
'last_accuracy': accuracy
}
)
except Exception as e:
logger.error(f"Error in incremental training: {e}", exc_info=True)
def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict:
"""Fetch market state at a specific timestamp for training"""
try:
from datetime import datetime
import pandas as pd
# Parse timestamp
ts = pd.Timestamp(timestamp)
# Get historical data for multiple timeframes
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
for tf in ['1s', '1m', '1h']:
try:
df = self.data_provider.get_historical_data(symbol, tf, limit=200)
if df is not None and not df.empty:
# Find data up to (but not including) the target timestamp
df_before = df[df.index < ts]
if not df_before.empty:
recent = df_before.tail(200)
market_state['timeframes'][tf] = {
'timestamps': recent.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': recent['open'].tolist(),
'high': recent['high'].tolist(),
'low': recent['low'].tolist(),
'close': recent['close'].tolist(),
'volume': recent['volume'].tolist()
}
except Exception as e:
logger.warning(f"Could not fetch {tf} data: {e}")
return market_state
except Exception as e:
logger.error(f"Error fetching market state: {e}")
return {}
def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1):
"""
Get live prediction from model using trainer inference
Caches inference data (inputs/outputs) for later training when actual candle arrives.
This allows us to:
1. Compare predicted vs actual candle values
2. Calculate loss
3. Do backpropagation with correct outputs
Returns:
Dict with prediction results including predicted_candle for ghost candle display
"""
try:
if not self.orchestrator:
return None
# Get trainer from orchestrator
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if not trainer or not trainer.model:
logger.debug("No transformer trainer available for live prediction")
return None
# Get market data using training adapter's method (reuses existing logic)
if not hasattr(self.training_adapter, '_get_realtime_market_data'):
logger.warning("Training adapter missing _get_realtime_market_data method")
return None
market_data, norm_params = self.training_adapter._get_realtime_market_data(symbol, self.data_provider)
if not market_data:
logger.debug(f"No market data available for {symbol} {timeframe}")
return None
# Make prediction with model
import torch
timestamp = datetime.now(timezone.utc)
with torch.no_grad():
trainer.model.eval()
outputs = trainer.model(**market_data)
# Extract action prediction
action_probs = outputs.get('action_probs')
if action_probs is None:
logger.debug("No action_probs in model output")
return None
action_idx = torch.argmax(action_probs, dim=-1).item()
confidence = action_probs[0, action_idx].item()
# Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL)
actions = ['HOLD', 'BUY', 'SELL']
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
# Extract predicted candles and denormalize
predicted_candles_raw = {}
if 'next_candles' in outputs:
for tf, tensor in outputs['next_candles'].items():
predicted_candles_raw[tf] = tensor.detach().cpu().numpy().tolist()
# Denormalize predicted candles
predicted_candles_denorm = {}
if predicted_candles_raw and norm_params:
for tf, raw_candle in predicted_candles_raw.items():
if tf in norm_params:
params = norm_params[tf]
price_min = params['price_min']
price_max = params['price_max']
vol_min = params['volume_min']
vol_max = params['volume_max']
# raw_candle is [1, 5] list
candle_values = raw_candle[0]
denorm_candle = [
candle_values[0] * (price_max - price_min) + price_min, # Open
candle_values[1] * (price_max - price_min) + price_min, # High
candle_values[2] * (price_max - price_min) + price_min, # Low
candle_values[3] * (price_max - price_min) + price_min, # Close
candle_values[4] * (vol_max - vol_min) + vol_min # Volume
]
predicted_candles_denorm[tf] = denorm_candle
# Get predicted price from candle close
predicted_price = None
if timeframe in predicted_candles_denorm:
predicted_price = predicted_candles_denorm[timeframe][3] # Close
elif '1m' in predicted_candles_denorm:
predicted_price = predicted_candles_denorm['1m'][3]
elif '1s' in predicted_candles_denorm:
predicted_price = predicted_candles_denorm['1s'][3]
# CACHE inference data for later training
# Store inputs, outputs, and normalization params so we can train when actual candle arrives
if symbol not in self.prediction_cache:
self.prediction_cache[symbol] = {}
if timeframe not in self.prediction_cache[symbol]:
self.prediction_cache[symbol][timeframe] = []
# Store cached inference data (convert tensors to CPU for storage)
cached_data = {
'timestamp': timestamp,
'symbol': symbol,
'timeframe': timeframe,
'model_inputs': {k: v.cpu().clone() if isinstance(v, torch.Tensor) else v
for k, v in market_data.items()},
'model_outputs': {k: v.cpu().clone() if isinstance(v, torch.Tensor) else v
for k, v in outputs.items()},
'normalization_params': norm_params,
'predicted_candle': predicted_candles_denorm.get(timeframe),
'prediction_steps': prediction_steps
}
self.prediction_cache[symbol][timeframe].append(cached_data)
# Keep only last 100 predictions per symbol/timeframe to prevent memory bloat
if len(self.prediction_cache[symbol][timeframe]) > 100:
self.prediction_cache[symbol][timeframe] = self.prediction_cache[symbol][timeframe][-100:]
logger.debug(f"Cached prediction for {symbol} {timeframe} @ {timestamp.isoformat()}")
# Return prediction result (same format as before for compatibility)
return {
'symbol': symbol,
'timeframe': timeframe,
'timestamp': timestamp.isoformat(),
'action': action,
'confidence': confidence,
'predicted_price': predicted_price,
'predicted_candle': predicted_candles_denorm,
'prediction_steps': prediction_steps
}
except Exception as e:
logger.error(f"Error getting live prediction: {e}")
import traceback
logger.debug(traceback.format_exc())
return None
def get_cached_predictions_for_training(self, symbol: str, timeframe: str, actual_candle_timestamp) -> List[Dict]:
"""
Retrieve cached predictions that match a specific candle timestamp for training
When an actual candle arrives, we can:
1. Find cached predictions made before this candle
2. Compare predicted vs actual candle values
3. Calculate loss and do backpropagation
Args:
symbol: Trading symbol
timeframe: Timeframe
actual_candle_timestamp: Timestamp of the actual candle that just arrived
Returns:
List of cached prediction dicts that should be trained on
"""
try:
if symbol not in self.prediction_cache:
return []
if timeframe not in self.prediction_cache[symbol]:
return []
# Find predictions made before this candle timestamp
# Predictions should be for candles that have now completed
matching_predictions = []
actual_time = actual_candle_timestamp if isinstance(actual_candle_timestamp, datetime) else datetime.fromisoformat(str(actual_candle_timestamp).replace('Z', '+00:00'))
for cached_pred in self.prediction_cache[symbol][timeframe]:
pred_time = cached_pred['timestamp']
if isinstance(pred_time, str):
pred_time = datetime.fromisoformat(pred_time.replace('Z', '+00:00'))
# Prediction should be for a candle that comes after the prediction time
# We match predictions that were made before the actual candle closed
if pred_time < actual_time:
matching_predictions.append(cached_pred)
return matching_predictions
except Exception as e:
logger.error(f"Error getting cached predictions for training: {e}")
return []
def clear_old_cached_predictions(self, symbol: str, timeframe: str, before_timestamp: datetime):
"""
Clear cached predictions older than a certain timestamp
Useful for cleaning up old predictions that are no longer needed
"""
try:
if symbol not in self.prediction_cache:
return
if timeframe not in self.prediction_cache[symbol]:
return
self.prediction_cache[symbol][timeframe] = [
pred for pred in self.prediction_cache[symbol][timeframe]
if pred['timestamp'] >= before_timestamp
]
except Exception as e:
logger.debug(f"Error clearing old cached predictions: {e}")
def run(self, host='127.0.0.1', port=8051, debug=False):
"""Run the application"""
logger.info(f"Starting Annotation Dashboard on http://{host}:{port}")
if self.has_socketio:
logger.info("Running with WebSocket support (SocketIO)")
self.socketio.run(self.server, host=host, port=port, debug=debug, allow_unsafe_werkzeug=True)
else:
logger.warning("Running without WebSocket support - install flask-socketio for live updates")
self.server.run(host=host, port=port, debug=debug)
def main():
"""Main entry point"""
logger.info("=" * 80)
logger.info("ANNOTATE Application Starting")
logger.info(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
logger.info("=" * 80)
# Print logging channel configuration
from utils.logging_config import print_channel_status
print_channel_status()
dashboard = AnnotationDashboard()
dashboard.run(debug=True)
if __name__ == '__main__':
main()