Skip to content

Conversation

@sophie-xhonneux
Copy link
Contributor

Implemented Identity class

TODO: implement EMATeacher

Description

Issue Number

Closes #1179

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

Implemented Identity class

TODO: implement EMATeacher
The big question on the EMA teacher side to me is how to allow for a
fleixble teacher and student architecture that can differ

We updated some APIs of the abstract base class to allow the ema_model
forward, subject to change given the loss calculator, which is imho the
second big question mark
@shmh40 shmh40 self-assigned this Oct 31, 2025
@shmh40 shmh40 moved this to In Progress in WeatherGen-dev Oct 31, 2025

class EMATeacher(TargetAndAuxModuleBase):
def __init__(self, model, rng, ema_model, batch_size, **kwargs):
# One of the issues is that the teacher model may have a different architecture
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that e.g. in JEPA the student has the predictor too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, in JEPA the student is Predictor(Encoder(x')) whereas the teacher is just Encoder(x), but also in BYOL there is a difference for instance

Copy link
Contributor

@shmh40 shmh40 Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Is there a useful abstraction we could stick with that would be helpful -- always EMA'ed encoder for example? EMATeacherEncoder always the same, then add e.g. predictor to this? This might not help, and don't know if this holds for byol, just thinking

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. The predictor could be the identity if it's not present.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need different "heads" for different latent student-teacher losses, the predictor would be just one of them

Easier to read and as batchsize gets more complicated in SSL this will
be a useful abstraction
@github-actions github-actions bot added the model Related to model training or definition (not generic infra) label Nov 5, 2025
It runs so far. Next steps:
 - Route all the config options
 - Start writing the loss functions to understand the state requirements
Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks already very nice overall but some minor structural changes would be good, see detailed comments.

return preds_tokens


def get_model(student_or_teacher, cf: Config, sources_size, targets_num_channels, targets_coords_size, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instantiate_model() is a more natural name for me

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I don't think it should go to model.py. If we have the function then it seems more natural that it is also responsible which model potentially to instantiate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it felt unnecessary to create another file for it

maybe_sharded_sd = self.original_model.state_dict()
# this copies correctly tested in pdb
mkeys, ukeys = self.ema_model.load_state_dict(maybe_sharded_sd, strict=True, assign=False)
mkeys, ukeys = self.ema_model.load_state_dict(maybe_sharded_sd, strict=False, assign=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because teacher arch =/= student arch so it cannot be strict

if student_or_teacher == "student" or student_or_teacher == "teacher":
return Model(cf, sources_size, targets_num_channels, targets_coords_size).create()
else:
if cf["training_mode"] == "masking": # TODO implement mode "student-teacher-pretrain":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a nested dict. But we should write an example config to see how it looks and feels like and how it works.



class IdentityTargetAndAux(TargetAndAuxModuleBase):
def __init__(self, model, rng, config):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have a brief documentation


class EMATeacher(TargetAndAuxModuleBase):
def __init__(self, model, rng, ema_model, batch_size, **kwargs):
# One of the issues is that the teacher model may have a different architecture
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. The predictor could be the identity if it's not present.

loss_values = self.loss_calculator.compute_loss(
preds=preds,
streams_data=batch[0],
streams_data=batch[0], # should additionally take targets?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this should take targets. We should have an TargetAndAuxCalculatorIdentity class that takes the batch and returns just the physical space targets. (No strong feelings if we call TargetAndAuxCalculatorIdentity or TargetAndAuxCalculatorPhysical or something similar)

self.ema_model.update(
self.cf.istep * self.world_size_original * self.cf.batch_size_per_gpu,
self.world_size_original * self.cf.batch_size_per_gpu,
self.cf.istep * get_batch_size(self.cf, self.world_size_original),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to abstract this into a function in utils/distributed.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change does this abstraction, not sure I understand



# should be moved to its own file so as to prevent cyclical imports
def get_target_and_aux_calculator(config, model, rng, batch_size, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go to the same file as instantiate_model.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, how strongly are you married to instantiate_model?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model:pretrain model Related to model training or definition (not generic infra)

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

Abstract class for the teacher in student-teacher training

4 participants