training fixes
This commit is contained in:
@ -494,11 +494,8 @@ class EnhancedCNN(nn.Module):
|
||||
|
||||
return q_values, extrema_pred, price_predictions, features_refined, advanced_predictions
|
||||
|
||||
def act(self, state, explore=True):
|
||||
def act(self, state, explore=True) -> Tuple[int, float, List[float]]:
|
||||
"""Enhanced action selection with ultra massive model predictions"""
|
||||
if explore and np.random.random() < 0.1: # 10% random exploration
|
||||
return np.random.choice(self.n_actions)
|
||||
|
||||
self.eval()
|
||||
|
||||
# Accept both NumPy arrays and already-built torch tensors
|
||||
@ -511,14 +508,16 @@ class EnhancedCNN(nn.Module):
|
||||
state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device)
|
||||
if state_tensor.dim() == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0)
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
q_values, extrema_pred, price_predictions, features, advanced_predictions = self(state_tensor)
|
||||
|
||||
# Apply softmax to get action probabilities
|
||||
action_probs = torch.softmax(q_values, dim=1)
|
||||
action = torch.argmax(action_probs, dim=1).item()
|
||||
|
||||
action_probs_tensor = torch.softmax(q_values, dim=1)
|
||||
action_idx = int(torch.argmax(action_probs_tensor, dim=1).item())
|
||||
confidence = float(action_probs_tensor[0, action_idx].item()) # Confidence of the chosen action
|
||||
action_probs = action_probs_tensor.squeeze(0).tolist() # Convert to list of floats for return
|
||||
|
||||
# Log advanced predictions for better decision making
|
||||
if hasattr(self, '_log_predictions') and self._log_predictions:
|
||||
# Log volatility prediction
|
||||
@ -547,7 +546,7 @@ class EnhancedCNN(nn.Module):
|
||||
logger.info(f" Market Regime: {regime_labels[regime_class]} ({regime[regime_class]:.3f})")
|
||||
logger.info(f" Risk Level: {risk_labels[risk_class]} ({risk[risk_class]:.3f})")
|
||||
|
||||
return action
|
||||
return action_idx, confidence, action_probs
|
||||
|
||||
def save(self, path):
|
||||
"""Save model weights and architecture"""
|
||||
|
@ -33,7 +33,7 @@ from .data_provider import DataProvider
|
||||
from .universal_data_adapter import UniversalDataAdapter, UniversalDataStream
|
||||
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface, ModelRegistry
|
||||
from NN.models.cob_rl_model import COBRLModelInterface # Specific import for COB RL Interface
|
||||
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface # Import from new file
|
||||
from NN.models.model_interfaces import ModelInterface as NNModelInterface, CNNModelInterface as NNCNNModelInterface, RLAgentInterface as NNRLAgentInterface, ExtremaTrainerInterface as NNExtremaTrainerInterface # Import from new file
|
||||
from core.extrema_trainer import ExtremaTrainer # Import ExtremaTrainer for its interface
|
||||
|
||||
# Import COB integration for real-time market microstructure data
|
||||
@ -435,15 +435,21 @@ class TradingOrchestrator:
|
||||
super().__init__(name)
|
||||
self.model = model
|
||||
|
||||
def predict(self, data):
|
||||
def predict(self, data=None):
|
||||
try:
|
||||
# Use available methods from ExtremaTrainer
|
||||
if hasattr(self.model, 'detect_extrema'):
|
||||
return self.model.detect_extrema(data)
|
||||
elif hasattr(self.model, 'get_pivot_signals'):
|
||||
return self.model.get_pivot_signals(data)
|
||||
# Return a default prediction if no methods available
|
||||
return {'action': 'HOLD', 'confidence': 0.5}
|
||||
# ExtremaTrainer provides context features, not a direct prediction
|
||||
# We assume 'data' here is the 'symbol' string to pass to get_context_features_for_model
|
||||
if not isinstance(data, str):
|
||||
logger.warning(f"ExtremaTrainerInterface.predict received non-string data: {type(data)}. Cannot get context features.")
|
||||
return None
|
||||
|
||||
features = self.model.get_context_features_for_model(symbol=data)
|
||||
if features is not None and features.size > 0:
|
||||
# The presence of features indicates a signal. We'll return a generic HOLD
|
||||
# with a neutral confidence. This can be refined if ExtremaTrainer provides
|
||||
# more specific BUY/SELL signals directly.
|
||||
return {'action': 'HOLD', 'confidence': 0.5, 'probabilities': {'BUY': 0.33, 'SELL': 0.33, 'HOLD': 0.34}}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extrema trainer prediction: {e}")
|
||||
return None
|
||||
@ -1016,7 +1022,14 @@ class TradingOrchestrator:
|
||||
if hasattr(model.model, 'act'):
|
||||
# Flatten / reshape enhanced_features as needed…
|
||||
x = self._prepare_cnn_input(enhanced_features)
|
||||
action_probs, confidence = model.model.act(x, explore=False)
|
||||
|
||||
# Debugging: Print the type and content of x before passing to act()
|
||||
logger.debug(f"CNN input (x) type: {type(x)}, shape: {x.shape}, content sample: {x.flatten()[:5]}...")
|
||||
|
||||
action_idx, confidence, action_probs = model.model.act(x, explore=False)
|
||||
|
||||
# Debugging: Print the type and content of the unpacked values
|
||||
logger.debug(f"CNN act() returned: action_idx={action_idx} (type={type(action_idx)}), confidence={confidence} (type={type(confidence)}), action_probs={action_probs[:5]}... (type={type(action_probs)})")
|
||||
else:
|
||||
# fallback to generic predict
|
||||
result = model.predict(enhanced_features)
|
||||
@ -1029,7 +1042,7 @@ class TradingOrchestrator:
|
||||
logger.warning(f"CNN inference failed for {symbol}@{timeframe}: {e}")
|
||||
continue # skip this timeframe entirely
|
||||
|
||||
# 4) If we still don’t have valid probs, skip
|
||||
# 4) If we still don't have valid probs, skip
|
||||
if action_probs is None:
|
||||
continue
|
||||
|
||||
@ -1175,11 +1188,15 @@ class TradingOrchestrator:
|
||||
action_probs = prediction_result.get('probabilities', None)
|
||||
confidence = prediction_result.get('confidence', 0.7)
|
||||
else:
|
||||
# Assume it's just action probabilities
|
||||
# Assume it's just action probabilities (e.g., a list or numpy array)
|
||||
action_probs = prediction_result
|
||||
confidence = 0.7 # Default confidence
|
||||
|
||||
if action_probs is not None:
|
||||
# Ensure action_probs is a numpy array for argmax
|
||||
if not isinstance(action_probs, np.ndarray):
|
||||
action_probs = np.array(action_probs)
|
||||
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
best_action_idx = np.argmax(action_probs)
|
||||
best_action = action_names[best_action_idx]
|
||||
|
Reference in New Issue
Block a user