wip model CP storage/loading,

models are aware of current position
fix kill stale procc task
This commit is contained in:
Dobromir Popov
2025-07-29 14:51:40 +03:00
parent f34b2a46a2
commit afde58bc40
7 changed files with 472 additions and 82 deletions

View File

@ -103,6 +103,9 @@ class BaseDataInput:
# Market microstructure data
market_microstructure: Dict[str, Any] = field(default_factory=dict)
# Position and trading state information
position_info: Dict[str, Any] = field(default_factory=dict)
def get_feature_vector(self) -> np.ndarray:
"""
Convert BaseDataInput to standardized feature vector for models
@ -174,7 +177,7 @@ class BaseDataInput:
features.extend(indicator_values[:100]) # Take first 100 indicators
features.extend([0.0] * max(0, 100 - len(indicator_values))) # Pad to exactly 100
# Last predictions from other models (FIXED SIZE: 50 features)
# Last predictions from other models (FIXED SIZE: 45 features)
prediction_features = []
for model_output in self.last_predictions.values():
prediction_features.extend([
@ -184,8 +187,18 @@ class BaseDataInput:
model_output.predictions.get('hold_probability', 0.0),
model_output.predictions.get('expected_reward', 0.0)
])
features.extend(prediction_features[:50]) # Take first 50 prediction features
features.extend([0.0] * max(0, 50 - len(prediction_features))) # Pad to exactly 50
features.extend(prediction_features[:45]) # Take first 45 prediction features
features.extend([0.0] * max(0, 45 - len(prediction_features))) # Pad to exactly 45
# Position and trading state information (FIXED SIZE: 5 features)
position_features = [
1.0 if self.position_info.get('has_position', False) else 0.0,
self.position_info.get('position_pnl', 0.0),
self.position_info.get('position_size', 0.0),
self.position_info.get('entry_price', 0.0),
self.position_info.get('time_in_position_minutes', 0.0)
]
features.extend(position_features) # Exactly 5 position features
# CRITICAL: Ensure EXACTLY the fixed feature size
if len(features) > FIXED_FEATURE_SIZE:

View File

@ -806,38 +806,45 @@ class TradingOrchestrator:
if hasattr(self.cob_rl_agent, "to"):
self.cob_rl_agent.to(self.device)
# Load best checkpoint and capture initial state (using database metadata)
# Load best checkpoint and capture initial state (using checkpoint manager)
checkpoint_loaded = False
if hasattr(self.cob_rl_agent, "load_model"):
try:
self.cob_rl_agent.load_model() # This loads the state into the model
db_manager = get_database_manager()
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(
"cob_rl"
try:
from utils.checkpoint_manager import load_best_checkpoint
# Try to load checkpoint using checkpoint manager
result = load_best_checkpoint("cob_rl")
if result:
file_path, metadata = result
# Load the checkpoint into the model
checkpoint = torch.load(file_path, map_location=self.device)
# Load model state
if 'model_state_dict' in checkpoint:
self.cob_rl_agent.model.load_state_dict(checkpoint['model_state_dict'])
if 'optimizer_state_dict' in checkpoint and hasattr(self.cob_rl_agent, 'optimizer'):
self.cob_rl_agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Update model states
self.model_states["cob_rl"]["initial_loss"] = (
metadata.performance_metrics.get("loss", 0.0)
)
if checkpoint_metadata:
self.model_states["cob_rl"]["initial_loss"] = (
checkpoint_metadata.training_metadata.get(
"initial_loss", None
)
)
self.model_states["cob_rl"]["current_loss"] = (
checkpoint_metadata.performance_metrics.get("loss", 0.0)
)
self.model_states["cob_rl"]["best_loss"] = (
checkpoint_metadata.performance_metrics.get("loss", 0.0)
)
self.model_states["cob_rl"]["checkpoint_loaded"] = True
self.model_states["cob_rl"][
"checkpoint_filename"
] = checkpoint_metadata.checkpoint_id
checkpoint_loaded = True
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
logger.info(
f"COB RL checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})"
)
except Exception as e:
logger.warning(f"Error loading COB RL checkpoint: {e}")
self.model_states["cob_rl"]["current_loss"] = (
metadata.performance_metrics.get("loss", 0.0)
)
self.model_states["cob_rl"]["best_loss"] = (
metadata.performance_metrics.get("loss", 0.0)
)
self.model_states["cob_rl"]["checkpoint_loaded"] = True
self.model_states["cob_rl"][
"checkpoint_filename"
] = metadata.checkpoint_id
checkpoint_loaded = True
loss_str = f"{metadata.performance_metrics.get('loss', 0.0):.4f}"
logger.info(
f"COB RL checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})"
)
except Exception as e:
logger.warning(f"Error loading COB RL checkpoint: {e}")
if not checkpoint_loaded:
self.model_states["cob_rl"]["initial_loss"] = None
@ -1020,7 +1027,63 @@ class TradingOrchestrator:
except Exception as e:
logger.error(f"Failed to register COB RL Agent: {e}")
# Decision model will be initialized elsewhere if needed
# Register Decision Fusion Model
if hasattr(self, 'decision_fusion_network') and self.decision_fusion_network:
try:
class DecisionFusionModelInterface(ModelInterface):
def __init__(self, model, name: str):
super().__init__(name)
self.model = model
def predict(self, data):
try:
if hasattr(self.model, "forward"):
# Convert data to tensor if needed
if isinstance(data, np.ndarray):
data = torch.from_numpy(data).float()
elif not isinstance(data, torch.Tensor):
logger.warning(f"Decision fusion received unexpected data type: {type(data)}")
return None
# Ensure data has correct shape
if data.dim() == 1:
data = data.unsqueeze(0) # Add batch dimension
with torch.no_grad():
self.model.eval()
output = self.model(data)
probabilities = output.squeeze().cpu().numpy()
# Convert to action prediction
action_idx = np.argmax(probabilities)
actions = ["BUY", "SELL", "HOLD"]
action = actions[action_idx]
confidence = float(probabilities[action_idx])
return {
"action": action,
"confidence": confidence,
"probabilities": {
"BUY": float(probabilities[0]),
"SELL": float(probabilities[1]),
"HOLD": float(probabilities[2])
}
}
return None
except Exception as e:
logger.error(f"Error in Decision Fusion prediction: {e}")
return None
def get_memory_usage(self) -> float:
return 25.0 # MB
decision_fusion_interface = DecisionFusionModelInterface(
self.decision_fusion_network, name="decision_fusion"
)
self.register_model(decision_fusion_interface, weight=0.3)
logger.info("Decision Fusion Model registered successfully")
except Exception as e:
logger.error(f"Failed to register Decision Fusion Model: {e}")
# Normalize weights after all registrations
self._normalize_weights()
@ -3204,6 +3267,8 @@ class TradingOrchestrator:
price_change_pct,
time_diff_minutes,
inference_price is not None, # Add price prediction flag
symbol, # Pass symbol for position lookup
None, # Let method determine position status
)
# Update model performance tracking
@ -3309,15 +3374,21 @@ class TradingOrchestrator:
price_change_pct: float,
time_diff_minutes: float,
has_price_prediction: bool = False,
symbol: str = None,
has_position: bool = None,
) -> tuple[float, bool]:
"""
Calculate sophisticated reward based on prediction accuracy, confidence, and price movement magnitude
Now considers position status when evaluating HOLD decisions
Args:
predicted_action: The predicted action ('BUY', 'SELL', 'HOLD')
prediction_confidence: Model's confidence in the prediction (0.0 to 1.0)
price_change_pct: Actual price change percentage
time_diff_minutes: Time elapsed since prediction
has_price_prediction: Whether the model made a price prediction
symbol: Trading symbol (for position lookup)
has_position: Whether we currently have a position (if None, will be looked up)
Returns:
tuple: (reward, was_correct)
@ -3326,6 +3397,12 @@ class TradingOrchestrator:
# Base thresholds for determining correctness
movement_threshold = 0.1 # 0.1% minimum movement to consider significant
# Determine current position status if not provided
if has_position is None and symbol:
has_position = self._has_open_position(symbol)
elif has_position is None:
has_position = False
# Determine if prediction was directionally correct
was_correct = False
directional_accuracy = 0.0
@ -3341,10 +3418,25 @@ class TradingOrchestrator:
0, -price_change_pct
) # Positive for downward movement
elif predicted_action == "HOLD":
was_correct = abs(price_change_pct) < movement_threshold
directional_accuracy = max(
0, movement_threshold - abs(price_change_pct)
) # Positive for stability
# HOLD evaluation now considers position status
if has_position:
# If we have a position, HOLD is correct if price moved favorably or stayed stable
# This prevents penalizing HOLD when we're already in a profitable position
if price_change_pct > 0: # Price went up while holding - good
was_correct = True
directional_accuracy = price_change_pct # Reward based on profit
elif abs(price_change_pct) < movement_threshold: # Price stable - neutral
was_correct = True
directional_accuracy = movement_threshold - abs(price_change_pct)
else: # Price dropped while holding - bad, but less penalty than wrong direction
was_correct = False
directional_accuracy = max(0, movement_threshold - abs(price_change_pct)) * 0.5
else:
# If we don't have a position, HOLD is correct if price stayed relatively stable
was_correct = abs(price_change_pct) < movement_threshold
directional_accuracy = max(
0, movement_threshold - abs(price_change_pct)
) # Positive for stability
# Calculate magnitude-based multiplier (higher rewards for larger correct movements)
magnitude_multiplier = min(
@ -3404,12 +3496,19 @@ class TradingOrchestrator:
except Exception as e:
logger.error(f"Error calculating sophisticated reward: {e}")
# Fallback to simple reward
simple_correct = (
(predicted_action == "BUY" and price_change_pct > 0.1)
or (predicted_action == "SELL" and price_change_pct < -0.1)
or (predicted_action == "HOLD" and abs(price_change_pct) < 0.1)
)
# Fallback to simple reward with position awareness
has_position = self._has_open_position(symbol) if symbol else False
if predicted_action == "HOLD" and has_position:
# If holding a position, HOLD is correct if price didn't drop significantly
simple_correct = price_change_pct > -0.2 # Allow small losses while holding
else:
# Standard evaluation for other cases
simple_correct = (
(predicted_action == "BUY" and price_change_pct > 0.1)
or (predicted_action == "SELL" and price_change_pct < -0.1)
or (predicted_action == "HOLD" and abs(price_change_pct) < 0.1)
)
return (1.0 if simple_correct else -0.5, simple_correct)
async def _train_model_on_outcome(
@ -5225,30 +5324,37 @@ class TradingOrchestrator:
try:
from utils.checkpoint_manager import load_best_checkpoint
# Try multiple checkpoint names for decision fusion
checkpoint_names = ["decision_fusion", "decision", "fusion"]
checkpoint_loaded = False
for checkpoint_name in checkpoint_names:
try:
result = load_best_checkpoint(checkpoint_name, checkpoint_name)
if result:
file_path, metadata = result
self.decision_fusion_network.load(file_path)
self.model_states["decision"]["checkpoint_loaded"] = True
self.model_states["decision"][
"checkpoint_filename"
] = metadata.checkpoint_id
logger.info(
f"Decision fusion network loaded from checkpoint: {metadata.checkpoint_id}"
)
checkpoint_loaded = True
break
except Exception as e:
logger.debug(f"Failed to load checkpoint '{checkpoint_name}': {e}")
continue
if not checkpoint_loaded:
# Try to load decision fusion checkpoint
result = load_best_checkpoint("decision_fusion")
if result:
file_path, metadata = result
# Load the checkpoint into the network
checkpoint = torch.load(file_path, map_location=self.device)
# Load model state
if 'model_state_dict' in checkpoint:
self.decision_fusion_network.load_state_dict(checkpoint['model_state_dict'])
# Update model states
self.model_states["decision"]["initial_loss"] = (
metadata.performance_metrics.get("loss", 0.0)
)
self.model_states["decision"]["current_loss"] = (
metadata.performance_metrics.get("loss", 0.0)
)
self.model_states["decision"]["best_loss"] = (
metadata.performance_metrics.get("loss", 0.0)
)
self.model_states["decision"]["checkpoint_loaded"] = True
self.model_states["decision"][
"checkpoint_filename"
] = metadata.checkpoint_id
loss_str = f"{metadata.performance_metrics.get('loss', 0.0):.4f}"
logger.info(
f"Decision fusion network loaded from checkpoint: {metadata.checkpoint_id} (loss={loss_str})"
)
else:
logger.info(
"No existing decision fusion checkpoint found, starting fresh"
)
@ -7416,11 +7522,45 @@ class TradingOrchestrator:
symbol: Trading symbol
Returns:
BaseDataInput with consistent data structure
BaseDataInput with consistent data structure and position information
"""
try:
# Use data provider's optimized build_base_data_input method
return self.data_provider.build_base_data_input(symbol)
base_data = self.data_provider.build_base_data_input(symbol)
if base_data:
# Add position information to the base data
current_price = self.data_provider.get_current_price(symbol)
has_position = self._has_open_position(symbol)
position_pnl = self._get_current_position_pnl(symbol, current_price) if current_price else 0.0
# Get additional position details if available
position_size = 0.0
entry_price = 0.0
time_in_position_minutes = 0.0
if has_position and self.trading_executor and hasattr(self.trading_executor, "get_current_position"):
try:
position = self.trading_executor.get_current_position(symbol)
if position:
position_size = position.get("size", 0.0)
entry_price = position.get("price", 0.0)
entry_time = position.get("entry_time")
if entry_time:
time_in_position_minutes = (datetime.now() - entry_time).total_seconds() / 60.0
except Exception as e:
logger.debug(f"Error getting position details for {symbol}: {e}")
# Add position information to base data
base_data.position_info = {
'has_position': has_position,
'position_pnl': position_pnl,
'position_size': position_size,
'entry_price': entry_price,
'time_in_position_minutes': time_in_position_minutes
}
return base_data
except Exception as e:
logger.error(f"Error building BaseDataInput for {symbol}: {e}")