@@ -35,6 +35,13 @@ def __init__(self, name, typename, direction, role="default"):
35
35
self .role = role
36
36
37
37
38
+ no_role_kernels = [
39
+ "awkward_NumpyArray_sort_asstrings_uint8" ,
40
+ "awkward_argsort" ,
41
+ "awkward_sort" ,
42
+ ]
43
+
44
+
38
45
class Specification :
39
46
def __init__ (self , templatized_kernel_name , spec , testdata , blacklisted ):
40
47
self .templatized_kernel_name = templatized_kernel_name
@@ -51,6 +58,8 @@ def __init__(self, templatized_kernel_name, spec, testdata, blacklisted):
51
58
)
52
59
if blacklisted :
53
60
self .tests = []
61
+ elif templatized_kernel_name in no_role_kernels :
62
+ self .tests = []
54
63
else :
55
64
self .tests = self .gettests (testdata )
56
65
@@ -185,6 +194,7 @@ def gettests(self, testdata):
185
194
186
195
def readspec ():
187
196
specdict = {}
197
+ specdict_unit = {}
188
198
with open (os .path .join (CURRENT_DIR , ".." , "kernel-specification.yml" )) as f :
189
199
loadfile = yaml .load (f , Loader = yaml .CSafeLoader )
190
200
@@ -193,6 +203,13 @@ def readspec():
193
203
data = json .load (f )["tests" ]
194
204
195
205
for spec in indspec :
206
+ for childfunc in spec ["specializations" ]:
207
+ specdict_unit [childfunc ["name" ]] = Specification (
208
+ spec ["name" ],
209
+ childfunc ,
210
+ data ,
211
+ not spec ["automatic-tests" ],
212
+ )
196
213
if "def " in spec ["definition" ]:
197
214
for childfunc in spec ["specializations" ]:
198
215
specdict [childfunc ["name" ]] = Specification (
@@ -201,7 +218,7 @@ def readspec():
201
218
data ,
202
219
not spec ["automatic-tests" ],
203
220
)
204
- return specdict
221
+ return specdict , specdict_unit
205
222
206
223
207
224
def getdtypes (args ):
@@ -215,6 +232,8 @@ def getdtypes(args):
215
232
typename = typename + "_"
216
233
if count == 1 :
217
234
dtypes .append ("cupy." + typename )
235
+ elif count == 2 :
236
+ dtypes .append ("cupy." + typename )
218
237
return dtypes
219
238
220
239
@@ -239,7 +258,12 @@ def checkintrange(test_args, error, args):
239
258
if "int" in typename or "uint" in typename :
240
259
dtype = gettypename (typename )
241
260
min_val , max_val = np .iinfo (dtype ).min , np .iinfo (dtype ).max
242
- if "List" in typename :
261
+ if "List[List" in typename :
262
+ for row in val :
263
+ for data in row :
264
+ if not (min_val <= data <= max_val ):
265
+ flag = False
266
+ elif "List" in typename :
243
267
for data in val :
244
268
if not (min_val <= data <= max_val ):
245
269
flag = False
@@ -652,12 +676,16 @@ def gencpuunittests(specdict):
652
676
653
677
654
678
cuda_kernels_tests = [
679
+ "awkward_Index_nones_as_index" ,
655
680
"awkward_ListArray_min_range" ,
656
681
"awkward_ListArray_validity" ,
657
682
"awkward_BitMaskedArray_to_ByteMaskedArray" ,
658
683
"awkward_ListArray_compact_offsets" ,
659
684
"awkward_ListOffsetArray_flatten_offsets" ,
660
685
"awkward_IndexedArray_overlay_mask" ,
686
+ "awkward_ByteMaskedArray_numnull" ,
687
+ "awkward_IndexedArray_numnull" ,
688
+ "awkward_IndexedArray_numnull_parents" ,
661
689
"awkward_IndexedArray_numnull_unique_64" ,
662
690
"awkward_NumpyArray_fill" ,
663
691
"awkward_ListArray_fill" ,
@@ -683,12 +711,19 @@ def gencpuunittests(specdict):
683
711
"awkward_RegularArray_getitem_next_range" ,
684
712
"awkward_RegularArray_getitem_next_range_spreadadvanced" ,
685
713
"awkward_RegularArray_getitem_next_array" ,
714
+ "awkward_RegularArray_getitem_next_array_regularize" ,
715
+ "awkward_RegularArray_reduce_local_nextparents" ,
716
+ "awkward_RegularArray_reduce_nonlocal_preparenext" ,
686
717
"awkward_missing_repeat" ,
687
718
"awkward_RegularArray_getitem_jagged_expand" ,
688
719
"awkward_ListArray_getitem_jagged_expand" ,
720
+ "awkward_ListArray_getitem_jagged_carrylen" ,
689
721
"awkward_ListArray_getitem_next_array_advanced" ,
690
722
"awkward_ListArray_getitem_next_array" ,
691
723
"awkward_ListArray_getitem_next_at" ,
724
+ "awkward_ListArray_getitem_next_range_counts" ,
725
+ "awkward_ListArray_rpad_and_clip_length_axis1" ,
726
+ "awkward_ListOffsetArray_reduce_nonlocal_nextstarts_64" ,
692
727
"awkward_NumpyArray_reduce_adjust_starts_64" ,
693
728
"awkward_NumpyArray_reduce_adjust_starts_shifts_64" ,
694
729
"awkward_RegularArray_getitem_next_at" ,
@@ -726,6 +761,7 @@ def gencpuunittests(specdict):
726
761
"awkward_reduce_sum_bool" ,
727
762
"awkward_reduce_prod_bool" ,
728
763
"awkward_reduce_countnonzero" ,
764
+ "awkward_sorting_ranges_length" ,
729
765
]
730
766
731
767
@@ -966,8 +1002,12 @@ def gencudaunittests(specdict):
966
1002
)
967
1003
)
968
1004
elif count == 2 :
969
- raise NotImplementedError
970
-
1005
+ f .write (
1006
+ " " * 4
1007
+ + "{} = cupy.array({}, dtype=cupy.{})\n " .format (
1008
+ arg , val , typename
1009
+ )
1010
+ )
971
1011
cuda_string = (
972
1012
"funcC = cupy_backend['"
973
1013
+ spec .templatized_kernel_name
@@ -1068,10 +1108,10 @@ def evalkernels():
1068
1108
if __name__ == "__main__" :
1069
1109
genpykernels ()
1070
1110
evalkernels ()
1071
- specdict = readspec ()
1111
+ specdict , specdict_unit = readspec ()
1072
1112
genspectests (specdict )
1073
1113
gencpukerneltests (specdict )
1074
- gencpuunittests (specdict )
1114
+ gencpuunittests (specdict_unit )
1075
1115
genunittests ()
1076
1116
gencudakerneltests (specdict )
1077
- gencudaunittests (specdict )
1117
+ gencudaunittests (specdict_unit )
0 commit comments