Merge commit 'd49a473ed6f4aef55bfdd47d6370e53582be6b7b' into cleanup

This commit is contained in:
Dobromir Popov
2025-10-01 00:32:19 +03:00
353 changed files with 81004 additions and 35899 deletions

View File

387
utils/audit_plotter.py Normal file
View File

@@ -0,0 +1,387 @@
"""
Audit Plotter
Create PNG snapshots of model input data at inference time:
- Subplot 1: 1s candlesticks for recent window
- Subplot 2: COB bucket volumes and imbalance near current price
Windows-safe, ASCII-only logging messages.
"""
from __future__ import annotations
import os
import math
import logging
from datetime import datetime
from typing import List, Tuple
try:
import matplotlib
# Use a non-interactive backend suitable for headless servers
matplotlib.use("Agg")
import matplotlib.pyplot as plt
except Exception:
matplotlib = None # type: ignore
plt = None # type: ignore
logger = logging.getLogger(__name__)
def _extract_recent_ohlcv(base_data, max_bars: int = 120) -> Tuple[List[datetime], List[float], List[float], List[float], List[float]]:
"""Return recent 1s OHLCV arrays (time, open, high, low, close). Falls back to 1m if needed."""
series = base_data.ohlcv_1s if getattr(base_data, "ohlcv_1s", None) else []
if not series or len(series) < 5:
series = base_data.ohlcv_1m if getattr(base_data, "ohlcv_1m", None) else []
series = series[-max_bars:] if series else []
times = [b.timestamp for b in series]
opens = [float(b.open) for b in series]
highs = [float(b.high) for b in series]
lows = [float(b.low) for b in series]
closes = [float(b.close) for b in series]
return times, opens, highs, lows, closes
def _extract_timeframe_ohlcv(base_data, timeframe_attr: str, max_bars: int = 60) -> Tuple[List[datetime], List[float], List[float], List[float], List[float]]:
"""Extract OHLCV data for a specific timeframe attribute."""
series = getattr(base_data, timeframe_attr, []) if hasattr(base_data, timeframe_attr) else []
series = series[-max_bars:] if series else []
if not series:
return [], [], [], [], []
times = [b.timestamp for b in series]
opens = [float(b.open) for b in series]
highs = [float(b.high) for b in series]
lows = [float(b.low) for b in series]
closes = [float(b.close) for b in series]
return times, opens, highs, lows, closes
def _plot_candlesticks(ax, times, opens, highs, lows, closes, title):
"""Plot candlestick chart on given axis."""
if not times:
ax.text(0.5, 0.5, f"No {title} data", ha="center", va="center", transform=ax.transAxes)
ax.set_title(title)
return
x = list(range(len(times)))
# Plot high-low wicks
ax.vlines(x, lows, highs, color="#444444", linewidth=0.8)
# Plot body as rectangles
bodies = [closes[i] - opens[i] for i in range(len(opens))]
bottoms = [min(opens[i], closes[i]) for i in range(len(opens))]
colors = ["#00aa55" if bodies[i] >= 0 else "#cc3333" for i in range(len(bodies))]
heights = [abs(bodies[i]) if abs(bodies[i]) > 1e-9 else 1e-9 for i in range(len(bodies))]
ax.bar(x, heights, bottom=bottoms, color=colors, width=0.6, align="center",
edgecolor="#222222", linewidth=0.3)
ax.set_title(title, fontsize=10)
ax.grid(True, linestyle=":", linewidth=0.4, alpha=0.6)
# Show recent price
if closes:
ax.text(0.02, 0.98, f"${closes[-1]:.2f}", transform=ax.transAxes,
verticalalignment='top', fontsize=8, fontweight='bold')
def _plot_data_summary(ax, base_data, symbol):
"""Plot data summary statistics."""
ax.axis('off')
# Collect data statistics
stats = []
# ETH timeframes
for tf, attr in [("1s", "ohlcv_1s"), ("1m", "ohlcv_1m"), ("1h", "ohlcv_1h"), ("1d", "ohlcv_1d")]:
data = getattr(base_data, attr, []) if hasattr(base_data, attr) else []
stats.append(f"ETH {tf}: {len(data)} bars")
# BTC data
btc_data = getattr(base_data, "btc_ohlcv_1s", []) if hasattr(base_data, "btc_ohlcv_1s") else []
stats.append(f"BTC 1s: {len(btc_data)} bars")
# COB data
cob = getattr(base_data, "cob_data", None)
if cob:
if hasattr(cob, "price_buckets") and cob.price_buckets:
stats.append(f"COB buckets: {len(cob.price_buckets)}")
elif hasattr(cob, "bids") and hasattr(cob, "asks"):
bids = getattr(cob, "bids", [])
asks = getattr(cob, "asks", [])
stats.append(f"COB levels: {len(bids)}b/{len(asks)}a")
else:
stats.append("COB: No data")
else:
stats.append("COB: Missing")
# Technical indicators
tech_indicators = getattr(base_data, "technical_indicators", {}) if hasattr(base_data, "technical_indicators") else {}
stats.append(f"Tech indicators: {len(tech_indicators)}")
# Display stats
y_pos = 0.9
ax.text(0.05, y_pos, "Data Summary:", fontweight='bold', transform=ax.transAxes)
y_pos -= 0.12
for stat in stats:
ax.text(0.05, y_pos, stat, fontsize=9, transform=ax.transAxes)
y_pos -= 0.1
ax.set_title("Input Data Stats", fontsize=10)
def _plot_cob_data(ax, prices, bid_v, ask_v, imb, current_price, symbol):
"""Plot COB data with bid/ask volumes and imbalance."""
if not prices:
ax.text(0.5, 0.5, f"No COB data for {symbol}", ha="center", va="center")
ax.set_title("COB Data - No Data Available")
return
# Normalize x as offsets around current price if available
if current_price > 0:
xvals = [p - current_price for p in prices]
ax.axvline(0.0, color="#666666", linestyle="--", linewidth=1.0, alpha=0.7)
ax.set_xlabel("Price offset from current")
else:
xvals = prices
ax.set_xlabel("Price")
# Plot bid/ask volumes
ax.plot(xvals, bid_v, label="Bid Volume", color="#2c7fb8", linewidth=1.5)
ax.plot(xvals, ask_v, label="Ask Volume", color="#d95f0e", linewidth=1.5)
# Secondary axis for imbalance
ax2 = ax.twinx()
ax2.plot(xvals, imb, label="Imbalance", color="#6a3d9a", linewidth=2, alpha=0.8)
ax2.set_ylabel("Imbalance", color="#6a3d9a")
ax2.tick_params(axis='y', labelcolor="#6a3d9a")
ax.set_ylabel("Volume")
ax.grid(True, linestyle=":", linewidth=0.6, alpha=0.6)
# Combined legend
lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines + lines2, labels + labels2, loc="upper right")
# Title with current price info
price_info = f" (${current_price:.2f})" if current_price > 0 else ""
ax.set_title(f"COB Price Buckets - {symbol}{price_info}", fontsize=11)
def _extract_cob(base_data, max_buckets: int = 40):
"""Return sorted price buckets and metrics from COBData."""
cob = getattr(base_data, "cob_data", None)
# Try to get price buckets from COBData object
if cob is not None and hasattr(cob, "price_buckets") and cob.price_buckets:
# Sort by price and clip
prices = sorted(list(cob.price_buckets.keys()))[:max_buckets]
bid_vol = []
ask_vol = []
imb = []
for p in prices:
bucket = cob.price_buckets.get(p, {})
b = float(bucket.get("bid_volume", 0.0))
a = float(bucket.get("ask_volume", 0.0))
bid_vol.append(b)
ask_vol.append(a)
denom = (b + a) if (b + a) > 0 else 1.0
imb.append((b - a) / denom)
return prices, bid_vol, ask_vol, imb
# Fallback: try to extract from raw bids/asks if available
if cob is not None:
# Check if we have raw bids/asks data
bids = getattr(cob, "bids", []) if hasattr(cob, "bids") else []
asks = getattr(cob, "asks", []) if hasattr(cob, "asks") else []
current_price = getattr(cob, "current_price", 0.0) if hasattr(cob, "current_price") else 0.0
if bids and asks and current_price > 0:
# Create price buckets from raw data
bucket_size = 1.0 if hasattr(cob, "bucket_size") and cob.bucket_size else 1.0
buckets = {}
# Process bids
for bid in bids[:50]: # Top 50 levels
if isinstance(bid, dict):
price = float(bid.get("price", 0))
size = float(bid.get("size", 0))
elif isinstance(bid, list) and len(bid) >= 2:
price = float(bid[0])
size = float(bid[1])
else:
continue
if price > 0 and size > 0:
bucket_price = round(price / bucket_size) * bucket_size
if bucket_price not in buckets:
buckets[bucket_price] = {"bid_volume": 0.0, "ask_volume": 0.0}
buckets[bucket_price]["bid_volume"] += size * price
# Process asks
for ask in asks[:50]: # Top 50 levels
if isinstance(ask, dict):
price = float(ask.get("price", 0))
size = float(ask.get("size", 0))
elif isinstance(ask, list) and len(ask) >= 2:
price = float(ask[0])
size = float(ask[1])
else:
continue
if price > 0 and size > 0:
bucket_price = round(price / bucket_size) * bucket_size
if bucket_price not in buckets:
buckets[bucket_price] = {"bid_volume": 0.0, "ask_volume": 0.0}
buckets[bucket_price]["ask_volume"] += size * price
if buckets:
# Sort by price and clip
prices = sorted(list(buckets.keys()))[:max_buckets]
bid_vol = []
ask_vol = []
imb = []
for p in prices:
bucket = buckets.get(p, {})
b = float(bucket.get("bid_volume", 0.0))
a = float(bucket.get("ask_volume", 0.0))
bid_vol.append(b)
ask_vol.append(a)
denom = (b + a) if (b + a) > 0 else 1.0
imb.append((b - a) / denom)
return prices, bid_vol, ask_vol, imb
# No COB data available
return [], [], [], []
def save_inference_audit_image(base_data, model_name: str, symbol: str, out_root: str = "audit_inputs") -> str:
"""Save a comprehensive PNG snapshot of input data with all timeframes and COB data."""
if matplotlib is None or plt is None:
logger.warning("matplotlib not available; skipping audit image")
return ""
try:
# Debug: Log what data we have
logger.info(f"Creating audit image for {model_name} - {symbol}")
if hasattr(base_data, 'ohlcv_1s'):
logger.info(f"ETH 1s data: {len(base_data.ohlcv_1s)} bars")
if hasattr(base_data, 'ohlcv_1m'):
logger.info(f"ETH 1m data: {len(base_data.ohlcv_1m)} bars")
if hasattr(base_data, 'ohlcv_1h'):
logger.info(f"ETH 1h data: {len(base_data.ohlcv_1h)} bars")
if hasattr(base_data, 'ohlcv_1d'):
logger.info(f"ETH 1d data: {len(base_data.ohlcv_1d)} bars")
if hasattr(base_data, 'btc_ohlcv_1s'):
logger.info(f"BTC 1s data: {len(base_data.btc_ohlcv_1s)} bars")
if hasattr(base_data, 'cob_data') and base_data.cob_data:
cob = base_data.cob_data
logger.info(f"COB data available: current_price={getattr(cob, 'current_price', 'N/A')}")
if hasattr(cob, 'price_buckets') and cob.price_buckets:
logger.info(f"COB price buckets: {len(cob.price_buckets)} buckets")
elif hasattr(cob, 'bids') and hasattr(cob, 'asks'):
logger.info(f"COB raw data: {len(getattr(cob, 'bids', []))} bids, {len(getattr(cob, 'asks', []))} asks")
else:
logger.info("COB data exists but no price_buckets or bids/asks found")
else:
logger.warning("No COB data available for audit image")
# Ensure output directory structure
day_dir = datetime.utcnow().strftime("%Y%m%d")
out_dir = os.path.join(out_root, day_dir)
os.makedirs(out_dir, exist_ok=True)
# File name: {ts}_{symbol}_{model}.png (ASCII-only)
ts_str = datetime.utcnow().strftime("%H%M%S_%f")
safe_symbol = symbol.replace("/", "-")
fname = f"{ts_str}_{safe_symbol}_{model_name}.png"
out_path = os.path.join(out_dir, fname)
# Extract all timeframe data
eth_1s_times, eth_1s_o, eth_1s_h, eth_1s_l, eth_1s_c = _extract_timeframe_ohlcv(base_data, "ohlcv_1s", 60)
eth_1m_times, eth_1m_o, eth_1m_h, eth_1m_l, eth_1m_c = _extract_timeframe_ohlcv(base_data, "ohlcv_1m", 60)
eth_1h_times, eth_1h_o, eth_1h_h, eth_1h_l, eth_1h_c = _extract_timeframe_ohlcv(base_data, "ohlcv_1h", 24)
eth_1d_times, eth_1d_o, eth_1d_h, eth_1d_l, eth_1d_c = _extract_timeframe_ohlcv(base_data, "ohlcv_1d", 30)
btc_1s_times, btc_1s_o, btc_1s_h, btc_1s_l, btc_1s_c = _extract_timeframe_ohlcv(base_data, "btc_ohlcv_1s", 60)
# Extract COB data
prices, bid_v, ask_v, imb = _extract_cob(base_data)
current_price = float(getattr(getattr(base_data, "cob_data", None), "current_price", 0.0))
# Create comprehensive figure with multiple subplots
fig = plt.figure(figsize=(16, 12), dpi=110)
gs = fig.add_gridspec(3, 3, height_ratios=[2, 2, 1.5], width_ratios=[1, 1, 1])
# ETH 1s data (top left)
ax1 = fig.add_subplot(gs[0, 0])
_plot_candlesticks(ax1, eth_1s_times, eth_1s_o, eth_1s_h, eth_1s_l, eth_1s_c, f"ETH 1s (last 60)")
# ETH 1m data (top middle)
ax2 = fig.add_subplot(gs[0, 1])
_plot_candlesticks(ax2, eth_1m_times, eth_1m_o, eth_1m_h, eth_1m_l, eth_1m_c, f"ETH 1m (last 60)")
# ETH 1h data (top right)
ax3 = fig.add_subplot(gs[0, 2])
_plot_candlesticks(ax3, eth_1h_times, eth_1h_o, eth_1h_h, eth_1h_l, eth_1h_c, f"ETH 1h (last 24)")
# ETH 1d data (middle left)
ax4 = fig.add_subplot(gs[1, 0])
_plot_candlesticks(ax4, eth_1d_times, eth_1d_o, eth_1d_h, eth_1d_l, eth_1d_c, f"ETH 1d (last 30)")
# BTC 1s data (middle middle)
ax5 = fig.add_subplot(gs[1, 1])
_plot_candlesticks(ax5, btc_1s_times, btc_1s_o, btc_1s_h, btc_1s_l, btc_1s_c, f"BTC 1s (last 60)")
# Data summary (middle right)
ax6 = fig.add_subplot(gs[1, 2])
_plot_data_summary(ax6, base_data, symbol)
# COB data (bottom, spanning all columns) with optional heatmap overlay above it
ax7 = fig.add_subplot(gs[2, :])
_plot_cob_data(ax7, prices, bid_v, ask_v, imb, current_price, symbol)
# Optional: append a small heatmap figure to the side if available
try:
heat_times = getattr(base_data, 'cob_heatmap_times', [])
heat_prices = getattr(base_data, 'cob_heatmap_prices', [])
heat_vals = getattr(base_data, 'cob_heatmap_values', [])
if heat_times and heat_prices and heat_vals:
import numpy as np
# Create an inset axes on ax7 for compact heatmap
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
inset_ax = inset_axes(ax7, width="25%", height="100%", loc='upper right', borderpad=1)
z = np.array(heat_vals, dtype=float)
if z.size > 0:
col_max = np.maximum(z.max(axis=0), 1e-9)
zn = (z / col_max).T
inset_ax.imshow(zn, aspect='auto', origin='lower', cmap='turbo')
inset_ax.set_title('COB Heatmap', fontsize=8)
inset_ax.set_xticks([])
inset_ax.set_yticks([])
except Exception as _hm_ex:
logger.debug(f"Audit heatmap overlay skipped: {_hm_ex}")
# Add overall title with model and timestamp info
fig.suptitle(f"{model_name} - {safe_symbol} - {datetime.utcnow().strftime('%H:%M:%S')}",
fontsize=14, fontweight='bold')
fig.tight_layout()
fig.savefig(out_path, bbox_inches="tight")
plt.close(fig)
logger.info(f"Saved comprehensive audit image: {out_path}")
return out_path
except Exception as ex:
logger.error(f"Failed to save audit image: {ex}")
try:
plt.close("all")
except Exception:
pass
return ""

