@@ -807,6 +807,136 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
807
807
}
808
808
};
809
809
810
+ struct StoreDistribution final : public gpu::WarpDistributionPattern {
811
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
812
+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
813
+ PatternRewriter &rewriter) const override {
814
+ auto yield = cast<gpu::YieldOp>(
815
+ warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
816
+ Operation *lastNode = yield->getPrevNode ();
817
+ auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
818
+ if (!storeScatterOp)
819
+ return failure ();
820
+ else if (!storeScatterOp.getOffsets ())
821
+ return rewriter.notifyMatchFailure (storeScatterOp,
822
+ " Store op must have offsets argument" );
823
+ else if (cast<VectorType>(storeScatterOp.getOffsets ().getType ())
824
+ .getRank () != 1 )
825
+ return rewriter.notifyMatchFailure (storeScatterOp,
826
+ " Expected 1D offsets vector" );
827
+
828
+ VectorType storeVecTy =
829
+ cast<VectorType>(storeScatterOp.getValue ().getType ());
830
+ assert (storeVecTy.getRank () <= 2 &&
831
+ " Expected at most 2D result at SG level" );
832
+ VectorType distStoreVecTy;
833
+ if (storeVecTy.getRank () == 2 )
834
+ distStoreVecTy = VectorType::Builder (storeVecTy).dropDim (0 );
835
+ else // rank 1
836
+ distStoreVecTy = VectorType::Builder (storeVecTy).setDim (0 , 1 );
837
+
838
+ SmallVector<size_t > newRetIndices;
839
+ SmallVector<Value> operands =
840
+ llvm::to_vector_of<Value>(storeScatterOp->getOperands ());
841
+ SmallVector<Type> operandTypes =
842
+ llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes ());
843
+ operandTypes[0 ] = distStoreVecTy;
844
+
845
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
846
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
847
+ SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector (
848
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
849
+
850
+ Value offsetsVec = newStoreScatterOpOperands[2 ];
851
+ Value maskVec = newStoreScatterOpOperands[3 ];
852
+
853
+ auto loc = newWarpOp.getLoc ();
854
+ Value laneId = warpOp.getLaneid ();
855
+ rewriter.setInsertionPointAfter (newWarpOp);
856
+ Value laneOffset =
857
+ vector::ExtractOp::create (rewriter, loc, offsetsVec, laneId);
858
+ laneOffset = vector::BroadcastOp::create (
859
+ rewriter, loc, VectorType::get ({1 }, laneOffset.getType ()), laneOffset);
860
+ Value laneMask = vector::ExtractOp::create (rewriter, loc, maskVec, laneId);
861
+ laneMask = vector::BroadcastOp::create (
862
+ rewriter, loc, VectorType::get ({1 }, laneMask.getType ()), laneMask);
863
+ newStoreScatterOpOperands[2 ] = laneOffset;
864
+ newStoreScatterOpOperands[3 ] = laneMask;
865
+
866
+ xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create (
867
+ rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
868
+ storeScatterOp->getAttrs ());
869
+ xegpu::removeLayoutAttrs (newOp);
870
+ rewriter.eraseOp (storeScatterOp);
871
+ return success ();
872
+ }
873
+ };
874
+
875
+ struct LoadDistribution final : public gpu::WarpDistributionPattern {
876
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
877
+ LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
878
+ PatternRewriter &rewriter) const override {
879
+ OpOperand *yieldOperand = getWarpResult (warpOp, [&](Operation *op) {
880
+ if (!isa<xegpu::LoadGatherOp>(op))
881
+ return false ;
882
+ auto yield = cast<gpu::YieldOp>(
883
+ warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
884
+ return yield->getPrevNode () == op;
885
+ });
886
+ if (!yieldOperand)
887
+ return rewriter.notifyMatchFailure (
888
+ warpOp, " warp result is not a xegpu::LoadGatherOp op" );
889
+
890
+ auto loadGatherOp =
891
+ yieldOperand->get ().getDefiningOp <xegpu::LoadGatherOp>();
892
+ if (!loadGatherOp.getOffsets ())
893
+ return rewriter.notifyMatchFailure (loadGatherOp,
894
+ " Load op must have offsets argument" );
895
+ else if (cast<VectorType>(loadGatherOp.getOffsets ().getType ()).getRank () !=
896
+ 1 )
897
+ return rewriter.notifyMatchFailure (loadGatherOp,
898
+ " Expected 1D offsets vector" );
899
+
900
+ SmallVector<size_t > newRetIndices;
901
+ SmallVector<Value> operands =
902
+ llvm::to_vector_of<Value>(loadGatherOp->getOperands ());
903
+ SmallVector<Type> operandTypes =
904
+ llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes ());
905
+
906
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
907
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
908
+
909
+ SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector (
910
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
911
+
912
+ const unsigned operandIdx = yieldOperand->getOperandNumber ();
913
+ VectorType loadVecTy =
914
+ cast<VectorType>(warpOp.getResult (operandIdx).getType ());
915
+ assert (loadVecTy.getRank () == 1 && " Expected a distributed vector" );
916
+
917
+ Value offsetsVec = newLoadGatherOperands[1 ];
918
+ Value maskVec = newLoadGatherOperands[2 ];
919
+ auto loc = newWarpOp.getLoc ();
920
+ Value laneId = warpOp.getLaneid ();
921
+ rewriter.setInsertionPointAfter (newWarpOp);
922
+ Value laneOffset =
923
+ vector::ExtractOp::create (rewriter, loc, offsetsVec, laneId);
924
+ laneOffset = vector::BroadcastOp::create (
925
+ rewriter, loc, VectorType::get ({1 }, laneOffset.getType ()), laneOffset);
926
+ Value laneMask = vector::ExtractOp::create (rewriter, loc, maskVec, laneId);
927
+ laneMask = vector::BroadcastOp::create (
928
+ rewriter, loc, VectorType::get ({1 }, laneMask.getType ()), laneMask);
929
+ newLoadGatherOperands[1 ] = laneOffset;
930
+ newLoadGatherOperands[2 ] = laneMask;
931
+
932
+ xegpu::LoadGatherOp newOp = rewriter.create <xegpu::LoadGatherOp>(
933
+ loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs ());
934
+ Value distributedVal = newWarpOp.getResult (operandIdx);
935
+ rewriter.replaceAllUsesWith (distributedVal, newOp->getResult (0 ));
936
+ return success ();
937
+ }
938
+ };
939
+
810
940
} // namespace
811
941
812
942
namespace {
@@ -819,10 +949,11 @@ struct XeGPUSubgroupDistributePass final
819
949
820
950
void xegpu::populateXeGPUSubgroupDistributePatterns (
821
951
RewritePatternSet &patterns) {
822
- patterns.add <CreateNdDescDistribution, StoreNdDistribution,
823
- LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
824
- UpdateNdOffsetDistribution, GpuBarrierDistribution>(
825
- patterns.getContext ());
952
+ patterns
953
+ .add <CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
954
+ DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
955
+ GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
956
+ patterns.getContext ());
826
957
}
827
958
828
959
void XeGPUSubgroupDistributePass::runOnOperation () {
@@ -837,6 +968,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
837
968
if (!isa<VectorType>(operand.get ().getType ()))
838
969
continue ;
839
970
971
+ if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
972
+ continue ;
840
973
xegpu::LayoutAttr layout = xegpu::getLayoutAttr (operand);
841
974
if (!layout) {
842
975
op->emitError (" Could not find layout attribute for operand " )
0 commit comments