Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion diffsynth_engine/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def unload_loras(self):
def get_tp_plan(self):
raise NotImplementedError(f"{self.__class__.__name__} does not support TP")

def get_fsdp_modules(self):
def get_fsdp_module_cls(self):
raise NotImplementedError(f"{self.__class__.__name__} does not support FSDP")


Expand Down
4 changes: 2 additions & 2 deletions diffsynth_engine/models/flux/flux_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,5 +515,5 @@ def compile_repeated_blocks(self, *args, **kwargs):
for block in self.blocks:
block.compile(*args, **kwargs)

def get_fsdp_modules(self):
return ["blocks", "single_blocks"]
def get_fsdp_module_cls(self):
return {FluxDoubleTransformerBlock, FluxSingleTransformerBlock}
5 changes: 5 additions & 0 deletions diffsynth_engine/models/qwen_image/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,8 @@ def _unmask_unattended(


class Qwen2_5_VLForConditionalGeneration(PreTrainedModel):
_supports_parallelization = True

def __init__(
self,
vision_config: Qwen2_5_VLVisionConfig,
Expand Down Expand Up @@ -1173,6 +1175,9 @@ def get_rope_index(

return position_ids, mrope_position_deltas

def get_fsdp_module_cls(self):
return {Qwen2_5_VisionBlock, Qwen2_5_VLDecoderLayer}

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
4 changes: 2 additions & 2 deletions diffsynth_engine/models/qwen_image/qwen_image_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,5 +535,5 @@ def compile_repeated_blocks(self, *args, **kwargs):
for block in self.transformer_blocks:
block.compile(*args, **kwargs)

def get_fsdp_modules(self):
return ["transformer_blocks"]
def get_fsdp_module_cls(self):
return {QwenImageTransformerBlock}
1 change: 0 additions & 1 deletion diffsynth_engine/models/wan/wan_audio_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor

class Wav2Vec2Model(PreTrainedModel):
converter = Wav2Vec2StateDictConverter()
_supports_parallelization = False

def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
super().__init__()
Expand Down
4 changes: 2 additions & 2 deletions diffsynth_engine/models/wan/wan_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,5 +502,5 @@ def compile_repeated_blocks(self, *args, **kwargs):
for block in self.single_blocks:
block.compile(*args, **kwargs)

def get_fsdp_modules(self):
return ["blocks"]
def get_fsdp_module_cls(self):
return {DiTBlock}
6 changes: 3 additions & 3 deletions diffsynth_engine/pipelines/flux_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _from_kohya(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dic
layer_id, layer_type = name.split("_", 1)
layer_type = layer_type.replace("self_attn_", "self_attn.").replace("mlp_", "mlp.")
rename = ".".join(["encoders", layer_id, clip_attn_rename_dict[layer_type]])

lora_args = {}
lora_args["alpha"] = param
lora_args["up"] = lora_state_dict[origin_key.replace(".alpha", ".lora_up.weight")]
Expand Down Expand Up @@ -517,7 +517,7 @@ def _from_state_dict(cls, state_dicts: FluxStateDicts, config: FluxPipelineConfi
if config.use_fbcache:
dit = FluxDiTFBCache.from_state_dict(
state_dicts.model,
device=init_device,
device=("cpu" if config.use_fsdp else init_device),
dtype=config.model_dtype,
in_channel=config.control_type.get_in_channel(),
attn_kwargs=attn_kwargs,
Expand All @@ -526,7 +526,7 @@ def _from_state_dict(cls, state_dicts: FluxStateDicts, config: FluxPipelineConfi
else:
dit = FluxDiT.from_state_dict(
state_dicts.model,
device=init_device,
device=("cpu" if config.use_fsdp else init_device),
dtype=config.model_dtype,
in_channel=config.control_type.get_in_channel(),
attn_kwargs=attn_kwargs,
Expand Down
6 changes: 3 additions & 3 deletions diffsynth_engine/pipelines/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip
state_dicts.encoder,
vision_config=vision_config,
config=text_config,
device=init_device,
device=("cpu" if config.use_fsdp else init_device),
dtype=config.encoder_dtype,
)
with open(QWEN_IMAGE_VAE_CONFIG_FILE, "r", encoding="utf-8") as f:
Expand All @@ -221,15 +221,15 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip
if config.use_fbcache:
dit = QwenImageDiTFBCache.from_state_dict(
state_dicts.model,
device=init_device,
device=("cpu" if config.use_fsdp else init_device),
dtype=config.model_dtype,
attn_kwargs=attn_kwargs,
relative_l1_threshold=config.fbcache_relative_l1_threshold,
)
else:
dit = QwenImageDiT.from_state_dict(
state_dicts.model,
device=init_device,
device=("cpu" if config.use_fsdp else init_device),
dtype=config.model_dtype,
attn_kwargs=attn_kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion diffsynth_engine/pipelines/sdxl_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def from_pretrained(cls, model_path_or_config: SDXLPipelineConfig) -> "SDXLImage

@classmethod
def from_state_dict(cls, state_dicts: SDXLStateDicts, config: SDXLPipelineConfig) -> "SDXLImagePipeline":
init_device = "cpu" if config.offload_mode else config.device
init_device = "cpu" if config.offload_mode is not None else config.device
tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
tokenizer_2 = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_2_CONF_PATH)
with LoRAContext():
Expand Down
2 changes: 1 addition & 1 deletion diffsynth_engine/pipelines/wan_s2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def _from_state_dict(
dit = WanS2VDiT.from_state_dict(
state_dicts.model,
config=model_config,
device=init_device,
device=("cpu" if config.use_fsdp else init_device),
dtype=config.model_dtype,
attn_kwargs=attn_kwargs,
)
Expand Down
4 changes: 2 additions & 2 deletions diffsynth_engine/pipelines/wan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def _from_state_dict(cls, state_dicts: WanStateDicts, config: WanPipelineConfig)
dit = WanDiT.from_state_dict(
dit_state_dict,
config=dit_config,
device=init_device,
device=("cpu" if config.use_fsdp else init_device),
dtype=config.model_dtype,
attn_kwargs=attn_kwargs,
)
Expand All @@ -578,7 +578,7 @@ def _from_state_dict(cls, state_dicts: WanStateDicts, config: WanPipelineConfig)
dit2 = WanDiT.from_state_dict(
dit2_state_dict,
config=dit_config,
device=init_device,
device=("cpu" if config.use_fsdp else init_device),
dtype=config.model_dtype,
attn_kwargs=attn_kwargs,
)
Expand Down
21 changes: 5 additions & 16 deletions diffsynth_engine/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel.style import ParallelStyle
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
from contextlib import contextmanager
from datetime import timedelta
from functools import partial
from typing import Dict, List, Union, Optional
from typing import Dict, List, Set, Type, Union, Optional
from queue import Empty

import diffsynth_engine.models.basic.attention as attention_ops
Expand Down Expand Up @@ -174,25 +174,14 @@ def to_device(data, device):
def shard_model(
module: nn.Module,
device_id: int | torch.device,
wrap_module_cls: Set[Type[nn.Module]],
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD,
wrap_module_names: Optional[List[str]] = None,
):
wrap_module_names = wrap_module_names or []

def wrap_fn(m):
for name in wrap_module_names:
submodule = getattr(module, name)
if isinstance(submodule, nn.ModuleList) and m in submodule:
return True
elif not isinstance(submodule, nn.ModuleList) and m is submodule:
return True
return False

return FSDP(
module,
device_id=device_id,
sharding_strategy=sharding_strategy,
auto_wrap_policy=partial(lambda_auto_wrap_policy, lambda_fn=wrap_fn),
auto_wrap_policy=partial(transformer_auto_wrap_policy, transformer_layer_cls=wrap_module_cls),
)


Expand Down Expand Up @@ -283,7 +272,7 @@ def wrap_for_parallel(module: Union[PreTrainedModel, BasePipeline]):
parallelize_plan=module.get_tp_plan(),
)
elif use_fsdp:
module = shard_model(module, device_id=device, wrap_module_names=module.get_fsdp_modules())
module = shard_model(module, device_id=device, wrap_module_cls=module.get_fsdp_module_cls())
return module

module = None
Expand Down