-
Notifications
You must be signed in to change notification settings - Fork 589
[RFC] Support full bf16 training #1646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| h = self.norm(h) if self.norm else h | ||
| output = self.output(h) if self.output else h | ||
| output = self.output(h).float() if self.output else h |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we set the training dtype during the training initialization, why not also do the output conversion in the trainer (train loop)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is already in the loss function https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/loss.py#L21
Also see #642
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, just removed
| In the case of full bf16 training, RoPE calculations and logits will still be in fp32. | ||
| """ | ||
|
|
||
| mixed_precision_param: Literal["bfloat16", "float32"] = "bfloat16" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if mixed_precision_param is float32 but dtype is bfloat16? There should be a check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agreed. Do we want to do this somewhere in train.py? Lmk if you think there's a better place
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mixed_precision_param is coming from FSDP2. I think if FSDP2 can work with that, it's users responsibility to config them properly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also make it work with DDP/single device: #1303. I think a warning is at least required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. In that case I will leave this as is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
One last thing in my mind is that the |
IIUC it is a general context manager, very similar to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit surprised that pytorch doesn't provide such context manager natively.
https://discuss.pytorch.org/t/context-manager-for-dtype-and-device/73827/3
LGTM.
Unlike |
|
Oh I see what you mean. I think the function itself can be used for more than model definition, so I'd still prefer it being in a util folder. Maybe let's merge it as is if you don't have strong opinion. |
|
@ebsmothers any thoughts on this? |
|
@joecummings sounds like we should revert this PR, as doing bf16 everywhere does not seem to be the right way to save memory. Wdyt? |
|
Sorry just getting caught up here. My two cents: pure bf16 should not preclude using optimizers like 8-bit Adam. In my mind it is still generally useful (and fairly standard: see e.g. Lightning’s true bf16 precision setting) to store model weights and gradients in bfloat16 without an extra higher-precision copy. Is 8-bit Adam currently supported by Titan? My impression was that it isn’t, but lmk if that’s mistaken. |
Unfortunately I went out on leave after opening #1646 so never actually finished it out to enable bf16 training in the forge experiment, which is what we ultimately wanted (thanks to @joecummings and @tianyu-l for pushing it through). I also see there was some discussion on the original PR which I belatedly responded to. If there are still concerns there let me know. Otherwise if we are not gonna revert that PR we should at least that one so that forge can reap the benefits as intended.


This PR adds support for full bfloat16 training. In SFT it is pretty common to store everything in bfloat16 to save memory, with select tensors (logits, RoPE buffers and activations) maintained in a higher precision to preserve numerical accuracy. Separately I think having this supported more generally would be useful for faster iteration -- e.g. it allows me to run Llama3 70B on a single node of H100s, which otherwise is not possible with the default config.
Assuming this is generally useful, would like feedback on:
set_default_dtypeAPI