"""Modular tensor aggregation module for feeding multi-scale/multi-layer features to MLPs.
Commonly used for:
- SSL linear probes using multiple transformer layers
- Multi-scale feature fusion
- Combining features from different network stages
"""
from typing import Union, List, Dict, Optional, Literal
import torch
import torch.nn as nn
from loguru import logger
AggregationMode = Literal["mean", "max", "cls", "flatten", "adaptive"]
TensorInput = Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]]
[docs]
class TensorAggregator(nn.Module):
"""Aggregates multi-dimensional tensors into 2D format for MLP input.
Pure aggregation module with NO trainable parameters.
Handles various input formats and aggregation strategies.
Args:
input_spec: Specification of input format and aggregation modes:
- str: Single aggregation mode for all tensors (e.g., "mean")
- List[str]: Per-tensor aggregation modes for list inputs
- Dict[str, str]: Per-key aggregation modes for dict inputs
adaptive_pool_size: Output size for adaptive pooling (default: 1)
Aggregation Modes:
- "mean": Spatial/temporal mean pooling
- "max": Spatial/temporal max pooling
- "cls": Take first token (for transformers with [CLS] token)
- "flatten": Flatten all dimensions after batch
- "adaptive": Adaptive average pooling to fixed size
Examples:
>>> # Single tensor with mean pooling
>>> agg = TensorAggregator("mean")
>>> x = torch.randn(4, 768, 14, 14)
>>> out = agg(x) # Shape: (4, 768)
>>> # SSL: Last 4 transformer layers with CLS token
>>> agg = TensorAggregator(["cls", "cls", "cls", "cls"])
>>> layers = [torch.randn(4, 197, 768) for _ in range(4)]
>>> out = agg(layers) # Shape: (4, 3072) # 768 * 4
>>> # Multi-scale features
>>> agg = TensorAggregator({"layer1": "cls", "layer2": "mean", "conv": "mean"})
>>> out = agg(
... {
... "layer1": torch.randn(4, 197, 768),
... "layer2": torch.randn(4, 197, 768),
... "conv": torch.randn(4, 512, 14, 14),
... }
... ) # Shape: (4, 2048)
"""
def __init__(
self,
input_spec: Union[str, List[str], Dict[str, str]],
adaptive_pool_size: int = 1,
):
super().__init__()
self.input_spec = input_spec
self.adaptive_pool_size = adaptive_pool_size
# Determine input type
if isinstance(input_spec, str):
self.input_type = "single"
self.agg_modes = {"default": input_spec}
elif isinstance(input_spec, list):
self.input_type = "list"
self.agg_modes = {i: mode for i, mode in enumerate(input_spec)}
elif isinstance(input_spec, dict):
self.input_type = "dict"
self.agg_modes = input_spec
else:
raise ValueError(f"Invalid input_spec type: {type(input_spec)}")
# Validate aggregation modes
valid_modes = {"mean", "max", "cls", "flatten", "adaptive"}
for mode in self.agg_modes.values():
if mode not in valid_modes:
raise ValueError(
f"Invalid aggregation mode: {mode}. Valid modes: {valid_modes}"
)
logger.info(f"Initialized TensorAggregator with {self.input_type} input")
logger.debug(f"Aggregation modes: {self.agg_modes}")
def _aggregate_single_tensor(
self, x: torch.Tensor, mode: str, key: Optional[Union[str, int]] = None
) -> torch.Tensor:
"""Aggregate a single tensor to 2D based on aggregation mode.
Args:
x: Input tensor of shape (B, ..., D) or (B, D, H, W)
mode: Aggregation mode
key: Optional key for logging
Returns:
2D tensor of shape (B, features)
"""
batch_size = x.shape[0]
original_shape = x.shape
logger.trace(f"Aggregating {key or 'tensor'}: {x.shape} using '{mode}'")
# Already 2D - nothing to do!
if x.ndim == 2:
logger.trace(f"Already 2D: {x.shape}")
return x
# 3D: (B, L, D) - sequence data
elif x.ndim == 3:
result = self._aggregate_3d(x, mode, key)
# 4D: (B, C, H, W) - image/feature maps
elif x.ndim == 4:
result = self._aggregate_4d(x, mode, key)
# 5D: (B, C, T, H, W) - video/3D data
elif x.ndim == 5:
result = self._aggregate_5d(x, mode, key)
else:
raise ValueError(
f"Unsupported tensor dimension: {x.ndim}. Supported: 2 (no-op), 3, 4, 5"
)
# Ensure output is 2D
if result.ndim != 2:
result = result.reshape(batch_size, -1)
logger.trace(f"Aggregated {original_shape} -> {result.shape}")
return result
def _aggregate_3d(
self, x: torch.Tensor, mode: str, key: Optional[Union[str, int]]
) -> torch.Tensor:
"""Aggregate 3D tensor (B, L, D) to 2D."""
if mode == "mean":
return x.mean(dim=1) # (B, D)
elif mode == "max":
return x.max(dim=1)[0] # (B, D)
elif mode == "cls":
return x[:, 0] # (B, D) - first token
elif mode == "flatten":
return x.reshape(x.shape[0], -1) # (B, L*D)
elif mode == "adaptive":
# Pool sequence dimension to fixed size
return nn.functional.adaptive_avg_pool1d(
x.transpose(1, 2),
self.adaptive_pool_size, # (B, D, L)
).squeeze(-1) # (B, D)
else:
raise ValueError(f"Mode '{mode}' not supported for 3D tensors")
def _aggregate_4d(
self, x: torch.Tensor, mode: str, key: Optional[Union[str, int]]
) -> torch.Tensor:
"""Aggregate 4D tensor (B, C, H, W) to 2D."""
batch_size = x.shape[0]
if mode == "mean":
return x.mean(dim=(2, 3)) # (B, C)
elif mode == "max":
return x.amax(dim=(2, 3)) # (B, C)
elif mode == "adaptive":
return nn.functional.adaptive_avg_pool2d(
x, (self.adaptive_pool_size, self.adaptive_pool_size)
).reshape(batch_size, -1) # (B, C * pool_size^2)
elif mode == "flatten":
return x.reshape(batch_size, -1) # (B, C*H*W)
elif mode == "cls":
logger.warning(
f"Using 'cls' on 4D tensor, taking [0,0] spatial position. "
f"Consider 'mean' or 'adaptive' instead. Shape: {x.shape}"
)
return x[:, :, 0, 0] # (B, C)
else:
raise ValueError(f"Mode '{mode}' not supported for 4D tensors")
def _aggregate_5d(
self, x: torch.Tensor, mode: str, key: Optional[Union[str, int]]
) -> torch.Tensor:
"""Aggregate 5D tensor (B, C, T, H, W) to 2D."""
batch_size = x.shape[0]
if mode == "mean":
return x.mean(dim=(2, 3, 4)) # (B, C)
elif mode == "max":
return x.amax(dim=(2, 3, 4)) # (B, C)
elif mode == "adaptive":
pool_size = self.adaptive_pool_size
return nn.functional.adaptive_avg_pool3d(
x, (pool_size, pool_size, pool_size)
).reshape(batch_size, -1) # (B, C * pool_size^3)
elif mode == "flatten":
return x.reshape(batch_size, -1) # (B, C*T*H*W)
else:
raise ValueError(
f"Mode '{mode}' not supported for 5D tensors. "
f"Use: mean, max, adaptive, flatten"
)
[docs]
def forward(self, x: TensorInput) -> torch.Tensor:
"""Aggregate input tensor(s) to 2D format.
Args:
x: Input tensor, list of tensors, or dict of tensors
Returns:
Aggregated 2D tensor of shape (B, total_features)
"""
# Single tensor
if isinstance(x, torch.Tensor):
if self.input_type != "single":
logger.warning(
f"Expected {self.input_type} input but got single tensor"
)
mode = self.agg_modes.get("default", "mean")
return self._aggregate_single_tensor(x, mode)
# List of tensors
elif isinstance(x, list):
if self.input_type == "single":
mode = self.agg_modes["default"]
aggregated = [
self._aggregate_single_tensor(tensor, mode, i)
for i, tensor in enumerate(x)
]
else:
if len(x) != len(self.agg_modes):
logger.warning(
f"Number of tensors ({len(x)}) != number of modes "
f"({len(self.agg_modes)})"
)
aggregated = []
for i, tensor in enumerate(x):
mode = self.agg_modes.get(i, list(self.agg_modes.values())[0])
agg = self._aggregate_single_tensor(tensor, mode, i)
aggregated.append(agg)
result = torch.cat(aggregated, dim=1)
logger.debug(f"Concatenated {len(aggregated)} tensors -> {result.shape}")
return result
# Dict of tensors (sorted for determinism)
elif isinstance(x, dict):
if self.input_type == "single":
mode = self.agg_modes["default"]
aggregated = [
self._aggregate_single_tensor(tensor, mode, key)
for key, tensor in sorted(x.items())
]
else:
aggregated = []
for key, tensor in sorted(x.items()):
mode = self.agg_modes.get(key, "mean")
if key not in self.agg_modes:
logger.warning(f"No mode specified for '{key}', using 'mean'")
agg = self._aggregate_single_tensor(tensor, mode, key)
aggregated.append(agg)
result = torch.cat(aggregated, dim=1)
logger.debug(
f"Concatenated {len(aggregated)} dict entries -> {result.shape}"
)
return result
else:
raise TypeError(
f"Unsupported input type: {type(x)}. "
f"Expected Tensor, List[Tensor], or Dict[str, Tensor]"
)
[docs]
def compute_output_dim(
self, input_shapes: Union[tuple, List[tuple], Dict[str, tuple]]
) -> int:
"""Compute the output dimension given input shapes.
Args:
input_shapes: Shape(s) of input tensor(s) (excluding batch dim)
Returns:
Total output features
Examples:
>>> agg = TensorAggregator(["cls", "mean"])
>>> agg.compute_output_dim([(197, 768), (197, 768)])
1536
>>> agg = TensorAggregator({"l1": "cls", "conv": "mean"})
>>> agg.compute_output_dim({"l1": (197, 768), "conv": (512, 14, 14)})
1280
"""
def _compute_single_dim(shape: tuple, mode: str) -> int:
"""Compute output dim for a single tensor."""
ndim = len(shape)
# Already 2D
if ndim == 1:
return shape[0]
# 3D tensor (seq_len, features)
elif ndim == 2:
if mode in ["cls", "mean", "max"]:
return shape[1]
elif mode == "flatten":
return shape[0] * shape[1]
elif mode == "adaptive":
return shape[1] * self.adaptive_pool_size
# 4D tensor (channels, height, width)
elif ndim == 3:
if mode in ["mean", "max", "cls"]:
return shape[0]
elif mode == "flatten":
return shape[0] * shape[1] * shape[2]
elif mode == "adaptive":
return shape[0] * (self.adaptive_pool_size**2)
# 5D tensor (channels, time, height, width)
elif ndim == 4:
if mode in ["mean", "max"]:
return shape[0]
elif mode == "flatten":
return shape[0] * shape[1] * shape[2] * shape[3]
elif mode == "adaptive":
return shape[0] * (self.adaptive_pool_size**3)
raise ValueError(f"Cannot compute dim for shape {shape} with mode {mode}")
# Single input
if isinstance(input_shapes, tuple):
mode = self.agg_modes.get("default", "mean")
return _compute_single_dim(input_shapes, mode)
# List of inputs
elif isinstance(input_shapes, list):
total = 0
for i, shape in enumerate(input_shapes):
mode = self.agg_modes.get(i, list(self.agg_modes.values())[0])
total += _compute_single_dim(shape, mode)
return total
# Dict of inputs
elif isinstance(input_shapes, dict):
total = 0
for key, shape in input_shapes.items():
mode = self.agg_modes.get(key, "mean")
total += _compute_single_dim(shape, mode)
return total
else:
raise TypeError(f"Unsupported input_shapes type: {type(input_shapes)}")
def __repr__(self) -> str:
return (
f"TensorAggregator(type={self.input_type}, "
f"modes={self.agg_modes}, "
f"adaptive_pool_size={self.adaptive_pool_size})"
)