TeacherStudentModule#

class stable_ssl.backbone.TeacherStudentModule(student: Module, warm_init: bool = True, base_ema_coefficient: float = 0.996, final_ema_coefficient: float = 1)[source]#

Bases: Module

Student network and its teacher network updated as an EMA of the student network.

The teacher model is updated by taking a running average of the student’s parameters and buffers. When ema_coefficient == 0.0, the teacher and student are literally the same object, saving memory but forward passes through the teacher will not produce any gradients.

Parameters:
  • student (torch.nn.Module) – The student model whose parameters will be tracked.

  • warm_init (bool, optional) – If True, performs an initialization step to match the student’s parameters immediately. Default is True.

  • base_ema_coefficient (float, optional) – EMA decay factor at the start of training. This value will be updated following a cosine schedule. Should be in [0, 1]. A value of 0.0 means the teacher is fully updated to the student’s parameters on every step, while a value of 1.0 means the teacher remains unchanged. Default is 0.996.

  • final_ema_coefficient (float, optional) – EMA decay factor at the end of training. Default is 1.

forward(*args, **kwargs)[source]#

Forward pass through either the student or teacher network.

You can choose which model to run in the default forward. Commonly the teacher is evaluated, so we default to that.

forward_student(*args, **kwargs)[source]#

Forward pass through the student network. Gradients will flow normally.

forward_teacher(*args, **kwargs)[source]#

Forward pass through the teacher network.

By default, the teacher network does not require grad. If ema_coefficient == 0, then teacher==student, so we wrap in torch.no_grad() to ensure no gradients flow.

update_ema_coefficient(epoch: int, total_epochs: int)[source]#

Update the EMA coefficient following a cosine schedule.

The EMA coefficient is updated following a cosine schedule:

ema_coefficient = final_ema_coefficient - 0.5 * (final_ema_coefficient - base_ema_coefficient) * (1 + cos(epoch / total_epochs * pi))

Parameters:
  • epoch (int) – Current epoch in the training loop.

  • total_epochs (int) – Total number of epochs in the training loop.

update_teacher()[source]#

Perform one EMA update step on the teacher’s parameters.

The update rule is:

teacher_param = ema_coefficient * teacher_param + (1 - ema_coefficient) * student_param

This is done in a no_grad context to ensure the teacher’s parameters do not accumulate gradients, but the student remains fully trainable.

Everything is updated, including buffers (e.g. batch norm running averages).