ui state, models toggle
This commit is contained in:
@ -284,7 +284,19 @@ class TradingOrchestrator:
|
||||
self.enhanced_rl_training = enhanced_rl_training
|
||||
|
||||
# Determine the device to use (GPU if available, else CPU)
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# Initialize device - force CPU mode to avoid CUDA errors
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
# Test CUDA availability
|
||||
test_tensor = torch.tensor([1.0]).cuda()
|
||||
self.device = torch.device("cuda")
|
||||
logger.info("CUDA device initialized successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"CUDA initialization failed: {e}, falling back to CPU")
|
||||
self.device = torch.device("cpu")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
logger.info(f"Using device: {self.device}")
|
||||
|
||||
# Configuration - AGGRESSIVE for more training data
|
||||
@ -389,7 +401,20 @@ class TradingOrchestrator:
|
||||
self.fusion_training_history: List[Any] = []
|
||||
self.last_fusion_inputs: Dict[str, Any] = (
|
||||
{}
|
||||
) # Fix: Explicitly initialize as dictionary
|
||||
)
|
||||
|
||||
# Model toggle states - control which models contribute to decisions
|
||||
self.model_toggle_states = {
|
||||
"dqn": {"inference_enabled": True, "training_enabled": True},
|
||||
"cnn": {"inference_enabled": True, "training_enabled": True},
|
||||
"cob_rl": {"inference_enabled": True, "training_enabled": True},
|
||||
"decision_fusion": {"inference_enabled": True, "training_enabled": True},
|
||||
"transformer": {"inference_enabled": True, "training_enabled": True},
|
||||
}
|
||||
|
||||
# UI state persistence
|
||||
self.ui_state_file = "data/ui_state.json"
|
||||
self._load_ui_state() # Fix: Explicitly initialize as dictionary
|
||||
self.fusion_checkpoint_frequency: int = 50 # Save every 50 decisions
|
||||
self.fusion_decisions_count: int = 0
|
||||
self.fusion_training_data: List[Any] = (
|
||||
@ -1309,6 +1334,57 @@ class TradingOrchestrator:
|
||||
else:
|
||||
logger.info("No saved orchestrator state found. Starting fresh.")
|
||||
|
||||
def _load_ui_state(self):
|
||||
"""Load UI state from file"""
|
||||
try:
|
||||
if os.path.exists(self.ui_state_file):
|
||||
with open(self.ui_state_file, "r") as f:
|
||||
ui_state = json.load(f)
|
||||
if "model_toggle_states" in ui_state:
|
||||
self.model_toggle_states.update(ui_state["model_toggle_states"])
|
||||
logger.info(f"UI state loaded from {self.ui_state_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading UI state: {e}")
|
||||
|
||||
def _save_ui_state(self):
|
||||
"""Save UI state to file"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self.ui_state_file), exist_ok=True)
|
||||
ui_state = {
|
||||
"model_toggle_states": self.model_toggle_states,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
with open(self.ui_state_file, "w") as f:
|
||||
json.dump(ui_state, f, indent=4)
|
||||
logger.debug(f"UI state saved to {self.ui_state_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving UI state: {e}")
|
||||
|
||||
def get_model_toggle_state(self, model_name: str) -> Dict[str, bool]:
|
||||
"""Get toggle state for a model"""
|
||||
return self.model_toggle_states.get(model_name, {"inference_enabled": True, "training_enabled": True})
|
||||
|
||||
def set_model_toggle_state(self, model_name: str, inference_enabled: bool = None, training_enabled: bool = None):
|
||||
"""Set toggle state for a model"""
|
||||
if model_name not in self.model_toggle_states:
|
||||
self.model_toggle_states[model_name] = {"inference_enabled": True, "training_enabled": True}
|
||||
|
||||
if inference_enabled is not None:
|
||||
self.model_toggle_states[model_name]["inference_enabled"] = inference_enabled
|
||||
if training_enabled is not None:
|
||||
self.model_toggle_states[model_name]["training_enabled"] = training_enabled
|
||||
|
||||
self._save_ui_state()
|
||||
logger.info(f"Model {model_name} toggle state updated: inference={self.model_toggle_states[model_name]['inference_enabled']}, training={self.model_toggle_states[model_name]['training_enabled']}")
|
||||
|
||||
def is_model_inference_enabled(self, model_name: str) -> bool:
|
||||
"""Check if model inference is enabled"""
|
||||
return self.model_toggle_states.get(model_name, {}).get("inference_enabled", True)
|
||||
|
||||
def is_model_training_enabled(self, model_name: str) -> bool:
|
||||
"""Check if model training is enabled"""
|
||||
return self.model_toggle_states.get(model_name, {}).get("training_enabled", True)
|
||||
|
||||
async def start_continuous_trading(self, symbols: Optional[List[str]] = None):
|
||||
"""Start the continuous trading loop, using a decision model and trading executor"""
|
||||
if symbols is None:
|
||||
@ -1846,11 +1922,14 @@ class TradingOrchestrator:
|
||||
logger.debug(f"No fallback prediction available for {symbol}")
|
||||
return None
|
||||
|
||||
# Choose decision method based on configuration
|
||||
# Choose decision method based on configuration and toggle state
|
||||
decision_fusion_inference_enabled = self.is_model_inference_enabled("decision_fusion")
|
||||
|
||||
if (
|
||||
self.decision_fusion_enabled
|
||||
and self.decision_fusion_mode == "neural"
|
||||
and self.decision_fusion_network is not None
|
||||
and decision_fusion_inference_enabled
|
||||
):
|
||||
# Use neural decision fusion
|
||||
decision = self._make_decision_fusion_decision(
|
||||
@ -1861,6 +1940,11 @@ class TradingOrchestrator:
|
||||
)
|
||||
else:
|
||||
# Use programmatic decision combination
|
||||
if not decision_fusion_inference_enabled:
|
||||
logger.info(f"Decision fusion model disabled, using programmatic mode for {symbol}")
|
||||
else:
|
||||
logger.debug(f"Using programmatic decision combination for {symbol}")
|
||||
|
||||
decision = self._combine_predictions(
|
||||
symbol=symbol,
|
||||
price=current_price,
|
||||
@ -4490,8 +4574,13 @@ class TradingOrchestrator:
|
||||
action_scores = {"BUY": 0.0, "SELL": 0.0, "HOLD": 0.0}
|
||||
total_weight = 0.0
|
||||
|
||||
# Process all predictions
|
||||
# Process all predictions (filter out disabled models)
|
||||
for pred in predictions:
|
||||
# Check if model inference is enabled
|
||||
if not self.is_model_inference_enabled(pred.model_name):
|
||||
logger.debug(f"Skipping disabled model {pred.model_name} in decision making")
|
||||
continue
|
||||
|
||||
# Get model weight
|
||||
model_weight = self.model_weights.get(pred.model_name, 0.1)
|
||||
|
||||
@ -4914,16 +5003,17 @@ class TradingOrchestrator:
|
||||
self.fc4 = nn.Linear(hidden_size // 2, 3) # BUY, SELL, HOLD
|
||||
|
||||
self.dropout = nn.Dropout(0.3)
|
||||
self.batch_norm1 = nn.BatchNorm1d(hidden_size)
|
||||
self.batch_norm2 = nn.BatchNorm1d(hidden_size)
|
||||
self.batch_norm3 = nn.BatchNorm1d(hidden_size // 2)
|
||||
# Use LayerNorm instead of BatchNorm1d for single-sample training compatibility
|
||||
self.layer_norm1 = nn.LayerNorm(hidden_size)
|
||||
self.layer_norm2 = nn.LayerNorm(hidden_size)
|
||||
self.layer_norm3 = nn.LayerNorm(hidden_size // 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.relu(self.batch_norm1(self.fc1(x)))
|
||||
x = torch.relu(self.layer_norm1(self.fc1(x)))
|
||||
x = self.dropout(x)
|
||||
x = torch.relu(self.batch_norm2(self.fc2(x)))
|
||||
x = torch.relu(self.layer_norm2(self.fc2(x)))
|
||||
x = self.dropout(x)
|
||||
x = torch.relu(self.batch_norm3(self.fc3(x)))
|
||||
x = torch.relu(self.layer_norm3(self.fc3(x)))
|
||||
x = self.dropout(x)
|
||||
return torch.softmax(self.fc4(x), dim=1)
|
||||
|
||||
@ -4980,18 +5070,30 @@ class TradingOrchestrator:
|
||||
try:
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
|
||||
result = load_best_checkpoint("decision", "decision")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
self.decision_fusion_network.load(file_path)
|
||||
self.model_states["decision"]["checkpoint_loaded"] = True
|
||||
self.model_states["decision"][
|
||||
"checkpoint_filename"
|
||||
] = metadata.checkpoint_id
|
||||
logger.info(
|
||||
f"Decision fusion network loaded from checkpoint: {metadata.checkpoint_id}"
|
||||
)
|
||||
else:
|
||||
# Try multiple checkpoint names for decision fusion
|
||||
checkpoint_names = ["decision_fusion", "decision", "fusion"]
|
||||
checkpoint_loaded = False
|
||||
|
||||
for checkpoint_name in checkpoint_names:
|
||||
try:
|
||||
result = load_best_checkpoint(checkpoint_name, checkpoint_name)
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
self.decision_fusion_network.load(file_path)
|
||||
self.model_states["decision"]["checkpoint_loaded"] = True
|
||||
self.model_states["decision"][
|
||||
"checkpoint_filename"
|
||||
] = metadata.checkpoint_id
|
||||
logger.info(
|
||||
f"Decision fusion network loaded from checkpoint: {metadata.checkpoint_id}"
|
||||
)
|
||||
checkpoint_loaded = True
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load checkpoint '{checkpoint_name}': {e}")
|
||||
continue
|
||||
|
||||
if not checkpoint_loaded:
|
||||
logger.info(
|
||||
"No existing decision fusion checkpoint found, starting fresh"
|
||||
)
|
||||
@ -5113,6 +5215,15 @@ class TradingOrchestrator:
|
||||
elif len(features) > expected_size:
|
||||
features = features[:expected_size]
|
||||
|
||||
# Log input feature statistics for debugging
|
||||
if len(features) > 0:
|
||||
feature_array = np.array(features)
|
||||
logger.debug(f"Decision fusion input features: size={len(features)}, "
|
||||
f"mean={np.mean(feature_array):.4f}, "
|
||||
f"std={np.std(feature_array):.4f}, "
|
||||
f"min={np.min(feature_array):.4f}, "
|
||||
f"max={np.max(feature_array):.4f}")
|
||||
|
||||
return torch.tensor(
|
||||
features, dtype=torch.float32, device=self.device
|
||||
).unsqueeze(0)
|
||||
@ -5399,6 +5510,15 @@ class TradingOrchestrator:
|
||||
if input_features is None:
|
||||
logger.warning("No input features found for decision fusion training")
|
||||
return
|
||||
|
||||
# Validate input features
|
||||
if not isinstance(input_features, torch.Tensor):
|
||||
logger.warning(f"Invalid input features type: {type(input_features)}")
|
||||
return
|
||||
|
||||
if input_features.dim() != 2 or input_features.size(0) != 1:
|
||||
logger.warning(f"Invalid input features shape: {input_features.shape}")
|
||||
return
|
||||
|
||||
# Create target based on outcome
|
||||
predicted_action = record.get("action", "HOLD")
|
||||
@ -5433,24 +5553,19 @@ class TradingOrchestrator:
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Temporarily disable batch normalization for single sample training
|
||||
for module in self.decision_fusion_network.modules():
|
||||
if isinstance(module, nn.BatchNorm1d):
|
||||
module.eval() # Use running statistics instead of batch statistics
|
||||
|
||||
# Forward pass - handle single sample properly
|
||||
# Forward pass - LayerNorm works with single samples
|
||||
output = self.decision_fusion_network(input_features)
|
||||
loss = criterion(output, target.unsqueeze(0))
|
||||
|
||||
# Log training details for debugging
|
||||
logger.debug(f"Decision fusion training: input_shape={input_features.shape}, "
|
||||
f"output_shape={output.shape}, target_shape={target.unsqueeze(0).shape}, "
|
||||
f"loss={loss.item():.4f}")
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Re-enable batch normalization for future training
|
||||
for module in self.decision_fusion_network.modules():
|
||||
if isinstance(module, nn.BatchNorm1d):
|
||||
module.train()
|
||||
|
||||
# Set back to eval mode for inference
|
||||
self.decision_fusion_network.eval()
|
||||
|
||||
@ -6241,8 +6356,8 @@ class TradingOrchestrator:
|
||||
# Track if any model was trained for checkpoint saving
|
||||
models_trained = []
|
||||
|
||||
# Train DQN agent if available
|
||||
if self.rl_agent and hasattr(self.rl_agent, "add_experience"):
|
||||
# Train DQN agent if available and enabled
|
||||
if self.rl_agent and hasattr(self.rl_agent, "add_experience") and self.is_model_training_enabled("dqn"):
|
||||
try:
|
||||
# Create state representation
|
||||
state = self._create_state_for_training(symbol, market_data)
|
||||
@ -6271,8 +6386,8 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training DQN on decision: {e}")
|
||||
|
||||
# Train CNN model if available
|
||||
if self.cnn_model and hasattr(self.cnn_model, "add_training_sample"):
|
||||
# Train CNN model if available and enabled
|
||||
if self.cnn_model and hasattr(self.cnn_model, "add_training_sample") and self.is_model_training_enabled("cnn"):
|
||||
try:
|
||||
# Create CNN input features
|
||||
cnn_features = self._create_cnn_features_for_training(
|
||||
@ -6298,8 +6413,8 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training CNN on decision: {e}")
|
||||
|
||||
# Train COB RL model if available and we have COB data
|
||||
if self.cob_rl_agent and symbol in self.latest_cob_data:
|
||||
# Train COB RL model if available, enabled, and we have COB data
|
||||
if self.cob_rl_agent and symbol in self.latest_cob_data and self.is_model_training_enabled("cob_rl"):
|
||||
try:
|
||||
cob_data = self.latest_cob_data[symbol]
|
||||
if hasattr(self.cob_rl_agent, "add_experience"):
|
||||
@ -6322,6 +6437,33 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training COB RL on decision: {e}")
|
||||
|
||||
# Train decision fusion model if available and enabled
|
||||
if self.decision_fusion_network and self.is_model_training_enabled("decision_fusion"):
|
||||
try:
|
||||
# Create decision fusion input
|
||||
fusion_input = self._create_decision_fusion_training_input(
|
||||
symbol, market_data
|
||||
)
|
||||
|
||||
# Create target based on action
|
||||
target_mapping = {
|
||||
"BUY": [1, 0, 0],
|
||||
"SELL": [0, 1, 0],
|
||||
"HOLD": [0, 0, 1],
|
||||
}
|
||||
target = target_mapping.get(action, [0, 0, 1])
|
||||
|
||||
# Add training sample
|
||||
self.decision_fusion_network.add_training_sample(
|
||||
fusion_input, target, weight=confidence
|
||||
)
|
||||
|
||||
models_trained.append("decision_fusion")
|
||||
logger.debug(f"🤝 Added decision fusion training sample: {action} {symbol}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training decision fusion on decision: {e}")
|
||||
|
||||
# CRITICAL FIX: Save checkpoints after training
|
||||
if models_trained:
|
||||
self._save_training_checkpoints(models_trained, confidence)
|
||||
@ -6363,6 +6505,10 @@ class TradingOrchestrator:
|
||||
model_obj = self.cob_rl_agent
|
||||
current_loss = 1.0 - performance_score
|
||||
|
||||
elif model_name == "decision_fusion" and self.decision_fusion_network:
|
||||
model_obj = self.decision_fusion_network
|
||||
current_loss = 1.0 - performance_score
|
||||
|
||||
if model_obj and current_loss is not None:
|
||||
# Check if this is the best performance so far
|
||||
model_state = self.model_states.get(model_name, {})
|
||||
@ -6372,11 +6518,11 @@ class TradingOrchestrator:
|
||||
model_state["current_loss"] = current_loss
|
||||
model_state["last_training"] = datetime.now()
|
||||
|
||||
# Save checkpoint if performance improved or periodic save
|
||||
# Save checkpoint if performance improved or every 3rd training
|
||||
should_save = (
|
||||
current_loss < best_loss # Performance improved
|
||||
or self.training_iterations % 100
|
||||
== 0 # Periodic save every 100 iterations
|
||||
or self.training_iterations % 3
|
||||
== 0 # Save every 3rd training iteration
|
||||
)
|
||||
|
||||
if should_save:
|
||||
@ -6548,6 +6694,38 @@ class TradingOrchestrator:
|
||||
logger.debug(f"Error creating COB state for training: {e}")
|
||||
return np.zeros(2000)
|
||||
|
||||
def _create_decision_fusion_training_input(self, symbol: str, market_data: Dict) -> np.ndarray:
|
||||
"""Create decision fusion training input from market data"""
|
||||
try:
|
||||
# Extract features from market data
|
||||
ohlcv_data = market_data.get("ohlcv", [])
|
||||
if not ohlcv_data:
|
||||
return np.zeros(100) # Default state size
|
||||
|
||||
# Extract features from recent candles
|
||||
features = []
|
||||
for candle in ohlcv_data[-20:]: # Last 20 candles
|
||||
features.extend(
|
||||
[
|
||||
candle.get("open", 0),
|
||||
candle.get("high", 0),
|
||||
candle.get("low", 0),
|
||||
candle.get("close", 0),
|
||||
candle.get("volume", 0),
|
||||
]
|
||||
)
|
||||
|
||||
# Pad or truncate to expected size
|
||||
state = np.array(features[:100])
|
||||
if len(state) < 100:
|
||||
state = np.pad(state, (0, 100 - len(state)), "constant")
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error creating decision fusion input: {e}")
|
||||
return np.zeros(100)
|
||||
|
||||
def _check_signal_confirmation(
|
||||
self, symbol: str, signal_data: Dict
|
||||
) -> Optional[str]:
|
||||
@ -6684,175 +6862,6 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Database cleanup failed: {e}")
|
||||
|
||||
def _save_training_checkpoints(
|
||||
self, models_trained: List[str], performance_score: float
|
||||
):
|
||||
"""Save checkpoints for trained models if performance improved
|
||||
|
||||
This is CRITICAL for preserving training progress across restarts.
|
||||
"""
|
||||
try:
|
||||
if not self.checkpoint_manager:
|
||||
return
|
||||
|
||||
# Increment training counter
|
||||
self.training_iterations += 1
|
||||
|
||||
# Save checkpoints for each trained model
|
||||
for model_name in models_trained:
|
||||
try:
|
||||
model_obj = None
|
||||
current_loss = None
|
||||
model_type = model_name
|
||||
|
||||
# Get model object and calculate current performance
|
||||
if model_name == "dqn" and self.rl_agent:
|
||||
model_obj = self.rl_agent
|
||||
# Use current loss from model state or estimate from performance
|
||||
current_loss = self.model_states["dqn"].get("current_loss")
|
||||
if current_loss is None:
|
||||
# Estimate loss from performance score (inverse relationship)
|
||||
current_loss = max(0.001, 1.0 - performance_score)
|
||||
|
||||
# Update model state tracking
|
||||
self.model_states["dqn"]["current_loss"] = current_loss
|
||||
|
||||
# If this is the first loss value, set it as initial and best
|
||||
if self.model_states["dqn"]["initial_loss"] is None:
|
||||
self.model_states["dqn"]["initial_loss"] = current_loss
|
||||
if (
|
||||
self.model_states["dqn"]["best_loss"] is None
|
||||
or current_loss < self.model_states["dqn"]["best_loss"]
|
||||
):
|
||||
self.model_states["dqn"]["best_loss"] = current_loss
|
||||
|
||||
elif model_name == "cnn" and self.cnn_model:
|
||||
model_obj = self.cnn_model
|
||||
# Use current loss from model state or estimate from performance
|
||||
current_loss = self.model_states["cnn"].get("current_loss")
|
||||
if current_loss is None:
|
||||
# Estimate loss from performance score (inverse relationship)
|
||||
current_loss = max(0.001, 1.0 - performance_score)
|
||||
|
||||
# Update model state tracking
|
||||
self.model_states["cnn"]["current_loss"] = current_loss
|
||||
|
||||
# If this is the first loss value, set it as initial and best
|
||||
if self.model_states["cnn"]["initial_loss"] is None:
|
||||
self.model_states["cnn"]["initial_loss"] = current_loss
|
||||
if (
|
||||
self.model_states["cnn"]["best_loss"] is None
|
||||
or current_loss < self.model_states["cnn"]["best_loss"]
|
||||
):
|
||||
self.model_states["cnn"]["best_loss"] = current_loss
|
||||
|
||||
elif model_name == "cob_rl" and self.cob_rl_agent:
|
||||
model_obj = self.cob_rl_agent
|
||||
# Use current loss from model state or estimate from performance
|
||||
current_loss = self.model_states["cob_rl"].get("current_loss")
|
||||
if current_loss is None:
|
||||
# Estimate loss from performance score (inverse relationship)
|
||||
current_loss = max(0.001, 1.0 - performance_score)
|
||||
|
||||
# Update model state tracking
|
||||
self.model_states["cob_rl"]["current_loss"] = current_loss
|
||||
|
||||
# If this is the first loss value, set it as initial and best
|
||||
if self.model_states["cob_rl"]["initial_loss"] is None:
|
||||
self.model_states["cob_rl"]["initial_loss"] = current_loss
|
||||
if (
|
||||
self.model_states["cob_rl"]["best_loss"] is None
|
||||
or current_loss < self.model_states["cob_rl"]["best_loss"]
|
||||
):
|
||||
self.model_states["cob_rl"]["best_loss"] = current_loss
|
||||
|
||||
elif (
|
||||
model_name == "extrema"
|
||||
and hasattr(self, "extrema_trainer")
|
||||
and self.extrema_trainer
|
||||
):
|
||||
model_obj = self.extrema_trainer
|
||||
# Use current loss from model state or estimate from performance
|
||||
current_loss = self.model_states["extrema"].get("current_loss")
|
||||
if current_loss is None:
|
||||
# Estimate loss from performance score (inverse relationship)
|
||||
current_loss = max(0.001, 1.0 - performance_score)
|
||||
|
||||
# Update model state tracking
|
||||
self.model_states["extrema"]["current_loss"] = current_loss
|
||||
|
||||
# If this is the first loss value, set it as initial and best
|
||||
if self.model_states["extrema"]["initial_loss"] is None:
|
||||
self.model_states["extrema"]["initial_loss"] = current_loss
|
||||
if (
|
||||
self.model_states["extrema"]["best_loss"] is None
|
||||
or current_loss < self.model_states["extrema"]["best_loss"]
|
||||
):
|
||||
self.model_states["extrema"]["best_loss"] = current_loss
|
||||
|
||||
# Skip if we couldn't get a model object
|
||||
if model_obj is None:
|
||||
continue
|
||||
|
||||
# Prepare performance metrics for checkpoint
|
||||
performance_metrics = {
|
||||
"loss": current_loss,
|
||||
"accuracy": performance_score, # Use confidence as a proxy for accuracy
|
||||
}
|
||||
|
||||
# Prepare training metadata
|
||||
training_metadata = {
|
||||
"training_iteration": self.training_iterations,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# Save checkpoint using checkpoint manager
|
||||
from utils.checkpoint_manager import save_checkpoint
|
||||
|
||||
checkpoint_metadata = save_checkpoint(
|
||||
model=model_obj,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata,
|
||||
)
|
||||
|
||||
if checkpoint_metadata:
|
||||
logger.info(
|
||||
f"Saved checkpoint for {model_name}: {checkpoint_metadata.checkpoint_id} (loss={current_loss:.4f})"
|
||||
)
|
||||
|
||||
# Also save periodically based on training iterations
|
||||
if self.training_iterations % 100 == 0:
|
||||
# Force save every 100 training iterations regardless of performance
|
||||
checkpoint_metadata = save_checkpoint(
|
||||
model=model_obj,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata,
|
||||
force_save=True,
|
||||
)
|
||||
if checkpoint_metadata:
|
||||
logger.info(
|
||||
f"Periodic checkpoint saved for {model_name}: {checkpoint_metadata.checkpoint_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _save_training_checkpoints: {e}")
|
||||
|
||||
def _schedule_database_cleanup(self):
|
||||
"""Schedule periodic database cleanup"""
|
||||
try:
|
||||
# Clean up old inference records (keep 30 days)
|
||||
self.inference_logger.cleanup_old_logs(days_to_keep=30)
|
||||
logger.info("Database cleanup completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Database cleanup failed: {e}")
|
||||
|
||||
def log_model_inference(
|
||||
self,
|
||||
model_name: str,
|
||||
|
Reference in New Issue
Block a user