inference works
This commit is contained in:
@ -332,7 +332,7 @@ class TradingOrchestrator:
|
||||
self.model_states['dqn']['checkpoint_filename'] = 'none (fresh start)'
|
||||
logger.info("DQN starting fresh - no checkpoint found")
|
||||
|
||||
logger.info(f"DQN Agent initialized: {state_size} state features, {action_size} actions")
|
||||
logger.info(f"DQN Agent initialized: {actual_state_size} state features, {action_size} actions")
|
||||
except ImportError:
|
||||
logger.warning("DQN Agent not available")
|
||||
self.rl_agent = None
|
||||
@ -474,6 +474,7 @@ class TradingOrchestrator:
|
||||
|
||||
# CRITICAL: Register models with the model registry
|
||||
logger.info("Registering models with model registry...")
|
||||
logger.info(f"Model registry before registration: {len(self.model_registry.models)} models")
|
||||
|
||||
# Import model interfaces
|
||||
# These are now imported at the top of the file
|
||||
@ -482,8 +483,11 @@ class TradingOrchestrator:
|
||||
if self.rl_agent:
|
||||
try:
|
||||
rl_interface = RLAgentInterface(self.rl_agent, name="dqn_agent")
|
||||
self.register_model(rl_interface, weight=0.2)
|
||||
logger.info("RL Agent registered successfully")
|
||||
success = self.register_model(rl_interface, weight=0.2)
|
||||
if success:
|
||||
logger.info("RL Agent registered successfully")
|
||||
else:
|
||||
logger.error("Failed to register RL Agent - register_model returned False")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register RL Agent: {e}")
|
||||
|
||||
@ -491,8 +495,11 @@ class TradingOrchestrator:
|
||||
if self.cnn_model:
|
||||
try:
|
||||
cnn_interface = CNNModelInterface(self.cnn_model, name="enhanced_cnn")
|
||||
self.register_model(cnn_interface, weight=0.25)
|
||||
logger.info("CNN Model registered successfully")
|
||||
success = self.register_model(cnn_interface, weight=0.25)
|
||||
if success:
|
||||
logger.info("CNN Model registered successfully")
|
||||
else:
|
||||
logger.error("Failed to register CNN Model - register_model returned False")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register CNN Model: {e}")
|
||||
|
||||
@ -596,6 +603,8 @@ class TradingOrchestrator:
|
||||
# Normalize weights after all registrations
|
||||
self._normalize_weights()
|
||||
logger.info(f"Current model weights: {self.model_weights}")
|
||||
logger.info(f"Model registry after registration: {len(self.model_registry.models)} models")
|
||||
logger.info(f"Registered models: {list(self.model_registry.models.keys())}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing ML models: {e}")
|
||||
@ -2080,14 +2089,7 @@ class TradingOrchestrator:
|
||||
# Store prediction in SQLite database for training
|
||||
logger.debug(f"Added CNN prediction to database: {prediction}")
|
||||
|
||||
# Store CNN prediction as inference record
|
||||
await self._store_inference_data_async(
|
||||
model_name="enhanced_cnn",
|
||||
model_input=base_data,
|
||||
prediction=prediction,
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol
|
||||
)
|
||||
# Note: Inference data will be stored in main prediction loop to avoid duplication
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error using CNN adapter: {e}")
|
||||
@ -2139,14 +2141,7 @@ class TradingOrchestrator:
|
||||
)
|
||||
predictions.append(pred)
|
||||
|
||||
# Store CNN fallback prediction as inference record
|
||||
await self._store_inference_data_async(
|
||||
model_name=model.name,
|
||||
model_input=base_data,
|
||||
prediction=pred,
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol
|
||||
)
|
||||
# Note: Inference data will be stored in main prediction loop to avoid duplication
|
||||
|
||||
# Capture for dashboard
|
||||
current_price = self._get_current_price(symbol)
|
||||
|
Reference in New Issue
Block a user