Skip to content

Commit b8a3c55

Browse files
committed
pytorch 1.3 support
1 parent 08dda1a commit b8a3c55

File tree

4 files changed

+26
-14
lines changed

4 files changed

+26
-14
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ before_install:
1717
- export CXX="g++-4.9"
1818
install:
1919
- pip install numpy
20-
- pip install -q torch -f https://download.pytorch.org/whl/nightly/cpu/torch.html
20+
- pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
2121
- pip install pycodestyle
2222
- pip install flake8
2323
- pip install codecov

cpu/compat.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#ifdef VERSION_GE_1_3
2+
#define DATA_PTR data_ptr
3+
#else
4+
#define DATA_PTR data
5+
#endif

cpu/dim_apply.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,25 @@
22

33
#include <torch/extension.h>
44

5+
#include "compat.h"
6+
57
#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \
68
[&] { \
7-
TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \
9+
TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR<TYPE1>(); \
810
auto TENSOR1##_size = TENSOR1.size(DIM); \
911
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
1012
\
11-
TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \
13+
TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR<TYPE2>(); \
1214
auto TENSOR2##_size = TENSOR2.size(DIM); \
1315
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
1416
\
15-
TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \
17+
TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR<TYPE3>(); \
1618
auto TENSOR3##_size = TENSOR3.size(DIM); \
1719
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
1820
\
1921
auto dims = TENSOR1.dim(); \
2022
auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \
21-
auto counter = zeros.data<int64_t>(); \
23+
auto counter = zeros.DATA_PTR<int64_t>(); \
2224
bool has_finished = false; \
2325
\
2426
while (!has_finished) { \
@@ -59,25 +61,25 @@
5961
#define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \
6062
TENSOR4, DIM, CODE) \
6163
[&] { \
62-
TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \
64+
TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR<TYPE1>(); \
6365
auto TENSOR1##_size = TENSOR1.size(DIM); \
6466
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
6567
\
66-
TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \
68+
TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR<TYPE2>(); \
6769
auto TENSOR2##_size = TENSOR2.size(DIM); \
6870
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
6971
\
70-
TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \
72+
TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR<TYPE3>(); \
7173
auto TENSOR3##_size = TENSOR3.size(DIM); \
7274
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
7375
\
74-
TYPE4 *TENSOR4##_data = TENSOR4.data<TYPE4>(); \
76+
TYPE4 *TENSOR4##_data = TENSOR4.DATA_PTR<TYPE4>(); \
7577
auto TENSOR4##_size = TENSOR4.size(DIM); \
7678
auto TENSOR4##_stride = TENSOR4.stride(DIM); \
7779
\
7880
auto dims = TENSOR1.dim(); \
7981
auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \
80-
auto counter = zeros.data<int64_t>(); \
82+
auto counter = zeros.DATA_PTR<int64_t>(); \
8183
bool has_finished = false; \
8284
\
8385
while (!has_finished) { \

setup.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,19 @@
33
import torch
44
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
55

6+
TORCH_MAJOR = int(torch.__version__.split('.')[0])
7+
TORCH_MINOR = int(torch.__version__.split('.')[1])
8+
69
extra_compile_args = []
710
if platform.system() != 'Windows':
811
extra_compile_args += ['-Wno-unused-variable']
912

13+
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
14+
extra_compile_args += ['-DVERSION_GE_1_3']
15+
1016
ext_modules = [
11-
CppExtension(
12-
'torch_scatter.scatter_cpu', ['cpu/scatter.cpp'],
13-
extra_compile_args=extra_compile_args)
17+
CppExtension('torch_scatter.scatter_cpu', ['cpu/scatter.cpp'],
18+
extra_compile_args=extra_compile_args)
1419
]
1520
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
1621

@@ -20,7 +25,7 @@
2025
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
2126
]
2227

23-
__version__ = '1.3.1'
28+
__version__ = '1.3.2'
2429
url = 'https://github.com/rusty1s/pytorch_scatter'
2530

2631
install_requires = []

0 commit comments

Comments
 (0)