Skip to content

Commit cf8cf0c

Browse files
committed
fix backward pass
1 parent 5b1737a commit cf8cf0c

File tree

5 files changed

+17
-5
lines changed

5 files changed

+17
-5
lines changed

test/composite/test_logsumexp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
def test_logsumexp():
66
src = torch.tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, -100])
7+
src.requires_grad_()
78
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
89

910
out = scatter_logsumexp(src, index)
@@ -16,3 +17,5 @@ def test_logsumexp():
1617

1718
expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
1819
assert torch.allclose(out, expected)
20+
21+
out.backward(torch.randn_like(out))

test/composite/test_softmax.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
def test_softmax():
66
src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
7+
src.requires_grad_()
78
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
89

910
out = scatter_softmax(src, index)
@@ -19,9 +20,12 @@ def test_softmax():
1920

2021
assert torch.allclose(out, expected)
2122

23+
out.backward(torch.randn_like(out))
24+
2225

2326
def test_log_softmax():
2427
src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')])
28+
src.requires_grad_()
2529
index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])
2630

2731
out = scatter_log_softmax(src, index)
@@ -36,3 +40,5 @@ def test_log_softmax():
3640
], dim=0)
3741

3842
assert torch.allclose(out, expected)
43+
44+
out.backward(torch.randn_like(out))

test/composite/test_std.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44

55
def test_std():
66
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype=torch.float)
7+
src.requires_grad_()
78
index = torch.tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], dtype=torch.long)
89

910
out = scatter_std(src, index, dim=-1, unbiased=True)
1011
std = src.std(dim=-1, unbiased=True)[0]
1112
expected = torch.tensor([[std, 0], [0, std]])
1213
assert torch.allclose(out, expected)
14+
15+
out.backward(torch.randn_like(out))

torch_scatter/composite/softmax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
1717
max_per_src_element = max_value_per_index.gather(dim, index)
1818

1919
recentered_scores = src - max_per_src_element
20-
recentered_scores_exp = recentered_scores.exp_()
20+
recentered_scores_exp = recentered_scores.exp()
2121

2222
sum_per_index = scatter_sum(recentered_scores_exp, index, dim)
2323
normalizing_constants = sum_per_index.add_(eps).gather(dim, index)
2424

25-
return recentered_scores_exp.div_(normalizing_constants)
25+
return recentered_scores_exp.div(normalizing_constants)
2626

2727

2828
@torch.jit.script

torch_scatter/composite/std.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
2727
index = broadcast(index, src, dim)
2828
tmp = scatter_sum(src, index, dim, dim_size=dim_size)
2929
count = broadcast(count, tmp, dim).clamp_(1)
30-
mean = tmp.div_(count)
30+
mean = tmp.div(count)
3131

3232
var = (src - mean.gather(dim, index))
3333
var = var * var
3434
out = scatter_sum(var, index, dim, out, dim_size)
3535

3636
if unbiased:
37-
count.sub_(1).clamp_(1)
38-
out.div_(count).sqrt_()
37+
count = count.sub(1).clamp_(1)
38+
out = out.div(count).sqrt()
3939

4040
return out

0 commit comments

Comments
 (0)