Skip to content

Commit 8099c53

Browse files
committed
added new tests
1 parent f0fdfe2 commit 8099c53

File tree

3 files changed

+36
-34
lines changed

3 files changed

+36
-34
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from setuptools import setup, find_packages
44

5-
__version__ = '0.2.3'
5+
__version__ = '0.3.0'
66
url = 'https://github.com/rusty1s/pytorch_scatter'
77

88
install_requires = ['cffi']

test/backward.json

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,56 @@
11
[
22
{
33
"name": "add",
4-
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
5-
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
6-
"dim": 1,
4+
"index": [2, 0, 1, 1, 0],
5+
"input": [1, 2, 3, 4, 5],
6+
"dim": 0,
77
"fill_value": 0,
8-
"grad": [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]],
9-
"expected": [[50, 60, 50, 30, 40], [15, 15, 35, 35, 25]]
8+
"grad": [4, 8, 6],
9+
"expected": [6, 4, 8, 8, 4]
1010
},
1111
{
12-
"name": "add",
13-
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
14-
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
12+
"name": "sub",
13+
"index": [2, 0, 1, 1, 0],
14+
"input": [1, 2, 3, 4, 5],
1515
"dim": 0,
1616
"fill_value": 0,
17-
"grad": [[10, 20], [15, 25]],
18-
"expected": [[10, 20], [15, 25], [15, 25], [10, 20]]
17+
"grad": [4, 8, 6],
18+
"expected": [-6, -4, -8, -8, -4]
1919
},
2020
{
2121
"name": "mean",
22-
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
23-
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
24-
"dim": 1,
22+
"index": [2, 0, 1, 1, 0],
23+
"input": [1, 2, 3, 4, 5],
24+
"dim": 0,
2525
"fill_value": 0,
26-
"grad": [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]],
27-
"expected": [[50, 60, 50, 30, 40], [15, 15, 35, 35, 25]]
26+
"grad": [4, 8, 6],
27+
"expected": [6, 2, 4, 4, 2]
2828
},
2929
{
30-
"name": "mean",
31-
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
32-
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
30+
"name": "max",
31+
"index": [2, 0, 1, 1, 0],
32+
"input": [1, 2, 3, 4, 5],
3333
"dim": 0,
3434
"fill_value": 0,
35-
"grad": [[10, 20], [15, 25]],
36-
"expected": [[10, 20], [15, 25], [15, 25], [10, 20]]
35+
"grad": [4, 8, 6],
36+
"expected": [6, 0, 0, 8, 4]
3737
},
3838
{
39-
"name": "max",
40-
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
41-
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
42-
"dim": 1,
43-
"fill_value": 0,
44-
"grad": [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]],
45-
"expected": [[50, 60, 0, 30, 40], [0, 15, 0, 35, 25]]
39+
"name": "min",
40+
"index": [2, 0, 1, 1, 0],
41+
"input": [1, 2, 3, 4, 5],
42+
"dim": 0,
43+
"fill_value": 3,
44+
"grad": [4, 8, 6],
45+
"expected": [6, 4, 8, 0, 0]
4646
},
4747
{
48-
"name": "max",
49-
"index": [[0, 0], [1, 1], [1, 1], [0, 0]],
50-
"input": [[5, 2], [2, 5], [4, 3], [1, 3]],
48+
"name": "mul",
49+
"index": [2, 0, 1, 1, 0],
50+
"input": [1, 2, 3, 4, 5],
5151
"dim": 0,
52-
"fill_value": 0,
53-
"grad": [[10, 20], [15, 25]],
54-
"expected": [[10, 0], [0, 25], [15, 0], [0, 20]]
52+
"fill_value": 2,
53+
"grad": [4, 8, 6],
54+
"expected": [12, 40, 64, 48, 16]
5555
}
5656
]

test/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
tensors = [t[:-4] for t in tensor_classes]
55
tensors.remove('ShortTensor') # TODO: PyTorch `atomicAdd` bug with short type.
6+
tensors.remove('ByteTensor') # We cannot properly test unsigned values.
7+
tensors.remove('CharTensor') # Overflow on gradient computations :(
68

79

810
def Tensor(str, x):

0 commit comments

Comments
 (0)