training fixes
This commit is contained in:
@ -494,11 +494,8 @@ class EnhancedCNN(nn.Module):
|
|||||||
|
|
||||||
return q_values, extrema_pred, price_predictions, features_refined, advanced_predictions
|
return q_values, extrema_pred, price_predictions, features_refined, advanced_predictions
|
||||||
|
|
||||||
def act(self, state, explore=True):
|
def act(self, state, explore=True) -> Tuple[int, float, List[float]]:
|
||||||
"""Enhanced action selection with ultra massive model predictions"""
|
"""Enhanced action selection with ultra massive model predictions"""
|
||||||
if explore and np.random.random() < 0.1: # 10% random exploration
|
|
||||||
return np.random.choice(self.n_actions)
|
|
||||||
|
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
# Accept both NumPy arrays and already-built torch tensors
|
# Accept both NumPy arrays and already-built torch tensors
|
||||||
@ -511,14 +508,16 @@ class EnhancedCNN(nn.Module):
|
|||||||
state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device)
|
state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device)
|
||||||
if state_tensor.dim() == 1:
|
if state_tensor.dim() == 1:
|
||||||
state_tensor = state_tensor.unsqueeze(0)
|
state_tensor = state_tensor.unsqueeze(0)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
q_values, extrema_pred, price_predictions, features, advanced_predictions = self(state_tensor)
|
q_values, extrema_pred, price_predictions, features, advanced_predictions = self(state_tensor)
|
||||||
|
|
||||||
# Apply softmax to get action probabilities
|
# Apply softmax to get action probabilities
|
||||||
action_probs = torch.softmax(q_values, dim=1)
|
action_probs_tensor = torch.softmax(q_values, dim=1)
|
||||||
action = torch.argmax(action_probs, dim=1).item()
|
action_idx = int(torch.argmax(action_probs_tensor, dim=1).item())
|
||||||
|
confidence = float(action_probs_tensor[0, action_idx].item()) # Confidence of the chosen action
|
||||||
|
action_probs = action_probs_tensor.squeeze(0).tolist() # Convert to list of floats for return
|
||||||
|
|
||||||
# Log advanced predictions for better decision making
|
# Log advanced predictions for better decision making
|
||||||
if hasattr(self, '_log_predictions') and self._log_predictions:
|
if hasattr(self, '_log_predictions') and self._log_predictions:
|
||||||
# Log volatility prediction
|
# Log volatility prediction
|
||||||
@ -547,7 +546,7 @@ class EnhancedCNN(nn.Module):
|
|||||||
logger.info(f" Market Regime: {regime_labels[regime_class]} ({regime[regime_class]:.3f})")
|
logger.info(f" Market Regime: {regime_labels[regime_class]} ({regime[regime_class]:.3f})")
|
||||||
logger.info(f" Risk Level: {risk_labels[risk_class]} ({risk[risk_class]:.3f})")
|
logger.info(f" Risk Level: {risk_labels[risk_class]} ({risk[risk_class]:.3f})")
|
||||||
|
|
||||||
return action
|
return action_idx, confidence, action_probs
|
||||||
|
|
||||||
def save(self, path):
|
def save(self, path):
|
||||||
"""Save model weights and architecture"""
|
"""Save model weights and architecture"""
|
||||||
|
@ -33,7 +33,7 @@ from .data_provider import DataProvider
|
|||||||
from .universal_data_adapter import UniversalDataAdapter, UniversalDataStream
|
from .universal_data_adapter import UniversalDataAdapter, UniversalDataStream
|
||||||
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface, ModelRegistry
|
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface, ModelRegistry
|
||||||
from NN.models.cob_rl_model import COBRLModelInterface # Specific import for COB RL Interface
|
from NN.models.cob_rl_model import COBRLModelInterface # Specific import for COB RL Interface
|
||||||
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface # Import from new file
|
from NN.models.model_interfaces import ModelInterface as NNModelInterface, CNNModelInterface as NNCNNModelInterface, RLAgentInterface as NNRLAgentInterface, ExtremaTrainerInterface as NNExtremaTrainerInterface # Import from new file
|
||||||
from core.extrema_trainer import ExtremaTrainer # Import ExtremaTrainer for its interface
|
from core.extrema_trainer import ExtremaTrainer # Import ExtremaTrainer for its interface
|
||||||
|
|
||||||
# Import COB integration for real-time market microstructure data
|
# Import COB integration for real-time market microstructure data
|
||||||
@ -435,15 +435,21 @@ class TradingOrchestrator:
|
|||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def predict(self, data):
|
def predict(self, data=None):
|
||||||
try:
|
try:
|
||||||
# Use available methods from ExtremaTrainer
|
# ExtremaTrainer provides context features, not a direct prediction
|
||||||
if hasattr(self.model, 'detect_extrema'):
|
# We assume 'data' here is the 'symbol' string to pass to get_context_features_for_model
|
||||||
return self.model.detect_extrema(data)
|
if not isinstance(data, str):
|
||||||
elif hasattr(self.model, 'get_pivot_signals'):
|
logger.warning(f"ExtremaTrainerInterface.predict received non-string data: {type(data)}. Cannot get context features.")
|
||||||
return self.model.get_pivot_signals(data)
|
return None
|
||||||
# Return a default prediction if no methods available
|
|
||||||
return {'action': 'HOLD', 'confidence': 0.5}
|
features = self.model.get_context_features_for_model(symbol=data)
|
||||||
|
if features is not None and features.size > 0:
|
||||||
|
# The presence of features indicates a signal. We'll return a generic HOLD
|
||||||
|
# with a neutral confidence. This can be refined if ExtremaTrainer provides
|
||||||
|
# more specific BUY/SELL signals directly.
|
||||||
|
return {'action': 'HOLD', 'confidence': 0.5, 'probabilities': {'BUY': 0.33, 'SELL': 0.33, 'HOLD': 0.34}}
|
||||||
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in extrema trainer prediction: {e}")
|
logger.error(f"Error in extrema trainer prediction: {e}")
|
||||||
return None
|
return None
|
||||||
@ -1016,7 +1022,14 @@ class TradingOrchestrator:
|
|||||||
if hasattr(model.model, 'act'):
|
if hasattr(model.model, 'act'):
|
||||||
# Flatten / reshape enhanced_features as needed…
|
# Flatten / reshape enhanced_features as needed…
|
||||||
x = self._prepare_cnn_input(enhanced_features)
|
x = self._prepare_cnn_input(enhanced_features)
|
||||||
action_probs, confidence = model.model.act(x, explore=False)
|
|
||||||
|
# Debugging: Print the type and content of x before passing to act()
|
||||||
|
logger.debug(f"CNN input (x) type: {type(x)}, shape: {x.shape}, content sample: {x.flatten()[:5]}...")
|
||||||
|
|
||||||
|
action_idx, confidence, action_probs = model.model.act(x, explore=False)
|
||||||
|
|
||||||
|
# Debugging: Print the type and content of the unpacked values
|
||||||
|
logger.debug(f"CNN act() returned: action_idx={action_idx} (type={type(action_idx)}), confidence={confidence} (type={type(confidence)}), action_probs={action_probs[:5]}... (type={type(action_probs)})")
|
||||||
else:
|
else:
|
||||||
# fallback to generic predict
|
# fallback to generic predict
|
||||||
result = model.predict(enhanced_features)
|
result = model.predict(enhanced_features)
|
||||||
@ -1029,7 +1042,7 @@ class TradingOrchestrator:
|
|||||||
logger.warning(f"CNN inference failed for {symbol}@{timeframe}: {e}")
|
logger.warning(f"CNN inference failed for {symbol}@{timeframe}: {e}")
|
||||||
continue # skip this timeframe entirely
|
continue # skip this timeframe entirely
|
||||||
|
|
||||||
# 4) If we still don’t have valid probs, skip
|
# 4) If we still don't have valid probs, skip
|
||||||
if action_probs is None:
|
if action_probs is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -1175,11 +1188,15 @@ class TradingOrchestrator:
|
|||||||
action_probs = prediction_result.get('probabilities', None)
|
action_probs = prediction_result.get('probabilities', None)
|
||||||
confidence = prediction_result.get('confidence', 0.7)
|
confidence = prediction_result.get('confidence', 0.7)
|
||||||
else:
|
else:
|
||||||
# Assume it's just action probabilities
|
# Assume it's just action probabilities (e.g., a list or numpy array)
|
||||||
action_probs = prediction_result
|
action_probs = prediction_result
|
||||||
confidence = 0.7 # Default confidence
|
confidence = 0.7 # Default confidence
|
||||||
|
|
||||||
if action_probs is not None:
|
if action_probs is not None:
|
||||||
|
# Ensure action_probs is a numpy array for argmax
|
||||||
|
if not isinstance(action_probs, np.ndarray):
|
||||||
|
action_probs = np.array(action_probs)
|
||||||
|
|
||||||
action_names = ['SELL', 'HOLD', 'BUY']
|
action_names = ['SELL', 'HOLD', 'BUY']
|
||||||
best_action_idx = np.argmax(action_probs)
|
best_action_idx = np.argmax(action_probs)
|
||||||
best_action = action_names[best_action_idx]
|
best_action = action_names[best_action_idx]
|
||||||
|
Reference in New Issue
Block a user