"""Online latent space visualization callback with dimensionality reduction.
This callback learns a 2D projection of high-dimensional features while preserving
neighborhood structure using a contrastive loss between high-D and low-D similarities.
"""
from functools import partial
from typing import Dict, Literal, Optional, Union
import numpy as np
import torch
from hydra.utils import instantiate
from lightning.pytorch import LightningModule, Trainer
from loguru import logger as logging
from torch import Tensor
from ..utils.distance_metrics import compute_pairwise_distances_chunked
from .queue import find_or_create_queue_callback
from .utils import TrainableCallback
[docs]
class LatentViz(TrainableCallback):
"""Online latent visualization callback with neighborhood-preserving dimensionality reduction.
This callback learns a 2D projection that preserves neighborhood structure from
high-dimensional features. It uses a contrastive loss that attracts neighbors
and repels non-neighbors in the 2D space.
The loss function is:
L = -∑_{ij} P_{ij} log Q_{ij} + ∑_{i,j ∈ Neg(i)} log(1 - Q_{ij})
where:
- P_{ij} is the high-D neighborhood graph (based on k-NN)
- Q_{ij} is the similarity in the learned 2D space
- Neg(i) is the set of negative samples for point i
Args:
name: Unique identifier for this callback instance.
input: Key in batch dict containing input features to visualize.
target: Optional key in batch dict containing labels for coloring plots.
If None, points will be plotted without color coding.
projection: The projection module to train (maps high-D to 2D). Can be:
- nn.Module instance
- callable that returns a module
- Hydra config to instantiate
queue_length: Size of the circular buffer for features.
k_neighbors: Number of nearest neighbors for building P matrix.
n_negatives: Number of negative samples per positive pair.
optimizer: Optimizer configuration. If None, uses Adam (recommended for DR tasks).
scheduler: Learning rate scheduler configuration. If None, uses ConstantLR.
accumulate_grad_batches: Number of batches to accumulate gradients.
update_interval: Update projection network every N training batches (default: 10).
warmup_epochs: Number of epochs to wait before starting projection training (default: 0).
Allows main model to stabilize before learning 2D projections.
distance_metric: Metric for computing distances in high-D space.
plot_interval: Interval (in epochs) for plotting 2D visualization.
save_dir: Optional directory to save plots. If None, saves to 'latent_viz_{name}'.
input_dim: Expected dimensionality of input features (for queue).
"""
def __init__(
self,
name: str,
input: str,
target: Optional[str],
projection: torch.nn.Module,
queue_length: int = 2048,
k_neighbors: int = 15,
n_negatives: int = 5,
optimizer: Optional[Union[str, dict, partial, torch.optim.Optimizer]] = None,
scheduler: Optional[
Union[str, dict, partial, torch.optim.lr_scheduler.LRScheduler]
] = None,
accumulate_grad_batches: int = 1,
update_interval: int = 10,
warmup_epochs: int = 0,
distance_metric: Literal["euclidean", "cosine"] = "euclidean",
plot_interval: int = 10,
save_dir: Optional[str] = None,
input_dim: Optional[Union[int, tuple, list]] = None,
):
super().__init__(
name=name,
optimizer=optimizer,
scheduler=scheduler,
accumulate_grad_batches=accumulate_grad_batches,
)
self.input = input
self.target = target
self.queue_length = queue_length
self.k_neighbors = k_neighbors
self.n_negatives = n_negatives
self.update_interval = update_interval
self.warmup_epochs = warmup_epochs
self.distance_metric = distance_metric
self.plot_interval = plot_interval
self.save_dir = save_dir
if input_dim is not None and isinstance(input_dim, (list, tuple)):
import numpy as np
input_dim = int(np.prod(input_dim))
self.input_dim = input_dim
self._projection_config = projection
# Will be initialized in setup
self._input_queue = None
self._target_queue = None
logging.info(f"Initialized LatentViz callback: {name}")
logging.info(f" - Input: {input}")
logging.info(
f" - Target: {target if target else 'None (no labels for coloring)'}"
)
logging.info(f" - Queue length: {queue_length}")
logging.info(f" - K neighbors: {k_neighbors}")
logging.info(f" - Negative samples: {n_negatives}")
logging.info(f" - Update interval: {update_interval} batches")
logging.info(f" - Warmup epochs: {warmup_epochs}")
logging.info(f" - Accumulate grad batches: {accumulate_grad_batches}")
def _initialize_module(self, pl_module: LightningModule) -> torch.nn.Module:
"""Initialize the projection module from configuration."""
if isinstance(self._projection_config, torch.nn.Module):
projection_module = self._projection_config
elif callable(self._projection_config):
projection_module = self._projection_config()
else:
projection_module = instantiate(self._projection_config, _convert_="object")
return projection_module
[docs]
def setup_optimizer(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Initialize optimizer - default to AdamW for dimensionality reduction tasks."""
if self._optimizer_config is None:
# Use AdamW by default for LatentViz (better weight decay handling)
logging.info(
f"{self.name}: Using default AdamW optimizer for dimensionality reduction"
)
self.optimizer = torch.optim.AdamW(
self.module.parameters(),
lr=1e-3, # Good default for AdamW
weight_decay=1e-2, # Higher weight decay works well with AdamW
betas=(0.9, 0.999), # Standard Adam betas
)
else:
# Use explicitly provided optimizer config
from stable_pretraining.optim.utils import create_optimizer
self.optimizer = create_optimizer(
self.module.parameters(), self._optimizer_config
)
[docs]
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
"""Setup module, optimizer, scheduler, and queues."""
super().setup(trainer, pl_module, stage)
if stage != "fit":
return
# Find or create queues (same as knn.py)
self._input_queue = find_or_create_queue_callback(
trainer,
self.input,
self.queue_length,
self.input_dim,
torch.float32 if self.input_dim is not None else None,
gather_distributed=True,
create_if_missing=True,
)
logging.info(f"{self.name}: Using queue for input '{self.input}'")
# Only create target queue if target is specified
if self.target is not None:
self._target_queue = find_or_create_queue_callback(
trainer,
self.target,
self.queue_length,
None, # No specific dimension for targets
torch.long,
gather_distributed=True,
create_if_missing=True,
)
logging.info(f"{self.name}: Using queue for target '{self.target}'")
[docs]
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: Dict,
batch: Dict,
batch_idx: int,
) -> None:
"""Perform projection network training step."""
# Skip training during warmup period
if trainer.current_epoch < self.warmup_epochs:
if batch_idx == 0: # Log once per epoch
logging.info(
f"{self.name}: Warmup period - skipping projection training "
f"(epoch {trainer.current_epoch + 1}/{self.warmup_epochs})"
)
return
# Only update every N batches to reduce computational overhead
if batch_idx % self.update_interval != 0:
return
# Get cached features directly from the shared queue
# Access the raw queue from the class-level registry
from .queue import OnlineQueue
shared_queue = OnlineQueue._shared_queues.get(self.input)
if shared_queue is None:
return
cached_features = shared_queue.get()
if cached_features is None or len(cached_features) == 0:
return
self.module.train()
with torch.enable_grad():
# Detach features to prevent gradients flowing to main model
x = cached_features.detach()
proj_dtype = next(self.module.parameters()).dtype
if x.dtype != proj_dtype:
x = x.to(proj_dtype)
z_2d = self.module(x)
loss = self._compute_loss(x, z_2d)
loss = loss / self.accumulate_grad_batches
loss.backward()
loss_value = loss.item() * self.accumulate_grad_batches
pl_module.log(
f"train/{self.name}_loss",
loss_value,
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True,
)
self.optimizer_step(batch_idx, trainer)
def _compute_loss(
self,
x_high: Tensor,
z_2d: Tensor,
) -> Tensor:
"""Compute the neighborhood-preserving loss.
Loss = -∑_{ij} P_{ij} log Q_{ij} + ∑_{i,j ∈ Neg(i)} log(1 - Q_{ij})
Args:
x_high: High-dimensional features [N, D]
z_2d: 2D projections [N, 2]
"""
n_samples = x_high.size(0)
device = x_high.device
chunk_size = 256 if n_samples > 1000 else -1
high_d_distances = compute_pairwise_distances_chunked(
x_high, x_high, metric=self.distance_metric, chunk_size=chunk_size
)
k_actual = min(self.k_neighbors, n_samples - 1) # Exclude self
high_d_distances.fill_diagonal_(float("inf")) # Exclude self
_, nn_indices = high_d_distances.topk(k=k_actual, dim=1, largest=False)
# Compute 2D similarities (Q matrix) using Student-t kernel - chunked for memory efficiency
# Student-t kernel: q_ij = (1 + ||z_i - z_j||^2)^(-1)
z_distances_sq = compute_pairwise_distances_chunked(
z_2d, z_2d, metric="squared_euclidean", chunk_size=chunk_size
)
q_matrix = 1.0 / (1.0 + z_distances_sq)
# Set diagonal to 0 (no self-similarity)
mask = torch.ones_like(q_matrix).detach()
mask.fill_diagonal_(0)
q_matrix = q_matrix * mask
# Normalize Q to [0, 1] range
q_matrix = q_matrix / (q_matrix + 1)
# Compute attraction loss for positive pairs (neighbors) - vectorized
row_indices = (
torch.arange(n_samples, device=device).unsqueeze(1).expand(-1, k_actual)
)
q_neighbors = q_matrix[row_indices, nn_indices]
attraction_loss = -(q_neighbors + 1e-10).log().mean()
# Compute repulsion loss for negative pairs - uniform sampling
n_negatives_per_point = self.n_negatives * k_actual
neg_indices = torch.randint(
0, n_samples, (n_samples, n_negatives_per_point), device=device
)
row_indices_neg = (
torch.arange(n_samples, device=device)
.unsqueeze(1)
.expand(-1, n_negatives_per_point)
)
q_negatives = q_matrix[row_indices_neg, neg_indices]
repulsion_loss = -((1 - q_negatives).clamp(min=1e-10).log()).mean()
total_loss = attraction_loss + repulsion_loss
return total_loss
[docs]
def on_validation_epoch_end(
self, trainer: Trainer, pl_module: LightningModule
) -> None:
"""Plot 2D visualization at specified intervals."""
# Skip visualization during warmup period
if trainer.current_epoch < self.warmup_epochs:
logging.info(
f"{self.name}: Warmup period - skipping visualization "
f"(epoch {trainer.current_epoch + 1}/{self.warmup_epochs})"
)
return
# Plot visualization at intervals
if trainer.current_epoch % self.plot_interval != 0:
return
# Get cached features
cached_features = self._input_queue.data
if cached_features is None or cached_features.numel() == 0:
return
# Get cached labels if available
cached_labels = None
if self._target_queue is not None:
cached_labels = self._target_queue.data
if cached_labels is not None and cached_labels.numel() == 0:
cached_labels = None
# Project to 2D
self.module.eval()
with torch.no_grad():
# Ensure correct dtype
proj_dtype = next(self.module.parameters()).dtype
if cached_features.dtype != proj_dtype:
cached_features = cached_features.to(proj_dtype)
z_2d = self.module(cached_features)
# Create visualization
self._plot_2d_embeddings(z_2d, cached_labels, trainer.current_epoch, trainer)
def _plot_2d_embeddings(
self, z_2d: Tensor, labels: Optional[Tensor], epoch: int, trainer: Trainer
) -> None:
"""Save 2D embeddings to file and log to experiment tracker."""
import os
# Save coordinates to NPZ file
z_2d_np = z_2d.cpu().numpy()
labels_np = labels.cpu().numpy() if labels is not None else None
if self.save_dir is not None:
save_dir = self.save_dir
else:
save_dir = f"latent_viz_{self.name}"
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f"epoch_{epoch:04d}.npz")
save_data = {"coordinates": z_2d_np}
if labels_np is not None:
save_data["labels"] = labels_np
np.savez_compressed(save_path, **save_data)
logging.info(f"{self.name}: Saved 2D coordinates to {save_path}")
# Log to experiment tracker if available
try:
from lightning.pytorch.loggers import WandbLogger
if isinstance(trainer.logger, WandbLogger):
import wandb
# Create WandB-specific table only (no direct scatter logging)
if labels_np is not None:
data = np.column_stack([z_2d_np, labels_np.astype(int)])
columns = ["x", "y", "class"]
else:
data = z_2d_np
columns = ["x", "y"]
table = wandb.Table(columns=columns, data=data.tolist())
# Log table - will overwrite previous epoch's table
wandb.log(
{
f"{self.name}/2d_latent_table": table,
f"{self.name}/current_epoch": epoch,
}
)
logging.info(
f"{self.name}: Logged latent table to experiment tracker at epoch {epoch}"
)
except ImportError:
logging.debug(
f"{self.name}: WandB not installed, skipping visualization logging"
)
except Exception as e:
logging.error(f"{self.name}: Failed to log visualization: {e}")
@property
def projection_module(self):
"""Alias for self.module for backward compatibility."""
return self.module