-
Notifications
You must be signed in to change notification settings - Fork 600
[Full DTensor] Add full_dtensor flag #2002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This refactors the PP splitting logic to consolidate around settings FQNs for each model chunk. For example: ``` [ ['tok_embeddings', 'layers.0'], # stage0 ['layers.1', 'layers.2'], # stage1 ['layers.3', 'layers.4'], # stage2 ... # so on... ] ``` This is better because it can generally be applied to all models, and the code can be re-used for cases that don't explicitly require pipelined execution (for example, streaming diloco needs to communicate model chunks) Changes: - Refactor deepseekv3 and llama to share the same pipeline util functions - Add module_names_per_model_chunk config, deprecate pipeline_parallel_split_points TODO (follow up PRs): - `pipeline_module_split` will be upstreamed to PyTorch as a `torch.distributed.pipelining` utility since it contains no model specific code. - Additional changes are needed to get this to work for torchft streaming diloco including updating the training loop to not execute if the pipeline schedule isn't set and making sure the pipelining_fn return the correct model chunks. cc @tushar00jain
Based on the discussion in this PR (#1495), the conclusion was to ensure that 16B uses the proper tokenizer to avoid the cudaAssertError in the current config which comes from mismatch between embeddings and tokenizer vocab. Thus, this PR; 1 - adds additional line to the readme for enabling users to pull the 16B-chat tokenizer, 2- updates the 16_toml config to point to the 16B tokenizer under /assets/tokenizer/deepseek-moe-16b-chat With that, the vocab size of 102400 already in the toml now works flawlessly. **Testing:** run download tokenizer run 20 iters with 16B without issue. <img width="1255" height="201" alt="Screenshot 2025-07-30 at 12 46 38 PM" src="https://github.com/user-attachments/assets/e33556bf-51c6-4fa0-ab71-d1b02ef74d99" />
Summary: remove some stale code that determines parameters to pass to outer optimizer --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1501). * #1446 * #1502 * __->__ #1501
Summary: the leaf folder wasn't being created so and no profiles were being written, so create it if it doesn't exist --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1502). * #1446 * __->__ #1502 * #1501
With recent api change to pipeline schedule pytorch/pytorch#157795, we can now schedule forward pass and calculate loss, allowing us to use validation and pp together. To test correctness we train from a seed checkpoint with training.seed and training.determinism set with varying degrees of parallelism and different pipeline schedules to compare if loss remains the same: | Parallelism | Loss | | --- | --- | | FSDP=2 | <img width="960" height="328" alt="Screenshot 2025-07-29 at 5 12 49 PM" src="https://github.com/user-attachments/assets/3aedc87d-f12c-409c-88da-86b0ac72a1a7" /> | | FSDP=2, TP=2, PP=2, PP_schedule="1F1B" | <img width="964" height="334" alt="Screenshot 2025-07-29 at 5 17 18 PM" src="https://github.com/user-attachments/assets/b5f8979b-0f44-48fc-aa4d-38e938c5cf43" /> | | FSDP=2, PP=4, PP_schedule="1F1B" | <img width="973" height="335" alt="Screenshot 2025-07-29 at 5 15 53 PM" src="https://github.com/user-attachments/assets/29636394-b602-4a21-995d-94769771f599" /> | | FSDP=2, PP=4, PP_schedule="Interleaved1F1B" |<img width="964" height="329" alt="Screenshot 2025-07-29 at 5 39 39 PM" src="https://github.com/user-attachments/assets/de960111-d0ad-4470-a096-493d7f59461e" /> | | FSDP=2, PP=4, PP_schedule="GPipe" | <img width="971" height="329" alt="Screenshot 2025-07-29 at 5 49 36 PM" src="https://github.com/user-attachments/assets/2100b2a2-2725-43c8-a937-78fb05962247" /> | FSDP=2, PP=4, PP_schedule="LoopedBFS" | <img width="963" height="330" alt="Screenshot 2025-07-29 at 5 54 55 PM" src="https://github.com/user-attachments/assets/102df0f7-bd4f-47a6-a94a-a1bf488237ce" /> | FSDP=2, PP=4, PP_schedule="InterleavedZeroBubble" | <img width="960" height="343" alt="Screenshot 2025-07-30 at 2 30 53 PM" src="https://github.com/user-attachments/assets/1d2bce1a-0b8c-4d09-85b8-0a0634f68690" />
# Fix incorrect data loading time measurement This PR fixes the timing of data_loading_times measurement in batch_generator. Previously, the timer started after calling next(data_iterator), which excluded the actual data fetching time from the measurement. Now, the timer starts before the next() call to correctly capture the full DataLoader latency.
…ts in training scripts (#1473) The goal of this PR is to add support for command line arguments to the bash training scripts. The `run_train.sh` had support for `overrides`, however, the `multinode_trainer.slurm` script did not. This `overrides` flag add supports for commands like: `sbatch multinode_trainer.slurm --job.description="TEST_RUN"` However, there is a problem with the current `overrides` implementation, when passing arguments with space such as `"TEST RUN"` instead of `"TEST_RUN"` then the variable `job.description` would only get `TEST` as input and the training script throws an error for unrecognizing the argument `RUN` which is passed in a different line. To address this I simplify the code and directly pass the additional overrides through `$@`. This solves the issue for commands such as: `sbatch multinode_trainer.slurm --job.description="TEST RUN"`
This PR adds learning rate logging. There was a previous attempt to implement this in an [earlier PR](#937), but that one was ultimately **closed**. This version ensures that LR logging works properly, I verified it using the WSD scheduler that was recently added in [another PR](#938). <img width="1842" height="730" alt="image" src="https://github.com/user-attachments/assets/8f23674a-d689-4cc2-9d9b-30bff4e63f3b" /> One design consideration here is that torchtitan supports multiple optimizers and learning rate schedules, each potentially having its own LR. However, in practice, I believe that 99.9999% of use cases will use a single LR. Given that, the logging works as follows: - If there is only one learning rate, it gets logged directly under the main charts as `lr`. - If there are multiple learning rates, they are logged under a separate section, each with its corresponding label. Alternatively, we could have ignored the multi-LR case and always logged a single LR, but I prefer this approach since it handles both scenarios robustly with minimal extra code. Happy to adjust if others have a strong preference for simplicity over robustness.
## Summary - For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. ## Test plan - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #1504 **Summary** ## ~~Change tokenizer size~~ This is resolved by downloading the right tokenizer Before the change: ``` File "/data/users/xilunwu/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/xilunwu/pytorch/torch/nn/modules/normalization.py", line 414, in forward return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/xilunwu/pytorch/torch/nn/functional.py", line 2924, in rms_norm return torch.rms_norm(input, normalized_shape, weight, eps) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ torch.AcceleratorError: CUDA error: device-side assert triggered Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information. CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. ``` Adding CUDA_LAUNCH_BLOCKING=1 to launch command shows the real error is in embedding. After fixing the tokenizer size the training works fine. ## Add `.contiguous()` to output after calling transpose() Command: `NGPU=8 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.context-parallel-degree 2` Error: ``` [rank0]:[rank0]: File "/data/users/xilunwu/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/xilunwu/oss/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 237, in forward [rank0]:[rank0]: output = output.view(bsz, seqlen, -1) # (bsz, seqlen, n_heads * v_head_dim) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. ``` The model code didn't match with llama3. After adding `.contiguous()` it runs correctly. ``` NGPU=8 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.context-parallel-degree 2 + NGPU=8 + export LOG_RANK=0 + LOG_RANK=0 + CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml + overrides= + '[' 2 -ne 0 ']' + overrides='--parallelism.context-parallel-degree 2' + TORCHFT_LIGHTHOUSE=http://localhost:29510 + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510 + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml --parallelism.context-parallel-degree 2 W0731 11:31:18.279000 3852560 torch/distributed/run.py:803] W0731 11:31:18.279000 3852560 torch/distributed/run.py:803] ***************************************** W0731 11:31:18.279000 3852560 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0731 11:31:18.279000 3852560 torch/distributed/run.py:803] ***************************************** [rank0]:[titan] 2025-07-31 11:31:25,671 - root - INFO - Starting job: DeepSeek-V3 16B model training [rank0]:[titan] 2025-07-31 11:31:27,890 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:[titan] 2025-07-31 11:31:27,891 - root - INFO - Building 2-D device mesh with ['dp_shard', 'cp'], [4, 2] [rank0]:[titan] 2025-07-31 11:31:27,897 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank0]:NCCL version 2.27.5+cuda12.6 [rank0]:[titan] 2025-07-31 11:31:32,956 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-07-31 11:31:33,170 - root - INFO - Preparing c4 dataset from allenai/c4 [rank0]:[titan] 2025-07-31 11:31:38,681 - root - INFO - Building deepseek_v3 16B with DeepSeekV3ModelArgs(_enforced='This field is used to enforce all fields have defaults.', max_batch_size=8, max_seq_len=4096, dtype='bf16', vocab_size=129280, dim=2048, inter_dim=10944, moe_inter_dim=1408, n_layers=27, n_dense_layers=1, n_heads=16, norm_eps=1e-05, n_routed_experts=64, n_shared_experts=2, n_activated_experts=6, n_expert_groups=1, n_limited_groups=1, score_func='softmax', route_scale=1.0, use_grouped_mm=True, load_balance_coeff=0.001, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, use_flex_attn=False, attn_mask_type='causal', original_seq_len=4096, rope_theta=10000.0, rope_factor=40, beta_fast=32, beta_slow=1, mscale=0.7) [rank0]:[titan] 2025-07-31 11:31:38,855 - root - INFO - CUDA capacity: NVIDIA H100 with 95.00GiB memory [rank0]:[titan] 2025-07-31 11:31:38,929 - root - INFO - Total parameter count: dense 968,486,400, sparse 14,848,098,304, active 2,771,250,688 [rank0]:[titan] 2025-07-31 11:31:38,929 - root - INFO - Model deepseek_v3 16B size: 15,816,584,704 total parameters [rank0]:[titan] 2025-07-31 11:31:38,930 - root - INFO - Applied full activation checkpointing to the model [rank0]:[titan] 2025-07-31 11:31:39,021 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-07-31 11:31:39,021 - root - INFO - Applied Context Parallel to the model [rank0]:[titan] 2025-07-31 11:31:39,398 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-07-31 11:31:39,399 - root - INFO - CUDA memory usage for model: 8.84GiB(9.30%) [rank0]:[titan] 2025-07-31 11:31:39,400 - root - INFO - Mixed precision training is handled by fully_shard [rank0]:[titan] 2025-07-31 11:31:39,400 - root - INFO - Trainer is initialized with local batch size 8, global batch size 32, gradient accumulation steps 1, sequence length 4096, total steps 1000 (warmup 200) [rank0]:[titan] 2025-07-31 11:31:39,400 - root - INFO - Training starts at step 1 [rank0]:[titan] 2025-07-31 11:31:49,242 - root - INFO - step: 1 loss: 12.2584 grad_norm: 1.2466 memory: 53.49GiB(56.30%) tps: 1,589 tflops: 28.21 mfu: 2.85% [rank0]:[titan] 2025-07-31 11:31:49,242 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-31 11:32:13,707 - root - INFO - step: 10 loss: 11.5358 grad_norm: 1.4495 memory: 71.08GiB(74.82%) tps: 6,027 tflops: 107.02 mfu: 10.82% [rank0]:[titan] 2025-07-31 11:32:40,848 - root - INFO - step: 20 loss: 10.0093 grad_norm: 7.7745 memory: 71.08GiB(74.82%) tps: 6,037 tflops: 107.20 mfu: 10.84% ```
as titled, found there is a small typo in importing deepseekv3 model functions.
Summary: Instead of maintaining a mapping in torchtitan with valid mx recipe names, just pass the string recipe directly to torchao. This way torchao can iterate on recipes without any changes to torchtitan to use those recipes. Note that appropriate error messages will be thrown from torchao if user specifies an invalid config name, so there is no need to duplicate them in torchtitan. Test Plan: ```bash with-proxy CONFIG_FILE="torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.print_after_conversion --training.compile --training.steps 50 --model.converters mx --mx.recipe_name "mxfp8_cublas_rceil" ``` Reviewers: Subscribers: Tasks: Tags:
Currently, `ntokens_seen` is only locally logged. I think it is almost always desirable to only track the global quantity (the only use case I can see for per-device tracking is for debugging?). Therefore, I propose to all-reduce `ntokens_seen` before logging.
Currently, the first time validation metrics are computed is when `step == job_config.validation.freq`. I think it is preferable to always compute them for the first step as well.
…ed mm to have col-major memory layout (#1517) # Summary Rather than store experts weights pre-transposed (E, in_dim, out_dim), we should store the expert weights non-transposed (E, out_dim, in_dim) then transpose before grouped gemm for (1) compatible dims for gemm, and (2) column-major memory layout required for right operand in grouped gemm. Doing this simple transpose (metadata change only) is must more efficient than doing this [inefficient memory layout transformation before every GEMM in fp8](https://github.com/pytorch/ao/blob/6e941c87c4d9fb9a74e6f979dd522605c696ca42/torchao/prototype/moe_training/scaled_grouped_mm.py#L96). # Eager Performance Llama4 debug model with FSDP=8, using config: ```python "debugmodel": TransformerModelArgs( dim=5120, n_layers=4, n_heads=40, n_kv_heads=8, ffn_dim_multiplier=1.2, multiple_of=2048, rope_theta=500000, max_seq_len=10485760, num_experts=16, interleave_moe_layer_step=1, ), ``` ### bfloat16 With change: ``` ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2147.0 Max Memory Usage: 92.67 GiB ``` Without change: ``` ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 1711.0 Max Memory Usage: 92.67 GiB ``` ### fp8 rowwise With change: ``` (torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2675.0 Max Memory Usage: 90.35 GiB ``` Without change: ``` (torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2360.0 Max Memory Usage: 90.35 GiB ```
…ers (#1519) We should add an `apply_compile` function for llama4 that uses fullgraph=False for MoE layers and fullgraph=True for dense layers. I keep manually applying this hack during development to test compile composability, but IMO we should have this merged and update to use fullgraph=True everywhere once that is supported. cc @xmfan @tianyu-l any thoughts?
This PR updates to use base rather than chat (they are the same but name is different) and makes it clear we are not loading the model weights for 16b. Testing: download via script run 20 iters with 16b_base tokenizer.
As titled, quick followup for #1499
# This pr implements the validator class for flux following the method discussed in Stable Diffusion 3 paper. The paper shows that creating 8 equidistant timesteps and calculating the average loss on them will result in a highly correlated loss to external validation methods such as CLIP or FID score. This pr's implementation rather than creating 8 stratified timesteps per sample, only applies one of these equidistant timesteps to each sample in a round-robin fashion. Aggregated over many samples in a validation set, this should give a similar validation score as the full timestep method, but will process more validation samples quickly. ### Implementations - Integrates the image generation evaluation in the validation step, users can - Refactors and combines eval job_config with validation - Adds an `all_timesteps` option to the job_config to choose whether to use round robin timesteps or full timesteps per sample - Creates validator class and validation dataloader for flux, validator dataloader handles generating timesteps for round-robin method of validation ### Enabling all timesteps Developers can enable the full timestamp method of validation by setting `all_timesteps = True` in the flux validation job config. Enabling all_timesteps may require tweaking some hyperparams `validation.local_batch_size, validation.steps` to prevent spiking memory and optimizing throughput. By using a ratio of around 1/4 for `validation.local_batch_size` to `training.local_batch_size` will not spike the memory higher than training when `fsdp = 8`. Below we can see the difference between round robin and all timesteps. In the comparison the total number of validation samples processed is the same, but in `all_timesteps=True` configuration we have to lower the batch size to prevent memory spiking. All timesteps also achieves a higher throughput (tps) but still processes total samples of validation set more slowly. | Round Robin (batch_size=32, steps=1, fsdp=8) | All Timesteps (batch_size=8, steps=4, fsdp=8) | | ---- | --- | | <img width="682" height="303" alt="Screenshot 2025-08-01 at 3 46 42 PM" src="https://github.com/user-attachments/assets/30328bfe-4c3c-4912-a329-2b94c834b67b" /> | <img width="719" height="308" alt="Screenshot 2025-08-01 at 3 30 10 PM" src="https://github.com/user-attachments/assets/c7325d21-8a7b-41d9-a0d2-74052e425083" /> |
Summary: - add a configuration option for users to provide how they want to partition the model - if this is provided, the model needs to implement `FaultTolerantTrainingSpec` that defines the framentation function to split the model based on the configuration - determine the model fragments in training script to pass to ft manager Test Plan: Running llama3 8b parameters with 2 fragments, 1 step delay, each fragment gets synced every 20 steps <img width="944" height="545" alt="image" src="https://github.com/user-attachments/assets/6d16f486-7260-49d6-8ba3-3e98cd331e58" /> --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1446). * #1516 * __->__ #1446
fix typo in chekcpoint.md
After changing the setup for the JobConfig and ConfigManager, some files had the old structure for paths and directories. Those directories affected the import libs paths. This PR fixes those paths and directories. Mainly these changes are related to this line: `from torchtitan.config_manager import ...` which should be `from torchtitan.config import ....`. --------- Co-authored-by: Ali Sol <[email protected]>
Given the complexity of MoE and EP modules This PR 1. creates `torchtitan/models/moe.py` as the central moe implementation (this is similar to why we have `torchtitan/models/attention.py`) 2. creates `torchtitan/distributed/expert_parallel.py` as the central EP implementation 3. rename `torchtitan/distributed/pipeline.py` -> `torchtitan/distributed/pipeline_parallel.py` to be consistent with EP 4. apply temporary fix by @rakkit #1467 before the memory leak issue with AC + PT-D all_to_all_single_autograd is fixed (cc @soulitzer)
Summary: - wasn't seeing print statements getting printed - the statements show up using the logger - also added some logging to validate the model is being split for diloco
The current EP grad clipping logic assumes that when using EP all of the norms returned by `torch.nn.utils.get_total_norm` are `DTensor`s. This assumption can be violated and the subsequent `full_tensor` call can correspondingly fail in the edge case where the [ep_grad list](https://github.com/pytorch/torchtitan/blob/a1fdd7e43694bbfeff5d6ad8ac738c067bb90d41/torchtitan/distributed/utils.py?plain=1#L408) is empty, in which case `get_total_norm` returns `tensor(0.)`, a non-`DTensor`. https://github.com/pytorch/torchtitan/blob/a1fdd7e43694bbfeff5d6ad8ac738c067bb90d41/torchtitan/distributed/utils.py?plain=1#L421-L423 ``` File "/app/torchtitan/torchtitan/distributed/utils.py", line 423, in _clip_grad_norm_with_ep ).full_tensor() ^^^^^^^^^^^ AttributeError: 'Tensor' object has no attribute 'full_tensor' ``` This edge case can occur in PP+EP setups when model uses some fully dense and some MoE layers (like DSv3), in which case some pp ranks may not be assigned any MoE layers. I suppose it is possible that `non_ep_grads` could also be empty, but I can only imagine this happening in extreme cases, so I did not change the `non_ep_grads` code. CC @tianyu-l
If validation and checkpoint occur on the same training step, do checkpointing first
issue pointed out in #1534 (comment) pytorch/pytorch#160285 solution given by @rakkit in #1534 (comment)
This experiment provides a complete framework for bitwise-deterministic reinforcement learning training that combines vLLM for fast rollouts and TorchTitan for training with gradients. Key features: - Bitwise-deterministic forward and backward passes - vLLM-compatible Qwen3 model with merged projections - Custom Flash Attention with gradient support - Gradient support for vLLM's batch-invariant operations - Complete RL training loop with GRPO-style advantages - Comprehensive test suite for determinism verification Components: - models/attention.py: VLLMCompatibleFlashAttention - models/qwen3/model_vllm_compat.py: vLLM-compatible Qwen3 model - batch_invariant_backward.py: Gradient support for vLLM operations - simple_rl.py: End-to-end RL training loop - tests/: Test suite for backward passes and determinism --------- Co-authored-by: Teja <[email protected]> Co-authored-by: Chien-Chin Huang <[email protected]>
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #2002 * #2001 * #1995 * __->__ #1985 We are adding more actions to convert the raw inputs and label. 1. The new CP can do the input/label/BlockMask sharding this in this method. 2. The experimental full dtensor model can simply override this method without changing too many Trainer code. This method is extracted from #1857 Makeing this a standalone PR allows us to continue the two projects above without one blocks another.
[ghstack-poisoned]
When full_dtensor is True, the compute_placement will be preserved. This means that `to_local()` won't be called for fsdp only case. nD parallelism case (fsdp + tp) will error out as we have not implemented this case. This argument doesn't affect the current simple_fsdp. We have verified `full_dtensor=True` case with the full dtensor skleton PR, which will be published once it is ready. ghstack-source-id: 9f9efce Pull-Request: #2002
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #2002 * #2001 * __->__ #1995 People are creating different train.py and duplicate the `main` function. But in realitly people just want to use different Trainer subclasses. This PR creates a main() in torchtitan/train.py to deduplicate the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the full dtensor version is in compiler toolkit, isn't it? cc. @SherlockNoMad @yiming0416
Do we have a plan to migrate full dtensor to simplefsdp folder?
@ruisizhang123 Not really, DTensorizing inputs in compiler toolkit only applies to the Also currently we directly import the |
| mp_policy: MixedPrecisionPolicy | None, | ||
| reshard_after_forward: bool, | ||
| reduction_divide_factor: float | None, | ||
| full_dtensor: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we rename full_dtensor to sth more explicit (e.g., is_input_dtensor)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_input_dtensor sounds like we are using full dtensor because input is a dtensor, but the idea should be -- we are using full dtensor so input should be full dtensor, as well as the params. So I think full_dtensor or use_full_dtensor is OK for now. Eventually I think we should deprecate the non-full-dtensor paths.
| mp_policy: MixedPrecisionPolicy | None, | ||
| reshard_after_forward: bool, | ||
| reduction_divide_factor: float | None, | ||
| full_dtensor: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_input_dtensor sounds like we are using full dtensor because input is a dtensor, but the idea should be -- we are using full dtensor so input should be full dtensor, as well as the params. So I think full_dtensor or use_full_dtensor is OK for now. Eventually I think we should deprecate the non-full-dtensor paths.
This allows people to customize the distributed environment, including ParallelDims and distributed backend.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#2002 * pytorch#2001 * pytorch#1995 * __->__ pytorch#1985 We are adding more actions to convert the raw inputs and label. 1. The new CP can do the input/label/BlockMask sharding this in this method. 2. The experimental full dtensor model can simply override this method without changing too many Trainer code. This method is extracted from pytorch#1857 Makeing this a standalone PR allows us to continue the two projects above without one blocks another.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#2002 * pytorch#2001 * __->__ pytorch#1995 People are creating different train.py and duplicate the `main` function. But in realitly people just want to use different Trainer subclasses. This PR creates a main() in torchtitan/train.py to deduplicate the code.
We should be able to control what passes to run in the compiler. This PR uses the config compile.passes to indicate in a list of graph passes to apply on the captured gm. By default, no pass is applied. Users can specify what passes to apply. Currently there are `autobucketing_reordering_pass` and `regional_inductor_pass`. ``` NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor ``` Also updated CI to include this new config
After some offline discussion, we've concluded that life would be easier if we can put simplefsdp's checkpoint logic for `reshard_after_forward` to compiler. The ac annotation part is borrowed form AP: [LINK](https://github.com/meta-pytorch/autoparallel/blob/main/autoparallel/activation_checkpointing.py#L69). **Trace and Loss Check** (all with torch.compile enable) reshard_after_fwd = False 1. SAC + llama3 ([trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-05-06_rank0_trace.json)) <img width="768" height="115" alt="Screenshot 2025-10-30 at 4 28 59 PM" src="https://github.com/user-attachments/assets/e4e22335-2e3f-46c8-8def-a60d592fee0a" /> <img width="689" height="512" alt="Screenshot 2025-11-05 at 9 02 30 PM" src="https://github.com/user-attachments/assets/40a71316-a457-4e72-9002-cc8beea8f32c" /> 2. Full AC + llama3 [(trace)]() <img width="729" height="105" alt="Screenshot 2025-10-30 at 4 30 53 PM" src="https://github.com/user-attachments/assets/e8d63460-579b-4f0a-8504-851480e5b548" /> <img width="789" height="763" alt="Screenshot 2025-11-05 at 9 11 34 PM" src="https://github.com/user-attachments/assets/1a13d09e-04c4-4db9-99fe-cf10d24bf7f5" /> 3. No AC + llama3 [[trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-03-50_rank0_trace.json)] <img width="748" height="115" alt="Screenshot 2025-10-30 at 4 32 05 PM" src="https://github.com/user-attachments/assets/20104d24-9d45-4eba-b694-815e133b88d0" /> <img width="800" height="764" alt="Screenshot 2025-11-05 at 9 07 46 PM" src="https://github.com/user-attachments/assets/55b104ce-8ec1-4ed6-95e7-300e96ad55af" /> reshard_after_fwd = True 1. SAC + llama3 ([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-31-11-34-24_rank0_trace.json)) <img width="795" height="108" alt="Screenshot 2025-10-31 at 11 34 47 AM" src="https://github.com/user-attachments/assets/a3988f72-7e87-4e52-90f9-8bee840cd6f4" /> 2. Full AC + llama3 ([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-31-11-36-27_rank0_trace.json)) <img width="593" height="110" alt="Screenshot 2025-10-31 at 11 38 02 AM" src="https://github.com/user-attachments/assets/5ee61b2b-9600-4af8-9a24-61b3564f93ca" /> 3. No AC + llama3 ([Trace](https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/ruisizhang123_2025-10-30-17-02-44_rank0_trace.json)) <img width="701" height="109" alt="Screenshot 2025-10-31 at 11 43 04 AM" src="https://github.com/user-attachments/assets/576b28f6-dae4-4ff7-b005-57b0cf9ad7cc" />
When full_dtensor is True, the compute_placement will be preserved. This means that `to_local()` won't be called for fsdp only case. nD parallelism case (fsdp + tp) will error out as we have not implemented this case. This argument doesn't affect the current simple_fsdp. We have verified `full_dtensor=True` case with the full dtensor skleton PR, which will be published once it is ready. [ghstack-poisoned]
When full_dtensor is True, the compute_placement will be preserved. This means that `to_local()` won't be called for fsdp only case. nD parallelism case (fsdp + tp) will error out as we have not implemented this case. This argument doesn't affect the current simple_fsdp. We have verified `full_dtensor=True` case with the full dtensor skleton PR, which will be published once it is ready. [ghstack-poisoned]
When full_dtensor is True, the compute_placement will be preserved. This means that `to_local()` won't be called for fsdp only case. nD parallelism case (fsdp + tp) will error out as we have not implemented this case. This argument doesn't affect the current simple_fsdp. We have verified `full_dtensor=True` case with the full dtensor skleton PR, which will be published once it is ready. ghstack-source-id: ac4d1bf Pull-Request: #2002
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #2002 * __->__ #2001 Add typing, credit to Claude.
When full_dtensor is True, the compute_placement will be preserved. This means that `to_local()` won't be called for fsdp only case. nD parallelism case (fsdp + tp) will error out as we have not implemented this case. This argument doesn't affect the current simple_fsdp. We have verified `full_dtensor=True` case with the full dtensor skleton PR, which will be published once it is ready. ghstack-source-id: ac4d1bf Pull-Request: #2002
beda97b to
9356ccc
Compare
When full_dtensor is True, the compute_placement will be preserved. This means that `to_local()` won't be called for fsdp only case. nD parallelism case (fsdp + tp) will error out as we have not implemented this case. This argument doesn't affect the current simple_fsdp. We have verified `full_dtensor=True` case with the full dtensor skleton PR, which will be published once it is ready. ghstack-source-id: ac4d1bf Pull-Request: #2002
9356ccc to
80f6377
Compare
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * __->__ #2013 When full_dtensor is True, the compute_placement will be preserved. This means that `to_local()` won't be called for fsdp only case. nD parallelism case (fsdp + tp) will error out as we have not implemented this case. This argument doesn't affect the current simple_fsdp. We have verified `full_dtensor=True` case with the full dtensor skleton PR, which will be published once it is ready. **This is a reland PR of #2002. The previous one was broken during rebase.**
Stack from ghstack (oldest at bottom):
When full_dtensor is True, the compute_placement will be preserved. This means that
to_local()won't be called for fsdp only case. nD parallelism case (fsdp + tp) will error out as we have not implemented this case.This argument doesn't affect the current simple_fsdp. We have verified
full_dtensor=Truecase with the full dtensor skleton PR, which will be published once it is ready.