Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,17 @@ jobs:
# Actual tests
encoder-test:
- 'fastvideo/v1/models/encoders/**'
- 'fastvideo/v1/models/loaders/**'
- 'fastvideo/v1/models/loader/**'
- 'fastvideo/v1/tests/encoders/**'
- *common-paths
vae-test:
- 'fastvideo/v1/models/vaes/**'
- 'fastvideo/v1/models/loaders/**'
- 'fastvideo/v1/models/loader/**'
- 'fastvideo/v1/tests/vaes/**'
- *common-paths
transformer-test:
- 'fastvideo/v1/models/dits/**'
- 'fastvideo/v1/models/loaders/**'
- 'fastvideo/v1/models/loader/**'
- 'fastvideo/v1/tests/transformers/**'
- 'fastvideo/v1/layers/**'
- 'fastvideo/v1/attention/**'
Expand Down
2 changes: 1 addition & 1 deletion examples/inference/basic/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def main():
# attempt to identify the optimal arguments.
generator = VideoGenerator.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
# if num_gpus > 1, FastVideo will automatically handle distributed setup
# FastVideo will automatically handle distributed setup
num_gpus=2,
use_fsdp_inference=True,
use_cpu_offload=False
Expand Down
6 changes: 4 additions & 2 deletions fastvideo/v1/configs/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field, fields
from typing import Any, Dict
from typing import Any, Dict, List, Tuple

from fastvideo.v1.logger import init_logger

Expand All @@ -12,7 +12,9 @@
# 3. Any field in ArchConfig is fixed upon initialization, and should be hidden away from users
@dataclass
class ArchConfig:
pass
stacked_params_mapping: List[Tuple[str, str, str]] = field(
default_factory=list
) # mapping from huggingface weight names to custom names


@dataclass
Expand Down
8 changes: 3 additions & 5 deletions fastvideo/v1/configs/models/dits/stepvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@
from fastvideo.v1.configs.models.dits.base import DiTArchConfig, DiTConfig


def is_blocks(n: str, m) -> bool:
return "blocks" in n and str.isdigit(n.split(".")[-1])


@dataclass
class StepVideoArchConfig(DiTArchConfig):
_fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks])
_fsdp_shard_conditions: list = field(
default_factory=lambda:
[lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit()])

