Files
gogo2/core/duckdb_storage.py
2025-11-06 16:26:38 +02:00

607 lines
23 KiB
Python

"""
DuckDB Storage - Unified Storage with Native Parquet Support
DuckDB provides the best of both worlds:
- Native Parquet support (query files directly)
- Full SQL capabilities (complex queries)
- Columnar storage (fast analytics)
- Zero-copy reads (extremely fast)
- Embedded database (no server)
This replaces the dual SQLite + Parquet system with a single unified solution.
"""
import duckdb
import logging
import pandas as pd
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
import json
import threading
logger = logging.getLogger(__name__)
class DuckDBStorage:
"""Unified storage using DuckDB with native Parquet support"""
def __init__(self, db_path: str = "cache/trading_data.duckdb"):
"""Initialize DuckDB storage"""
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
# Parquet storage directory (only for annotation snapshots)
self.parquet_dir = self.db_path.parent / "annotation_snapshots"
self.parquet_dir.mkdir(parents=True, exist_ok=True)
# Connect to DuckDB
self.conn = duckdb.connect(str(self.db_path))
# Batch logging for compact output
self._batch_buffer = [] # List of (symbol, timeframe, count, total) tuples
self._batch_lock = threading.Lock()
self._batch_flush_timer = None
self._batch_flush_delay = 0.5 # Flush after 0.5 seconds of inactivity
self._batch_timer_lock = threading.Lock()
self._flush_in_progress = False
# Initialize schema
self._init_schema()
logger.info(f"DuckDB storage initialized: {self.db_path}")
logger.info(f"Annotation snapshots: {self.parquet_dir}")
def _init_schema(self):
"""Initialize database schema - all data in DuckDB tables"""
# Create OHLCV data table - stores ALL candles
self.conn.execute("""
CREATE SEQUENCE IF NOT EXISTS ohlcv_id_seq START 1
""")
self.conn.execute("""
CREATE TABLE IF NOT EXISTS ohlcv_data (
id INTEGER PRIMARY KEY DEFAULT nextval('ohlcv_id_seq'),
symbol VARCHAR NOT NULL,
timeframe VARCHAR NOT NULL,
timestamp BIGINT NOT NULL,
open DOUBLE NOT NULL,
high DOUBLE NOT NULL,
low DOUBLE NOT NULL,
close DOUBLE NOT NULL,
volume DOUBLE NOT NULL,
created_at BIGINT NOT NULL,
UNIQUE(symbol, timeframe, timestamp)
)
""")
# Create indexes for fast queries
self.conn.execute("""
CREATE INDEX IF NOT EXISTS idx_ohlcv_symbol_timeframe
ON ohlcv_data(symbol, timeframe)
""")
self.conn.execute("""
CREATE INDEX IF NOT EXISTS idx_ohlcv_timestamp
ON ohlcv_data(timestamp)
""")
# Create annotations table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS annotations (
annotation_id VARCHAR PRIMARY KEY,
symbol VARCHAR NOT NULL,
timeframe VARCHAR NOT NULL,
direction VARCHAR NOT NULL,
entry_timestamp BIGINT NOT NULL,
entry_price DOUBLE NOT NULL,
exit_timestamp BIGINT NOT NULL,
exit_price DOUBLE NOT NULL,
profit_loss_pct DOUBLE NOT NULL,
notes TEXT,
created_at BIGINT NOT NULL,
market_context JSON,
model_features JSON,
pivot_data JSON,
parquet_path VARCHAR
)
""")
# Create cache metadata table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS cache_metadata (
symbol VARCHAR NOT NULL,
timeframe VARCHAR NOT NULL,
parquet_path VARCHAR,
first_timestamp BIGINT NOT NULL,
last_timestamp BIGINT NOT NULL,
candle_count INTEGER NOT NULL,
last_update BIGINT NOT NULL,
PRIMARY KEY (symbol, timeframe)
)
""")
logger.info("DuckDB schema initialized (all data in tables)")
def _schedule_batch_flush(self):
"""Schedule a batch flush after delay, resetting timer on each call"""
with self._batch_timer_lock:
# Cancel existing timer if any
if self._batch_flush_timer:
self._batch_flush_timer.cancel()
# Start new timer that will flush after delay
self._batch_flush_timer = threading.Timer(self._batch_flush_delay, self._flush_batch_log)
self._batch_flush_timer.daemon = True
self._batch_flush_timer.start()
def _flush_batch_log(self):
"""Flush accumulated batch logs as a single compact line"""
with self._batch_lock:
if not self._batch_buffer or self._flush_in_progress:
return
self._flush_in_progress = True
# Group by symbol for better readability
symbol_groups = {}
for symbol, timeframe, count, total in self._batch_buffer:
if symbol not in symbol_groups:
symbol_groups[symbol] = []
symbol_groups[symbol].append((timeframe, count, total))
# Build compact log message
parts = []
for symbol in sorted(symbol_groups.keys()):
symbol_parts = []
for timeframe, count, total in sorted(symbol_groups[symbol]):
symbol_parts.append(f"{timeframe}({count}, total: {total})")
parts.append(f"{symbol}: {', '.join(symbol_parts)}")
log_msg = "Stored candles batch: " + " | ".join(parts)
logger.info(log_msg)
# Clear buffer and reset flag
self._batch_buffer.clear()
self._flush_in_progress = False
# Clear timer reference after flushing
with self._batch_timer_lock:
self._batch_flush_timer = None
def store_ohlcv_data(self, symbol: str, timeframe: str, df: pd.DataFrame) -> int:
"""
Store OHLCV data directly in DuckDB table
Args:
symbol: Trading symbol
timeframe: Timeframe
df: DataFrame with OHLCV data
Returns:
Number of rows stored
"""
if df is None or df.empty:
return 0
try:
# Prepare data
df_copy = df.copy()
# Ensure timestamp column
if 'timestamp' not in df_copy.columns:
df_copy['timestamp'] = df_copy.index
# Convert timestamp to Unix milliseconds
if pd.api.types.is_datetime64_any_dtype(df_copy['timestamp']):
df_copy['timestamp'] = df_copy['timestamp'].astype('int64') // 10**6
# Add metadata
df_copy['symbol'] = symbol
df_copy['timeframe'] = timeframe
df_copy['created_at'] = int(datetime.now().timestamp() * 1000)
# Select columns in correct order
columns = ['symbol', 'timeframe', 'timestamp', 'open', 'high', 'low', 'close', 'volume', 'created_at']
df_insert = df_copy[columns]
# Insert data directly into DuckDB (ignore duplicates)
# Note: id column is auto-generated, so we don't include it
self.conn.execute("""
INSERT INTO ohlcv_data (symbol, timeframe, timestamp, open, high, low, close, volume, created_at)
SELECT symbol, timeframe, timestamp, open, high, low, close, volume, created_at
FROM df_insert
ON CONFLICT DO NOTHING
""")
# Update metadata
result = self.conn.execute("""
SELECT
MIN(timestamp) as first_ts,
MAX(timestamp) as last_ts,
COUNT(*) as count
FROM ohlcv_data
WHERE symbol = ? AND timeframe = ?
""", (symbol, timeframe)).fetchone()
first_ts, last_ts, count = result
now_ts = int(datetime.now().timestamp() * 1000)
self.conn.execute("""
INSERT OR REPLACE INTO cache_metadata
(symbol, timeframe, parquet_path, first_timestamp, last_timestamp, candle_count, last_update)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (symbol, timeframe, '', first_ts, last_ts, count, now_ts))
# Add to batch buffer instead of logging immediately
with self._batch_lock:
self._batch_buffer.append((symbol, timeframe, len(df), count))
self._schedule_batch_flush()
return len(df)
except Exception as e:
logger.error(f"Error storing OHLCV data: {e}")
import traceback
traceback.print_exc()
return 0
def get_ohlcv_data(self, symbol: str, timeframe: str,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: Optional[int] = None,
direction: str = 'latest') -> Optional[pd.DataFrame]:
"""
Query OHLCV data directly from DuckDB table
Args:
symbol: Trading symbol
timeframe: Timeframe
start_time: Start time filter
end_time: End time filter
limit: Maximum number of candles
direction: 'latest' (most recent), 'before' (older data), 'after' (newer data)
Returns:
DataFrame with OHLCV data
"""
try:
# Build query
query = """
SELECT timestamp, open, high, low, close, volume
FROM ohlcv_data
WHERE symbol = ? AND timeframe = ?
"""
params = [symbol, timeframe]
# Handle different direction modes
if direction == 'before' and end_time:
# Get older data: candles BEFORE end_time
query += " AND timestamp < ?"
params.append(int(end_time.timestamp() * 1000))
query += " ORDER BY timestamp DESC"
elif direction == 'after' and start_time:
# Get newer data: candles AFTER start_time
query += " AND timestamp > ?"
params.append(int(start_time.timestamp() * 1000))
query += " ORDER BY timestamp ASC"
else:
# Default: get most recent data in range
if start_time:
query += " AND timestamp >= ?"
params.append(int(start_time.timestamp() * 1000))
if end_time:
query += " AND timestamp <= ?"
params.append(int(end_time.timestamp() * 1000))
query += " ORDER BY timestamp DESC"
if limit:
query += f" LIMIT {limit}"
# Execute query
df = self.conn.execute(query, params).df()
if df.empty:
return None
# Convert timestamp to datetime
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
df = df.set_index('timestamp')
df = df.sort_index()
logger.debug(f"Retrieved {len(df)} candles for {symbol} {timeframe} from DuckDB (direction={direction})")
return df
except Exception as e:
logger.error(f"Error retrieving OHLCV data: {e}")
import traceback
traceback.print_exc()
return None
def get_last_timestamp(self, symbol: str, timeframe: str) -> Optional[datetime]:
"""
Get the last timestamp for a symbol/timeframe from DuckDB
Args:
symbol: Trading symbol
timeframe: Timeframe
Returns:
Last timestamp or None if no data exists
"""
try:
query = """
SELECT MAX(timestamp) as last_timestamp
FROM ohlcv_data
WHERE symbol = ? AND timeframe = ?
"""
result = self.conn.execute(query, [symbol, timeframe]).fetchone()
if result and result[0] is not None:
last_timestamp = pd.to_datetime(result[0], unit='ms', utc=True)
logger.debug(f"Last timestamp for {symbol} {timeframe}: {last_timestamp}")
return last_timestamp
return None
except Exception as e:
logger.error(f"Error getting last timestamp for {symbol} {timeframe}: {e}")
return None
def get_ohlcv_data_since_timestamp(self, symbol: str, timeframe: str,
since_timestamp: datetime,
limit: int = 1500) -> Optional[pd.DataFrame]:
"""
Get OHLCV data since a specific timestamp, capped at limit
Args:
symbol: Trading symbol
timeframe: Timeframe
since_timestamp: Get data since this timestamp
limit: Maximum number of candles (default 1500)
Returns:
DataFrame with OHLCV data since timestamp
"""
try:
query = """
SELECT timestamp, open, high, low, close, volume
FROM ohlcv_data
WHERE symbol = ? AND timeframe = ? AND timestamp > ?
ORDER BY timestamp ASC
LIMIT ?
"""
params = [
symbol,
timeframe,
int(since_timestamp.timestamp() * 1000),
limit
]
df = self.conn.execute(query, params).df()
if df.empty:
return None
# Convert timestamp to datetime
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
df = df.set_index('timestamp')
logger.debug(f"Retrieved {len(df)} candles for {symbol} {timeframe} since {since_timestamp}")
return df
except Exception as e:
logger.error(f"Error retrieving OHLCV data since timestamp: {e}")
return None
def store_annotation(self, annotation_id: str, annotation_data: Dict[str, Any],
market_snapshots: Dict[str, pd.DataFrame],
model_predictions: Optional[List[Dict]] = None) -> bool:
"""
Store annotation with market snapshots as Parquet
Args:
annotation_id: Unique annotation ID
annotation_data: Annotation metadata
market_snapshots: Dict of {timeframe: DataFrame} with OHLCV data
model_predictions: List of model predictions
Returns:
True if successful
"""
try:
# Parse timestamps
entry_time = annotation_data.get('entry', {}).get('timestamp')
exit_time = annotation_data.get('exit', {}).get('timestamp')
if isinstance(entry_time, str):
entry_time = datetime.fromisoformat(entry_time.replace('Z', '+00:00'))
if isinstance(exit_time, str):
exit_time = datetime.fromisoformat(exit_time.replace('Z', '+00:00'))
# Store market snapshots as Parquet
annotation_parquet_dir = self.parquet_dir / "annotations" / annotation_id
annotation_parquet_dir.mkdir(parents=True, exist_ok=True)
for timeframe, df in market_snapshots.items():
if df is None or df.empty:
continue
df_copy = df.copy()
# Ensure timestamp column
if 'timestamp' not in df_copy.columns:
df_copy['timestamp'] = df_copy.index
# Convert timestamp
if pd.api.types.is_datetime64_any_dtype(df_copy['timestamp']):
df_copy['timestamp'] = df_copy['timestamp'].astype('int64') // 10**6
# Save to parquet
parquet_file = annotation_parquet_dir / f"{timeframe}.parquet"
df_copy.to_parquet(parquet_file, index=False, compression='snappy')
# Store annotation metadata in DuckDB
self.conn.execute("""
INSERT OR REPLACE INTO annotations
(annotation_id, symbol, timeframe, direction,
entry_timestamp, entry_price, exit_timestamp, exit_price,
profit_loss_pct, notes, created_at, market_context,
model_features, pivot_data, parquet_path)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
annotation_id,
annotation_data.get('symbol'),
annotation_data.get('timeframe'),
annotation_data.get('direction'),
int(entry_time.timestamp() * 1000),
annotation_data.get('entry', {}).get('price'),
int(exit_time.timestamp() * 1000),
annotation_data.get('exit', {}).get('price'),
annotation_data.get('profit_loss_pct'),
annotation_data.get('notes', ''),
int(datetime.now().timestamp() * 1000),
json.dumps(annotation_data.get('entry_market_state', {})),
json.dumps(annotation_data.get('model_features', {})),
json.dumps(annotation_data.get('pivot_data', {})),
str(annotation_parquet_dir)
))
logger.info(f"Stored annotation {annotation_id} with {len(market_snapshots)} timeframes")
return True
except Exception as e:
logger.error(f"Error storing annotation: {e}")
import traceback
traceback.print_exc()
return False
def get_annotation(self, annotation_id: str) -> Optional[Dict[str, Any]]:
"""
Retrieve annotation with market snapshots from Parquet
Args:
annotation_id: Annotation ID
Returns:
Dict with annotation data and OHLCV snapshots
"""
try:
# Get annotation metadata
result = self.conn.execute("""
SELECT * FROM annotations WHERE annotation_id = ?
""", (annotation_id,)).fetchone()
if not result:
return None
# Parse annotation data
columns = [desc[0] for desc in self.conn.description]
annotation = dict(zip(columns, result))
# Parse JSON fields
annotation['market_context'] = json.loads(annotation.get('market_context', '{}'))
annotation['model_features'] = json.loads(annotation.get('model_features', '{}'))
annotation['pivot_data'] = json.loads(annotation.get('pivot_data', '{}'))
# Load OHLCV snapshots from Parquet
parquet_dir = Path(annotation['parquet_path'])
annotation['ohlcv_snapshots'] = {}
if parquet_dir.exists():
for parquet_file in parquet_dir.glob('*.parquet'):
timeframe = parquet_file.stem
# Query parquet directly with DuckDB
df = self.conn.execute(f"""
SELECT timestamp, open, high, low, close, volume
FROM read_parquet('{parquet_file}')
ORDER BY timestamp
""").df()
if not df.empty:
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
df = df.set_index('timestamp')
annotation['ohlcv_snapshots'][timeframe] = df
logger.info(f"Retrieved annotation {annotation_id} with {len(annotation['ohlcv_snapshots'])} timeframes")
return annotation
except Exception as e:
logger.error(f"Error retrieving annotation: {e}")
return None
def query_sql(self, query: str, params: Optional[List] = None) -> pd.DataFrame:
"""
Execute arbitrary SQL query (including Parquet queries)
Args:
query: SQL query
params: Query parameters
Returns:
DataFrame with results
"""
try:
if params:
result = self.conn.execute(query, params)
else:
result = self.conn.execute(query)
return result.df()
except Exception as e:
logger.error(f"Error executing query: {e}")
return pd.DataFrame()
def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
try:
# Get OHLCV stats
ohlcv_stats = self.conn.execute("""
SELECT symbol, timeframe, candle_count, first_timestamp, last_timestamp
FROM cache_metadata
ORDER BY symbol, timeframe
""").df()
if not ohlcv_stats.empty:
ohlcv_stats['first_timestamp'] = pd.to_datetime(ohlcv_stats['first_timestamp'], unit='ms')
ohlcv_stats['last_timestamp'] = pd.to_datetime(ohlcv_stats['last_timestamp'], unit='ms')
# Get annotation count
annotation_count = self.conn.execute("""
SELECT COUNT(*) as count FROM annotations
""").fetchone()[0]
# Get total candles
total_candles = self.conn.execute("""
SELECT SUM(candle_count) as total FROM cache_metadata
""").fetchone()[0] or 0
return {
'ohlcv_stats': ohlcv_stats.to_dict('records') if not ohlcv_stats.empty else [],
'annotation_count': annotation_count,
'total_candles': total_candles
}
except Exception as e:
logger.error(f"Error getting cache stats: {e}")
return {}
def close(self):
"""Close database connection"""
# Cancel any pending timer
with self._batch_timer_lock:
if self._batch_flush_timer:
self._batch_flush_timer.cancel()
# Flush any pending batch logs
self._flush_batch_log()
if self.conn:
self.conn.close()
logger.info("DuckDB connection closed")