Skip to content

Conversation

@ebsmothers
Copy link
Contributor

This PR adds support for full bfloat16 training. In SFT it is pretty common to store everything in bfloat16 to save memory, with select tensors (logits, RoPE buffers and activations) maintained in a higher precision to preserve numerical accuracy. Separately I think having this supported more generally would be useful for faster iteration -- e.g. it allows me to run Llama3 70B on a single node of H100s, which otherwise is not possible with the default config.

Assuming this is generally useful, would like feedback on:

  1. Acceptable loss convergence: in the first 100 steps on Llama3 8B full bf16 training goes from 12.25 -> 8, as opposed to 12.25 -> 7 with fp32 training. Is this a concern? (As mentioned, for SFT this is less of an issue; happy to validate that statement if that's helpful.)
  2. Interaction with mixed precision training -- where is the right place to validate that these are not both set at once?
  3. Where to put the set_default_dtype API

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 27, 2025

h = self.norm(h) if self.norm else h
output = self.output(h) if self.output else h
output = self.output(h).float() if self.output else h
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we set the training dtype during the training initialization, why not also do the output conversion in the trainer (train loop)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just removed

In the case of full bf16 training, RoPE calculations and logits will still be in fp32.
"""

mixed_precision_param: Literal["bfloat16", "float32"] = "bfloat16"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if mixed_precision_param is float32 but dtype is bfloat16? There should be a check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed. Do we want to do this somewhere in train.py? Lmk if you think there's a better place

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mixed_precision_param is coming from FSDP2. I think if FSDP2 can work with that, it's users responsibility to config them properly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also make it work with DDP/single device: #1303. I think a warning is at least required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. In that case I will leave this as is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin
autocast is not well supported in torchtitan anyways. I'm not sure if it is still maintained. See other issue like #1525

But sure, having a warning sounds good.

@fegin
Copy link
Contributor

fegin commented Sep 16, 2025

One last thing in my mind is that the set_default_dtype() definition should be moved to torchtitan/model.

@tianyu-l
Copy link
Contributor

@fegin

One last thing in my mind is that the set_default_dtype() definition should be moved to torchtitan/model

IIUC it is a general context manager, very similar to with torch.device(). Curious why you think it should be moved to model folder?

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit surprised that pytorch doesn't provide such context manager natively.
https://discuss.pytorch.org/t/context-manager-for-dtype-and-device/73827/3

LGTM.

@fegin
Copy link
Contributor

fegin commented Sep 16, 2025

IIUC it is a general context manager, very similar to with torch.device(). Curious why you think it should be moved to model folder?

Unlike torch.device() which allocates the same model on different devices, this context changes the entire model, the model meaning is technically different. That's my motivation. But I don't have a strong opinion on this. It is also reasonable to put under tool folder.

@tianyu-l
Copy link
Contributor

Oh I see what you mean. I think the function itself can be used for more than model definition, so I'd still prefer it being in a util folder. Maybe let's merge it as is if you don't have strong opinion.

@tianyu-l tianyu-l merged commit e99e16c into pytorch:main Sep 16, 2025
7 checks passed
@hann-wang
Copy link
Contributor

hann-wang commented Sep 18, 2025

I don't think keeping optimizer states in BF16 is a good idea.

Generally speaking, keeping optimizer states in BF16 will degrade the final performance. Megatron-LM supports only FP16 and FP32 for optimizer states. (FP16 requires a separate scaling factor)

Here's a comparison between FP32/BF16 optimizer states on Megatron-LM:
image

@fegin
Copy link
Contributor

fegin commented Sep 18, 2025

@ebsmothers any thoughts on this?

@samsja
Copy link
Contributor

samsja commented Sep 19, 2025

I don't think keeping optimizer states in BF16 is a good idea.

Generally speaking, keeping optimizer states in BF16 will degrade the final performance. Megatron-LM supports only FP16 and FP32 for optimizer states. (FP16 requires a separate scaling factor)

Here's a comparison between FP32/BF16 optimizer states on Megatron-LM: image

I also don't think that doing pure bf16 training makes sense even for sft. If the goal is to reduce memory footprint of the optimizer I think that adam8bit is a better tradeoff for low gpu count and with many gpu fsdp should make the optimizer state quite small on each gpu

@tianyu-l
Copy link
Contributor

@joecummings sounds like we should revert this PR, as doing bf16 everywhere does not seem to be the right way to save memory. Wdyt?

@ebsmothers
Copy link
Contributor Author

Sorry just getting caught up here. My two cents: pure bf16 should not preclude using optimizers like 8-bit Adam. In my mind it is still generally useful (and fairly standard: see e.g. Lightning’s true bf16 precision setting) to store model weights and gradients in bfloat16 without an extra higher-precision copy. Is 8-bit Adam currently supported by Titan? My impression was that it isn’t, but lmk if that’s mistaken.

tianyu-l pushed a commit that referenced this pull request Sep 30, 2025
Unfortunately I went out on leave after opening #1646 so never actually
finished it out to enable bf16 training in the forge experiment, which
is what we ultimately wanted (thanks to @joecummings and @tianyu-l for
pushing it through).

I also see there was some discussion on the original PR which I
belatedly responded to. If there are still concerns there let me know.
Otherwise if we are not gonna revert that PR we should at least that one
so that forge can reap the benefits as intended.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants