Skip to content

Commit 584e3d9

Browse files
[SYCL][NFC] More sub_group_mask.hpp cleanup (#18975)
1 parent 74d8de6 commit 584e3d9

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

sycl/include/sycl/detail/spirv.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#pragma once
1010

1111
#ifdef __SYCL_DEVICE_ONLY__
12-
1312
// Some __spirv_* inrinsics are automatically forward-declared by the compiler,
1413
// but not all of them. For example:
1514
// __spirv_AtomicStore(unsigned long long*, ...)
@@ -18,7 +17,10 @@
1817
#include <sycl/__spirv/spirv_ops.hpp>
1918
#include <sycl/__spirv/spirv_types.hpp>
2019

21-
#include <sycl/ext/oneapi/experimental/non_uniform_groups.hpp> // for IdToMaskPosition
20+
#include <sycl/access/access.hpp>
21+
#include <sycl/detail/generic_type_traits.hpp>
22+
#include <sycl/id.hpp>
23+
#include <sycl/multi_ptr.hpp>
2224

2325
#if defined(__NVPTX__)
2426
#include <sycl/ext/oneapi/experimental/cuda/masked_shuffles.hpp>
@@ -33,6 +35,7 @@ struct sub_group;
3335
namespace ext {
3436
namespace oneapi {
3537
struct sub_group;
38+
struct sub_group_mask;
3639
namespace experimental {
3740
template <typename ParentGroup> class fragment;
3841

@@ -61,6 +64,9 @@ GetMultiPtrDecoratedAs(multi_ptr<FromT, Space, IsDecorated> MPtr) {
6164

6265
template <typename NonUniformGroup>
6366
inline uint32_t IdToMaskPosition(NonUniformGroup Group, uint32_t Id);
67+
template <typename NonUniformGroup>
68+
inline ext::oneapi::sub_group_mask GetMask(NonUniformGroup Group);
69+
inline sycl::vec<unsigned, 4> ExtractMask(ext::oneapi::sub_group_mask Mask);
6470

6571
namespace spirv {
6672

sycl/include/sycl/ext/oneapi/sub_group_mask.hpp

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
#include <sycl/detail/helpers.hpp> // for Builder
1111
#include <sycl/detail/memcpy.hpp> // detail::memcpy
12-
#include <sycl/exception.hpp> // for errc, exception
13-
#include <sycl/feature_test.hpp> // for SYCL_EXT_ONEAPI_SUB_GROUP_MASK
14-
#include <sycl/id.hpp> // for id
15-
#include <sycl/marray.hpp> // for marray
12+
#include <sycl/detail/spirv.hpp>
13+
#include <sycl/feature_test.hpp> // for SYCL_EXT_ONEAPI_SUB_GROUP_MASK
14+
#include <sycl/id.hpp> // for id
15+
#include <sycl/marray.hpp> // for marray
1616
#include <sycl/sub_group.hpp>
1717
#include <sycl/vector.hpp> // for vec
1818

@@ -378,19 +378,11 @@ group_ballot([[maybe_unused]] Group g, [[maybe_unused]] bool predicate) {
378378
#ifdef __SYCL_DEVICE_ONLY__
379379
return sycl::detail::commonGroupBallotImpl(g, predicate);
380380
#else
381-
throw exception{errc::feature_not_supported,
382-
"Sub-group mask is not supported on host device"};
381+
// Groups are not user-constructible, this call should not be reachable from
382+
// host and therefore we do nothing here.
383383
#endif
384384
}
385385

386386
} // namespace ext::oneapi
387387
} // namespace _V1
388388
} // namespace sycl
389-
390-
// We have a cyclic dependency with
391-
// sub_group_mask.hpp
392-
// detail/spirv.hpp
393-
// non_uniform_groups.hpp
394-
// "Break" it by including this at the end (instead of beginning). Ideally, we
395-
// should refactor this somehow...
396-
#include <sycl/detail/spirv.hpp>

0 commit comments

Comments
 (0)