diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 8d8381ed0428..3ddd010a97d3 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -382,7 +382,13 @@ def _replace(self, child, name, conv_linear_layer): if self.conv_linear_layer: return Conv_LinearALlreduce(child, self.mp_group, name=name) elif name == "lm_head" or name == 'embed_out': - return LmHeadLinearAllreduce(child, self.mp_group) + if is_autotp_training_mode(): + return child + + ## gather output column parallel + ## return LinearLayer(child, self.mp_group, name=name, gather_output=True) + else: + return LmHeadLinearAllreduce(child, self.mp_group) return LinearAllreduce(child, self.mp_group, name=name) else: diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 3c7491e99999..2e494f82cfa3 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -110,6 +110,36 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]: return None, grad_output +class GatherTensor(torch.autograd.Function): + """Gather the input from model parallel region and concatinate.""" + + @staticmethod + def forward(ctx, group, input_): + """Forward function.""" + # gather along last dim + world_size = dist.get_world_size(group) + if world_size == 1: + return + ctx.group = group + ctx.world_size = world_size + + gather_shape = (world_size, ) + input_.shape + output = torch.empty(gather_shape, dtype=input_.dtype, device=get_accelerator().current_device_name()) + dist.all_gather_into_tensor(output, input_.contiguous(), group) + tensor_list = output.chunk(world_size, dim=0) + output = torch.cat(tensor_list, dim=-1).squeeze(0).contiguous() + return output + + @staticmethod + def backward(ctx, grad_output): + #split along last_dim + """Backward function.""" + rank = dist.get_rank(ctx.group) + input_list = torch.chunk(grad_output, ctx.world_size, -1) + grad_output = input_list[rank].contiguous() + return None, grad_output + + class TensorParallel_Layer(nn.Module, ABC): """ A base class for model layers with tensor parallelism support. @@ -394,16 +424,18 @@ def uneven_partition(self, params_list): #remove kwargs from partition. class LinearLayer(TensorParallel_Layer): - def __init__(self, module, mp_group=None, skip_partition=False, **kwargs): + def __init__(self, module, mp_group=None, skip_partition=False, gather_output=False, **kwargs): super(LinearLayer, self).__init__(mp_group, **kwargs) self.weight = module.weight self.bias = module.bias + if not skip_partition: self._tp_partition([self.weight, self.bias]) self.support_training = True self.config_tp_params(self.weight) if self.bias is not None: self.config_tp_params(self.bias) + self.gather_output = gather_output def forward(self, input): if getattr(self, 'mp_group', None) is not None: @@ -411,6 +443,10 @@ def forward(self, input): output = torch.matmul(input, self.weight.transpose(-1, -2)) if self.bias is not None: output += self.bias + + if self.gather_output: + output = GatherTensor.apply(self.mp_group, output) + return output @torch.no_grad() @@ -598,6 +634,7 @@ def __init__(self, module, mp_group, **kwargs): def forward(self, input): input_shard_size = get_shard_size(input.shape[-1], self.tp_world_size, "lm_head") input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.tp_world_size, "lm_head")[0:self.tp_index]) + output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size], self.weight.transpose(-1, -2)) if self.mp_group is not None: diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index ed94a5021fee..ef2ae2394152 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -335,9 +335,6 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None): return new_module def set_lm_head(module): - if is_autotp_training_mode(): - # we need to handle autoTP training mode separately. - return embedding_weight = None for n, p in module.named_parameters():