Skip to content

Commit fc012e0

Browse files
committed
bugfix for scatter mean
1 parent 93e779f commit fc012e0

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch_scatter/scatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
6060
index_dim = dim
6161
if index_dim < 0:
6262
index_dim = index_dim + src.dim()
63-
if index.dim() <= dim:
63+
if index.dim() <= index_dim:
6464
index_dim = index.dim() - 1
6565

6666
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)

0 commit comments

Comments
 (0)