We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 93e779f commit fc012e0Copy full SHA for fc012e0
torch_scatter/scatter.py
@@ -60,7 +60,7 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
60
index_dim = dim
61
if index_dim < 0:
62
index_dim = index_dim + src.dim()
63
- if index.dim() <= dim:
+ if index.dim() <= index_dim:
64
index_dim = index.dim() - 1
65
66
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
0 commit comments