Files
gogo2/utils/tensorboard_logger.py
2025-07-22 15:44:59 +03:00

219 lines
7.2 KiB
Python

#!/usr/bin/env python3
"""
TensorBoard Logger Utility
This module provides a centralized way to log training metrics to TensorBoard.
It ensures consistent logging across different training components.
"""
import os
import logging
from pathlib import Path
from datetime import datetime
from typing import Dict, Any, Optional, Union, List
# Import conditionally to handle missing dependencies gracefully
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
TENSORBOARD_AVAILABLE = False
logger = logging.getLogger(__name__)
class TensorBoardLogger:
"""
Centralized TensorBoard logging utility for training metrics
This class provides a consistent interface for logging metrics to TensorBoard
across different training components.
"""
def __init__(self,
log_dir: Optional[str] = None,
experiment_name: Optional[str] = None,
enabled: bool = True):
"""
Initialize TensorBoard logger
Args:
log_dir: Base directory for TensorBoard logs (default: 'runs')
experiment_name: Name of the experiment (default: timestamp)
enabled: Whether TensorBoard logging is enabled
"""
self.enabled = enabled and TENSORBOARD_AVAILABLE
self.writer = None
if not self.enabled:
if not TENSORBOARD_AVAILABLE:
logger.warning("TensorBoard not available. Install with: pip install tensorboard")
return
# Set up log directory
if log_dir is None:
log_dir = "runs"
# Create experiment name if not provided
if experiment_name is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_name = f"training_{timestamp}"
# Create full log path
self.log_dir = os.path.join(log_dir, experiment_name)
# Create writer
try:
self.writer = SummaryWriter(log_dir=self.log_dir)
logger.info(f"TensorBoard logging enabled at: {self.log_dir}")
except Exception as e:
logger.error(f"Failed to initialize TensorBoard: {e}")
self.enabled = False
def log_scalar(self, tag: str, value: float, step: int) -> None:
"""
Log a scalar value to TensorBoard
Args:
tag: Metric name
value: Metric value
step: Training step
"""
if not self.enabled or self.writer is None:
return
try:
self.writer.add_scalar(tag, value, step)
except Exception as e:
logger.warning(f"Failed to log scalar {tag}: {e}")
def log_scalars(self, main_tag: str, tag_value_dict: Dict[str, float], step: int) -> None:
"""
Log multiple scalar values with the same main tag
Args:
main_tag: Main tag for the metrics
tag_value_dict: Dictionary of tag names to values
step: Training step
"""
if not self.enabled or self.writer is None:
return
try:
self.writer.add_scalars(main_tag, tag_value_dict, step)
except Exception as e:
logger.warning(f"Failed to log scalars for {main_tag}: {e}")
def log_histogram(self, tag: str, values, step: int) -> None:
"""
Log a histogram to TensorBoard
Args:
tag: Histogram name
values: Values to create histogram from
step: Training step
"""
if not self.enabled or self.writer is None:
return
try:
self.writer.add_histogram(tag, values, step)
except Exception as e:
logger.warning(f"Failed to log histogram {tag}: {e}")
def log_training_metrics(self,
metrics: Dict[str, Any],
step: int,
prefix: str = "Training") -> None:
"""
Log training metrics to TensorBoard
Args:
metrics: Dictionary of metric names to values
step: Training step
prefix: Prefix for metric names
"""
if not self.enabled or self.writer is None:
return
for name, value in metrics.items():
if isinstance(value, (int, float)):
self.log_scalar(f"{prefix}/{name}", value, step)
elif hasattr(value, "shape"): # For numpy arrays or tensors
try:
self.log_histogram(f"{prefix}/{name}", value, step)
except:
pass
def log_model_metrics(self,
model_name: str,
metrics: Dict[str, Any],
step: int) -> None:
"""
Log model-specific metrics to TensorBoard
Args:
model_name: Name of the model
metrics: Dictionary of metric names to values
step: Training step
"""
if not self.enabled or self.writer is None:
return
for name, value in metrics.items():
if isinstance(value, (int, float)):
self.log_scalar(f"Model/{model_name}/{name}", value, step)
def log_reward_metrics(self,
symbol: str,
metrics: Dict[str, float],
step: int) -> None:
"""
Log reward-related metrics to TensorBoard
Args:
symbol: Trading symbol
metrics: Dictionary of metric names to values
step: Training step
"""
if not self.enabled or self.writer is None:
return
for name, value in metrics.items():
self.log_scalar(f"Rewards/{symbol}/{name}", value, step)
def log_state_metrics(self,
symbol: str,
state_info: Dict[str, Any],
step: int) -> None:
"""
Log state-related metrics to TensorBoard
Args:
symbol: Trading symbol
state_info: Dictionary of state information
step: Training step
"""
if not self.enabled or self.writer is None:
return
# Log state size
if "size" in state_info:
self.log_scalar(f"State/{symbol}/Size", state_info["size"], step)
# Log state quality
if "quality" in state_info:
self.log_scalar(f"State/{symbol}/Quality", state_info["quality"], step)
# Log feature counts
if "feature_counts" in state_info:
for feature_type, count in state_info["feature_counts"].items():
self.log_scalar(f"State/{symbol}/Features/{feature_type}", count, step)
def close(self) -> None:
"""Close the TensorBoard writer"""
if self.enabled and self.writer is not None:
try:
self.writer.close()
logger.info("TensorBoard writer closed")
except Exception as e:
logger.warning(f"Error closing TensorBoard writer: {e}")