Source code for stable_pretraining.utils.distance_metrics
"""Distance metric functions for computing pairwise distances between tensors."""
from typing import Literal
import torch
[docs]
def compute_pairwise_distances(
x: torch.Tensor,
y: torch.Tensor,
metric: Literal[
"euclidean", "squared_euclidean", "cosine", "manhattan"
] = "euclidean",
) -> torch.Tensor:
"""Compute pairwise distances between two sets of vectors.
Args:
x: Tensor of shape (n, d) containing n vectors of dimension d
y: Tensor of shape (m, d) containing m vectors of dimension d
metric: Distance metric to use. Options:
- "euclidean": L2 distance
- "squared_euclidean": Squared L2 distance
- "cosine": Cosine distance (1 - cosine_similarity)
- "manhattan": L1 distance
Returns:
Distance matrix of shape (n, m) where element (i, j) is the distance
between x[i] and y[j]
"""
if metric == "euclidean":
return torch.cdist(x, y, p=2)
elif metric == "squared_euclidean":
return torch.cdist(x, y, p=2).pow(2)
elif metric == "cosine":
# Normalize vectors to unit length
x_norm = torch.nn.functional.normalize(x, p=2, dim=1)
y_norm = torch.nn.functional.normalize(y, p=2, dim=1)
# Cosine similarity = dot product of normalized vectors
cosine_sim = torch.mm(x_norm, y_norm.t())
# Cosine distance = 1 - cosine similarity
return 1 - cosine_sim
elif metric == "manhattan":
return torch.cdist(x, y, p=1)
raise ValueError(
f"Unknown metric: {metric}. Must be one of: euclidean, squared_euclidean, cosine, manhattan"
)
[docs]
def compute_pairwise_distances_chunked(
x: torch.Tensor,
y: torch.Tensor,
metric: Literal[
"euclidean", "squared_euclidean", "cosine", "manhattan"
] = "euclidean",
chunk_size: int = 1024,
) -> torch.Tensor:
"""Memory-efficient computation of pairwise distances using chunking.
Args:
x: Tensor of shape (n, d) containing n vectors of dimension d
y: Tensor of shape (m, d) containing m vectors of dimension d
metric: Distance metric to use
chunk_size: Process y in chunks of this size to save memory
Returns:
Distance matrix of shape (n, m)
"""
m = y.shape[0]
# If chunk_size is -1 or larger than m, process all at once
if chunk_size <= 0 or chunk_size >= m:
return compute_pairwise_distances(x, y, metric)
# Build distance matrix by concatenating chunks
chunks = []
# Process y in chunks
for i in range(0, m, chunk_size):
end_idx = min(i + chunk_size, m)
y_chunk = y[i:end_idx]
chunk_distances = compute_pairwise_distances(x, y_chunk, metric)
chunks.append(chunk_distances)
# Concatenate chunks along the second dimension
distances = torch.cat(chunks, dim=1)
return distances