-
Notifications
You must be signed in to change notification settings - Fork 72
Adding GPT OSS Support #646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding GPT OSS Support #646
Conversation
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
a0b201d
to
cc30a90
Compare
cc30a90
to
2e17497
Compare
Signed-off-by: Mustafa Eyceoz <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
3737e2a
to
37e70bc
Compare
Signed-off-by: Mustafa Eyceoz <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work on the GPT-OSS integration! The MXFP4 quantization/dequantization implementation, router parameter freezing, auxiliary loss support, and batch collation refactor represent solid technical execution.
Before we can merge this PR, there are several issues we need to address:
- Data processing seems like it may have trouble correctly identifying GPT-OSS models in edge cases
- The batch collator looks like it could be dropping minibatches and not accumulating sample/token counts correctly
- kernels dependency is missing from requirements
- In the training loop, looks like we're reducing floats instead of ints in several places
- Quite a few functions and variables could use more descriptive names
- I noticed several places where our logic can be simplified, especially in the GPT-OSS model saving workflow
I like the batch collator refactor, but I'm worried that using a running estimate for batch_num_loss_counted_tokens may cause problems. Since we're now using a running estimate for the total number of loss counted tokens in the entire minibatch, the gradient
signal the model receives will be biased toward the first few microbatches, which can heavily overshoot or undershoot the estimate compared to later microbatches. This could be a big problem when training with large EBS (3840) and on datasets with rare
samples you want to train on (~10k/370k).
We may want to add code in the future that reads the number of loss counted tokens in each batch (validation loss, logging, etc.). In these cases, the value won't be stable until we hit the number of accumulation steps.
One way to solve this issue, while also fixing the dual-use of regular and batch sampler during distributed sampler fallback, would be updating how we use the MultipackSampler to have it behave like a regular sampler that generates samples ahead of time on a
budget, then collates them together at the end. This way, each time you get a batch from the data loader, you know which tokens will be loaded.
It would also be good to add some behavior-driven tests for the sampling so we can have confidence in the refactor.
Please address the issues listed above and consider the suggested approach for the batch collator. The core implementation is solid - these changes will make it production-ready.
test_tokens = ["<|start|>", "<|channel|>", "<|message|>"] | ||
for token in test_tokens: | ||
# If any of these tokens can't be encoded, it's not GPT-OSS | ||
tokenizer.encode(token, add_special_tokens=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In what cases would this raise an exception? Afaik, tokenizer.encode
will encode any string into a set of tokens.
Signed-off-by: Mustafa Eyceoz <[email protected]>
37da1e4
to
4011142
Compare
4011142
to
17e39e5
Compare
277aeb4
to
b2418c9
Compare
* addition of padded batch packer + simplified train loop * update tests + linting
b2418c9
to
3ab23f7
Compare
bc27542
to
99987ef
Compare
99987ef
to
3913460
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Adds support for GPT OSS:
Also cleans up our loss calculation by removing our forward override and using accurate batch stats