_param_names_mapping: dict = field(
default_factory=lambda: {
Expand Down
5 changes: 4 additions & 1 deletion fastvideo/v1/configs/models/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ class TextEncoderArchConfig(EncoderArchConfig):
output_past: bool = True
scalable_attention: bool = True
tie_word_embeddings: bool = False

stacked_params_mapping: List[Tuple[str, str, str]] = field(
default_factory=list
) # mapping from huggingface weight names to custom names
tokenizer_kwargs: Dict[str, Any] = field(default_factory=dict)
_fsdp_shard_conditions: list = field(default_factory=lambda: [])

def __post_init__(self) -> None:
self.tokenizer_kwargs = {
Expand Down
26 changes: 25 additions & 1 deletion fastvideo/v1/configs/models/encoders/clip.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional, Tuple

from fastvideo.v1.configs.models.encoders.base import (ImageEncoderArchConfig,
ImageEncoderConfig,
TextEncoderArchConfig,
TextEncoderConfig)


def _is_transformer_layer(n: str, m) -> bool:
return "layers" in n and str.isdigit(n.split(".")[-1])


def _is_embeddings(n: str, m) -> bool:
return n.endswith("embeddings")


@dataclass
class CLIPTextArchConfig(TextEncoderArchConfig):
vocab_size: int = 49408
Expand All @@ -27,6 +35,15 @@ class CLIPTextArchConfig(TextEncoderArchConfig):
bos_token_id: int = 49406
eos_token_id: int = 49407
text_len: int = 77
stacked_params_mapping: List[Tuple[str, str,
str]] = field(default_factory=lambda: [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
])
_fsdp_shard_conditions: list = field(
default_factory=lambda: [_is_transformer_layer, _is_embeddings])


@dataclass
Expand All @@ -45,6 +62,13 @@ class CLIPVisionArchConfig(ImageEncoderArchConfig):
attention_dropout: float = 0.0
initializer_range: float = 0.02
initializer_factor: float = 1.0
stacked_params_mapping: List[Tuple[str, str,
str]] = field(default_factory=lambda: [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
])


@dataclass
Expand Down
26 changes: 25 additions & 1 deletion fastvideo/v1/configs/models/encoders/llama.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional, Tuple

from fastvideo.v1.configs.models.encoders.base import (TextEncoderArchConfig,
TextEncoderConfig)


def _is_transformer_layer(n: str, m) -> bool:
return "layers" in n and str.isdigit(n.split(".")[-1])


def _is_embeddings(n: str, m) -> bool:
return n.endswith("embed_tokens")


def _is_final_norm(n: str, m) -> bool:
return n.endswith("norm")


@dataclass
class LlamaArchConfig(TextEncoderArchConfig):
vocab_size: int = 32000
Expand All @@ -32,6 +44,18 @@ class LlamaArchConfig(TextEncoderArchConfig):
head_dim: Optional[int] = None
hidden_state_skip_layer: int = 2
text_len: int = 256
stacked_params_mapping: List[Tuple[str, str, str]] = field(
default_factory=lambda: [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0), # type: ignore
(".gate_up_proj", ".up_proj", 1), # type: ignore
])
_fsdp_shard_conditions: list = field(
default_factory=lambda:
[_is_transformer_layer, _is_embeddings, _is_final_norm])


@dataclass
Expand Down
24 changes: 23 additions & 1 deletion fastvideo/v1/configs/models/encoders/t5.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional, Tuple

from fastvideo.v1.configs.models.encoders.base import (TextEncoderArchConfig,
TextEncoderConfig)


def _is_transformer_layer(n: str, m) -> bool:
return "block" in n and str.isdigit(n.split(".")[-1])


def _is_embeddings(n: str, m) -> bool:
return n.endswith("shared")


def _is_final_layernorm(n: str, m) -> bool:
return n.endswith("final_layer_norm")


@dataclass
class T5ArchConfig(TextEncoderArchConfig):
vocab_size: int = 32128
Expand All @@ -29,6 +41,16 @@ class T5ArchConfig(TextEncoderArchConfig):
eos_token_id: int = 1
classifier_dropout: float = 0.0
text_len: int = 512
stacked_params_mapping: List[Tuple[str, str,
str]] = field(default_factory=lambda: [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q", "q"),
(".qkv_proj", ".k", "k"),
(".qkv_proj", ".v", "v"),
])
_fsdp_shard_conditions: list = field(
default_factory=lambda:
[_is_transformer_layer, _is_embeddings, _is_final_layernorm])

# Referenced from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py
def __post_init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
build_parquet_iterable_style_dataloader)
from fastvideo.v1.distributed import get_world_rank
from fastvideo.v1.distributed.parallel_state import (
cleanup_dist_env_and_memory, get_torch_device,
cleanup_dist_env_and_memory, get_local_torch_device,
maybe_init_distributed_environment_and_model_parallel)
from fastvideo.v1.logger import init_logger

Expand Down Expand Up @@ -148,8 +148,8 @@ def main() -> None:
break

# Move data to device
latents = latents.to(get_torch_device())
embeddings = embeddings.to(get_torch_device())
latents = latents.to(get_local_torch_device())
embeddings = embeddings.to(get_local_torch_device())

# Calculate actual batch size
batch_size = latents.size(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
build_parquet_map_style_dataloader)
from fastvideo.v1.distributed import get_world_rank
from fastvideo.v1.distributed.parallel_state import (
cleanup_dist_env_and_memory, get_torch_device,
cleanup_dist_env_and_memory, get_local_torch_device,
maybe_init_distributed_environment_and_model_parallel)
from fastvideo.v1.logger import init_logger

Expand Down Expand Up @@ -165,8 +165,8 @@ def main() -> None:
break

# Move data to device
latents = latents.to(get_torch_device())
embeddings = embeddings.to(get_torch_device())
latents = latents.to(get_local_torch_device())
embeddings = embeddings.to(get_local_torch_device())

# Calculate actual batch size
batch_size = latents.size(0)
Expand Down
10 changes: 5 additions & 5 deletions fastvideo/v1/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from fastvideo.v1.distributed.communication_op import *
from fastvideo.v1.distributed.parallel_state import (
cleanup_dist_env_and_memory, get_dp_group, get_dp_rank, get_dp_world_size,
get_sp_group, get_sp_parallel_rank, get_sp_world_size, get_torch_device,
get_tp_group, get_tp_rank, get_tp_world_size, get_world_group,
get_world_rank, get_world_size, init_distributed_environment,
initialize_model_parallel,
get_local_torch_device, get_sp_group, get_sp_parallel_rank,
get_sp_world_size, get_tp_group, get_tp_rank, get_tp_world_size,
get_world_group, get_world_rank, get_world_size,
init_distributed_environment, initialize_model_parallel,
maybe_init_distributed_environment_and_model_parallel,
model_parallel_is_initialized)
from fastvideo.v1.distributed.utils import *
Expand Down Expand Up @@ -40,5 +40,5 @@
"get_tp_world_size",

# Get torch device
"get_torch_device",
"get_local_torch_device",
]
38 changes: 32 additions & 6 deletions fastvideo/v1/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import torch
import torch.distributed
import torch.distributed as dist
from torch.distributed import Backend, ProcessGroup, ReduceOp

