Skip to content

Commit fdcf96e

Browse files
committed
added exclusive_scan function and add new cuda kernels
1 parent 432a11d commit fdcf96e

15 files changed

+505
-44
lines changed

dev/generate-kernel-signatures.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
"awkward_RegularArray_getitem_next_range",
4848
"awkward_RegularArray_getitem_next_range_spreadadvanced",
4949
"awkward_RegularArray_getitem_next_array",
50+
"awkward_RegularArray_reduce_local_nextparents",
51+
"awkward_RegularArray_reduce_nonlocal_preparenext",
5052
"awkward_missing_repeat",
5153
"awkward_RegularArray_getitem_jagged_expand",
5254
"awkward_ListArray_getitem_jagged_expand",
@@ -93,6 +95,7 @@
9395
"awkward_reduce_sum_bool",
9496
"awkward_reduce_prod_bool",
9597
"awkward_reduce_countnonzero",
98+
"awkward_sorting_ranges_length",
9699
]
97100

98101

dev/generate-tests.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ def __init__(self, name, typename, direction, role="default"):
3535
self.role = role
3636

3737

38+
no_role_kernels = [
39+
"awkward_NumpyArray_sort_asstrings_uint8",
40+
"awkward_argsort",
41+
"awkward_sort",
42+
]
43+
44+
3845
class Specification:
3946
def __init__(self, templatized_kernel_name, spec, testdata, blacklisted):
4047
self.templatized_kernel_name = templatized_kernel_name
@@ -51,6 +58,8 @@ def __init__(self, templatized_kernel_name, spec, testdata, blacklisted):
5158
)
5259
if blacklisted:
5360
self.tests = []
61+
elif templatized_kernel_name in no_role_kernels:
62+
self.tests = []
5463
else:
5564
self.tests = self.gettests(testdata)
5665

@@ -185,6 +194,7 @@ def gettests(self, testdata):
185194

186195
def readspec():
187196
specdict = {}
197+
specdict_unit = {}
188198
with open(os.path.join(CURRENT_DIR, "..", "kernel-specification.yml")) as f:
189199
loadfile = yaml.load(f, Loader=yaml.CSafeLoader)
190200

@@ -193,6 +203,13 @@ def readspec():
193203
data = json.load(f)["tests"]
194204

195205
for spec in indspec:
206+
for childfunc in spec["specializations"]:
207+
specdict_unit[childfunc["name"]] = Specification(
208+
spec["name"],
209+
childfunc,
210+
data,
211+
not spec["automatic-tests"],
212+
)
196213
if "def " in spec["definition"]:
197214
for childfunc in spec["specializations"]:
198215
specdict[childfunc["name"]] = Specification(
@@ -201,7 +218,7 @@ def readspec():
201218
data,
202219
not spec["automatic-tests"],
203220
)
204-
return specdict
221+
return specdict, specdict_unit
205222

206223

207224
def getdtypes(args):
@@ -215,6 +232,8 @@ def getdtypes(args):
215232
typename = typename + "_"
216233
if count == 1:
217234
dtypes.append("cupy." + typename)
235+
elif count == 2:
236+
dtypes.append("cupy." + typename)
218237
return dtypes
219238

220239

@@ -239,7 +258,12 @@ def checkintrange(test_args, error, args):
239258
if "int" in typename or "uint" in typename:
240259
dtype = gettypename(typename)
241260
min_val, max_val = np.iinfo(dtype).min, np.iinfo(dtype).max
242-
if "List" in typename:
261+
if "List[List" in typename:
262+
for row in val:
263+
for data in row:
264+
if not (min_val <= data <= max_val):
265+
flag = False
266+
elif "List" in typename:
243267
for data in val:
244268
if not (min_val <= data <= max_val):
245269
flag = False
@@ -687,6 +711,8 @@ def gencpuunittests(specdict):
687711
"awkward_RegularArray_getitem_next_range",
688712
"awkward_RegularArray_getitem_next_range_spreadadvanced",
689713
"awkward_RegularArray_getitem_next_array",
714+
"awkward_RegularArray_reduce_local_nextparents",
715+
"awkward_RegularArray_reduce_nonlocal_preparenext",
690716
"awkward_missing_repeat",
691717
"awkward_RegularArray_getitem_jagged_expand",
692718
"awkward_ListArray_getitem_jagged_expand",
@@ -733,6 +759,7 @@ def gencpuunittests(specdict):
733759
"awkward_reduce_sum_bool",
734760
"awkward_reduce_prod_bool",
735761
"awkward_reduce_countnonzero",
762+
"awkward_sorting_ranges_length",
736763
]
737764

738765

@@ -973,8 +1000,12 @@ def gencudaunittests(specdict):
9731000
)
9741001
)
9751002
elif count == 2:
976-
raise NotImplementedError
977-
1003+
f.write(
1004+
" " * 4
1005+
+ "{} = cupy.array({}, dtype=cupy.{})\n".format(
1006+
arg, val, typename
1007+
)
1008+
)
9781009
cuda_string = (
9791010
"funcC = cupy_backend['"
9801011
+ spec.templatized_kernel_name
@@ -1075,10 +1106,10 @@ def evalkernels():
10751106
if __name__ == "__main__":
10761107
genpykernels()
10771108
evalkernels()
1078-
specdict = readspec()
1109+
specdict, specdict_unit = readspec()
10791110
genspectests(specdict)
10801111
gencpukerneltests(specdict)
1081-
gencpuunittests(specdict)
1112+
gencpuunittests(specdict_unit)
10821113
genunittests()
10831114
gencudakerneltests(specdict)
1084-
gencudaunittests(specdict)
1115+
gencudaunittests(specdict_unit)

0 commit comments

Comments
 (0)