Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions sycl/include/sycl/detail/vector_arith.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@

#include <sycl/aliases.hpp>
#include <sycl/detail/generic_type_traits.hpp>
#include <sycl/detail/type_traits.hpp>
#include <sycl/detail/type_traits/vec_marray_traits.hpp>
#include <sycl/ext/oneapi/bfloat16.hpp>

#include <functional>

Expand Down
6 changes: 3 additions & 3 deletions sycl/include/sycl/detail/vector_convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@

#pragma once

#include <sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s...
#include <sycl/exception.hpp> // for errc
#include <sycl/detail/generic_type_traits.hpp>

#include <sycl/detail/memcpy.hpp>
#include <sycl/ext/oneapi/bfloat16.hpp>
#include <sycl/half_type.hpp>
#include <sycl/vector.hpp>

#ifndef __SYCL_DEVICE_ONLY__
#include <cfenv> // for fesetround, fegetround
#include <cfenv>
#include <sycl/exception.hpp>
#endif

#include <type_traits>
Expand Down
6 changes: 3 additions & 3 deletions sycl/include/sycl/ext/oneapi/experimental/cuda/builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ ldg(const T *ptr) {
} else if constexpr (std::is_same_v<T, sycl::vec<half, 2>>) {
typedef __fp16 h2 ATTRIBUTE_EXT_VEC_TYPE(2);
auto rv = __nvvm_ldg_h2(reinterpret_cast<const h2 *>(ptr));
sycl::vec<half, 2> ret;
T ret;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just makes this statement template-type-dependent, so we don't need a definition of half unless the caller used it (and hence already had it).

ret.x() = rv[0];
ret.y() = rv[1];
return ret;
Expand All @@ -376,7 +376,7 @@ ldg(const T *ptr) {
h2 rv_2 = __nvvm_ldg_h2(reinterpret_cast<const h2 *>(ptr));
auto rv = __nvvm_ldg_h(reinterpret_cast<const __fp16 *>(
std::next(reinterpret_cast<const h2 *>(ptr))));
sycl::vec<half, 3> ret;
T ret;
ret.x() = rv_2[0];
ret.y() = rv_2[1];
ret.z() = rv;
Expand All @@ -385,7 +385,7 @@ ldg(const T *ptr) {
typedef __fp16 h2 ATTRIBUTE_EXT_VEC_TYPE(2);
auto rv1 = __nvvm_ldg_h2(reinterpret_cast<const h2 *>(ptr));
auto rv2 = __nvvm_ldg_h2(std::next(reinterpret_cast<const h2 *>(ptr)));
sycl::vec<half, 4> ret;
T ret;
ret.x() = rv1[0];
ret.y() = rv1[1];
ret.z() = rv2[0];
Expand Down
11 changes: 3 additions & 8 deletions sycl/include/sycl/marray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,15 @@

#include <sycl/aliases.hpp>
#include <sycl/detail/common.hpp>
#include <sycl/detail/is_device_copyable.hpp>
#include <sycl/half_type.hpp>

#include <array>
#include <cstddef>
#include <cstdint>
#include <type_traits>
#include <utility>
#include <sycl/detail/fwd/half.hpp>

namespace sycl {
inline namespace _V1 {

template <typename DataT, std::size_t N> class marray;

template <typename T> struct is_device_copyable;

namespace detail {

// Helper trait for counting the aggregate number of arguments in a type list,
Expand Down
32 changes: 10 additions & 22 deletions sycl/include/sycl/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,16 @@
#error "SYCL device compiler is built without ext_vector_type support"
#endif

#include <sycl/access/access.hpp> // for decorated, address_space
#include <sycl/aliases.hpp> // for half, cl_char, cl_int
#include <sycl/detail/common.hpp> // for ArrayCreator
#include <sycl/detail/defines_elementary.hpp> // for __SYCL2020_DEPRECATED
#include <sycl/detail/fwd/accessor.hpp>
#include <sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s...
#include <sycl/detail/memcpy.hpp> // for memcpy
#include <sycl/detail/named_swizzles_mixin.hpp>
#include <sycl/detail/type_traits.hpp> // for is_floating_point
#include <sycl/detail/vector_arith.hpp>
#include <sycl/half_type.hpp> // for StorageT, half, Vec16...

#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16
#include <sycl/detail/common.hpp>
#include <sycl/detail/fwd/accessor.hpp>
#include <sycl/detail/fwd/half.hpp>
#include <sycl/detail/memcpy.hpp>

#include <algorithm> // for std::min
#include <array> // for array
#include <cassert> // for assert
#include <cstddef> // for size_t, NULL, byte
#include <cstdint> // for uint8_t, int16_t, int...
#include <functional> // for divides, multiplies
#include <iterator> // for pair
#include <ostream> // for operator<<, basic_ost...
#include <type_traits> // for enable_if_t, is_same
#include <utility> // for index_sequence, make_...
#include <algorithm>
#include <functional>

namespace sycl {

Expand All @@ -63,6 +49,9 @@ namespace sycl {
enum class rounding_mode { automatic = 0, rte = 1, rtz = 2, rtp = 3, rtn = 4 };

inline namespace _V1 {
namespace ext::oneapi {
class bfloat16;
}

struct elem {
static constexpr int x = 0;
Expand Down Expand Up @@ -512,8 +501,7 @@ class __SYCL_EBO vec :
#endif
bool, /*->*/ std::uint8_t, //
sycl::half, /*->*/ sycl::detail::half_impl::StorageT, //
sycl::ext::oneapi::bfloat16,
/*->*/ sycl::ext::oneapi::bfloat16::Bfloat16StorageT, //
sycl::ext::oneapi::bfloat16, /*->*/ uint16_t, //
char, /*->*/ detail::ConvertToOpenCLType_t<char>, //
DataT, /*->*/ DataT //
>::type;
Expand Down