Skip to content

Commit a081c8d

Browse files
committed
[MLIR][XeGPU] Scattered ops sg-to-wi distribution
1 parent 230b9b2 commit a081c8d

File tree

2 files changed

+179
-4
lines changed

2 files changed

+179
-4
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 137 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,136 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
807807
}
808808
};
809809

810+
struct StoreDistribution final : public gpu::WarpDistributionPattern {
811+
using gpu::WarpDistributionPattern::WarpDistributionPattern;
812+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
813+
PatternRewriter &rewriter) const override {
814+
auto yield = cast<gpu::YieldOp>(
815+
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
816+
Operation *lastNode = yield->getPrevNode();
817+
auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
818+
if (!storeScatterOp)
819+
return failure();
820+
else if (!storeScatterOp.getOffsets())
821+
return rewriter.notifyMatchFailure(storeScatterOp,
822+
"Store op must have offsets argument");
823+
else if (cast<VectorType>(storeScatterOp.getOffsets().getType())
824+
.getRank() != 1)
825+
return rewriter.notifyMatchFailure(storeScatterOp,
826+
"Expected 1D offsets vector");
827+
828+
VectorType storeVecTy =
829+
cast<VectorType>(storeScatterOp.getValue().getType());
830+
assert(storeVecTy.getRank() <= 2 &&
831+
"Expected at most 2D result at SG level");
832+
VectorType distStoreVecTy;
833+
if (storeVecTy.getRank() == 2)
834+
distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0);
835+
else // rank 1
836+
distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
837+
838+
SmallVector<size_t> newRetIndices;
839+
SmallVector<Value> operands =
840+
llvm::to_vector_of<Value>(storeScatterOp->getOperands());
841+
SmallVector<Type> operandTypes =
842+
llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
843+
operandTypes[0] = distStoreVecTy;
844+
845+
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
846+
rewriter, warpOp, operands, operandTypes, newRetIndices);
847+
SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
848+
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
849+
850+
Value offsetsVec = newStoreScatterOpOperands[2];
851+
Value maskVec = newStoreScatterOpOperands[3];
852+
853+
auto loc = newWarpOp.getLoc();
854+
Value laneId = warpOp.getLaneid();
855+
rewriter.setInsertionPointAfter(newWarpOp);
856+
Value laneOffset =
857+
vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
858+
laneOffset = vector::BroadcastOp::create(
859+
rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
860+
Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
861+
laneMask = vector::BroadcastOp::create(
862+
rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
863+
newStoreScatterOpOperands[2] = laneOffset;
864+
newStoreScatterOpOperands[3] = laneMask;
865+
866+
xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
867+
rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
868+
storeScatterOp->getAttrs());
869+
xegpu::removeLayoutAttrs(newOp);
870+
rewriter.eraseOp(storeScatterOp);
871+
return success();
872+
}
873+
};
874+
875+
struct LoadDistribution final : public gpu::WarpDistributionPattern {
876+
using gpu::WarpDistributionPattern::WarpDistributionPattern;
877+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
878+
PatternRewriter &rewriter) const override {
879+
OpOperand *yieldOperand = getWarpResult(warpOp, [&](Operation *op) {
880+
if (!isa<xegpu::LoadGatherOp>(op))
881+
return false;
882+
auto yield = cast<gpu::YieldOp>(
883+
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
884+
return yield->getPrevNode() == op;
885+
});
886+
if (!yieldOperand)
887+
return rewriter.notifyMatchFailure(
888+
warpOp, "warp result is not a xegpu::LoadGatherOp op");
889+
890+
auto loadGatherOp =
891+
yieldOperand->get().getDefiningOp<xegpu::LoadGatherOp>();
892+
if (!loadGatherOp.getOffsets())
893+
return rewriter.notifyMatchFailure(loadGatherOp,
894+
"Load op must have offsets argument");
895+
else if (cast<VectorType>(loadGatherOp.getOffsets().getType()).getRank() !=
896+
1)
897+
return rewriter.notifyMatchFailure(loadGatherOp,
898+
"Expected 1D offsets vector");
899+
900+
SmallVector<size_t> newRetIndices;
901+
SmallVector<Value> operands =
902+
llvm::to_vector_of<Value>(loadGatherOp->getOperands());
903+
SmallVector<Type> operandTypes =
904+
llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes());
905+
906+
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
907+
rewriter, warpOp, operands, operandTypes, newRetIndices);
908+
909+
SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
910+
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
911+
912+
const unsigned operandIdx = yieldOperand->getOperandNumber();
913+
VectorType loadVecTy =
914+
cast<VectorType>(warpOp.getResult(operandIdx).getType());
915+
assert(loadVecTy.getRank() == 1 && "Expected a distributed vector");
916+
917+
Value offsetsVec = newLoadGatherOperands[1];
918+
Value maskVec = newLoadGatherOperands[2];
919+
auto loc = newWarpOp.getLoc();
920+
Value laneId = warpOp.getLaneid();
921+
rewriter.setInsertionPointAfter(newWarpOp);
922+
Value laneOffset =
923+
vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
924+
laneOffset = vector::BroadcastOp::create(
925+
rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
926+
Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
927+
laneMask = vector::BroadcastOp::create(
928+
rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
929+
newLoadGatherOperands[1] = laneOffset;
930+
newLoadGatherOperands[2] = laneMask;
931+
932+
xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
933+
loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs());
934+
Value distributedVal = newWarpOp.getResult(operandIdx);
935+
rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
936+
return success();
937+
}
938+
};
939+
810940
} // namespace
811941

812942
namespace {
@@ -819,10 +949,11 @@ struct XeGPUSubgroupDistributePass final
819949

820950
void xegpu::populateXeGPUSubgroupDistributePatterns(
821951
RewritePatternSet &patterns) {
822-
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
823-
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
824-
UpdateNdOffsetDistribution, GpuBarrierDistribution>(
825-
patterns.getContext());
952+
patterns
953+
.add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
954+
DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
955+
GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
956+
patterns.getContext());
826957
}
827958

828959
void XeGPUSubgroupDistributePass::runOnOperation() {
@@ -837,6 +968,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
837968
if (!isa<VectorType>(operand.get().getType()))
838969
continue;
839970

971+
if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
972+
continue;
840973
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
841974
if (!layout) {
842975
op->emitError("Could not find layout attribute for operand ")

mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,45 @@ gpu.module @test {
319319
gpu.return
320320
}
321321
}
322+
323+
// -----
324+
// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
325+
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
326+
// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
327+
// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
328+
// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
329+
// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
330+
// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
331+
// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
332+
// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
333+
gpu.module @test {
334+
gpu.func @scatter_ops_chunksize(%src: memref<256xf16>, %offset1: vector<16xindex>) {
335+
%1 = arith.constant dense<1>: vector<16xi1>
336+
%3 = xegpu.load %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
337+
: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
338+
xegpu.store %3, %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
339+
: vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
340+
gpu.return
341+
}
342+
}
343+
344+
// -----
345+
// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
346+
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
347+
// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
348+
// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
349+
// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
350+
// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
351+
// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
352+
// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
353+
// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
354+
gpu.module @test {
355+
gpu.func @scatter_ops(%src: memref<256xf16>, %offset1: vector<16xindex>) {
356+
%1 = arith.constant dense<1>: vector<16xi1>
357+
%3 = xegpu.load %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
358+
: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
359+
xegpu.store %3, %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
360+
: vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
361+
gpu.return
362+
}
363+
}

0 commit comments

Comments
 (0)