Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,10 +431,16 @@ def dcp_load(
self.sd_adapter is not None
), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided."
hf_state_dict = self.sd_adapter.to_hf(state_dict)
hf_storage_reader = self.sd_adapter.get_hf_storage_reader(checkpoint_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this is cleaner!


begin_load = time.monotonic()
logger.info("Starting dcp.load with HuggingFaceStorageReader")
dcp.load(
hf_state_dict,
storage_reader=HuggingFaceStorageReader(path=checkpoint_id),
storage_reader=hf_storage_reader,
)
logger.info(
f"dcp.load with HuggingFaceStorageReader completed in {time.monotonic() - begin_load:.2f} seconds"
)

state_dict = self.sd_adapter.from_hf(hf_state_dict)
Expand Down
5 changes: 5 additions & 0 deletions torchtitan/experiments/qwen3/model/state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import re
from typing import Any

from torch.distributed.checkpoint import HuggingFaceStorageReader

from torchtitan.protocols.state_dict_adapter import StateDictAdapter

from .args import Qwen3ModelArgs
Expand Down Expand Up @@ -45,6 +47,9 @@ def __init__(self, model_args: Qwen3ModelArgs, hf_assets_path: str | None):
"lm_head.weight": "output.weight",
}

def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader:
return HuggingFaceStorageReader(path)

def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:

to_hf_map = {v: k for k, v in self.from_hf_map.items()}
Expand Down
5 changes: 3 additions & 2 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,9 @@
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
use_flex_attn=True,
attn_mask_type="block_causal",
# use_flex_attn=True,
# attn_mask_type="block_causal",
hf_weight_quantized=True,
),
}

Expand Down
3 changes: 3 additions & 0 deletions torchtitan/models/deepseek_v3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class DeepSeekV3ModelArgs(BaseModelArgs):
beta_slow: int = 1
mscale: float = 1.0

# HF checkpoint args
hf_weight_quantized: bool = False

def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
seq_len = job_config.training.seq_len
if seq_len > self.max_seq_len:
Expand Down
73 changes: 0 additions & 73 deletions torchtitan/models/deepseek_v3/model/quantization.py

This file was deleted.

115 changes: 56 additions & 59 deletions torchtitan/models/deepseek_v3/model/state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@


import re
import time
from typing import Any

import torch
from torch.distributed.checkpoint import HuggingFaceStorageReader
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard
from torchtitan.protocols.state_dict_adapter import StateDictAdapter
from torchtitan.tools.logging import logger

from .args import DeepSeekV3ModelArgs
from .quantization import calculate_scale_shape, dequantize_from_fp8


