Source code for stable_pretraining.utils.distributed
"""Distributed training utilities."""
import torch
import torch.distributed as dist
import torch.distributed.nn
[docs]
def is_dist_avail_and_initialized():
"""Check if distributed training is available and initialized.
Returns:
bool: True if distributed is available and initialized, False otherwise
"""
return dist.is_available() and dist.is_initialized()
[docs]
def all_gather(tensor, *args, **kwargs):
"""Gather tensors from all processes.
Args:
tensor: The tensor to gather
*args: Additional arguments for all_gather
**kwargs: Additional keyword arguments for all_gather
Returns:
Tuple containing the gathered tensors
"""
if is_dist_avail_and_initialized():
torch.distributed.nn.functional.all_gather(tensor, *args, **kwargs)
return (tensor,)
[docs]
def all_reduce(tensor, *args, **kwargs):
"""Reduce tensors across all processes.
Args:
tensor: The tensor to reduce
*args: Additional arguments for all_reduce
**kwargs: Additional keyword arguments for all_reduce
Returns:
The reduced tensor
"""
if is_dist_avail_and_initialized():
torch.distributed.nn.functional.all_reduce(tensor, *args, **kwargs)
return tensor
[docs]
class FullGatherLayer(torch.autograd.Function):
"""Gather tensors from all process and support backward propagation.
Supports backward propagation for the gradients across processes.
"""
[docs]
@staticmethod
def forward(ctx, x):
if not torch.distributed.is_initialized():
return x.unsqueeze(0)
output = [
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(output, x)
return torch.stack(output)
[docs]
@staticmethod
def backward(ctx, grad):
if not torch.distributed.is_initialized():
return grad.squeeze(0)
torch.distributed.all_reduce(grad, op=torch.distributed.ReduceOp.AVG)
return grad[torch.distributed.get_rank()]