wip model CP storage/loading,
models are aware of current position fix kill stale procc task
This commit is contained in:
@ -103,6 +103,9 @@ class BaseDataInput:
|
||||
# Market microstructure data
|
||||
market_microstructure: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Position and trading state information
|
||||
position_info: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def get_feature_vector(self) -> np.ndarray:
|
||||
"""
|
||||
Convert BaseDataInput to standardized feature vector for models
|
||||
@ -174,7 +177,7 @@ class BaseDataInput:
|
||||
features.extend(indicator_values[:100]) # Take first 100 indicators
|
||||
features.extend([0.0] * max(0, 100 - len(indicator_values))) # Pad to exactly 100
|
||||
|
||||
# Last predictions from other models (FIXED SIZE: 50 features)
|
||||
# Last predictions from other models (FIXED SIZE: 45 features)
|
||||
prediction_features = []
|
||||
for model_output in self.last_predictions.values():
|
||||
prediction_features.extend([
|
||||
@ -184,8 +187,18 @@ class BaseDataInput:
|
||||
model_output.predictions.get('hold_probability', 0.0),
|
||||
model_output.predictions.get('expected_reward', 0.0)
|
||||
])
|
||||
features.extend(prediction_features[:50]) # Take first 50 prediction features
|
||||
features.extend([0.0] * max(0, 50 - len(prediction_features))) # Pad to exactly 50
|
||||
features.extend(prediction_features[:45]) # Take first 45 prediction features
|
||||
features.extend([0.0] * max(0, 45 - len(prediction_features))) # Pad to exactly 45
|
||||
|
||||
# Position and trading state information (FIXED SIZE: 5 features)
|
||||
position_features = [
|
||||
1.0 if self.position_info.get('has_position', False) else 0.0,
|
||||
self.position_info.get('position_pnl', 0.0),
|
||||
self.position_info.get('position_size', 0.0),
|
||||
self.position_info.get('entry_price', 0.0),
|
||||
self.position_info.get('time_in_position_minutes', 0.0)
|
||||
]
|
||||
features.extend(position_features) # Exactly 5 position features
|
||||
|
||||
# CRITICAL: Ensure EXACTLY the fixed feature size
|
||||
if len(features) > FIXED_FEATURE_SIZE:
|
||||
|
@ -806,38 +806,45 @@ class TradingOrchestrator:
|
||||
if hasattr(self.cob_rl_agent, "to"):
|
||||
self.cob_rl_agent.to(self.device)
|
||||
|
||||
# Load best checkpoint and capture initial state (using database metadata)
|
||||
# Load best checkpoint and capture initial state (using checkpoint manager)
|
||||
checkpoint_loaded = False
|
||||
if hasattr(self.cob_rl_agent, "load_model"):
|
||||
try:
|
||||
self.cob_rl_agent.load_model() # This loads the state into the model
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(
|
||||
"cob_rl"
|
||||
try:
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
|
||||
# Try to load checkpoint using checkpoint manager
|
||||
result = load_best_checkpoint("cob_rl")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
# Load the checkpoint into the model
|
||||
checkpoint = torch.load(file_path, map_location=self.device)
|
||||
|
||||
# Load model state
|
||||
if 'model_state_dict' in checkpoint:
|
||||
self.cob_rl_agent.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
if 'optimizer_state_dict' in checkpoint and hasattr(self.cob_rl_agent, 'optimizer'):
|
||||
self.cob_rl_agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
# Update model states
|
||||
self.model_states["cob_rl"]["initial_loss"] = (
|
||||
metadata.performance_metrics.get("loss", 0.0)
|
||||
)
|
||||
if checkpoint_metadata:
|
||||
self.model_states["cob_rl"]["initial_loss"] = (
|
||||
checkpoint_metadata.training_metadata.get(
|
||||
"initial_loss", None
|
||||
)
|
||||
)
|
||||
self.model_states["cob_rl"]["current_loss"] = (
|
||||
checkpoint_metadata.performance_metrics.get("loss", 0.0)
|
||||
)
|
||||
self.model_states["cob_rl"]["best_loss"] = (
|
||||
checkpoint_metadata.performance_metrics.get("loss", 0.0)
|
||||
)
|
||||
self.model_states["cob_rl"]["checkpoint_loaded"] = True
|
||||
self.model_states["cob_rl"][
|
||||
"checkpoint_filename"
|
||||
] = checkpoint_metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
||||
logger.info(
|
||||
f"COB RL checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading COB RL checkpoint: {e}")
|
||||
self.model_states["cob_rl"]["current_loss"] = (
|
||||
metadata.performance_metrics.get("loss", 0.0)
|
||||
)
|
||||
self.model_states["cob_rl"]["best_loss"] = (
|
||||
metadata.performance_metrics.get("loss", 0.0)
|
||||
)
|
||||
self.model_states["cob_rl"]["checkpoint_loaded"] = True
|
||||
self.model_states["cob_rl"][
|
||||
"checkpoint_filename"
|
||||
] = metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{metadata.performance_metrics.get('loss', 0.0):.4f}"
|
||||
logger.info(
|
||||
f"COB RL checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading COB RL checkpoint: {e}")
|
||||
|
||||
if not checkpoint_loaded:
|
||||
self.model_states["cob_rl"]["initial_loss"] = None
|
||||
@ -1020,7 +1027,63 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register COB RL Agent: {e}")
|
||||
|
||||
# Decision model will be initialized elsewhere if needed
|
||||
# Register Decision Fusion Model
|
||||
if hasattr(self, 'decision_fusion_network') and self.decision_fusion_network:
|
||||
try:
|
||||
class DecisionFusionModelInterface(ModelInterface):
|
||||
def __init__(self, model, name: str):
|
||||
super().__init__(name)
|
||||
self.model = model
|
||||
|
||||
def predict(self, data):
|
||||
try:
|
||||
if hasattr(self.model, "forward"):
|
||||
# Convert data to tensor if needed
|
||||
if isinstance(data, np.ndarray):
|
||||
data = torch.from_numpy(data).float()
|
||||
elif not isinstance(data, torch.Tensor):
|
||||
logger.warning(f"Decision fusion received unexpected data type: {type(data)}")
|
||||
return None
|
||||
|
||||
# Ensure data has correct shape
|
||||
if data.dim() == 1:
|
||||
data = data.unsqueeze(0) # Add batch dimension
|
||||
|
||||
with torch.no_grad():
|
||||
self.model.eval()
|
||||
output = self.model(data)
|
||||
probabilities = output.squeeze().cpu().numpy()
|
||||
|
||||
# Convert to action prediction
|
||||
action_idx = np.argmax(probabilities)
|
||||
actions = ["BUY", "SELL", "HOLD"]
|
||||
action = actions[action_idx]
|
||||
confidence = float(probabilities[action_idx])
|
||||
|
||||
return {
|
||||
"action": action,
|
||||
"confidence": confidence,
|
||||
"probabilities": {
|
||||
"BUY": float(probabilities[0]),
|
||||
"SELL": float(probabilities[1]),
|
||||
"HOLD": float(probabilities[2])
|
||||
}
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Decision Fusion prediction: {e}")
|
||||
return None
|
||||
|
||||
def get_memory_usage(self) -> float:
|
||||
return 25.0 # MB
|
||||
|
||||
decision_fusion_interface = DecisionFusionModelInterface(
|
||||
self.decision_fusion_network, name="decision_fusion"
|
||||
)
|
||||
self.register_model(decision_fusion_interface, weight=0.3)
|
||||
logger.info("Decision Fusion Model registered successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register Decision Fusion Model: {e}")
|
||||
|
||||
# Normalize weights after all registrations
|
||||
self._normalize_weights()
|
||||
@ -3204,6 +3267,8 @@ class TradingOrchestrator:
|
||||
price_change_pct,
|
||||
time_diff_minutes,
|
||||
inference_price is not None, # Add price prediction flag
|
||||
symbol, # Pass symbol for position lookup
|
||||
None, # Let method determine position status
|
||||
)
|
||||
|
||||
# Update model performance tracking
|
||||
@ -3309,15 +3374,21 @@ class TradingOrchestrator:
|
||||
price_change_pct: float,
|
||||
time_diff_minutes: float,
|
||||
has_price_prediction: bool = False,
|
||||
symbol: str = None,
|
||||
has_position: bool = None,
|
||||
) -> tuple[float, bool]:
|
||||
"""
|
||||
Calculate sophisticated reward based on prediction accuracy, confidence, and price movement magnitude
|
||||
Now considers position status when evaluating HOLD decisions
|
||||
|
||||
Args:
|
||||
predicted_action: The predicted action ('BUY', 'SELL', 'HOLD')
|
||||
prediction_confidence: Model's confidence in the prediction (0.0 to 1.0)
|
||||
price_change_pct: Actual price change percentage
|
||||
time_diff_minutes: Time elapsed since prediction
|
||||
has_price_prediction: Whether the model made a price prediction
|
||||
symbol: Trading symbol (for position lookup)
|
||||
has_position: Whether we currently have a position (if None, will be looked up)
|
||||
|
||||
Returns:
|
||||
tuple: (reward, was_correct)
|
||||
@ -3326,6 +3397,12 @@ class TradingOrchestrator:
|
||||
# Base thresholds for determining correctness
|
||||
movement_threshold = 0.1 # 0.1% minimum movement to consider significant
|
||||
|
||||
# Determine current position status if not provided
|
||||
if has_position is None and symbol:
|
||||
has_position = self._has_open_position(symbol)
|
||||
elif has_position is None:
|
||||
has_position = False
|
||||
|
||||
# Determine if prediction was directionally correct
|
||||
was_correct = False
|
||||
directional_accuracy = 0.0
|
||||
@ -3341,10 +3418,25 @@ class TradingOrchestrator:
|
||||
0, -price_change_pct
|
||||
) # Positive for downward movement
|
||||
elif predicted_action == "HOLD":
|
||||
was_correct = abs(price_change_pct) < movement_threshold
|
||||
directional_accuracy = max(
|
||||
0, movement_threshold - abs(price_change_pct)
|
||||
) # Positive for stability
|
||||
# HOLD evaluation now considers position status
|
||||
if has_position:
|
||||
# If we have a position, HOLD is correct if price moved favorably or stayed stable
|
||||
# This prevents penalizing HOLD when we're already in a profitable position
|
||||
if price_change_pct > 0: # Price went up while holding - good
|
||||
was_correct = True
|
||||
directional_accuracy = price_change_pct # Reward based on profit
|
||||
elif abs(price_change_pct) < movement_threshold: # Price stable - neutral
|
||||
was_correct = True
|
||||
directional_accuracy = movement_threshold - abs(price_change_pct)
|
||||
else: # Price dropped while holding - bad, but less penalty than wrong direction
|
||||
was_correct = False
|
||||
directional_accuracy = max(0, movement_threshold - abs(price_change_pct)) * 0.5
|
||||
else:
|
||||
# If we don't have a position, HOLD is correct if price stayed relatively stable
|
||||
was_correct = abs(price_change_pct) < movement_threshold
|
||||
directional_accuracy = max(
|
||||
0, movement_threshold - abs(price_change_pct)
|
||||
) # Positive for stability
|
||||
|
||||
# Calculate magnitude-based multiplier (higher rewards for larger correct movements)
|
||||
magnitude_multiplier = min(
|
||||
@ -3404,12 +3496,19 @@ class TradingOrchestrator:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating sophisticated reward: {e}")
|
||||
# Fallback to simple reward
|
||||
simple_correct = (
|
||||
(predicted_action == "BUY" and price_change_pct > 0.1)
|
||||
or (predicted_action == "SELL" and price_change_pct < -0.1)
|
||||
or (predicted_action == "HOLD" and abs(price_change_pct) < 0.1)
|
||||
)
|
||||
# Fallback to simple reward with position awareness
|
||||
has_position = self._has_open_position(symbol) if symbol else False
|
||||
|
||||
if predicted_action == "HOLD" and has_position:
|
||||
# If holding a position, HOLD is correct if price didn't drop significantly
|
||||
simple_correct = price_change_pct > -0.2 # Allow small losses while holding
|
||||
else:
|
||||
# Standard evaluation for other cases
|
||||
simple_correct = (
|
||||
(predicted_action == "BUY" and price_change_pct > 0.1)
|
||||
or (predicted_action == "SELL" and price_change_pct < -0.1)
|
||||
or (predicted_action == "HOLD" and abs(price_change_pct) < 0.1)
|
||||
)
|
||||
return (1.0 if simple_correct else -0.5, simple_correct)
|
||||
|
||||
async def _train_model_on_outcome(
|
||||
@ -5225,30 +5324,37 @@ class TradingOrchestrator:
|
||||
try:
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
|
||||
# Try multiple checkpoint names for decision fusion
|
||||
checkpoint_names = ["decision_fusion", "decision", "fusion"]
|
||||
checkpoint_loaded = False
|
||||
|
||||
for checkpoint_name in checkpoint_names:
|
||||
try:
|
||||
result = load_best_checkpoint(checkpoint_name, checkpoint_name)
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
self.decision_fusion_network.load(file_path)
|
||||
self.model_states["decision"]["checkpoint_loaded"] = True
|
||||
self.model_states["decision"][
|
||||
"checkpoint_filename"
|
||||
] = metadata.checkpoint_id
|
||||
logger.info(
|
||||
f"Decision fusion network loaded from checkpoint: {metadata.checkpoint_id}"
|
||||
)
|
||||
checkpoint_loaded = True
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load checkpoint '{checkpoint_name}': {e}")
|
||||
continue
|
||||
|
||||
if not checkpoint_loaded:
|
||||
# Try to load decision fusion checkpoint
|
||||
result = load_best_checkpoint("decision_fusion")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
# Load the checkpoint into the network
|
||||
checkpoint = torch.load(file_path, map_location=self.device)
|
||||
|
||||
# Load model state
|
||||
if 'model_state_dict' in checkpoint:
|
||||
self.decision_fusion_network.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
# Update model states
|
||||
self.model_states["decision"]["initial_loss"] = (
|
||||
metadata.performance_metrics.get("loss", 0.0)
|
||||
)
|
||||
self.model_states["decision"]["current_loss"] = (
|
||||
metadata.performance_metrics.get("loss", 0.0)
|
||||
)
|
||||
self.model_states["decision"]["best_loss"] = (
|
||||
metadata.performance_metrics.get("loss", 0.0)
|
||||
)
|
||||
self.model_states["decision"]["checkpoint_loaded"] = True
|
||||
self.model_states["decision"][
|
||||
"checkpoint_filename"
|
||||
] = metadata.checkpoint_id
|
||||
|
||||
loss_str = f"{metadata.performance_metrics.get('loss', 0.0):.4f}"
|
||||
logger.info(
|
||||
f"Decision fusion network loaded from checkpoint: {metadata.checkpoint_id} (loss={loss_str})"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"No existing decision fusion checkpoint found, starting fresh"
|
||||
)
|
||||
@ -7416,11 +7522,45 @@ class TradingOrchestrator:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
BaseDataInput with consistent data structure
|
||||
BaseDataInput with consistent data structure and position information
|
||||
"""
|
||||
try:
|
||||
# Use data provider's optimized build_base_data_input method
|
||||
return self.data_provider.build_base_data_input(symbol)
|
||||
base_data = self.data_provider.build_base_data_input(symbol)
|
||||
|
||||
if base_data:
|
||||
# Add position information to the base data
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
has_position = self._has_open_position(symbol)
|
||||
position_pnl = self._get_current_position_pnl(symbol, current_price) if current_price else 0.0
|
||||
|
||||
# Get additional position details if available
|
||||
position_size = 0.0
|
||||
entry_price = 0.0
|
||||
time_in_position_minutes = 0.0
|
||||
|
||||
if has_position and self.trading_executor and hasattr(self.trading_executor, "get_current_position"):
|
||||
try:
|
||||
position = self.trading_executor.get_current_position(symbol)
|
||||
if position:
|
||||
position_size = position.get("size", 0.0)
|
||||
entry_price = position.get("price", 0.0)
|
||||
entry_time = position.get("entry_time")
|
||||
if entry_time:
|
||||
time_in_position_minutes = (datetime.now() - entry_time).total_seconds() / 60.0
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting position details for {symbol}: {e}")
|
||||
|
||||
# Add position information to base data
|
||||
base_data.position_info = {
|
||||
'has_position': has_position,
|
||||
'position_pnl': position_pnl,
|
||||
'position_size': position_size,
|
||||
'entry_price': entry_price,
|
||||
'time_in_position_minutes': time_in_position_minutes
|
||||
}
|
||||
|
||||
return base_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building BaseDataInput for {symbol}: {e}")
|
||||
|
Reference in New Issue
Block a user