@@ -35,6 +35,13 @@ def __init__(self, name, typename, direction, role="default"):
35
35
self .role = role
36
36
37
37
38
+ no_role_kernels = [
39
+ "awkward_NumpyArray_sort_asstrings_uint8" ,
40
+ "awkward_argsort" ,
41
+ "awkward_sort" ,
42
+ ]
43
+
44
+
38
45
class Specification :
39
46
def __init__ (self , templatized_kernel_name , spec , testdata , blacklisted ):
40
47
self .templatized_kernel_name = templatized_kernel_name
@@ -51,6 +58,8 @@ def __init__(self, templatized_kernel_name, spec, testdata, blacklisted):
51
58
)
52
59
if blacklisted :
53
60
self .tests = []
61
+ elif templatized_kernel_name in no_role_kernels :
62
+ self .tests = []
54
63
else :
55
64
self .tests = self .gettests (testdata )
56
65
@@ -185,6 +194,7 @@ def gettests(self, testdata):
185
194
186
195
def readspec ():
187
196
specdict = {}
197
+ specdict_unit = {}
188
198
with open (os .path .join (CURRENT_DIR , ".." , "kernel-specification.yml" )) as f :
189
199
loadfile = yaml .load (f , Loader = yaml .CSafeLoader )
190
200
@@ -193,6 +203,13 @@ def readspec():
193
203
data = json .load (f )["tests" ]
194
204
195
205
for spec in indspec :
206
+ for childfunc in spec ["specializations" ]:
207
+ specdict_unit [childfunc ["name" ]] = Specification (
208
+ spec ["name" ],
209
+ childfunc ,
210
+ data ,
211
+ not spec ["automatic-tests" ],
212
+ )
196
213
if "def " in spec ["definition" ]:
197
214
for childfunc in spec ["specializations" ]:
198
215
specdict [childfunc ["name" ]] = Specification (
@@ -201,7 +218,7 @@ def readspec():
201
218
data ,
202
219
not spec ["automatic-tests" ],
203
220
)
204
- return specdict
221
+ return specdict , specdict_unit
205
222
206
223
207
224
def getdtypes (args ):
@@ -215,6 +232,8 @@ def getdtypes(args):
215
232
typename = typename + "_"
216
233
if count == 1 :
217
234
dtypes .append ("cupy." + typename )
235
+ elif count == 2 :
236
+ dtypes .append ("cupy." + typename )
218
237
return dtypes
219
238
220
239
@@ -239,7 +258,12 @@ def checkintrange(test_args, error, args):
239
258
if "int" in typename or "uint" in typename :
240
259
dtype = gettypename (typename )
241
260
min_val , max_val = np .iinfo (dtype ).min , np .iinfo (dtype ).max
242
- if "List" in typename :
261
+ if "List[List" in typename :
262
+ for row in val :
263
+ for data in row :
264
+ if not (min_val <= data <= max_val ):
265
+ flag = False
266
+ elif "List" in typename :
243
267
for data in val :
244
268
if not (min_val <= data <= max_val ):
245
269
flag = False
@@ -687,6 +711,8 @@ def gencpuunittests(specdict):
687
711
"awkward_RegularArray_getitem_next_range" ,
688
712
"awkward_RegularArray_getitem_next_range_spreadadvanced" ,
689
713
"awkward_RegularArray_getitem_next_array" ,
714
+ "awkward_RegularArray_reduce_local_nextparents" ,
715
+ "awkward_RegularArray_reduce_nonlocal_preparenext" ,
690
716
"awkward_missing_repeat" ,
691
717
"awkward_RegularArray_getitem_jagged_expand" ,
692
718
"awkward_ListArray_getitem_jagged_expand" ,
@@ -733,6 +759,7 @@ def gencpuunittests(specdict):
733
759
"awkward_reduce_sum_bool" ,
734
760
"awkward_reduce_prod_bool" ,
735
761
"awkward_reduce_countnonzero" ,
762
+ "awkward_sorting_ranges_length" ,
736
763
]
737
764
738
765
@@ -973,8 +1000,12 @@ def gencudaunittests(specdict):
973
1000
)
974
1001
)
975
1002
elif count == 2 :
976
- raise NotImplementedError
977
-
1003
+ f .write (
1004
+ " " * 4
1005
+ + "{} = cupy.array({}, dtype=cupy.{})\n " .format (
1006
+ arg , val , typename
1007
+ )
1008
+ )
978
1009
cuda_string = (
979
1010
"funcC = cupy_backend['"
980
1011
+ spec .templatized_kernel_name
@@ -1075,10 +1106,10 @@ def evalkernels():
1075
1106
if __name__ == "__main__" :
1076
1107
genpykernels ()
1077
1108
evalkernels ()
1078
- specdict = readspec ()
1109
+ specdict , specdict_unit = readspec ()
1079
1110
genspectests (specdict )
1080
1111
gencpukerneltests (specdict )
1081
- gencpuunittests (specdict )
1112
+ gencpuunittests (specdict_unit )
1082
1113
genunittests ()
1083
1114
gencudakerneltests (specdict )
1084
- gencudaunittests (specdict )
1115
+ gencudaunittests (specdict_unit )
0 commit comments