Skip to content

Commit 0299bab

Browse files
committed
feat: add CUDA kernels (need to be fixed)
1 parent 190060b commit 0299bab

File tree

4 files changed

+214
-0
lines changed

4 files changed

+214
-0
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
2+
3+
// BEGIN PYTHON
4+
// def f(grid, block, args):
5+
// (numnull, mask, length, validwhen, invocation_index, err_code) = args
6+
// scan_in_array = cupy.empty(length, dtype=cupy.int64)
7+
// cuda_kernel_templates.get_function(fetch_specialization(['awkward_ByteMaskedArray_numnull_a', numnull.dtype, mask.dtype]))(grid, block, (numnull, mask, length, validwhen, scan_in_array, invocation_index, err_code))
8+
// scan_in_array = inclusive_scan(grid, block, (scan_in_array, invocation_index, err_code))
9+
// cuda_kernel_templates.get_function(fetch_specialization(['awkward_ByteMaskedArray_numnull_b', numnull.dtype, mask.dtype]))(grid, block, (numnull, mask, length, validwhen, scan_in_array, invocation_index, err_code))
10+
// out["awkward_ByteMaskedArray_numnull_a", {dtype_specializations}] = None
11+
// out["awkward_ByteMaskedArray_numnull_b", {dtype_specializations}] = None
12+
// END PYTHON
13+
14+
template <typename T, typename C>
15+
__global__ void
16+
awkward_ByteMaskedArray_numnull_a(T* numnull,
17+
const C* mask,
18+
int64_t length,
19+
bool validwhen,
20+
int64_t* scan_in_array,
21+
uint64_t invocation_index,
22+
uint64_t* err_code) {
23+
if (err_code[0] == NO_ERROR) {
24+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
25+
26+
if (thread_id < length) {
27+
*numnull = 0;
28+
if ((mask[thread_id] != 0) != validwhen) {
29+
scan_in_array[thread_id] = 1;
30+
}
31+
else {
32+
scan_in_array[thread_id] = 0;
33+
}
34+
}
35+
}
36+
}
37+
38+
template <typename T, typename C>
39+
__global__ void
40+
awkward_ByteMaskedArray_numnull_b(T* numnull,
41+
const C* mask,
42+
int64_t length,
43+
bool validwhen,
44+
int64_t* scan_in_array,
45+
uint64_t invocation_index,
46+
uint64_t* err_code) {
47+
if (err_code[0] == NO_ERROR) {
48+
*numnull = scan_in_array[length - 1];
49+
}
50+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
2+
3+
// BEGIN PYTHON
4+
// def f(grid, block, args):
5+
// (toindex, length, invocation_index, err_code) = args
6+
// scan_in_array = cupy.empty(length, dtype=cupy.int64)
7+
// scan_in_array_n_non_null = cupy.empty(length, dtype=cupy.int64)
8+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_Index_nones_as_index_a", toindex.dtype]))(grid, block, (toindex, length, scan_in_array, scan_in_array_n_non_null, invocation_index, err_code))
9+
// scan_in_array = inclusive_scan(grid, block, (scan_in_array, invocation_index, err_code))
10+
// scan_in_array_n_non_null = inclusive_scan(grid, block, (scan_in_array_n_non_null, invocation_index, err_code))
11+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_Index_nones_as_index_b", toindex.dtype]))(grid, block, (toindex, length, scan_in_array, scan_in_array_n_non_null, invocation_index, err_code))
12+
// out["awkward_Index_nones_as_index_a", {dtype_specializations}] = None
13+
// out["awkward_Index_nones_as_index_b", {dtype_specializations}] = None
14+
// END PYTHON
15+
16+
template <typename T>
17+
__global__ void
18+
awkward_Index_nones_as_index_a(T* toindex,
19+
int64_t length,
20+
int64_t* scan_in_array,
21+
int64_t* scan_in_array_n_non_null,
22+
uint64_t invocation_index,
23+
uint64_t* err_code) {
24+
if (err_code[0] == NO_ERROR) {
25+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
26+
if (thread_id < length) {
27+
if (toindex[thread_id] != -1) {
28+
scan_in_array[thread_id] = 1;
29+
scan_in_array_n_non_null[thread_id] = 0;
30+
}
31+
else {
32+
scan_in_array_n_non_null[thread_id] = 1;
33+
scan_in_array[thread_id] = 0;
34+
}
35+
}
36+
}
37+
}
38+
39+
template <typename T>
40+
__global__ void
41+
awkward_Index_nones_as_index_b(T* toindex,
42+
int64_t length,
43+
int64_t* scan_in_array,
44+
int64_t* scan_in_array_n_non_null,
45+
uint64_t invocation_index,
46+
uint64_t* err_code) {
47+
if (err_code[0] == NO_ERROR) {
48+
int64_t n_non_null = scan_in_array[length - 1];
49+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
50+
if (thread_id < length) {
51+
toindex[thread_id] == -1 ? toindex[thread_id] = (n_non_null + scan_in_array_n_non_null[thread_id] - 1): toindex[thread_id];
52+
}
53+
}
54+
}
55+
56+
// fails for [-1]
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
2+
3+
// BEGIN PYTHON
4+
// def f(grid, block, args):
5+
// (numnull, fromindex, lenindex, invocation_index, err_code) = args
6+
// scan_in_array = cupy.empty(lenindex, dtype=cupy.int64)
7+
// cuda_kernel_templates.get_function(fetch_specialization(['awkward_IndexedArray_numnull_a', numnull.dtype, fromindex.dtype]))(grid, block, (numnull, fromindex, lenindex, scan_in_array, invocation_index, err_code))
8+
// scan_in_array = inclusive_scan(grid, block, (scan_in_array, invocation_index, err_code))
9+
// cuda_kernel_templates.get_function(fetch_specialization(['awkward_IndexedArray_numnull_b', numnull.dtype, fromindex.dtype]))(grid, block, (numnull, fromindex, lenindex, scan_in_array, invocation_index, err_code))
10+
// out["awkward_IndexedArray_numnull_a", {dtype_specializations}] = None
11+
// out["awkward_IndexedArray_numnull_b", {dtype_specializations}] = None
12+
// END PYTHON
13+
14+
template <typename T, typename C>
15+
__global__ void
16+
awkward_IndexedArray_numnull_a(T* numnull,
17+
const C* fromindex,
18+
int64_t lenindex,
19+
int64_t* scan_in_array,
20+
uint64_t invocation_index,
21+
uint64_t* err_code) {
22+
if (err_code[0] == NO_ERROR) {
23+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
24+
25+
if (thread_id < lenindex) {
26+
if (fromindex[thread_id] < 0) {
27+
scan_in_array[thread_id] = 1;
28+
}
29+
else {
30+
scan_in_array[thread_id] = 0;
31+
}
32+
}
33+
}
34+
}
35+
36+
template <typename T, typename C>
37+
__global__ void
38+
awkward_IndexedArray_numnull_b(T* numnull,
39+
const C* fromindex,
40+
int64_t lenindex,
41+
int64_t* scan_in_array,
42+
uint64_t invocation_index,
43+
uint64_t* err_code) {
44+
if (err_code[0] == NO_ERROR) {
45+
*numnull = scan_in_array[lenindex - 1];
46+
}
47+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
2+
3+
// BEGIN PYTHON
4+
// def f(grid, block, args):
5+
// (numnull, tolength, fromindex, lenindex, invocation_index, err_code) = args
6+
// scan_in_array = cupy.empty(lenindex, dtype=cupy.int64)
7+
// cuda_kernel_templates.get_function(fetch_specialization(['awkward_IndexedArray_numnull_parents_a', numnull.dtype, tolength.dtype, fromindex.dtype]))(grid, block, (numnull, tolength, fromindex, lenindex, scan_in_array, invocation_index, err_code))
8+
// scan_in_array = inclusive_scan(grid, block, (scan_in_array, invocation_index, err_code))
9+
// cuda_kernel_templates.get_function(fetch_specialization(['awkward_IndexedArray_numnull_parents_b', numnull.dtype, tolength.dtype, fromindex.dtype]))(grid, block, (numnull, tolength, fromindex, lenindex, scan_in_array, invocation_index, err_code))
10+
// out["awkward_IndexedArray_numnull_parents_a", {dtype_specializations}] = None
11+
// out["awkward_IndexedArray_numnull_parents_b", {dtype_specializations}] = None
12+
// END PYTHON
13+
14+
template <typename T, typename C, typename U>
15+
__global__ void
16+
awkward_IndexedArray_numnull_parents_a(T* numnull,
17+
U* tolength,
18+
const C* fromindex,
19+
int64_t lenindex,
20+
int64_t* scan_in_array,
21+
uint64_t invocation_index,
22+
uint64_t* err_code) {
23+
if (err_code[0] == NO_ERROR) {
24+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
25+
26+
if (thread_id < lenindex) {
27+
if (fromindex[thread_id] < 0) {
28+
scan_in_array[thread_id] = 1;
29+
}
30+
else {
31+
scan_in_array[thread_id] = 0;
32+
}
33+
}
34+
}
35+
}
36+
37+
template <typename T, typename C, typename U>
38+
__global__ void
39+
awkward_IndexedArray_numnull_parents_b(T* numnull,
40+
U* tolength,
41+
const C* fromindex,
42+
int64_t lenindex,
43+
int64_t* scan_in_array,
44+
uint64_t invocation_index,
45+
uint64_t* err_code) {
46+
if (err_code[0] == NO_ERROR) {
47+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
48+
49+
if (thread_id < lenindex) {
50+
if (fromindex[thread_id] < 0) {
51+
numnull[thread_id] = 1;
52+
}
53+
else {
54+
numnull[thread_id] = 0;
55+
}
56+
}
57+
*tolength = scan_in_array[lenindex - 1];
58+
}
59+
}
60+
61+
// fails for [-1]

0 commit comments

Comments
 (0)