Skip to content

Commit 0be33ff

Browse files
committed
potential windows fix
1 parent feca30d commit 0be33ff

File tree

4 files changed

+25
-24
lines changed

4 files changed

+25
-24
lines changed

csrc/cpu/reducer.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ const std::map<std::string, ReductionType> reduce2REDUCE = {
4040
} \
4141
}()
4242

43-
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
44-
static inline scalar_t init() {
43+
template <typename scalar_t> struct Reducer {
44+
static inline scalar_t init(ReductionType REDUCE) {
4545
if (REDUCE == MUL || REDUCE == DIV)
4646
return (scalar_t)1;
4747
else if (REDUCE == MIN)
@@ -52,8 +52,8 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
5252
return (scalar_t)0;
5353
}
5454

55-
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
56-
int64_t new_arg) {
55+
static inline void update(ReductionType REDUCE, scalar_t *val,
56+
scalar_t new_val, int64_t *arg, int64_t new_arg) {
5757
if (REDUCE == SUM || REDUCE == MEAN)
5858
*val = *val + new_val;
5959
else if (REDUCE == MUL)
@@ -67,8 +67,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
6767
}
6868
}
6969

70-
static inline void write(scalar_t *address, scalar_t val,
71-
int64_t *arg_address, int64_t arg, int count) {
70+
static inline void write(ReductionType REDUCE, scalar_t *address,
71+
scalar_t val, int64_t *arg_address, int64_t arg,
72+
int count) {
7273
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
7374
*address = val;
7475
else if (REDUCE == MEAN)

csrc/cpu/scatter_cpu.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,22 +61,22 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
6161
int64_t i, idx;
6262
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
6363
if (!optional_out.has_value())
64-
out.fill_(Reducer<scalar_t, REDUCE>::init());
64+
out.fill_(Reducer<scalar_t>::init(REDUCE));
6565

6666
for (auto b = 0; b < B; b++) {
6767
for (auto e = 0; e < E; e++) {
6868
for (auto k = 0; k < K; k++) {
6969
i = b * E * K + e * K + k;
7070
idx = index_info.data[IndexToOffset<int64_t>::get(i, index_info)];
71-
Reducer<scalar_t, REDUCE>::update(
72-
out_data + b * N * K + idx * K + k, src_data[i],
71+
Reducer<scalar_t>::update(
72+
REDUCE, out_data + b * N * K + idx * K + k, src_data[i],
7373
arg_out_data + b * N * K + idx * K + k, e);
7474
}
7575
}
7676
}
7777

7878
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
79-
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
79+
out.masked_fill_(out == Reducer<scalar_t>::init(REDUCE), (scalar_t)0);
8080
});
8181
});
8282

csrc/cpu/segment_coo_cpu.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
7272
int64_t idx, next_idx, row_start;
7373
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
7474
if (!optional_out.has_value())
75-
out.fill_(Reducer<scalar_t, REDUCE>::init());
75+
out.fill_(Reducer<scalar_t>::init(REDUCE));
7676
if (REDUCE == MEAN)
7777
count_data = arg_out.value().data_ptr<scalar_t>();
7878

@@ -87,13 +87,13 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
8787
for (auto e = 0; e < E; e++) {
8888

8989
for (auto k = 0; k < K; k++)
90-
Reducer<scalar_t, REDUCE>::update(
91-
&vals[k], src_data[b * E * K + e * K + k], &args[k], e);
90+
Reducer<scalar_t>::update(
91+
REDUCE, &vals[k], src_data[b * E * K + e * K + k], &args[k], e);
9292

9393
if (e == E - 1) {
9494
for (auto k = 0; k < K; k++)
95-
Reducer<scalar_t, REDUCE>::write(
96-
out_data + b * N * K + idx * K + k, vals[k],
95+
Reducer<scalar_t>::write(
96+
REDUCE, out_data + b * N * K + idx * K + k, vals[k],
9797
arg_out_data + b * N * K + idx * K + k, args[k],
9898
e + 1 - row_start);
9999
if (REDUCE == MEAN)
@@ -104,8 +104,8 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
104104

105105
if (idx != next_idx) {
106106
for (auto k = 0; k < K; k++) {
107-
Reducer<scalar_t, REDUCE>::write(
108-
out_data + b * N * K + idx * K + k, vals[k],
107+
Reducer<scalar_t>::write(
108+
REDUCE, out_data + b * N * K + idx * K + k, vals[k],
109109
arg_out_data + b * N * K + idx * K + k, args[k],
110110
e + 1 - row_start);
111111

@@ -121,7 +121,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
121121
}
122122
}
123123
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
124-
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
124+
out.masked_fill_(out == Reducer<scalar_t>::init(REDUCE), (scalar_t)0);
125125

126126
if (REDUCE == MEAN)
127127
arg_out.value().clamp_(1);

csrc/cpu/segment_csr_cpu.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,17 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
6868

6969
offset = (n / (indptr.size(-1) - 1)) * E * K;
7070
for (auto k = 0; k < K; k++)
71-
vals[k] = Reducer<scalar_t, REDUCE>::init();
71+
vals[k] = Reducer<scalar_t>::init(REDUCE);
7272

7373
for (auto e = row_start; e < row_end; e++)
7474
for (auto k = 0; k < K; k++)
75-
Reducer<scalar_t, REDUCE>::update(
76-
&vals[k], src_data[offset + e * K + k], &args[k], e);
75+
Reducer<scalar_t>::update(
76+
REDUCE, &vals[k], src_data[offset + e * K + k], &args[k], e);
7777

7878
for (auto k = 0; k < K; k++)
79-
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
80-
arg_out_data + n * K + k, args[k],
81-
row_end - row_start);
79+
Reducer<scalar_t>::write(REDUCE, out_data + n * K + k, vals[k],
80+
arg_out_data + n * K + k, args[k],
81+
row_end - row_start);
8282
}
8383
});
8484
});

0 commit comments

Comments
 (0)