inference works
This commit is contained in:
@ -332,7 +332,7 @@ class TradingOrchestrator:
|
|||||||
self.model_states['dqn']['checkpoint_filename'] = 'none (fresh start)'
|
self.model_states['dqn']['checkpoint_filename'] = 'none (fresh start)'
|
||||||
logger.info("DQN starting fresh - no checkpoint found")
|
logger.info("DQN starting fresh - no checkpoint found")
|
||||||
|
|
||||||
logger.info(f"DQN Agent initialized: {state_size} state features, {action_size} actions")
|
logger.info(f"DQN Agent initialized: {actual_state_size} state features, {action_size} actions")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("DQN Agent not available")
|
logger.warning("DQN Agent not available")
|
||||||
self.rl_agent = None
|
self.rl_agent = None
|
||||||
@ -474,6 +474,7 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
# CRITICAL: Register models with the model registry
|
# CRITICAL: Register models with the model registry
|
||||||
logger.info("Registering models with model registry...")
|
logger.info("Registering models with model registry...")
|
||||||
|
logger.info(f"Model registry before registration: {len(self.model_registry.models)} models")
|
||||||
|
|
||||||
# Import model interfaces
|
# Import model interfaces
|
||||||
# These are now imported at the top of the file
|
# These are now imported at the top of the file
|
||||||
@ -482,8 +483,11 @@ class TradingOrchestrator:
|
|||||||
if self.rl_agent:
|
if self.rl_agent:
|
||||||
try:
|
try:
|
||||||
rl_interface = RLAgentInterface(self.rl_agent, name="dqn_agent")
|
rl_interface = RLAgentInterface(self.rl_agent, name="dqn_agent")
|
||||||
self.register_model(rl_interface, weight=0.2)
|
success = self.register_model(rl_interface, weight=0.2)
|
||||||
|
if success:
|
||||||
logger.info("RL Agent registered successfully")
|
logger.info("RL Agent registered successfully")
|
||||||
|
else:
|
||||||
|
logger.error("Failed to register RL Agent - register_model returned False")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to register RL Agent: {e}")
|
logger.error(f"Failed to register RL Agent: {e}")
|
||||||
|
|
||||||
@ -491,8 +495,11 @@ class TradingOrchestrator:
|
|||||||
if self.cnn_model:
|
if self.cnn_model:
|
||||||
try:
|
try:
|
||||||
cnn_interface = CNNModelInterface(self.cnn_model, name="enhanced_cnn")
|
cnn_interface = CNNModelInterface(self.cnn_model, name="enhanced_cnn")
|
||||||
self.register_model(cnn_interface, weight=0.25)
|
success = self.register_model(cnn_interface, weight=0.25)
|
||||||
|
if success:
|
||||||
logger.info("CNN Model registered successfully")
|
logger.info("CNN Model registered successfully")
|
||||||
|
else:
|
||||||
|
logger.error("Failed to register CNN Model - register_model returned False")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to register CNN Model: {e}")
|
logger.error(f"Failed to register CNN Model: {e}")
|
||||||
|
|
||||||
@ -596,6 +603,8 @@ class TradingOrchestrator:
|
|||||||
# Normalize weights after all registrations
|
# Normalize weights after all registrations
|
||||||
self._normalize_weights()
|
self._normalize_weights()
|
||||||
logger.info(f"Current model weights: {self.model_weights}")
|
logger.info(f"Current model weights: {self.model_weights}")
|
||||||
|
logger.info(f"Model registry after registration: {len(self.model_registry.models)} models")
|
||||||
|
logger.info(f"Registered models: {list(self.model_registry.models.keys())}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error initializing ML models: {e}")
|
logger.error(f"Error initializing ML models: {e}")
|
||||||
@ -2080,14 +2089,7 @@ class TradingOrchestrator:
|
|||||||
# Store prediction in SQLite database for training
|
# Store prediction in SQLite database for training
|
||||||
logger.debug(f"Added CNN prediction to database: {prediction}")
|
logger.debug(f"Added CNN prediction to database: {prediction}")
|
||||||
|
|
||||||
# Store CNN prediction as inference record
|
# Note: Inference data will be stored in main prediction loop to avoid duplication
|
||||||
await self._store_inference_data_async(
|
|
||||||
model_name="enhanced_cnn",
|
|
||||||
model_input=base_data,
|
|
||||||
prediction=prediction,
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
symbol=symbol
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error using CNN adapter: {e}")
|
logger.error(f"Error using CNN adapter: {e}")
|
||||||
@ -2139,14 +2141,7 @@ class TradingOrchestrator:
|
|||||||
)
|
)
|
||||||
predictions.append(pred)
|
predictions.append(pred)
|
||||||
|
|
||||||
# Store CNN fallback prediction as inference record
|
# Note: Inference data will be stored in main prediction loop to avoid duplication
|
||||||
await self._store_inference_data_async(
|
|
||||||
model_name=model.name,
|
|
||||||
model_input=base_data,
|
|
||||||
prediction=pred,
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
symbol=symbol
|
|
||||||
)
|
|
||||||
|
|
||||||
# Capture for dashboard
|
# Capture for dashboard
|
||||||
current_price = self._get_current_price(symbol)
|
current_price = self._get_current_price(symbol)
|
||||||
|
Binary file not shown.
31
reset_db_manager.py
Normal file
31
reset_db_manager.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Script to reset the database manager instance to trigger migration in running system
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from utils.database_manager import reset_database_manager
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Reset the database manager to trigger migration"""
|
||||||
|
try:
|
||||||
|
logger.info("Resetting database manager to trigger migration...")
|
||||||
|
reset_database_manager()
|
||||||
|
logger.info("✅ Database manager reset successfully!")
|
||||||
|
logger.info("The migration will run automatically on the next database access.")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Failed to reset database manager: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = main()
|
||||||
|
sys.exit(0 if success else 1)
|
46
test_db_migration.py
Normal file
46
test_db_migration.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script to verify database migration works correctly
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from utils.database_manager import get_database_manager, reset_database_manager
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def test_migration():
|
||||||
|
"""Test the database migration"""
|
||||||
|
try:
|
||||||
|
logger.info("Testing database migration...")
|
||||||
|
|
||||||
|
# Reset the database manager to force re-initialization
|
||||||
|
reset_database_manager()
|
||||||
|
|
||||||
|
# Get a new instance (this will trigger migration)
|
||||||
|
db_manager = get_database_manager()
|
||||||
|
|
||||||
|
# Test if we can access the input_features_blob column
|
||||||
|
with db_manager._get_connection() as conn:
|
||||||
|
cursor = conn.execute("PRAGMA table_info(inference_records)")
|
||||||
|
columns = [row[1] for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
if 'input_features_blob' in columns:
|
||||||
|
logger.info("✅ input_features_blob column exists - migration successful!")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error("❌ input_features_blob column missing - migration failed!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Migration test failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = test_migration()
|
||||||
|
sys.exit(0 if success else 1)
|
52
test_model_registry.py
Normal file
52
test_model_registry.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add the project root to the path
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def test_model_registry():
|
||||||
|
"""Test the model registry state"""
|
||||||
|
try:
|
||||||
|
from core.orchestrator import TradingOrchestrator
|
||||||
|
from core.data_provider import DataProvider
|
||||||
|
|
||||||
|
logger.info("Testing model registry...")
|
||||||
|
|
||||||
|
# Initialize data provider
|
||||||
|
data_provider = DataProvider()
|
||||||
|
|
||||||
|
# Initialize orchestrator
|
||||||
|
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
||||||
|
|
||||||
|
# Check model registry state
|
||||||
|
logger.info(f"Model registry models: {len(orchestrator.model_registry.models)}")
|
||||||
|
logger.info(f"Registered models: {list(orchestrator.model_registry.models.keys())}")
|
||||||
|
|
||||||
|
# Check if models were created
|
||||||
|
logger.info(f"RL Agent: {orchestrator.rl_agent is not None}")
|
||||||
|
logger.info(f"CNN Model: {orchestrator.cnn_model is not None}")
|
||||||
|
logger.info(f"CNN Adapter: {orchestrator.cnn_adapter is not None}")
|
||||||
|
|
||||||
|
# Check model weights
|
||||||
|
logger.info(f"Model weights: {orchestrator.model_weights}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error testing model registry: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = test_model_registry()
|
||||||
|
if success:
|
||||||
|
logger.info("✅ Model registry test completed successfully")
|
||||||
|
else:
|
||||||
|
logger.error("❌ Model registry test failed")
|
@ -124,6 +124,29 @@ class DatabaseManager:
|
|||||||
|
|
||||||
logger.info(f"Database initialized at {self.db_path}")
|
logger.info(f"Database initialized at {self.db_path}")
|
||||||
|
|
||||||
|
# Run migrations to handle schema updates
|
||||||
|
self._run_migrations()
|
||||||
|
|
||||||
|
def _run_migrations(self):
|
||||||
|
"""Run database migrations to handle schema updates"""
|
||||||
|
try:
|
||||||
|
with self._get_connection() as conn:
|
||||||
|
# Check if input_features_blob column exists
|
||||||
|
cursor = conn.execute("PRAGMA table_info(inference_records)")
|
||||||
|
columns = [row[1] for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
if 'input_features_blob' not in columns:
|
||||||
|
logger.info("Adding input_features_blob column to inference_records table")
|
||||||
|
conn.execute("ALTER TABLE inference_records ADD COLUMN input_features_blob BLOB")
|
||||||
|
conn.commit()
|
||||||
|
logger.info("Successfully added input_features_blob column")
|
||||||
|
else:
|
||||||
|
logger.debug("input_features_blob column already exists")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error running database migrations: {e}")
|
||||||
|
# If migration fails, we can still continue without the blob column
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _get_connection(self):
|
def _get_connection(self):
|
||||||
"""Get database connection with proper error handling"""
|
"""Get database connection with proper error handling"""
|
||||||
@ -145,11 +168,18 @@ class DatabaseManager:
|
|||||||
"""Log an inference record"""
|
"""Log an inference record"""
|
||||||
try:
|
try:
|
||||||
with self._get_connection() as conn:
|
with self._get_connection() as conn:
|
||||||
# Serialize input features if provided
|
# Check if input_features_blob column exists
|
||||||
|
cursor = conn.execute("PRAGMA table_info(inference_records)")
|
||||||
|
columns = [row[1] for row in cursor.fetchall()]
|
||||||
|
has_blob_column = 'input_features_blob' in columns
|
||||||
|
|
||||||
|
# Serialize input features if provided and column exists
|
||||||
input_features_blob = None
|
input_features_blob = None
|
||||||
if record.input_features is not None:
|
if record.input_features is not None and has_blob_column:
|
||||||
input_features_blob = record.input_features.tobytes()
|
input_features_blob = record.input_features.tobytes()
|
||||||
|
|
||||||
|
if has_blob_column:
|
||||||
|
# Use full query with blob column
|
||||||
conn.execute("""
|
conn.execute("""
|
||||||
INSERT INTO inference_records (
|
INSERT INTO inference_records (
|
||||||
model_name, timestamp, symbol, action, confidence,
|
model_name, timestamp, symbol, action, confidence,
|
||||||
@ -170,6 +200,29 @@ class DatabaseManager:
|
|||||||
record.checkpoint_id,
|
record.checkpoint_id,
|
||||||
json.dumps(record.metadata) if record.metadata else None
|
json.dumps(record.metadata) if record.metadata else None
|
||||||
))
|
))
|
||||||
|
else:
|
||||||
|
# Fallback query without blob column
|
||||||
|
logger.warning("input_features_blob column missing, storing without full features")
|
||||||
|
conn.execute("""
|
||||||
|
INSERT INTO inference_records (
|
||||||
|
model_name, timestamp, symbol, action, confidence,
|
||||||
|
probabilities, input_features_hash,
|
||||||
|
processing_time_ms, memory_usage_mb, checkpoint_id, metadata
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""", (
|
||||||
|
record.model_name,
|
||||||
|
record.timestamp.isoformat(),
|
||||||
|
record.symbol,
|
||||||
|
record.action,
|
||||||
|
record.confidence,
|
||||||
|
json.dumps(record.probabilities),
|
||||||
|
record.input_features_hash,
|
||||||
|
record.processing_time_ms,
|
||||||
|
record.memory_usage_mb,
|
||||||
|
record.checkpoint_id,
|
||||||
|
json.dumps(record.metadata) if record.metadata else None
|
||||||
|
))
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -343,7 +396,8 @@ class DatabaseManager:
|
|||||||
for row in cursor.fetchall():
|
for row in cursor.fetchall():
|
||||||
# Deserialize input features if available
|
# Deserialize input features if available
|
||||||
input_features = None
|
input_features = None
|
||||||
if row['input_features_blob']:
|
# Check if the column exists in the row (handles missing column gracefully)
|
||||||
|
if 'input_features_blob' in row.keys() and row['input_features_blob']:
|
||||||
try:
|
try:
|
||||||
# Reconstruct numpy array from bytes
|
# Reconstruct numpy array from bytes
|
||||||
input_features = np.frombuffer(row['input_features_blob'], dtype=np.float32)
|
input_features = np.frombuffer(row['input_features_blob'], dtype=np.float32)
|
||||||
@ -412,6 +466,15 @@ class DatabaseManager:
|
|||||||
cutoff_time = datetime.now() - timedelta(hours=hours_back)
|
cutoff_time = datetime.now() - timedelta(hours=hours_back)
|
||||||
|
|
||||||
with self._get_connection() as conn:
|
with self._get_connection() as conn:
|
||||||
|
# Check if input_features_blob column exists before querying
|
||||||
|
cursor = conn.execute("PRAGMA table_info(inference_records)")
|
||||||
|
columns = [row[1] for row in cursor.fetchall()]
|
||||||
|
has_blob_column = 'input_features_blob' in columns
|
||||||
|
|
||||||
|
if not has_blob_column:
|
||||||
|
logger.warning("input_features_blob column not found, returning empty list")
|
||||||
|
return []
|
||||||
|
|
||||||
if symbol:
|
if symbol:
|
||||||
cursor = conn.execute("""
|
cursor = conn.execute("""
|
||||||
SELECT * FROM inference_records
|
SELECT * FROM inference_records
|
||||||
@ -494,3 +557,9 @@ def get_database_manager(db_path: str = "data/trading_system.db") -> DatabaseMan
|
|||||||
_db_manager_instance = DatabaseManager(db_path)
|
_db_manager_instance = DatabaseManager(db_path)
|
||||||
|
|
||||||
return _db_manager_instance
|
return _db_manager_instance
|
||||||
|
|
||||||
|
def reset_database_manager():
|
||||||
|
"""Reset the database manager instance to force re-initialization"""
|
||||||
|
global _db_manager_instance
|
||||||
|
_db_manager_instance = None
|
||||||
|
logger.info("Database manager instance reset - will re-initialize on next access")
|
Reference in New Issue
Block a user