training fixes

This commit is contained in:
Dobromir Popov
2025-07-14 00:47:44 +03:00
parent ebf65494a8
commit e76b1b16dc
2 changed files with 37 additions and 21 deletions

View File

@ -494,11 +494,8 @@ class EnhancedCNN(nn.Module):
return q_values, extrema_pred, price_predictions, features_refined, advanced_predictions
def act(self, state, explore=True):
def act(self, state, explore=True) -> Tuple[int, float, List[float]]:
"""Enhanced action selection with ultra massive model predictions"""
if explore and np.random.random() < 0.1: # 10% random exploration
return np.random.choice(self.n_actions)
self.eval()
# Accept both NumPy arrays and already-built torch tensors
@ -511,14 +508,16 @@ class EnhancedCNN(nn.Module):
state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device)
if state_tensor.dim() == 1:
state_tensor = state_tensor.unsqueeze(0)
with torch.no_grad():
q_values, extrema_pred, price_predictions, features, advanced_predictions = self(state_tensor)
# Apply softmax to get action probabilities
action_probs = torch.softmax(q_values, dim=1)
action = torch.argmax(action_probs, dim=1).item()
action_probs_tensor = torch.softmax(q_values, dim=1)
action_idx = int(torch.argmax(action_probs_tensor, dim=1).item())
confidence = float(action_probs_tensor[0, action_idx].item()) # Confidence of the chosen action
action_probs = action_probs_tensor.squeeze(0).tolist() # Convert to list of floats for return
# Log advanced predictions for better decision making
if hasattr(self, '_log_predictions') and self._log_predictions:
# Log volatility prediction
@ -547,7 +546,7 @@ class EnhancedCNN(nn.Module):
logger.info(f" Market Regime: {regime_labels[regime_class]} ({regime[regime_class]:.3f})")
logger.info(f" Risk Level: {risk_labels[risk_class]} ({risk[risk_class]:.3f})")
return action
return action_idx, confidence, action_probs
def save(self, path):
"""Save model weights and architecture"""

View File

@ -33,7 +33,7 @@ from .data_provider import DataProvider
from .universal_data_adapter import UniversalDataAdapter, UniversalDataStream
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface, ModelRegistry
from NN.models.cob_rl_model import COBRLModelInterface # Specific import for COB RL Interface
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface # Import from new file
from NN.models.model_interfaces import ModelInterface as NNModelInterface, CNNModelInterface as NNCNNModelInterface, RLAgentInterface as NNRLAgentInterface, ExtremaTrainerInterface as NNExtremaTrainerInterface # Import from new file
from core.extrema_trainer import ExtremaTrainer # Import ExtremaTrainer for its interface
# Import COB integration for real-time market microstructure data
@ -435,15 +435,21 @@ class TradingOrchestrator:
super().__init__(name)
self.model = model
def predict(self, data):
def predict(self, data=None):
try:
# Use available methods from ExtremaTrainer
if hasattr(self.model, 'detect_extrema'):
return self.model.detect_extrema(data)
elif hasattr(self.model, 'get_pivot_signals'):
return self.model.get_pivot_signals(data)
# Return a default prediction if no methods available
return {'action': 'HOLD', 'confidence': 0.5}
# ExtremaTrainer provides context features, not a direct prediction
# We assume 'data' here is the 'symbol' string to pass to get_context_features_for_model
if not isinstance(data, str):
logger.warning(f"ExtremaTrainerInterface.predict received non-string data: {type(data)}. Cannot get context features.")
return None
features = self.model.get_context_features_for_model(symbol=data)
if features is not None and features.size > 0:
# The presence of features indicates a signal. We'll return a generic HOLD
# with a neutral confidence. This can be refined if ExtremaTrainer provides
# more specific BUY/SELL signals directly.
return {'action': 'HOLD', 'confidence': 0.5, 'probabilities': {'BUY': 0.33, 'SELL': 0.33, 'HOLD': 0.34}}
return None
except Exception as e:
logger.error(f"Error in extrema trainer prediction: {e}")
return None
@ -1016,7 +1022,14 @@ class TradingOrchestrator:
if hasattr(model.model, 'act'):
# Flatten / reshape enhanced_features as needed…
x = self._prepare_cnn_input(enhanced_features)
action_probs, confidence = model.model.act(x, explore=False)
# Debugging: Print the type and content of x before passing to act()
logger.debug(f"CNN input (x) type: {type(x)}, shape: {x.shape}, content sample: {x.flatten()[:5]}...")
action_idx, confidence, action_probs = model.model.act(x, explore=False)
# Debugging: Print the type and content of the unpacked values
logger.debug(f"CNN act() returned: action_idx={action_idx} (type={type(action_idx)}), confidence={confidence} (type={type(confidence)}), action_probs={action_probs[:5]}... (type={type(action_probs)})")
else:
# fallback to generic predict
result = model.predict(enhanced_features)
@ -1029,7 +1042,7 @@ class TradingOrchestrator:
logger.warning(f"CNN inference failed for {symbol}@{timeframe}: {e}")
continue # skip this timeframe entirely
# 4) If we still dont have valid probs, skip
# 4) If we still don't have valid probs, skip
if action_probs is None:
continue
@ -1175,11 +1188,15 @@ class TradingOrchestrator:
action_probs = prediction_result.get('probabilities', None)
confidence = prediction_result.get('confidence', 0.7)
else:
# Assume it's just action probabilities
# Assume it's just action probabilities (e.g., a list or numpy array)
action_probs = prediction_result
confidence = 0.7 # Default confidence
if action_probs is not None:
# Ensure action_probs is a numpy array for argmax
if not isinstance(action_probs, np.ndarray):
action_probs = np.array(action_probs)
action_names = ['SELL', 'HOLD', 'BUY']
best_action_idx = np.argmax(action_probs)
best_action = action_names[best_action_idx]