Skip to content

Commit 6fca568

Browse files
committed
override shfl methods for torch.half
1 parent 66bcc36 commit 6fca568

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

csrc/cuda/segment_coo_cuda.cu

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <ATen/cuda/CUDAContext.h>
44
#include <ATen/cuda/detail/IndexUtils.cuh>
55
#include <ATen/cuda/detail/TensorInfo.cuh>
6-
#include <type_traits>
76

87
#include "reducer.cuh"
98
#include "utils.cuh"
@@ -26,10 +25,6 @@ segment_coo_kernel(const scalar_t *src_data,
2625
int lane_idx = row_idx & (32 - 1);
2726
int D = index_info.sizes[index_info.dims - 1];
2827

29-
using cuda_scalar_t =
30-
typename std::conditional<std::is_same<scalar_t, at::Half>::value, __half,
31-
scalar_t>::type;
32-
3328
if (row_idx < E) {
3429
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
3530
row_idx, index_info);
@@ -41,7 +36,7 @@ segment_coo_kernel(const scalar_t *src_data,
4136
#pragma unroll
4237
for (int i = 1; i < 32; i *= 2) {
4338
// Parallel reduction inside a single warp.
44-
tmp = __shfl_up_sync(FULL_MASK, (cuda_scalar_t)val, i);
39+
tmp = __shfl_up_sync(FULL_MASK, val, i);
4540
next_idx = __shfl_up_sync(FULL_MASK, idx, i);
4641
if (lane_idx >= i && row_idx / D == (row_idx - i) / D) {
4742
assert(idx >= next_idx);

csrc/cuda/segment_csr_cuda.cu

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@ segment_csr_kernel(const scalar_t *src_data,
2626
int row_idx = thread_idx / TB;
2727
int lane_idx = thread_idx & (TB - 1);
2828

29-
using cuda_scalar_t =
30-
typename std::conditional<std::is_same<scalar_t, at::Half>::value, __half,
31-
scalar_t>::type;
32-
3329
if (row_idx < N) {
3430
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
3531
int64_t row_start = __ldg(indptr_info.data + offset);
@@ -52,8 +48,7 @@ segment_csr_kernel(const scalar_t *src_data,
5248
if (REDUCE == MIN || REDUCE == MAX)
5349
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
5450
Reducer<scalar_t, REDUCE>::update(
55-
&val, __shfl_down_sync(FULL_MASK, (cuda_scalar_t)val, i), &arg,
56-
arg_tmp);
51+
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
5752
}
5853

5954
if (lane_idx == 0) {

csrc/cuda/utils.cuh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,15 @@
55
#define CHECK_CUDA(x) \
66
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
77
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
8+
9+
__device__ __inline__ at::Half __shfl_up_sync(const unsigned mask,
10+
const at::Half var,
11+
const unsigned int delta) {
12+
return __shfl_up_sync(mask, (__half)var, delta);
13+
}
14+
15+
__device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
16+
const at::Half var,
17+
const unsigned int delta) {
18+
return __shfl_down_sync(mask, (__half)var, delta);
19+
}

0 commit comments

Comments
 (0)