295
utils/cache_manager.py Normal file
View File

@@ -0,0 +1,295 @@
"""
Cache Manager for Trading System
Utilities for managing and cleaning up cache files, including:
- Parquet file validation and repair
- Cache cleanup and maintenance
- Cache health monitoring
"""
import os
import logging
import pandas as pd
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
class CacheManager:
"""Manages cache files for the trading system"""
def __init__(self, cache_dirs: List[str] = None):
"""
Initialize cache manager
Args:
cache_dirs: List of cache directories to manage
"""
self.cache_dirs = cache_dirs or [
"data/cache",
"data/monthly_cache",
"data/pivot_cache"
]
# Ensure cache directories exist
for cache_dir in self.cache_dirs:
Path(cache_dir).mkdir(parents=True, exist_ok=True)
def validate_parquet_file(self, file_path: Path) -> Tuple[bool, Optional[str]]:
"""
Validate a Parquet file
Args:
file_path: Path to the Parquet file
Returns:
Tuple of (is_valid, error_message)
"""
try:
if not file_path.exists():
return False, "File does not exist"
if file_path.stat().st_size == 0:
return False, "File is empty"
# Try to read the file
df = pd.read_parquet(file_path)
if df.empty:
return False, "File contains no data"
# Check for required columns (basic validation)
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
return False, f"Missing required columns: {missing_columns}"
return True, None
except Exception as e:
error_str = str(e).lower()
corrupted_indicators = [
"parquet magic bytes not found",
"corrupted",
"couldn't deserialize thrift",
"don't know what type",
"invalid parquet file",
"unexpected end of file",
"invalid metadata"
]
if any(indicator in error_str for indicator in corrupted_indicators):
return False, f"Corrupted Parquet file: {e}"
else:
return False, f"Validation error: {e}"
def scan_cache_health(self) -> Dict[str, Dict]:
"""
Scan all cache directories for file health
Returns:
Dictionary with cache health information
"""
health_report = {}
for cache_dir in self.cache_dirs:
cache_path = Path(cache_dir)
if not cache_path.exists():
continue
dir_report = {
'total_files': 0,
'valid_files': 0,
'corrupted_files': 0,
'empty_files': 0,
'total_size_mb': 0.0,
'corrupted_files_list': [],
'old_files': []
}
# Scan all Parquet files
for file_path in cache_path.glob("*.parquet"):
dir_report['total_files'] += 1
file_size_mb = file_path.stat().st_size / (1024 * 1024)
dir_report['total_size_mb'] += file_size_mb
# Check file age
file_age = datetime.now() - datetime.fromtimestamp(file_path.stat().st_mtime)
if file_age > timedelta(days=7): # Files older than 7 days
dir_report['old_files'].append({
'file': str(file_path),
'age_days': file_age.days,
'size_mb': file_size_mb
})
# Validate file
is_valid, error_msg = self.validate_parquet_file(file_path)
if is_valid:
dir_report['valid_files'] += 1
else:
if "empty" in error_msg.lower():
dir_report['empty_files'] += 1
else:
dir_report['corrupted_files'] += 1
dir_report['corrupted_files_list'].append({
'file': str(file_path),
'error': error_msg,
'size_mb': file_size_mb
})
health_report[cache_dir] = dir_report
return health_report
def cleanup_corrupted_files(self, dry_run: bool = True) -> Dict[str, List[str]]:
"""
Clean up corrupted cache files
Args:
dry_run: If True, only report what would be deleted
Returns:
Dictionary of deleted files by directory
"""
deleted_files = {}
for cache_dir in self.cache_dirs:
cache_path = Path(cache_dir)
if not cache_path.exists():
continue
deleted_files[cache_dir] = []
for file_path in cache_path.glob("*.parquet"):
is_valid, error_msg = self.validate_parquet_file(file_path)
if not is_valid:
if dry_run:
deleted_files[cache_dir].append(f"WOULD DELETE: {file_path} ({error_msg})")
logger.info(f"Would delete corrupted file: {file_path} ({error_msg})")
else:
try:
file_path.unlink()
deleted_files[cache_dir].append(f"DELETED: {file_path} ({error_msg})")
logger.info(f"Deleted corrupted file: {file_path}")
except Exception as e:
deleted_files[cache_dir].append(f"FAILED TO DELETE: {file_path} ({e})")
logger.error(f"Failed to delete corrupted file {file_path}: {e}")
return deleted_files
def cleanup_old_files(self, days_to_keep: int = 7, dry_run: bool = True) -> Dict[str, List[str]]:
"""
Clean up old cache files
Args:
days_to_keep: Number of days to keep files
dry_run: If True, only report what would be deleted
Returns:
Dictionary of deleted files by directory
"""
deleted_files = {}
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
for cache_dir in self.cache_dirs:
cache_path = Path(cache_dir)
if not cache_path.exists():
continue
deleted_files[cache_dir] = []
for file_path in cache_path.glob("*.parquet"):
file_mtime = datetime.fromtimestamp(file_path.stat().st_mtime)
if file_mtime < cutoff_date:
age_days = (datetime.now() - file_mtime).days
if dry_run:
deleted_files[cache_dir].append(f"WOULD DELETE: {file_path} (age: {age_days} days)")
logger.info(f"Would delete old file: {file_path} (age: {age_days} days)")
else:
try:
file_path.unlink()
deleted_files[cache_dir].append(f"DELETED: {file_path} (age: {age_days} days)")
logger.info(f"Deleted old file: {file_path}")
except Exception as e:
deleted_files[cache_dir].append(f"FAILED TO DELETE: {file_path} ({e})")
logger.error(f"Failed to delete old file {file_path}: {e}")
return deleted_files
def get_cache_summary(self) -> Dict[str, any]:
"""Get a summary of cache usage"""
health_report = self.scan_cache_health()
total_files = sum(report['total_files'] for report in health_report.values())
total_valid = sum(report['valid_files'] for report in health_report.values())
total_corrupted = sum(report['corrupted_files'] for report in health_report.values())
total_size_mb = sum(report['total_size_mb'] for report in health_report.values())
return {
'total_files': total_files,
'valid_files': total_valid,
'corrupted_files': total_corrupted,
'health_percentage': (total_valid / total_files * 100) if total_files > 0 else 0,
'total_size_mb': total_size_mb,
'directories': health_report
}
def emergency_cache_reset(self, confirm: bool = False) -> bool:
"""
Emergency cache reset - deletes all cache files
Args:
confirm: Must be True to actually delete files
Returns:
True if reset was performed
"""
if not confirm:
logger.warning("Emergency cache reset called but not confirmed")
return False
deleted_count = 0
for cache_dir in self.cache_dirs:
cache_path = Path(cache_dir)
if not cache_path.exists():
continue
for file_path in cache_path.glob("*"):
try:
if file_path.is_file():
file_path.unlink()
deleted_count += 1
except Exception as e:
logger.error(f"Failed to delete {file_path}: {e}")
logger.warning(f"Emergency cache reset completed: deleted {deleted_count} files")
return True
# Global cache manager instance
_cache_manager_instance = None
def get_cache_manager() -> CacheManager:
"""Get the global cache manager instance"""
global _cache_manager_instance
if _cache_manager_instance is None:
_cache_manager_instance = CacheManager()
return _cache_manager_instance
def cleanup_corrupted_cache(dry_run: bool = True) -> Dict[str, List[str]]:
"""Convenience function to clean up corrupted cache files"""
cache_manager = get_cache_manager()
return cache_manager.cleanup_corrupted_files(dry_run=dry_run)
def get_cache_health() -> Dict[str, any]:
"""Convenience function to get cache health summary"""
cache_manager = get_cache_manager()
return cache_manager.get_cache_summary()