import fastvideo.v1.envs as envs
Expand Down Expand Up @@ -692,13 +693,19 @@ def destroy(self) -> None:


_WORLD: Optional[GroupCoordinator] = None
_NODE: Optional[GroupCoordinator] = None


def get_world_group() -> GroupCoordinator:
assert _WORLD is not None, ("world group is not initialized")
return _WORLD


def get_node_group() -> GroupCoordinator:
assert _NODE is not None, ("node group is not initialized")
return _NODE


def init_world_group(ranks: List[int], local_rank: int,
backend: str) -> GroupCoordinator:
return GroupCoordinator(
Expand All @@ -710,6 +717,18 @@ def init_world_group(ranks: List[int], local_rank: int,
)


def init_node_group(local_rank: int, backend: str):
cpu_group = get_world_group().cpu_group
node_ranks = same_node_ranks(cpu_group)
node_size = len(node_ranks)
all_node_ranks = [
list(range(i * node_size, (i + 1) * node_size))
for i in range(dist.get_world_size() // node_size)
]
global _NODE
_NODE = init_model_parallel_group(all_node_ranks, local_rank, backend)


def init_model_parallel_group(
group_ranks: List[List[int]],
local_rank: int,
Expand Down Expand Up @@ -782,6 +801,8 @@ def init_distributed_environment(
else:
assert _WORLD.world_size == torch.distributed.get_world_size(), (
"world group already initialized with a different world size")
# Init a group for each node
init_node_group(local_rank, backend)


_SP: Optional[GroupCoordinator] = None
Expand Down Expand Up @@ -904,7 +925,7 @@ def get_dp_rank() -> int:
return get_dp_group().rank_in_group


def get_torch_device() -> torch.device:
def get_local_torch_device() -> torch.device:
"""Return the torch device for the current rank."""
return torch.device(f"cuda:{envs.LOCAL_RANK}")

Expand Down Expand Up @@ -1021,17 +1042,22 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
"torch._C._host_emptyCache() only available in Pytorch >=2.5")


def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
source_rank: int = 0) -> List[bool]:
def same_node_ranks(pg: Union[ProcessGroup, StatelessProcessGroup],
source_rank: int = 0) -> List[int]:
"""
This is a collective operation that returns if each rank is in the same node
This is a collective operation that returns ranks that are in the same node
as the source rank. It tests if processes are attached to the same
memory system (shared access to shared memory).
Args:
pg: the global process group to test
source_rank: the rank to test against
Returns:
A list of ranks that are in the same node as the source rank.
"""
if isinstance(pg, ProcessGroup):
assert torch.distributed.get_backend(
pg) != torch.distributed.Backend.NCCL, (
"in_the_same_node_as should be tested with a non-NCCL group.")
"same_node_ranks should be tested with a non-NCCL group.")
# local rank inside the group
rank = torch.distributed.get_rank(group=pg)
world_size = torch.distributed.get_world_size(group=pg)
Expand Down Expand Up @@ -1103,7 +1129,7 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
aggregated_data += rank_data

return [x == 1 for x in aggregated_data.tolist()]
return [i for i, x in enumerate(aggregated_data.tolist()) if x == 1]


def initialize_tensor_parallel_group(
Expand Down
Loading
Loading