fix merge
This commit is contained in:
@@ -267,17 +267,6 @@ class COBRLModelInterface(ModelInterface):
|
||||
|
||||
logger.info(f"COB RL Model Interface initialized on {self.device}")
|
||||
|
||||
<<<<<<< HEAD
|
||||
def predict(self, cob_features) -> Dict[str, Any]:
|
||||
=======
|
||||
def to(self, device):
|
||||
"""PyTorch-style device movement method"""
|
||||
self.device = device
|
||||
self.model = self.model.to(device)
|
||||
return self
|
||||
|
||||
def predict(self, cob_features: np.ndarray) -> Dict[str, Any]:
|
||||
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
|
||||
"""Make prediction using the model"""
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
|
||||
@@ -4,11 +4,6 @@ import torch.optim as optim
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
import random
|
||||
<<<<<<< HEAD
|
||||
from typing import Tuple, List
|
||||
=======
|
||||
from typing import Tuple, List, Dict, Any
|
||||
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
|
||||
@@ -27,18 +27,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
<<<<<<< HEAD
|
||||
# Import prediction tracking
|
||||
from core.prediction_database import get_prediction_db
|
||||
=======
|
||||
# Import checkpoint management
|
||||
try:
|
||||
from utils.checkpoint_manager import get_checkpoint_manager, save_checkpoint
|
||||
CHECKPOINT_MANAGER_AVAILABLE = True
|
||||
except ImportError:
|
||||
CHECKPOINT_MANAGER_AVAILABLE = False
|
||||
logger.warning("Checkpoint manager not available. Model persistence will be disabled.")
|
||||
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1878,33 +1866,6 @@ class EnhancedRealtimeTrainingSystem:
|
||||
|
||||
outputs = model(features_tensor)
|
||||
|
||||
<<<<<<< HEAD
|
||||
# Extract logits from model output (model returns a dictionary)
|
||||
if isinstance(outputs, dict):
|
||||
logits = outputs['logits']
|
||||
elif isinstance(outputs, tuple):
|
||||
logits = outputs[0] # First element is usually logits
|
||||
else:
|
||||
logits = outputs
|
||||
|
||||
# Ensure logits is a tensor
|
||||
if not isinstance(logits, torch.Tensor):
|
||||
logger.error(f"CNN output is not a tensor: {type(logits)}")
|
||||
return 0.0
|
||||
|
||||
=======
|
||||
# FIXED: Handle case where model returns tuple (extract the logits)
|
||||
if isinstance(outputs, tuple):
|
||||
# Assume the first element is the main output (logits)
|
||||
logits = outputs[0]
|
||||
elif isinstance(outputs, dict):
|
||||
# Handle dictionary output (get main prediction)
|
||||
logits = outputs.get('logits', outputs.get('predictions', outputs.get('output', list(outputs.values())[0])))
|
||||
else:
|
||||
# Single tensor output
|
||||
logits = outputs
|
||||
|
||||
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
|
||||
loss = criterion(logits, targets_tensor)
|
||||
|
||||
loss.backward()
|
||||
@@ -2404,46 +2365,6 @@ class EnhancedRealtimeTrainingSystem:
|
||||
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
|
||||
and self.orchestrator.rl_agent):
|
||||
|
||||
<<<<<<< HEAD
|
||||
# Use RL agent to make prediction
|
||||
current_state = self._get_dqn_state_features(symbol)
|
||||
if current_state is None:
|
||||
return
|
||||
action = self.orchestrator.rl_agent.act(current_state, explore=False)
|
||||
# Get Q-values separately if available
|
||||
if hasattr(self.orchestrator.rl_agent, 'policy_net'):
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(current_state).unsqueeze(0).to(self.orchestrator.rl_agent.device)
|
||||
q_values_tensor = self.orchestrator.rl_agent.policy_net(state_tensor)
|
||||
if isinstance(q_values_tensor, tuple):
|
||||
q_values = q_values_tensor[0].cpu().numpy()[0].tolist()
|
||||
else:
|
||||
q_values = [0.33, 0.33, 0.34] # Default uniform distribution
|
||||
|
||||
confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33
|
||||
|
||||
=======
|
||||
# Get action from DQN agent
|
||||
action = self.orchestrator.rl_agent.act(current_state, explore=False)
|
||||
|
||||
# Get Q-values by manually calling the model
|
||||
q_values = self._get_dqn_q_values(current_state)
|
||||
|
||||
# Calculate confidence from Q-values
|
||||
if q_values is not None and len(q_values) > 0:
|
||||
# Convert to probabilities and get confidence
|
||||
probs = torch.softmax(torch.tensor(q_values), dim=0).numpy()
|
||||
confidence = float(max(probs))
|
||||
q_values = q_values.tolist() if hasattr(q_values, 'tolist') else list(q_values)
|
||||
else:
|
||||
confidence = 0.33
|
||||
q_values = [0.33, 0.33, 0.34] # Default uniform distribution
|
||||
|
||||
# Handle case where action is None (HOLD)
|
||||
if action is None:
|
||||
action = 2 # Map None to HOLD action
|
||||
|
||||
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
|
||||
else:
|
||||
# Fallback to technical analysis-based prediction
|
||||
action, q_values, confidence = self._technical_analysis_prediction(symbol)
|
||||
@@ -2484,21 +2405,6 @@ class EnhancedRealtimeTrainingSystem:
|
||||
|
||||
self.last_prediction_time[symbol] = int(current_time)
|
||||
|
||||
<<<<<<< HEAD
|
||||
# Robust action labeling
|
||||
if action is None:
|
||||
action_label = 'HOLD'
|
||||
elif action == 0:
|
||||
action_label = 'SELL'
|
||||
elif action == 1:
|
||||
action_label = 'BUY'
|
||||
else:
|
||||
action_label = 'UNKNOWN'
|
||||
|
||||
logger.info(f"Forward DQN prediction: {symbol} action={action_label} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
|
||||
=======
|
||||
logger.info(f"Forward DQN prediction: {symbol} action={['BUY','SELL','HOLD'][action]} confidence={confidence:.2f} price=${current_price:.2f} target={target_time.strftime('%H:%M:%S')} dims={len(current_state)}")
|
||||
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating forward DQN prediction: {e}")
|
||||
|
||||
Reference in New Issue
Block a user