Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 102 additions & 3 deletions torchtitan/experiments/auto_parallel/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,98 @@

from autoparallel.api import AutoParallel

from torch.distributed import DeviceMesh
from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.tensor.placement_types import Replicate, Shard
from torch.distributed.tensor.placement_types import Partial, Replicate, Shard

from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims

from torchtitan.tools.logging import logger


def group_mm_nodes_with_its_gradients(nodes):
fwd_nodes = [n for n in nodes if "nn_module_stack" in n.meta]
bwd_nodes = [n for n in nodes if "fwd_nn_module_stack" in n.meta]
assert len(fwd_nodes) * 2 == len(bwd_nodes)
res = {}
for fwd_node in fwd_nodes:
o = []
for bwd_node in bwd_nodes:
if fwd_node.meta["nn_module_stack"] == bwd_node.meta["fwd_nn_module_stack"]:
o.append(bwd_node)
assert len(o) == 2
res[fwd_node] = o
return res


def force_tp_constraints(autop, mm_nodes, feat_dim=1, bwd_constraint=False):
# out = x @ w - S(0)R, RS(1) -> S(0)S(1)
# g_w = g.T @ x - S(1)S(0), S(0)R -> PS(0)
# g_x = g @ w.T - S(0)S(1), RS(0) -> S(0)P

add_node_constraint = autop.sharding_optimizer.add_node_constraint
fwd_bwd_groups = group_mm_nodes_with_its_gradients(mm_nodes)
fwd_nodes = list(fwd_bwd_groups.keys())
dim1 = 0 if feat_dim == 1 else 1
dim2 = 1 if feat_dim == 1 else 0
# assume there are 7 mm nodes per transformer block
# skip last mm as it's the final projection layer
assert (
len(fwd_nodes) - 1
) % 7 == 0, f"expected 7 mm nodes per transformer block, {len(fwd_nodes) - 1}"
for block in range(0, len(fwd_nodes) - 1, 7):
fwd_nodes_block = fwd_nodes[block : block + 7]
# force the first 3 mm nodes to be S(0)S(1)
the_nodes = fwd_nodes_block[:3] + fwd_nodes_block[4:6]
for n in the_nodes:
add_node_constraint(n, (Shard(0), Shard(feat_dim)))
add_node_constraint(n.all_input_nodes[0], (Shard(0), Replicate()))
add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(1)))

if bwd_constraint:
bwd_nodes = fwd_bwd_groups[n]
# first is g_w, second is g_x
add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim1)))
add_node_constraint(bwd_nodes[1], (Shard(0), Partial()))

# add reduction to finish TP, yielding S(0)P
the_nodes = fwd_nodes_block[3:4] + fwd_nodes_block[6:7]
for n in the_nodes:
add_node_constraint(n, (Shard(0), Partial()))
add_node_constraint(n.all_input_nodes[0], (Shard(0), Shard(feat_dim)))
add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(0)))

if bwd_constraint:
bwd_nodes = fwd_bwd_groups[n]
# first is g_w, second is g_x
add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim2)))
add_node_constraint(bwd_nodes[1], (Shard(0), Shard(feat_dim)))


def add_tp_constraints(autop):
mm_nodes = autop.gm.graph.find_nodes(
op="call_function", target=torch.ops.aten.mm.default
)
einsum_nodes = autop.gm.graph.find_nodes(
op="call_function", target=torch.ops.aten.einsum.default
)
assert (len(mm_nodes) > 0) ^ (
len(einsum_nodes) > 0
), f"only one should be non-empty, got {len(mm_nodes)} and {len(einsum_nodes)}"
feat_dim = 1 if len(mm_nodes) > 0 else 2
tgt_nodes = mm_nodes + einsum_nodes
force_tp_constraints(autop, tgt_nodes, feat_dim=feat_dim, bwd_constraint=True)

if einsum_nodes:
# add sequence parallelism if we have einsum nodes
autop.sharding_optimizer.add_node_constraint(
list(tgt_nodes[3].users)[0], (Shard(0), Partial())
)
autop.sharding_optimizer.add_node_constraint(
list(list(tgt_nodes[3].users)[0].users)[0], (Shard(0), Shard(1))
)


def parallelize_llama(
model,
parallel_dims: ParallelDims,
Expand All @@ -33,6 +115,7 @@ def parallelize_llama(
the model must fit on GPU or CPU memory.
"""
world_mesh = parallel_dims.world_mesh

def input_fn():
global_batch_size = job_config.training.global_batch_size
if global_batch_size < 0:
Expand Down Expand Up @@ -62,6 +145,17 @@ def input_fn():
lambda bucket_idx: 1000 / parallel_dims.tp
)

# XXX MICROPIPELINE
enable_async_tp = True
if enable_async_tp:
mesh = world_mesh
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

enable_symm_mem_for_group(mesh["tp"].get_group().group_name)
torch._inductor.config._micro_pipeline_tp = True
torch._inductor.config.reorder_for_compute_comm_overlap = False
# XXX--- MICROPIPELINE

# bail out
# model = model_fn()
# return model
Expand All @@ -78,6 +172,7 @@ def input_fn():
world_mesh,
mp_policy=mp_policy,
compile=job_config.compile,
repeated_subgraphs=True,
) as autop:
autop.add_parameter_memory_constraint(low=None, high=None)

Expand All @@ -101,7 +196,8 @@ def input_fn():
)
out_sharding = x_sharding
loss_parallel_enabled = (
parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel
parallel_dims.tp_enabled
and not job_config.parallelism.disable_loss_parallel
)
if loss_parallel_enabled:
out_sharding = tuple(
Expand All @@ -111,6 +207,9 @@ def input_fn():
)
autop.add_input_constraints([x_sharding])
autop.add_output_constraints([out_sharding])
enable_manual_constraint = True
if enable_manual_constraint:
add_tp_constraints(autop)
t0 = time.time()
sharding_placement = autop.optimize_placement()
t1 = time.time()
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
),
"8B": TransformerModelArgs(
dim=4096,
n_layers=32,
n_layers=1,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
Expand Down
Loading