Skip to content

Commit 605f85f

Browse files
authored
Trainer + reference model in bf16 (#189)
1 parent 88ed1c2 commit 605f85f

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

apps/grpo/qwen3_1_7b.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ trainer:
4646
local_batch_size: ${batch_size}
4747
seq_len: 2048
4848
max_norm: 1.0
49-
steps: 5
49+
steps: 1000000
50+
dtype: bfloat16
5051
compile:
5152
enable: false
5253
parallelism:
@@ -80,6 +81,8 @@ ref_model:
8081
name: qwen3
8182
flavor: 1.7B
8283
hf_assets_path: hf://${model}
84+
training:
85+
dtype: bfloat16
8386
compile:
8487
enable: false
8588
parallelism:

src/forge/actors/reference_model.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
from monarch.actor import current_rank, current_size, endpoint
1616
from torch.distributed.tensor import DTensor
1717

18-
from torchtitan.config.job_config import Checkpoint, Compile, Model, Parallelism
18+
from torchtitan.config.job_config import (
19+
Checkpoint,
20+
Compile,
21+
Model,
22+
Parallelism,
23+
Training,
24+
)
1925
from torchtitan.experiments.forge.engine import ForgeEngine
2026
from torchtitan.experiments.forge.job_config import ForgeJobConfig
2127

@@ -29,6 +35,9 @@ class ReferenceModel(ForgeActor):
2935
parallelism: Parallelism = field(default_factory=Parallelism)
3036
checkpoint: Checkpoint = field(default_factory=Checkpoint)
3137
compile: Compile = field(default_factory=Compile)
38+
training: Training = field(
39+
default_factory=Training
40+
) # Only needed in order to correctly set a lower dtype
3241

3342
# Populated in setup
3443
# TODO: Commented out since engine_config parsing extracts from class members

0 commit comments

Comments
 (0)