Skip to content

Commit e99e16c

Browse files
authored
[RFC] Support full bf16 training (#1646)
This PR adds support for full bfloat16 training. In SFT it is pretty common to store everything in bfloat16 to save memory, with select tensors (logits, RoPE buffers and activations) maintained in a higher precision to preserve numerical accuracy. Separately I think having this supported more generally would be useful for faster iteration -- e.g. it allows me to run Llama3 70B on a single node of H100s, which otherwise is not possible with the default config. Assuming this is generally useful, would like feedback on: 1) Acceptable loss convergence: in the first 100 steps on Llama3 8B full bf16 training goes from 12.25 -> 8, as opposed to 12.25 -> 7 with fp32 training. Is this a concern? (As mentioned, for SFT this is less of an issue; happy to validate that statement if that's helpful.) 2) Interaction with mixed precision training -- where is the right place to validate that these are not both set at once? 3) Where to put the `set_default_dtype` API
1 parent 40a8725 commit e99e16c

File tree

3 files changed

+41
-3
lines changed

3 files changed

+41
-3
lines changed

torchtitan/config/job_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,13 @@ class Training:
201201
Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP
202202
"""
203203

204+
dtype: Literal["bfloat16", "float32"] = "float32"
205+
"""
206+
torch dtype for training. In contrast to mixed precision training, setting training_dtype=bfloat16 will
207+
put all parameters, gradients, and optimizer states in bfloat16, without an extra copy of fp32 weights.
208+
In the case of full bf16 training, RoPE calculations and logits will still be in fp32.
209+
"""
210+
204211
mixed_precision_param: Literal["bfloat16", "float32"] = "bfloat16"
205212
"""
206213
torch dtype to use for parameters when applying mixed precision via fully_shard or torch.autocast.

torchtitan/tools/utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import contextlib
78
import gc
89
import subprocess
910
import time
1011
from dataclasses import dataclass
11-
from typing import Optional
12+
from typing import Generator, Optional
1213

1314
import torch
1415
from torch._utils import _get_available_device_type, _get_device_module
@@ -174,3 +175,30 @@ def check_if_feature_in_pytorch(
174175
f"{min_nightly_version}. Please upgrade a newer version to include the "
175176
f"change in ({pull_request}) for correct {feature_name}."
176177
)
178+
179+
180+
@contextlib.contextmanager
181+
def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
182+
"""
183+
Context manager to set torch's default dtype.
184+
185+
Args:
186+
dtype (torch.dtype): The desired default dtype inside the context manager.
187+
188+
Returns:
189+
ContextManager: context manager for setting default dtype.
190+
191+
Example:
192+
>>> with set_default_dtype(torch.bfloat16):
193+
>>> x = torch.tensor([1, 2, 3])
194+
>>> x.dtype
195+
torch.bfloat16
196+
197+
198+
"""
199+
old_dtype = torch.get_default_dtype()
200+
torch.set_default_dtype(dtype)
201+
try:
202+
yield
203+
finally:
204+
torch.set_default_dtype(old_dtype)

torchtitan/train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
build_metrics_processor,
2323
ensure_pp_loss_visible,
2424
)
25-
from torchtitan.config import ConfigManager, JobConfig
25+
from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP
2626
from torchtitan.distributed import ParallelDims, utils as dist_utils
2727
from torchtitan.models.attention import init_attention_mask
2828
from torchtitan.protocols.model_converter import build_model_converters
@@ -154,7 +154,10 @@ def __init__(self, job_config: JobConfig):
154154
logger.info(
155155
f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
156156
)
157-
with torch.device("meta"):
157+
with (
158+
torch.device("meta"),
159+
utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]),
160+
):
158161
model = self.train_spec.model_cls(model_args)
159162

160163
# Build the collection of model converters. No-op if `model.converters` empty

0 commit comments

Comments
 (0)