Skip to content

Commit 605566e

Browse files
author
Miltos
authored
Add optional parameter to scatter_[log_]softmax to indicate the number of segments (#243)
* Add parameter to `scatter_[log_]softmax` * Update softmax.py * Format code.
1 parent 26844e1 commit 605566e

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

torch_scatter/composite/softmax.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,51 @@
1+
from typing import Optional
2+
13
import torch
24

35
from torch_scatter import scatter_sum, scatter_max
46
from torch_scatter.utils import broadcast
57

68

79
def scatter_softmax(src: torch.Tensor, index: torch.Tensor,
8-
dim: int = -1) -> torch.Tensor:
10+
dim: int = -1,
11+
dim_size: Optional[int] = None) -> torch.Tensor:
912
if not torch.is_floating_point(src):
1013
raise ValueError('`scatter_softmax` can only be computed over tensors '
1114
'with floating point data types.')
1215

1316
index = broadcast(index, src, dim)
1417

15-
max_value_per_index = scatter_max(src, index, dim=dim)[0]
18+
max_value_per_index = scatter_max(
19+
src, index, dim=dim, dim_size=dim_size)[0]
1620
max_per_src_element = max_value_per_index.gather(dim, index)
1721

1822
recentered_scores = src - max_per_src_element
1923
recentered_scores_exp = recentered_scores.exp_()
2024

21-
sum_per_index = scatter_sum(recentered_scores_exp, index, dim)
25+
sum_per_index = scatter_sum(
26+
recentered_scores_exp, index, dim, dim_size=dim_size)
2227
normalizing_constants = sum_per_index.gather(dim, index)
2328

2429
return recentered_scores_exp.div(normalizing_constants)
2530

2631

2732
def scatter_log_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
28-
eps: float = 1e-12) -> torch.Tensor:
33+
eps: float = 1e-12,
34+
dim_size: Optional[int] = None) -> torch.Tensor:
2935
if not torch.is_floating_point(src):
3036
raise ValueError('`scatter_log_softmax` can only be computed over '
3137
'tensors with floating point data types.')
3238

3339
index = broadcast(index, src, dim)
3440

35-
max_value_per_index = scatter_max(src, index, dim=dim)[0]
41+
max_value_per_index = scatter_max(
42+
src, index, dim=dim, dim_size=dim_size)[0]
3643
max_per_src_element = max_value_per_index.gather(dim, index)
3744

3845
recentered_scores = src - max_per_src_element
3946

40-
sum_per_index = scatter_sum(recentered_scores.exp(), index, dim)
47+
sum_per_index = scatter_sum(
48+
recentered_scores.exp(), index, dim, dim_size=dim_size)
4149
normalizing_constants = sum_per_index.add_(eps).log_().gather(dim, index)
4250

4351
return recentered_scores.sub_(normalizing_constants)

0 commit comments

Comments
 (0)