diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index e9e7014425..933a6eda3b 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -431,10 +431,16 @@ def dcp_load( self.sd_adapter is not None ), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided." hf_state_dict = self.sd_adapter.to_hf(state_dict) + hf_storage_reader = self.sd_adapter.get_hf_storage_reader(checkpoint_id) + begin_load = time.monotonic() + logger.info("Starting dcp.load with HuggingFaceStorageReader") dcp.load( hf_state_dict, - storage_reader=HuggingFaceStorageReader(path=checkpoint_id), + storage_reader=hf_storage_reader, + ) + logger.info( + f"dcp.load with HuggingFaceStorageReader completed in {time.monotonic() - begin_load:.2f} seconds" ) state_dict = self.sd_adapter.from_hf(hf_state_dict) diff --git a/torchtitan/experiments/qwen3/model/state_dict_adapter.py b/torchtitan/experiments/qwen3/model/state_dict_adapter.py index 760cc662be..600d9b511e 100644 --- a/torchtitan/experiments/qwen3/model/state_dict_adapter.py +++ b/torchtitan/experiments/qwen3/model/state_dict_adapter.py @@ -15,6 +15,8 @@ import re from typing import Any +from torch.distributed.checkpoint import HuggingFaceStorageReader + from torchtitan.protocols.state_dict_adapter import StateDictAdapter from .args import Qwen3ModelArgs @@ -45,6 +47,9 @@ def __init__(self, model_args: Qwen3ModelArgs, hf_assets_path: str | None): "lm_head.weight": "output.weight", } + def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: + return HuggingFaceStorageReader(path) + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: to_hf_map = {v: k for k, v in self.from_hf_map.items()} diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index d0478cc961..357ca80101 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -154,8 +154,9 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, - attn_mask_type="block_causal", + # use_flex_attn=True, + # attn_mask_type="block_causal", + hf_weight_quantized=True, ), } diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index d6afedfa34..b27b7a9d50 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -86,6 +86,9 @@ class DeepSeekV3ModelArgs(BaseModelArgs): beta_slow: int = 1 mscale: float = 1.0 + # HF checkpoint args + hf_weight_quantized: bool = False + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: seq_len = job_config.training.seq_len if seq_len > self.max_seq_len: diff --git a/torchtitan/models/deepseek_v3/model/quantization.py b/torchtitan/models/deepseek_v3/model/quantization.py deleted file mode 100644 index a8ac6003a2..0000000000 --- a/torchtitan/models/deepseek_v3/model/quantization.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torchtitan.tools.logging import logger - -# Fixed block size of 128x128 as specified in the algorithm -BLOCK_SIZE = 128 - - -def calculate_scale_shape( - weight: torch.Tensor, BLOCK_SIZE: int = BLOCK_SIZE -) -> torch.Size: - # Calculate the scale tensor shape - orig_shape = weight.shape - - # Calculate number of blocks needed - block_rows = (orig_shape[0] + BLOCK_SIZE - 1) // BLOCK_SIZE - block_cols = (orig_shape[1] + BLOCK_SIZE - 1) // BLOCK_SIZE - - # Verify scale_inv shape matches expected block dimensions - expected_scale_shape = torch.Size((block_rows, block_cols)) - - return expected_scale_shape - - -def dequantize_from_fp8( - weight: torch.Tensor, - scale_inv: torch.Tensor, - dtype=torch.bfloat16, - BLOCK_SIZE: int = BLOCK_SIZE, -) -> torch.Tensor: - # Convert to float32 for computation - float_weight = weight.to(torch.float32) - # Get original dimensions - orig_shape = weight.shape - - # Verify scale_inv shape matches expected block dimensions - expected_scale_shape = calculate_scale_shape(weight, BLOCK_SIZE) - block_rows, block_cols = expected_scale_shape - if scale_inv.shape != expected_scale_shape: - logger.warning( - f"scale_inv shape {scale_inv.shape} doesn't match expected shape {expected_scale_shape}" - ) - - # NOTE: When processing large models on-the-fly, misalignment between block boundaries - # and DTensor local shape partitioning can lead to silent numerical inaccuracies. - dequantized = float_weight.detach().clone().to(dtype=dtype) - - # Apply scaling factors to each block - for i in range(block_rows): - row_start = i * BLOCK_SIZE - row_end = min(row_start + BLOCK_SIZE, orig_shape[0]) - - for j in range(block_cols): - col_start = j * BLOCK_SIZE - col_end = min(col_start + BLOCK_SIZE, orig_shape[1]) - - # Get the block - block = float_weight[row_start:row_end, col_start:col_end] - - scale = scale_inv[i, j] - block = block * scale - - # Explicitly convert block to dtype - block_converted = block.to(dtype=torch.float32) - # Store the dequantized block - dequantized[row_start:row_end, col_start:col_end] = block_converted - - return dequantized diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index e947d70695..f0e40cd87b 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -6,16 +6,18 @@ import re +import time from typing import Any import torch +from torch.distributed.checkpoint import HuggingFaceStorageReader from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard from torchtitan.protocols.state_dict_adapter import StateDictAdapter +from torchtitan.tools.logging import logger from .args import DeepSeekV3ModelArgs -from .quantization import calculate_scale_shape, dequantize_from_fp8 class DeepSeekV3StateDictAdapter(StateDictAdapter): @@ -78,6 +80,24 @@ def __init__( self.grouped_expert_weight_shape = {} # {titan_abstract_key: shape} self.local_experts_indices = {} # {titan_abstract_key: (start_idx, end_idx)} + def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: + if self.model_args.hf_weight_quantized: + from torch.distributed.checkpoint.quantized_hf_storage import ( + QuantizedHuggingFaceStorageReader, + ) + + # NOTE: Now we use Quantized HF storage reader to read DeepSeek-V3 671B model. + # If loading checkpoints without quantization, use HuggingFaceStorageReader instead + BLOCK_SIZE = 128 + return QuantizedHuggingFaceStorageReader( + path=path, + target_dtype=torch.float32, + block_size=BLOCK_SIZE, + thread_count=8, + ) + else: + return HuggingFaceStorageReader(path) + def _calculate_strided_shard_shard_indices( self, strided_shard_dim_degree: int, @@ -220,6 +240,10 @@ def _get_local_experts_weights( Returns: Dictionary mapping individual expert keys to their DTensor weights """ + start_time = time.time() + logger.info( + f"Starting _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}" + ) device_mesh = grouped_expert_weight.device_mesh dtensor_placements = grouped_expert_weight.placements @@ -285,6 +309,11 @@ def _get_local_experts_weights( local_expert_tensors[expert_key] = expert_dtensor + end_time = time.time() + duration = end_time - start_time + logger.info( + f"Completed _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}, duration: {duration:.4f}s" + ) return local_expert_tensors def _concatenate_expert_weights_dtensor( @@ -312,6 +341,10 @@ def _concatenate_expert_weights_dtensor( Returns: Concatenated GroupedExperts weight DTensor if all experts are available, otherwise None """ + start_time = time.time() + logger.info( + f"Starting _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}" + ) # If we have all the experts for this abstract_key, concatenate them experts = expert_weights_by_layer[layer_num][abstract_key] expected_n_experts = ( @@ -341,6 +374,11 @@ def _concatenate_expert_weights_dtensor( if not expert_weights_by_layer[layer_num]: del expert_weights_by_layer[layer_num] + end_time = time.time() + duration = end_time - start_time + logger.info( + f"Completed _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}, duration: {duration:.4f}s" + ) return stacked_dtensor def _split_experts_weights( @@ -398,61 +436,14 @@ def _concatenate_expert_weights( return stacked_tensor - def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]: - """ - Dequantize the weights from float8 to float32. - """ - - scale_inv_keys = [] - for key, weight in state_dict.items(): - if key.endswith(".weight") and key + "_scale_inv" in state_dict: - scale_inv = state_dict[key + "_scale_inv"] - dequantized_weight = dequantize_from_fp8( - weight, scale_inv, dtype=torch.float32 - ) - # update the weight and remove the scale_inv tensor - state_dict[key] = dequantized_weight - scale_inv_keys.append(key + "_scale_inv") - - for key in scale_inv_keys: - state_dict.pop(key) - - return state_dict - - def _add_quantization_scale_inv_tensors( - self, state_dict: dict[str, Any] - ) -> dict[str, Any]: - """ - Add quantization scale tensors the state_dict. - """ - non_quantized_keys = [ - "input_layernorm.weight", - "post_attention_layernorm.weight", - "norm.weight", - "lm_head.weight", - "embed_tokens.weight", - "mlp.gate.weight", - ] - - weight_scale_inv_state_dict = {} - for key, value in state_dict.items(): - if key.endswith(".weight") and not any( - non_quantized_key in key for non_quantized_key in non_quantized_keys - ): - expected_scale_shape = calculate_scale_shape(value) - # add weight_scale_inv to the state_dict - weight_scale_inv_state_dict[key + "_scale_inv"] = torch.ones( - expected_scale_shape, dtype=torch.float32 - ) - - state_dict.update(weight_scale_inv_state_dict) - return state_dict - def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ 1. Convert between the HF shape and the torchtitan shape. 2. Split the GroupedExperts' weight into separate expert's wegiht. """ + start_time = time.time() + logger.info(f"Starting to_hf conversion, state_dict has {len(state_dict)} keys") + to_hf_map = {v: k for k, v in self.from_hf_map.items()} hf_state_dict = {} @@ -500,11 +491,12 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: new_key = to_hf_map[key] hf_state_dict[new_key] = value - # Prepare for dequantization - hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors( - hf_state_dict + end_time = time.time() + duration = end_time - start_time + logger.info( + f"Completed to_hf conversion, generated {len(hf_state_dict)} keys, duration: {duration:.4f}s" ) - return hf_state_dict_with_scale_inv + return hf_state_dict def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """ @@ -512,12 +504,12 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: 2. Convert between the HF shape and the torchtitan shape. 3. Concate separate expert's wegiht into GroupedExperts' weight. """ + start_time = time.time() + logger.info( + f"Starting from_hf conversion, state_dict has {len(hf_state_dict)} keys" + ) - # dequantize the tensor in state_dict and remove the scale_inv tensor - - hf_state_dict = self._dequantize(hf_state_dict) state_dict = {} - expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} for key, value in hf_state_dict.items(): @@ -565,4 +557,9 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: new_key = self.from_hf_map[key] state_dict[new_key] = value + end_time = time.time() + duration = end_time - start_time + logger.info( + f"Completed from_hf conversion, processed {len(hf_state_dict)} keys, duration: {duration:.4f}s" + ) return state_dict diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml index a92a6f5a16..7d44a6a1f6 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_671b.toml @@ -39,7 +39,6 @@ local_batch_size = 4 seq_len = 4096 max_norm = 1.0 # grad norm clipping steps = 10_000 -compile = false dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -66,7 +65,7 @@ mode = "selective" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] -enable=true +enable = true components = ["loss"] # ["model", "loss"] [float8] diff --git a/torchtitan/models/llama3/model/state_dict_adapter.py b/torchtitan/models/llama3/model/state_dict_adapter.py index 2c386ece0d..1475ba2055 100644 --- a/torchtitan/models/llama3/model/state_dict_adapter.py +++ b/torchtitan/models/llama3/model/state_dict_adapter.py @@ -10,6 +10,7 @@ logger = logging.getLogger() +from torch.distributed.checkpoint import HuggingFaceStorageReader from torchtitan.protocols.state_dict_adapter import StateDictAdapter from .args import TransformerModelArgs @@ -41,6 +42,9 @@ def __init__( "lm_head.weight": "output.weight", } + def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: + return HuggingFaceStorageReader(path) + # HuggingFace permutation function (exact copy from their conversion script) def _permute(self, w, n_heads_arg, dim1=None, dim2=None): if dim1 is None: diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index 5b441e9bbf..a6fedd0af4 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -11,6 +11,9 @@ from abc import ABC, abstractmethod from typing import Any +from torch.distributed.checkpoint import HuggingFaceStorageReader + + logger = logging.getLogger() from .model import BaseModelArgs @@ -58,6 +61,19 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """ pass + @abstractmethod + def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader: + """Returns hf storage reader to read HF checkpoint + + Args: + path: the path to read HF checkpoint + + Returns: + THe HuggingFace storage reader to read rom HF checkpoint + + """ + pass + class StateDictAdapter(BaseStateDictAdapter): """State dict adapter base class which provides convenient default behavior to build fqn_to_index_mapping"""