Skip to content

Conversation

Maxusmusti
Copy link
Collaborator

Adds support for GPT OSS:

  • trains de-quantized
  • re-quantizes at save
  • freezes routing parameters (currently), but also has support for aux loss

Also cleans up our loss calculation by removing our forward override and using accurate batch stats

@mergify mergify bot added CI/CD Affects CI/CD configuration documentation Improvements or additions to documentation testing Relates to testing dependencies Pull requests that update a dependency file labels Aug 21, 2025
Copy link
Contributor

mergify bot commented Aug 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. @Maxusmusti please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 21, 2025
@Maxusmusti Maxusmusti force-pushed the fixed-speed-gpt-oss-support-freezing-fix-loss branch from a0b201d to cc30a90 Compare August 21, 2025 21:54
@mergify mergify bot added ci-failure and removed needs-rebase labels Aug 21, 2025
@Maxusmusti Maxusmusti force-pushed the fixed-speed-gpt-oss-support-freezing-fix-loss branch from cc30a90 to 2e17497 Compare August 21, 2025 22:00
Signed-off-by: Mustafa Eyceoz <[email protected]>
@mergify mergify bot added ci-failure and removed ci-failure labels Sep 9, 2025
@Maxusmusti Maxusmusti force-pushed the fixed-speed-gpt-oss-support-freezing-fix-loss branch from 3737e2a to 37e70bc Compare September 9, 2025 19:17
@mergify mergify bot added ci-failure and removed ci-failure labels Sep 9, 2025
Signed-off-by: Mustafa Eyceoz <[email protected]>
@mergify mergify bot removed the ci-failure label Sep 9, 2025
@Maxusmusti Maxusmusti marked this pull request as ready for review September 9, 2025 20:07
Copy link
Member

@RobotSail RobotSail left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work on the GPT-OSS integration! The MXFP4 quantization/dequantization implementation, router parameter freezing, auxiliary loss support, and batch collation refactor represent solid technical execution.

Before we can merge this PR, there are several issues we need to address:

  1. Data processing seems like it may have trouble correctly identifying GPT-OSS models in edge cases
  2. The batch collator looks like it could be dropping minibatches and not accumulating sample/token counts correctly
  3. kernels dependency is missing from requirements
  4. In the training loop, looks like we're reducing floats instead of ints in several places
  5. Quite a few functions and variables could use more descriptive names
  6. I noticed several places where our logic can be simplified, especially in the GPT-OSS model saving workflow

I like the batch collator refactor, but I'm worried that using a running estimate for batch_num_loss_counted_tokens may cause problems. Since we're now using a running estimate for the total number of loss counted tokens in the entire minibatch, the gradient
signal the model receives will be biased toward the first few microbatches, which can heavily overshoot or undershoot the estimate compared to later microbatches. This could be a big problem when training with large EBS (3840) and on datasets with rare
samples you want to train on (~10k/370k).

We may want to add code in the future that reads the number of loss counted tokens in each batch (validation loss, logging, etc.). In these cases, the value won't be stable until we hit the number of accumulation steps.

One way to solve this issue, while also fixing the dual-use of regular and batch sampler during distributed sampler fallback, would be updating how we use the MultipackSampler to have it behave like a regular sampler that generates samples ahead of time on a
budget, then collates them together at the end. This way, each time you get a batch from the data loader, you know which tokens will be loaded.

It would also be good to add some behavior-driven tests for the sampling so we can have confidence in the refactor.

Please address the issues listed above and consider the suggested approach for the batch collator. The core implementation is solid - these changes will make it production-ready.

test_tokens = ["<|start|>", "<|channel|>", "<|message|>"]
for token in test_tokens:
# If any of these tokens can't be encoded, it's not GPT-OSS
tokenizer.encode(token, add_special_tokens=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what cases would this raise an exception? Afaik, tokenizer.encode will encode any string into a set of tokens.

Signed-off-by: Mustafa Eyceoz <[email protected]>
@mergify mergify bot added ci-failure and removed ci-failure labels Sep 10, 2025
@Maxusmusti Maxusmusti force-pushed the fixed-speed-gpt-oss-support-freezing-fix-loss branch from 37da1e4 to 4011142 Compare September 10, 2025 21:42
@mergify mergify bot added ci-failure and removed ci-failure labels Sep 10, 2025
@Maxusmusti Maxusmusti force-pushed the fixed-speed-gpt-oss-support-freezing-fix-loss branch from 4011142 to 17e39e5 Compare September 10, 2025 21:58
@mergify mergify bot added ci-failure and removed ci-failure labels Sep 10, 2025
@RobotSail RobotSail force-pushed the fixed-speed-gpt-oss-support-freezing-fix-loss branch from 277aeb4 to b2418c9 Compare September 15, 2025 21:21
@mergify mergify bot added ci-failure and removed ci-failure labels Sep 15, 2025
* addition of padded batch packer + simplified train loop

* update tests + linting
@RobotSail RobotSail force-pushed the fixed-speed-gpt-oss-support-freezing-fix-loss branch from b2418c9 to 3ab23f7 Compare September 16, 2025 02:00
@mergify mergify bot removed the ci-failure label Sep 16, 2025
@RobotSail RobotSail force-pushed the fixed-speed-gpt-oss-support-freezing-fix-loss branch 2 times, most recently from bc27542 to 99987ef Compare September 16, 2025 02:15
@mergify mergify bot added the ci-failure label Sep 16, 2025
@RobotSail RobotSail force-pushed the fixed-speed-gpt-oss-support-freezing-fix-loss branch from 99987ef to 3913460 Compare September 16, 2025 02:19
@mergify mergify bot removed the ci-failure label Sep 16, 2025
Copy link
Member

@RobotSail RobotSail left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mergify mergify bot added the one-approval label Sep 16, 2025
@Maxusmusti Maxusmusti merged commit 4934f4c into main Sep 17, 2025
23 of 28 checks passed
@Maxusmusti Maxusmusti deleted the fixed-speed-gpt-oss-support-freezing-fix-loss branch September 17, 2025 15:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CI/CD Affects CI/CD configuration dependencies Pull requests that update a dependency file documentation Improvements or additions to documentation one-approval testing Relates to testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants