diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.cc b/onnxruntime/core/providers/webgpu/tensor/gather.cc index 39d07991f3c5a..d9fa7edc1aac8 100644 --- a/onnxruntime/core/providers/webgpu/tensor/gather.cc +++ b/onnxruntime/core/providers/webgpu/tensor/gather.cc @@ -9,32 +9,51 @@ namespace onnxruntime { namespace webgpu { Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& data = shader.AddInput("data", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& data = shader.AddInput("data", ShaderUsage::UseIndicesTypeAlias); const auto& indices = shader.AddInput("input_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + const auto& output = shader.AddOutput("output", ShaderUsage::UseValueTypeAlias); + const auto& data_indices = shader.AddIndices("data_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& output_indices = shader.AddIndices("output_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + bool is_bool = Inputs()[0].var_type == ProgramVariableDataType::Boolx4; shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") - << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" - << " var indices_indices = input_indices_indices_t(0);\n"; - for (int i = 0; i < indices.Rank(); i++) { - shader.MainFunctionBody() << " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", axis_ + i)) << ";\n"; - } - shader.MainFunctionBody() << " var idx = " << indices.GetByIndices("indices_indices") << ";\n" - << " if (idx < 0) {\n" - << " idx = idx + input_indices_value_t(" << data.IndicesGet("uniforms.data_shape", axis_) << ");\n" - << " }\n" - << " var data_indices : data_indices_t;\n"; - for (int i = 0, j = 0; i < data.Rank(); i++) { - if (static_cast(i) == axis_) { - shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, "u32(idx)") << ";\n"; - j += indices.Rank(); + << " var idx : input_indices_value_t;\n" + << " var output_indices : output_indices_indices_t;\n" + << " var indices_indices : input_indices_indices_t;\n" + << " var data_indices : data_indices_indices_t;\n" + << " var value : output_value_t;\n" + << " var data_offset : u32;\n"; + for (int comp = 0; comp < (is_bool ? 4 : 1); comp++) { + shader.MainFunctionBody() << " output_indices = " << output_indices.OffsetToIndices(is_bool ? (std::to_string(comp) + " + 4 * global_idx") : "global_idx") << ";\n"; + + for (int i = 0; i < indices.Rank(); i++) { + shader.MainFunctionBody() << " " << indices.IndicesSet("indices_indices", i, output_indices.IndicesGet("output_indices", axis_ + i)) << ";\n"; + } + + shader.MainFunctionBody() << " idx = " << indices.GetByIndices("indices_indices") << ";\n" + << " if (idx < 0) {\n" + << " idx = idx + input_indices_value_t(" << data_indices.IndicesGet("uniforms.data_indices_shape", axis_) << ");\n" + << " }\n"; + + for (int i = 0, j = 0; i < data_indices.Rank(); i++) { + if (static_cast(i) == axis_) { + shader.MainFunctionBody() << " " << data_indices.IndicesSet("data_indices", i, "u32(idx)") << ";\n"; + j += indices.Rank(); + } else { + shader.MainFunctionBody() << " " << data_indices.IndicesSet("data_indices", i, output_indices.IndicesGet("output_indices", j)) << ";\n"; + j++; + } + } + + shader.MainFunctionBody() << " data_offset = " << data_indices.IndicesToOffset("data_indices") << ";\n"; + if (is_bool) { + shader.MainFunctionBody() << " value[" << comp << "] = " << data.GetByOffset("data_offset / 4") << "[data_offset % 4];\n"; } else { - shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, output.IndicesGet("output_indices", j)) << ";\n"; - j++; + shader.MainFunctionBody() << " value = " << data.GetByOffset("data_offset") << ";\n"; } } - shader.MainFunctionBody() << " " << output.SetByOffset("global_idx", data.GetByIndices("data_indices")); + shader.MainFunctionBody() << " " << output.SetByOffset("global_idx", "value"); return Status::OK(); } @@ -47,14 +66,20 @@ Status Gather::ComputeInternal(ComputeContext& context) const { return Status::OK(); } + bool is_bool = p.input_tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; + if (is_bool) { + data_size = (data_size + 3) / 4; + } uint32_t axis = static_cast(p.axis); GatherProgram program{axis}; program - .AddInputs({{p.input_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + .AddInputs({{p.input_tensor, ProgramTensorMetadataDependency::TypeAndRank, ProgramInput::Flatten, (is_bool ? 4 : 1)}, {p.indices_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutput({p.output_tensor, ProgramTensorMetadataDependency::Rank}) + .AddOutput({p.output_tensor, ProgramTensorMetadataDependency::Rank, {data_size}, (is_bool ? 4 : 1)}) .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .CacheHint(std::to_string(axis)) + .AddIndices(p.input_tensor->Shape()) + .AddIndices(p.output_tensor->Shape()) .AddUniformVariables({{data_size}}); return context.RunProgram(program); } @@ -71,9 +96,9 @@ Status Gather::ComputeInternal(ComputeContext& context) const { KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), \ KERNEL_CLASS); -WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 1, 10, Gather, WebGpuSupportedNumberTypes()) -WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 11, 12, Gather, WebGpuSupportedNumberTypes()) -WEBGPU_GATHER_KERNEL(Gather, 13, Gather, WebGpuSupportedNumberTypes()) +WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 1, 10, Gather, WebGpuSupportedNumberAndBoolTypes()) +WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 11, 12, Gather, WebGpuSupportedNumberAndBoolTypes()) +WEBGPU_GATHER_KERNEL(Gather, 13, Gather, WebGpuSupportedNumberAndBoolTypes()) } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_supported_types.h b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h index ff66cd535399e..28f523219ba44 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_supported_types.h +++ b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h @@ -20,6 +20,14 @@ using SupportedFloats = float, MLFloat16>; +using SupportedNumberAndBoolTypes = + TypeList< + float, + MLFloat16, + int32_t, + uint32_t, + bool>; + inline const std::vector& WebGpuSupportedNumberTypes() { static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); return supportedDataTypes; @@ -30,5 +38,10 @@ inline const std::vector& WebGpuSupportedFloatTypes() { return supportedDataTypes; } +inline const std::vector& WebGpuSupportedNumberAndBoolTypes() { + static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); + return supportedDataTypes; +} + } // namespace webgpu } // namespace onnxruntime