Skip to content

Commit 48cede7

Browse files
authored
Optimize segment_coo and segment_csr BFloat16/Half implementation in CPU backend (#375)
* Optimize segment_csr_cpu BFloat16 implementation * Optimize segment_coo_cpu BFloat16 implementation
1 parent dbf42c4 commit 48cede7

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

csrc/cpu/segment_coo_cpu.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "index_info.h"
44
#include "reducer.h"
55
#include "utils.h"
6+
#include <ATen/OpMathType.h>
67

78
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
89
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
@@ -70,11 +71,12 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
7071
auto stride = index_info.strides[index_info.dims - 1];
7172
std::vector<int64_t> args(K);
7273
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "segment_coo_cpu", [&] {
74+
using opmath_t = at::opmath_type<scalar_t>;
7375
auto src_data = src.data_ptr<scalar_t>();
7476
auto out_data = out.data_ptr<scalar_t>();
7577
scalar_t *count_data = nullptr;
7678

77-
std::vector<scalar_t> vals(K);
79+
std::vector<opmath_t> vals(K);
7880
int64_t idx, next_idx, row_start;
7981
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
8082
if (!optional_out.has_value())
@@ -87,19 +89,19 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
8789
idx = index_info.data[offset];
8890

8991
for (auto k = 0; k < K; k++)
90-
vals[k] = out_data[b * N * K + k];
92+
vals[k] = static_cast<opmath_t>(out_data[b * N * K + k]);
9193

9294
row_start = 0;
9395
for (auto e = 0; e < E; e++) {
9496

9597
for (auto k = 0; k < K; k++)
96-
Reducer<scalar_t, REDUCE>::update(
97-
&vals[k], src_data[b * E * K + e * K + k], &args[k], e);
98+
Reducer<opmath_t, REDUCE>::update(
99+
&vals[k], static_cast<opmath_t>(src_data[b * E * K + e * K + k]), &args[k], e);
98100

99101
if (e == E - 1) {
100102
for (auto k = 0; k < K; k++)
101103
Reducer<scalar_t, REDUCE>::write(
102-
out_data + b * N * K + idx * K + k, vals[k],
104+
out_data + b * N * K + idx * K + k, static_cast<scalar_t>(vals[k]),
103105
arg_out_data + b * N * K + idx * K + k, args[k],
104106
e + 1 - row_start);
105107
if (REDUCE == MEAN)
@@ -111,11 +113,11 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
111113
if (idx != next_idx) {
112114
for (auto k = 0; k < K; k++) {
113115
Reducer<scalar_t, REDUCE>::write(
114-
out_data + b * N * K + idx * K + k, vals[k],
116+
out_data + b * N * K + idx * K + k, static_cast<scalar_t>(vals[k]),
115117
arg_out_data + b * N * K + idx * K + k, args[k],
116118
e + 1 - row_start);
117119

118-
vals[k] = out_data[b * N * K + next_idx * K + k];
120+
vals[k] = static_cast<opmath_t>(out_data[b * N * K + next_idx * K + k]);
119121
}
120122
if (REDUCE == MEAN)
121123
count_data[b * N + idx] = (scalar_t)(e + 1 - row_start);

csrc/cpu/segment_csr_cpu.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "index_info.h"
44
#include "reducer.h"
55
#include "utils.h"
6+
#include <ATen/OpMathType.h>
67

78
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
89
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
@@ -58,10 +59,11 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
5859
auto stride = indptr_info.strides[indptr_info.dims - 1];
5960
std::vector<int64_t> args(K);
6061
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "segment_csr_cpu", [&] {
62+
using opmath_t = at::opmath_type<scalar_t>;
6163
auto src_data = src.data_ptr<scalar_t>();
6264
auto out_data = out.data_ptr<scalar_t>();
6365

64-
std::vector<scalar_t> vals(K);
66+
std::vector<opmath_t> vals(K);
6567
int64_t row_start, row_end;
6668
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
6769
for (auto n = 0; n < N; n++) {
@@ -71,15 +73,15 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
7173

7274
offset = (n / (indptr.size(-1) - 1)) * E * K;
7375
for (auto k = 0; k < K; k++)
74-
vals[k] = Reducer<scalar_t, REDUCE>::init();
76+
vals[k] = Reducer<opmath_t, REDUCE>::init();
7577

7678
for (auto e = row_start; e < row_end; e++)
7779
for (auto k = 0; k < K; k++)
78-
Reducer<scalar_t, REDUCE>::update(
79-
&vals[k], src_data[offset + e * K + k], &args[k], e);
80+
Reducer<opmath_t, REDUCE>::update(
81+
&vals[k], static_cast<opmath_t>(src_data[offset + e * K + k]), &args[k], e);
8082

8183
for (auto k = 0; k < K; k++)
82-
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
84+
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, static_cast<scalar_t>(vals[k]),
8385
arg_out_data + n * K + k, args[k],
8486
row_end - row_start);
8587
}

0 commit comments

Comments
 (0)