219 lines
7.2 KiB
Python
219 lines
7.2 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
TensorBoard Logger Utility
|
|
|
|
This module provides a centralized way to log training metrics to TensorBoard.
|
|
It ensures consistent logging across different training components.
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from typing import Dict, Any, Optional, Union, List
|
|
|
|
# Import conditionally to handle missing dependencies gracefully
|
|
try:
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
TENSORBOARD_AVAILABLE = True
|
|
except ImportError:
|
|
TENSORBOARD_AVAILABLE = False
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TensorBoardLogger:
|
|
"""
|
|
Centralized TensorBoard logging utility for training metrics
|
|
|
|
This class provides a consistent interface for logging metrics to TensorBoard
|
|
across different training components.
|
|
"""
|
|
|
|
def __init__(self,
|
|
log_dir: Optional[str] = None,
|
|
experiment_name: Optional[str] = None,
|
|
enabled: bool = True):
|
|
"""
|
|
Initialize TensorBoard logger
|
|
|
|
Args:
|
|
log_dir: Base directory for TensorBoard logs (default: 'runs')
|
|
experiment_name: Name of the experiment (default: timestamp)
|
|
enabled: Whether TensorBoard logging is enabled
|
|
"""
|
|
self.enabled = enabled and TENSORBOARD_AVAILABLE
|
|
self.writer = None
|
|
|
|
if not self.enabled:
|
|
if not TENSORBOARD_AVAILABLE:
|
|
logger.warning("TensorBoard not available. Install with: pip install tensorboard")
|
|
return
|
|
|
|
# Set up log directory
|
|
if log_dir is None:
|
|
log_dir = "runs"
|
|
|
|
# Create experiment name if not provided
|
|
if experiment_name is None:
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
experiment_name = f"training_{timestamp}"
|
|
|
|
# Create full log path
|
|
self.log_dir = os.path.join(log_dir, experiment_name)
|
|
|
|
# Create writer
|
|
try:
|
|
self.writer = SummaryWriter(log_dir=self.log_dir)
|
|
logger.info(f"TensorBoard logging enabled at: {self.log_dir}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize TensorBoard: {e}")
|
|
self.enabled = False
|
|
|
|
def log_scalar(self, tag: str, value: float, step: int) -> None:
|
|
"""
|
|
Log a scalar value to TensorBoard
|
|
|
|
Args:
|
|
tag: Metric name
|
|
value: Metric value
|
|
step: Training step
|
|
"""
|
|
if not self.enabled or self.writer is None:
|
|
return
|
|
|
|
try:
|
|
self.writer.add_scalar(tag, value, step)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to log scalar {tag}: {e}")
|
|
|
|
def log_scalars(self, main_tag: str, tag_value_dict: Dict[str, float], step: int) -> None:
|
|
"""
|
|
Log multiple scalar values with the same main tag
|
|
|
|
Args:
|
|
main_tag: Main tag for the metrics
|
|
tag_value_dict: Dictionary of tag names to values
|
|
step: Training step
|
|
"""
|
|
if not self.enabled or self.writer is None:
|
|
return
|
|
|
|
try:
|
|
self.writer.add_scalars(main_tag, tag_value_dict, step)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to log scalars for {main_tag}: {e}")
|
|
|
|
def log_histogram(self, tag: str, values, step: int) -> None:
|
|
"""
|
|
Log a histogram to TensorBoard
|
|
|
|
Args:
|
|
tag: Histogram name
|
|
values: Values to create histogram from
|
|
step: Training step
|
|
"""
|
|
if not self.enabled or self.writer is None:
|
|
return
|
|
|
|
try:
|
|
self.writer.add_histogram(tag, values, step)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to log histogram {tag}: {e}")
|
|
|
|
def log_training_metrics(self,
|
|
metrics: Dict[str, Any],
|
|
step: int,
|
|
prefix: str = "Training") -> None:
|
|
"""
|
|
Log training metrics to TensorBoard
|
|
|
|
Args:
|
|
metrics: Dictionary of metric names to values
|
|
step: Training step
|
|
prefix: Prefix for metric names
|
|
"""
|
|
if not self.enabled or self.writer is None:
|
|
return
|
|
|
|
for name, value in metrics.items():
|
|
if isinstance(value, (int, float)):
|
|
self.log_scalar(f"{prefix}/{name}", value, step)
|
|
elif hasattr(value, "shape"): # For numpy arrays or tensors
|
|
try:
|
|
self.log_histogram(f"{prefix}/{name}", value, step)
|
|
except:
|
|
pass
|
|
|
|
def log_model_metrics(self,
|
|
model_name: str,
|
|
metrics: Dict[str, Any],
|
|
step: int) -> None:
|
|
"""
|
|
Log model-specific metrics to TensorBoard
|
|
|
|
Args:
|
|
model_name: Name of the model
|
|
metrics: Dictionary of metric names to values
|
|
step: Training step
|
|
"""
|
|
if not self.enabled or self.writer is None:
|
|
return
|
|
|
|
for name, value in metrics.items():
|
|
if isinstance(value, (int, float)):
|
|
self.log_scalar(f"Model/{model_name}/{name}", value, step)
|
|
|
|
def log_reward_metrics(self,
|
|
symbol: str,
|
|
metrics: Dict[str, float],
|
|
step: int) -> None:
|
|
"""
|
|
Log reward-related metrics to TensorBoard
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
metrics: Dictionary of metric names to values
|
|
step: Training step
|
|
"""
|
|
if not self.enabled or self.writer is None:
|
|
return
|
|
|
|
for name, value in metrics.items():
|
|
self.log_scalar(f"Rewards/{symbol}/{name}", value, step)
|
|
|
|
def log_state_metrics(self,
|
|
symbol: str,
|
|
state_info: Dict[str, Any],
|
|
step: int) -> None:
|
|
"""
|
|
Log state-related metrics to TensorBoard
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
state_info: Dictionary of state information
|
|
step: Training step
|
|
"""
|
|
if not self.enabled or self.writer is None:
|
|
return
|
|
|
|
# Log state size
|
|
if "size" in state_info:
|
|
self.log_scalar(f"State/{symbol}/Size", state_info["size"], step)
|
|
|
|
# Log state quality
|
|
if "quality" in state_info:
|
|
self.log_scalar(f"State/{symbol}/Quality", state_info["quality"], step)
|
|
|
|
# Log feature counts
|
|
if "feature_counts" in state_info:
|
|
for feature_type, count in state_info["feature_counts"].items():
|
|
self.log_scalar(f"State/{symbol}/Features/{feature_type}", count, step)
|
|
|
|
def close(self) -> None:
|
|
"""Close the TensorBoard writer"""
|
|
if self.enabled and self.writer is not None:
|
|
try:
|
|
self.writer.close()
|
|
logger.info("TensorBoard writer closed")
|
|
except Exception as e:
|
|
logger.warning(f"Error closing TensorBoard writer: {e}") |