Skip to content

Conversation

akroviakov
Copy link
Contributor

This PR adds distribution patterns for scattered load and store ops, chunk size included.

XeGPU moves toward offsets being part of the load/store ops, so the pass only supports this case. Manipulating a vector of offsets indirectly through create_tdesc is complex and soon to become obsolete anyway.
This PR assumes the SIMT-adapted scatter ops verification introduced in #154653. The distribution itself can be reviewed in the meantime.

@llvmbot
Copy link
Member

llvmbot commented Aug 22, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Artem Kroviakov (akroviakov)

Changes

This PR adds distribution patterns for scattered load and store ops, chunk size included.

XeGPU moves toward offsets being part of the load/store ops, so the pass only supports this case. Manipulating a vector of offsets indirectly through create_tdesc is complex and soon to become obsolete anyway.
This PR assumes the SIMT-adapted scatter ops verification introduced in #154653. The distribution itself can be reviewed in the meantime.


Full diff: https://github.com/llvm/llvm-project/pull/154949.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp (+137-4)
  • (modified) mlir/test/Dialect/XeGPU/subgroup-distribute.mlir (+36-13)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 2088c3c7fc5ec..a1e5855aed264 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -811,6 +811,136 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+struct StoreDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    auto yield = cast<gpu::YieldOp>(
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    Operation *lastNode = yield->getPrevNode();
+    auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
+    if (!storeScatterOp)
+      return failure();
+    else if (!storeScatterOp.getOffsets())
+      return rewriter.notifyMatchFailure(storeScatterOp,
+                                         "Store op must have offsets argument");
+    else if (cast<VectorType>(storeScatterOp.getOffsets().getType())
+                 .getRank() != 1)
+      return rewriter.notifyMatchFailure(storeScatterOp,
+                                         "Expected 1D offsets vector");
+
+    VectorType storeVecTy =
+        cast<VectorType>(storeScatterOp.getValue().getType());
+    assert(storeVecTy.getRank() <= 2 &&
+           "Expected at most 2D result at SG level");
+    VectorType distStoreVecTy;
+    if (storeVecTy.getRank() == 2)
+      distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0);
+    else // rank 1
+      distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
+
+    SmallVector<size_t> newRetIndices;
+    SmallVector<Value> operands =
+        llvm::to_vector_of<Value>(storeScatterOp->getOperands());
+    SmallVector<Type> operandTypes =
+        llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
+    operandTypes[0] = distStoreVecTy;
+
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, operands, operandTypes, newRetIndices);
+    SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
+        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+    Value offsetsVec = newStoreScatterOpOperands[2];
+    Value maskVec = newStoreScatterOpOperands[3];
+
+    auto loc = newWarpOp.getLoc();
+    Value laneId = warpOp.getLaneid();
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Value laneOffset =
+        vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
+    laneOffset = vector::BroadcastOp::create(
+        rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
+    Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
+    laneMask = vector::BroadcastOp::create(
+        rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
+    newStoreScatterOpOperands[2] = laneOffset;
+    newStoreScatterOpOperands[3] = laneMask;
+
+    xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
+        rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
+        storeScatterOp->getAttrs());
+    xegpu::removeLayoutAttrs(newOp);
+    rewriter.eraseOp(storeScatterOp);
+    return success();
+  }
+};
+
+struct LoadDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand = getWarpResult(warpOp, [&](Operation *op) {
+      if (!isa<xegpu::LoadGatherOp>(op))
+        return false;
+      auto yield = cast<gpu::YieldOp>(
+          warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+      return yield->getPrevNode() == op;
+    });
+    if (!yieldOperand)
+      return rewriter.notifyMatchFailure(
+          warpOp, "warp result is not a xegpu::LoadGatherOp op");
+
+    auto loadGatherOp =
+        yieldOperand->get().getDefiningOp<xegpu::LoadGatherOp>();
+    if (!loadGatherOp.getOffsets())
+      return rewriter.notifyMatchFailure(loadGatherOp,
+                                         "Load op must have offsets argument");
+    else if (cast<VectorType>(loadGatherOp.getOffsets().getType()).getRank() !=
+             1)
+      return rewriter.notifyMatchFailure(loadGatherOp,
+                                         "Expected 1D offsets vector");
+
+    SmallVector<size_t> newRetIndices;
+    SmallVector<Value> operands =
+        llvm::to_vector_of<Value>(loadGatherOp->getOperands());
+    SmallVector<Type> operandTypes =
+        llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes());
+
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, operands, operandTypes, newRetIndices);
+
+    SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
+        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+    const unsigned operandIdx = yieldOperand->getOperandNumber();
+    VectorType loadVecTy =
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
+    assert(loadVecTy.getRank() == 1 && "Expected a distributed vector");
+
+    Value offsetsVec = newLoadGatherOperands[1];
+    Value maskVec = newLoadGatherOperands[2];
+    auto loc = newWarpOp.getLoc();
+    Value laneId = warpOp.getLaneid();
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Value laneOffset =
+        vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
+    laneOffset = vector::BroadcastOp::create(
+        rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
+    Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
+    laneMask = vector::BroadcastOp::create(
+        rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
+    newLoadGatherOperands[1] = laneOffset;
+    newLoadGatherOperands[2] = laneMask;
+
+    xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
+        loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs());
+    Value distributedVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
+    return success();
+  }
+};
+
 } // namespace
 
 namespace {
@@ -823,10 +953,11 @@ struct XeGPUSubgroupDistributePass final
 
 void xegpu::populateXeGPUSubgroupDistributePatterns(
     RewritePatternSet &patterns) {
-  patterns.add<CreateNdDescDistribution, StoreNdDistribution,
-               LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
-               UpdateNdOffsetDistribution, GpuBarrierDistribution>(
-      patterns.getContext());
+  patterns
+      .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
+           DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
+           GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
+          patterns.getContext());
 }
 
 void XeGPUSubgroupDistributePass::runOnOperation() {
@@ -841,6 +972,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       if (!isa<VectorType>(operand.get().getType()))
         continue;
 
+      if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
+        continue;
       xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
       if (!layout) {
         op->emitError("Could not find layout attribute for operand ")
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 54ef56e013abb..a4757dd132024 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -302,20 +302,43 @@ gpu.module @test {
 }
 
 // -----
-// CHECK-LABEL: gpu.func @gpu_barrier({{.*}}) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
-// CHECK-NEXT: %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<16xf16> -> vector<1xf16>
-// CHECK-NEXT: gpu.barrier
-// CHECK-NEXT: %[[T2:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
-// CHECK-NEXT: xegpu.store_nd %[[T1]], %[[T2]] : vector<1xf16>, !xegpu.tensor_desc<16xf16>
+// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
+// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
 gpu.module @test {
-  gpu.func @gpu_barrier(%arg0: memref<256xf16>, %arg1: memref<256xf16>) {
-    %c0 = arith.constant 0 : index
-    %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-    %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16>
-    gpu.barrier
-    %2 = xegpu.create_nd_tdesc %arg1[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-    xegpu.store_nd %1, %2 : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+  gpu.func @scatter_ops_chunksize(%src: memref<256xf16>, %offset1: vector<16xindex>) {
+    %1 = arith.constant dense<1>: vector<16xi1>
+    %3 = xegpu.load %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+        : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+    xegpu.store %3, %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+        : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
+// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.module @test {
+  gpu.func @scatter_ops(%src: memref<256xf16>, %offset1: vector<16xindex>) {
+    %1 = arith.constant dense<1>: vector<16xi1>
+    %3 = xegpu.load %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+        : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+    xegpu.store %3, %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+        : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
     gpu.return
   }
 }

@akroviakov
Copy link
Contributor Author

Pinging @charithaintc

@charithaintc charithaintc self-requested a review August 22, 2025 17:50
@@ -841,6 +972,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
if (!isa<VectorType>(operand.get().getType()))
continue;

if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deserves some comment I think

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a first pass.

generally I don't agree with the code sequence extract[laneid] -> broadcast.

Instead I think offsets and masks must be distributed. We should discuss this.

general comments:

  1. No need of casts for typed values.
  2. please pay attention to variable names.

@@ -811,6 +811,136 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
}
};

struct StoreDistribution final : public gpu::WarpDistributionPattern {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment on what the pattern does. example is preferred.

LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<gpu::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getTerminator helper was added recently to WarpOp, please use it.

auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
if (!storeScatterOp)
return failure();
else if (!storeScatterOp.getOffsets())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no need of else here. can simply use a new if.

"Expected 1D offsets vector");

VectorType storeVecTy =
cast<VectorType>(storeScatterOp.getValue().getType());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cast<VectorType>(storeScatterOp.getValue().getType());
storeScatterOp.getValueType();

else if (!storeScatterOp.getOffsets())
return rewriter.notifyMatchFailure(storeScatterOp,
"Store op must have offsets argument");
else if (cast<VectorType>(storeScatterOp.getOffsets().getType())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like getOffsets return a typed value. So no need of casts.

}
};

struct LoadDistribution final : public gpu::WarpDistributionPattern {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment for the pattern.

LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *yieldOperand = getWarpResult(warpOp, [&](Operation *op) {
if (!isa<xegpu::LoadGatherOp>(op))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add explanation why this check is needed (memory ordering preservation)

if (!loadGatherOp.getOffsets())
return rewriter.notifyMatchFailure(loadGatherOp,
"Load op must have offsets argument");
else if (cast<VectorType>(loadGatherOp.getOffsets().getType()).getRank() !=
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: drop else

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need of casts for typed values.

SmallVector<size_t> newRetIndices;
SmallVector<Value> operands =
llvm::to_vector_of<Value>(loadGatherOp->getOperands());
SmallVector<Type> operandTypes =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use a better name. check other patterns. Also same comment as above. I think offsets and masks must be distributed.

Comment on lines 926 to 930
Value laneOffset =
vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
laneOffset = vector::BroadcastOp::create(
rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
laneMask = vector::BroadcastOp::create(
rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
newLoadGatherOperands[1] = laneOffset;
newLoadGatherOperands[2] = laneMask;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure about this code sequence.

@akroviakov
Copy link
Contributor Author

akroviakov commented Aug 23, 2025

Instead I think offsets and masks must be distributed.

This is the main difference between scattered ops and nd ops.

  1. We do not have an intrinsic beneath these ops, that would provide clear rules (i.e., describe the structure) of a load/store.
  2. We do not have a single offset that defines a base pointer for a 2D shape whose structure we could describe using a layout attribute.

The offsets are the layout, and they are not necessarily linear (w.r.t. lane id) or compile time defined.

The documentation does not prevent me from supplying a completely unstructured vector of offsets (e.g., [0, 5, 2, 11, 1]), it only says that the op needs SG-size vector of offsets:

  • offsets: represents offsets from source. required if source in not a TensorDescType.
    offsets is a vector of index type and vector length is either the subgroup size
    or 1 in SIMT mode. scalar offset is also valid for SIMT mode.

Therefore, we cannot "distribute" such vector based on lane_layout = [1, N]. How would that look for the above unstructured vector? And if we could distribute, why do we need a vector of offsets at SG level?

The same applies to the mask, one can supply any random vector of i1. How do we convey that lane 0 is 0, but lane 3 is 1 in pure SIMT?

The offsets/mask vectors are not SG-uniform, they are allowed to be unstructured, and they can be completely runtime defined. What is distribution supposed to do with them at compile time, in your opinion?

@akroviakov akroviakov force-pushed the akroviak/xegpu-scatter-sg-to-wi branch from 566cb7b to a081c8d Compare August 23, 2025 10:35
@charithaintc
Copy link
Contributor

charithaintc commented Aug 25, 2025

Instead I think offsets and masks must be distributed.

This is the main difference between scattered ops and nd ops.

  1. We do not have an intrinsic beneath these ops, that would provide clear rules (i.e., describe the structure) of a load/store.
  2. We do not have a single offset that defines a base pointer for a 2D shape whose structure we could describe using a layout attribute.

The offsets are the layout, and they are not necessarily linear (w.r.t. lane id) or compile time defined.

The documentation does not prevent me from supplying a completely unstructured vector of offsets (e.g., [0, 5, 2, 11, 1]), it only says that the op needs SG-size vector of offsets:

  • offsets: represents offsets from source. required if source in not a TensorDescType.
    offsets is a vector of index type and vector length is either the subgroup size
    or 1 in SIMT mode. scalar offset is also valid for SIMT mode.

Therefore, we cannot "distribute" such vector based on lane_layout = [1, N]. How would that look for the above unstructured vector? And if we could distribute, why do we need a vector of offsets at SG level?

The same applies to the mask, one can supply any random vector of i1. How do we convey that lane 0 is 0, but lane 3 is 1 in pure SIMT?

The offsets/mask vectors are not SG-uniform, they are allowed to be unstructured, and they can be completely runtime defined. What is distribution supposed to do with them at compile time, in your opinion?

When I say "offsets are distributed" it does not mean we have to describe them as some affine function of laneID. I meant is the vector<16xindex> will become vector<1xindex>.

And then each lane can extract the scalar value from this <1xindex> vector. Let me give an example.

Before.

%offsets = arith.constant dense<0> : vector<16xindex>
// insert any value to this vector (random, linear does not matter)
%v = xegpu.load %base [%offset] : i64, vector<16xindex> -> vector<16xf16>

After SIMT distribution.

%offsets = arith.constant dense<0> : vector<1xindex>
// insert any value to this vector (random, linear does not matter)
%scalar_offset = vector.extract %offset[0] : index
%v = xegpu.load %base [%scalar_offset] : i64, index-> vector<1xf16>

Can you please explain why such strategy would not work?

If instead if we broadcast the offsets, we are wasting a lot of registers plus broadcasting need cross-lane comm.

Also, upstream already have patterns to distribute the constants (i.e. elementwise ops). So you don't have to do anything there.

@akroviakov
Copy link
Contributor Author

akroviakov commented Aug 25, 2025

That is the point, the offsets are not

%offsets = arith.constant dense<0> : vector<16xindex>

They can be arbitrary, how does the proposed distribution work with a vector of [0, 5, 2, 11, 1] instead of sg-uniform value?

UPD: I see, let me think about it

@charithaintc
Copy link
Contributor

That is the point, the offsets are not

%offsets = arith.constant dense<0> : vector<16xindex>

They can be arbitrary, how does the proposed distribution work with a vector of [0, 5, 2, 11, 1] instead of sg-uniform value?

UPD: I see, let me think about it

I feel such random offsets are more likely to come as a memory buffer and not likely to be static at SG level.

@akroviakov
Copy link
Contributor Author

Before.

%offsets = arith.constant dense<0> : vector<16xindex>
// insert any value to this vector (random, linear does not matter)
%v = xegpu.load %base [%offset] : i64, vector<16xindex> -> vector<16xf16>

But should we care about the vector producer? Arith distribution is there, but there are more ways to create a vector, we could even receive it as an argument from the runtime. Do I miss the op or distribution logic constraint that the offsets producer must be retrievable?

@charithaintc
Copy link
Contributor

Before.

%offsets = arith.constant dense<0> : vector<16xindex>
// insert any value to this vector (random, linear does not matter)
%v = xegpu.load %base [%offset] : i64, vector<16xindex> -> vector<16xf16>

But should we care about the vector producer? Arith distribution is there, but there are more ways to create a vector, we could even receive it as an argument from the runtime. Do I miss the op or distribution logic constraint that the offsets producer must be retrievable?

"even receive it as an argument from the runtime." what does this mean? a func argument?

AFAIK, from the load gather distribution perspective, only thing we need to care about is what are the distributed types for the base, offsets and masks. Everything else should ideally be handled by the framework.

I suggest, testing it with distributed type as <1xindex> and see what the framework does.

@akroviakov
Copy link
Contributor Author

even receive it as an argument from the runtime." what does this mean? a func argument?

Cases where sg distribution has no access to %offsets = arith.constant dense<0> : vector<16xindex>.
At the SG level of a scattered op, we only see a vector of n offsets. We cannot assume (unless otherwise stated in the pass restrictions) whether its producer is distributable or not. Without this assumption, how can we rely on <1xindex>?

I agree that extracting at idx 0 would work if the input is already distributed, but we go bottom up, and we cannot assume that it will be distributed.

@charithaintc
Copy link
Contributor

I did some quick testing. my conclusion is that we don't have to care about how the offset is defined. It will be taken care by the framework (unless it is produced by some op that is not supported, in which case we need to add support).

Example 1: (Trivially distributable)

func.func @lane_dependent_warp_propagate_read(
    %src: memref<1024xf32>, %dest: memref<1024xf32>) {
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f32
  %laneid = gpu.lane_id
  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
    %2 = arith.constant dense<0.0> : vector<32xf32>
    gpu.yield %2 : vector<32xf32>
  }
  vector.transfer_write %r, %dest[%laneid] : vector<1xf32>, memref<1024xf32>
  return
}

To

  func.func @lane_dependent_warp_propagate_read(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
    %cst = arith.constant dense<0.000000e+00> : vector<1xf32>
    %0 = gpu.lane_id
    vector.transfer_write %cst, %arg1[%0] : vector<1xf32>, memref<1024xf32>
    return
  }

Example 2 (complicated case).

func.func @lane_dependent_warp_propagate_read(
    %src: memref<1024xf32>, %dest: memref<1024xf32>) {
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f32
  %laneid = gpu.lane_id
  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
    // %2 = arith.constant dense<0.0> : vector<32xf32>
    %2 = arith.constant dense<[0.0, 1.0, 3.0, 4.0, 0.0, 1.0, 3.0, 4.0,0.0, 1.0, 3.0, 4.0,0.0, 1.0, 3.0, 4.0,0.0, 1.0, 3.0, 4.0,0.0, 1.0, 3.0, 4.0,0.0, 1.0, 3.0, 4.0,0.0, 1.0, 3.0, 4.0]> : vector<32xf32>
    gpu.yield %2 : vector<32xf32>
  }
  vector.transfer_write %r, %dest[%laneid] : vector<1xf32>, memref<1024xf32>
  return

To

  func.func @lane_dependent_warp_propagate_read(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
    %cst = arith.constant dense<[0.000000e+00, 1.000000e+00, 3.000000e+00, 4.000000e+00, 0.000000e+00, 1.000000e+00, 3.000000e+00, 4.000000e+00, 0.000000e+00, 1.000000e+00, 3.000000e+00, 4.000000e+00, 0.000000e+00, 1.000000e+00, 3.000000e+00, 4.000000e+00, 0.000000e+00, 1.000000e+00, 3.000000e+00, 4.000000e+00, 0.000000e+00, 1.000000e+00, 3.000000e+00, 4.000000e+00, 0.000000e+00, 1.000000e+00, 3.000000e+00, 4.000000e+00, 0.000000e+00, 1.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<32xf32>
    %0 = gpu.lane_id
    %1 = gpu.warp_execute_on_lane_0(%0)[32] -> (vector<1xf32>) {
      gpu.yield %cst : vector<32xf32>
    }
    vector.transfer_write %1, %arg1[%0] : vector<1xf32>, memref<1024xf32>
    return
  }

I agree, for the complex case broadcasting is needed indeed. But I guess this is outside the scope of gather/scatter distribution. It should not care about it.

@charithaintc
Copy link
Contributor

even receive it as an argument from the runtime." what does this mean? a func argument?

Cases where sg distribution has no access to %offsets = arith.constant dense<0> : vector<16xindex>. At the SG level of a scattered op, we only see a vector of n offsets. We cannot assume (unless otherwise stated in the pass restrictions) whether its producer is distributable or not. Without this assumption, how can we rely on <1xindex>?

I agree that extracting at idx 0 would work if the input is already distributed, but we go bottom up, and we cannot assume that it will be distributed.

I get your point. But I think gather scatter logic should not care about this. It should simply assume this is always distributable. Maybe we should wait for @Jianhui-Li's input also :-)

@akroviakov
Copy link
Contributor Author

assume this is always distributable.

This solves the major issue.

In your examples, how does the distribution pattern decide whether to use laneId (example 2) or 0 (example 1) if it only sees the argument as a vector of n values?

@charithaintc
Copy link
Contributor

assume this is always distributable.

This solves the major issue.

In your examples, how does the distribution pattern decide whether to use laneId (example 2) or 0 (example 1) if it only sees the argument as a vector of n values?

This is done by WarpOpElementwise, If the compile can prove the value is uniform it will uniformly distribute the vector. If not the values can not be distributed. So warpOP is not fully eliminated. There is a pattern to lower remaining warpOp s to scf.if.

@charithaintc
Copy link
Contributor

Final result (after lowering remaining warpop) :

  memref.global "private" @__shared_32xf32 : memref<32xf32, 3>
  func.func @lane_dependent_warp_propagate_read(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
    %0 = ub.poison : f32
    %c0 = arith.constant 0 : index
    %cst = arith.constant dense<[0.000000e+00, 1.000000e+00, 3.000000e+00, 4.000000e+00, 0.000000e+00, 1.000000e+00, 3.000000e+00, 4.000000e+00, 0.000000e+00, 1.000000e+00, 3.000000e+00, 4.000000e+00, 0.000000e+00, 1.000000e+00, 3.000000e+00, 4.000000e+00, 0.000000e+00, 1.000000e+00, 3.000000e+00, 4.000000e+00, 0.000000e+00, 1.000000e+00, 3.000000e+00, 4.000000e+00, 0.000000e+00, 1.000000e+00, 3.000000e+00, 4.000000e+00, 0.000000e+00, 1.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<32xf32>
    %1 = gpu.lane_id
    %2 = arith.cmpi eq, %1, %c0 : index
    %3 = memref.get_global @__shared_32xf32 : memref<32xf32, 3>
    scf.if %2 {
      vector.transfer_write %cst, %3[%c0] {in_bounds = [true]} : vector<32xf32>, memref<32xf32, 3>
    }
    gpu.barrier
    %4 = vector.transfer_read %3[%1], %0 {in_bounds = [true]} : memref<32xf32, 3>, vector<1xf32>
    vector.transfer_write %4, %arg1[%1] : vector<1xf32>, memref<1024xf32>
    return
  }

As you can see if the offsets are "weird" it will go though SLM. My point is gather/scatter distribution should not care about it. It is separation of concerns.

@akroviakov
Copy link
Contributor Author

Thanks @charithaintc for the examples and clarifications. I update the patterns to distribute the offsets and mask. The vector.step op distribution (as the main expected producer) will be added in the subsequent PR.

@akroviakov
Copy link
Contributor Author

Will address the rest of the feedback in a separate commit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants