Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions kernel_tuner/searchspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from warnings import warn
from copy import deepcopy
from collections import defaultdict, deque
from inspect import signature

import numpy as np
from scipy.stats.qmc import LatinHypercube
Expand Down Expand Up @@ -495,6 +496,13 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
def __add_restrictions(self, parameter_space: Problem) -> Problem:
"""Add the user-specified restrictions as constraints on the parameter space."""
restrictions = deepcopy(self.restrictions)
# differentiate between old style monolithic with single 'p' argument and newer *args style
if (len(restrictions) == 1
and not isinstance(restrictions[0], (Constraint, FunctionConstraint, str))
and callable(restrictions[0])
and len(signature(restrictions[0]).parameters) == 1
and len(self.param_names) > 1):
restrictions = restrictions[0]
if isinstance(restrictions, list):
for restriction in restrictions:
required_params = self.param_names
Expand All @@ -504,10 +512,6 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
required_params = restriction[1]
restriction = restriction[0]
if callable(restriction) and not isinstance(restriction, Constraint):
# def restrictions_wrapper(*args):
# return check_instance_restrictions(restriction, dict(zip(self.param_names, args)), False)
# print(restriction, isinstance(restriction, Constraint))
# restriction = FunctionConstraint(restrictions_wrapper)
restriction = FunctionConstraint(restriction, required_params)

# add as a Constraint
Expand All @@ -529,6 +533,7 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
elif callable(restrictions):

def restrictions_wrapper(*args):
"""Wrap old-style monolithic restrictions to work with multiple arguments."""
return check_instance_restrictions(restrictions, dict(zip(self.param_names, args)), False)

parameter_space.addConstraint(FunctionConstraint(restrictions_wrapper), self.param_names)
Expand Down
34 changes: 34 additions & 0 deletions test/test_searchspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,3 +623,37 @@ def test_full_searchspace(compare_against_bruteforce=False):
compare_two_searchspace_objects(searchspace, searchspace_bruteforce)
else:
assert searchspace.size == len(searchspace.list) == 349853

def test_restriction_backwards_compatibility():
"""Test whether the backwards compatibility code for restrictions (list of strings) works as expected."""
# create a searchspace with mixed parameter types
max_threads = 1024
tune_params = dict()
tune_params["N_PER_BLOCK"] = [32, 64, 128, 256, 512, 1024]
tune_params["M_PER_BLOCK"] = [32, 64, 128, 256, 512, 1024]
tune_params["block_size_y"] = [1, 2, 4, 8, 16, 32]
tune_params["block_size_z"] = [1, 2, 4, 8, 16, 32]

# old style monolithic restriction function
def restrict(p):
n_global_per_warp = int(p["N_PER_BLOCK"] // p["block_size_y"])
m_global_per_warp = int(p["M_PER_BLOCK"] // p["block_size_z"])
if n_global_per_warp == 0 or m_global_per_warp == 0:
return False

searchspace_callable = Searchspace(tune_params, restrict, max_threads)

def restrict_args(N_PER_BLOCK, M_PER_BLOCK, block_size_y, block_size_z):
n_global_per_warp = int(N_PER_BLOCK // block_size_y)
m_global_per_warp = int(M_PER_BLOCK // block_size_z)
if n_global_per_warp == 0 or m_global_per_warp == 0:
return False

# args-style restriction
searchspace_str = Searchspace(tune_params, restrict_args, max_threads)

# check the size
assert searchspace_str.size == searchspace_callable.size

# check that both searchspaces are identical in outcome
compare_two_searchspace_objects(searchspace_str, searchspace_callable)