-
Notifications
You must be signed in to change notification settings - Fork 596
Add post_dataloading_processing method to Trainer #1985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/fegin/24/base
Are you sure you want to change the base?
Conversation
We are adding more actions to convert the raw inputs and label. 1. The new CP can do the input/label/BlockMask sharding this in this method. 2. The experimental full dtensor model can simply override this method without changing too many Trainer code. This method is extracted from #1857 Makeing this a standalone PR allows us to continue the two projects above without one blocks another. ghstack-source-id: d1882a7 Pull-Request: #1985
| extra_inputs=extra_inputs, | ||
| ) | ||
|
|
||
| return inputs, label, extra_inputs, extra_kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add docstring for returns, especially on the difference between extra_inputs and extra_kwargs.
Also not sure if we should just merge inputs and extra_inputs. Not urgent though.
| model_parts = self.model_parts | ||
| parallel_dims = self.parallel_dims | ||
|
|
||
| def post_dataloading_processing( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This name is accurate in where it should be called, but we are putting it not right after dataloading. Rather we are putting it before training, which makes sense because when other library depends on torchtitan training but not torchtitan data loading, this is the right place to put it.
I just wonder if we could have another name that can express it's happening right before (but mostly as part of) the training, e.g. a bad and verbose version would be pre-actual-training-last-minute-data-preparation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about pre_training_data_processing or pre_training_data_preparation or if the "last" is really an important message, then final_data_preparation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying to avoid the term "pre-training" which could cause confusion.
I think we can go with post_dataloading_process, seems no ambiguity.
| def post_dataloading_processing( | |
| def post_dataloading_process( |
| model_parts = self.model_parts | ||
| parallel_dims = self.parallel_dims | ||
|
|
||
| def post_dataloading_processing( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying to avoid the term "pre-training" which could cause confusion.
I think we can go with post_dataloading_process, seems no ambiguity.
| def post_dataloading_processing( | |
| def post_dataloading_process( |
Stack from ghstack (oldest at bottom):
We are adding more actions to convert the raw inputs and label.
This method is extracted from #1857
Makeing this a standalone PR allows us to continue the two projects above without one blocks another.