Merge commit 'd49a473ed6f4aef55bfdd47d6370e53582be6b7b' into cleanup
This commit is contained in:
387
utils/audit_plotter.py
Normal file
387
utils/audit_plotter.py
Normal file
@@ -0,0 +1,387 @@
|
||||
"""
|
||||
Audit Plotter
|
||||
|
||||
Create PNG snapshots of model input data at inference time:
|
||||
- Subplot 1: 1s candlesticks for recent window
|
||||
- Subplot 2: COB bucket volumes and imbalance near current price
|
||||
|
||||
Windows-safe, ASCII-only logging messages.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import math
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple
|
||||
|
||||
try:
|
||||
import matplotlib
|
||||
# Use a non-interactive backend suitable for headless servers
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
except Exception:
|
||||
matplotlib = None # type: ignore
|
||||
plt = None # type: ignore
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_recent_ohlcv(base_data, max_bars: int = 120) -> Tuple[List[datetime], List[float], List[float], List[float], List[float]]:
|
||||
"""Return recent 1s OHLCV arrays (time, open, high, low, close). Falls back to 1m if needed."""
|
||||
series = base_data.ohlcv_1s if getattr(base_data, "ohlcv_1s", None) else []
|
||||
if not series or len(series) < 5:
|
||||
series = base_data.ohlcv_1m if getattr(base_data, "ohlcv_1m", None) else []
|
||||
|
||||
series = series[-max_bars:] if series else []
|
||||
times = [b.timestamp for b in series]
|
||||
opens = [float(b.open) for b in series]
|
||||
highs = [float(b.high) for b in series]
|
||||
lows = [float(b.low) for b in series]
|
||||
closes = [float(b.close) for b in series]
|
||||
return times, opens, highs, lows, closes
|
||||
|
||||
|
||||
def _extract_timeframe_ohlcv(base_data, timeframe_attr: str, max_bars: int = 60) -> Tuple[List[datetime], List[float], List[float], List[float], List[float]]:
|
||||
"""Extract OHLCV data for a specific timeframe attribute."""
|
||||
series = getattr(base_data, timeframe_attr, []) if hasattr(base_data, timeframe_attr) else []
|
||||
series = series[-max_bars:] if series else []
|
||||
|
||||
if not series:
|
||||
return [], [], [], [], []
|
||||
|
||||
times = [b.timestamp for b in series]
|
||||
opens = [float(b.open) for b in series]
|
||||
highs = [float(b.high) for b in series]
|
||||
lows = [float(b.low) for b in series]
|
||||
closes = [float(b.close) for b in series]
|
||||
return times, opens, highs, lows, closes
|
||||
|
||||
|
||||
def _plot_candlesticks(ax, times, opens, highs, lows, closes, title):
|
||||
"""Plot candlestick chart on given axis."""
|
||||
if not times:
|
||||
ax.text(0.5, 0.5, f"No {title} data", ha="center", va="center", transform=ax.transAxes)
|
||||
ax.set_title(title)
|
||||
return
|
||||
|
||||
x = list(range(len(times)))
|
||||
# Plot high-low wicks
|
||||
ax.vlines(x, lows, highs, color="#444444", linewidth=0.8)
|
||||
|
||||
# Plot body as rectangles
|
||||
bodies = [closes[i] - opens[i] for i in range(len(opens))]
|
||||
bottoms = [min(opens[i], closes[i]) for i in range(len(opens))]
|
||||
colors = ["#00aa55" if bodies[i] >= 0 else "#cc3333" for i in range(len(bodies))]
|
||||
heights = [abs(bodies[i]) if abs(bodies[i]) > 1e-9 else 1e-9 for i in range(len(bodies))]
|
||||
|
||||
ax.bar(x, heights, bottom=bottoms, color=colors, width=0.6, align="center",
|
||||
edgecolor="#222222", linewidth=0.3)
|
||||
|
||||
ax.set_title(title, fontsize=10)
|
||||
ax.grid(True, linestyle=":", linewidth=0.4, alpha=0.6)
|
||||
|
||||
# Show recent price
|
||||
if closes:
|
||||
ax.text(0.02, 0.98, f"${closes[-1]:.2f}", transform=ax.transAxes,
|
||||
verticalalignment='top', fontsize=8, fontweight='bold')
|
||||
|
||||
|
||||
def _plot_data_summary(ax, base_data, symbol):
|
||||
"""Plot data summary statistics."""
|
||||
ax.axis('off')
|
||||
|
||||
# Collect data statistics
|
||||
stats = []
|
||||
|
||||
# ETH timeframes
|
||||
for tf, attr in [("1s", "ohlcv_1s"), ("1m", "ohlcv_1m"), ("1h", "ohlcv_1h"), ("1d", "ohlcv_1d")]:
|
||||
data = getattr(base_data, attr, []) if hasattr(base_data, attr) else []
|
||||
stats.append(f"ETH {tf}: {len(data)} bars")
|
||||
|
||||
# BTC data
|
||||
btc_data = getattr(base_data, "btc_ohlcv_1s", []) if hasattr(base_data, "btc_ohlcv_1s") else []
|
||||
stats.append(f"BTC 1s: {len(btc_data)} bars")
|
||||
|
||||
# COB data
|
||||
cob = getattr(base_data, "cob_data", None)
|
||||
if cob:
|
||||
if hasattr(cob, "price_buckets") and cob.price_buckets:
|
||||
stats.append(f"COB buckets: {len(cob.price_buckets)}")
|
||||
elif hasattr(cob, "bids") and hasattr(cob, "asks"):
|
||||
bids = getattr(cob, "bids", [])
|
||||
asks = getattr(cob, "asks", [])
|
||||
stats.append(f"COB levels: {len(bids)}b/{len(asks)}a")
|
||||
else:
|
||||
stats.append("COB: No data")
|
||||
else:
|
||||
stats.append("COB: Missing")
|
||||
|
||||
# Technical indicators
|
||||
tech_indicators = getattr(base_data, "technical_indicators", {}) if hasattr(base_data, "technical_indicators") else {}
|
||||
stats.append(f"Tech indicators: {len(tech_indicators)}")
|
||||
|
||||
# Display stats
|
||||
y_pos = 0.9
|
||||
ax.text(0.05, y_pos, "Data Summary:", fontweight='bold', transform=ax.transAxes)
|
||||
y_pos -= 0.12
|
||||
|
||||
for stat in stats:
|
||||
ax.text(0.05, y_pos, stat, fontsize=9, transform=ax.transAxes)
|
||||
y_pos -= 0.1
|
||||
|
||||
ax.set_title("Input Data Stats", fontsize=10)
|
||||
|
||||
|
||||
def _plot_cob_data(ax, prices, bid_v, ask_v, imb, current_price, symbol):
|
||||
"""Plot COB data with bid/ask volumes and imbalance."""
|
||||
if not prices:
|
||||
ax.text(0.5, 0.5, f"No COB data for {symbol}", ha="center", va="center")
|
||||
ax.set_title("COB Data - No Data Available")
|
||||
return
|
||||
|
||||
# Normalize x as offsets around current price if available
|
||||
if current_price > 0:
|
||||
xvals = [p - current_price for p in prices]
|
||||
ax.axvline(0.0, color="#666666", linestyle="--", linewidth=1.0, alpha=0.7)
|
||||
ax.set_xlabel("Price offset from current")
|
||||
else:
|
||||
xvals = prices
|
||||
ax.set_xlabel("Price")
|
||||
|
||||
# Plot bid/ask volumes
|
||||
ax.plot(xvals, bid_v, label="Bid Volume", color="#2c7fb8", linewidth=1.5)
|
||||
ax.plot(xvals, ask_v, label="Ask Volume", color="#d95f0e", linewidth=1.5)
|
||||
|
||||
# Secondary axis for imbalance
|
||||
ax2 = ax.twinx()
|
||||
ax2.plot(xvals, imb, label="Imbalance", color="#6a3d9a", linewidth=2, alpha=0.8)
|
||||
ax2.set_ylabel("Imbalance", color="#6a3d9a")
|
||||
ax2.tick_params(axis='y', labelcolor="#6a3d9a")
|
||||
|
||||
ax.set_ylabel("Volume")
|
||||
ax.grid(True, linestyle=":", linewidth=0.6, alpha=0.6)
|
||||
|
||||
# Combined legend
|
||||
lines, labels = ax.get_legend_handles_labels()
|
||||
lines2, labels2 = ax2.get_legend_handles_labels()
|
||||
ax.legend(lines + lines2, labels + labels2, loc="upper right")
|
||||
|
||||
# Title with current price info
|
||||
price_info = f" (${current_price:.2f})" if current_price > 0 else ""
|
||||
ax.set_title(f"COB Price Buckets - {symbol}{price_info}", fontsize=11)
|
||||
|
||||
|
||||
def _extract_cob(base_data, max_buckets: int = 40):
|
||||
"""Return sorted price buckets and metrics from COBData."""
|
||||
cob = getattr(base_data, "cob_data", None)
|
||||
|
||||
# Try to get price buckets from COBData object
|
||||
if cob is not None and hasattr(cob, "price_buckets") and cob.price_buckets:
|
||||
# Sort by price and clip
|
||||
prices = sorted(list(cob.price_buckets.keys()))[:max_buckets]
|
||||
bid_vol = []
|
||||
ask_vol = []
|
||||
imb = []
|
||||
for p in prices:
|
||||
bucket = cob.price_buckets.get(p, {})
|
||||
b = float(bucket.get("bid_volume", 0.0))
|
||||
a = float(bucket.get("ask_volume", 0.0))
|
||||
bid_vol.append(b)
|
||||
ask_vol.append(a)
|
||||
denom = (b + a) if (b + a) > 0 else 1.0
|
||||
imb.append((b - a) / denom)
|
||||
return prices, bid_vol, ask_vol, imb
|
||||
|
||||
# Fallback: try to extract from raw bids/asks if available
|
||||
if cob is not None:
|
||||
# Check if we have raw bids/asks data
|
||||
bids = getattr(cob, "bids", []) if hasattr(cob, "bids") else []
|
||||
asks = getattr(cob, "asks", []) if hasattr(cob, "asks") else []
|
||||
current_price = getattr(cob, "current_price", 0.0) if hasattr(cob, "current_price") else 0.0
|
||||
|
||||
if bids and asks and current_price > 0:
|
||||
# Create price buckets from raw data
|
||||
bucket_size = 1.0 if hasattr(cob, "bucket_size") and cob.bucket_size else 1.0
|
||||
buckets = {}
|
||||
|
||||
# Process bids
|
||||
for bid in bids[:50]: # Top 50 levels
|
||||
if isinstance(bid, dict):
|
||||
price = float(bid.get("price", 0))
|
||||
size = float(bid.get("size", 0))
|
||||
elif isinstance(bid, list) and len(bid) >= 2:
|
||||
price = float(bid[0])
|
||||
size = float(bid[1])
|
||||
else:
|
||||
continue
|
||||
|
||||
if price > 0 and size > 0:
|
||||
bucket_price = round(price / bucket_size) * bucket_size
|
||||
if bucket_price not in buckets:
|
||||
buckets[bucket_price] = {"bid_volume": 0.0, "ask_volume": 0.0}
|
||||
buckets[bucket_price]["bid_volume"] += size * price
|
||||
|
||||
# Process asks
|
||||
for ask in asks[:50]: # Top 50 levels
|
||||
if isinstance(ask, dict):
|
||||
price = float(ask.get("price", 0))
|
||||
size = float(ask.get("size", 0))
|
||||
elif isinstance(ask, list) and len(ask) >= 2:
|
||||
price = float(ask[0])
|
||||
size = float(ask[1])
|
||||
else:
|
||||
continue
|
||||
|
||||
if price > 0 and size > 0:
|
||||
bucket_price = round(price / bucket_size) * bucket_size
|
||||
if bucket_price not in buckets:
|
||||
buckets[bucket_price] = {"bid_volume": 0.0, "ask_volume": 0.0}
|
||||
buckets[bucket_price]["ask_volume"] += size * price
|
||||
|
||||
if buckets:
|
||||
# Sort by price and clip
|
||||
prices = sorted(list(buckets.keys()))[:max_buckets]
|
||||
bid_vol = []
|
||||
ask_vol = []
|
||||
imb = []
|
||||
for p in prices:
|
||||
bucket = buckets.get(p, {})
|
||||
b = float(bucket.get("bid_volume", 0.0))
|
||||
a = float(bucket.get("ask_volume", 0.0))
|
||||
bid_vol.append(b)
|
||||
ask_vol.append(a)
|
||||
denom = (b + a) if (b + a) > 0 else 1.0
|
||||
imb.append((b - a) / denom)
|
||||
return prices, bid_vol, ask_vol, imb
|
||||
|
||||
# No COB data available
|
||||
return [], [], [], []
|
||||
|
||||
|
||||
def save_inference_audit_image(base_data, model_name: str, symbol: str, out_root: str = "audit_inputs") -> str:
|
||||
"""Save a comprehensive PNG snapshot of input data with all timeframes and COB data."""
|
||||
if matplotlib is None or plt is None:
|
||||
logger.warning("matplotlib not available; skipping audit image")
|
||||
return ""
|
||||
|
||||
try:
|
||||
# Debug: Log what data we have
|
||||
logger.info(f"Creating audit image for {model_name} - {symbol}")
|
||||
if hasattr(base_data, 'ohlcv_1s'):
|
||||
logger.info(f"ETH 1s data: {len(base_data.ohlcv_1s)} bars")
|
||||
if hasattr(base_data, 'ohlcv_1m'):
|
||||
logger.info(f"ETH 1m data: {len(base_data.ohlcv_1m)} bars")
|
||||
if hasattr(base_data, 'ohlcv_1h'):
|
||||
logger.info(f"ETH 1h data: {len(base_data.ohlcv_1h)} bars")
|
||||
if hasattr(base_data, 'ohlcv_1d'):
|
||||
logger.info(f"ETH 1d data: {len(base_data.ohlcv_1d)} bars")
|
||||
if hasattr(base_data, 'btc_ohlcv_1s'):
|
||||
logger.info(f"BTC 1s data: {len(base_data.btc_ohlcv_1s)} bars")
|
||||
if hasattr(base_data, 'cob_data') and base_data.cob_data:
|
||||
cob = base_data.cob_data
|
||||
logger.info(f"COB data available: current_price={getattr(cob, 'current_price', 'N/A')}")
|
||||
if hasattr(cob, 'price_buckets') and cob.price_buckets:
|
||||
logger.info(f"COB price buckets: {len(cob.price_buckets)} buckets")
|
||||
elif hasattr(cob, 'bids') and hasattr(cob, 'asks'):
|
||||
logger.info(f"COB raw data: {len(getattr(cob, 'bids', []))} bids, {len(getattr(cob, 'asks', []))} asks")
|
||||
else:
|
||||
logger.info("COB data exists but no price_buckets or bids/asks found")
|
||||
else:
|
||||
logger.warning("No COB data available for audit image")
|
||||
# Ensure output directory structure
|
||||
day_dir = datetime.utcnow().strftime("%Y%m%d")
|
||||
out_dir = os.path.join(out_root, day_dir)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# File name: {ts}_{symbol}_{model}.png (ASCII-only)
|
||||
ts_str = datetime.utcnow().strftime("%H%M%S_%f")
|
||||
safe_symbol = symbol.replace("/", "-")
|
||||
fname = f"{ts_str}_{safe_symbol}_{model_name}.png"
|
||||
out_path = os.path.join(out_dir, fname)
|
||||
|
||||
# Extract all timeframe data
|
||||
eth_1s_times, eth_1s_o, eth_1s_h, eth_1s_l, eth_1s_c = _extract_timeframe_ohlcv(base_data, "ohlcv_1s", 60)
|
||||
eth_1m_times, eth_1m_o, eth_1m_h, eth_1m_l, eth_1m_c = _extract_timeframe_ohlcv(base_data, "ohlcv_1m", 60)
|
||||
eth_1h_times, eth_1h_o, eth_1h_h, eth_1h_l, eth_1h_c = _extract_timeframe_ohlcv(base_data, "ohlcv_1h", 24)
|
||||
eth_1d_times, eth_1d_o, eth_1d_h, eth_1d_l, eth_1d_c = _extract_timeframe_ohlcv(base_data, "ohlcv_1d", 30)
|
||||
btc_1s_times, btc_1s_o, btc_1s_h, btc_1s_l, btc_1s_c = _extract_timeframe_ohlcv(base_data, "btc_ohlcv_1s", 60)
|
||||
|
||||
# Extract COB data
|
||||
prices, bid_v, ask_v, imb = _extract_cob(base_data)
|
||||
current_price = float(getattr(getattr(base_data, "cob_data", None), "current_price", 0.0))
|
||||
|
||||
# Create comprehensive figure with multiple subplots
|
||||
fig = plt.figure(figsize=(16, 12), dpi=110)
|
||||
gs = fig.add_gridspec(3, 3, height_ratios=[2, 2, 1.5], width_ratios=[1, 1, 1])
|
||||
|
||||
# ETH 1s data (top left)
|
||||
ax1 = fig.add_subplot(gs[0, 0])
|
||||
_plot_candlesticks(ax1, eth_1s_times, eth_1s_o, eth_1s_h, eth_1s_l, eth_1s_c, f"ETH 1s (last 60)")
|
||||
|
||||
# ETH 1m data (top middle)
|
||||
ax2 = fig.add_subplot(gs[0, 1])
|
||||
_plot_candlesticks(ax2, eth_1m_times, eth_1m_o, eth_1m_h, eth_1m_l, eth_1m_c, f"ETH 1m (last 60)")
|
||||
|
||||
# ETH 1h data (top right)
|
||||
ax3 = fig.add_subplot(gs[0, 2])
|
||||
_plot_candlesticks(ax3, eth_1h_times, eth_1h_o, eth_1h_h, eth_1h_l, eth_1h_c, f"ETH 1h (last 24)")
|
||||
|
||||
# ETH 1d data (middle left)
|
||||
ax4 = fig.add_subplot(gs[1, 0])
|
||||
_plot_candlesticks(ax4, eth_1d_times, eth_1d_o, eth_1d_h, eth_1d_l, eth_1d_c, f"ETH 1d (last 30)")
|
||||
|
||||
# BTC 1s data (middle middle)
|
||||
ax5 = fig.add_subplot(gs[1, 1])
|
||||
_plot_candlesticks(ax5, btc_1s_times, btc_1s_o, btc_1s_h, btc_1s_l, btc_1s_c, f"BTC 1s (last 60)")
|
||||
|
||||
# Data summary (middle right)
|
||||
ax6 = fig.add_subplot(gs[1, 2])
|
||||
_plot_data_summary(ax6, base_data, symbol)
|
||||
|
||||
# COB data (bottom, spanning all columns) with optional heatmap overlay above it
|
||||
ax7 = fig.add_subplot(gs[2, :])
|
||||
_plot_cob_data(ax7, prices, bid_v, ask_v, imb, current_price, symbol)
|
||||
|
||||
# Optional: append a small heatmap figure to the side if available
|
||||
try:
|
||||
heat_times = getattr(base_data, 'cob_heatmap_times', [])
|
||||
heat_prices = getattr(base_data, 'cob_heatmap_prices', [])
|
||||
heat_vals = getattr(base_data, 'cob_heatmap_values', [])
|
||||
if heat_times and heat_prices and heat_vals:
|
||||
import numpy as np
|
||||
# Create an inset axes on ax7 for compact heatmap
|
||||
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
|
||||
inset_ax = inset_axes(ax7, width="25%", height="100%", loc='upper right', borderpad=1)
|
||||
z = np.array(heat_vals, dtype=float)
|
||||
if z.size > 0:
|
||||
col_max = np.maximum(z.max(axis=0), 1e-9)
|
||||
zn = (z / col_max).T
|
||||
inset_ax.imshow(zn, aspect='auto', origin='lower', cmap='turbo')
|
||||
inset_ax.set_title('COB Heatmap', fontsize=8)
|
||||
inset_ax.set_xticks([])
|
||||
inset_ax.set_yticks([])
|
||||
except Exception as _hm_ex:
|
||||
logger.debug(f"Audit heatmap overlay skipped: {_hm_ex}")
|
||||
|
||||
# Add overall title with model and timestamp info
|
||||
fig.suptitle(f"{model_name} - {safe_symbol} - {datetime.utcnow().strftime('%H:%M:%S')}",
|
||||
fontsize=14, fontweight='bold')
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(out_path, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
logger.info(f"Saved comprehensive audit image: {out_path}")
|
||||
return out_path
|
||||
except Exception as ex:
|
||||
logger.error(f"Failed to save audit image: {ex}")
|
||||
try:
|
||||
plt.close("all")
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
295
utils/cache_manager.py
Normal file
295
utils/cache_manager.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
Cache Manager for Trading System
|
||||
|
||||
Utilities for managing and cleaning up cache files, including:
|
||||
- Parquet file validation and repair
|
||||
- Cache cleanup and maintenance
|
||||
- Cache health monitoring
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CacheManager:
|
||||
"""Manages cache files for the trading system"""
|
||||
|
||||
def __init__(self, cache_dirs: List[str] = None):
|
||||
"""
|
||||
Initialize cache manager
|
||||
|
||||
Args:
|
||||
cache_dirs: List of cache directories to manage
|
||||
"""
|
||||
self.cache_dirs = cache_dirs or [
|
||||
"data/cache",
|
||||
"data/monthly_cache",
|
||||
"data/pivot_cache"
|
||||
]
|
||||
|
||||
# Ensure cache directories exist
|
||||
for cache_dir in self.cache_dirs:
|
||||
Path(cache_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def validate_parquet_file(self, file_path: Path) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate a Parquet file
|
||||
|
||||
Args:
|
||||
file_path: Path to the Parquet file
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
try:
|
||||
if not file_path.exists():
|
||||
return False, "File does not exist"
|
||||
|
||||
if file_path.stat().st_size == 0:
|
||||
return False, "File is empty"
|
||||
|
||||
# Try to read the file
|
||||
df = pd.read_parquet(file_path)
|
||||
|
||||
if df.empty:
|
||||
return False, "File contains no data"
|
||||
|
||||
# Check for required columns (basic validation)
|
||||
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||||
|
||||
if missing_columns:
|
||||
return False, f"Missing required columns: {missing_columns}"
|
||||
|
||||
return True, None
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
corrupted_indicators = [
|
||||
"parquet magic bytes not found",
|
||||
"corrupted",
|
||||
"couldn't deserialize thrift",
|
||||
"don't know what type",
|
||||
"invalid parquet file",
|
||||
"unexpected end of file",
|
||||
"invalid metadata"
|
||||
]
|
||||
|
||||
if any(indicator in error_str for indicator in corrupted_indicators):
|
||||
return False, f"Corrupted Parquet file: {e}"
|
||||
else:
|
||||
return False, f"Validation error: {e}"
|
||||
|
||||
def scan_cache_health(self) -> Dict[str, Dict]:
|
||||
"""
|
||||
Scan all cache directories for file health
|
||||
|
||||
Returns:
|
||||
Dictionary with cache health information
|
||||
"""
|
||||
health_report = {}
|
||||
|
||||
for cache_dir in self.cache_dirs:
|
||||
cache_path = Path(cache_dir)
|
||||
if not cache_path.exists():
|
||||
continue
|
||||
|
||||
dir_report = {
|
||||
'total_files': 0,
|
||||
'valid_files': 0,
|
||||
'corrupted_files': 0,
|
||||
'empty_files': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'corrupted_files_list': [],
|
||||
'old_files': []
|
||||
}
|
||||
|
||||
# Scan all Parquet files
|
||||
for file_path in cache_path.glob("*.parquet"):
|
||||
dir_report['total_files'] += 1
|
||||
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
dir_report['total_size_mb'] += file_size_mb
|
||||
|
||||
# Check file age
|
||||
file_age = datetime.now() - datetime.fromtimestamp(file_path.stat().st_mtime)
|
||||
if file_age > timedelta(days=7): # Files older than 7 days
|
||||
dir_report['old_files'].append({
|
||||
'file': str(file_path),
|
||||
'age_days': file_age.days,
|
||||
'size_mb': file_size_mb
|
||||
})
|
||||
|
||||
# Validate file
|
||||
is_valid, error_msg = self.validate_parquet_file(file_path)
|
||||
|
||||
if is_valid:
|
||||
dir_report['valid_files'] += 1
|
||||
else:
|
||||
if "empty" in error_msg.lower():
|
||||
dir_report['empty_files'] += 1
|
||||
else:
|
||||
dir_report['corrupted_files'] += 1
|
||||
dir_report['corrupted_files_list'].append({
|
||||
'file': str(file_path),
|
||||
'error': error_msg,
|
||||
'size_mb': file_size_mb
|
||||
})
|
||||
|
||||
health_report[cache_dir] = dir_report
|
||||
|
||||
return health_report
|
||||
|
||||
def cleanup_corrupted_files(self, dry_run: bool = True) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Clean up corrupted cache files
|
||||
|
||||
Args:
|
||||
dry_run: If True, only report what would be deleted
|
||||
|
||||
Returns:
|
||||
Dictionary of deleted files by directory
|
||||
"""
|
||||
deleted_files = {}
|
||||
|
||||
for cache_dir in self.cache_dirs:
|
||||
cache_path = Path(cache_dir)
|
||||
if not cache_path.exists():
|
||||
continue
|
||||
|
||||
deleted_files[cache_dir] = []
|
||||
|
||||
for file_path in cache_path.glob("*.parquet"):
|
||||
is_valid, error_msg = self.validate_parquet_file(file_path)
|
||||
|
||||
if not is_valid:
|
||||
if dry_run:
|
||||
deleted_files[cache_dir].append(f"WOULD DELETE: {file_path} ({error_msg})")
|
||||
logger.info(f"Would delete corrupted file: {file_path} ({error_msg})")
|
||||
else:
|
||||
try:
|
||||
file_path.unlink()
|
||||
deleted_files[cache_dir].append(f"DELETED: {file_path} ({error_msg})")
|
||||
logger.info(f"Deleted corrupted file: {file_path}")
|
||||
except Exception as e:
|
||||
deleted_files[cache_dir].append(f"FAILED TO DELETE: {file_path} ({e})")
|
||||
logger.error(f"Failed to delete corrupted file {file_path}: {e}")
|
||||
|
||||
return deleted_files
|
||||
|
||||
def cleanup_old_files(self, days_to_keep: int = 7, dry_run: bool = True) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Clean up old cache files
|
||||
|
||||
Args:
|
||||
days_to_keep: Number of days to keep files
|
||||
dry_run: If True, only report what would be deleted
|
||||
|
||||
Returns:
|
||||
Dictionary of deleted files by directory
|
||||
"""
|
||||
deleted_files = {}
|
||||
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
|
||||
|
||||
for cache_dir in self.cache_dirs:
|
||||
cache_path = Path(cache_dir)
|
||||
if not cache_path.exists():
|
||||
continue
|
||||
|
||||
deleted_files[cache_dir] = []
|
||||
|
||||
for file_path in cache_path.glob("*.parquet"):
|
||||
file_mtime = datetime.fromtimestamp(file_path.stat().st_mtime)
|
||||
|
||||
if file_mtime < cutoff_date:
|
||||
age_days = (datetime.now() - file_mtime).days
|
||||
|
||||
if dry_run:
|
||||
deleted_files[cache_dir].append(f"WOULD DELETE: {file_path} (age: {age_days} days)")
|
||||
logger.info(f"Would delete old file: {file_path} (age: {age_days} days)")
|
||||
else:
|
||||
try:
|
||||
file_path.unlink()
|
||||
deleted_files[cache_dir].append(f"DELETED: {file_path} (age: {age_days} days)")
|
||||
logger.info(f"Deleted old file: {file_path}")
|
||||
except Exception as e:
|
||||
deleted_files[cache_dir].append(f"FAILED TO DELETE: {file_path} ({e})")
|
||||
logger.error(f"Failed to delete old file {file_path}: {e}")
|
||||
|
||||
return deleted_files
|
||||
|
||||
def get_cache_summary(self) -> Dict[str, any]:
|
||||
"""Get a summary of cache usage"""
|
||||
health_report = self.scan_cache_health()
|
||||
|
||||
total_files = sum(report['total_files'] for report in health_report.values())
|
||||
total_valid = sum(report['valid_files'] for report in health_report.values())
|
||||
total_corrupted = sum(report['corrupted_files'] for report in health_report.values())
|
||||
total_size_mb = sum(report['total_size_mb'] for report in health_report.values())
|
||||
|
||||
return {
|
||||
'total_files': total_files,
|
||||
'valid_files': total_valid,
|
||||
'corrupted_files': total_corrupted,
|
||||
'health_percentage': (total_valid / total_files * 100) if total_files > 0 else 0,
|
||||
'total_size_mb': total_size_mb,
|
||||
'directories': health_report
|
||||
}
|
||||
|
||||
def emergency_cache_reset(self, confirm: bool = False) -> bool:
|
||||
"""
|
||||
Emergency cache reset - deletes all cache files
|
||||
|
||||
Args:
|
||||
confirm: Must be True to actually delete files
|
||||
|
||||
Returns:
|
||||
True if reset was performed
|
||||
"""
|
||||
if not confirm:
|
||||
logger.warning("Emergency cache reset called but not confirmed")
|
||||
return False
|
||||
|
||||
deleted_count = 0
|
||||
|
||||
for cache_dir in self.cache_dirs:
|
||||
cache_path = Path(cache_dir)
|
||||
if not cache_path.exists():
|
||||
continue
|
||||
|
||||
for file_path in cache_path.glob("*"):
|
||||
try:
|
||||
if file_path.is_file():
|
||||
file_path.unlink()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete {file_path}: {e}")
|
||||
|
||||
logger.warning(f"Emergency cache reset completed: deleted {deleted_count} files")
|
||||
return True
|
||||
|
||||
# Global cache manager instance
|
||||
_cache_manager_instance = None
|
||||
|
||||
def get_cache_manager() -> CacheManager:
|
||||
"""Get the global cache manager instance"""
|
||||
global _cache_manager_instance
|
||||
|
||||
if _cache_manager_instance is None:
|
||||
_cache_manager_instance = CacheManager()
|
||||
|
||||
return _cache_manager_instance
|
||||
|
||||
def cleanup_corrupted_cache(dry_run: bool = True) -> Dict[str, List[str]]:
|
||||
"""Convenience function to clean up corrupted cache files"""
|
||||
cache_manager = get_cache_manager()
|
||||
return cache_manager.cleanup_corrupted_files(dry_run=dry_run)
|
||||
|
||||
def get_cache_health() -> Dict[str, any]:
|
||||
"""Convenience function to get cache health summary"""
|
||||
cache_manager = get_cache_manager()
|
||||
return cache_manager.get_cache_summary()
|
||||
547
utils/checkpoint_manager.py
Normal file
547
utils/checkpoint_manager.py
Normal file
@@ -0,0 +1,547 @@
|
||||
"""
|
||||
Checkpoint Manager
|
||||
|
||||
This module provides functionality for managing model checkpoints, including:
|
||||
- Saving checkpoints with metadata
|
||||
- Loading the best checkpoint based on performance metrics
|
||||
- Cleaning up old or underperforming checkpoints
|
||||
- Database-backed metadata storage for efficient access
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import glob
|
||||
import logging
|
||||
import shutil
|
||||
import torch
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
|
||||
from .database_manager import get_database_manager, CheckpointMetadata
|
||||
from .text_logger import get_text_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global checkpoint manager instance
|
||||
_checkpoint_manager_instance = None
|
||||
|
||||
def get_checkpoint_manager(checkpoint_dir: str = "models/checkpoints", max_checkpoints: int = 10, metric_name: str = "accuracy") -> 'CheckpointManager':
|
||||
"""
|
||||
Get the global checkpoint manager instance
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Directory to store checkpoints
|
||||
max_checkpoints: Maximum number of checkpoints to keep
|
||||
metric_name: Metric to use for ranking checkpoints
|
||||
|
||||
Returns:
|
||||
CheckpointManager: Global checkpoint manager instance
|
||||
"""
|
||||
global _checkpoint_manager_instance
|
||||
|
||||
if _checkpoint_manager_instance is None:
|
||||
_checkpoint_manager_instance = CheckpointManager(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
max_checkpoints=max_checkpoints,
|
||||
metric_name=metric_name
|
||||
)
|
||||
|
||||
return _checkpoint_manager_instance
|
||||
|
||||
def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Dict[str, Any] = None, checkpoint_dir: str = "models/checkpoints") -> Any:
|
||||
"""
|
||||
Save a checkpoint with metadata to both filesystem and database
|
||||
|
||||
Args:
|
||||
model: The model to save
|
||||
model_name: Name of the model
|
||||
model_type: Type of the model ('cnn', 'rl', etc.)
|
||||
performance_metrics: Performance metrics
|
||||
training_metadata: Additional training metadata
|
||||
checkpoint_dir: Directory to store checkpoints
|
||||
|
||||
Returns:
|
||||
Any: Checkpoint metadata
|
||||
"""
|
||||
try:
|
||||
# Create checkpoint directory
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Create timestamp
|
||||
timestamp = datetime.now()
|
||||
timestamp_str = timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# Create checkpoint path
|
||||
model_dir = os.path.join(checkpoint_dir, model_name)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp_str}")
|
||||
checkpoint_id = f"{model_name}_{timestamp_str}"
|
||||
|
||||
# Save model
|
||||
torch_path = f"{checkpoint_path}.pt"
|
||||
if hasattr(model, 'save'):
|
||||
# Use model's save method if available
|
||||
model.save(checkpoint_path)
|
||||
else:
|
||||
# Otherwise, save state_dict
|
||||
torch.save({
|
||||
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else None,
|
||||
'model_name': model_name,
|
||||
'model_type': model_type,
|
||||
'timestamp': timestamp_str,
|
||||
'checkpoint_id': checkpoint_id
|
||||
}, torch_path)
|
||||
|
||||
# Calculate file size
|
||||
file_size_mb = os.path.getsize(torch_path) / (1024 * 1024) if os.path.exists(torch_path) else 0.0
|
||||
|
||||
# Save metadata to database
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = CheckpointMetadata(
|
||||
checkpoint_id=checkpoint_id,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
timestamp=timestamp,
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata or {},
|
||||
file_path=torch_path,
|
||||
file_size_mb=file_size_mb,
|
||||
is_active=True # New checkpoint is active by default
|
||||
)
|
||||
|
||||
# Save to database
|
||||
if db_manager.save_checkpoint_metadata(checkpoint_metadata):
|
||||
# Log checkpoint save event to text file
|
||||
text_logger = get_text_logger()
|
||||
text_logger.log_checkpoint_event(
|
||||
model_name=model_name,
|
||||
event_type="SAVED",
|
||||
checkpoint_id=checkpoint_id,
|
||||
details=f"loss={performance_metrics.get('loss', 'N/A')}, size={file_size_mb:.1f}MB"
|
||||
)
|
||||
logger.info(f"Checkpoint saved: {checkpoint_id}")
|
||||
else:
|
||||
logger.warning(f"Failed to save checkpoint metadata to database: {checkpoint_id}")
|
||||
|
||||
# Also save legacy JSON metadata for backward compatibility
|
||||
legacy_metadata = {
|
||||
'model_name': model_name,
|
||||
'model_type': model_type,
|
||||
'timestamp': timestamp_str,
|
||||
'performance_metrics': performance_metrics,
|
||||
'training_metadata': training_metadata or {},
|
||||
'checkpoint_id': checkpoint_id,
|
||||
'performance_score': performance_metrics.get('accuracy', performance_metrics.get('reward', 0.0)),
|
||||
'created_at': timestamp_str
|
||||
}
|
||||
|
||||
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
|
||||
json.dump(legacy_metadata, f, indent=2)
|
||||
|
||||
# Get checkpoint manager and clean up old checkpoints
|
||||
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
|
||||
checkpoint_manager._cleanup_checkpoints(model_name)
|
||||
|
||||
# Return metadata as an object for backward compatibility
|
||||
class CheckpointMetadataObj:
|
||||
def __init__(self, metadata):
|
||||
for key, value in metadata.items():
|
||||
setattr(self, key, value)
|
||||
# Add database fields
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.loss = performance_metrics.get('loss', performance_metrics.get('accuracy', 0.0))
|
||||
|
||||
return CheckpointMetadataObj(legacy_metadata)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint: {e}")
|
||||
return None
|
||||
|
||||
def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoints") -> Optional[Tuple[str, Any]]:
|
||||
"""
|
||||
Load the best checkpoint based on performance metrics using database metadata
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
checkpoint_dir: Directory to store checkpoints
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, Any]]: Path to the best checkpoint and its metadata, or None if not found
|
||||
"""
|
||||
try:
|
||||
# First try to get from database (fast metadata access)
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(model_name, "accuracy")
|
||||
|
||||
if not checkpoint_metadata:
|
||||
# Fallback to legacy file-based approach (no more scattered "No checkpoints found" logs)
|
||||
pass # Silent fallback
|
||||
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
|
||||
checkpoint_path, legacy_metadata = checkpoint_manager.load_best_checkpoint(model_name)
|
||||
|
||||
if not checkpoint_path:
|
||||
return None
|
||||
|
||||
# Convert legacy metadata to object
|
||||
class CheckpointMetadataObj:
|
||||
def __init__(self, metadata):
|
||||
for key, value in metadata.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
# Add performance score if not present
|
||||
if not hasattr(self, 'performance_score'):
|
||||
metrics = getattr(self, 'metrics', {})
|
||||
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
|
||||
self.performance_score = metrics.get(primary_metric, 0.0)
|
||||
|
||||
# Add created_at if not present
|
||||
if not hasattr(self, 'created_at'):
|
||||
self.created_at = getattr(self, 'timestamp', 'unknown')
|
||||
|
||||
# Add loss for compatibility
|
||||
self.loss = metrics.get('loss', self.performance_score)
|
||||
self.checkpoint_id = getattr(self, 'checkpoint_id', f"{model_name}_unknown")
|
||||
|
||||
return f"{checkpoint_path}.pt", CheckpointMetadataObj(legacy_metadata)
|
||||
|
||||
# Check if checkpoint file exists
|
||||
if not os.path.exists(checkpoint_metadata.file_path):
|
||||
logger.warning(f"Checkpoint file not found: {checkpoint_metadata.file_path}")
|
||||
return None
|
||||
|
||||
# Log checkpoint load event to text file
|
||||
text_logger = get_text_logger()
|
||||
text_logger.log_checkpoint_event(
|
||||
model_name=model_name,
|
||||
event_type="LOADED",
|
||||
checkpoint_id=checkpoint_metadata.checkpoint_id,
|
||||
details=f"loss={checkpoint_metadata.performance_metrics.get('loss', 'N/A')}"
|
||||
)
|
||||
|
||||
# Convert database metadata to object for backward compatibility
|
||||
class CheckpointMetadataObj:
|
||||
def __init__(self, db_metadata: CheckpointMetadata):
|
||||
self.checkpoint_id = db_metadata.checkpoint_id
|
||||
self.model_name = db_metadata.model_name
|
||||
self.model_type = db_metadata.model_type
|
||||
self.timestamp = db_metadata.timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
self.performance_metrics = db_metadata.performance_metrics
|
||||
self.training_metadata = db_metadata.training_metadata
|
||||
self.file_path = db_metadata.file_path
|
||||
self.file_size_mb = db_metadata.file_size_mb
|
||||
self.is_active = db_metadata.is_active
|
||||
|
||||
# Backward compatibility fields
|
||||
self.metrics = db_metadata.performance_metrics
|
||||
self.metadata = db_metadata.training_metadata
|
||||
self.created_at = self.timestamp
|
||||
self.performance_score = db_metadata.performance_metrics.get('accuracy',
|
||||
db_metadata.performance_metrics.get('reward', 0.0))
|
||||
self.loss = db_metadata.performance_metrics.get('loss', self.performance_score)
|
||||
|
||||
return checkpoint_metadata.file_path, CheckpointMetadataObj(checkpoint_metadata)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint: {e}")
|
||||
return None
|
||||
|
||||
class CheckpointManager:
|
||||
"""
|
||||
Manages model checkpoints with performance-based optimization
|
||||
|
||||
This class:
|
||||
1. Saves checkpoints with metadata
|
||||
2. Loads the best checkpoint based on performance metrics
|
||||
3. Cleans up old or underperforming checkpoints
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint_dir: str, max_checkpoints: int = 10, metric_name: str = "accuracy"):
|
||||
"""
|
||||
Initialize the checkpoint manager
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Directory to store checkpoints
|
||||
max_checkpoints: Maximum number of checkpoints to keep
|
||||
metric_name: Metric to use for ranking checkpoints
|
||||
"""
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.max_checkpoints = max_checkpoints
|
||||
self.metric_name = metric_name
|
||||
|
||||
# Create checkpoint directory if it doesn't exist
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
logger.info(f"CheckpointManager initialized with checkpoint_dir: {checkpoint_dir}")
|
||||
|
||||
def save_checkpoint(self, model_name: str, model_path: str, metrics: Dict[str, float], metadata: Dict[str, Any] = None) -> str:
|
||||
"""
|
||||
Save a checkpoint with metadata
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
model_path: Path to the model file
|
||||
metrics: Performance metrics
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
str: Path to the saved checkpoint
|
||||
"""
|
||||
try:
|
||||
# Create timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# Create checkpoint directory
|
||||
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Create checkpoint path
|
||||
checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_{timestamp}")
|
||||
|
||||
# Copy model file to checkpoint path
|
||||
shutil.copy2(model_path, f"{checkpoint_path}.pt")
|
||||
|
||||
# Create metadata
|
||||
checkpoint_metadata = {
|
||||
'model_name': model_name,
|
||||
'timestamp': timestamp,
|
||||
'metrics': metrics,
|
||||
'metadata': metadata or {}
|
||||
}
|
||||
|
||||
# Save metadata
|
||||
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
|
||||
json.dump(checkpoint_metadata, f, indent=2)
|
||||
|
||||
logger.info(f"Saved checkpoint to {checkpoint_path}")
|
||||
|
||||
# Clean up old checkpoints
|
||||
self._cleanup_checkpoints(model_name)
|
||||
|
||||
return checkpoint_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint: {e}")
|
||||
return ""
|
||||
|
||||
def load_best_checkpoint(self, model_name: str) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
Load the best checkpoint based on performance metrics
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict[str, Any]]: Path to the best checkpoint and its metadata
|
||||
"""
|
||||
try:
|
||||
# Find all checkpoint metadata files
|
||||
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
|
||||
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
|
||||
|
||||
if not metadata_files:
|
||||
# No more scattered "No checkpoints found" logs - handled by database system
|
||||
return "", {}
|
||||
|
||||
# Load metadata for each checkpoint
|
||||
checkpoints = []
|
||||
for metadata_file in metadata_files:
|
||||
try:
|
||||
with open(metadata_file, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Get checkpoint path (remove _metadata.json)
|
||||
checkpoint_path = metadata_file[:-14]
|
||||
|
||||
# Check if model file exists
|
||||
if not os.path.exists(f"{checkpoint_path}.pt"):
|
||||
logger.warning(f"Model file not found for checkpoint {checkpoint_path}")
|
||||
continue
|
||||
|
||||
checkpoints.append((checkpoint_path, metadata))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
|
||||
|
||||
if not checkpoints:
|
||||
# No more scattered logs - handled by database system
|
||||
return "", {}
|
||||
|
||||
# Sort by metric (highest first)
|
||||
checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True)
|
||||
|
||||
# Return best checkpoint
|
||||
best_checkpoint_path = checkpoints[0][0]
|
||||
best_checkpoint_metadata = checkpoints[0][1]
|
||||
|
||||
logger.info(f"Best checkpoint for {model_name}: {best_checkpoint_path}")
|
||||
|
||||
return best_checkpoint_path, best_checkpoint_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint: {e}")
|
||||
return "", {}
|
||||
|
||||
def _cleanup_checkpoints(self, model_name: str) -> int:
|
||||
"""
|
||||
Clean up old or underperforming checkpoints
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
|
||||
Returns:
|
||||
int: Number of checkpoints deleted
|
||||
"""
|
||||
try:
|
||||
# Find all checkpoint metadata files
|
||||
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
|
||||
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
|
||||
|
||||
if not metadata_files or len(metadata_files) <= self.max_checkpoints:
|
||||
return 0
|
||||
|
||||
# Load metadata for each checkpoint
|
||||
checkpoints = []
|
||||
for metadata_file in metadata_files:
|
||||
try:
|
||||
with open(metadata_file, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Get checkpoint path (remove _metadata.json)
|
||||
checkpoint_path = metadata_file[:-14]
|
||||
|
||||
checkpoints.append((checkpoint_path, metadata))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
|
||||
|
||||
# Sort by metric (highest first)
|
||||
checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True)
|
||||
|
||||
# Keep only the best checkpoints
|
||||
checkpoints_to_delete = checkpoints[self.max_checkpoints:]
|
||||
|
||||
# Delete checkpoints
|
||||
deleted_count = 0
|
||||
for checkpoint_path, _ in checkpoints_to_delete:
|
||||
try:
|
||||
# Delete model file
|
||||
if os.path.exists(f"{checkpoint_path}.pt"):
|
||||
os.remove(f"{checkpoint_path}.pt")
|
||||
|
||||
# Delete metadata file
|
||||
if os.path.exists(f"{checkpoint_path}_metadata.json"):
|
||||
os.remove(f"{checkpoint_path}_metadata.json")
|
||||
|
||||
deleted_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting checkpoint {checkpoint_path}: {e}")
|
||||
|
||||
logger.info(f"Deleted {deleted_count} old checkpoints for {model_name}")
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up checkpoints: {e}")
|
||||
return 0
|
||||
|
||||
def get_all_checkpoints(self, model_name: str) -> List[Tuple[str, Dict[str, Any]]]:
|
||||
"""
|
||||
Get all checkpoints for a model
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, Dict[str, Any]]]: List of checkpoint paths and metadata
|
||||
"""
|
||||
try:
|
||||
# Find all checkpoint metadata files
|
||||
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
|
||||
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
|
||||
|
||||
if not metadata_files:
|
||||
return []
|
||||
|
||||
# Load metadata for each checkpoint
|
||||
checkpoints = []
|
||||
for metadata_file in metadata_files:
|
||||
try:
|
||||
with open(metadata_file, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Get checkpoint path (remove _metadata.json)
|
||||
checkpoint_path = metadata_file[:-14]
|
||||
|
||||
# Check if model file exists
|
||||
if not os.path.exists(f"{checkpoint_path}.pt"):
|
||||
logger.warning(f"Model file not found for checkpoint {checkpoint_path}")
|
||||
continue
|
||||
|
||||
checkpoints.append((checkpoint_path, metadata))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
checkpoints.sort(key=lambda x: x[1].get('timestamp', ''), reverse=True)
|
||||
|
||||
return checkpoints
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all checkpoints: {e}")
|
||||
return []
|
||||
|
||||
def get_checkpoint_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics about all checkpoints
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Statistics about checkpoints
|
||||
"""
|
||||
try:
|
||||
stats = {
|
||||
'total_checkpoints': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'models': {}
|
||||
}
|
||||
|
||||
# Iterate through all model directories
|
||||
for model_dir in os.listdir(self.checkpoint_dir):
|
||||
model_path = os.path.join(self.checkpoint_dir, model_dir)
|
||||
if not os.path.isdir(model_path):
|
||||
continue
|
||||
|
||||
# Count checkpoints for this model
|
||||
checkpoint_files = glob.glob(os.path.join(model_path, f"{model_dir}_*.pt"))
|
||||
model_checkpoints = len(checkpoint_files)
|
||||
|
||||
# Calculate total size for this model
|
||||
model_size_mb = 0.0
|
||||
for checkpoint_file in checkpoint_files:
|
||||
try:
|
||||
size_bytes = os.path.getsize(checkpoint_file)
|
||||
model_size_mb += size_bytes / (1024 * 1024) # Convert to MB
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
stats['models'][model_dir] = {
|
||||
'checkpoints': model_checkpoints,
|
||||
'size_mb': round(model_size_mb, 2)
|
||||
}
|
||||
|
||||
stats['total_checkpoints'] += model_checkpoints
|
||||
stats['total_size_mb'] += model_size_mb
|
||||
|
||||
stats['total_size_mb'] = round(stats['total_size_mb'], 2)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint stats: {e}")
|
||||
return {
|
||||
'total_checkpoints': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'models': {}
|
||||
}
|
||||
565
utils/database_manager.py
Normal file
565
utils/database_manager.py
Normal file
@@ -0,0 +1,565 @@
|
||||
"""
|
||||
Database Manager for Trading System
|
||||
|
||||
Manages SQLite database for:
|
||||
1. Inference records logging
|
||||
2. Checkpoint metadata storage
|
||||
3. Model performance tracking
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class InferenceRecord:
|
||||
"""Structure for inference logging"""
|
||||
model_name: str
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
action: str
|
||||
confidence: float
|
||||
probabilities: Dict[str, float]
|
||||
input_features_hash: str # Hash of input features for deduplication
|
||||
processing_time_ms: float
|
||||
memory_usage_mb: float
|
||||
input_features: Optional[np.ndarray] = None # Full input features for training
|
||||
checkpoint_id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
@dataclass
|
||||
class CheckpointMetadata:
|
||||
"""Structure for checkpoint metadata"""
|
||||
checkpoint_id: str
|
||||
model_name: str
|
||||
model_type: str
|
||||
timestamp: datetime
|
||||
performance_metrics: Dict[str, float]
|
||||
training_metadata: Dict[str, Any]
|
||||
file_path: str
|
||||
file_size_mb: float
|
||||
is_active: bool = False # Currently loaded checkpoint
|
||||
|
||||
class DatabaseManager:
|
||||
"""Manages SQLite database for trading system logging and metadata"""
|
||||
|
||||
def __init__(self, db_path: str = "data/trading_system.db"):
|
||||
self.db_path = db_path
|
||||
self._ensure_db_directory()
|
||||
self._initialize_database()
|
||||
|
||||
def _ensure_db_directory(self):
|
||||
"""Ensure database directory exists"""
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
|
||||
def _initialize_database(self):
|
||||
"""Initialize database tables"""
|
||||
with self._get_connection() as conn:
|
||||
# Inference records table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS inference_records (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
confidence REAL NOT NULL,
|
||||
probabilities TEXT NOT NULL, -- JSON
|
||||
input_features_hash TEXT NOT NULL,
|
||||
input_features_blob BLOB, -- Store full input features for training
|
||||
processing_time_ms REAL NOT NULL,
|
||||
memory_usage_mb REAL NOT NULL,
|
||||
checkpoint_id TEXT,
|
||||
metadata TEXT, -- JSON
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Checkpoint metadata table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS checkpoint_metadata (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
checkpoint_id TEXT UNIQUE NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
model_type TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
performance_metrics TEXT NOT NULL, -- JSON
|
||||
training_metadata TEXT NOT NULL, -- JSON
|
||||
file_path TEXT NOT NULL,
|
||||
file_size_mb REAL NOT NULL,
|
||||
is_active BOOLEAN DEFAULT FALSE,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Model performance tracking table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS model_performance (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
date TEXT NOT NULL,
|
||||
total_predictions INTEGER DEFAULT 0,
|
||||
correct_predictions INTEGER DEFAULT 0,
|
||||
accuracy REAL DEFAULT 0.0,
|
||||
avg_confidence REAL DEFAULT 0.0,
|
||||
avg_processing_time_ms REAL DEFAULT 0.0,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(model_name, date)
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes for better performance
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_model_timestamp ON inference_records(model_name, timestamp)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_symbol ON inference_records(symbol)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_model ON checkpoint_metadata(model_name)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_active ON checkpoint_metadata(is_active)")
|
||||
|
||||
logger.info(f"Database initialized at {self.db_path}")
|
||||
|
||||
# Run migrations to handle schema updates
|
||||
self._run_migrations()
|
||||
|
||||
def _run_migrations(self):
|
||||
"""Run database migrations to handle schema updates"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
# Check if input_features_blob column exists
|
||||
cursor = conn.execute("PRAGMA table_info(inference_records)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
if 'input_features_blob' not in columns:
|
||||
logger.info("Adding input_features_blob column to inference_records table")
|
||||
conn.execute("ALTER TABLE inference_records ADD COLUMN input_features_blob BLOB")
|
||||
conn.commit()
|
||||
logger.info("Successfully added input_features_blob column")
|
||||
else:
|
||||
logger.debug("input_features_blob column already exists")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running database migrations: {e}")
|
||||
# If migration fails, we can still continue without the blob column
|
||||
|
||||
@contextmanager
|
||||
def _get_connection(self):
|
||||
"""Get database connection with proper error handling"""
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(self.db_path, timeout=30.0)
|
||||
conn.row_factory = sqlite3.Row # Enable dict-like access
|
||||
yield conn
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
logger.error(f"Database error: {e}")
|
||||
raise
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def log_inference(self, record: InferenceRecord) -> bool:
|
||||
"""Log an inference record"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
# Check if input_features_blob column exists
|
||||
cursor = conn.execute("PRAGMA table_info(inference_records)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
has_blob_column = 'input_features_blob' in columns
|
||||
|
||||
# Serialize input features if provided and column exists
|
||||
input_features_blob = None
|
||||
if record.input_features is not None and has_blob_column:
|
||||
input_features_blob = record.input_features.tobytes()
|
||||
|
||||
if has_blob_column:
|
||||
# Use full query with blob column
|
||||
conn.execute("""
|
||||
INSERT INTO inference_records (
|
||||
model_name, timestamp, symbol, action, confidence,
|
||||
probabilities, input_features_hash, input_features_blob,
|
||||
processing_time_ms, memory_usage_mb, checkpoint_id, metadata
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
record.model_name,
|
||||
record.timestamp.isoformat(),
|
||||
record.symbol,
|
||||
record.action,
|
||||
record.confidence,
|
||||
json.dumps(record.probabilities),
|
||||
record.input_features_hash,
|
||||
input_features_blob,
|
||||
record.processing_time_ms,
|
||||
record.memory_usage_mb,
|
||||
record.checkpoint_id,
|
||||
json.dumps(record.metadata) if record.metadata else None
|
||||
))
|
||||
else:
|
||||
# Fallback query without blob column
|
||||
logger.warning("input_features_blob column missing, storing without full features")
|
||||
conn.execute("""
|
||||
INSERT INTO inference_records (
|
||||
model_name, timestamp, symbol, action, confidence,
|
||||
probabilities, input_features_hash,
|
||||
processing_time_ms, memory_usage_mb, checkpoint_id, metadata
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
record.model_name,
|
||||
record.timestamp.isoformat(),
|
||||
record.symbol,
|
||||
record.action,
|
||||
record.confidence,
|
||||
json.dumps(record.probabilities),
|
||||
record.input_features_hash,
|
||||
record.processing_time_ms,
|
||||
record.memory_usage_mb,
|
||||
record.checkpoint_id,
|
||||
json.dumps(record.metadata) if record.metadata else None
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log inference record: {e}")
|
||||
return False
|
||||
|
||||
def save_checkpoint_metadata(self, metadata: CheckpointMetadata) -> bool:
|
||||
"""Save checkpoint metadata"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
# First, set all other checkpoints for this model as inactive
|
||||
conn.execute("""
|
||||
UPDATE checkpoint_metadata
|
||||
SET is_active = FALSE
|
||||
WHERE model_name = ?
|
||||
""", (metadata.model_name,))
|
||||
|
||||
# Insert or replace the new checkpoint metadata
|
||||
conn.execute("""
|
||||
INSERT OR REPLACE INTO checkpoint_metadata (
|
||||
checkpoint_id, model_name, model_type, timestamp,
|
||||
performance_metrics, training_metadata, file_path,
|
||||
file_size_mb, is_active
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
metadata.checkpoint_id,
|
||||
metadata.model_name,
|
||||
metadata.model_type,
|
||||
metadata.timestamp.isoformat(),
|
||||
json.dumps(metadata.performance_metrics),
|
||||
json.dumps(metadata.training_metadata),
|
||||
metadata.file_path,
|
||||
metadata.file_size_mb,
|
||||
metadata.is_active
|
||||
))
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save checkpoint metadata: {e}")
|
||||
return False
|
||||
|
||||
def get_checkpoint_metadata(self, model_name: str, checkpoint_id: str = None) -> Optional[CheckpointMetadata]:
|
||||
"""Get checkpoint metadata without loading the actual model"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
if checkpoint_id:
|
||||
# Get specific checkpoint
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM checkpoint_metadata
|
||||
WHERE model_name = ? AND checkpoint_id = ?
|
||||
""", (model_name, checkpoint_id))
|
||||
else:
|
||||
# Get active checkpoint
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM checkpoint_metadata
|
||||
WHERE model_name = ? AND is_active = TRUE
|
||||
ORDER BY timestamp DESC LIMIT 1
|
||||
""", (model_name,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return CheckpointMetadata(
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
model_name=row['model_name'],
|
||||
model_type=row['model_type'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
performance_metrics=json.loads(row['performance_metrics']),
|
||||
training_metadata=json.loads(row['training_metadata']),
|
||||
file_path=row['file_path'],
|
||||
file_size_mb=row['file_size_mb'],
|
||||
is_active=bool(row['is_active'])
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get checkpoint metadata: {e}")
|
||||
return None
|
||||
|
||||
def get_best_checkpoint_metadata(self, model_name: str, metric_name: str = "accuracy") -> Optional[CheckpointMetadata]:
|
||||
"""Get best checkpoint metadata based on performance metric"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM checkpoint_metadata
|
||||
WHERE model_name = ?
|
||||
ORDER BY json_extract(performance_metrics, '$.' || ?) DESC
|
||||
LIMIT 1
|
||||
""", (model_name, metric_name))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return CheckpointMetadata(
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
model_name=row['model_name'],
|
||||
model_type=row['model_type'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
performance_metrics=json.loads(row['performance_metrics']),
|
||||
training_metadata=json.loads(row['training_metadata']),
|
||||
file_path=row['file_path'],
|
||||
file_size_mb=row['file_size_mb'],
|
||||
is_active=bool(row['is_active'])
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get best checkpoint metadata: {e}")
|
||||
return None
|
||||
|
||||
def list_checkpoints(self, model_name: str) -> List[CheckpointMetadata]:
|
||||
"""List all checkpoints for a model"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM checkpoint_metadata
|
||||
WHERE model_name = ?
|
||||
ORDER BY timestamp DESC
|
||||
""", (model_name,))
|
||||
|
||||
checkpoints = []
|
||||
for row in cursor.fetchall():
|
||||
checkpoints.append(CheckpointMetadata(
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
model_name=row['model_name'],
|
||||
model_type=row['model_type'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
performance_metrics=json.loads(row['performance_metrics']),
|
||||
training_metadata=json.loads(row['training_metadata']),
|
||||
file_path=row['file_path'],
|
||||
file_size_mb=row['file_size_mb'],
|
||||
is_active=bool(row['is_active'])
|
||||
))
|
||||
return checkpoints
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list checkpoints: {e}")
|
||||
return []
|
||||
|
||||
def set_active_checkpoint(self, model_name: str, checkpoint_id: str) -> bool:
|
||||
"""Set a checkpoint as active for a model"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
# First, set all checkpoints for this model as inactive
|
||||
conn.execute("""
|
||||
UPDATE checkpoint_metadata
|
||||
SET is_active = FALSE
|
||||
WHERE model_name = ?
|
||||
""", (model_name,))
|
||||
|
||||
# Set the specified checkpoint as active
|
||||
cursor = conn.execute("""
|
||||
UPDATE checkpoint_metadata
|
||||
SET is_active = TRUE
|
||||
WHERE model_name = ? AND checkpoint_id = ?
|
||||
""", (model_name, checkpoint_id))
|
||||
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set active checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def get_recent_inferences(self, model_name: str, limit: int = 100) -> List[InferenceRecord]:
|
||||
"""Get recent inference records for a model"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM inference_records
|
||||
WHERE model_name = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""", (model_name, limit))
|
||||
|
||||
records = []
|
||||
for row in cursor.fetchall():
|
||||
# Deserialize input features if available
|
||||
input_features = None
|
||||
# Check if the column exists in the row (handles missing column gracefully)
|
||||
if 'input_features_blob' in row.keys() and row['input_features_blob']:
|
||||
try:
|
||||
# Reconstruct numpy array from bytes
|
||||
input_features = np.frombuffer(row['input_features_blob'], dtype=np.float32)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to deserialize input features: {e}")
|
||||
|
||||
records.append(InferenceRecord(
|
||||
model_name=row['model_name'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
symbol=row['symbol'],
|
||||
action=row['action'],
|
||||
confidence=row['confidence'],
|
||||
probabilities=json.loads(row['probabilities']),
|
||||
input_features_hash=row['input_features_hash'],
|
||||
processing_time_ms=row['processing_time_ms'],
|
||||
memory_usage_mb=row['memory_usage_mb'],
|
||||
input_features=input_features,
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
metadata=json.loads(row['metadata']) if row['metadata'] else None
|
||||
))
|
||||
return records
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get recent inferences: {e}")
|
||||
return []
|
||||
|
||||
def update_model_performance(self, model_name: str, date: str,
|
||||
total_predictions: int, correct_predictions: int,
|
||||
avg_confidence: float, avg_processing_time: float) -> bool:
|
||||
"""Update daily model performance statistics"""
|
||||
try:
|
||||
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
||||
|
||||
with self._get_connection() as conn:
|
||||
conn.execute("""
|
||||
INSERT OR REPLACE INTO model_performance (
|
||||
model_name, date, total_predictions, correct_predictions,
|
||||
accuracy, avg_confidence, avg_processing_time_ms
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
model_name, date, total_predictions, correct_predictions,
|
||||
accuracy, avg_confidence, avg_processing_time
|
||||
))
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update model performance: {e}")
|
||||
return False
|
||||
|
||||
def get_inference_records_for_training(self, model_name: str,
|
||||
symbol: str = None,
|
||||
hours_back: int = 24,
|
||||
limit: int = 1000) -> List[InferenceRecord]:
|
||||
"""
|
||||
Get inference records with input features for training feedback
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
symbol: Optional symbol filter
|
||||
hours_back: How many hours back to look
|
||||
limit: Maximum number of records
|
||||
|
||||
Returns:
|
||||
List of InferenceRecord with input_features populated
|
||||
"""
|
||||
try:
|
||||
cutoff_time = datetime.now() - timedelta(hours=hours_back)
|
||||
|
||||
with self._get_connection() as conn:
|
||||
# Check if input_features_blob column exists before querying
|
||||
cursor = conn.execute("PRAGMA table_info(inference_records)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
has_blob_column = 'input_features_blob' in columns
|
||||
|
||||
if not has_blob_column:
|
||||
logger.warning("input_features_blob column not found, returning empty list")
|
||||
return []
|
||||
|
||||
if symbol:
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM inference_records
|
||||
WHERE model_name = ? AND symbol = ? AND timestamp >= ?
|
||||
AND input_features_blob IS NOT NULL
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""", (model_name, symbol, cutoff_time.isoformat(), limit))
|
||||
else:
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM inference_records
|
||||
WHERE model_name = ? AND timestamp >= ?
|
||||
AND input_features_blob IS NOT NULL
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""", (model_name, cutoff_time.isoformat(), limit))
|
||||
|
||||
records = []
|
||||
for row in cursor.fetchall():
|
||||
# Deserialize input features
|
||||
input_features = None
|
||||
if row['input_features_blob']:
|
||||
try:
|
||||
input_features = np.frombuffer(row['input_features_blob'], dtype=np.float32)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to deserialize input features: {e}")
|
||||
continue # Skip records with corrupted features
|
||||
|
||||
records.append(InferenceRecord(
|
||||
model_name=row['model_name'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
symbol=row['symbol'],
|
||||
action=row['action'],
|
||||
confidence=row['confidence'],
|
||||
probabilities=json.loads(row['probabilities']),
|
||||
input_features_hash=row['input_features_hash'],
|
||||
processing_time_ms=row['processing_time_ms'],
|
||||
memory_usage_mb=row['memory_usage_mb'],
|
||||
input_features=input_features,
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
metadata=json.loads(row['metadata']) if row['metadata'] else None
|
||||
))
|
||||
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get inference records for training: {e}")
|
||||
return []
|
||||
|
||||
def cleanup_old_records(self, days_to_keep: int = 30) -> bool:
|
||||
"""Clean up old inference records"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
|
||||
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
DELETE FROM inference_records
|
||||
WHERE timestamp < ?
|
||||
""", (cutoff_date.isoformat(),))
|
||||
|
||||
deleted_count = cursor.rowcount
|
||||
conn.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"Cleaned up {deleted_count} old inference records")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup old records: {e}")
|
||||
return False
|
||||
|
||||
# Global database manager instance
|
||||
_db_manager_instance = None
|
||||
|
||||
def get_database_manager(db_path: str = "data/trading_system.db") -> DatabaseManager:
|
||||
"""Get the global database manager instance"""
|
||||
global _db_manager_instance
|
||||
|
||||
if _db_manager_instance is None:
|
||||
_db_manager_instance = DatabaseManager(db_path)
|
||||
|
||||
return _db_manager_instance
|
||||
|
||||
def reset_database_manager():
|
||||
"""Reset the database manager instance to force re-initialization"""
|
||||
global _db_manager_instance
|
||||
_db_manager_instance = None
|
||||
logger.info("Database manager instance reset - will re-initialize on next access")
|
||||
234
utils/inference_logger.py
Normal file
234
utils/inference_logger.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
Inference Logger
|
||||
|
||||
Centralized logging system for model inferences with database storage
|
||||
Eliminates scattered logging throughout the codebase
|
||||
"""
|
||||
|
||||
import time
|
||||
import hashlib
|
||||
import logging
|
||||
import psutil
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
from .database_manager import get_database_manager, InferenceRecord
|
||||
from .text_logger import get_text_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class InferenceLogger:
|
||||
"""Centralized inference logging system"""
|
||||
|
||||
def __init__(self):
|
||||
self.db_manager = get_database_manager()
|
||||
self.text_logger = get_text_logger()
|
||||
self._process = psutil.Process()
|
||||
|
||||
def log_inference(self,
|
||||
model_name: str,
|
||||
symbol: str,
|
||||
action: str,
|
||||
confidence: float,
|
||||
probabilities: Dict[str, float],
|
||||
input_features: Union[np.ndarray, Dict, List],
|
||||
processing_time_ms: float,
|
||||
checkpoint_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Log a model inference with all relevant details
|
||||
|
||||
Args:
|
||||
model_name: Name of the model making the prediction
|
||||
symbol: Trading symbol
|
||||
action: Predicted action (BUY/SELL/HOLD)
|
||||
confidence: Confidence score (0.0 to 1.0)
|
||||
probabilities: Action probabilities dict
|
||||
input_features: Input features used for prediction
|
||||
processing_time_ms: Time taken for inference in milliseconds
|
||||
checkpoint_id: ID of the checkpoint used
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
bool: True if logged successfully
|
||||
"""
|
||||
try:
|
||||
# Create feature hash for deduplication
|
||||
feature_hash = self._hash_features(input_features)
|
||||
|
||||
# Get current memory usage
|
||||
memory_usage_mb = self._get_memory_usage()
|
||||
|
||||
# Convert input features to numpy array if needed
|
||||
features_array = None
|
||||
if isinstance(input_features, np.ndarray):
|
||||
features_array = input_features.astype(np.float32)
|
||||
elif isinstance(input_features, (list, tuple)):
|
||||
features_array = np.array(input_features, dtype=np.float32)
|
||||
|
||||
# Create inference record
|
||||
record = InferenceRecord(
|
||||
model_name=model_name,
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
probabilities=probabilities,
|
||||
input_features_hash=feature_hash,
|
||||
processing_time_ms=processing_time_ms,
|
||||
memory_usage_mb=memory_usage_mb,
|
||||
input_features=features_array,
|
||||
checkpoint_id=checkpoint_id,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Log to database
|
||||
db_success = self.db_manager.log_inference(record)
|
||||
|
||||
# Log to text file
|
||||
text_success = self.text_logger.log_inference(
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
processing_time_ms=processing_time_ms,
|
||||
checkpoint_id=checkpoint_id
|
||||
)
|
||||
|
||||
if db_success:
|
||||
# Reduced logging - no more scattered logs at runtime
|
||||
pass # Database logging successful, text file provides human-readable record
|
||||
else:
|
||||
logger.error(f"Failed to log inference for {model_name}")
|
||||
|
||||
return db_success and text_success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging inference: {e}")
|
||||
return False
|
||||
|
||||
def _hash_features(self, features: Union[np.ndarray, Dict, List]) -> str:
|
||||
"""Create a hash of input features for deduplication"""
|
||||
try:
|
||||
if isinstance(features, np.ndarray):
|
||||
# Hash numpy array
|
||||
return hashlib.md5(features.tobytes()).hexdigest()[:16]
|
||||
elif isinstance(features, (dict, list)):
|
||||
# Hash dict or list by converting to string
|
||||
feature_str = str(sorted(features.items()) if isinstance(features, dict) else features)
|
||||
return hashlib.md5(feature_str.encode()).hexdigest()[:16]
|
||||
else:
|
||||
# Hash string representation
|
||||
return hashlib.md5(str(features).encode()).hexdigest()[:16]
|
||||
except Exception:
|
||||
# Fallback to timestamp-based hash
|
||||
return hashlib.md5(str(time.time()).encode()).hexdigest()[:16]
|
||||
|
||||
def _get_memory_usage(self) -> float:
|
||||
"""Get current memory usage in MB"""
|
||||
try:
|
||||
return self._process.memory_info().rss / (1024 * 1024)
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def get_model_stats(self, model_name: str, hours: int = 24) -> Dict[str, Any]:
|
||||
"""Get inference statistics for a model"""
|
||||
try:
|
||||
# Get recent inferences
|
||||
recent_inferences = self.db_manager.get_recent_inferences(model_name, limit=1000)
|
||||
|
||||
if not recent_inferences:
|
||||
return {
|
||||
'total_inferences': 0,
|
||||
'avg_confidence': 0.0,
|
||||
'avg_processing_time_ms': 0.0,
|
||||
'action_distribution': {},
|
||||
'symbol_distribution': {}
|
||||
}
|
||||
|
||||
# Filter by time window
|
||||
cutoff_time = datetime.now() - timedelta(hours=hours)
|
||||
recent_inferences = [r for r in recent_inferences if r.timestamp >= cutoff_time]
|
||||
|
||||
if not recent_inferences:
|
||||
return {
|
||||
'total_inferences': 0,
|
||||
'avg_confidence': 0.0,
|
||||
'avg_processing_time_ms': 0.0,
|
||||
'action_distribution': {},
|
||||
'symbol_distribution': {}
|
||||
}
|
||||
|
||||
# Calculate statistics
|
||||
total_inferences = len(recent_inferences)
|
||||
avg_confidence = sum(r.confidence for r in recent_inferences) / total_inferences
|
||||
avg_processing_time = sum(r.processing_time_ms for r in recent_inferences) / total_inferences
|
||||
|
||||
# Action distribution
|
||||
action_counts = {}
|
||||
for record in recent_inferences:
|
||||
action_counts[record.action] = action_counts.get(record.action, 0) + 1
|
||||
|
||||
# Symbol distribution
|
||||
symbol_counts = {}
|
||||
for record in recent_inferences:
|
||||
symbol_counts[record.symbol] = symbol_counts.get(record.symbol, 0) + 1
|
||||
|
||||
return {
|
||||
'total_inferences': total_inferences,
|
||||
'avg_confidence': avg_confidence,
|
||||
'avg_processing_time_ms': avg_processing_time,
|
||||
'action_distribution': action_counts,
|
||||
'symbol_distribution': symbol_counts,
|
||||
'latest_inference': recent_inferences[0].timestamp.isoformat() if recent_inferences else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model stats: {e}")
|
||||
return {}
|
||||
|
||||
def cleanup_old_logs(self, days_to_keep: int = 30) -> bool:
|
||||
"""Clean up old inference logs"""
|
||||
return self.db_manager.cleanup_old_records(days_to_keep)
|
||||
|
||||
# Global inference logger instance
|
||||
_inference_logger_instance = None
|
||||
|
||||
def get_inference_logger() -> InferenceLogger:
|
||||
"""Get the global inference logger instance"""
|
||||
global _inference_logger_instance
|
||||
|
||||
if _inference_logger_instance is None:
|
||||
_inference_logger_instance = InferenceLogger()
|
||||
|
||||
return _inference_logger_instance
|
||||
|
||||
def log_model_inference(model_name: str,
|
||||
symbol: str,
|
||||
action: str,
|
||||
confidence: float,
|
||||
probabilities: Dict[str, float],
|
||||
input_features: Union[np.ndarray, Dict, List],
|
||||
processing_time_ms: float,
|
||||
checkpoint_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Convenience function to log model inference
|
||||
|
||||
This is the main function that should be called throughout the codebase
|
||||
instead of scattered logger.info() calls
|
||||
"""
|
||||
inference_logger = get_inference_logger()
|
||||
return inference_logger.log_inference(
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
probabilities=probabilities,
|
||||
input_features=input_features,
|
||||
processing_time_ms=processing_time_ms,
|
||||
checkpoint_id=checkpoint_id,
|
||||
metadata=metadata
|
||||
)
|
||||
@@ -1,164 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TensorBoard Launcher with Automatic Port Management
|
||||
|
||||
This script launches TensorBoard with automatic port fallback if the preferred port is in use.
|
||||
It also kills any stale debug instances that might be running.
|
||||
|
||||
Usage:
|
||||
python launch_tensorboard.py --logdir=path/to/logs --preferred-port=6007 --port-range=6000-7000
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
if project_root not in sys.path:
|
||||
sys.path.append(project_root)
|
||||
|
||||
from utils.port_manager import get_port_with_fallback, kill_stale_debug_instances
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger('tensorboard_launcher')
|
||||
|
||||
def launch_tensorboard(logdir, port, host='localhost', open_browser=True):
|
||||
"""
|
||||
Launch TensorBoard on the specified port
|
||||
|
||||
Args:
|
||||
logdir (str): Path to log directory
|
||||
port (int): Port to use
|
||||
host (str): Host to bind to
|
||||
open_browser (bool): Whether to open browser automatically
|
||||
|
||||
Returns:
|
||||
subprocess.Popen: Process object
|
||||
"""
|
||||
cmd = [
|
||||
sys.executable, "-m", "tensorboard.main",
|
||||
f"--logdir={logdir}",
|
||||
f"--port={port}",
|
||||
f"--host={host}"
|
||||
]
|
||||
|
||||
# Add --load_fast=false to improve startup times
|
||||
cmd.append("--load_fast=false")
|
||||
|
||||
# Control whether to open browser
|
||||
if not open_browser:
|
||||
cmd.append("--window_title=TensorBoard")
|
||||
|
||||
logger.info(f"Launching TensorBoard: {' '.join(cmd)}")
|
||||
|
||||
# Use subprocess.Popen to start TensorBoard without waiting for it to finish
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=True,
|
||||
bufsize=1
|
||||
)
|
||||
|
||||
# Log the first few lines of output to confirm it's starting correctly
|
||||
line_count = 0
|
||||
for line in process.stdout:
|
||||
logger.info(f"TensorBoard: {line.strip()}")
|
||||
line_count += 1
|
||||
|
||||
# Check if TensorBoard has started successfully
|
||||
if "TensorBoard" in line and "http://" in line:
|
||||
url = line.strip().split("http://")[1].split(" ")[0]
|
||||
logger.info(f"TensorBoard available at: http://{url}")
|
||||
|
||||
# Only log the first few lines
|
||||
if line_count >= 10:
|
||||
break
|
||||
|
||||
# Continue reading output in background to prevent pipe from filling
|
||||
def read_output():
|
||||
for line in process.stdout:
|
||||
pass
|
||||
|
||||
import threading
|
||||
threading.Thread(target=read_output, daemon=True).start()
|
||||
|
||||
return process
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Launch TensorBoard with automatic port management')
|
||||
parser.add_argument('--logdir', type=str, default='NN/models/saved/logs',
|
||||
help='Directory containing TensorBoard event files')
|
||||
parser.add_argument('--preferred-port', type=int, default=6007,
|
||||
help='Preferred port to use')
|
||||
parser.add_argument('--port-range', type=str, default='6000-7000',
|
||||
help='Port range to try if preferred port is unavailable (format: min-max)')
|
||||
parser.add_argument('--host', type=str, default='localhost',
|
||||
help='Host to bind to')
|
||||
parser.add_argument('--no-browser', action='store_true',
|
||||
help='Do not open browser automatically')
|
||||
parser.add_argument('--kill-stale', action='store_true',
|
||||
help='Kill stale debug instances before starting')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse port range
|
||||
try:
|
||||
min_port, max_port = map(int, args.port_range.split('-'))
|
||||
except ValueError:
|
||||
logger.error(f"Invalid port range format: {args.port_range}. Use format: min-max")
|
||||
return 1
|
||||
|
||||
# Kill stale instances if requested
|
||||
if args.kill_stale:
|
||||
logger.info("Killing stale debug instances...")
|
||||
count, _ = kill_stale_debug_instances()
|
||||
logger.info(f"Killed {count} stale instances")
|
||||
|
||||
# Get an available port
|
||||
try:
|
||||
port = get_port_with_fallback(args.preferred_port, min_port, max_port)
|
||||
logger.info(f"Using port {port} for TensorBoard")
|
||||
except RuntimeError as e:
|
||||
logger.error(str(e))
|
||||
return 1
|
||||
|
||||
# Ensure log directory exists
|
||||
logdir = os.path.abspath(args.logdir)
|
||||
os.makedirs(logdir, exist_ok=True)
|
||||
|
||||
# Launch TensorBoard
|
||||
process = launch_tensorboard(
|
||||
logdir=logdir,
|
||||
port=port,
|
||||
host=args.host,
|
||||
open_browser=not args.no_browser
|
||||
)
|
||||
|
||||
# Wait for process to end (it shouldn't unless there's an error or user kills it)
|
||||
try:
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
logger.error(f"TensorBoard exited with code {return_code}")
|
||||
return return_code
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received keyboard interrupt, shutting down TensorBoard...")
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("TensorBoard didn't terminate gracefully, forcing kill")
|
||||
process.kill()
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -1,241 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Model utilities for robust saving and loading of PyTorch models
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import shutil
|
||||
import gc
|
||||
import json
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def robust_save(model: Any, path: str, include_optimizer: bool = True) -> bool:
|
||||
"""
|
||||
Robust model saving with multiple fallback approaches
|
||||
|
||||
Args:
|
||||
model: The model object to save (should have policy_net, target_net, optimizer, epsilon attributes)
|
||||
path: Path to save the model
|
||||
include_optimizer: Whether to include optimizer state in the save
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
||||
|
||||
# Backup path in case the main save fails
|
||||
backup_path = f"{path}.backup"
|
||||
|
||||
# Clean up GPU memory before saving
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Prepare checkpoint data
|
||||
checkpoint = {
|
||||
'policy_net': model.policy_net.state_dict(),
|
||||
'target_net': model.target_net.state_dict(),
|
||||
'epsilon': getattr(model, 'epsilon', 0.0),
|
||||
'state_size': getattr(model, 'state_size', None),
|
||||
'action_size': getattr(model, 'action_size', None),
|
||||
'hidden_size': getattr(model, 'hidden_size', None),
|
||||
}
|
||||
|
||||
# Add optimizer state if requested and available
|
||||
if include_optimizer and hasattr(model, 'optimizer') and model.optimizer is not None:
|
||||
checkpoint['optimizer'] = model.optimizer.state_dict()
|
||||
|
||||
# Attempt 1: Try with default settings in a separate file first
|
||||
try:
|
||||
logger.info(f"Saving model to {backup_path} (attempt 1)")
|
||||
torch.save(checkpoint, backup_path)
|
||||
logger.info(f"Successfully saved to {backup_path}")
|
||||
|
||||
# If backup worked, copy to the actual path
|
||||
if os.path.exists(backup_path):
|
||||
shutil.copy(backup_path, path)
|
||||
logger.info(f"Copied backup to {path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"First save attempt failed: {e}")
|
||||
|
||||
# Attempt 2: Try with pickle protocol 2 (more compatible)
|
||||
try:
|
||||
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
|
||||
torch.save(checkpoint, path, pickle_protocol=2)
|
||||
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Second save attempt failed: {e}")
|
||||
|
||||
# Attempt 3: Try without optimizer state (which can be large and cause issues)
|
||||
try:
|
||||
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
|
||||
checkpoint_no_opt = {k: v for k, v in checkpoint.items() if k != 'optimizer'}
|
||||
torch.save(checkpoint_no_opt, path)
|
||||
logger.info(f"Successfully saved to {path} without optimizer state")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Third save attempt failed: {e}")
|
||||
|
||||
# Attempt 4: Try with torch.jit.save instead
|
||||
try:
|
||||
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
|
||||
# Save policy network using jit
|
||||
scripted_policy = torch.jit.script(model.policy_net)
|
||||
torch.jit.save(scripted_policy, f"{path}.policy.jit")
|
||||
|
||||
# Save target network using jit
|
||||
scripted_target = torch.jit.script(model.target_net)
|
||||
torch.jit.save(scripted_target, f"{path}.target.jit")
|
||||
|
||||
# Save parameters separately as JSON
|
||||
params = {
|
||||
'epsilon': float(getattr(model, 'epsilon', 0.0)),
|
||||
'state_size': int(getattr(model, 'state_size', 0)),
|
||||
'action_size': int(getattr(model, 'action_size', 0)),
|
||||
'hidden_size': int(getattr(model, 'hidden_size', 0))
|
||||
}
|
||||
with open(f"{path}.params.json", "w") as f:
|
||||
json.dump(params, f)
|
||||
|
||||
logger.info(f"Successfully saved model components with jit.save")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"All save attempts failed: {e}")
|
||||
return False
|
||||
|
||||
def robust_load(model: Any, path: str, device: Optional[torch.device] = None) -> bool:
|
||||
"""
|
||||
Robust model loading with fallback approaches
|
||||
|
||||
Args:
|
||||
model: The model object to load into
|
||||
path: Path to load the model from
|
||||
device: Device to load the model on
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Try regular PyTorch load first
|
||||
try:
|
||||
logger.info(f"Loading model from {path}")
|
||||
if os.path.exists(path):
|
||||
checkpoint = torch.load(path, map_location=device)
|
||||
|
||||
# Load network states
|
||||
if 'policy_net' in checkpoint:
|
||||
model.policy_net.load_state_dict(checkpoint['policy_net'])
|
||||
if 'target_net' in checkpoint:
|
||||
model.target_net.load_state_dict(checkpoint['target_net'])
|
||||
|
||||
# Load other attributes
|
||||
if 'epsilon' in checkpoint:
|
||||
model.epsilon = checkpoint['epsilon']
|
||||
if 'optimizer' in checkpoint and hasattr(model, 'optimizer') and model.optimizer is not None:
|
||||
try:
|
||||
model.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load optimizer state: {e}")
|
||||
|
||||
logger.info("Successfully loaded model")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Regular load failed: {e}")
|
||||
|
||||
# Try loading JIT saved components
|
||||
try:
|
||||
policy_path = f"{path}.policy.jit"
|
||||
target_path = f"{path}.target.jit"
|
||||
params_path = f"{path}.params.json"
|
||||
|
||||
if all(os.path.exists(p) for p in [policy_path, target_path, params_path]):
|
||||
logger.info(f"Loading JIT model components")
|
||||
|
||||
# Load JIT models (this is more complex and may need model reconstruction)
|
||||
# For now, just log that we found JIT files
|
||||
logger.info("Found JIT model files, but loading them requires special handling")
|
||||
with open(params_path, 'r') as f:
|
||||
params = json.load(f)
|
||||
logger.info(f"Model parameters: {params}")
|
||||
|
||||
# Note: Actually loading JIT models would require recreating the model architecture
|
||||
# This is a placeholder for future implementation
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"JIT load failed: {e}")
|
||||
|
||||
logger.error(f"All load attempts failed for {path}")
|
||||
return False
|
||||
|
||||
def get_model_info(path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about a saved model
|
||||
|
||||
Args:
|
||||
path: Path to the model file
|
||||
|
||||
Returns:
|
||||
dict: Model information
|
||||
"""
|
||||
info = {
|
||||
'exists': False,
|
||||
'size_bytes': 0,
|
||||
'has_optimizer': False,
|
||||
'parameters': {}
|
||||
}
|
||||
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
info['exists'] = True
|
||||
info['size_bytes'] = os.path.getsize(path)
|
||||
|
||||
# Try to load and inspect
|
||||
checkpoint = torch.load(path, map_location='cpu')
|
||||
info['has_optimizer'] = 'optimizer' in checkpoint
|
||||
|
||||
# Extract parameter info
|
||||
for key in ['epsilon', 'state_size', 'action_size', 'hidden_size']:
|
||||
if key in checkpoint:
|
||||
info['parameters'][key] = checkpoint[key]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model info for {path}: {e}")
|
||||
|
||||
return info
|
||||
|
||||
def verify_save_load_cycle(model: Any, test_path: str) -> bool:
|
||||
"""
|
||||
Test that a model can be saved and loaded correctly
|
||||
|
||||
Args:
|
||||
model: Model to test
|
||||
test_path: Path for test file
|
||||
|
||||
Returns:
|
||||
bool: True if save/load cycle successful
|
||||
"""
|
||||
try:
|
||||
# Save the model
|
||||
if not robust_save(model, test_path):
|
||||
return False
|
||||
|
||||
# Create a new model instance (this would need model creation logic)
|
||||
# For now, just verify the file exists and has content
|
||||
if os.path.exists(test_path) and os.path.getsize(test_path) > 0:
|
||||
logger.info("Save/load cycle verification successful")
|
||||
# Clean up test file
|
||||
os.remove(test_path)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Save/load cycle verification failed: {e}")
|
||||
return False
|
||||
@@ -1,238 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Port Management Utility
|
||||
|
||||
This script provides utilities to:
|
||||
1. Find available ports in a specified range
|
||||
2. Kill stale processes running on specific ports
|
||||
3. Kill all debug/training instances
|
||||
|
||||
Usage:
|
||||
- As a module: import port_manager and use its functions
|
||||
- Directly: python port_manager.py --kill-stale --min-port 6000 --max-port 7000
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import socket
|
||||
import argparse
|
||||
import psutil
|
||||
import logging
|
||||
import time
|
||||
import signal
|
||||
from typing import List, Tuple, Optional, Set
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger('port_manager')
|
||||
|
||||
# Define process names to look for when killing stale instances
|
||||
DEBUG_PROCESS_KEYWORDS = [
|
||||
'tensorboard',
|
||||
'python train_',
|
||||
'realtime.py',
|
||||
'train_rl_with_realtime.py'
|
||||
]
|
||||
|
||||
def is_port_in_use(port: int) -> bool:
|
||||
"""
|
||||
Check if a port is in use
|
||||
|
||||
Args:
|
||||
port (int): Port number to check
|
||||
|
||||
Returns:
|
||||
bool: True if port is in use, False otherwise
|
||||
"""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
def find_available_port(start_port: int, end_port: int) -> Optional[int]:
|
||||
"""
|
||||
Find an available port in the specified range
|
||||
|
||||
Args:
|
||||
start_port (int): Lower bound of port range
|
||||
end_port (int): Upper bound of port range
|
||||
|
||||
Returns:
|
||||
Optional[int]: Available port number or None if no ports available
|
||||
"""
|
||||
for port in range(start_port, end_port + 1):
|
||||
if not is_port_in_use(port):
|
||||
return port
|
||||
return None
|
||||
|
||||
def get_process_by_port(port: int) -> List[psutil.Process]:
|
||||
"""
|
||||
Get processes using a specific port
|
||||
|
||||
Args:
|
||||
port (int): Port number to check
|
||||
|
||||
Returns:
|
||||
List[psutil.Process]: List of processes using the port
|
||||
"""
|
||||
processes = []
|
||||
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
|
||||
try:
|
||||
for conn in proc.connections(kind='inet'):
|
||||
if conn.laddr.port == port:
|
||||
processes.append(proc)
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
pass
|
||||
return processes
|
||||
|
||||
def kill_process_by_port(port: int) -> Tuple[int, List[str]]:
|
||||
"""
|
||||
Kill processes using a specific port
|
||||
|
||||
Args:
|
||||
port (int): Port number to check
|
||||
|
||||
Returns:
|
||||
Tuple[int, List[str]]: Count of killed processes and their names
|
||||
"""
|
||||
processes = get_process_by_port(port)
|
||||
killed = []
|
||||
|
||||
for proc in processes:
|
||||
try:
|
||||
proc_name = " ".join(proc.cmdline()) if proc.cmdline() else proc.name()
|
||||
logger.info(f"Terminating process {proc.pid}: {proc_name}")
|
||||
proc.terminate()
|
||||
killed.append(proc_name)
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
pass
|
||||
|
||||
# Give processes time to terminate gracefully
|
||||
if processes:
|
||||
time.sleep(0.5)
|
||||
|
||||
# Force kill any remaining processes
|
||||
for proc in processes:
|
||||
try:
|
||||
if proc.is_running():
|
||||
logger.info(f"Force killing process {proc.pid}")
|
||||
proc.kill()
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
pass
|
||||
|
||||
return len(killed), killed
|
||||
|
||||
def kill_stale_debug_instances() -> Tuple[int, Set[str]]:
|
||||
"""
|
||||
Kill all stale debug and training instances based on process names
|
||||
|
||||
Returns:
|
||||
Tuple[int, Set[str]]: Count of killed processes and their names
|
||||
"""
|
||||
killed_count = 0
|
||||
killed_procs = set()
|
||||
|
||||
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
|
||||
try:
|
||||
cmd = " ".join(proc.cmdline()) if proc.cmdline() else proc.name()
|
||||
|
||||
# Check if this is a debug/training process we should kill
|
||||
if any(keyword in cmd for keyword in DEBUG_PROCESS_KEYWORDS):
|
||||
logger.info(f"Terminating stale process {proc.pid}: {cmd}")
|
||||
proc.terminate()
|
||||
killed_count += 1
|
||||
killed_procs.add(cmd)
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
pass
|
||||
|
||||
# Give processes time to terminate
|
||||
if killed_count > 0:
|
||||
time.sleep(1)
|
||||
|
||||
# Force kill any remaining processes
|
||||
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
|
||||
try:
|
||||
cmd = " ".join(proc.cmdline()) if proc.cmdline() else proc.name()
|
||||
|
||||
if any(keyword in cmd for keyword in DEBUG_PROCESS_KEYWORDS) and proc.is_running():
|
||||
logger.info(f"Force killing stale process {proc.pid}")
|
||||
proc.kill()
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
pass
|
||||
|
||||
return killed_count, killed_procs
|
||||
|
||||
def get_port_with_fallback(preferred_port: int, min_port: int, max_port: int) -> int:
|
||||
"""
|
||||
Try to use preferred port, fall back to any available port in range
|
||||
|
||||
Args:
|
||||
preferred_port (int): Preferred port to use
|
||||
min_port (int): Minimum port in fallback range
|
||||
max_port (int): Maximum port in fallback range
|
||||
|
||||
Returns:
|
||||
int: Available port number
|
||||
"""
|
||||
# First try the preferred port
|
||||
if not is_port_in_use(preferred_port):
|
||||
return preferred_port
|
||||
|
||||
# If preferred port is in use, try to free it
|
||||
logger.info(f"Preferred port {preferred_port} is in use, attempting to free it")
|
||||
kill_count, _ = kill_process_by_port(preferred_port)
|
||||
|
||||
if kill_count > 0 and not is_port_in_use(preferred_port):
|
||||
logger.info(f"Successfully freed port {preferred_port}")
|
||||
return preferred_port
|
||||
|
||||
# If we couldn't free the preferred port, find another available port
|
||||
logger.info(f"Looking for available port in range {min_port}-{max_port}")
|
||||
available_port = find_available_port(min_port, max_port)
|
||||
|
||||
if available_port:
|
||||
logger.info(f"Using alternative port: {available_port}")
|
||||
return available_port
|
||||
else:
|
||||
# If no ports are available, force kill processes in the entire range
|
||||
logger.warning(f"No available ports in range {min_port}-{max_port}, freeing ports")
|
||||
for port in range(min_port, max_port + 1):
|
||||
kill_process_by_port(port)
|
||||
|
||||
# Try again
|
||||
available_port = find_available_port(min_port, max_port)
|
||||
if available_port:
|
||||
logger.info(f"Using port {available_port} after freeing")
|
||||
return available_port
|
||||
else:
|
||||
logger.error(f"Could not find available port even after freeing range {min_port}-{max_port}")
|
||||
raise RuntimeError(f"No available ports in range {min_port}-{max_port}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Port management utility')
|
||||
parser.add_argument('--kill-stale', action='store_true', help='Kill all stale debug instances')
|
||||
parser.add_argument('--free-port', type=int, help='Free a specific port')
|
||||
parser.add_argument('--find-port', action='store_true', help='Find an available port')
|
||||
parser.add_argument('--min-port', type=int, default=6000, help='Minimum port in range')
|
||||
parser.add_argument('--max-port', type=int, default=7000, help='Maximum port in range')
|
||||
parser.add_argument('--preferred-port', type=int, help='Preferred port to use')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.kill_stale:
|
||||
count, procs = kill_stale_debug_instances()
|
||||
logger.info(f"Killed {count} stale processes")
|
||||
for proc in procs:
|
||||
logger.info(f" - {proc}")
|
||||
|
||||
if args.free_port:
|
||||
count, killed = kill_process_by_port(args.free_port)
|
||||
logger.info(f"Killed {count} processes using port {args.free_port}")
|
||||
for proc in killed:
|
||||
logger.info(f" - {proc}")
|
||||
|
||||
if args.find_port or args.preferred_port:
|
||||
preferred = args.preferred_port if args.preferred_port else args.min_port
|
||||
port = get_port_with_fallback(preferred, args.min_port, args.max_port)
|
||||
print(port) # Print only the port number for easy capture in scripts
|
||||
156
utils/text_logger.py
Normal file
156
utils/text_logger.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Text File Logger for Trading System
|
||||
|
||||
Simple text file logging for tracking inference records and system events
|
||||
Provides human-readable logs alongside database storage
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TextLogger:
|
||||
"""Simple text file logger for trading system events"""
|
||||
|
||||
def __init__(self, log_dir: str = "logs"):
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Create separate log files for different types of events
|
||||
self.inference_log = self.log_dir / "inference_records.txt"
|
||||
self.checkpoint_log = self.log_dir / "checkpoint_events.txt"
|
||||
self.system_log = self.log_dir / "system_events.txt"
|
||||
|
||||
def log_inference(self, model_name: str, symbol: str, action: str,
|
||||
confidence: float, processing_time_ms: float,
|
||||
checkpoint_id: str = None) -> bool:
|
||||
"""Log inference record to text file"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
checkpoint_info = f" [checkpoint: {checkpoint_id}]" if checkpoint_id else ""
|
||||
|
||||
log_entry = (
|
||||
f"{timestamp} | {model_name:15} | {symbol:10} | "
|
||||
f"{action:4} | conf={confidence:.3f} | "
|
||||
f"time={processing_time_ms:6.1f}ms{checkpoint_info}\n"
|
||||
)
|
||||
|
||||
with open(self.inference_log, 'a', encoding='utf-8') as f:
|
||||
f.write(log_entry)
|
||||
f.flush()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log inference to text file: {e}")
|
||||
return False
|
||||
|
||||
def log_checkpoint_event(self, model_name: str, event_type: str,
|
||||
checkpoint_id: str, details: str = "") -> bool:
|
||||
"""Log checkpoint events (save, load, etc.)"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
details_str = f" - {details}" if details else ""
|
||||
|
||||
log_entry = (
|
||||
f"{timestamp} | {model_name:15} | {event_type:10} | "
|
||||
f"{checkpoint_id}{details_str}\n"
|
||||
)
|
||||
|
||||
with open(self.checkpoint_log, 'a', encoding='utf-8') as f:
|
||||
f.write(log_entry)
|
||||
f.flush()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log checkpoint event to text file: {e}")
|
||||
return False
|
||||
|
||||
def log_system_event(self, event_type: str, message: str,
|
||||
component: str = "system") -> bool:
|
||||
"""Log general system events"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
log_entry = (
|
||||
f"{timestamp} | {component:15} | {event_type:10} | {message}\n"
|
||||
)
|
||||
|
||||
with open(self.system_log, 'a', encoding='utf-8') as f:
|
||||
f.write(log_entry)
|
||||
f.flush()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log system event to text file: {e}")
|
||||
return False
|
||||
|
||||
def get_recent_inferences(self, lines: int = 50) -> str:
|
||||
"""Get recent inference records from text file"""
|
||||
try:
|
||||
if not self.inference_log.exists():
|
||||
return "No inference records found"
|
||||
|
||||
with open(self.inference_log, 'r', encoding='utf-8') as f:
|
||||
all_lines = f.readlines()
|
||||
recent_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
|
||||
return ''.join(recent_lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read inference log: {e}")
|
||||
return f"Error reading log: {e}"
|
||||
|
||||
def get_recent_checkpoint_events(self, lines: int = 20) -> str:
|
||||
"""Get recent checkpoint events from text file"""
|
||||
try:
|
||||
if not self.checkpoint_log.exists():
|
||||
return "No checkpoint events found"
|
||||
|
||||
with open(self.checkpoint_log, 'r', encoding='utf-8') as f:
|
||||
all_lines = f.readlines()
|
||||
recent_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
|
||||
return ''.join(recent_lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read checkpoint log: {e}")
|
||||
return f"Error reading log: {e}"
|
||||
|
||||
def cleanup_old_logs(self, max_lines: int = 10000) -> bool:
|
||||
"""Keep only the most recent log entries"""
|
||||
try:
|
||||
for log_file in [self.inference_log, self.checkpoint_log, self.system_log]:
|
||||
if log_file.exists():
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
if len(lines) > max_lines:
|
||||
# Keep only the most recent lines
|
||||
recent_lines = lines[-max_lines:]
|
||||
with open(log_file, 'w', encoding='utf-8') as f:
|
||||
f.writelines(recent_lines)
|
||||
|
||||
logger.info(f"Cleaned up {log_file.name}: kept {len(recent_lines)} lines")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup logs: {e}")
|
||||
return False
|
||||
|
||||
# Global text logger instance
|
||||
_text_logger_instance = None
|
||||
|
||||
def get_text_logger(log_dir: str = "logs") -> TextLogger:
|
||||
"""Get the global text logger instance"""
|
||||
global _text_logger_instance
|
||||
|
||||
if _text_logger_instance is None:
|
||||
_text_logger_instance = TextLogger(log_dir)
|
||||
|
||||
return _text_logger_instance
|
||||
252
utils/timezone_utils.py
Normal file
252
utils/timezone_utils.py
Normal file
@@ -0,0 +1,252 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Centralized timezone utilities for the trading system
|
||||
|
||||
This module provides consistent timezone handling across all components:
|
||||
- All external data (Binance, MEXC) comes in UTC
|
||||
- All internal processing uses Europe/Sofia timezone
|
||||
- All timestamps stored in database are timezone-aware
|
||||
- All NN model inputs use consistent timezone
|
||||
"""
|
||||
|
||||
import pytz
|
||||
import pandas as pd
|
||||
from datetime import datetime, timezone
|
||||
from typing import Union, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define timezone constants
|
||||
UTC = pytz.UTC
|
||||
SOFIA_TZ = pytz.timezone('Europe/Sofia')
|
||||
SYSTEM_TIMEZONE = SOFIA_TZ # Our system's primary timezone
|
||||
|
||||
def get_system_timezone():
|
||||
"""Get the system's primary timezone (Europe/Sofia)"""
|
||||
return SYSTEM_TIMEZONE
|
||||
|
||||
def get_utc_timezone():
|
||||
"""Get UTC timezone"""
|
||||
return UTC
|
||||
|
||||
def now_utc() -> datetime:
|
||||
"""Get current time in UTC"""
|
||||
return datetime.now(UTC)
|
||||
|
||||
def now_sofia() -> datetime:
|
||||
"""Get current time in Sofia timezone"""
|
||||
return datetime.now(SOFIA_TZ)
|
||||
|
||||
def now_system() -> datetime:
|
||||
"""Get current time in system timezone (Sofia)"""
|
||||
return now_sofia()
|
||||
|
||||
def to_utc(dt: Union[datetime, pd.Timestamp]) -> datetime:
|
||||
"""Convert datetime to UTC timezone"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
if isinstance(dt, pd.Timestamp):
|
||||
dt = dt.to_pydatetime()
|
||||
|
||||
if dt.tzinfo is None:
|
||||
# Assume it's in system timezone if no timezone info
|
||||
dt = SYSTEM_TIMEZONE.localize(dt)
|
||||
|
||||
return dt.astimezone(UTC)
|
||||
|
||||
def to_sofia(dt: Union[datetime, pd.Timestamp]) -> datetime:
|
||||
"""Convert datetime to Sofia timezone"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
if isinstance(dt, pd.Timestamp):
|
||||
dt = dt.to_pydatetime()
|
||||
|
||||
if dt.tzinfo is None:
|
||||
# Assume it's UTC if no timezone info (common for external data)
|
||||
dt = UTC.localize(dt)
|
||||
|
||||
return dt.astimezone(SOFIA_TZ)
|
||||
|
||||
def to_system_timezone(dt: Union[datetime, pd.Timestamp]) -> datetime:
|
||||
"""Convert datetime to system timezone (Sofia)"""
|
||||
return to_sofia(dt)
|
||||
|
||||
def normalize_timestamp(timestamp: Union[int, float, str, datetime, pd.Timestamp],
|
||||
source_tz: Optional[pytz.BaseTzInfo] = None) -> datetime:
|
||||
"""
|
||||
Normalize various timestamp formats to system timezone (Sofia)
|
||||
|
||||
Args:
|
||||
timestamp: Timestamp in various formats
|
||||
source_tz: Source timezone (defaults to UTC for external data)
|
||||
|
||||
Returns:
|
||||
datetime: Timezone-aware datetime in system timezone
|
||||
"""
|
||||
if timestamp is None:
|
||||
return now_system()
|
||||
|
||||
# Default source timezone is UTC (most external APIs use UTC)
|
||||
if source_tz is None:
|
||||
source_tz = UTC
|
||||
|
||||
try:
|
||||
# Handle different timestamp formats
|
||||
if isinstance(timestamp, (int, float)):
|
||||
# Unix timestamp (assume seconds, convert to milliseconds if needed)
|
||||
if timestamp > 1e10: # Milliseconds
|
||||
timestamp = timestamp / 1000
|
||||
dt = datetime.fromtimestamp(timestamp, tz=source_tz)
|
||||
|
||||
elif isinstance(timestamp, str):
|
||||
# String timestamp
|
||||
dt = pd.to_datetime(timestamp)
|
||||
if dt.tzinfo is None:
|
||||
dt = source_tz.localize(dt)
|
||||
|
||||
elif isinstance(timestamp, pd.Timestamp):
|
||||
dt = timestamp.to_pydatetime()
|
||||
if dt.tzinfo is None:
|
||||
dt = source_tz.localize(dt)
|
||||
|
||||
elif isinstance(timestamp, datetime):
|
||||
dt = timestamp
|
||||
if dt.tzinfo is None:
|
||||
dt = source_tz.localize(dt)
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown timestamp format: {type(timestamp)}")
|
||||
return now_system()
|
||||
|
||||
# Convert to system timezone
|
||||
return dt.astimezone(SYSTEM_TIMEZONE)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing timestamp {timestamp}: {e}")
|
||||
return now_system()
|
||||
|
||||
def normalize_dataframe_timestamps(df: pd.DataFrame,
|
||||
timestamp_col: str = 'timestamp',
|
||||
source_tz: Optional[pytz.BaseTzInfo] = None) -> pd.DataFrame:
|
||||
"""
|
||||
Normalize timestamps in a DataFrame to system timezone
|
||||
|
||||
Args:
|
||||
df: DataFrame with timestamp column
|
||||
timestamp_col: Name of timestamp column
|
||||
source_tz: Source timezone (defaults to UTC)
|
||||
|
||||
Returns:
|
||||
DataFrame with normalized timestamps
|
||||
"""
|
||||
if df.empty or timestamp_col not in df.columns:
|
||||
return df
|
||||
|
||||
if source_tz is None:
|
||||
source_tz = UTC
|
||||
|
||||
try:
|
||||
# Convert to datetime if not already
|
||||
if not pd.api.types.is_datetime64_any_dtype(df[timestamp_col]):
|
||||
df[timestamp_col] = pd.to_datetime(df[timestamp_col])
|
||||
|
||||
# Handle timezone
|
||||
if df[timestamp_col].dt.tz is None:
|
||||
# Localize to source timezone first
|
||||
df[timestamp_col] = df[timestamp_col].dt.tz_localize(source_tz)
|
||||
|
||||
# Convert to system timezone
|
||||
df[timestamp_col] = df[timestamp_col].dt.tz_convert(SYSTEM_TIMEZONE)
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing DataFrame timestamps: {e}")
|
||||
return df
|
||||
|
||||
def normalize_dataframe_index(df: pd.DataFrame,
|
||||
source_tz: Optional[pytz.BaseTzInfo] = None) -> pd.DataFrame:
|
||||
"""
|
||||
Normalize DataFrame index timestamps to system timezone
|
||||
|
||||
Args:
|
||||
df: DataFrame with datetime index
|
||||
source_tz: Source timezone (defaults to UTC)
|
||||
|
||||
Returns:
|
||||
DataFrame with normalized index
|
||||
"""
|
||||
if df.empty or not isinstance(df.index, pd.DatetimeIndex):
|
||||
return df
|
||||
|
||||
if source_tz is None:
|
||||
source_tz = UTC
|
||||
|
||||
try:
|
||||
# Handle timezone
|
||||
if df.index.tz is None:
|
||||
# Localize to source timezone first
|
||||
df.index = df.index.tz_localize(source_tz)
|
||||
|
||||
# Convert to system timezone
|
||||
df.index = df.index.tz_convert(SYSTEM_TIMEZONE)
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing DataFrame index: {e}")
|
||||
return df
|
||||
|
||||
def format_timestamp_for_display(dt: datetime, format_str: str = '%H:%M:%S') -> str:
|
||||
"""
|
||||
Format timestamp for display in system timezone
|
||||
|
||||
Args:
|
||||
dt: Datetime to format
|
||||
format_str: Format string
|
||||
|
||||
Returns:
|
||||
Formatted timestamp string
|
||||
"""
|
||||
if dt is None:
|
||||
return now_system().strftime(format_str)
|
||||
|
||||
try:
|
||||
# Convert to system timezone if needed
|
||||
if isinstance(dt, datetime):
|
||||
if dt.tzinfo is None:
|
||||
dt = UTC.localize(dt)
|
||||
dt = dt.astimezone(SYSTEM_TIMEZONE)
|
||||
|
||||
return dt.strftime(format_str)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting timestamp {dt}: {e}")
|
||||
return now_system().strftime(format_str)
|
||||
|
||||
def get_timezone_offset_hours() -> float:
|
||||
"""Get current timezone offset from UTC in hours"""
|
||||
now = now_system()
|
||||
utc_now = now_utc()
|
||||
offset_seconds = (now - utc_now.replace(tzinfo=None)).total_seconds()
|
||||
return offset_seconds / 3600
|
||||
|
||||
def is_market_hours() -> bool:
|
||||
"""Check if it's currently market hours (24/7 for crypto, but useful for logging)"""
|
||||
# Crypto markets are 24/7, but this can be useful for other purposes
|
||||
return True
|
||||
|
||||
def log_timezone_info():
|
||||
"""Log current timezone information for debugging"""
|
||||
now_utc_time = now_utc()
|
||||
now_sofia_time = now_sofia()
|
||||
offset_hours = get_timezone_offset_hours()
|
||||
|
||||
logger.info(f"Timezone Info:")
|
||||
logger.info(f" UTC Time: {now_utc_time}")
|
||||
logger.info(f" Sofia Time: {now_sofia_time}")
|
||||
logger.info(f" Offset: {offset_hours:+.1f} hours from UTC")
|
||||
logger.info(f" System Timezone: {SYSTEM_TIMEZONE}")
|
||||
190
utils/training_integration.py
Normal file
190
utils/training_integration.py
Normal file
@@ -0,0 +1,190 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training Integration for Checkpoint Management
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from .checkpoint_manager import get_checkpoint_manager, load_best_checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TrainingIntegration:
|
||||
def __init__(self, enable_wandb: bool = True):
|
||||
self.enable_wandb = enable_wandb
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
|
||||
|
||||
def save_cnn_checkpoint(self,
|
||||
cnn_model,
|
||||
model_name: str,
|
||||
epoch: int,
|
||||
train_accuracy: float,
|
||||
val_accuracy: float,
|
||||
train_loss: float,
|
||||
val_loss: float,
|
||||
training_time_hours: float = None) -> bool:
|
||||
try:
|
||||
performance_metrics = {
|
||||
'accuracy': train_accuracy,
|
||||
'val_accuracy': val_accuracy,
|
||||
'loss': train_loss,
|
||||
'val_loss': val_loss
|
||||
}
|
||||
|
||||
training_metadata = {
|
||||
'epoch': epoch,
|
||||
'training_time_hours': training_time_hours,
|
||||
'total_parameters': self._count_parameters(cnn_model)
|
||||
}
|
||||
|
||||
if self.enable_wandb:
|
||||
try:
|
||||
import wandb
|
||||
if wandb.run is not None:
|
||||
wandb.log({
|
||||
f"{model_name}/train_accuracy": train_accuracy,
|
||||
f"{model_name}/val_accuracy": val_accuracy,
|
||||
f"{model_name}/train_loss": train_loss,
|
||||
f"{model_name}/val_loss": val_loss,
|
||||
f"{model_name}/epoch": epoch
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging to W&B: {e}")
|
||||
|
||||
# Save the model first to get the path
|
||||
model_path = f"models/{model_name}_temp.pt"
|
||||
torch.save(cnn_model.state_dict(), model_path)
|
||||
|
||||
metadata = self.checkpoint_manager.save_checkpoint(
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
model_type='cnn',
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f"CNN checkpoint saved: {metadata.checkpoint_id}")
|
||||
return True
|
||||
else:
|
||||
logger.info(f"CNN checkpoint not saved (performance not improved)")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving CNN checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def save_rl_checkpoint(self,
|
||||
rl_agent,
|
||||
model_name: str,
|
||||
episode: int,
|
||||
avg_reward: float,
|
||||
best_reward: float,
|
||||
epsilon: float,
|
||||
total_pnl: float = None) -> bool:
|
||||
try:
|
||||
performance_metrics = {
|
||||
'reward': avg_reward,
|
||||
'best_reward': best_reward
|
||||
}
|
||||
|
||||
if total_pnl is not None:
|
||||
performance_metrics['pnl'] = total_pnl
|
||||
|
||||
training_metadata = {
|
||||
'episode': episode,
|
||||
'epsilon': epsilon,
|
||||
'total_parameters': self._count_parameters(rl_agent)
|
||||
}
|
||||
|
||||
if self.enable_wandb:
|
||||
try:
|
||||
import wandb
|
||||
if wandb.run is not None:
|
||||
wandb.log({
|
||||
f"{model_name}/avg_reward": avg_reward,
|
||||
f"{model_name}/best_reward": best_reward,
|
||||
f"{model_name}/epsilon": epsilon,
|
||||
f"{model_name}/episode": episode
|
||||
})
|
||||
|
||||
if total_pnl is not None:
|
||||
wandb.log({f"{model_name}/total_pnl": total_pnl})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging to W&B: {e}")
|
||||
|
||||
# Save the model first to get the path
|
||||
model_path = f"models/{model_name}_temp.pt"
|
||||
torch.save(rl_agent.state_dict() if hasattr(rl_agent, 'state_dict') else rl_agent, model_path)
|
||||
|
||||
metadata = self.checkpoint_manager.save_checkpoint(
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
model_type='rl',
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f"RL checkpoint saved: {metadata.checkpoint_id}")
|
||||
return True
|
||||
else:
|
||||
logger.info(f"RL checkpoint not saved (performance not improved)")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving RL checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def load_best_model(self, model_name: str, model_class=None):
|
||||
try:
|
||||
result = self.checkpoint_manager.load_best_checkpoint(model_name)
|
||||
if not result:
|
||||
logger.warning(f"No checkpoint found for model: {model_name}")
|
||||
return None
|
||||
|
||||
file_path, metadata = result
|
||||
|
||||
checkpoint = torch.load(file_path, map_location='cpu')
|
||||
|
||||
logger.info(f"Loaded best checkpoint for {model_name}:")
|
||||
logger.info(f" Performance score: {metadata.performance_score:.4f}")
|
||||
logger.info(f" Created: {metadata.created_at}")
|
||||
|
||||
if model_class and 'model_state_dict' in checkpoint:
|
||||
model = model_class()
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
return model
|
||||
|
||||
return checkpoint
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best model {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def _count_parameters(self, model) -> int:
|
||||
try:
|
||||
if hasattr(model, 'parameters'):
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
elif hasattr(model, 'policy_net'):
|
||||
policy_params = sum(p.numel() for p in model.policy_net.parameters())
|
||||
target_params = sum(p.numel() for p in model.target_net.parameters()) if hasattr(model, 'target_net') else 0
|
||||
return policy_params + target_params
|
||||
else:
|
||||
return 0
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
_training_integration = None
|
||||
|
||||
def get_training_integration() -> TrainingIntegration:
|
||||
global _training_integration
|
||||
if _training_integration is None:
|
||||
_training_integration = TrainingIntegration()
|
||||
return _training_integration
|
||||
Reference in New Issue
Block a user