inrefence predictions fix
This commit is contained in:
@ -11,7 +11,8 @@ import sqlite3
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, asdict
|
||||
@ -30,6 +31,7 @@ class InferenceRecord:
|
||||
input_features_hash: str # Hash of input features for deduplication
|
||||
processing_time_ms: float
|
||||
memory_usage_mb: float
|
||||
input_features: Optional[np.ndarray] = None # Full input features for training
|
||||
checkpoint_id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
@ -72,6 +74,7 @@ class DatabaseManager:
|
||||
confidence REAL NOT NULL,
|
||||
probabilities TEXT NOT NULL, -- JSON
|
||||
input_features_hash TEXT NOT NULL,
|
||||
input_features_blob BLOB, -- Store full input features for training
|
||||
processing_time_ms REAL NOT NULL,
|
||||
memory_usage_mb REAL NOT NULL,
|
||||
checkpoint_id TEXT,
|
||||
@ -142,12 +145,17 @@ class DatabaseManager:
|
||||
"""Log an inference record"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
# Serialize input features if provided
|
||||
input_features_blob = None
|
||||
if record.input_features is not None:
|
||||
input_features_blob = record.input_features.tobytes()
|
||||
|
||||
conn.execute("""
|
||||
INSERT INTO inference_records (
|
||||
model_name, timestamp, symbol, action, confidence,
|
||||
probabilities, input_features_hash, processing_time_ms,
|
||||
memory_usage_mb, checkpoint_id, metadata
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
probabilities, input_features_hash, input_features_blob,
|
||||
processing_time_ms, memory_usage_mb, checkpoint_id, metadata
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
record.model_name,
|
||||
record.timestamp.isoformat(),
|
||||
@ -156,6 +164,7 @@ class DatabaseManager:
|
||||
record.confidence,
|
||||
json.dumps(record.probabilities),
|
||||
record.input_features_hash,
|
||||
input_features_blob,
|
||||
record.processing_time_ms,
|
||||
record.memory_usage_mb,
|
||||
record.checkpoint_id,
|
||||
@ -332,6 +341,15 @@ class DatabaseManager:
|
||||
|
||||
records = []
|
||||
for row in cursor.fetchall():
|
||||
# Deserialize input features if available
|
||||
input_features = None
|
||||
if row['input_features_blob']:
|
||||
try:
|
||||
# Reconstruct numpy array from bytes
|
||||
input_features = np.frombuffer(row['input_features_blob'], dtype=np.float32)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to deserialize input features: {e}")
|
||||
|
||||
records.append(InferenceRecord(
|
||||
model_name=row['model_name'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
@ -342,6 +360,7 @@ class DatabaseManager:
|
||||
input_features_hash=row['input_features_hash'],
|
||||
processing_time_ms=row['processing_time_ms'],
|
||||
memory_usage_mb=row['memory_usage_mb'],
|
||||
input_features=input_features,
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
metadata=json.loads(row['metadata']) if row['metadata'] else None
|
||||
))
|
||||
@ -373,6 +392,75 @@ class DatabaseManager:
|
||||
logger.error(f"Failed to update model performance: {e}")
|
||||
return False
|
||||
|
||||
def get_inference_records_for_training(self, model_name: str,
|
||||
symbol: str = None,
|
||||
hours_back: int = 24,
|
||||
limit: int = 1000) -> List[InferenceRecord]:
|
||||
"""
|
||||
Get inference records with input features for training feedback
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
symbol: Optional symbol filter
|
||||
hours_back: How many hours back to look
|
||||
limit: Maximum number of records
|
||||
|
||||
Returns:
|
||||
List of InferenceRecord with input_features populated
|
||||
"""
|
||||
try:
|
||||
cutoff_time = datetime.now() - timedelta(hours=hours_back)
|
||||
|
||||
with self._get_connection() as conn:
|
||||
if symbol:
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM inference_records
|
||||
WHERE model_name = ? AND symbol = ? AND timestamp >= ?
|
||||
AND input_features_blob IS NOT NULL
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""", (model_name, symbol, cutoff_time.isoformat(), limit))
|
||||
else:
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM inference_records
|
||||
WHERE model_name = ? AND timestamp >= ?
|
||||
AND input_features_blob IS NOT NULL
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""", (model_name, cutoff_time.isoformat(), limit))
|
||||
|
||||
records = []
|
||||
for row in cursor.fetchall():
|
||||
# Deserialize input features
|
||||
input_features = None
|
||||
if row['input_features_blob']:
|
||||
try:
|
||||
input_features = np.frombuffer(row['input_features_blob'], dtype=np.float32)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to deserialize input features: {e}")
|
||||
continue # Skip records with corrupted features
|
||||
|
||||
records.append(InferenceRecord(
|
||||
model_name=row['model_name'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
symbol=row['symbol'],
|
||||
action=row['action'],
|
||||
confidence=row['confidence'],
|
||||
probabilities=json.loads(row['probabilities']),
|
||||
input_features_hash=row['input_features_hash'],
|
||||
processing_time_ms=row['processing_time_ms'],
|
||||
memory_usage_mb=row['memory_usage_mb'],
|
||||
input_features=input_features,
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
metadata=json.loads(row['metadata']) if row['metadata'] else None
|
||||
))
|
||||
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get inference records for training: {e}")
|
||||
return []
|
||||
|
||||
def cleanup_old_records(self, days_to_keep: int = 30) -> bool:
|
||||
"""Clean up old inference records"""
|
||||
try:
|
||||
|
@ -61,6 +61,13 @@ class InferenceLogger:
|
||||
# Get current memory usage
|
||||
memory_usage_mb = self._get_memory_usage()
|
||||
|
||||
# Convert input features to numpy array if needed
|
||||
features_array = None
|
||||
if isinstance(input_features, np.ndarray):
|
||||
features_array = input_features.astype(np.float32)
|
||||
elif isinstance(input_features, (list, tuple)):
|
||||
features_array = np.array(input_features, dtype=np.float32)
|
||||
|
||||
# Create inference record
|
||||
record = InferenceRecord(
|
||||
model_name=model_name,
|
||||
@ -72,6 +79,7 @@ class InferenceLogger:
|
||||
input_features_hash=feature_hash,
|
||||
processing_time_ms=processing_time_ms,
|
||||
memory_usage_mb=memory_usage_mb,
|
||||
input_features=features_array,
|
||||
checkpoint_id=checkpoint_id,
|
||||
metadata=metadata
|
||||
)
|
||||
|
Reference in New Issue
Block a user