From f30587bfe55240a9547aed6528cc8e853f016f59 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 4 Sep 2025 22:44:58 -0700 Subject: [PATCH 1/6] benchmarking --- torchtitan/components/checkpoint.py | 27 ++++++++++++++++--- torchtitan/config/job_config.py | 3 +++ torchtitan/models/deepseek_v3/__init__.py | 2 +- .../train_configs/deepseek_v3_671b.toml | 6 +++-- 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index e9e7014425..f8c10c7d5a 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -40,6 +40,7 @@ from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.optimizer import OptimizersContainer from torchtitan.config import Checkpoint as CheckpointConfig, TORCH_DTYPE_MAP +from torchtitan.models.deepseek_v3.model.quantization import BLOCK_SIZE from torchtitan.protocols import BaseStateDictAdapter from torchtitan.tools.logging import logger from torchtitan.tools.utils import GarbageCollection @@ -247,6 +248,7 @@ def load_state_dict(state_dict): # Checkpoint policy related fields. self.initial_load_model_only = checkpoint_config.initial_load_model_only self.initial_load_in_hf = checkpoint_config.initial_load_in_hf + self.initial_load_dequantize = checkpoint_config.initial_load_dequantize self.initial_load_path = checkpoint_config.initial_load_path self.last_save_model_only = checkpoint_config.last_save_model_only self.last_save_in_hf = checkpoint_config.last_save_in_hf @@ -417,6 +419,7 @@ def dcp_load( state_dict: dict[str, Any], checkpoint_id: str, from_hf: bool, + dequantize: bool = False, ) -> None: """Load the checkpoint with dcp. Args: @@ -432,10 +435,25 @@ def dcp_load( ), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided." hf_state_dict = self.sd_adapter.to_hf(state_dict) - dcp.load( - hf_state_dict, - storage_reader=HuggingFaceStorageReader(path=checkpoint_id), - ) + if not dequantize: + dcp.load( + hf_state_dict, + storage_reader=HuggingFaceStorageReader(path=checkpoint_id), + ) + else: + from torch.distributed.checkpoint.quantized_hf_storage import ( + QuantizedHuggingFaceStorageReader, + ) + BLOCK_SIZE = 128 # hardcode for deepseek 671b now + dcp.load( + hf_state_dict, + storage_reader=QuantizedHuggingFaceStorageReader( + path=checkpoint_id, + target_dtype=torch.float32, + block_size=BLOCK_SIZE, + thread_count=4, + ), + ) state_dict = self.sd_adapter.from_hf(hf_state_dict) self.states[MODEL].load_state_dict(state_dict) @@ -600,6 +618,7 @@ def load(self, step: int = -1) -> bool: states, checkpoint_id=checkpoint_id, from_hf=from_hf, + dequantize=self.initial_load_dequantize, ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index e0189c9bb3..13fdb7d50e 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -443,6 +443,9 @@ class Checkpoint: non-tensors. The default value is False. """ + initial_load_dequantize: bool = False + + last_save_model_only: bool = True """ When last_save_model_only=True, only the model will be saved at the end of training, diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index d0478cc961..f467a2dec5 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -135,7 +135,7 @@ dim=7168, inter_dim=18432, moe_inter_dim=2048, - n_layers=61, + n_layers=4, n_dense_layers=3, n_heads=128, moe_args=MoEArgs( diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index a92a6f5a16..6c2efa1531 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -39,7 +39,6 @@ local_batch_size = 4 seq_len = 4096 max_norm = 1.0 # grad norm clipping steps = 10_000 -compile = false dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -54,12 +53,15 @@ expert_parallel_degree = 1 expert_tensor_parallel_degree = 1 [checkpoint] -enable = false +enable = true folder = "checkpoint" interval = 500 last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" +initial_load_in_hf = true +initial_load_dequantize = true +initial_load_path = "/data/users/jianiw/model/DeepSeek-V3.1-Base" [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] From 5c1d887f3e10d8fcdc7090cb5951401d94357fba Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 8 Sep 2025 16:51:02 -0700 Subject: [PATCH 2/6] test --- torchtitan/components/checkpoint.py | 6 ++++ torchtitan/models/deepseek_v3/__init__.py | 2 +- .../deepseek_v3/model/state_dict_adapter.py | 33 ++++++++++++++++--- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index f8c10c7d5a..22f0bda8c0 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -436,15 +436,20 @@ def dcp_load( hf_state_dict = self.sd_adapter.to_hf(state_dict) if not dequantize: + begin_load = time.monotonic() + logger.info("Starting dcp.load with HuggingFaceStorageReader") dcp.load( hf_state_dict, storage_reader=HuggingFaceStorageReader(path=checkpoint_id), ) + logger.info(f"dcp.load with HuggingFaceStorageReader completed in {time.monotonic() - begin_load:.2f} seconds") else: from torch.distributed.checkpoint.quantized_hf_storage import ( QuantizedHuggingFaceStorageReader, ) BLOCK_SIZE = 128 # hardcode for deepseek 671b now + begin_load = time.monotonic() + logger.info("Starting dcp.load with QuantizedHuggingFaceStorageReader") dcp.load( hf_state_dict, storage_reader=QuantizedHuggingFaceStorageReader( @@ -454,6 +459,7 @@ def dcp_load( thread_count=4, ), ) + logger.info(f"dcp.load with QuantizedHuggingFaceStorageReader completed in {time.monotonic() - begin_load:.2f} seconds") state_dict = self.sd_adapter.from_hf(hf_state_dict) self.states[MODEL].load_state_dict(state_dict) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index f467a2dec5..d0478cc961 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -135,7 +135,7 @@ dim=7168, inter_dim=18432, moe_inter_dim=2048, - n_layers=4, + n_layers=61, n_dense_layers=3, n_heads=128, moe_args=MoEArgs( diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index e947d70695..7631752518 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -6,6 +6,7 @@ import re +import time from typing import Any import torch @@ -16,6 +17,7 @@ from .args import DeepSeekV3ModelArgs from .quantization import calculate_scale_shape, dequantize_from_fp8 +from torchtitan.tools.logging import logger class DeepSeekV3StateDictAdapter(StateDictAdapter): @@ -220,6 +222,8 @@ def _get_local_experts_weights( Returns: Dictionary mapping individual expert keys to their DTensor weights """ + start_time = time.time() + logger.info(f"Starting _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}") device_mesh = grouped_expert_weight.device_mesh dtensor_placements = grouped_expert_weight.placements @@ -285,6 +289,9 @@ def _get_local_experts_weights( local_expert_tensors[expert_key] = expert_dtensor + end_time = time.time() + duration = end_time - start_time + logger.info(f"Completed _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}, duration: {duration:.4f}s") return local_expert_tensors def _concatenate_expert_weights_dtensor( @@ -312,6 +319,8 @@ def _concatenate_expert_weights_dtensor( Returns: Concatenated GroupedExperts weight DTensor if all experts are available, otherwise None """ + start_time = time.time() + logger.info(f"Starting _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}") # If we have all the experts for this abstract_key, concatenate them experts = expert_weights_by_layer[layer_num][abstract_key] expected_n_experts = ( @@ -341,6 +350,9 @@ def _concatenate_expert_weights_dtensor( if not expert_weights_by_layer[layer_num]: del expert_weights_by_layer[layer_num] + end_time = time.time() + duration = end_time - start_time + logger.info(f"Completed _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}, duration: {duration:.4f}s") return stacked_dtensor def _split_experts_weights( @@ -453,6 +465,9 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: 1. Convert between the HF shape and the torchtitan shape. 2. Split the GroupedExperts' weight into separate expert's wegiht. """ + start_time = time.time() + logger.info(f"Starting to_hf conversion, state_dict has {len(state_dict)} keys") + to_hf_map = {v: k for k, v in self.from_hf_map.items()} hf_state_dict = {} @@ -501,10 +516,13 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: hf_state_dict[new_key] = value # Prepare for dequantization - hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors( - hf_state_dict - ) - return hf_state_dict_with_scale_inv + # hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors( + # hf_state_dict + # ) + end_time = time.time() + duration = end_time - start_time + logger.info(f"Completed to_hf conversion, generated {len(hf_state_dict)} keys, duration: {duration:.4f}s") + return hf_state_dict def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """ @@ -512,10 +530,12 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: 2. Convert between the HF shape and the torchtitan shape. 3. Concate separate expert's wegiht into GroupedExperts' weight. """ + start_time = time.time() + logger.info(f"Starting from_hf conversion, state_dict has {len(hf_state_dict)} keys") # dequantize the tensor in state_dict and remove the scale_inv tensor - hf_state_dict = self._dequantize(hf_state_dict) + # hf_state_dict = self._dequantize(hf_state_dict) state_dict = {} expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} @@ -565,4 +585,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: new_key = self.from_hf_map[key] state_dict[new_key] = value + end_time = time.time() + duration = end_time - start_time + logger.info(f"Completed from_hf conversion, processed {len(hf_state_dict)} keys, duration: {duration:.4f}s") return state_dict From 1d0dff1c1f121a7c6b7093600153df43c671a203 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 9 Sep 2025 16:53:56 -0700 Subject: [PATCH 3/6] benchmarking --- torchtitan/components/checkpoint.py | 7 +++++-- .../models/deepseek_v3/train_configs/deepseek_v3_671b.toml | 5 +---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 22f0bda8c0..7c70bcb242 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -447,7 +447,10 @@ def dcp_load( from torch.distributed.checkpoint.quantized_hf_storage import ( QuantizedHuggingFaceStorageReader, ) - BLOCK_SIZE = 128 # hardcode for deepseek 671b now + + # NOTE: The following config is for DeepSeek-V3 671B model, which is using + # FP8 weight format with 128x128 block scaling. + BLOCK_SIZE = 128 begin_load = time.monotonic() logger.info("Starting dcp.load with QuantizedHuggingFaceStorageReader") dcp.load( @@ -456,7 +459,7 @@ def dcp_load( path=checkpoint_id, target_dtype=torch.float32, block_size=BLOCK_SIZE, - thread_count=4, + thread_count=8, ), ) logger.info(f"dcp.load with QuantizedHuggingFaceStorageReader completed in {time.monotonic() - begin_load:.2f} seconds") diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 6c2efa1531..1003d19496 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -53,15 +53,12 @@ expert_parallel_degree = 1 expert_tensor_parallel_degree = 1 [checkpoint] -enable = true +enable = false folder = "checkpoint" interval = 500 last_save_model_only = true export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" -initial_load_in_hf = true -initial_load_dequantize = true -initial_load_path = "/data/users/jianiw/model/DeepSeek-V3.1-Base" [activation_checkpoint] mode = "selective" # ["none", "selective", "full"] From 9797c300b174da8a85914636193bc3509cf9c21b Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 9 Sep 2025 16:55:37 -0700 Subject: [PATCH 4/6] remove dequantize --- .../models/deepseek_v3/model/quantization.py | 73 ------------------- .../deepseek_v3/model/state_dict_adapter.py | 58 --------------- 2 files changed, 131 deletions(-) delete mode 100644 torchtitan/models/deepseek_v3/model/quantization.py diff --git a/torchtitan/models/deepseek_v3/model/quantization.py b/torchtitan/models/deepseek_v3/model/quantization.py deleted file mode 100644 index a8ac6003a2..0000000000 --- a/torchtitan/models/deepseek_v3/model/quantization.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torchtitan.tools.logging import logger - -# Fixed block size of 128x128 as specified in the algorithm -BLOCK_SIZE = 128 - - -def calculate_scale_shape( - weight: torch.Tensor, BLOCK_SIZE: int = BLOCK_SIZE -) -> torch.Size: - # Calculate the scale tensor shape - orig_shape = weight.shape - - # Calculate number of blocks needed - block_rows = (orig_shape[0] + BLOCK_SIZE - 1) // BLOCK_SIZE - block_cols = (orig_shape[1] + BLOCK_SIZE - 1) // BLOCK_SIZE - - # Verify scale_inv shape matches expected block dimensions - expected_scale_shape = torch.Size((block_rows, block_cols)) - - return expected_scale_shape - - -def dequantize_from_fp8( - weight: torch.Tensor, - scale_inv: torch.Tensor, - dtype=torch.bfloat16, - BLOCK_SIZE: int = BLOCK_SIZE, -) -> torch.Tensor: - # Convert to float32 for computation - float_weight = weight.to(torch.float32) - # Get original dimensions - orig_shape = weight.shape - - # Verify scale_inv shape matches expected block dimensions - expected_scale_shape = calculate_scale_shape(weight, BLOCK_SIZE) - block_rows, block_cols = expected_scale_shape - if scale_inv.shape != expected_scale_shape: - logger.warning( - f"scale_inv shape {scale_inv.shape} doesn't match expected shape {expected_scale_shape}" - ) - - # NOTE: When processing large models on-the-fly, misalignment between block boundaries - # and DTensor local shape partitioning can lead to silent numerical inaccuracies. - dequantized = float_weight.detach().clone().to(dtype=dtype) - - # Apply scaling factors to each block - for i in range(block_rows): - row_start = i * BLOCK_SIZE - row_end = min(row_start + BLOCK_SIZE, orig_shape[0]) - - for j in range(block_cols): - col_start = j * BLOCK_SIZE - col_end = min(col_start + BLOCK_SIZE, orig_shape[1]) - - # Get the block - block = float_weight[row_start:row_end, col_start:col_end] - - scale = scale_inv[i, j] - block = block * scale - - # Explicitly convert block to dtype - block_converted = block.to(dtype=torch.float32) - # Store the dequantized block - dequantized[row_start:row_end, col_start:col_end] = block_converted - - return dequantized diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 7631752518..c14703a094 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -16,7 +16,6 @@ from torchtitan.protocols.state_dict_adapter import StateDictAdapter from .args import DeepSeekV3ModelArgs -from .quantization import calculate_scale_shape, dequantize_from_fp8 from torchtitan.tools.logging import logger @@ -410,55 +409,6 @@ def _concatenate_expert_weights( return stacked_tensor - def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: - """ - Dequantize the weights from float8 to float32. - """ - - scale_inv_keys = [] - for key, weight in state_dict.items(): - if key.endswith(".weight") and key + "_scale_inv" in state_dict: - scale_inv = state_dict[key + "_scale_inv"] - dequantized_weight = dequantize_from_fp8( - weight, scale_inv, dtype=torch.float32 - ) - # update the weight and remove the scale_inv tensor - state_dict[key] = dequantized_weight - scale_inv_keys.append(key + "_scale_inv") - - for key in scale_inv_keys: - state_dict.pop(key) - - return state_dict - - def _add_quantization_scale_inv_tensors( - self, state_dict: dict[str, Any] - ) -> dict[str, Any]: - """ - Add quantization scale tensors the state_dict. - """ - non_quantized_keys = [ - "input_layernorm.weight", - "post_attention_layernorm.weight", - "norm.weight", - "lm_head.weight", - "embed_tokens.weight", - "mlp.gate.weight", - ] - - weight_scale_inv_state_dict = {} - for key, value in state_dict.items(): - if key.endswith(".weight") and not any( - non_quantized_key in key for non_quantized_key in non_quantized_keys - ): - expected_scale_shape = calculate_scale_shape(value) - # add weight_scale_inv to the state_dict - weight_scale_inv_state_dict[key + "_scale_inv"] = torch.ones( - expected_scale_shape, dtype=torch.float32 - ) - - state_dict.update(weight_scale_inv_state_dict) - return state_dict def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ @@ -515,10 +465,6 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: new_key = to_hf_map[key] hf_state_dict[new_key] = value - # Prepare for dequantization - # hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors( - # hf_state_dict - # ) end_time = time.time() duration = end_time - start_time logger.info(f"Completed to_hf conversion, generated {len(hf_state_dict)} keys, duration: {duration:.4f}s") @@ -533,11 +479,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: start_time = time.time() logger.info(f"Starting from_hf conversion, state_dict has {len(hf_state_dict)} keys") - # dequantize the tensor in state_dict and remove the scale_inv tensor - - # hf_state_dict = self._dequantize(hf_state_dict) state_dict = {} - expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} for key, value in hf_state_dict.items(): From 57d8632cb5c37392164ca6b5285bc21d2f76cf76 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Wed, 10 Sep 2025 14:05:28 -0700 Subject: [PATCH 5/6] reformat --- torchtitan/components/checkpoint.py | 42 ++++----------- torchtitan/config/job_config.py | 3 -- .../qwen3/model/state_dict_adapter.py | 5 ++ torchtitan/models/deepseek_v3/__init__.py | 5 +- torchtitan/models/deepseek_v3/model/args.py | 3 ++ .../deepseek_v3/model/state_dict_adapter.py | 54 +++++++++++++++---- .../train_configs/deepseek_v3_671b.toml | 2 +- .../models/llama3/model/state_dict_adapter.py | 4 ++ torchtitan/protocols/state_dict_adapter.py | 16 ++++++ 9 files changed, 86 insertions(+), 48 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 7c70bcb242..933a6eda3b 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -40,7 +40,6 @@ from torchtitan.components.lr_scheduler import LRSchedulersContainer from torchtitan.components.optimizer import OptimizersContainer from torchtitan.config import Checkpoint as CheckpointConfig, TORCH_DTYPE_MAP -from torchtitan.models.deepseek_v3.model.quantization import BLOCK_SIZE from torchtitan.protocols import BaseStateDictAdapter from torchtitan.tools.logging import logger from torchtitan.tools.utils import GarbageCollection @@ -248,7 +247,6 @@ def load_state_dict(state_dict): # Checkpoint policy related fields. self.initial_load_model_only = checkpoint_config.initial_load_model_only self.initial_load_in_hf = checkpoint_config.initial_load_in_hf - self.initial_load_dequantize = checkpoint_config.initial_load_dequantize self.initial_load_path = checkpoint_config.initial_load_path self.last_save_model_only = checkpoint_config.last_save_model_only self.last_save_in_hf = checkpoint_config.last_save_in_hf @@ -419,7 +417,6 @@ def dcp_load( state_dict: dict[str, Any], checkpoint_id: str, from_hf: bool, - dequantize: bool = False, ) -> None: """Load the checkpoint with dcp. Args: @@ -434,35 +431,17 @@ def dcp_load( self.sd_adapter is not None ), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided." hf_state_dict = self.sd_adapter.to_hf(state_dict) + hf_storage_reader = self.sd_adapter.get_hf_storage_reader(checkpoint_id) - if not dequantize: - begin_load = time.monotonic() - logger.info("Starting dcp.load with HuggingFaceStorageReader") - dcp.load( - hf_state_dict, - storage_reader=HuggingFaceStorageReader(path=checkpoint_id), - ) - logger.info(f"dcp.load with HuggingFaceStorageReader completed in {time.monotonic() - begin_load:.2f} seconds") - else: - from torch.distributed.checkpoint.quantized_hf_storage import ( - QuantizedHuggingFaceStorageReader, - ) - - # NOTE: The following config is for DeepSeek-V3 671B model, which is using - # FP8 weight format with 128x128 block scaling. - BLOCK_SIZE = 128 - begin_load = time.monotonic() - logger.info("Starting dcp.load with QuantizedHuggingFaceStorageReader") - dcp.load( - hf_state_dict, - storage_reader=QuantizedHuggingFaceStorageReader( - path=checkpoint_id, - target_dtype=torch.float32, - block_size=BLOCK_SIZE, - thread_count=8, - ), - ) - logger.info(f"dcp.load with QuantizedHuggingFaceStorageReader completed in {time.monotonic() - begin_load:.2f} seconds") + begin_load = time.monotonic() + logger.info("Starting dcp.load with HuggingFaceStorageReader") + dcp.load( + hf_state_dict, + storage_reader=hf_storage_reader, + ) + logger.info( + f"dcp.load with HuggingFaceStorageReader completed in {time.monotonic() - begin_load:.2f} seconds" + ) state_dict = self.sd_adapter.from_hf(hf_state_dict) self.states[MODEL].load_state_dict(state_dict) @@ -627,7 +606,6 @@ def load(self, step: int = -1) -> bool: states, checkpoint_id=checkpoint_id, from_hf=from_hf, - dequantize=self.initial_load_dequantize, ) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 13fdb7d50e..e0189c9bb3 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -443,9 +443,6 @@ class Checkpoint: non-tensors. The default value is False. """ - initial_load_dequantize: bool = False - - last_save_model_only: bool = True """ When last_save_model_only=True, only the model will be saved at the end of training, diff --git a/torchtitan/experiments/qwen3/model/state_dict_adapter.py b/torchtitan/experiments/qwen3/model/state_dict_adapter.py index 760cc662be..600d9b511e 100644 --- a/torchtitan/experiments/qwen3/model/state_dict_adapter.py +++ b/torchtitan/experiments/qwen3/model/state_dict_adapter.py @@ -15,6 +15,8 @@ import re from typing import Any +from torch.distributed.checkpoint import HuggingFaceStorageReader + from torchtitan.protocols.state_dict_adapter import StateDictAdapter from .args import Qwen3ModelArgs @@ -45,6 +47,9 @@ def __init__(self, model_args: Qwen3ModelArgs, hf_assets_path: str | None): "lm_head.weight": "output.weight", } + def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: + return HuggingFaceStorageReader(path) + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: to_hf_map = {v: k for k, v in self.from_hf_map.items()} diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index d0478cc961..357ca80101 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -154,8 +154,9 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, - attn_mask_type="block_causal", + # use_flex_attn=True, + # attn_mask_type="block_causal", + hf_weight_quantized=True, ), } diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index d6afedfa34..b27b7a9d50 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -86,6 +86,9 @@ class DeepSeekV3ModelArgs(BaseModelArgs): beta_slow: int = 1 mscale: float = 1.0 + # HF checkpoint args + hf_weight_quantized: bool = False + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: seq_len = job_config.training.seq_len if seq_len > self.max_seq_len: diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index c14703a094..4c18a944a8 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -10,13 +10,14 @@ from typing import Any import torch +from torch.distributed.checkpoint import HuggingFaceStorageReader from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard from torchtitan.protocols.state_dict_adapter import StateDictAdapter +from torchtitan.tools.logging import logger from .args import DeepSeekV3ModelArgs -from torchtitan.tools.logging import logger class DeepSeekV3StateDictAdapter(StateDictAdapter): @@ -79,6 +80,26 @@ def __init__( self.grouped_expert_weight_shape = {} # {titan_abstract_key: shape} self.local_experts_indices = {} # {titan_abstract_key: (start_idx, end_idx)} + def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: + if self.model_args.hf_weight_quantized: + from torch.distributed.checkpoint.quantized_hf_storage import ( + QuantizedHuggingFaceStorageReader, + ) + + # NOTE: Now we use Quantized HF storage reader to read DeepSeek-V3 671B model. + # If loading checkpoints without quantization, use HuggingFaceStorageReader instead + BLOCK_SIZE = 128 + return ( + QuantizedHuggingFaceStorageReader( + path=path, + target_dtype=torch.float32, + block_size=BLOCK_SIZE, + thread_count=8, + ), + ) + else: + return HuggingFaceStorageReader(path) + def _calculate_strided_shard_shard_indices( self, strided_shard_dim_degree: int, @@ -222,7 +243,9 @@ def _get_local_experts_weights( Dictionary mapping individual expert keys to their DTensor weights """ start_time = time.time() - logger.info(f"Starting _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}") + logger.info( + f"Starting _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}" + ) device_mesh = grouped_expert_weight.device_mesh dtensor_placements = grouped_expert_weight.placements @@ -290,7 +313,9 @@ def _get_local_experts_weights( end_time = time.time() duration = end_time - start_time - logger.info(f"Completed _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}, duration: {duration:.4f}s") + logger.info( + f"Completed _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}, duration: {duration:.4f}s" + ) return local_expert_tensors def _concatenate_expert_weights_dtensor( @@ -319,7 +344,9 @@ def _concatenate_expert_weights_dtensor( Concatenated GroupedExperts weight DTensor if all experts are available, otherwise None """ start_time = time.time() - logger.info(f"Starting _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}") + logger.info( + f"Starting _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}" + ) # If we have all the experts for this abstract_key, concatenate them experts = expert_weights_by_layer[layer_num][abstract_key] expected_n_experts = ( @@ -351,7 +378,9 @@ def _concatenate_expert_weights_dtensor( end_time = time.time() duration = end_time - start_time - logger.info(f"Completed _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}, duration: {duration:.4f}s") + logger.info( + f"Completed _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}, duration: {duration:.4f}s" + ) return stacked_dtensor def _split_experts_weights( @@ -409,7 +438,6 @@ def _concatenate_expert_weights( return stacked_tensor - def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ 1. Convert between the HF shape and the torchtitan shape. @@ -417,7 +445,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ start_time = time.time() logger.info(f"Starting to_hf conversion, state_dict has {len(state_dict)} keys") - + to_hf_map = {v: k for k, v in self.from_hf_map.items()} hf_state_dict = {} @@ -467,7 +495,9 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: end_time = time.time() duration = end_time - start_time - logger.info(f"Completed to_hf conversion, generated {len(hf_state_dict)} keys, duration: {duration:.4f}s") + logger.info( + f"Completed to_hf conversion, generated {len(hf_state_dict)} keys, duration: {duration:.4f}s" + ) return hf_state_dict def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: @@ -477,7 +507,9 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: 3. Concate separate expert's wegiht into GroupedExperts' weight. """ start_time = time.time() - logger.info(f"Starting from_hf conversion, state_dict has {len(hf_state_dict)} keys") + logger.info( + f"Starting from_hf conversion, state_dict has {len(hf_state_dict)} keys" + ) state_dict = {} expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} @@ -529,5 +561,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: end_time = time.time() duration = end_time - start_time - logger.info(f"Completed from_hf conversion, processed {len(hf_state_dict)} keys, duration: {duration:.4f}s") + logger.info( + f"Completed from_hf conversion, processed {len(hf_state_dict)} keys, duration: {duration:.4f}s" + ) return state_dict diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index 1003d19496..7d44a6a1f6 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -65,7 +65,7 @@ mode = "selective" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] -enable=true +enable = true components = ["loss"] # ["model", "loss"] [float8] diff --git a/torchtitan/models/llama3/model/state_dict_adapter.py b/torchtitan/models/llama3/model/state_dict_adapter.py index 2c386ece0d..1475ba2055 100644 --- a/torchtitan/models/llama3/model/state_dict_adapter.py +++ b/torchtitan/models/llama3/model/state_dict_adapter.py @@ -10,6 +10,7 @@ logger = logging.getLogger() +from torch.distributed.checkpoint import HuggingFaceStorageReader from torchtitan.protocols.state_dict_adapter import StateDictAdapter from .args import TransformerModelArgs @@ -41,6 +42,9 @@ def __init__( "lm_head.weight": "output.weight", } + def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: + return HuggingFaceStorageReader(path) + # HuggingFace permutation function (exact copy from their conversion script) def _permute(self, w, n_heads_arg, dim1=None, dim2=None): if dim1 is None: diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index 5b441e9bbf..a6fedd0af4 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -11,6 +11,9 @@ from abc import ABC, abstractmethod from typing import Any +from torch.distributed.checkpoint import HuggingFaceStorageReader + + logger = logging.getLogger() from .model import BaseModelArgs @@ -58,6 +61,19 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """ pass + @abstractmethod + def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: + """Returns hf storage reader to read HF checkpoint + + Args: + path: the path to read HF checkpoint + + Returns: + THe HuggingFace storage reader to read rom HF checkpoint + + """ + pass + class StateDictAdapter(BaseStateDictAdapter): """State dict adapter base class which provides convenient default behavior to build fqn_to_index_mapping""" From b0b90036204668277e4f7c93a79aa017a2d66f28 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 15 Sep 2025 11:36:25 -0700 Subject: [PATCH 6/6] fix return --- .../models/deepseek_v3/model/state_dict_adapter.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index 4c18a944a8..f0e40cd87b 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -89,13 +89,11 @@ def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: # NOTE: Now we use Quantized HF storage reader to read DeepSeek-V3 671B model. # If loading checkpoints without quantization, use HuggingFaceStorageReader instead BLOCK_SIZE = 128 - return ( - QuantizedHuggingFaceStorageReader( - path=path, - target_dtype=torch.float32, - block_size=BLOCK_SIZE, - thread_count=8, - ), + return QuantizedHuggingFaceStorageReader( + path=path, + target_dtype=torch.float32, + block_size=BLOCK_SIZE, + thread_count=8, ) else: return HuggingFaceStorageReader(path)