547
utils/checkpoint_manager.py Normal file
View File

@@ -0,0 +1,547 @@
"""
Checkpoint Manager
This module provides functionality for managing model checkpoints, including:
- Saving checkpoints with metadata
- Loading the best checkpoint based on performance metrics
- Cleaning up old or underperforming checkpoints
- Database-backed metadata storage for efficient access
"""
import os
import json
import glob
import logging
import shutil
import torch
import hashlib
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
from .database_manager import get_database_manager, CheckpointMetadata
from .text_logger import get_text_logger
logger = logging.getLogger(__name__)
# Global checkpoint manager instance
_checkpoint_manager_instance = None
def get_checkpoint_manager(checkpoint_dir: str = "models/checkpoints", max_checkpoints: int = 10, metric_name: str = "accuracy") -> 'CheckpointManager':
"""
Get the global checkpoint manager instance
Args:
checkpoint_dir: Directory to store checkpoints
max_checkpoints: Maximum number of checkpoints to keep
metric_name: Metric to use for ranking checkpoints
Returns:
CheckpointManager: Global checkpoint manager instance
"""
global _checkpoint_manager_instance
if _checkpoint_manager_instance is None:
_checkpoint_manager_instance = CheckpointManager(
checkpoint_dir=checkpoint_dir,
max_checkpoints=max_checkpoints,
metric_name=metric_name
)
return _checkpoint_manager_instance
def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Dict[str, Any] = None, checkpoint_dir: str = "models/checkpoints") -> Any:
"""
Save a checkpoint with metadata to both filesystem and database
Args:
model: The model to save
model_name: Name of the model
model_type: Type of the model ('cnn', 'rl', etc.)
performance_metrics: Performance metrics
training_metadata: Additional training metadata
checkpoint_dir: Directory to store checkpoints
Returns:
Any: Checkpoint metadata
"""
try:
# Create checkpoint directory
os.makedirs(checkpoint_dir, exist_ok=True)
# Create timestamp
timestamp = datetime.now()
timestamp_str = timestamp.strftime("%Y%m%d_%H%M%S")
# Create checkpoint path
model_dir = os.path.join(checkpoint_dir, model_name)
os.makedirs(model_dir, exist_ok=True)
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp_str}")
checkpoint_id = f"{model_name}_{timestamp_str}"
# Save model
torch_path = f"{checkpoint_path}.pt"
if hasattr(model, 'save'):
# Use model's save method if available
model.save(checkpoint_path)
else:
# Otherwise, save state_dict
torch.save({
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else None,
'model_name': model_name,
'model_type': model_type,
'timestamp': timestamp_str,
'checkpoint_id': checkpoint_id
}, torch_path)
# Calculate file size
file_size_mb = os.path.getsize(torch_path) / (1024 * 1024) if os.path.exists(torch_path) else 0.0
# Save metadata to database
db_manager = get_database_manager()
checkpoint_metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name=model_name,
model_type=model_type,
timestamp=timestamp,
performance_metrics=performance_metrics,
training_metadata=training_metadata or {},
file_path=torch_path,
file_size_mb=file_size_mb,
is_active=True # New checkpoint is active by default
)
# Save to database
if db_manager.save_checkpoint_metadata(checkpoint_metadata):
# Log checkpoint save event to text file
text_logger = get_text_logger()
text_logger.log_checkpoint_event(
model_name=model_name,
event_type="SAVED",
checkpoint_id=checkpoint_id,
details=f"loss={performance_metrics.get('loss', 'N/A')}, size={file_size_mb:.1f}MB"
)
logger.info(f"Checkpoint saved: {checkpoint_id}")
else:
logger.warning(f"Failed to save checkpoint metadata to database: {checkpoint_id}")
# Also save legacy JSON metadata for backward compatibility
legacy_metadata = {
'model_name': model_name,
'model_type': model_type,
'timestamp': timestamp_str,
'performance_metrics': performance_metrics,
'training_metadata': training_metadata or {},
'checkpoint_id': checkpoint_id,
'performance_score': performance_metrics.get('accuracy', performance_metrics.get('reward', 0.0)),
'created_at': timestamp_str
}
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
json.dump(legacy_metadata, f, indent=2)
# Get checkpoint manager and clean up old checkpoints
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
checkpoint_manager._cleanup_checkpoints(model_name)
# Return metadata as an object for backward compatibility
class CheckpointMetadataObj:
def __init__(self, metadata):
for key, value in metadata.items():
setattr(self, key, value)
# Add database fields
self.checkpoint_id = checkpoint_id
self.loss = performance_metrics.get('loss', performance_metrics.get('accuracy', 0.0))
return CheckpointMetadataObj(legacy_metadata)
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
return None
def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoints") -> Optional[Tuple[str, Any]]:
"""
Load the best checkpoint based on performance metrics using database metadata
Args:
model_name: Name of the model
checkpoint_dir: Directory to store checkpoints
Returns:
Optional[Tuple[str, Any]]: Path to the best checkpoint and its metadata, or None if not found
"""
try:
# First try to get from database (fast metadata access)
db_manager = get_database_manager()
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(model_name, "accuracy")
if not checkpoint_metadata:
# Fallback to legacy file-based approach (no more scattered "No checkpoints found" logs)
pass # Silent fallback
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
checkpoint_path, legacy_metadata = checkpoint_manager.load_best_checkpoint(model_name)
if not checkpoint_path:
return None
# Convert legacy metadata to object
class CheckpointMetadataObj:
def __init__(self, metadata):
for key, value in metadata.items():
setattr(self, key, value)
# Add performance score if not present
if not hasattr(self, 'performance_score'):
metrics = getattr(self, 'metrics', {})
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
self.performance_score = metrics.get(primary_metric, 0.0)
# Add created_at if not present
if not hasattr(self, 'created_at'):
self.created_at = getattr(self, 'timestamp', 'unknown')
# Add loss for compatibility
self.loss = metrics.get('loss', self.performance_score)
self.checkpoint_id = getattr(self, 'checkpoint_id', f"{model_name}_unknown")
return f"{checkpoint_path}.pt", CheckpointMetadataObj(legacy_metadata)
# Check if checkpoint file exists
if not os.path.exists(checkpoint_metadata.file_path):
logger.warning(f"Checkpoint file not found: {checkpoint_metadata.file_path}")
return None
# Log checkpoint load event to text file
text_logger = get_text_logger()
text_logger.log_checkpoint_event(
model_name=model_name,
event_type="LOADED",
checkpoint_id=checkpoint_metadata.checkpoint_id,
details=f"loss={checkpoint_metadata.performance_metrics.get('loss', 'N/A')}"
)
# Convert database metadata to object for backward compatibility
class CheckpointMetadataObj:
def __init__(self, db_metadata: CheckpointMetadata):
self.checkpoint_id = db_metadata.checkpoint_id
self.model_name = db_metadata.model_name
self.model_type = db_metadata.model_type
self.timestamp = db_metadata.timestamp.strftime("%Y%m%d_%H%M%S")
self.performance_metrics = db_metadata.performance_metrics
self.training_metadata = db_metadata.training_metadata
self.file_path = db_metadata.file_path
self.file_size_mb = db_metadata.file_size_mb
self.is_active = db_metadata.is_active
# Backward compatibility fields
self.metrics = db_metadata.performance_metrics
self.metadata = db_metadata.training_metadata
self.created_at = self.timestamp
self.performance_score = db_metadata.performance_metrics.get('accuracy',
db_metadata.performance_metrics.get('reward', 0.0))
self.loss = db_metadata.performance_metrics.get('loss', self.performance_score)
return checkpoint_metadata.file_path, CheckpointMetadataObj(checkpoint_metadata)
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return None
class CheckpointManager:
"""
Manages model checkpoints with performance-based optimization
This class:
1. Saves checkpoints with metadata
2. Loads the best checkpoint based on performance metrics
3. Cleans up old or underperforming checkpoints
"""
def __init__(self, checkpoint_dir: str, max_checkpoints: int = 10, metric_name: str = "accuracy"):
"""
Initialize the checkpoint manager
Args:
checkpoint_dir: Directory to store checkpoints
max_checkpoints: Maximum number of checkpoints to keep
metric_name: Metric to use for ranking checkpoints
"""
self.checkpoint_dir = checkpoint_dir
self.max_checkpoints = max_checkpoints
self.metric_name = metric_name
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
logger.info(f"CheckpointManager initialized with checkpoint_dir: {checkpoint_dir}")
def save_checkpoint(self, model_name: str, model_path: str, metrics: Dict[str, float], metadata: Dict[str, Any] = None) -> str:
"""
Save a checkpoint with metadata
Args:
model_name: Name of the model
model_path: Path to the model file
metrics: Performance metrics
metadata: Additional metadata
Returns:
str: Path to the saved checkpoint
"""
try:
# Create timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create checkpoint directory
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
os.makedirs(checkpoint_dir, exist_ok=True)
# Create checkpoint path
checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_{timestamp}")
# Copy model file to checkpoint path
shutil.copy2(model_path, f"{checkpoint_path}.pt")
# Create metadata
checkpoint_metadata = {
'model_name': model_name,
'timestamp': timestamp,
'metrics': metrics,
'metadata': metadata or {}
}
# Save metadata
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
json.dump(checkpoint_metadata, f, indent=2)
logger.info(f"Saved checkpoint to {checkpoint_path}")
# Clean up old checkpoints
self._cleanup_checkpoints(model_name)
return checkpoint_path
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
return ""
def load_best_checkpoint(self, model_name: str) -> Tuple[str, Dict[str, Any]]:
"""
Load the best checkpoint based on performance metrics
Args:
model_name: Name of the model
Returns:
Tuple[str, Dict[str, Any]]: Path to the best checkpoint and its metadata
"""
try:
# Find all checkpoint metadata files
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
if not metadata_files:
# No more scattered "No checkpoints found" logs - handled by database system
return "", {}
# Load metadata for each checkpoint
checkpoints = []
for metadata_file in metadata_files:
try:
with open(metadata_file, 'r') as f:
metadata = json.load(f)
# Get checkpoint path (remove _metadata.json)
checkpoint_path = metadata_file[:-14]
# Check if model file exists
if not os.path.exists(f"{checkpoint_path}.pt"):
logger.warning(f"Model file not found for checkpoint {checkpoint_path}")
continue
checkpoints.append((checkpoint_path, metadata))
except Exception as e:
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
if not checkpoints:
# No more scattered logs - handled by database system
return "", {}
# Sort by metric (highest first)
checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True)
# Return best checkpoint
best_checkpoint_path = checkpoints[0][0]
best_checkpoint_metadata = checkpoints[0][1]
logger.info(f"Best checkpoint for {model_name}: {best_checkpoint_path}")
return best_checkpoint_path, best_checkpoint_metadata
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return "", {}
def _cleanup_checkpoints(self, model_name: str) -> int:
"""
Clean up old or underperforming checkpoints
Args:
model_name: Name of the model
Returns:
int: Number of checkpoints deleted
"""
try:
# Find all checkpoint metadata files
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
if not metadata_files or len(metadata_files) <= self.max_checkpoints:
return 0
# Load metadata for each checkpoint
checkpoints = []
for metadata_file in metadata_files:
try:
with open(metadata_file, 'r') as f:
metadata = json.load(f)
# Get checkpoint path (remove _metadata.json)
checkpoint_path = metadata_file[:-14]
checkpoints.append((checkpoint_path, metadata))
except Exception as e:
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
# Sort by metric (highest first)
checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True)
# Keep only the best checkpoints
checkpoints_to_delete = checkpoints[self.max_checkpoints:]
# Delete checkpoints
deleted_count = 0
for checkpoint_path, _ in checkpoints_to_delete:
try:
# Delete model file
if os.path.exists(f"{checkpoint_path}.pt"):
os.remove(f"{checkpoint_path}.pt")
# Delete metadata file
if os.path.exists(f"{checkpoint_path}_metadata.json"):
os.remove(f"{checkpoint_path}_metadata.json")
deleted_count += 1
except Exception as e:
logger.error(f"Error deleting checkpoint {checkpoint_path}: {e}")
logger.info(f"Deleted {deleted_count} old checkpoints for {model_name}")
return deleted_count
except Exception as e:
logger.error(f"Error cleaning up checkpoints: {e}")
return 0
def get_all_checkpoints(self, model_name: str) -> List[Tuple[str, Dict[str, Any]]]:
"""
Get all checkpoints for a model
Args:
model_name: Name of the model
Returns:
List[Tuple[str, Dict[str, Any]]]: List of checkpoint paths and metadata
"""
try:
# Find all checkpoint metadata files
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
if not metadata_files:
return []
# Load metadata for each checkpoint
checkpoints = []
for metadata_file in metadata_files:
try:
with open(metadata_file, 'r') as f:
metadata = json.load(f)
# Get checkpoint path (remove _metadata.json)
checkpoint_path = metadata_file[:-14]
# Check if model file exists
if not os.path.exists(f"{checkpoint_path}.pt"):
logger.warning(f"Model file not found for checkpoint {checkpoint_path}")
continue
checkpoints.append((checkpoint_path, metadata))
except Exception as e:
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
# Sort by timestamp (newest first)
checkpoints.sort(key=lambda x: x[1].get('timestamp', ''), reverse=True)
return checkpoints
except Exception as e:
logger.error(f"Error getting all checkpoints: {e}")
return []
def get_checkpoint_stats(self) -> Dict[str, Any]:
"""
Get statistics about all checkpoints
Returns:
Dict[str, Any]: Statistics about checkpoints
"""
try:
stats = {
'total_checkpoints': 0,
'total_size_mb': 0.0,
'models': {}
}
# Iterate through all model directories
for model_dir in os.listdir(self.checkpoint_dir):
model_path = os.path.join(self.checkpoint_dir, model_dir)
if not os.path.isdir(model_path):
continue
# Count checkpoints for this model
checkpoint_files = glob.glob(os.path.join(model_path, f"{model_dir}_*.pt"))
model_checkpoints = len(checkpoint_files)
# Calculate total size for this model
model_size_mb = 0.0
for checkpoint_file in checkpoint_files:
try:
size_bytes = os.path.getsize(checkpoint_file)
model_size_mb += size_bytes / (1024 * 1024) # Convert to MB
except OSError:
pass
stats['models'][model_dir] = {
'checkpoints': model_checkpoints,
'size_mb': round(model_size_mb, 2)
}
stats['total_checkpoints'] += model_checkpoints
stats['total_size_mb'] += model_size_mb
stats['total_size_mb'] = round(stats['total_size_mb'], 2)
return stats
except Exception as e:
logger.error(f"Error getting checkpoint stats: {e}")
return {
'total_checkpoints': 0,
'total_size_mb': 0.0,
'models': {}
}

