|
1 |
| -from itertools import chain |
2 |
| - |
3 | 1 | import torch
|
4 | 2 | from torch.autograd import Function
|
5 | 3 |
|
6 |
| -from .._ext import ffi |
7 |
| - |
8 |
| - |
9 |
| -def has_arg(name): |
10 |
| - return name in ['max', 'min'] |
11 |
| - |
12 |
| - |
13 |
| -def _scatter(name, dim, *data): |
14 |
| - a, b, c = data[:3] |
15 |
| - |
16 |
| - # Assert index dimension is valid. |
17 |
| - assert dim >= 0 and dim < a.dim(), 'Index dimension is out of bounds' |
18 |
| - |
19 |
| - # Assert same dimensionality across all inputs. |
20 |
| - assert b.dim() == c.dim(), ('Index tensor must have same dimensions as ' |
21 |
| - 'input tensor') |
22 |
| - assert a.dim() == c.dim(), ('Input tensor must have same dimensions as ' |
23 |
| - 'output tensor') |
24 |
| - |
25 |
| - # Assert same tensor length across index and input. |
26 |
| - assert b.numel() == c.numel(), ('Index tensor must have same size as ' |
27 |
| - 'input tensor') |
| 4 | +from .ffi import scatter as ffi_scatter |
28 | 5 |
|
29 |
| - # Assert same tensor sizes across input and output apart from `dim`. |
30 |
| - for d in chain(range(dim), range(dim + 1, a.dim())): |
31 |
| - assert a.size(d) == c.size(d), ( |
32 |
| - 'Input tensor must have same size as output tensor apart from the ' |
33 |
| - 'specified dimension') |
34 | 6 |
|
35 |
| - typename = type(data[0]).__name__.replace('Tensor', '') |
36 |
| - cuda = 'cuda_' if data[0].is_cuda else '' |
37 |
| - func = getattr(ffi, 'scatter_{}_{}{}'.format(name, cuda, typename)) |
38 |
| - func(dim, *data) |
39 |
| - return (data[0], data[3]) if has_arg(name) else data[0] |
40 |
| - |
41 |
| - |
42 |
| -def index_backward(dim, index, grad, arg): # pragma: no cover |
43 |
| - typename = type(grad).__name__.replace('Tensor', '') |
44 |
| - cuda = 'cuda_' if grad.is_cuda else '' |
45 |
| - func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename)) |
46 |
| - output = grad.new(index.size()).fill_(0) |
47 |
| - func(dim, output, index, grad, arg) |
48 |
| - return output |
49 |
| - |
50 |
| - |
51 |
| -class _Scatter(Function): |
| 7 | +class Scatter(Function): |
52 | 8 | def __init__(self, name, dim):
|
53 |
| - super(_Scatter, self).__init__() |
| 9 | + super(Scatter, self).__init__() |
54 | 10 | self.name = name
|
55 | 11 | self.dim = dim
|
56 | 12 |
|
| 13 | + def save_for_backward_step(self, *data): |
| 14 | + raise NotImplementedError |
| 15 | + |
57 | 16 | def forward(self, *data):
|
58 | 17 | assert not self.needs_input_grad[1], 'Can\'t differentiate the index'
|
59 | 18 |
|
60 | 19 | self.mark_dirty(data[0]) # Mark output as dirty.
|
61 | 20 | self.len = len(data) # Save number of arguments for backward step.
|
62 | 21 |
|
63 |
| - _scatter(self.name, self.dim, *data) |
| 22 | + output = ffi_scatter(self.name, self.dim, *data) |
| 23 | + self.save_for_backward_step(*data) |
64 | 24 |
|
65 |
| - # `scatter_min` and `scatter_max` additionally return the `argmax` |
66 |
| - # respectively `argmin`. Therefore, we need to save the `arg` for the |
67 |
| - # backward pass. |
68 |
| - if has_arg(self.name): |
69 |
| - self.save_for_backward(data[1], data[3]) |
70 |
| - return data[0], data[3] |
71 |
| - else: |
72 |
| - self.save_for_backward(data[1]) |
73 |
| - return data[0] |
| 25 | + return output |
74 | 26 |
|
75 | 27 | def backward(self, *data): # pragma: no cover
|
76 | 28 | grad_output = grad_input = None
|
77 | 29 |
|
78 | 30 | if self.needs_input_grad[0]:
|
79 | 31 | grad_output = data[0]
|
80 | 32 |
|
81 |
| - # Different grad computation of `input` if `scatter_max` or |
82 |
| - # `scatter_min` was used. |
83 |
| - if self.needs_input_grad[2] and not has_arg(self.name): |
84 |
| - index, = self.saved_variables |
85 |
| - grad_input = data[0].gather(self.dim, index.data) |
86 |
| - |
87 |
| - if self.needs_input_grad[2] and has_arg(self.name): |
88 |
| - index, arg = self.saved_variables |
89 |
| - data = (index.data, data[0], arg.data) |
90 |
| - grad_input = index_backward(self.dim, *data) |
| 33 | + # Call grad computation of `input` for the specific scatter operation. |
| 34 | + if self.needs_input_grad[2]: |
| 35 | + grad_input = self.backward_step(data[0], *self.saved_variables) |
91 | 36 |
|
92 |
| - # Return and fill with empty grads for non-differentiable passed |
93 |
| - # arguments in forward pass. |
| 37 | + # Return and fill with empty grads for non-differentiable arguments. |
94 | 38 | return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
|
95 | 39 |
|
| 40 | + def backward_step(self, *data): |
| 41 | + raise NotImplementedError |
| 42 | + |
96 | 43 |
|
97 |
| -def scatter(name, dim, *data): |
| 44 | +def scatter(Clx, name, dim, *data): |
98 | 45 | if torch.is_tensor(data[0]):
|
99 |
| - return _scatter(name, dim, *data) |
| 46 | + return ffi_scatter(name, dim, *data) |
100 | 47 | else:
|
101 |
| - return _Scatter(name, dim)(*data) |
| 48 | + return Clx(dim)(*data) |
| 49 | + |
| 50 | + |
| 51 | +# def index_backward(dim, index, grad, arg): # pragma: no cover |
| 52 | +# typename = type(grad).__name__.replace('Tensor', '') |
| 53 | +# cuda = 'cuda_' if grad.is_cuda else '' |
| 54 | +# func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename)) |
| 55 | +# output = grad.new(index.size()).fill_(0) |
| 56 | +# func(dim, output, index, grad, arg) |
| 57 | +# return output |
| 58 | + |
| 59 | +# def _scatter_backward(name, dim, saved, *data): |
| 60 | +# # saved = (index, ), (index, arg) or (index, count) |
| 61 | + |
| 62 | +# print(name) |
| 63 | +# print(len(data)) |
| 64 | +# print(len(saved)) |
| 65 | +# print(saved[1].size()) |
| 66 | +# # data = (grad, ) |
| 67 | +# # index, = seved |
| 68 | +# if has_arg(name): |
| 69 | +# return index_backward(dim, saved[0].data, data[0], saved[1].data) |
| 70 | + |
| 71 | +# if has_count(name): |
| 72 | +# return (data[0] / saved[1]).gather(dim, saved[0].data) |
| 73 | +# # Different grad computation of `input` if `scatter_max` or |
| 74 | +# # `scatter_min` was used. |
| 75 | +# # if self.needs_input_grad[2] and not has_arg(self.name): |
| 76 | +# # index, = self.saved_variables |
| 77 | +# # grad_input = data[0].gather(self.dim, index.data) |
| 78 | + |
| 79 | +# # if self.needs_input_grad[2] and has_arg(self.name): |
| 80 | +# # index, arg = self.saved_variables |
| 81 | +# # data = (index.data, data[0], arg.data) |
| 82 | +# grad_input = index_backward(self.dim, *data) |
0 commit comments