class DeepSeekV3StateDictAdapter(StateDictAdapter):
Expand Down Expand Up @@ -78,6 +80,24 @@ def __init__(
self.grouped_expert_weight_shape = {} # {titan_abstract_key: shape}
self.local_experts_indices = {} # {titan_abstract_key: (start_idx, end_idx)}

def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader:
if self.model_args.hf_weight_quantized:
from torch.distributed.checkpoint.quantized_hf_storage import (
QuantizedHuggingFaceStorageReader,
)

# NOTE: Now we use Quantized HF storage reader to read DeepSeek-V3 671B model.
# If loading checkpoints without quantization, use HuggingFaceStorageReader instead
BLOCK_SIZE = 128
return QuantizedHuggingFaceStorageReader(
path=path,
target_dtype=torch.float32,
block_size=BLOCK_SIZE,
thread_count=8,
)
else:
return HuggingFaceStorageReader(path)

def _calculate_strided_shard_shard_indices(
self,
strided_shard_dim_degree: int,
Expand Down Expand Up @@ -220,6 +240,10 @@ def _get_local_experts_weights(
Returns:
Dictionary mapping individual expert keys to their DTensor weights
"""
start_time = time.time()
logger.info(
f"Starting _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}"
)
device_mesh = grouped_expert_weight.device_mesh
dtensor_placements = grouped_expert_weight.placements

Expand Down Expand Up @@ -285,6 +309,11 @@ def _get_local_experts_weights(

local_expert_tensors[expert_key] = expert_dtensor

end_time = time.time()
duration = end_time - start_time
logger.info(
f"Completed _get_local_experts_weights for layer {layer_id}, abstract_key: {abstract_key}, duration: {duration:.4f}s"
)
return local_expert_tensors

def _concatenate_expert_weights_dtensor(
Expand Down Expand Up @@ -312,6 +341,10 @@ def _concatenate_expert_weights_dtensor(
Returns:
Concatenated GroupedExperts weight DTensor if all experts are available, otherwise None
"""
start_time = time.time()
logger.info(
f"Starting _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}"
)
# If we have all the experts for this abstract_key, concatenate them
experts = expert_weights_by_layer[layer_num][abstract_key]
expected_n_experts = (
Expand Down Expand Up @@ -341,6 +374,11 @@ def _concatenate_expert_weights_dtensor(
if not expert_weights_by_layer[layer_num]:
del expert_weights_by_layer[layer_num]

end_time = time.time()
duration = end_time - start_time
logger.info(
f"Completed _concatenate_expert_weights_dtensor for layer {layer_num}, abstract_key: {abstract_key}, duration: {duration:.4f}s"
)
return stacked_dtensor

def _split_experts_weights(
Expand Down Expand Up @@ -398,61 +436,14 @@ def _concatenate_expert_weights(

return stacked_tensor

def _dequantize(self, state_dict: dict[str, Any]) -> dict[str, Any]:
"""
Dequantize the weights from float8 to float32.
"""

scale_inv_keys = []
for key, weight in state_dict.items():
if key.endswith(".weight") and key + "_scale_inv" in state_dict:
scale_inv = state_dict[key + "_scale_inv"]
dequantized_weight = dequantize_from_fp8(
weight, scale_inv, dtype=torch.float32
)
# update the weight and remove the scale_inv tensor
state_dict[key] = dequantized_weight
scale_inv_keys.append(key + "_scale_inv")

for key in scale_inv_keys:
state_dict.pop(key)

return state_dict

def _add_quantization_scale_inv_tensors(
self, state_dict: dict[str, Any]
) -> dict[str, Any]:
"""
Add quantization scale tensors the state_dict.
"""
non_quantized_keys = [
"input_layernorm.weight",
"post_attention_layernorm.weight",
"norm.weight",
"lm_head.weight",
"embed_tokens.weight",
"mlp.gate.weight",
]

weight_scale_inv_state_dict = {}
for key, value in state_dict.items():
if key.endswith(".weight") and not any(
non_quantized_key in key for non_quantized_key in non_quantized_keys
):
expected_scale_shape = calculate_scale_shape(value)
# add weight_scale_inv to the state_dict
weight_scale_inv_state_dict[key + "_scale_inv"] = torch.ones(
expected_scale_shape, dtype=torch.float32
)

state_dict.update(weight_scale_inv_state_dict)
return state_dict

def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
"""
1. Convert between the HF shape and the torchtitan shape.
2. Split the GroupedExperts' weight into separate expert's wegiht.
"""
start_time = time.time()
logger.info(f"Starting to_hf conversion, state_dict has {len(state_dict)} keys")

to_hf_map = {v: k for k, v in self.from_hf_map.items()}

hf_state_dict = {}
Expand Down Expand Up @@ -500,24 +491,25 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
new_key = to_hf_map[key]
hf_state_dict[new_key] = value

# Prepare for dequantization
hf_state_dict_with_scale_inv = self._add_quantization_scale_inv_tensors(
hf_state_dict
end_time = time.time()
duration = end_time - start_time
logger.info(
f"Completed to_hf conversion, generated {len(hf_state_dict)} keys, duration: {duration:.4f}s"
)
return hf_state_dict_with_scale_inv
return hf_state_dict

def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
"""
1. When loading from HF checkpoint, dequantize the weights from float8 to float32.
2. Convert between the HF shape and the torchtitan shape.
3. Concate separate expert's wegiht into GroupedExperts' weight.
"""
start_time = time.time()
logger.info(
f"Starting from_hf conversion, state_dict has {len(hf_state_dict)} keys"
)

# dequantize the tensor in state_dict and remove the scale_inv tensor

hf_state_dict = self._dequantize(hf_state_dict)
state_dict = {}

expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}}

for key, value in hf_state_dict.items():
Expand Down Expand Up @@ -565,4 +557,9 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
new_key = self.from_hf_map[key]
state_dict[new_key] = value

end_time = time.time()
duration = end_time - start_time
logger.info(
f"Completed from_hf conversion, processed {len(hf_state_dict)} keys, duration: {duration:.4f}s"
)
return state_dict
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ local_batch_size = 4
seq_len = 4096
max_norm = 1.0 # grad norm clipping
steps = 10_000
compile = false
dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)

[parallelism]
Expand All @@ -66,7 +65,7 @@ mode = "selective" # ["none", "selective", "full"]
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[compile]
enable=true
enable = true
components = ["loss"] # ["model", "loss"]

[float8]
Expand Down
4 changes: 4 additions & 0 deletions torchtitan/models/llama3/model/state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

logger = logging.getLogger()

from torch.distributed.checkpoint import HuggingFaceStorageReader
from torchtitan.protocols.state_dict_adapter import StateDictAdapter

from .args import TransformerModelArgs
Expand Down Expand Up @@ -41,6 +42,9 @@ def __init__(
"lm_head.weight": "output.weight",
}

def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader:
return HuggingFaceStorageReader(path)

# HuggingFace permutation function (exact copy from their conversion script)
def _permute(self, w, n_heads_arg, dim1=None, dim2=None):
if dim1 is None:
Expand Down
16 changes: 16 additions & 0 deletions torchtitan/protocols/state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from abc import ABC, abstractmethod
from typing import Any

from torch.distributed.checkpoint import HuggingFaceStorageReader


logger = logging.getLogger()

from .model import BaseModelArgs
Expand Down Expand Up @@ -58,6 +61,19 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
"""
pass

@abstractmethod
def get_hf_storage_reader(self, path: str) -> HuggingFaceStorageReader:
"""Returns hf storage reader to read HF checkpoint

Args:
path: the path to read HF checkpoint

Returns:
THe HuggingFace storage reader to read rom HF checkpoint

"""
pass


class StateDictAdapter(BaseStateDictAdapter):
"""State dict adapter base class which provides convenient default behavior to build fqn_to_index_mapping"""
Expand Down
Loading