-
Notifications
You must be signed in to change notification settings - Fork 42
Abstract class for target/aux computation #1184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
Implemented Identity class TODO: implement EMATeacher
The big question on the EMA teacher side to me is how to allow for a fleixble teacher and student architecture that can differ We updated some APIs of the abstract base class to allow the ema_model forward, subject to change given the loss calculator, which is imho the second big question mark
|
|
||
| class EMATeacher(TargetAndAuxModuleBase): | ||
| def __init__(self, model, rng, ema_model, batch_size, **kwargs): | ||
| # One of the issues is that the teacher model may have a different architecture |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean that e.g. in JEPA the student has the predictor too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, in JEPA the student is Predictor(Encoder(x')) whereas the teacher is just Encoder(x), but also in BYOL there is a difference for instance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. Is there a useful abstraction we could stick with that would be helpful -- always EMA'ed encoder for example? EMATeacherEncoder always the same, then add e.g. predictor to this? This might not help, and don't know if this holds for byol, just thinking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. The predictor could be the identity if it's not present.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will need different "heads" for different latent student-teacher losses, the predictor would be just one of them
Easier to read and as batchsize gets more complicated in SSL this will be a useful abstraction
It runs so far. Next steps: - Route all the config options - Start writing the loss functions to understand the state requirements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks already very nice overall but some minor structural changes would be good, see detailed comments.
| return preds_tokens | ||
|
|
||
|
|
||
| def get_model(student_or_teacher, cf: Config, sources_size, targets_num_channels, targets_coords_size, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instantiate_model() is a more natural name for me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And I don't think it should go to model.py. If we have the function then it seems more natural that it is also responsible which model potentially to instantiate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it felt unnecessary to create another file for it
| maybe_sharded_sd = self.original_model.state_dict() | ||
| # this copies correctly tested in pdb | ||
| mkeys, ukeys = self.ema_model.load_state_dict(maybe_sharded_sd, strict=True, assign=False) | ||
| mkeys, ukeys = self.ema_model.load_state_dict(maybe_sharded_sd, strict=False, assign=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because teacher arch =/= student arch so it cannot be strict
| if student_or_teacher == "student" or student_or_teacher == "teacher": | ||
| return Model(cf, sources_size, targets_num_channels, targets_coords_size).create() | ||
| else: | ||
| if cf["training_mode"] == "masking": # TODO implement mode "student-teacher-pretrain": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a nested dict. But we should write an example config to see how it looks and feels like and how it works.
|
|
||
|
|
||
| class IdentityTargetAndAux(TargetAndAuxModuleBase): | ||
| def __init__(self, model, rng, config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we have a brief documentation
|
|
||
| class EMATeacher(TargetAndAuxModuleBase): | ||
| def __init__(self, model, rng, ema_model, batch_size, **kwargs): | ||
| # One of the issues is that the teacher model may have a different architecture |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. The predictor could be the identity if it's not present.
| loss_values = self.loss_calculator.compute_loss( | ||
| preds=preds, | ||
| streams_data=batch[0], | ||
| streams_data=batch[0], # should additionally take targets? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this should take targets. We should have an TargetAndAuxCalculatorIdentity class that takes the batch and returns just the physical space targets. (No strong feelings if we call TargetAndAuxCalculatorIdentity or TargetAndAuxCalculatorPhysical or something similar)
| self.ema_model.update( | ||
| self.cf.istep * self.world_size_original * self.cf.batch_size_per_gpu, | ||
| self.world_size_original * self.cf.batch_size_per_gpu, | ||
| self.cf.istep * get_batch_size(self.cf, self.world_size_original), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to abstract this into a function in utils/distributed.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this change does this abstraction, not sure I understand
|
|
||
|
|
||
| # should be moved to its own file so as to prevent cyclical imports | ||
| def get_target_and_aux_calculator(config, model, rng, batch_size, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should go to the same file as instantiate_model.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, how strongly are you married to instantiate_model?
Implemented Identity class
TODO: implement EMATeacher
Description
Issue Number
Closes #1179
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60