565
utils/database_manager.py Normal file
View File

@@ -0,0 +1,565 @@
"""
Database Manager for Trading System
Manages SQLite database for:
1. Inference records logging
2. Checkpoint metadata storage
3. Model performance tracking
"""
import sqlite3
import json
import logging
import os
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
from contextlib import contextmanager
from dataclasses import dataclass, asdict
logger = logging.getLogger(__name__)
@dataclass
class InferenceRecord:
"""Structure for inference logging"""
model_name: str
timestamp: datetime
symbol: str
action: str
confidence: float
probabilities: Dict[str, float]
input_features_hash: str # Hash of input features for deduplication
processing_time_ms: float
memory_usage_mb: float
input_features: Optional[np.ndarray] = None # Full input features for training
checkpoint_id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
@dataclass
class CheckpointMetadata:
"""Structure for checkpoint metadata"""
checkpoint_id: str
model_name: str
model_type: str
timestamp: datetime
performance_metrics: Dict[str, float]
training_metadata: Dict[str, Any]
file_path: str
file_size_mb: float
is_active: bool = False # Currently loaded checkpoint
class DatabaseManager:
"""Manages SQLite database for trading system logging and metadata"""
def __init__(self, db_path: str = "data/trading_system.db"):
self.db_path = db_path
self._ensure_db_directory()
self._initialize_database()
def _ensure_db_directory(self):
"""Ensure database directory exists"""
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
def _initialize_database(self):
"""Initialize database tables"""
with self._get_connection() as conn:
# Inference records table
conn.execute("""
CREATE TABLE IF NOT EXISTS inference_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_name TEXT NOT NULL,
timestamp TEXT NOT NULL,
symbol TEXT NOT NULL,
action TEXT NOT NULL,
confidence REAL NOT NULL,
probabilities TEXT NOT NULL, -- JSON
input_features_hash TEXT NOT NULL,
input_features_blob BLOB, -- Store full input features for training
processing_time_ms REAL NOT NULL,
memory_usage_mb REAL NOT NULL,
checkpoint_id TEXT,
metadata TEXT, -- JSON
created_at TEXT DEFAULT CURRENT_TIMESTAMP
)
""")
# Checkpoint metadata table
conn.execute("""
CREATE TABLE IF NOT EXISTS checkpoint_metadata (
id INTEGER PRIMARY KEY AUTOINCREMENT,
checkpoint_id TEXT UNIQUE NOT NULL,
model_name TEXT NOT NULL,
model_type TEXT NOT NULL,
timestamp TEXT NOT NULL,
performance_metrics TEXT NOT NULL, -- JSON
training_metadata TEXT NOT NULL, -- JSON
file_path TEXT NOT NULL,
file_size_mb REAL NOT NULL,
is_active BOOLEAN DEFAULT FALSE,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
)
""")
# Model performance tracking table
conn.execute("""
CREATE TABLE IF NOT EXISTS model_performance (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_name TEXT NOT NULL,
date TEXT NOT NULL,
total_predictions INTEGER DEFAULT 0,
correct_predictions INTEGER DEFAULT 0,
accuracy REAL DEFAULT 0.0,
avg_confidence REAL DEFAULT 0.0,
avg_processing_time_ms REAL DEFAULT 0.0,
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
UNIQUE(model_name, date)
)
""")
# Create indexes for better performance
conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_model_timestamp ON inference_records(model_name, timestamp)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_symbol ON inference_records(symbol)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_model ON checkpoint_metadata(model_name)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_active ON checkpoint_metadata(is_active)")
logger.info(f"Database initialized at {self.db_path}")
# Run migrations to handle schema updates
self._run_migrations()
def _run_migrations(self):
"""Run database migrations to handle schema updates"""
try:
with self._get_connection() as conn:
# Check if input_features_blob column exists
cursor = conn.execute("PRAGMA table_info(inference_records)")
columns = [row[1] for row in cursor.fetchall()]
if 'input_features_blob' not in columns:
logger.info("Adding input_features_blob column to inference_records table")
conn.execute("ALTER TABLE inference_records ADD COLUMN input_features_blob BLOB")
conn.commit()
logger.info("Successfully added input_features_blob column")
else:
logger.debug("input_features_blob column already exists")
except Exception as e:
logger.error(f"Error running database migrations: {e}")
# If migration fails, we can still continue without the blob column
@contextmanager
def _get_connection(self):
"""Get database connection with proper error handling"""
conn = None
try:
conn = sqlite3.connect(self.db_path, timeout=30.0)
conn.row_factory = sqlite3.Row # Enable dict-like access
yield conn
except Exception as e:
if conn:
conn.rollback()
logger.error(f"Database error: {e}")
raise
finally:
if conn:
conn.close()
def log_inference(self, record: InferenceRecord) -> bool:
"""Log an inference record"""
try:
with self._get_connection() as conn:
# Check if input_features_blob column exists
cursor = conn.execute("PRAGMA table_info(inference_records)")
columns = [row[1] for row in cursor.fetchall()]
has_blob_column = 'input_features_blob' in columns
# Serialize input features if provided and column exists
input_features_blob = None
if record.input_features is not None and has_blob_column:
input_features_blob = record.input_features.tobytes()
if has_blob_column:
# Use full query with blob column
conn.execute("""
INSERT INTO inference_records (
model_name, timestamp, symbol, action, confidence,
probabilities, input_features_hash, input_features_blob,
processing_time_ms, memory_usage_mb, checkpoint_id, metadata
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
record.model_name,
record.timestamp.isoformat(),
record.symbol,
record.action,
record.confidence,
json.dumps(record.probabilities),
record.input_features_hash,
input_features_blob,
record.processing_time_ms,
record.memory_usage_mb,
record.checkpoint_id,
json.dumps(record.metadata) if record.metadata else None
))
else:
# Fallback query without blob column
logger.warning("input_features_blob column missing, storing without full features")
conn.execute("""
INSERT INTO inference_records (
model_name, timestamp, symbol, action, confidence,
probabilities, input_features_hash,
processing_time_ms, memory_usage_mb, checkpoint_id, metadata
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
record.model_name,
record.timestamp.isoformat(),
record.symbol,
record.action,
record.confidence,
json.dumps(record.probabilities),
record.input_features_hash,
record.processing_time_ms,
record.memory_usage_mb,
record.checkpoint_id,
json.dumps(record.metadata) if record.metadata else None
))
conn.commit()
return True
except Exception as e:
logger.error(f"Failed to log inference record: {e}")
return False
def save_checkpoint_metadata(self, metadata: CheckpointMetadata) -> bool:
"""Save checkpoint metadata"""
try:
with self._get_connection() as conn:
# First, set all other checkpoints for this model as inactive
conn.execute("""
UPDATE checkpoint_metadata
SET is_active = FALSE
WHERE model_name = ?
""", (metadata.model_name,))
# Insert or replace the new checkpoint metadata
conn.execute("""
INSERT OR REPLACE INTO checkpoint_metadata (
checkpoint_id, model_name, model_type, timestamp,
performance_metrics, training_metadata, file_path,
file_size_mb, is_active
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
metadata.checkpoint_id,
metadata.model_name,
metadata.model_type,
metadata.timestamp.isoformat(),
json.dumps(metadata.performance_metrics),
json.dumps(metadata.training_metadata),
metadata.file_path,
metadata.file_size_mb,
metadata.is_active
))
conn.commit()
return True
except Exception as e:
logger.error(f"Failed to save checkpoint metadata: {e}")
return False
def get_checkpoint_metadata(self, model_name: str, checkpoint_id: str = None) -> Optional[CheckpointMetadata]:
"""Get checkpoint metadata without loading the actual model"""
try:
with self._get_connection() as conn:
if checkpoint_id:
# Get specific checkpoint
cursor = conn.execute("""
SELECT * FROM checkpoint_metadata
WHERE model_name = ? AND checkpoint_id = ?
""", (model_name, checkpoint_id))
else:
# Get active checkpoint
cursor = conn.execute("""
SELECT * FROM checkpoint_metadata
WHERE model_name = ? AND is_active = TRUE
ORDER BY timestamp DESC LIMIT 1
""", (model_name,))
row = cursor.fetchone()
if row:
return CheckpointMetadata(
checkpoint_id=row['checkpoint_id'],
model_name=row['model_name'],
model_type=row['model_type'],
timestamp=datetime.fromisoformat(row['timestamp']),
performance_metrics=json.loads(row['performance_metrics']),
training_metadata=json.loads(row['training_metadata']),
file_path=row['file_path'],
file_size_mb=row['file_size_mb'],
is_active=bool(row['is_active'])
)
return None
except Exception as e:
logger.error(f"Failed to get checkpoint metadata: {e}")
return None
def get_best_checkpoint_metadata(self, model_name: str, metric_name: str = "accuracy") -> Optional[CheckpointMetadata]:
"""Get best checkpoint metadata based on performance metric"""
try:
with self._get_connection() as conn:
cursor = conn.execute("""
SELECT * FROM checkpoint_metadata
WHERE model_name = ?
ORDER BY json_extract(performance_metrics, '$.' || ?) DESC
LIMIT 1
""", (model_name, metric_name))
row = cursor.fetchone()
if row:
return CheckpointMetadata(
checkpoint_id=row['checkpoint_id'],
model_name=row['model_name'],
model_type=row['model_type'],
timestamp=datetime.fromisoformat(row['timestamp']),
performance_metrics=json.loads(row['performance_metrics']),
training_metadata=json.loads(row['training_metadata']),
file_path=row['file_path'],
file_size_mb=row['file_size_mb'],
is_active=bool(row['is_active'])
)
return None
except Exception as e:
logger.error(f"Failed to get best checkpoint metadata: {e}")
return None
def list_checkpoints(self, model_name: str) -> List[CheckpointMetadata]:
"""List all checkpoints for a model"""
try:
with self._get_connection() as conn:
cursor = conn.execute("""
SELECT * FROM checkpoint_metadata
WHERE model_name = ?
ORDER BY timestamp DESC
""", (model_name,))
checkpoints = []
for row in cursor.fetchall():
checkpoints.append(CheckpointMetadata(
checkpoint_id=row['checkpoint_id'],
model_name=row['model_name'],
model_type=row['model_type'],
timestamp=datetime.fromisoformat(row['timestamp']),
performance_metrics=json.loads(row['performance_metrics']),
training_metadata=json.loads(row['training_metadata']),
file_path=row['file_path'],
file_size_mb=row['file_size_mb'],
is_active=bool(row['is_active'])
))
return checkpoints
except Exception as e:
logger.error(f"Failed to list checkpoints: {e}")
return []
def set_active_checkpoint(self, model_name: str, checkpoint_id: str) -> bool:
"""Set a checkpoint as active for a model"""
try:
with self._get_connection() as conn:
# First, set all checkpoints for this model as inactive
conn.execute("""
UPDATE checkpoint_metadata
SET is_active = FALSE
WHERE model_name = ?
""", (model_name,))
# Set the specified checkpoint as active
cursor = conn.execute("""
UPDATE checkpoint_metadata
SET is_active = TRUE
WHERE model_name = ? AND checkpoint_id = ?
""", (model_name, checkpoint_id))
conn.commit()
return cursor.rowcount > 0
except Exception as e:
logger.error(f"Failed to set active checkpoint: {e}")
return False
def get_recent_inferences(self, model_name: str, limit: int = 100) -> List[InferenceRecord]:
"""Get recent inference records for a model"""
try:
with self._get_connection() as conn:
cursor = conn.execute("""
SELECT * FROM inference_records
WHERE model_name = ?
ORDER BY timestamp DESC
LIMIT ?
""", (model_name, limit))
records = []
for row in cursor.fetchall():
# Deserialize input features if available
input_features = None
# Check if the column exists in the row (handles missing column gracefully)
if 'input_features_blob' in row.keys() and row['input_features_blob']:
try:
# Reconstruct numpy array from bytes
input_features = np.frombuffer(row['input_features_blob'], dtype=np.float32)
except Exception as e:
logger.warning(f"Failed to deserialize input features: {e}")
records.append(InferenceRecord(
model_name=row['model_name'],
timestamp=datetime.fromisoformat(row['timestamp']),
symbol=row['symbol'],
action=row['action'],
confidence=row['confidence'],
probabilities=json.loads(row['probabilities']),
input_features_hash=row['input_features_hash'],
processing_time_ms=row['processing_time_ms'],
memory_usage_mb=row['memory_usage_mb'],
input_features=input_features,
checkpoint_id=row['checkpoint_id'],
metadata=json.loads(row['metadata']) if row['metadata'] else None
))
return records
except Exception as e:
logger.error(f"Failed to get recent inferences: {e}")
return []
def update_model_performance(self, model_name: str, date: str,
total_predictions: int, correct_predictions: int,
avg_confidence: float, avg_processing_time: float) -> bool:
"""Update daily model performance statistics"""
try:
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
with self._get_connection() as conn:
conn.execute("""
INSERT OR REPLACE INTO model_performance (
model_name, date, total_predictions, correct_predictions,
accuracy, avg_confidence, avg_processing_time_ms
) VALUES (?, ?, ?, ?, ?, ?, ?)
""", (
model_name, date, total_predictions, correct_predictions,
accuracy, avg_confidence, avg_processing_time
))
conn.commit()
return True
except Exception as e:
logger.error(f"Failed to update model performance: {e}")
return False
def get_inference_records_for_training(self, model_name: str,
symbol: str = None,
hours_back: int = 24,
limit: int = 1000) -> List[InferenceRecord]:
"""
Get inference records with input features for training feedback
Args:
model_name: Name of the model
symbol: Optional symbol filter
hours_back: How many hours back to look
limit: Maximum number of records
Returns:
List of InferenceRecord with input_features populated
"""
try:
cutoff_time = datetime.now() - timedelta(hours=hours_back)
with self._get_connection() as conn:
# Check if input_features_blob column exists before querying
cursor = conn.execute("PRAGMA table_info(inference_records)")
columns = [row[1] for row in cursor.fetchall()]
has_blob_column = 'input_features_blob' in columns
if not has_blob_column:
logger.warning("input_features_blob column not found, returning empty list")
return []
if symbol:
cursor = conn.execute("""
SELECT * FROM inference_records
WHERE model_name = ? AND symbol = ? AND timestamp >= ?
AND input_features_blob IS NOT NULL
ORDER BY timestamp DESC
LIMIT ?
""", (model_name, symbol, cutoff_time.isoformat(), limit))
else:
cursor = conn.execute("""
SELECT * FROM inference_records
WHERE model_name = ? AND timestamp >= ?
AND input_features_blob IS NOT NULL
ORDER BY timestamp DESC
LIMIT ?
""", (model_name, cutoff_time.isoformat(), limit))
records = []
for row in cursor.fetchall():
# Deserialize input features
input_features = None
if row['input_features_blob']:
try:
input_features = np.frombuffer(row['input_features_blob'], dtype=np.float32)
except Exception as e:
logger.warning(f"Failed to deserialize input features: {e}")
continue # Skip records with corrupted features
records.append(InferenceRecord(
model_name=row['model_name'],
timestamp=datetime.fromisoformat(row['timestamp']),
symbol=row['symbol'],
action=row['action'],
confidence=row['confidence'],
probabilities=json.loads(row['probabilities']),
input_features_hash=row['input_features_hash'],
processing_time_ms=row['processing_time_ms'],
memory_usage_mb=row['memory_usage_mb'],
input_features=input_features,
checkpoint_id=row['checkpoint_id'],
metadata=json.loads(row['metadata']) if row['metadata'] else None
))
return records
except Exception as e:
logger.error(f"Failed to get inference records for training: {e}")
return []
def cleanup_old_records(self, days_to_keep: int = 30) -> bool:
"""Clean up old inference records"""
try:
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
with self._get_connection() as conn:
cursor = conn.execute("""
DELETE FROM inference_records
WHERE timestamp < ?
""", (cutoff_date.isoformat(),))
deleted_count = cursor.rowcount
conn.commit()
if deleted_count > 0:
logger.info(f"Cleaned up {deleted_count} old inference records")
return True
except Exception as e:
logger.error(f"Failed to cleanup old records: {e}")
return False
# Global database manager instance
_db_manager_instance = None
def get_database_manager(db_path: str = "data/trading_system.db") -> DatabaseManager:
"""Get the global database manager instance"""
global _db_manager_instance
if _db_manager_instance is None:
_db_manager_instance = DatabaseManager(db_path)
return _db_manager_instance
def reset_database_manager():
"""Reset the database manager instance to force re-initialization"""
global _db_manager_instance
_db_manager_instance = None
logger.info("Database manager instance reset - will re-initialize on next access")

234
utils/inference_logger.py Normal file
View File

@@ -0,0 +1,234 @@
"""
Inference Logger
Centralized logging system for model inferences with database storage
Eliminates scattered logging throughout the codebase
"""
import time
import hashlib
import logging
import psutil
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Union
from dataclasses import dataclass
import numpy as np
from .database_manager import get_database_manager, InferenceRecord
from .text_logger import get_text_logger
logger = logging.getLogger(__name__)
class InferenceLogger:
"""Centralized inference logging system"""
def __init__(self):
self.db_manager = get_database_manager()
self.text_logger = get_text_logger()
self._process = psutil.Process()
def log_inference(self,
model_name: str,
symbol: str,
action: str,
confidence: float,
probabilities: Dict[str, float],
input_features: Union[np.ndarray, Dict, List],
processing_time_ms: float,
checkpoint_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Log a model inference with all relevant details
Args:
model_name: Name of the model making the prediction
symbol: Trading symbol
action: Predicted action (BUY/SELL/HOLD)
confidence: Confidence score (0.0 to 1.0)
probabilities: Action probabilities dict
input_features: Input features used for prediction
processing_time_ms: Time taken for inference in milliseconds
checkpoint_id: ID of the checkpoint used
metadata: Additional metadata
Returns:
bool: True if logged successfully
"""
try:
# Create feature hash for deduplication
feature_hash = self._hash_features(input_features)
# Get current memory usage
memory_usage_mb = self._get_memory_usage()
# Convert input features to numpy array if needed
features_array = None
if isinstance(input_features, np.ndarray):
features_array = input_features.astype(np.float32)
elif isinstance(input_features, (list, tuple)):
features_array = np.array(input_features, dtype=np.float32)
# Create inference record
record = InferenceRecord(
model_name=model_name,
timestamp=datetime.now(),
symbol=symbol,
action=action,
confidence=confidence,
probabilities=probabilities,
input_features_hash=feature_hash,
processing_time_ms=processing_time_ms,
memory_usage_mb=memory_usage_mb,
input_features=features_array,
checkpoint_id=checkpoint_id,
metadata=metadata
)
# Log to database
db_success = self.db_manager.log_inference(record)
# Log to text file
text_success = self.text_logger.log_inference(
model_name=model_name,
symbol=symbol,
action=action,
confidence=confidence,
processing_time_ms=processing_time_ms,
checkpoint_id=checkpoint_id
)
if db_success:
# Reduced logging - no more scattered logs at runtime
pass # Database logging successful, text file provides human-readable record
else:
logger.error(f"Failed to log inference for {model_name}")
return db_success and text_success
except Exception as e:
logger.error(f"Error logging inference: {e}")
return False
def _hash_features(self, features: Union[np.ndarray, Dict, List]) -> str:
"""Create a hash of input features for deduplication"""
try:
if isinstance(features, np.ndarray):
# Hash numpy array
return hashlib.md5(features.tobytes()).hexdigest()[:16]
elif isinstance(features, (dict, list)):
# Hash dict or list by converting to string
feature_str = str(sorted(features.items()) if isinstance(features, dict) else features)
return hashlib.md5(feature_str.encode()).hexdigest()[:16]
else:
# Hash string representation
return hashlib.md5(str(features).encode()).hexdigest()[:16]
except Exception:
# Fallback to timestamp-based hash
return hashlib.md5(str(time.time()).encode()).hexdigest()[:16]
def _get_memory_usage(self) -> float:
"""Get current memory usage in MB"""
try:
return self._process.memory_info().rss / (1024 * 1024)
except Exception:
return 0.0
def get_model_stats(self, model_name: str, hours: int = 24) -> Dict[str, Any]:
"""Get inference statistics for a model"""
try:
# Get recent inferences
recent_inferences = self.db_manager.get_recent_inferences(model_name, limit=1000)
if not recent_inferences:
return {
'total_inferences': 0,
'avg_confidence': 0.0,
'avg_processing_time_ms': 0.0,
'action_distribution': {},
'symbol_distribution': {}
}
# Filter by time window
cutoff_time = datetime.now() - timedelta(hours=hours)
recent_inferences = [r for r in recent_inferences if r.timestamp >= cutoff_time]
if not recent_inferences:
return {
'total_inferences': 0,
'avg_confidence': 0.0,
'avg_processing_time_ms': 0.0,
'action_distribution': {},
'symbol_distribution': {}
}
# Calculate statistics
total_inferences = len(recent_inferences)
avg_confidence = sum(r.confidence for r in recent_inferences) / total_inferences
avg_processing_time = sum(r.processing_time_ms for r in recent_inferences) / total_inferences
# Action distribution
action_counts = {}
for record in recent_inferences:
action_counts[record.action] = action_counts.get(record.action, 0) + 1
# Symbol distribution
symbol_counts = {}
for record in recent_inferences:
symbol_counts[record.symbol] = symbol_counts.get(record.symbol, 0) + 1
return {
'total_inferences': total_inferences,
'avg_confidence': avg_confidence,
'avg_processing_time_ms': avg_processing_time,
'action_distribution': action_counts,
'symbol_distribution': symbol_counts,
'latest_inference': recent_inferences[0].timestamp.isoformat() if recent_inferences else None
}
except Exception as e:
logger.error(f"Error getting model stats: {e}")
return {}
def cleanup_old_logs(self, days_to_keep: int = 30) -> bool:
"""Clean up old inference logs"""
return self.db_manager.cleanup_old_records(days_to_keep)
# Global inference logger instance
_inference_logger_instance = None
def get_inference_logger() -> InferenceLogger:
"""Get the global inference logger instance"""
global _inference_logger_instance
if _inference_logger_instance is None:
_inference_logger_instance = InferenceLogger()
return _inference_logger_instance
def log_model_inference(model_name: str,
symbol: str,
action: str,
confidence: float,
probabilities: Dict[str, float],
input_features: Union[np.ndarray, Dict, List],
processing_time_ms: float,
checkpoint_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Convenience function to log model inference
This is the main function that should be called throughout the codebase
instead of scattered logger.info() calls
"""
inference_logger = get_inference_logger()
return inference_logger.log_inference(
model_name=model_name,
symbol=symbol,
action=action,
confidence=confidence,
probabilities=probabilities,
input_features=input_features,
processing_time_ms=processing_time_ms,
checkpoint_id=checkpoint_id,
metadata=metadata
)

View File

@@ -1,164 +0,0 @@
#!/usr/bin/env python3
"""
TensorBoard Launcher with Automatic Port Management
This script launches TensorBoard with automatic port fallback if the preferred port is in use.
It also kills any stale debug instances that might be running.
Usage:
python launch_tensorboard.py --logdir=path/to/logs --preferred-port=6007 --port-range=6000-7000
"""
import os
import sys
import subprocess
import argparse
import logging
from pathlib import Path
# Add project root to path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if project_root not in sys.path:
sys.path.append(project_root)
from utils.port_manager import get_port_with_fallback, kill_stale_debug_instances
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('tensorboard_launcher')
def launch_tensorboard(logdir, port, host='localhost', open_browser=True):
"""
Launch TensorBoard on the specified port
Args:
logdir (str): Path to log directory
port (int): Port to use
host (str): Host to bind to
open_browser (bool): Whether to open browser automatically
Returns:
subprocess.Popen: Process object
"""
cmd = [
sys.executable, "-m", "tensorboard.main",
f"--logdir={logdir}",
f"--port={port}",
f"--host={host}"
]
# Add --load_fast=false to improve startup times
cmd.append("--load_fast=false")
# Control whether to open browser
if not open_browser:
cmd.append("--window_title=TensorBoard")
logger.info(f"Launching TensorBoard: {' '.join(cmd)}")
# Use subprocess.Popen to start TensorBoard without waiting for it to finish
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
bufsize=1
)
# Log the first few lines of output to confirm it's starting correctly
line_count = 0
for line in process.stdout:
logger.info(f"TensorBoard: {line.strip()}")
line_count += 1
# Check if TensorBoard has started successfully
if "TensorBoard" in line and "http://" in line:
url = line.strip().split("http://")[1].split(" ")[0]
logger.info(f"TensorBoard available at: http://{url}")
# Only log the first few lines
if line_count >= 10:
break
# Continue reading output in background to prevent pipe from filling
def read_output():
for line in process.stdout:
pass
import threading
threading.Thread(target=read_output, daemon=True).start()
return process
def main():
parser = argparse.ArgumentParser(description='Launch TensorBoard with automatic port management')
parser.add_argument('--logdir', type=str, default='NN/models/saved/logs',
help='Directory containing TensorBoard event files')
parser.add_argument('--preferred-port', type=int, default=6007,
help='Preferred port to use')
parser.add_argument('--port-range', type=str, default='6000-7000',
help='Port range to try if preferred port is unavailable (format: min-max)')
parser.add_argument('--host', type=str, default='localhost',
help='Host to bind to')
parser.add_argument('--no-browser', action='store_true',
help='Do not open browser automatically')
parser.add_argument('--kill-stale', action='store_true',
help='Kill stale debug instances before starting')
args = parser.parse_args()
# Parse port range
try:
min_port, max_port = map(int, args.port_range.split('-'))
except ValueError:
logger.error(f"Invalid port range format: {args.port_range}. Use format: min-max")
return 1
# Kill stale instances if requested
if args.kill_stale:
logger.info("Killing stale debug instances...")
count, _ = kill_stale_debug_instances()
logger.info(f"Killed {count} stale instances")
# Get an available port
try:
port = get_port_with_fallback(args.preferred_port, min_port, max_port)
logger.info(f"Using port {port} for TensorBoard")
except RuntimeError as e:
logger.error(str(e))
return 1
# Ensure log directory exists
logdir = os.path.abspath(args.logdir)
os.makedirs(logdir, exist_ok=True)
# Launch TensorBoard
process = launch_tensorboard(
logdir=logdir,
port=port,
host=args.host,
open_browser=not args.no_browser
)
# Wait for process to end (it shouldn't unless there's an error or user kills it)
try:
return_code = process.wait()
if return_code != 0:
logger.error(f"TensorBoard exited with code {return_code}")
return return_code
except KeyboardInterrupt:
logger.info("Received keyboard interrupt, shutting down TensorBoard...")
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
logger.warning("TensorBoard didn't terminate gracefully, forcing kill")
process.kill()
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -1,241 +0,0 @@
#!/usr/bin/env python
"""
Model utilities for robust saving and loading of PyTorch models
"""
import os
import logging
import torch
import shutil
import gc
import json
from typing import Any, Dict, Optional, Union
logger = logging.getLogger(__name__)
def robust_save(model: Any, path: str, include_optimizer: bool = True) -> bool:
"""
Robust model saving with multiple fallback approaches
Args:
model: The model object to save (should have policy_net, target_net, optimizer, epsilon attributes)
path: Path to save the model
include_optimizer: Whether to include optimizer state in the save
Returns:
bool: True if successful, False otherwise
"""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
# Backup path in case the main save fails
backup_path = f"{path}.backup"
# Clean up GPU memory before saving
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Prepare checkpoint data
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'epsilon': getattr(model, 'epsilon', 0.0),
'state_size': getattr(model, 'state_size', None),
'action_size': getattr(model, 'action_size', None),
'hidden_size': getattr(model, 'hidden_size', None),
}
# Add optimizer state if requested and available
if include_optimizer and hasattr(model, 'optimizer') and model.optimizer is not None:
checkpoint['optimizer'] = model.optimizer.state_dict()
# Attempt 1: Try with default settings in a separate file first
try:
logger.info(f"Saving model to {backup_path} (attempt 1)")
torch.save(checkpoint, backup_path)
logger.info(f"Successfully saved to {backup_path}")
# If backup worked, copy to the actual path
if os.path.exists(backup_path):
shutil.copy(backup_path, path)
logger.info(f"Copied backup to {path}")
return True
except Exception as e:
logger.warning(f"First save attempt failed: {e}")
# Attempt 2: Try with pickle protocol 2 (more compatible)
try:
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
torch.save(checkpoint, path, pickle_protocol=2)
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
return True
except Exception as e:
logger.warning(f"Second save attempt failed: {e}")
# Attempt 3: Try without optimizer state (which can be large and cause issues)
try:
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
checkpoint_no_opt = {k: v for k, v in checkpoint.items() if k != 'optimizer'}
torch.save(checkpoint_no_opt, path)
logger.info(f"Successfully saved to {path} without optimizer state")
return True
except Exception as e:
logger.warning(f"Third save attempt failed: {e}")
# Attempt 4: Try with torch.jit.save instead
try:
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
# Save policy network using jit
scripted_policy = torch.jit.script(model.policy_net)
torch.jit.save(scripted_policy, f"{path}.policy.jit")
# Save target network using jit
scripted_target = torch.jit.script(model.target_net)
torch.jit.save(scripted_target, f"{path}.target.jit")
# Save parameters separately as JSON
params = {
'epsilon': float(getattr(model, 'epsilon', 0.0)),
'state_size': int(getattr(model, 'state_size', 0)),
'action_size': int(getattr(model, 'action_size', 0)),
'hidden_size': int(getattr(model, 'hidden_size', 0))
}
with open(f"{path}.params.json", "w") as f:
json.dump(params, f)
logger.info(f"Successfully saved model components with jit.save")
return True
except Exception as e:
logger.error(f"All save attempts failed: {e}")
return False
def robust_load(model: Any, path: str, device: Optional[torch.device] = None) -> bool:
"""
Robust model loading with fallback approaches
Args:
model: The model object to load into
path: Path to load the model from
device: Device to load the model on
Returns:
bool: True if successful, False otherwise
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Try regular PyTorch load first
try:
logger.info(f"Loading model from {path}")
if os.path.exists(path):
checkpoint = torch.load(path, map_location=device)
# Load network states
if 'policy_net' in checkpoint:
model.policy_net.load_state_dict(checkpoint['policy_net'])
if 'target_net' in checkpoint:
model.target_net.load_state_dict(checkpoint['target_net'])
# Load other attributes
if 'epsilon' in checkpoint:
model.epsilon = checkpoint['epsilon']
if 'optimizer' in checkpoint and hasattr(model, 'optimizer') and model.optimizer is not None:
try:
model.optimizer.load_state_dict(checkpoint['optimizer'])
except Exception as e:
logger.warning(f"Failed to load optimizer state: {e}")
logger.info("Successfully loaded model")
return True
except Exception as e:
logger.warning(f"Regular load failed: {e}")
# Try loading JIT saved components
try:
policy_path = f"{path}.policy.jit"
target_path = f"{path}.target.jit"
params_path = f"{path}.params.json"
if all(os.path.exists(p) for p in [policy_path, target_path, params_path]):
logger.info(f"Loading JIT model components")
# Load JIT models (this is more complex and may need model reconstruction)
# For now, just log that we found JIT files
logger.info("Found JIT model files, but loading them requires special handling")
with open(params_path, 'r') as f:
params = json.load(f)
logger.info(f"Model parameters: {params}")
# Note: Actually loading JIT models would require recreating the model architecture
# This is a placeholder for future implementation
return False
except Exception as e:
logger.error(f"JIT load failed: {e}")
logger.error(f"All load attempts failed for {path}")
return False
def get_model_info(path: str) -> Dict[str, Any]:
"""
Get information about a saved model
Args:
path: Path to the model file
Returns:
dict: Model information
"""
info = {
'exists': False,
'size_bytes': 0,
'has_optimizer': False,
'parameters': {}
}
try:
if os.path.exists(path):
info['exists'] = True
info['size_bytes'] = os.path.getsize(path)
# Try to load and inspect
checkpoint = torch.load(path, map_location='cpu')
info['has_optimizer'] = 'optimizer' in checkpoint
# Extract parameter info
for key in ['epsilon', 'state_size', 'action_size', 'hidden_size']:
if key in checkpoint:
info['parameters'][key] = checkpoint[key]
except Exception as e:
logger.warning(f"Failed to get model info for {path}: {e}")
return info
def verify_save_load_cycle(model: Any, test_path: str) -> bool:
"""
Test that a model can be saved and loaded correctly
Args:
model: Model to test
test_path: Path for test file
Returns:
bool: True if save/load cycle successful
"""
try:
# Save the model
if not robust_save(model, test_path):
return False
# Create a new model instance (this would need model creation logic)
# For now, just verify the file exists and has content
if os.path.exists(test_path) and os.path.getsize(test_path) > 0:
logger.info("Save/load cycle verification successful")
# Clean up test file
os.remove(test_path)
return True
else:
return False
except Exception as e:
logger.error(f"Save/load cycle verification failed: {e}")
return False

View File

@@ -1,238 +0,0 @@
#!/usr/bin/env python3
"""
Port Management Utility
This script provides utilities to:
1. Find available ports in a specified range
2. Kill stale processes running on specific ports
3. Kill all debug/training instances
Usage:
- As a module: import port_manager and use its functions
- Directly: python port_manager.py --kill-stale --min-port 6000 --max-port 7000
"""
import os
import sys
import socket
import argparse
import psutil
import logging
import time
import signal
from typing import List, Tuple, Optional, Set
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('port_manager')
# Define process names to look for when killing stale instances
DEBUG_PROCESS_KEYWORDS = [
'tensorboard',
'python train_',
'realtime.py',
'train_rl_with_realtime.py'
]
def is_port_in_use(port: int) -> bool:
"""
Check if a port is in use
Args:
port (int): Port number to check
Returns:
bool: True if port is in use, False otherwise
"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
def find_available_port(start_port: int, end_port: int) -> Optional[int]:
"""
Find an available port in the specified range
Args:
start_port (int): Lower bound of port range
end_port (int): Upper bound of port range
Returns:
Optional[int]: Available port number or None if no ports available
"""
for port in range(start_port, end_port + 1):
if not is_port_in_use(port):
return port
return None
def get_process_by_port(port: int) -> List[psutil.Process]:
"""
Get processes using a specific port
Args:
port (int): Port number to check
Returns:
List[psutil.Process]: List of processes using the port
"""
processes = []
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
try:
for conn in proc.connections(kind='inet'):
if conn.laddr.port == port:
processes.append(proc)
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
pass
return processes
def kill_process_by_port(port: int) -> Tuple[int, List[str]]:
"""
Kill processes using a specific port
Args:
port (int): Port number to check
Returns:
Tuple[int, List[str]]: Count of killed processes and their names
"""
processes = get_process_by_port(port)
killed = []
for proc in processes:
try:
proc_name = " ".join(proc.cmdline()) if proc.cmdline() else proc.name()
logger.info(f"Terminating process {proc.pid}: {proc_name}")
proc.terminate()
killed.append(proc_name)
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
# Give processes time to terminate gracefully
if processes:
time.sleep(0.5)
# Force kill any remaining processes
for proc in processes:
try:
if proc.is_running():
logger.info(f"Force killing process {proc.pid}")
proc.kill()
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
return len(killed), killed
def kill_stale_debug_instances() -> Tuple[int, Set[str]]:
"""
Kill all stale debug and training instances based on process names
Returns:
Tuple[int, Set[str]]: Count of killed processes and their names
"""
killed_count = 0
killed_procs = set()
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
try:
cmd = " ".join(proc.cmdline()) if proc.cmdline() else proc.name()
# Check if this is a debug/training process we should kill
if any(keyword in cmd for keyword in DEBUG_PROCESS_KEYWORDS):
logger.info(f"Terminating stale process {proc.pid}: {cmd}")
proc.terminate()
killed_count += 1
killed_procs.add(cmd)
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
pass
# Give processes time to terminate
if killed_count > 0:
time.sleep(1)
# Force kill any remaining processes
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
try:
cmd = " ".join(proc.cmdline()) if proc.cmdline() else proc.name()
if any(keyword in cmd for keyword in DEBUG_PROCESS_KEYWORDS) and proc.is_running():
logger.info(f"Force killing stale process {proc.pid}")
proc.kill()
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
pass
return killed_count, killed_procs
def get_port_with_fallback(preferred_port: int, min_port: int, max_port: int) -> int:
"""
Try to use preferred port, fall back to any available port in range
Args:
preferred_port (int): Preferred port to use
min_port (int): Minimum port in fallback range
max_port (int): Maximum port in fallback range
Returns:
int: Available port number
"""
# First try the preferred port
if not is_port_in_use(preferred_port):
return preferred_port
# If preferred port is in use, try to free it
logger.info(f"Preferred port {preferred_port} is in use, attempting to free it")
kill_count, _ = kill_process_by_port(preferred_port)
if kill_count > 0 and not is_port_in_use(preferred_port):
logger.info(f"Successfully freed port {preferred_port}")
return preferred_port
# If we couldn't free the preferred port, find another available port
logger.info(f"Looking for available port in range {min_port}-{max_port}")
available_port = find_available_port(min_port, max_port)
if available_port:
logger.info(f"Using alternative port: {available_port}")
return available_port
else:
# If no ports are available, force kill processes in the entire range
logger.warning(f"No available ports in range {min_port}-{max_port}, freeing ports")
for port in range(min_port, max_port + 1):
kill_process_by_port(port)
# Try again
available_port = find_available_port(min_port, max_port)
if available_port:
logger.info(f"Using port {available_port} after freeing")
return available_port
else:
logger.error(f"Could not find available port even after freeing range {min_port}-{max_port}")
raise RuntimeError(f"No available ports in range {min_port}-{max_port}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Port management utility')
parser.add_argument('--kill-stale', action='store_true', help='Kill all stale debug instances')
parser.add_argument('--free-port', type=int, help='Free a specific port')
parser.add_argument('--find-port', action='store_true', help='Find an available port')
parser.add_argument('--min-port', type=int, default=6000, help='Minimum port in range')
parser.add_argument('--max-port', type=int, default=7000, help='Maximum port in range')
parser.add_argument('--preferred-port', type=int, help='Preferred port to use')
args = parser.parse_args()
if args.kill_stale:
count, procs = kill_stale_debug_instances()
logger.info(f"Killed {count} stale processes")
for proc in procs:
logger.info(f" - {proc}")
if args.free_port:
count, killed = kill_process_by_port(args.free_port)
logger.info(f"Killed {count} processes using port {args.free_port}")
for proc in killed:
logger.info(f" - {proc}")
if args.find_port or args.preferred_port:
preferred = args.preferred_port if args.preferred_port else args.min_port
port = get_port_with_fallback(preferred, args.min_port, args.max_port)
print(port) # Print only the port number for easy capture in scripts

156
utils/text_logger.py Normal file
View File

@@ -0,0 +1,156 @@
"""
Text File Logger for Trading System
Simple text file logging for tracking inference records and system events
Provides human-readable logs alongside database storage
"""
import os
import logging
from datetime import datetime
from typing import Dict, Any, Optional
from pathlib import Path
logger = logging.getLogger(__name__)
class TextLogger:
"""Simple text file logger for trading system events"""
def __init__(self, log_dir: str = "logs"):
self.log_dir = Path(log_dir)
self.log_dir.mkdir(exist_ok=True)
# Create separate log files for different types of events
self.inference_log = self.log_dir / "inference_records.txt"
self.checkpoint_log = self.log_dir / "checkpoint_events.txt"
self.system_log = self.log_dir / "system_events.txt"
def log_inference(self, model_name: str, symbol: str, action: str,
confidence: float, processing_time_ms: float,
checkpoint_id: str = None) -> bool:
"""Log inference record to text file"""
try:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
checkpoint_info = f" [checkpoint: {checkpoint_id}]" if checkpoint_id else ""
log_entry = (
f"{timestamp} | {model_name:15} | {symbol:10} | "
f"{action:4} | conf={confidence:.3f} | "
f"time={processing_time_ms:6.1f}ms{checkpoint_info}\n"
)
with open(self.inference_log, 'a', encoding='utf-8') as f:
f.write(log_entry)
f.flush()
return True
except Exception as e:
logger.error(f"Failed to log inference to text file: {e}")
return False
def log_checkpoint_event(self, model_name: str, event_type: str,
checkpoint_id: str, details: str = "") -> bool:
"""Log checkpoint events (save, load, etc.)"""
try:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
details_str = f" - {details}" if details else ""
log_entry = (
f"{timestamp} | {model_name:15} | {event_type:10} | "
f"{checkpoint_id}{details_str}\n"
)
with open(self.checkpoint_log, 'a', encoding='utf-8') as f:
f.write(log_entry)
f.flush()
return True
except Exception as e:
logger.error(f"Failed to log checkpoint event to text file: {e}")
return False
def log_system_event(self, event_type: str, message: str,
component: str = "system") -> bool:
"""Log general system events"""
try:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_entry = (
f"{timestamp} | {component:15} | {event_type:10} | {message}\n"
)
with open(self.system_log, 'a', encoding='utf-8') as f:
f.write(log_entry)
f.flush()
return True
except Exception as e:
logger.error(f"Failed to log system event to text file: {e}")
return False
def get_recent_inferences(self, lines: int = 50) -> str:
"""Get recent inference records from text file"""
try:
if not self.inference_log.exists():
return "No inference records found"
with open(self.inference_log, 'r', encoding='utf-8') as f:
all_lines = f.readlines()
recent_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
return ''.join(recent_lines)
except Exception as e:
logger.error(f"Failed to read inference log: {e}")
return f"Error reading log: {e}"
def get_recent_checkpoint_events(self, lines: int = 20) -> str:
"""Get recent checkpoint events from text file"""
try:
if not self.checkpoint_log.exists():
return "No checkpoint events found"
with open(self.checkpoint_log, 'r', encoding='utf-8') as f:
all_lines = f.readlines()
recent_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
return ''.join(recent_lines)
except Exception as e:
logger.error(f"Failed to read checkpoint log: {e}")
return f"Error reading log: {e}"
def cleanup_old_logs(self, max_lines: int = 10000) -> bool:
"""Keep only the most recent log entries"""
try:
for log_file in [self.inference_log, self.checkpoint_log, self.system_log]:
if log_file.exists():
with open(log_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
if len(lines) > max_lines:
# Keep only the most recent lines
recent_lines = lines[-max_lines:]
with open(log_file, 'w', encoding='utf-8') as f:
f.writelines(recent_lines)
logger.info(f"Cleaned up {log_file.name}: kept {len(recent_lines)} lines")
return True
except Exception as e:
logger.error(f"Failed to cleanup logs: {e}")
return False
# Global text logger instance
_text_logger_instance = None
def get_text_logger(log_dir: str = "logs") -> TextLogger:
"""Get the global text logger instance"""
global _text_logger_instance
if _text_logger_instance is None:
_text_logger_instance = TextLogger(log_dir)
return _text_logger_instance

252
utils/timezone_utils.py Normal file
View File

@@ -0,0 +1,252 @@
#!/usr/bin/env python3
"""
Centralized timezone utilities for the trading system
This module provides consistent timezone handling across all components:
- All external data (Binance, MEXC) comes in UTC
- All internal processing uses Europe/Sofia timezone
- All timestamps stored in database are timezone-aware
- All NN model inputs use consistent timezone
"""
import pytz
import pandas as pd
from datetime import datetime, timezone
from typing import Union, Optional
import logging
logger = logging.getLogger(__name__)
# Define timezone constants
UTC = pytz.UTC
SOFIA_TZ = pytz.timezone('Europe/Sofia')
SYSTEM_TIMEZONE = SOFIA_TZ # Our system's primary timezone
def get_system_timezone():
"""Get the system's primary timezone (Europe/Sofia)"""
return SYSTEM_TIMEZONE
def get_utc_timezone():
"""Get UTC timezone"""
return UTC
def now_utc() -> datetime:
"""Get current time in UTC"""
return datetime.now(UTC)
def now_sofia() -> datetime:
"""Get current time in Sofia timezone"""
return datetime.now(SOFIA_TZ)
def now_system() -> datetime:
"""Get current time in system timezone (Sofia)"""
return now_sofia()
def to_utc(dt: Union[datetime, pd.Timestamp]) -> datetime:
"""Convert datetime to UTC timezone"""
if dt is None:
return None
if isinstance(dt, pd.Timestamp):
dt = dt.to_pydatetime()
if dt.tzinfo is None:
# Assume it's in system timezone if no timezone info
dt = SYSTEM_TIMEZONE.localize(dt)
return dt.astimezone(UTC)
def to_sofia(dt: Union[datetime, pd.Timestamp]) -> datetime:
"""Convert datetime to Sofia timezone"""
if dt is None:
return None
if isinstance(dt, pd.Timestamp):
dt = dt.to_pydatetime()
if dt.tzinfo is None:
# Assume it's UTC if no timezone info (common for external data)
dt = UTC.localize(dt)
return dt.astimezone(SOFIA_TZ)
def to_system_timezone(dt: Union[datetime, pd.Timestamp]) -> datetime:
"""Convert datetime to system timezone (Sofia)"""
return to_sofia(dt)
def normalize_timestamp(timestamp: Union[int, float, str, datetime, pd.Timestamp],
source_tz: Optional[pytz.BaseTzInfo] = None) -> datetime:
"""
Normalize various timestamp formats to system timezone (Sofia)
Args:
timestamp: Timestamp in various formats
source_tz: Source timezone (defaults to UTC for external data)
Returns:
datetime: Timezone-aware datetime in system timezone
"""
if timestamp is None:
return now_system()
# Default source timezone is UTC (most external APIs use UTC)
if source_tz is None:
source_tz = UTC
try:
# Handle different timestamp formats
if isinstance(timestamp, (int, float)):
# Unix timestamp (assume seconds, convert to milliseconds if needed)
if timestamp > 1e10: # Milliseconds
timestamp = timestamp / 1000
dt = datetime.fromtimestamp(timestamp, tz=source_tz)
elif isinstance(timestamp, str):
# String timestamp
dt = pd.to_datetime(timestamp)
if dt.tzinfo is None:
dt = source_tz.localize(dt)
elif isinstance(timestamp, pd.Timestamp):
dt = timestamp.to_pydatetime()
if dt.tzinfo is None:
dt = source_tz.localize(dt)
elif isinstance(timestamp, datetime):
dt = timestamp
if dt.tzinfo is None:
dt = source_tz.localize(dt)
else:
logger.warning(f"Unknown timestamp format: {type(timestamp)}")
return now_system()
# Convert to system timezone
return dt.astimezone(SYSTEM_TIMEZONE)
except Exception as e:
logger.error(f"Error normalizing timestamp {timestamp}: {e}")
return now_system()
def normalize_dataframe_timestamps(df: pd.DataFrame,
timestamp_col: str = 'timestamp',
source_tz: Optional[pytz.BaseTzInfo] = None) -> pd.DataFrame:
"""
Normalize timestamps in a DataFrame to system timezone
Args:
df: DataFrame with timestamp column
timestamp_col: Name of timestamp column
source_tz: Source timezone (defaults to UTC)
Returns:
DataFrame with normalized timestamps
"""
if df.empty or timestamp_col not in df.columns:
return df
if source_tz is None:
source_tz = UTC
try:
# Convert to datetime if not already
if not pd.api.types.is_datetime64_any_dtype(df[timestamp_col]):
df[timestamp_col] = pd.to_datetime(df[timestamp_col])
# Handle timezone
if df[timestamp_col].dt.tz is None:
# Localize to source timezone first
df[timestamp_col] = df[timestamp_col].dt.tz_localize(source_tz)
# Convert to system timezone
df[timestamp_col] = df[timestamp_col].dt.tz_convert(SYSTEM_TIMEZONE)
return df
except Exception as e:
logger.error(f"Error normalizing DataFrame timestamps: {e}")
return df
def normalize_dataframe_index(df: pd.DataFrame,
source_tz: Optional[pytz.BaseTzInfo] = None) -> pd.DataFrame:
"""
Normalize DataFrame index timestamps to system timezone
Args:
df: DataFrame with datetime index
source_tz: Source timezone (defaults to UTC)
Returns:
DataFrame with normalized index
"""
if df.empty or not isinstance(df.index, pd.DatetimeIndex):
return df
if source_tz is None:
source_tz = UTC
try:
# Handle timezone
if df.index.tz is None:
# Localize to source timezone first
df.index = df.index.tz_localize(source_tz)
# Convert to system timezone
df.index = df.index.tz_convert(SYSTEM_TIMEZONE)
return df
except Exception as e:
logger.error(f"Error normalizing DataFrame index: {e}")
return df
def format_timestamp_for_display(dt: datetime, format_str: str = '%H:%M:%S') -> str:
"""
Format timestamp for display in system timezone
Args:
dt: Datetime to format
format_str: Format string
Returns:
Formatted timestamp string
"""
if dt is None:
return now_system().strftime(format_str)
try:
# Convert to system timezone if needed
if isinstance(dt, datetime):
if dt.tzinfo is None:
dt = UTC.localize(dt)
dt = dt.astimezone(SYSTEM_TIMEZONE)
return dt.strftime(format_str)
except Exception as e:
logger.error(f"Error formatting timestamp {dt}: {e}")
return now_system().strftime(format_str)
def get_timezone_offset_hours() -> float:
"""Get current timezone offset from UTC in hours"""
now = now_system()
utc_now = now_utc()
offset_seconds = (now - utc_now.replace(tzinfo=None)).total_seconds()
return offset_seconds / 3600
def is_market_hours() -> bool:
"""Check if it's currently market hours (24/7 for crypto, but useful for logging)"""
# Crypto markets are 24/7, but this can be useful for other purposes
return True
def log_timezone_info():
"""Log current timezone information for debugging"""
now_utc_time = now_utc()
now_sofia_time = now_sofia()
offset_hours = get_timezone_offset_hours()
logger.info(f"Timezone Info:")
logger.info(f" UTC Time: {now_utc_time}")
logger.info(f" Sofia Time: {now_sofia_time}")
logger.info(f" Offset: {offset_hours:+.1f} hours from UTC")
logger.info(f" System Timezone: {SYSTEM_TIMEZONE}")

View File

@@ -0,0 +1,190 @@
#!/usr/bin/env python3
"""
Training Integration for Checkpoint Management
"""
import logging
import torch
from datetime import datetime
from typing import Dict, Any, Optional
from pathlib import Path
from .checkpoint_manager import get_checkpoint_manager, load_best_checkpoint
logger = logging.getLogger(__name__)
class TrainingIntegration:
def __init__(self, enable_wandb: bool = True):
self.enable_wandb = enable_wandb
self.checkpoint_manager = get_checkpoint_manager()
def save_cnn_checkpoint(self,
cnn_model,
model_name: str,
epoch: int,
train_accuracy: float,
val_accuracy: float,
train_loss: float,
val_loss: float,
training_time_hours: float = None) -> bool:
try:
performance_metrics = {
'accuracy': train_accuracy,
'val_accuracy': val_accuracy,
'loss': train_loss,
'val_loss': val_loss
}
training_metadata = {
'epoch': epoch,
'training_time_hours': training_time_hours,
'total_parameters': self._count_parameters(cnn_model)
}
if self.enable_wandb:
try:
import wandb
if wandb.run is not None:
wandb.log({
f"{model_name}/train_accuracy": train_accuracy,
f"{model_name}/val_accuracy": val_accuracy,
f"{model_name}/train_loss": train_loss,
f"{model_name}/val_loss": val_loss,
f"{model_name}/epoch": epoch
})
except Exception as e:
logger.warning(f"Error logging to W&B: {e}")
# Save the model first to get the path
model_path = f"models/{model_name}_temp.pt"
torch.save(cnn_model.state_dict(), model_path)
metadata = self.checkpoint_manager.save_checkpoint(
model_name=model_name,
model_path=model_path,
model_type='cnn',
performance_metrics=performance_metrics,
training_metadata=training_metadata
)
if metadata:
logger.info(f"CNN checkpoint saved: {metadata.checkpoint_id}")
return True
else:
logger.info(f"CNN checkpoint not saved (performance not improved)")
return False
except Exception as e:
logger.error(f"Error saving CNN checkpoint: {e}")
return False
def save_rl_checkpoint(self,
rl_agent,
model_name: str,
episode: int,
avg_reward: float,
best_reward: float,
epsilon: float,
total_pnl: float = None) -> bool:
try:
performance_metrics = {
'reward': avg_reward,
'best_reward': best_reward
}
if total_pnl is not None:
performance_metrics['pnl'] = total_pnl
training_metadata = {
'episode': episode,
'epsilon': epsilon,
'total_parameters': self._count_parameters(rl_agent)
}
if self.enable_wandb:
try:
import wandb
if wandb.run is not None:
wandb.log({
f"{model_name}/avg_reward": avg_reward,
f"{model_name}/best_reward": best_reward,
f"{model_name}/epsilon": epsilon,
f"{model_name}/episode": episode
})
if total_pnl is not None:
wandb.log({f"{model_name}/total_pnl": total_pnl})
except Exception as e:
logger.warning(f"Error logging to W&B: {e}")
# Save the model first to get the path
model_path = f"models/{model_name}_temp.pt"
torch.save(rl_agent.state_dict() if hasattr(rl_agent, 'state_dict') else rl_agent, model_path)
metadata = self.checkpoint_manager.save_checkpoint(
model_name=model_name,
model_path=model_path,
model_type='rl',
performance_metrics=performance_metrics,
training_metadata=training_metadata
)
if metadata:
logger.info(f"RL checkpoint saved: {metadata.checkpoint_id}")
return True
else:
logger.info(f"RL checkpoint not saved (performance not improved)")
return False
except Exception as e:
logger.error(f"Error saving RL checkpoint: {e}")
return False
def load_best_model(self, model_name: str, model_class=None):
try:
result = self.checkpoint_manager.load_best_checkpoint(model_name)
if not result:
logger.warning(f"No checkpoint found for model: {model_name}")
return None
file_path, metadata = result
checkpoint = torch.load(file_path, map_location='cpu')
logger.info(f"Loaded best checkpoint for {model_name}:")
logger.info(f" Performance score: {metadata.performance_score:.4f}")
logger.info(f" Created: {metadata.created_at}")
if model_class and 'model_state_dict' in checkpoint:
model = model_class()
model.load_state_dict(checkpoint['model_state_dict'])
return model
return checkpoint
except Exception as e:
logger.error(f"Error loading best model {model_name}: {e}")
return None
def _count_parameters(self, model) -> int:
try:
if hasattr(model, 'parameters'):
return sum(p.numel() for p in model.parameters())
elif hasattr(model, 'policy_net'):
policy_params = sum(p.numel() for p in model.policy_net.parameters())
target_params = sum(p.numel() for p in model.target_net.parameters()) if hasattr(model, 'target_net') else 0
return policy_params + target_params
else:
return 0
except Exception:
return 0
_training_integration = None
def get_training_integration() -> TrainingIntegration:
global _training_integration
if _training_integration is None:
_training_integration = TrainingIntegration()
return _training_integration