|
| 1 | +from typing import Optional |
| 2 | + |
1 | 3 | import torch
|
2 | 4 |
|
3 | 5 | from torch_scatter import scatter_sum, scatter_max
|
4 | 6 | from torch_scatter.utils import broadcast
|
5 | 7 |
|
6 | 8 |
|
7 | 9 | def scatter_softmax(src: torch.Tensor, index: torch.Tensor,
|
8 |
| - dim: int = -1) -> torch.Tensor: |
| 10 | + dim: int = -1, |
| 11 | + dim_size: Optional[int] = None) -> torch.Tensor: |
9 | 12 | if not torch.is_floating_point(src):
|
10 | 13 | raise ValueError('`scatter_softmax` can only be computed over tensors '
|
11 | 14 | 'with floating point data types.')
|
12 | 15 |
|
13 | 16 | index = broadcast(index, src, dim)
|
14 | 17 |
|
15 |
| - max_value_per_index = scatter_max(src, index, dim=dim)[0] |
| 18 | + max_value_per_index = scatter_max( |
| 19 | + src, index, dim=dim, dim_size=dim_size)[0] |
16 | 20 | max_per_src_element = max_value_per_index.gather(dim, index)
|
17 | 21 |
|
18 | 22 | recentered_scores = src - max_per_src_element
|
19 | 23 | recentered_scores_exp = recentered_scores.exp_()
|
20 | 24 |
|
21 |
| - sum_per_index = scatter_sum(recentered_scores_exp, index, dim) |
| 25 | + sum_per_index = scatter_sum( |
| 26 | + recentered_scores_exp, index, dim, dim_size=dim_size) |
22 | 27 | normalizing_constants = sum_per_index.gather(dim, index)
|
23 | 28 |
|
24 | 29 | return recentered_scores_exp.div(normalizing_constants)
|
25 | 30 |
|
26 | 31 |
|
27 | 32 | def scatter_log_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
|
28 |
| - eps: float = 1e-12) -> torch.Tensor: |
| 33 | + eps: float = 1e-12, |
| 34 | + dim_size: Optional[int] = None) -> torch.Tensor: |
29 | 35 | if not torch.is_floating_point(src):
|
30 | 36 | raise ValueError('`scatter_log_softmax` can only be computed over '
|
31 | 37 | 'tensors with floating point data types.')
|
32 | 38 |
|
33 | 39 | index = broadcast(index, src, dim)
|
34 | 40 |
|
35 |
| - max_value_per_index = scatter_max(src, index, dim=dim)[0] |
| 41 | + max_value_per_index = scatter_max( |
| 42 | + src, index, dim=dim, dim_size=dim_size)[0] |
36 | 43 | max_per_src_element = max_value_per_index.gather(dim, index)
|
37 | 44 |
|
38 | 45 | recentered_scores = src - max_per_src_element
|
39 | 46 |
|
40 |
| - sum_per_index = scatter_sum(recentered_scores.exp(), index, dim) |
| 47 | + sum_per_index = scatter_sum( |
| 48 | + recentered_scores.exp(), index, dim, dim_size=dim_size) |
41 | 49 | normalizing_constants = sum_per_index.add_(eps).log_().gather(dim, index)
|
42 | 50 |
|
43 | 51 | return recentered_scores.sub_(normalizing_constants)
|
0 commit comments