Skip to content

Commit ebfa066

Browse files
[NFC][SYCL] Drop bfloat16::Bfloat16StorageT
It's not part of the specification and should have never been a public type alias inside `bfloat16`. There aren't too many uses of it (`bfloat16` itself and `convertToOpenCLType`/`vec::convert`) so I don't see much value in creating a named type alias.
1 parent 9b938b1 commit ebfa066

File tree

4 files changed

+18
-16
lines changed

4 files changed

+18
-16
lines changed

sycl/include/sycl/detail/generic_type_traits.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,7 @@ template <typename T> auto convertToOpenCLType(T &&x) {
175175
} else if constexpr (std::is_same_v<no_ref, ext::oneapi::bfloat16>) {
176176
// On host, don't interpret BF16 as uint16.
177177
#ifdef __SYCL_DEVICE_ONLY__
178-
using OpenCLType = typename no_ref::Bfloat16StorageT;
179-
return sycl::bit_cast<OpenCLType>(x);
178+
return sycl::bit_cast<uint16_t>(x);
180179
#else
181180
return std::forward<T>(x);
182181
#endif

sycl/include/sycl/detail/vector_convert.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -895,8 +895,7 @@ vec<convertT, NumElements> vec<DataT, NumElements>::convert() const {
895895
#endif
896896
bool, /*->*/ std::uint8_t, //
897897
sycl::half, /*->*/ sycl::detail::half_impl::StorageT, //
898-
sycl::ext::oneapi::bfloat16,
899-
/*->*/ sycl::ext::oneapi::bfloat16::Bfloat16StorageT, //
898+
sycl::ext::oneapi::bfloat16, /*->*/ uint16_t, //
900899
char, /*->*/ detail::ConvertToOpenCLType_t<char>, //
901900
DataT, /*->*/ DataT //
902901
>::type

sycl/include/sycl/ext/oneapi/bfloat16.hpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@ namespace ext::oneapi {
2121

2222
class bfloat16 {
2323
public:
24-
using Bfloat16StorageT = uint16_t;
24+
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
25+
using Bfloat16StorageT
26+
__SYCL_DEPRECATED("bfloat16::Bfloat16StorageT is non-standard and has "
27+
"been deprecated.") = uint16_t;
28+
#endif
2529

2630
bfloat16() = default;
2731
~bfloat16() = default;
@@ -58,7 +62,7 @@ class bfloat16 {
5862
friend bfloat16 operator-(const bfloat16 &lhs) {
5963
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \
6064
(__SYCL_CUDA_ARCH__ >= 800)
61-
Bfloat16StorageT res;
65+
uint16_t res;
6266
asm("neg.bf16 %0, %1;" : "=h"(res) : "h"(lhs.value));
6367
return bit_cast<bfloat16>(res);
6468
#else
@@ -146,18 +150,18 @@ class bfloat16 {
146150
#endif
147151

148152
private:
149-
Bfloat16StorageT value;
153+
uint16_t value;
150154

151155
// Private tag used to avoid constructor ambiguity.
152156
struct private_tag {
153157
explicit private_tag() = default;
154158
};
155159

156-
constexpr bfloat16(Bfloat16StorageT Value, private_tag) : value{Value} {}
160+
constexpr bfloat16(uint16_t Value, private_tag) : value{Value} {}
157161

158162
// Explicit conversion functions
159-
static float to_float(const Bfloat16StorageT &a);
160-
static Bfloat16StorageT from_float(const float &a);
163+
static float to_float(const uint16_t &a);
164+
static uint16_t from_float(const float &a);
161165

162166
// Friend traits.
163167
friend std::numeric_limits<bfloat16>;
@@ -178,7 +182,7 @@ class bfloat16 {
178182
extern "C" __DPCPP_SYCL_EXTERNAL float
179183
__devicelib_ConvertBF16ToFINTEL(const uint16_t &) noexcept;
180184
#endif
181-
inline float bfloat16::to_float(const bfloat16::Bfloat16StorageT &a) {
185+
inline float bfloat16::to_float(const uint16_t &a) {
182186
#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
183187
return __devicelib_ConvertBF16ToFINTEL(a);
184188
#else
@@ -213,11 +217,11 @@ inline uint16_t from_float_to_uint16_t(const float &a) {
213217
extern "C" __DPCPP_SYCL_EXTERNAL uint16_t
214218
__devicelib_ConvertFToBF16INTEL(const float &) noexcept;
215219
#endif
216-
inline bfloat16::Bfloat16StorageT bfloat16::from_float(const float &a) {
220+
inline uint16_t bfloat16::from_float(const float &a) {
217221
#if defined(__SYCL_DEVICE_ONLY__)
218222
#if defined(__NVPTX__)
219223
#if (__SYCL_CUDA_ARCH__ >= 800)
220-
Bfloat16StorageT res;
224+
uint16_t res;
221225
asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(res) : "f"(a));
222226
return res;
223227
#else

sycl/test-e2e/BFloat16/bfloat_hw.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ using get_uint_type_of_size = typename std::conditional_t<
1717
std::conditional_t<Size == 8, uint64_t, void>>>>;
1818

1919
using bfloat16 = sycl::ext::oneapi::bfloat16;
20-
using Bfloat16StorageT = get_uint_type_of_size<sizeof(bfloat16)>;
20+
static_assert(sizeof(bfloat16) == size(uint16_t));
2121

22-
bool test(float Val, Bfloat16StorageT Bits) {
22+
bool test(float Val, uint16_t Bits) {
2323
std::cout << "Value: " << Val << " Bits: " << std::hex << "0x" << Bits
2424
<< std::dec << "...\n";
2525
bool Passed = true;
2626
{
2727
std::cout << " float -> bfloat16 conversion ...";
28-
Bfloat16StorageT RawVal = sycl::bit_cast<Bfloat16StorageT>(bfloat16(Val));
28+
auto RawVal = sycl::bit_cast<uint16_t>(bfloat16(Val));
2929
bool Res = (RawVal == Bits);
3030
Passed &= Res;
3131

0 commit comments

Comments
 (0)