3
3
#include " index_info.h"
4
4
#include " reducer.h"
5
5
#include " utils.h"
6
+ #include < ATen/OpMathType.h>
6
7
7
8
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
8
9
segment_coo_cpu (torch::Tensor src, torch::Tensor index,
@@ -70,11 +71,12 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
70
71
auto stride = index_info.strides [index_info.dims - 1 ];
71
72
std::vector<int64_t > args (K);
72
73
AT_DISPATCH_ALL_TYPES_AND2 (at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type (), " segment_coo_cpu" , [&] {
74
+ using opmath_t = at::opmath_type<scalar_t >;
73
75
auto src_data = src.data_ptr <scalar_t >();
74
76
auto out_data = out.data_ptr <scalar_t >();
75
77
scalar_t *count_data = nullptr ;
76
78
77
- std::vector<scalar_t > vals (K);
79
+ std::vector<opmath_t > vals (K);
78
80
int64_t idx, next_idx, row_start;
79
81
AT_DISPATCH_REDUCTION_TYPES (reduce, [&] {
80
82
if (!optional_out.has_value ())
@@ -87,19 +89,19 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
87
89
idx = index_info.data [offset];
88
90
89
91
for (auto k = 0 ; k < K; k++)
90
- vals[k] = out_data[b * N * K + k];
92
+ vals[k] = static_cast < opmath_t >( out_data[b * N * K + k]) ;
91
93
92
94
row_start = 0 ;
93
95
for (auto e = 0 ; e < E; e++) {
94
96
95
97
for (auto k = 0 ; k < K; k++)
96
- Reducer<scalar_t , REDUCE>::update (
97
- &vals[k], src_data[b * E * K + e * K + k], &args[k], e);
98
+ Reducer<opmath_t , REDUCE>::update (
99
+ &vals[k], static_cast < opmath_t >( src_data[b * E * K + e * K + k]) , &args[k], e);
98
100
99
101
if (e == E - 1 ) {
100
102
for (auto k = 0 ; k < K; k++)
101
103
Reducer<scalar_t , REDUCE>::write (
102
- out_data + b * N * K + idx * K + k, vals[k],
104
+ out_data + b * N * K + idx * K + k, static_cast < scalar_t >( vals[k]) ,
103
105
arg_out_data + b * N * K + idx * K + k, args[k],
104
106
e + 1 - row_start);
105
107
if (REDUCE == MEAN)
@@ -111,11 +113,11 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
111
113
if (idx != next_idx) {
112
114
for (auto k = 0 ; k < K; k++) {
113
115
Reducer<scalar_t , REDUCE>::write (
114
- out_data + b * N * K + idx * K + k, vals[k],
116
+ out_data + b * N * K + idx * K + k, static_cast < scalar_t >( vals[k]) ,
115
117
arg_out_data + b * N * K + idx * K + k, args[k],
116
118
e + 1 - row_start);
117
119
118
- vals[k] = out_data[b * N * K + next_idx * K + k];
120
+ vals[k] = static_cast < opmath_t >( out_data[b * N * K + next_idx * K + k]) ;
119
121
}
120
122
if (REDUCE == MEAN)
121
123
count_data[b * N + idx] = (scalar_t )(e + 1 - row_start);
0 commit comments