|
2 | 2 |
|
3 | 3 | #include <torch/extension.h>
|
4 | 4 |
|
| 5 | +#include "compat.h" |
| 6 | + |
5 | 7 | #define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \
|
6 | 8 | [&] { \
|
7 |
| - TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \ |
| 9 | + TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR<TYPE1>(); \ |
8 | 10 | auto TENSOR1##_size = TENSOR1.size(DIM); \
|
9 | 11 | auto TENSOR1##_stride = TENSOR1.stride(DIM); \
|
10 | 12 | \
|
11 |
| - TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \ |
| 13 | + TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR<TYPE2>(); \ |
12 | 14 | auto TENSOR2##_size = TENSOR2.size(DIM); \
|
13 | 15 | auto TENSOR2##_stride = TENSOR2.stride(DIM); \
|
14 | 16 | \
|
15 |
| - TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \ |
| 17 | + TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR<TYPE3>(); \ |
16 | 18 | auto TENSOR3##_size = TENSOR3.size(DIM); \
|
17 | 19 | auto TENSOR3##_stride = TENSOR3.stride(DIM); \
|
18 | 20 | \
|
19 | 21 | auto dims = TENSOR1.dim(); \
|
20 | 22 | auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \
|
21 |
| - auto counter = zeros.data<int64_t>(); \ |
| 23 | + auto counter = zeros.DATA_PTR<int64_t>(); \ |
22 | 24 | bool has_finished = false; \
|
23 | 25 | \
|
24 | 26 | while (!has_finished) { \
|
|
59 | 61 | #define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \
|
60 | 62 | TENSOR4, DIM, CODE) \
|
61 | 63 | [&] { \
|
62 |
| - TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \ |
| 64 | + TYPE1 *TENSOR1##_data = TENSOR1.DATA_PTR<TYPE1>(); \ |
63 | 65 | auto TENSOR1##_size = TENSOR1.size(DIM); \
|
64 | 66 | auto TENSOR1##_stride = TENSOR1.stride(DIM); \
|
65 | 67 | \
|
66 |
| - TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \ |
| 68 | + TYPE2 *TENSOR2##_data = TENSOR2.DATA_PTR<TYPE2>(); \ |
67 | 69 | auto TENSOR2##_size = TENSOR2.size(DIM); \
|
68 | 70 | auto TENSOR2##_stride = TENSOR2.stride(DIM); \
|
69 | 71 | \
|
70 |
| - TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \ |
| 72 | + TYPE3 *TENSOR3##_data = TENSOR3.DATA_PTR<TYPE3>(); \ |
71 | 73 | auto TENSOR3##_size = TENSOR3.size(DIM); \
|
72 | 74 | auto TENSOR3##_stride = TENSOR3.stride(DIM); \
|
73 | 75 | \
|
74 |
| - TYPE4 *TENSOR4##_data = TENSOR4.data<TYPE4>(); \ |
| 76 | + TYPE4 *TENSOR4##_data = TENSOR4.DATA_PTR<TYPE4>(); \ |
75 | 77 | auto TENSOR4##_size = TENSOR4.size(DIM); \
|
76 | 78 | auto TENSOR4##_stride = TENSOR4.stride(DIM); \
|
77 | 79 | \
|
78 | 80 | auto dims = TENSOR1.dim(); \
|
79 | 81 | auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \
|
80 |
| - auto counter = zeros.data<int64_t>(); \ |
| 82 | + auto counter = zeros.DATA_PTR<int64_t>(); \ |
81 | 83 | bool has_finished = false; \
|
82 | 84 | \
|
83 | 85 | while (!has_finished) { \
|
|
0 commit comments