|
8 | 8 |
|
9 | 9 | #pragma once
|
10 | 10 |
|
11 |
| -#include <sycl/access/access.hpp> // for decorated, address_space |
12 |
| -#include <sycl/aliases.hpp> // for half, cl_char, cl_double |
13 |
| -#include <sycl/detail/helpers.hpp> // for marray |
14 |
| -#include <sycl/detail/type_traits.hpp> // for is_gen_based_on_type_s... |
15 |
| -#include <sycl/half_type.hpp> // for BIsRepresentationT |
16 |
| -#include <sycl/multi_ptr.hpp> // for multi_ptr, address_spa... |
17 |
| - |
18 |
| -#include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16 storage type. |
| 11 | +#include <sycl/access/access.hpp> |
| 12 | +#include <sycl/aliases.hpp> |
| 13 | +#include <sycl/bit_cast.hpp> |
| 14 | +#include <sycl/detail/fwd/half.hpp> |
| 15 | +#include <sycl/detail/type_traits.hpp> |
19 | 16 |
|
20 | 17 | #include <cstddef> // for byte
|
21 | 18 | #include <cstdint> // for uint8_t
|
|
24 | 21 |
|
25 | 22 | namespace sycl {
|
26 | 23 | inline namespace _V1 {
|
| 24 | +namespace ext::oneapi { |
| 25 | +class bfloat16; |
| 26 | +} |
27 | 27 | namespace detail {
|
28 | 28 | template <typename T>
|
29 | 29 | using is_byte = typename
|
@@ -166,13 +166,16 @@ template <typename T> auto convertToOpenCLType(T &&x) {
|
166 | 166 | static_assert(sizeof(OpenCLType) == sizeof(T));
|
167 | 167 | return static_cast<OpenCLType>(x);
|
168 | 168 | } else if constexpr (std::is_same_v<no_ref, half>) {
|
169 |
| - using OpenCLType = sycl::detail::half_impl::BIsRepresentationT; |
| 169 | + // Make it template-param-dependent to compile with incomplete `half`: |
| 170 | + using OpenCLType = |
| 171 | + std::enable_if_t<std::is_same_v<no_ref, half>, |
| 172 | + sycl::detail::half_impl::BIsRepresentationT>; |
170 | 173 | static_assert(sizeof(OpenCLType) == sizeof(T));
|
171 | 174 | return static_cast<OpenCLType>(x);
|
172 | 175 | } else if constexpr (std::is_same_v<no_ref, ext::oneapi::bfloat16>) {
|
173 | 176 | // On host, don't interpret BF16 as uint16.
|
174 | 177 | #ifdef __SYCL_DEVICE_ONLY__
|
175 |
| - using OpenCLType = sycl::ext::oneapi::bfloat16::Bfloat16StorageT; |
| 178 | + using OpenCLType = typename no_ref::Bfloat16StorageT; |
176 | 179 | return sycl::bit_cast<OpenCLType>(x);
|
177 | 180 | #else
|
178 | 181 | return std::forward<T>(x);
|
|
0 commit comments