Source code for stable_ssl.data.collate
import numpy as np
import torch
def _collapse_nested_dict(base, other):
if type(base) in [list, tuple]:
for i in range(len(base)):
base[i] = _collapse_nested_dict(base[i], other[i])
return base
elif isinstance(base, dict):
for key in base:
base[key] = _collapse_nested_dict(base[key], other[key])
return base
else:
base = torch.cat([base, other], 0)
return base
[docs]
class Collator:
"""Custom collate function that optionally builds an affinity (or “graph”) matrix based on a specified field."""
def __init__(self, G_from=None):
self.G_from = G_from
def _flatten(self, x):
batch = {}
for name in x[0].keys():
if type(x[0][name]) is list:
flattened = sum([s[name] for s in x], [])
else:
flattened = [s[name] for s in x]
if torch.is_tensor(flattened[0]):
batch[name] = torch.stack(flattened, 0)
elif type(flattened[0]) is dict:
batch[name] = self._flatten(flattened)
else:
batch[name] = torch.from_numpy(np.array(flattened))
return batch
def __call__(self, samples):
# samples = torch.utils.data.default_collate(samples)
# contract the views if there are any
single_view = torch.is_tensor(samples[0]["image"])
if single_view:
samples = torch.utils.data.default_collate(samples)
else:
samples = self._flatten(samples)
if self.G_from is not None:
t = samples[self.G_from]
if t.ndim == 1 and t.dtype in [torch.long, torch.int]:
G = (t[:, None].eq(t)).to(
device=samples["image"].device, dtype=samples["image"].dtype
)
else:
G = t.flatten(1) @ t.flatten(1).T
samples["G"] = G
return samples
@staticmethod
def _test():
indices = torch.randperm(50000)[:128]
images = torch.randn((128, 3, 28, 28))
labels = torch.randint(0, 10, size=(128,))
# single view
data = [
dict(
image=images[i],
label=labels[i],
idx=indices[i],
)
for i in range(128)
]
collator = Collator(G_from="label")
collected = collator(data)
assert collected["image"].eq(images).all()
assert collected["label"].eq(labels).all()
assert collected["idx"].eq(indices).all()
assert collected["G"].eq(labels[:, None] == labels).all()
collator = Collator(G_from="idx")
collected = collator(data)
assert collected["G"].eq(indices[:, None] == indices).all()
# multi-view
data = [
dict(
image=[images[i] + torch.randn((3, 28, 28)) for _ in range(2)],
label=[labels[i]] * 2,
idx=[indices[i]] * 2,
)
for i in range(128)
]
indices = torch.repeat_interleave(indices, 2)
labels = torch.repeat_interleave(labels, 2)
collator = Collator(G_from="label")
collected = collator(data)
assert collected["G"].eq(labels[:, None] == labels).all()
collator = Collator(G_from="idx")
collected = collator(data)
assert collected["G"].eq(indices[:, None] == indices).all()
return True