This commit is contained in:
Dobromir Popov
2025-12-10 00:45:41 +02:00
parent c21d8cbea1
commit fadfa8c741
5 changed files with 256 additions and 117 deletions

View File

@@ -163,8 +163,12 @@ class RealTrainingAdapter:
# CRITICAL: Training lock to prevent concurrent model access # CRITICAL: Training lock to prevent concurrent model access
# Multiple threads (batch training + per-candle training) can corrupt # Multiple threads (batch training + per-candle training) can corrupt
# the computation graph if they access the model simultaneously # the computation graph if they access the model simultaneously
# Use RLock (reentrant lock) to allow same thread to acquire multiple times
import threading import threading
self._training_lock = threading.Lock() self._training_lock = threading.RLock()
# Track which thread currently holds the training lock (for debugging)
self._training_lock_holder = None
# Use orchestrator's inference training coordinator (if available) # Use orchestrator's inference training coordinator (if available)
# This reduces duplication and centralizes coordination logic # This reduces duplication and centralizes coordination logic
@@ -4142,7 +4146,16 @@ class RealTrainingAdapter:
# CRITICAL: Acquire training lock to prevent concurrent model access # CRITICAL: Acquire training lock to prevent concurrent model access
# This prevents "inplace operation" errors when batch training runs simultaneously # This prevents "inplace operation" errors when batch training runs simultaneously
import torch import torch
with self._training_lock: import threading
# Try to acquire lock with timeout to prevent deadlock
lock_acquired = self._training_lock.acquire(timeout=5.0)
if not lock_acquired:
logger.warning("Could not acquire training lock within 5 seconds - skipping this training step")
return
try:
self._training_lock_holder = threading.current_thread().name
with torch.enable_grad(): with torch.enable_grad():
trainer.model.train() trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False) result = trainer.train_step(batch, accumulate_gradients=False)
@@ -4193,6 +4206,10 @@ class RealTrainingAdapter:
improved=improved improved=improved
) )
self.realtime_training_metrics['last_checkpoint_step'] = self.realtime_training_metrics['total_steps'] self.realtime_training_metrics['last_checkpoint_step'] = self.realtime_training_metrics['total_steps']
finally:
# CRITICAL: Always release the lock, even if an exception occurs
self._training_lock_holder = None
self._training_lock.release()
except Exception as e: except Exception as e:
logger.warning(f"Error training transformer on sample: {e}") logger.warning(f"Error training transformer on sample: {e}")

View File

@@ -2626,6 +2626,95 @@ class AnnotationDashboard:
'error': str(e) 'error': str(e)
}) })
@self.server.route('/api/live-updates-batch', methods=['POST'])
def get_live_updates_batch():
"""Get live chart and prediction updates for multiple timeframes (optimized batch endpoint)"""
try:
data = request.get_json() or {}
symbol = data.get('symbol', 'ETH/USDT')
timeframes = data.get('timeframes', ['1m'])
response = {
'success': True,
'chart_updates': {}, # Dict of timeframe -> chart_update
'prediction': None # Single prediction for all timeframes
}
# Get latest candle for each requested timeframe
if self.data_loader:
for timeframe in timeframes:
try:
df = self.data_loader.get_data(symbol, timeframe, limit=2, direction='latest')
if df is not None and not df.empty:
latest_candle = df.iloc[-1]
# Format timestamp as ISO string
timestamp = latest_candle.name
if hasattr(timestamp, 'isoformat'):
if timestamp.tzinfo is not None:
timestamp_str = timestamp.astimezone(timezone.utc).isoformat()
else:
timestamp_str = timestamp.isoformat() + 'Z'
else:
timestamp_str = str(timestamp)
is_confirmed = len(df) >= 2
response['chart_updates'][timeframe] = {
'symbol': symbol,
'timeframe': timeframe,
'candle': {
'timestamp': timestamp_str,
'open': float(latest_candle['open']),
'high': float(latest_candle['high']),
'low': float(latest_candle['low']),
'close': float(latest_candle['close']),
'volume': float(latest_candle['volume'])
},
'is_confirmed': is_confirmed
}
except Exception as e:
logger.debug(f"Error getting candle for {timeframe}: {e}")
# Get latest model predictions (same for all timeframes)
if self.orchestrator:
try:
predictions = {}
# DQN predictions
if hasattr(self.orchestrator, 'recent_dqn_predictions') and symbol in self.orchestrator.recent_dqn_predictions:
dqn_preds = list(self.orchestrator.recent_dqn_predictions[symbol])
if dqn_preds:
predictions['dqn'] = dqn_preds[-1]
# CNN predictions
if hasattr(self.orchestrator, 'recent_cnn_predictions') and symbol in self.orchestrator.recent_cnn_predictions:
cnn_preds = list(self.orchestrator.recent_cnn_predictions[symbol])
if cnn_preds:
predictions['cnn'] = cnn_preds[-1]
# Transformer predictions
if hasattr(self.orchestrator, 'recent_transformer_predictions') and symbol in self.orchestrator.recent_transformer_predictions:
transformer_preds = list(self.orchestrator.recent_transformer_predictions[symbol])
if transformer_preds:
transformer_pred = transformer_preds[-1].copy()
predictions['transformer'] = self._serialize_prediction(transformer_pred)
if predictions:
response['prediction'] = predictions
except Exception as e:
logger.debug(f"Error getting predictions: {e}")
return jsonify(response)
except Exception as e:
logger.error(f"Error in batch live updates: {e}")
return jsonify({
'success': False,
'error': str(e)
})
@self.server.route('/api/realtime-inference/signals', methods=['GET']) @self.server.route('/api/realtime-inference/signals', methods=['GET'])
def get_realtime_signals(): def get_realtime_signals():
"""Get latest real-time inference signals""" """Get latest real-time inference signals"""

