fix merge

This commit is contained in:
Dobromir Popov
2025-10-02 23:50:08 +03:00
parent 8654e08028
commit a468c75c47
13 changed files with 150 additions and 14309 deletions

View File

@@ -267,17 +267,6 @@ class COBRLModelInterface(ModelInterface):
logger.info(f"COB RL Model Interface initialized on {self.device}")
<<<<<<< HEAD
def predict(self, cob_features) -> Dict[str, Any]:
=======
def to(self, device):
"""PyTorch-style device movement method"""
self.device = device
self.model = self.model.to(device)
return self
def predict(self, cob_features: np.ndarray) -> Dict[str, Any]:
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
"""Make prediction using the model"""
self.model.eval()
with torch.no_grad():

View File

@@ -4,11 +4,6 @@ import torch.optim as optim
import numpy as np
from collections import deque
import random
<<<<<<< HEAD
from typing import Tuple, List
=======
from typing import Tuple, List, Dict, Any
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
import os
import sys
import logging

View File

@@ -27,18 +27,6 @@ import torch
import torch.nn as nn
import torch.optim as optim
<<<<<<< HEAD
# Import prediction tracking
from core.prediction_database import get_prediction_db
=======
# Import checkpoint management
try:
from utils.checkpoint_manager import get_checkpoint_manager, save_checkpoint
CHECKPOINT_MANAGER_AVAILABLE = True
except ImportError:
CHECKPOINT_MANAGER_AVAILABLE = False
logger.warning("Checkpoint manager not available. Model persistence will be disabled.")
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
logger = logging.getLogger(__name__)
@@ -1878,33 +1866,6 @@ class EnhancedRealtimeTrainingSystem:
outputs = model(features_tensor)
<<<<<<< HEAD
# Extract logits from model output (model returns a dictionary)
if isinstance(outputs, dict):
logits = outputs['logits']
elif isinstance(outputs, tuple):
logits = outputs[0] # First element is usually logits
else:
logits = outputs
# Ensure logits is a tensor
if not isinstance(logits, torch.Tensor):
logger.error(f"CNN output is not a tensor: {type(logits)}")
return 0.0
=======
# FIXED: Handle case where model returns tuple (extract the logits)
if isinstance(outputs, tuple):
# Assume the first element is the main output (logits)
logits = outputs[0]
elif isinstance(outputs, dict):
# Handle dictionary output (get main prediction)
logits = outputs.get('logits', outputs.get('predictions', outputs.get('output', list(outputs.values())[0])))
else:
# Single tensor output
logits = outputs
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
loss = criterion(logits, targets_tensor)
loss.backward()
@@ -2404,46 +2365,6 @@ class EnhancedRealtimeTrainingSystem:
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
and self.orchestrator.rl_agent):
<<<<<<< HEAD
# Use RL agent to make prediction
current_state = self._get_dqn_state_features(symbol)
if current_state is None:
return
action = self.orchestrator.rl_agent.act(current_state, explore=False)
# Get Q-values separately if available
if hasattr(self.orchestrator.rl_agent, 'policy_net'):
with torch.no_grad():
state_tensor = torch.FloatTensor(current_state).unsqueeze(0).to(self.orchestrator.rl_agent.device)
q_values_tensor = self.orchestrator.rl_agent.policy_net(state_tensor)
if isinstance(q_values_tensor, tuple):
q_values = q_values_tensor[0].cpu().numpy()[0].tolist()
else:
q_values = [0.33, 0.33, 0.34] # Default uniform distribution
confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33
=======
# Get action from DQN agent
action = self.orchestrator.rl_agent.act(current_state, explore=False)
# Get Q-values by manually calling the model
q_values = self._get_dqn_q_values(current_state)
# Calculate confidence from Q-values
if q_values is not None and len(q_values) > 0:
# Convert to probabilities and get confidence
probs = torch.softmax(torch.tensor(q_values), dim=0).numpy()
confidence = float(max(probs))
q_values = q_values.tolist() if hasattr(q_values, 'tolist') else list(q_values)
else:
confidence = 0.33
q_values = [0.33, 0.33, 0.34] # Default uniform distribution
# Handle case where action is None (HOLD)
if action is None:
action = 2 # Map None to HOLD action
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
else:
# Fallback to technical analysis-based prediction
action, q_values, confidence = self._technical_analysis_prediction(symbol)
@@ -2484,21 +2405,6 @@ class EnhancedRealtimeTrainingSystem:
self.last_prediction_time[symbol] = int(current_time)
<<<<<<< HEAD
# Robust action labeling
if action is None:
action_label = 'HOLD'
elif action == 0:
action_label = 'SELL'
elif action == 1:
action_label = 'BUY'
else:
action_label = 'UNKNOWN'
logger.info(f"Forward DQN prediction: {symbol} action={action_label} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
=======
logger.info(f"Forward DQN prediction: {symbol} action={['BUY','SELL','HOLD'][action]} confidence={confidence:.2f} price=${current_price:.2f} target={target_time.strftime('%H:%M:%S')} dims={len(current_state)}")
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
except Exception as e:
logger.error(f"Error generating forward DQN prediction: {e}")