TeacherStudentModule#
- class stable_ssl.backbone.TeacherStudentModule(student: Module, warm_init: bool = True, base_ema_coefficient: float = 0.996, final_ema_coefficient: float = 1)[source]#
Bases:
Module
Student network and its teacher network updated as an EMA of the student network.
The teacher model is updated by taking a running average of the student’s parameters and buffers. When ema_coefficient == 0.0, the teacher and student are literally the same object, saving memory but forward passes through the teacher will not produce any gradients.
- Parameters:
student (torch.nn.Module) – The student model whose parameters will be tracked.
warm_init (bool, optional) – If True, performs an initialization step to match the student’s parameters immediately. Default is True.
base_ema_coefficient (float, optional) – EMA decay factor at the start of training. This value will be updated following a cosine schedule. Should be in [0, 1]. A value of 0.0 means the teacher is fully updated to the student’s parameters on every step, while a value of 1.0 means the teacher remains unchanged. Default is 0.996.
final_ema_coefficient (float, optional) – EMA decay factor at the end of training. Default is 1.
- forward(*args, **kwargs)[source]#
Forward pass through either the student or teacher network.
You can choose which model to run in the default forward. Commonly the teacher is evaluated, so we default to that.
- forward_student(*args, **kwargs)[source]#
Forward pass through the student network. Gradients will flow normally.
- forward_teacher(*args, **kwargs)[source]#
Forward pass through the teacher network.
By default, the teacher network does not require grad. If ema_coefficient == 0, then teacher==student, so we wrap in torch.no_grad() to ensure no gradients flow.
- update_ema_coefficient(epoch: int, total_epochs: int)[source]#
Update the EMA coefficient following a cosine schedule.
- The EMA coefficient is updated following a cosine schedule:
ema_coefficient = final_ema_coefficient - 0.5 * (final_ema_coefficient - base_ema_coefficient) * (1 + cos(epoch / total_epochs * pi))
- update_teacher()[source]#
Perform one EMA update step on the teacher’s parameters.
- The update rule is:
teacher_param = ema_coefficient * teacher_param + (1 - ema_coefficient) * student_param
This is done in a no_grad context to ensure the teacher’s parameters do not accumulate gradients, but the student remains fully trainable.
Everything is updated, including buffers (e.g. batch norm running averages).