View File

@@ -3059,30 +3059,37 @@ class ChartManager {
let targetPrice = currentPrice; let targetPrice = currentPrice;
// CRITICAL FIX: Check if price_delta is normalized (< 1.0) or real price change // CRITICAL FIX: Use calculated_direction and calculated_steepness from trend_vector
if (trendVector.price_delta !== undefined && trendVector.price_delta !== null) { // The price_delta in trend_vector is the pivot range, not the predicted change
const priceDelta = parseFloat(trendVector.price_delta); // We should use direction and steepness to estimate the trend
const direction = parseFloat(trendVector.calculated_direction) || 0; // -1, 0, or 1
const steepness = parseFloat(trendVector.calculated_steepness) || 0;
// If price_delta is very small (< 1.0), it's likely normalized - scale it // Steepness is in price units, but we need to scale it reasonably
if (Math.abs(priceDelta) < 1.0) { // If steepness is > 100, it's likely in absolute price units (too large)
// Normalized value - treat as percentage of current price // Scale it down to a reasonable percentage move
targetPrice = currentPrice * (1 + priceDelta); let priceChange = 0;
if (steepness > 0) {
// If steepness is large (> 10), treat it as absolute price change but cap it
if (steepness > 10) {
// Cap at 2% of current price
const maxChange = 0.02 * currentPrice;
priceChange = Math.min(steepness, maxChange) * direction;
} else { } else {
// Real price delta - add directly // Small steepness - use as percentage
targetPrice = currentPrice + priceDelta; priceChange = (steepness / 100) * currentPrice * direction;
} }
} else { } else {
// Fallback: Use direction and steepness // Fallback: Use angle if available
const direction = trendVector.direction === 'up' ? 1 : const angle = parseFloat(trendVector.calculated_angle) || 0;
(trendVector.direction === 'down' ? -1 : 0); // Angle is in radians, convert to price change
const steepness = parseFloat(trendVector.steepness) || 0; // 0 to 1 // Small angle = small change, large angle = large change
priceChange = Math.tan(angle) * currentPrice * 0.01; // Scale down
// Estimate price change based on steepness (max 1% move per projection period)
const maxChange = 0.01 * currentPrice;
const projectedChange = maxChange * steepness * direction;
targetPrice = currentPrice + projectedChange;
} }
targetPrice = currentPrice + priceChange;
// Sanity check: Don't let target price go to 0 or negative // Sanity check: Don't let target price go to 0 or negative
if (targetPrice <= 0 || !isFinite(targetPrice)) { if (targetPrice <= 0 || !isFinite(targetPrice)) {
console.warn('Invalid target price calculated:', targetPrice, 'using current price instead'); console.warn('Invalid target price calculated:', targetPrice, 'using current price instead');

View File

@@ -57,28 +57,42 @@ class LiveUpdatesPolling {
} }
_poll() { _poll() {
// Poll each subscription // OPTIMIZATION: Batch all subscriptions into a single API call
// Group by symbol to reduce API calls from 4 to 1
const symbolGroups = {};
this.subscriptions.forEach(sub => { this.subscriptions.forEach(sub => {
fetch('/api/live-updates', { if (!symbolGroups[sub.symbol]) {
symbolGroups[sub.symbol] = [];
}
symbolGroups[sub.symbol].push(sub.timeframe);
});
// Make one call per symbol with all timeframes
Object.entries(symbolGroups).forEach(([symbol, timeframes]) => {
fetch('/api/live-updates-batch', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ body: JSON.stringify({
symbol: sub.symbol, symbol: symbol,
timeframe: sub.timeframe timeframes: timeframes
}) })
}) })
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
if (data.success) { if (data.success) {
// Handle chart update (even if null, predictions should still be processed) // Handle chart updates for each timeframe
if (data.chart_update && this.onChartUpdate) { if (data.chart_updates && this.onChartUpdate) {
this.onChartUpdate(data.chart_update); // chart_updates is an object: { '1s': {...}, '1m': {...}, ... }
Object.entries(data.chart_updates).forEach(([timeframe, update]) => {
if (update) {
this.onChartUpdate(update);
}
});
} }
// CRITICAL FIX: Handle prediction update properly // Handle prediction update (single prediction for all timeframes)
// data.prediction is already in format { transformer: {...}, dqn: {...}, cnn: {...} } // data.prediction is in format { transformer: {...}, dqn: {...}, cnn: {...} }
if (data.prediction && this.onPredictionUpdate) { if (data.prediction && this.onPredictionUpdate) {
// Log prediction data for debugging
console.log('[Live Updates] Received prediction data:', { console.log('[Live Updates] Received prediction data:', {
has_transformer: !!data.prediction.transformer, has_transformer: !!data.prediction.transformer,
has_dqn: !!data.prediction.dqn, has_dqn: !!data.prediction.dqn,
@@ -88,10 +102,7 @@ class LiveUpdatesPolling {
has_predicted_candle: !!data.prediction.transformer?.predicted_candle has_predicted_candle: !!data.prediction.transformer?.predicted_candle
}); });
// Pass the prediction object directly (it's already in the correct format)
this.onPredictionUpdate(data.prediction); this.onPredictionUpdate(data.prediction);
} else if (!data.prediction) {
console.debug('[Live Updates] No prediction data in response');
} }
} else { } else {
console.debug('[Live Updates] Response not successful:', data); console.debug('[Live Updates] Response not successful:', data);

View File

@@ -38,6 +38,10 @@ class DuckDBStorage:
# Connect to DuckDB # Connect to DuckDB
self.conn = duckdb.connect(str(self.db_path)) self.conn = duckdb.connect(str(self.db_path))
# CRITICAL: DuckDB connections are NOT thread-safe
# All database operations must be serialized with this lock
self._conn_lock = threading.RLock() # Use RLock to allow reentrant calls from same thread
# Batch logging for compact output # Batch logging for compact output
self._batch_buffer = [] # List of (symbol, timeframe, count, total) tuples self._batch_buffer = [] # List of (symbol, timeframe, count, total) tuples
self._batch_lock = threading.Lock() self._batch_lock = threading.Lock()
@@ -54,9 +58,10 @@ class DuckDBStorage:
def _init_schema(self): def _init_schema(self):
"""Initialize database schema - all data in DuckDB tables""" """Initialize database schema - all data in DuckDB tables"""
# CRITICAL: Schema initialization must be serialized
# Create OHLCV data table - stores ALL candles with self._conn_lock:
self.conn.execute(""" # Create OHLCV data table - stores ALL candles
self.conn.execute("""
CREATE SEQUENCE IF NOT EXISTS ohlcv_id_seq START 1 CREATE SEQUENCE IF NOT EXISTS ohlcv_id_seq START 1
""") """)
self.conn.execute(""" self.conn.execute("""
@@ -207,34 +212,36 @@ class DuckDBStorage:
# Insert data directly into DuckDB (ignore duplicates) # Insert data directly into DuckDB (ignore duplicates)
# Note: id column is auto-generated, so we don't include it # Note: id column is auto-generated, so we don't include it
# Using INSERT OR IGNORE for better DuckDB compatibility # Using INSERT OR IGNORE for better DuckDB compatibility
self.conn.execute(""" # CRITICAL: All DuckDB operations must be serialized with lock
INSERT OR IGNORE INTO ohlcv_data (symbol, timeframe, timestamp, open, high, low, close, volume, created_at) with self._conn_lock:
SELECT symbol, timeframe, timestamp, open, high, low, close, volume, created_at self.conn.execute("""
FROM df_insert INSERT OR IGNORE INTO ohlcv_data (symbol, timeframe, timestamp, open, high, low, close, volume, created_at)
""") SELECT symbol, timeframe, timestamp, open, high, low, close, volume, created_at
FROM df_insert
""")
# Update metadata # Update metadata
result = self.conn.execute(""" result = self.conn.execute("""
SELECT SELECT
MIN(timestamp) as first_ts, MIN(timestamp) as first_ts,
MAX(timestamp) as last_ts, MAX(timestamp) as last_ts,
COUNT(*) as count COUNT(*) as count
FROM ohlcv_data FROM ohlcv_data
WHERE symbol = ? AND timeframe = ? WHERE symbol = ? AND timeframe = ?
""", (symbol, timeframe)).fetchone() """, (symbol, timeframe)).fetchone()
# Handle case where no data exists yet # Handle case where no data exists yet
if result is None or result[0] is None: if result is None or result[0] is None:
first_ts, last_ts, count = 0, 0, 0 first_ts, last_ts, count = 0, 0, 0
else: else:
first_ts, last_ts, count = result first_ts, last_ts, count = result
now_ts = int(datetime.now().timestamp() * 1000) now_ts = int(datetime.now().timestamp() * 1000)
self.conn.execute(""" self.conn.execute("""
INSERT OR REPLACE INTO cache_metadata INSERT OR REPLACE INTO cache_metadata
(symbol, timeframe, parquet_path, first_timestamp, last_timestamp, candle_count, last_update) (symbol, timeframe, parquet_path, first_timestamp, last_timestamp, candle_count, last_update)
VALUES (?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?)
""", (symbol, timeframe, '', first_ts, last_ts, count, now_ts)) """, (symbol, timeframe, '', first_ts, last_ts, count, now_ts))
# Add to batch buffer instead of logging immediately # Add to batch buffer instead of logging immediately
with self._batch_lock: with self._batch_lock:
@@ -303,8 +310,9 @@ class DuckDBStorage:
if limit: if limit:
query += f" LIMIT {limit}" query += f" LIMIT {limit}"
# Execute query # Execute query with thread-safe lock
df = self.conn.execute(query, params).df() with self._conn_lock:
df = self.conn.execute(query, params).df()
if df.empty: if df.empty:
return None return None
@@ -341,7 +349,8 @@ class DuckDBStorage:
WHERE symbol = ? AND timeframe = ? WHERE symbol = ? AND timeframe = ?
""" """
result = self.conn.execute(query, [symbol, timeframe]).fetchone() with self._conn_lock:
result = self.conn.execute(query, [symbol, timeframe]).fetchone()
if result and result[0] is not None: if result and result[0] is not None:
last_timestamp = pd.to_datetime(result[0], unit='ms', utc=True) last_timestamp = pd.to_datetime(result[0], unit='ms', utc=True)
@@ -385,7 +394,8 @@ class DuckDBStorage:
limit limit
] ]
df = self.conn.execute(query, params).df() with self._conn_lock:
df = self.conn.execute(query, params).df()
if df.empty: if df.empty:
return None return None
@@ -449,14 +459,15 @@ class DuckDBStorage:
df_copy.to_parquet(parquet_file, index=False, compression='snappy') df_copy.to_parquet(parquet_file, index=False, compression='snappy')
# Store annotation metadata in DuckDB # Store annotation metadata in DuckDB
self.conn.execute(""" with self._conn_lock:
INSERT OR REPLACE INTO annotations self.conn.execute("""
(annotation_id, symbol, timeframe, direction, INSERT OR REPLACE INTO annotations
entry_timestamp, entry_price, exit_timestamp, exit_price, (annotation_id, symbol, timeframe, direction,
profit_loss_pct, notes, created_at, market_context, entry_timestamp, entry_price, exit_timestamp, exit_price,
model_features, pivot_data, parquet_path) profit_loss_pct, notes, created_at, market_context,
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) model_features, pivot_data, parquet_path)
""", ( VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
annotation_id, annotation_id,
annotation_data.get('symbol'), annotation_data.get('symbol'),
annotation_data.get('timeframe'), annotation_data.get('timeframe'),
@@ -495,15 +506,16 @@ class DuckDBStorage:
""" """
try: try:
# Get annotation metadata # Get annotation metadata
result = self.conn.execute(""" with self._conn_lock:
SELECT * FROM annotations WHERE annotation_id = ? result = self.conn.execute("""
""", (annotation_id,)).fetchone() SELECT * FROM annotations WHERE annotation_id = ?
""", (annotation_id,)).fetchone()
if not result: if not result:
return None return None
# Parse annotation data # Parse annotation data
columns = [desc[0] for desc in self.conn.description] columns = [desc[0] for desc in self.conn.description]
annotation = dict(zip(columns, result)) annotation = dict(zip(columns, result))
# Parse JSON fields # Parse JSON fields
@@ -520,11 +532,12 @@ class DuckDBStorage:
timeframe = parquet_file.stem timeframe = parquet_file.stem
# Query parquet directly with DuckDB # Query parquet directly with DuckDB
df = self.conn.execute(f""" with self._conn_lock:
SELECT timestamp, open, high, low, close, volume df = self.conn.execute(f"""
FROM read_parquet('{parquet_file}') SELECT timestamp, open, high, low, close, volume
ORDER BY timestamp FROM read_parquet('{parquet_file}')
""").df() ORDER BY timestamp
""").df()
if not df.empty: if not df.empty:
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True) df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
@@ -550,12 +563,13 @@ class DuckDBStorage:
DataFrame with results DataFrame with results
""" """
try: try:
if params: with self._conn_lock:
result = self.conn.execute(query, params) if params:
else: result = self.conn.execute(query, params)
result = self.conn.execute(query) else:
result = self.conn.execute(query)
return result.df() return result.df()
except Exception as e: except Exception as e:
logger.error(f"Error executing query: {e}") logger.error(f"Error executing query: {e}")
@@ -564,26 +578,27 @@ class DuckDBStorage:
def get_cache_stats(self) -> Dict[str, Any]: def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache statistics""" """Get cache statistics"""
try: try:
# Get OHLCV stats with self._conn_lock:
ohlcv_stats = self.conn.execute(""" # Get OHLCV stats
SELECT symbol, timeframe, candle_count, first_timestamp, last_timestamp ohlcv_stats = self.conn.execute("""
FROM cache_metadata SELECT symbol, timeframe, candle_count, first_timestamp, last_timestamp
ORDER BY symbol, timeframe FROM cache_metadata
""").df() ORDER BY symbol, timeframe
""").df()
if not ohlcv_stats.empty: if not ohlcv_stats.empty:
ohlcv_stats['first_timestamp'] = pd.to_datetime(ohlcv_stats['first_timestamp'], unit='ms') ohlcv_stats['first_timestamp'] = pd.to_datetime(ohlcv_stats['first_timestamp'], unit='ms')
ohlcv_stats['last_timestamp'] = pd.to_datetime(ohlcv_stats['last_timestamp'], unit='ms') ohlcv_stats['last_timestamp'] = pd.to_datetime(ohlcv_stats['last_timestamp'], unit='ms')
# Get annotation count # Get annotation count
annotation_count = self.conn.execute(""" annotation_count = self.conn.execute("""
SELECT COUNT(*) as count FROM annotations SELECT COUNT(*) as count FROM annotations
""").fetchone()[0] """).fetchone()[0]
# Get total candles # Get total candles
total_candles = self.conn.execute(""" total_candles = self.conn.execute("""
SELECT SUM(candle_count) as total FROM cache_metadata SELECT SUM(candle_count) as total FROM cache_metadata
""").fetchone()[0] or 0 """).fetchone()[0] or 0
return { return {
'ohlcv_stats': ohlcv_stats.to_dict('records') if not ohlcv_stats.empty else [], 'ohlcv_stats': ohlcv_stats.to_dict('records') if not ohlcv_stats.empty else [],