diff --git a/xtuner/v1/ops/moe/npu/group_gemm.py b/xtuner/v1/ops/moe/npu/group_gemm.py index f09ceaa5a..4ce113023 100644 --- a/xtuner/v1/ops/moe/npu/group_gemm.py +++ b/xtuner/v1/ops/moe/npu/group_gemm.py @@ -1,9 +1,8 @@ import torch +from mindspeed.core.fusions.grouped_matmul import Ops def npu_group_gemm(x: torch.Tensor, weights: torch.Tensor, split_sizes: torch.Tensor) -> torch.Tensor: - from mindspeed.core.fusions.grouped_matmul import Ops - weights = weights.transpose(1, 2) out = Ops.gmm(x, weights, split_sizes, trans_b=False)