live updates wip
This commit is contained in:
@@ -131,6 +131,23 @@ class AnnotationDashboard:
|
||||
static_folder='static'
|
||||
)
|
||||
|
||||
# Initialize SocketIO for WebSocket support
|
||||
try:
|
||||
from flask_socketio import SocketIO, emit
|
||||
self.socketio = SocketIO(
|
||||
self.server,
|
||||
cors_allowed_origins="*",
|
||||
async_mode='threading',
|
||||
logger=False,
|
||||
engineio_logger=False
|
||||
)
|
||||
self.has_socketio = True
|
||||
logger.info("✅ SocketIO initialized for real-time updates")
|
||||
except ImportError:
|
||||
self.socketio = None
|
||||
self.has_socketio = False
|
||||
logger.warning("⚠️ flask-socketio not installed - live updates will use polling")
|
||||
|
||||
# Suppress werkzeug request logs (reduce noise from polling endpoints)
|
||||
werkzeug_logger = logging.getLogger('werkzeug')
|
||||
werkzeug_logger.setLevel(logging.WARNING) # Only show warnings and errors, not INFO
|
||||
@@ -556,8 +573,13 @@ class AnnotationDashboard:
|
||||
def index():
|
||||
"""Main dashboard page - loads existing annotations"""
|
||||
try:
|
||||
# Get all existing annotations
|
||||
annotations = self.annotation_manager.get_annotations()
|
||||
# Get symbols and timeframes from config
|
||||
symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT'])
|
||||
timeframes = self.config.get('timeframes', ['1s', '1m', '1h', '1d'])
|
||||
current_symbol = symbols[0] if symbols else 'ETH/USDT'
|
||||
|
||||
# Get annotations filtered by current symbol
|
||||
annotations = self.annotation_manager.get_annotations(symbol=current_symbol)
|
||||
|
||||
# Convert to serializable format
|
||||
annotations_data = []
|
||||
@@ -580,15 +602,11 @@ class AnnotationDashboard:
|
||||
'created_at': ann_dict.get('created_at')
|
||||
})
|
||||
|
||||
logger.info(f"Loading dashboard with {len(annotations_data)} existing annotations")
|
||||
|
||||
# Get symbols and timeframes from config
|
||||
symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT'])
|
||||
timeframes = self.config.get('timeframes', ['1s', '1m', '1h', '1d'])
|
||||
logger.info(f"Loading dashboard with {len(annotations_data)} annotations for {current_symbol}")
|
||||
|
||||
# Prepare template data
|
||||
template_data = {
|
||||
'current_symbol': symbols[0] if symbols else 'ETH/USDT', # Use first symbol as default
|
||||
'current_symbol': current_symbol,
|
||||
'symbols': symbols,
|
||||
'timeframes': timeframes,
|
||||
'annotations': annotations_data
|
||||
@@ -1112,6 +1130,52 @@ class AnnotationDashboard:
|
||||
}
|
||||
})
|
||||
|
||||
@self.server.route('/api/get-annotations', methods=['POST'])
|
||||
def get_annotations_api():
|
||||
"""Get annotations filtered by symbol"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
symbol = data.get('symbol', 'ETH/USDT')
|
||||
|
||||
# Get annotations for this symbol
|
||||
annotations = self.annotation_manager.get_annotations(symbol=symbol)
|
||||
|
||||
# Convert to serializable format
|
||||
annotations_data = []
|
||||
for ann in annotations:
|
||||
if hasattr(ann, '__dict__'):
|
||||
ann_dict = ann.__dict__
|
||||
else:
|
||||
ann_dict = ann
|
||||
|
||||
annotations_data.append({
|
||||
'annotation_id': ann_dict.get('annotation_id'),
|
||||
'symbol': ann_dict.get('symbol'),
|
||||
'timeframe': ann_dict.get('timeframe'),
|
||||
'entry': ann_dict.get('entry'),
|
||||
'exit': ann_dict.get('exit'),
|
||||
'direction': ann_dict.get('direction'),
|
||||
'profit_loss_pct': ann_dict.get('profit_loss_pct'),
|
||||
'notes': ann_dict.get('notes', ''),
|
||||
'created_at': ann_dict.get('created_at')
|
||||
})
|
||||
|
||||
logger.info(f"Returning {len(annotations_data)} annotations for {symbol}")
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'annotations': annotations_data,
|
||||
'symbol': symbol,
|
||||
'count': len(annotations_data)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting annotations: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@self.server.route('/api/export-annotations', methods=['POST'])
|
||||
def export_annotations():
|
||||
"""Export annotations to file"""
|
||||
@@ -1158,17 +1222,20 @@ class AnnotationDashboard:
|
||||
model_name = data['model_name']
|
||||
annotation_ids = data.get('annotation_ids', [])
|
||||
|
||||
# If no specific annotations provided, use all
|
||||
# CRITICAL: Get current symbol to filter annotations
|
||||
current_symbol = data.get('symbol', 'ETH/USDT')
|
||||
|
||||
# If no specific annotations provided, use all for current symbol
|
||||
if not annotation_ids:
|
||||
annotations = self.annotation_manager.get_annotations()
|
||||
annotations = self.annotation_manager.get_annotations(symbol=current_symbol)
|
||||
annotation_ids = [
|
||||
a.annotation_id if hasattr(a, 'annotation_id') else a.get('annotation_id')
|
||||
for a in annotations
|
||||
]
|
||||
logger.info(f"Using all {len(annotation_ids)} annotations for {current_symbol}")
|
||||
|
||||
# Load test cases from disk (they were auto-generated when annotations were saved)
|
||||
# CRITICAL: Filter by current symbol to avoid cross-symbol training
|
||||
current_symbol = data.get('symbol', 'ETH/USDT')
|
||||
# Filter by current symbol to avoid cross-symbol training
|
||||
all_test_cases = self.annotation_manager.get_all_test_cases(symbol=current_symbol)
|
||||
|
||||
# Filter to selected annotations
|
||||
@@ -1566,11 +1633,160 @@ class AnnotationDashboard:
|
||||
'message': str(e)
|
||||
}
|
||||
})
|
||||
|
||||
# WebSocket event handlers (if SocketIO is available)
|
||||
if self.has_socketio:
|
||||
self._setup_websocket_handlers()
|
||||
|
||||
def _setup_websocket_handlers(self):
|
||||
"""Setup WebSocket event handlers for real-time updates"""
|
||||
if not self.has_socketio:
|
||||
return
|
||||
|
||||
@self.socketio.on('connect')
|
||||
def handle_connect():
|
||||
"""Handle client connection"""
|
||||
logger.info(f"WebSocket client connected")
|
||||
from flask_socketio import emit
|
||||
emit('connection_response', {'status': 'connected', 'message': 'Connected to ANNOTATE live updates'})
|
||||
|
||||
@self.socketio.on('disconnect')
|
||||
def handle_disconnect():
|
||||
"""Handle client disconnection"""
|
||||
logger.info(f"WebSocket client disconnected")
|
||||
|
||||
@self.socketio.on('subscribe_live_updates')
|
||||
def handle_subscribe(data):
|
||||
"""Subscribe to live chart and prediction updates"""
|
||||
from flask_socketio import emit, join_room
|
||||
symbol = data.get('symbol', 'ETH/USDT')
|
||||
timeframe = data.get('timeframe', '1s')
|
||||
room = f"{symbol}_{timeframe}"
|
||||
|
||||
join_room(room)
|
||||
logger.info(f"Client subscribed to live updates: {room}")
|
||||
emit('subscription_confirmed', {'room': room, 'symbol': symbol, 'timeframe': timeframe})
|
||||
|
||||
# Start live update thread if not already running
|
||||
if not hasattr(self, '_live_update_thread') or not self._live_update_thread.is_alive():
|
||||
self._start_live_update_thread()
|
||||
|
||||
@self.socketio.on('request_prediction')
|
||||
def handle_prediction_request(data):
|
||||
"""Handle manual prediction request"""
|
||||
from flask_socketio import emit
|
||||
try:
|
||||
symbol = data.get('symbol', 'ETH/USDT')
|
||||
timeframe = data.get('timeframe', '1s')
|
||||
prediction_steps = data.get('prediction_steps', 1)
|
||||
|
||||
# Get prediction from model
|
||||
prediction = self._get_live_prediction(symbol, timeframe, prediction_steps)
|
||||
|
||||
emit('prediction_update', prediction)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling prediction request: {e}")
|
||||
emit('prediction_error', {'error': str(e)})
|
||||
|
||||
def _start_live_update_thread(self):
|
||||
"""Start background thread for live updates"""
|
||||
import threading
|
||||
|
||||
def live_update_worker():
|
||||
"""Background worker for live updates"""
|
||||
import time
|
||||
from flask_socketio import emit
|
||||
|
||||
logger.info("Live update thread started")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Get active rooms (symbol_timeframe combinations)
|
||||
# For now, update all subscribed clients every second
|
||||
|
||||
# Get latest chart data
|
||||
if self.data_provider:
|
||||
for symbol in ['ETH/USDT', 'BTC/USDT']: # TODO: Get from active subscriptions
|
||||
for timeframe in ['1s', '1m']:
|
||||
room = f"{symbol}_{timeframe}"
|
||||
|
||||
# Get latest candle
|
||||
try:
|
||||
candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=1)
|
||||
if candles and len(candles) > 0:
|
||||
latest_candle = candles[-1]
|
||||
|
||||
# Emit chart update
|
||||
self.socketio.emit('chart_update', {
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe,
|
||||
'candle': {
|
||||
'timestamp': latest_candle.get('timestamp'),
|
||||
'open': latest_candle.get('open'),
|
||||
'high': latest_candle.get('high'),
|
||||
'low': latest_candle.get('low'),
|
||||
'close': latest_candle.get('close'),
|
||||
'volume': latest_candle.get('volume')
|
||||
}
|
||||
}, room=room)
|
||||
|
||||
# Get prediction if model is loaded
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer'):
|
||||
prediction = self._get_live_prediction(symbol, timeframe, 1)
|
||||
if prediction:
|
||||
self.socketio.emit('prediction_update', prediction, room=room)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting data for {symbol} {timeframe}: {e}")
|
||||
|
||||
time.sleep(1) # Update every second
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in live update thread: {e}")
|
||||
time.sleep(5) # Wait longer on error
|
||||
|
||||
self._live_update_thread = threading.Thread(target=live_update_worker, daemon=True)
|
||||
self._live_update_thread.start()
|
||||
|
||||
def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1):
|
||||
"""Get live prediction from model"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
|
||||
return None
|
||||
|
||||
# Get recent candles for prediction
|
||||
candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=200)
|
||||
if not candles or len(candles) < 200:
|
||||
return None
|
||||
|
||||
# TODO: Implement actual prediction logic
|
||||
# For now, return placeholder
|
||||
import random
|
||||
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'action': random.choice(['BUY', 'SELL', 'HOLD']),
|
||||
'confidence': random.uniform(0.6, 0.95),
|
||||
'predicted_price': candles[-1].get('close', 0) * (1 + random.uniform(-0.01, 0.01)),
|
||||
'prediction_steps': prediction_steps
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting live prediction: {e}")
|
||||
return None
|
||||
|
||||
def run(self, host='127.0.0.1', port=8051, debug=False):
|
||||
"""Run the application"""
|
||||
logger.info(f"Starting Annotation Dashboard on http://{host}:{port}")
|
||||
self.server.run(host=host, port=port, debug=debug)
|
||||
|
||||
if self.has_socketio:
|
||||
logger.info("✅ Running with WebSocket support (SocketIO)")
|
||||
self.socketio.run(self.server, host=host, port=port, debug=debug, allow_unsafe_werkzeug=True)
|
||||
else:
|
||||
logger.warning("⚠️ Running without WebSocket support - install flask-socketio for live updates")
|
||||
self.server.run(host=host, port=port, debug=debug)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
Reference in New Issue
Block a user