Skip to content

Commit f0fdfe2

Browse files
committed
correct gradient computations
1 parent 43432f4 commit f0fdfe2

File tree

8 files changed

+180
-84
lines changed

8 files changed

+180
-84
lines changed

torch_scatter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .functions.max import scatter_max_, scatter_max
77
from .functions.min import scatter_min_, scatter_min
88

9-
__version__ = '0.2.3'
9+
__version__ = '0.3.0'
1010

1111
__all__ = [
1212
'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',

torch_scatter/functions/div.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
1-
from .scatter import scatter
1+
from .scatter import Scatter, scatter
22
from .utils import gen_output
33

44

5+
class ScatterDiv(Scatter):
6+
def __init__(self, dim):
7+
super(ScatterDiv, self).__init__('div', dim)
8+
9+
def save_for_backward_step(self, *data):
10+
output, index, input = data
11+
self.save_for_backward(output, index, input)
12+
13+
def backward_step(self, *data):
14+
grad, output, index, input = data
15+
return (grad / output.data).gather(self.dim, index.data) * input.data
16+
17+
518
def scatter_div_(output, index, input, dim=0):
619
r"""
720
|
@@ -53,7 +66,7 @@ def scatter_div_(output, index, input, dim=0):
5366
0.5000 0.2500 0.1667 1.0000 1.0000 1.0000
5467
[torch.FloatTensor of size 2x6]
5568
"""
56-
return scatter('div', dim, output, index, input)
69+
return scatter(ScatterDiv, 'div', dim, output, index, input)
5770

5871

5972
def scatter_div(index, input, dim=0, size=None, fill_value=1):

torch_scatter/functions/ffi.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from itertools import chain
2+
3+
from .._ext import ffi
4+
5+
6+
def scatter(name, dim, *data):
7+
# data = output, index, input, additional data
8+
a, b, c = data[:3]
9+
10+
# Assert index dimension is valid.
11+
assert dim >= 0 and dim < b.dim(), 'Index dimension is out of bounds'
12+
13+
# Assert same dimensionality across all inputs.
14+
assert b.dim() == c.dim(), ('Index tensor must have same dimensions as '
15+
'input tensor')
16+
assert a.dim() == c.dim(), ('Input tensor must have same dimensions as '
17+
'output tensor')
18+
19+
# Assert same tensor length across index and input.
20+
assert b.numel() == c.numel(), ('Index tensor must have same size as '
21+
'input tensor')
22+
23+
# Assert same tensor sizes across input and output apart from `dim`.
24+
for d in chain(range(dim), range(dim + 1, a.dim())):
25+
assert a.size(d) == c.size(d), (
26+
'Input tensor must have same size as output tensor apart from the '
27+
'specified dimension')
28+
29+
typename = type(data[0]).__name__.replace('Tensor', '')
30+
cuda = 'cuda_' if data[0].is_cuda else ''
31+
func = getattr(ffi, 'scatter_{}_{}{}'.format(name, cuda, typename))
32+
func(dim, *data)
33+
34+
if len(data) <= 3:
35+
return data[0]
36+
37+
return (data[0], ) + tuple(data[3:])
38+
39+
40+
def index_backward(dim, index, grad, arg): # pragma: no cover
41+
typename = type(grad).__name__.replace('Tensor', '')
42+
cuda = 'cuda_' if grad.is_cuda else ''
43+
func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
44+
output = grad.new(index.size()).fill_(0)
45+
func(dim, output, index, grad, arg)
46+
return output

torch_scatter/functions/max.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,21 @@
1-
from .scatter import scatter
1+
from .scatter import Scatter, scatter
2+
from .ffi import index_backward
23
from .utils import gen_filled_tensor, gen_output
34

45

6+
class ScatterMax(Scatter):
7+
def __init__(self, dim):
8+
super(ScatterMax, self).__init__('max', dim)
9+
10+
def save_for_backward_step(self, *data):
11+
output, index, input, arg = data
12+
self.save_for_backward(index, arg)
13+
14+
def backward_step(self, *data):
15+
grad, index, arg = data
16+
return index_backward(self.dim, index.data, grad, arg.data)
17+
18+
519
def scatter_max_(output, index, input, dim=0):
620
r"""
721
|
@@ -61,7 +75,7 @@ def scatter_max_(output, index, input, dim=0):
6175
)
6276
"""
6377
arg = gen_filled_tensor(index, output.size(), fill_value=-1)
64-
return scatter('max', dim, output, index, input, arg)
78+
return scatter(ScatterMax, 'max', dim, output, index, input, arg)
6579

6680

6781
def scatter_max(index, input, dim=0, size=None, fill_value=0):

torch_scatter/functions/mean.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
11
from __future__ import division
22

3-
from .scatter import scatter
3+
from .scatter import Scatter, scatter
44
from .utils import gen_filled_tensor, gen_output
55

66

7+
class ScatterMean(Scatter):
8+
def __init__(self, dim):
9+
super(ScatterMean, self).__init__('mean', dim)
10+
11+
def save_for_backward_step(self, *data):
12+
output, index, input, count = data
13+
self.save_for_backward(index)
14+
15+
def backward_step(self, *data):
16+
grad, index = data
17+
return grad.gather(self.dim, index.data)
18+
19+
720
def scatter_mean_(output, index, input, dim=0):
821
r"""
922
|
@@ -56,10 +69,12 @@ def scatter_mean_(output, index, input, dim=0):
5669
1.0000 4.0000 2.0000 0.0000 0.0000 0.0000
5770
[torch.FloatTensor of size 2x6]
5871
"""
72+
init = gen_filled_tensor(output, output.size(), fill_value=0)
5973
count = gen_filled_tensor(output, output.size(), fill_value=0)
60-
scatter('mean', dim, output, index, input, count)
74+
scatter(ScatterMean, 'mean', dim, init, index, input, count)
6175
count[count == 0] = 1
62-
output /= count
76+
init /= count
77+
output += init
6378
return output
6479

6580

torch_scatter/functions/min.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,21 @@
1-
from .scatter import scatter
1+
from .scatter import Scatter, scatter
2+
from .ffi import index_backward
23
from .utils import gen_filled_tensor, gen_output
34

45

6+
class ScatterMin(Scatter):
7+
def __init__(self, dim):
8+
super(ScatterMin, self).__init__('min', dim)
9+
10+
def save_for_backward_step(self, *data):
11+
output, index, input, arg = data
12+
self.save_for_backward(index, arg)
13+
14+
def backward_step(self, *data):
15+
grad, index, arg = data
16+
return index_backward(self.dim, index.data, grad, arg.data)
17+
18+
519
def scatter_min_(output, index, input, dim=0):
620
r"""
721
|
@@ -61,7 +75,7 @@ def scatter_min_(output, index, input, dim=0):
6175
)
6276
"""
6377
arg = gen_filled_tensor(index, output.size(), fill_value=-1)
64-
return scatter('min', dim, output, index, input, arg)
78+
return scatter(ScatterMin, 'min', dim, output, index, input, arg)
6579

6680

6781
def scatter_min(index, input, dim=0, size=None, fill_value=0):

torch_scatter/functions/mul.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
1-
from .scatter import scatter
1+
from .scatter import Scatter, scatter
22
from .utils import gen_output
33

44

5+
class ScatterMul(Scatter):
6+
def __init__(self, dim):
7+
super(ScatterMul, self).__init__('mul', dim)
8+
9+
def save_for_backward_step(self, *data):
10+
output, index, input = data
11+
self.save_for_backward(output, index, input)
12+
13+
def backward_step(self, *data):
14+
grad, output, index, input = data
15+
return (grad * output.data).gather(self.dim, index.data) / input.data
16+
17+
518
def scatter_mul_(output, index, input, dim=0):
619
r"""
720
|
@@ -52,7 +65,7 @@ def scatter_mul_(output, index, input, dim=0):
5265
6 4 8 1 1 1
5366
[torch.FloatTensor of size 2x6]
5467
"""
55-
return scatter('mul', dim, output, index, input)
68+
return scatter(ScatterMul, 'mul', dim, output, index, input)
5669

5770

5871
def scatter_mul(index, input, dim=0, size=None, fill_value=1):

torch_scatter/functions/scatter.py

Lines changed: 53 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,82 @@
1-
from itertools import chain
2-
31
import torch
42
from torch.autograd import Function
53

6-
from .._ext import ffi
7-
8-
9-
def has_arg(name):
10-
return name in ['max', 'min']
11-
12-
13-
def _scatter(name, dim, *data):
14-
a, b, c = data[:3]
15-
16-
# Assert index dimension is valid.
17-
assert dim >= 0 and dim < a.dim(), 'Index dimension is out of bounds'
18-
19-
# Assert same dimensionality across all inputs.
20-
assert b.dim() == c.dim(), ('Index tensor must have same dimensions as '
21-
'input tensor')
22-
assert a.dim() == c.dim(), ('Input tensor must have same dimensions as '
23-
'output tensor')
24-
25-
# Assert same tensor length across index and input.
26-
assert b.numel() == c.numel(), ('Index tensor must have same size as '
27-
'input tensor')
4+
from .ffi import scatter as ffi_scatter
285

29-
# Assert same tensor sizes across input and output apart from `dim`.
30-
for d in chain(range(dim), range(dim + 1, a.dim())):
31-
assert a.size(d) == c.size(d), (
32-
'Input tensor must have same size as output tensor apart from the '
33-
'specified dimension')
346

35-
typename = type(data[0]).__name__.replace('Tensor', '')
36-
cuda = 'cuda_' if data[0].is_cuda else ''
37-
func = getattr(ffi, 'scatter_{}_{}{}'.format(name, cuda, typename))
38-
func(dim, *data)
39-
return (data[0], data[3]) if has_arg(name) else data[0]
40-
41-
42-
def index_backward(dim, index, grad, arg): # pragma: no cover
43-
typename = type(grad).__name__.replace('Tensor', '')
44-
cuda = 'cuda_' if grad.is_cuda else ''
45-
func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
46-
output = grad.new(index.size()).fill_(0)
47-
func(dim, output, index, grad, arg)
48-
return output
49-
50-
51-
class _Scatter(Function):
7+
class Scatter(Function):
528
def __init__(self, name, dim):
53-
super(_Scatter, self).__init__()
9+
super(Scatter, self).__init__()
5410
self.name = name
5511
self.dim = dim
5612

13+
def save_for_backward_step(self, *data):
14+
raise NotImplementedError
15+
5716
def forward(self, *data):
5817
assert not self.needs_input_grad[1], 'Can\'t differentiate the index'
5918

6019
self.mark_dirty(data[0]) # Mark output as dirty.
6120
self.len = len(data) # Save number of arguments for backward step.
6221

63-
_scatter(self.name, self.dim, *data)
22+
output = ffi_scatter(self.name, self.dim, *data)
23+
self.save_for_backward_step(*data)
6424

65-
# `scatter_min` and `scatter_max` additionally return the `argmax`
66-
# respectively `argmin`. Therefore, we need to save the `arg` for the
67-
# backward pass.
68-
if has_arg(self.name):
69-
self.save_for_backward(data[1], data[3])
70-
return data[0], data[3]
71-
else:
72-
self.save_for_backward(data[1])
73-
return data[0]
25+
return output
7426

7527
def backward(self, *data): # pragma: no cover
7628
grad_output = grad_input = None
7729

7830
if self.needs_input_grad[0]:
7931
grad_output = data[0]
8032

81-
# Different grad computation of `input` if `scatter_max` or
82-
# `scatter_min` was used.
83-
if self.needs_input_grad[2] and not has_arg(self.name):
84-
index, = self.saved_variables
85-
grad_input = data[0].gather(self.dim, index.data)
86-
87-
if self.needs_input_grad[2] and has_arg(self.name):
88-
index, arg = self.saved_variables
89-
data = (index.data, data[0], arg.data)
90-
grad_input = index_backward(self.dim, *data)
33+
# Call grad computation of `input` for the specific scatter operation.
34+
if self.needs_input_grad[2]:
35+
grad_input = self.backward_step(data[0], *self.saved_variables)
9136

92-
# Return and fill with empty grads for non-differentiable passed
93-
# arguments in forward pass.
37+
# Return and fill with empty grads for non-differentiable arguments.
9438
return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
9539

40+
def backward_step(self, *data):
41+
raise NotImplementedError
42+
9643

97-
def scatter(name, dim, *data):
44+
def scatter(Clx, name, dim, *data):
9845
if torch.is_tensor(data[0]):
99-
return _scatter(name, dim, *data)
46+
return ffi_scatter(name, dim, *data)
10047
else:
101-
return _Scatter(name, dim)(*data)
48+
return Clx(dim)(*data)
49+
50+
51+
# def index_backward(dim, index, grad, arg): # pragma: no cover
52+
# typename = type(grad).__name__.replace('Tensor', '')
53+
# cuda = 'cuda_' if grad.is_cuda else ''
54+
# func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
55+
# output = grad.new(index.size()).fill_(0)
56+
# func(dim, output, index, grad, arg)
57+
# return output
58+
59+
# def _scatter_backward(name, dim, saved, *data):
60+
# # saved = (index, ), (index, arg) or (index, count)
61+
62+
# print(name)
63+
# print(len(data))
64+
# print(len(saved))
65+
# print(saved[1].size())
66+
# # data = (grad, )
67+
# # index, = seved
68+
# if has_arg(name):
69+
# return index_backward(dim, saved[0].data, data[0], saved[1].data)
70+
71+
# if has_count(name):
72+
# return (data[0] / saved[1]).gather(dim, saved[0].data)
73+
# # Different grad computation of `input` if `scatter_max` or
74+
# # `scatter_min` was used.
75+
# # if self.needs_input_grad[2] and not has_arg(self.name):
76+
# # index, = self.saved_variables
77+
# # grad_input = data[0].gather(self.dim, index.data)
78+
79+
# # if self.needs_input_grad[2] and has_arg(self.name):
80+
# # index, arg = self.saved_variables
81+
# # data = (index.data, data[0], arg.data)
82+
# grad_input = index_backward(self.dim, *data)

0 commit comments

Comments
 (0)