Skip to content

Commit c72f36c

Browse files
committed
upgrade to PyTorch 1.10
1 parent 605566e commit c72f36c

File tree

4 files changed

+4
-5
lines changed

4 files changed

+4
-5
lines changed

.github/workflows/testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
matrix:
1212
os: [ubuntu-latest, windows-latest]
1313
python-version: [3.6]
14-
torch-version: [1.8.0, 1.9.0]
14+
torch-version: [1.9.0, 1.10.0]
1515

1616
steps:
1717
- uses: actions/checkout@v2

csrc/cuda/segment_coo_cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
274274
if (out.is_floating_point())
275275
out.true_divide_(count);
276276
else
277-
out.floor_divide_(count);
277+
out.div_(count, "floor");
278278
}
279279
});
280280
});

csrc/scatter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class ScatterMean : public torch::autograd::Function<ScatterMean> {
132132
if (out.is_floating_point())
133133
out.true_divide_(count);
134134
else
135-
out.floor_divide_(count);
135+
out.div_(count, "floor");
136136

137137
ctx->save_for_backward({index, count});
138138
if (optional_out.has_value())

torch_scatter/scatter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
3838
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
3939
out: Optional[torch.Tensor] = None,
4040
dim_size: Optional[int] = None) -> torch.Tensor:
41-
4241
out = scatter_sum(src, index, dim, out, dim_size)
4342
dim_size = out.size(dim)
4443

@@ -55,7 +54,7 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
5554
if out.is_floating_point():
5655
out.true_divide_(count)
5756
else:
58-
out.floor_divide_(count)
57+
out.div_(count, rounding_mode='floor')
5958
return out
6059

6160

0 commit comments

Comments
 (0)