From 739f95403c3fcf13e61e73e573e0fd391a9af911 Mon Sep 17 00:00:00 2001 From: fjwillemsen Date: Wed, 8 Oct 2025 12:07:42 +0200 Subject: [PATCH 1/3] This fixes the issue regarding backwards compatibility with old monolithic restrictions functions reported in #333 --- kernel_tuner/searchspace.py | 18 +++++++++++++----- test/test_searchspace.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/kernel_tuner/searchspace.py b/kernel_tuner/searchspace.py index d3d00052..de9b33c0 100644 --- a/kernel_tuner/searchspace.py +++ b/kernel_tuner/searchspace.py @@ -7,6 +7,7 @@ from warnings import warn from copy import deepcopy from collections import defaultdict, deque +from inspect import signature import numpy as np from scipy.stats.qmc import LatinHypercube @@ -495,6 +496,8 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver: def __add_restrictions(self, parameter_space: Problem) -> Problem: """Add the user-specified restrictions as constraints on the parameter space.""" restrictions = deepcopy(self.restrictions) + if len(restrictions) == 1 and not isinstance(restrictions[0], (Constraint, FunctionConstraint, str)) and callable(restrictions[0]) and len(signature(restrictions[0]).parameters) == 1: + restrictions = restrictions[0] if isinstance(restrictions, list): for restriction in restrictions: required_params = self.param_names @@ -504,11 +507,16 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem: required_params = restriction[1] restriction = restriction[0] if callable(restriction) and not isinstance(restriction, Constraint): - # def restrictions_wrapper(*args): - # return check_instance_restrictions(restriction, dict(zip(self.param_names, args)), False) - # print(restriction, isinstance(restriction, Constraint)) - # restriction = FunctionConstraint(restrictions_wrapper) - restriction = FunctionConstraint(restriction, required_params) + # differentiate between old style monolithic with single 'p' argument and newer *args style + if len(signature(restriction).parameters) == 1 and len(self.param_names) != 1: + def restrictions_wrapper(*args): + # raise ValueError(self.param_names, args, restriction, signature(restriction).parameters) + # return restriction(dict(zip(self.param_names, args))) + return check_instance_restrictions(restriction, dict(zip(self.param_names, args)), False) + + restriction = FunctionConstraint(restrictions_wrapper) + else: + restriction = FunctionConstraint(restriction, required_params) # add as a Constraint all_params_required = all(param_name in required_params for param_name in self.param_names) diff --git a/test/test_searchspace.py b/test/test_searchspace.py index f742a4b7..32f58bec 100644 --- a/test/test_searchspace.py +++ b/test/test_searchspace.py @@ -623,3 +623,37 @@ def test_full_searchspace(compare_against_bruteforce=False): compare_two_searchspace_objects(searchspace, searchspace_bruteforce) else: assert searchspace.size == len(searchspace.list) == 349853 + +def test_restriction_backwards_compatibility(): + """Test whether the backwards compatibility code for restrictions (list of strings) works as expected.""" + # create a searchspace with mixed parameter types + max_threads = 1024 + tune_params = dict() + tune_params["N_PER_BLOCK"] = [32, 64, 128, 256, 512, 1024] + tune_params["M_PER_BLOCK"] = [32, 64, 128, 256, 512, 1024] + tune_params["block_size_y"] = [1, 2, 4, 8, 16, 32] + tune_params["block_size_z"] = [1, 2, 4, 8, 16, 32] + + # old style monolithic restriction function + def restrict(p): + n_global_per_warp = int(p["N_PER_BLOCK"] // p["block_size_y"]) + m_global_per_warp = int(p["M_PER_BLOCK"] // p["block_size_z"]) + if n_global_per_warp == 0 or m_global_per_warp == 0: + return False + + searchspace_callable = Searchspace(tune_params, restrict, max_threads) + + def restrict_args(N_PER_BLOCK, M_PER_BLOCK, block_size_y, block_size_z): + n_global_per_warp = int(N_PER_BLOCK // block_size_y) + m_global_per_warp = int(M_PER_BLOCK // block_size_z) + if n_global_per_warp == 0 or m_global_per_warp == 0: + return False + + # args-style restriction + searchspace_str = Searchspace(tune_params, restrict_args, max_threads) + + # check the size + assert searchspace_str.size == searchspace_callable.size + + # check that both searchspaces are identical in outcome + compare_two_searchspace_objects(searchspace_str, searchspace_callable) From 4d19ee8201e2a7af6a2db9f80c5f36f1b9ab9c3d Mon Sep 17 00:00:00 2001 From: fjwillemsen Date: Wed, 8 Oct 2025 14:58:14 +0200 Subject: [PATCH 2/3] Improvements to code style --- kernel_tuner/searchspace.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/kernel_tuner/searchspace.py b/kernel_tuner/searchspace.py index de9b33c0..d5612f86 100644 --- a/kernel_tuner/searchspace.py +++ b/kernel_tuner/searchspace.py @@ -496,7 +496,12 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver: def __add_restrictions(self, parameter_space: Problem) -> Problem: """Add the user-specified restrictions as constraints on the parameter space.""" restrictions = deepcopy(self.restrictions) - if len(restrictions) == 1 and not isinstance(restrictions[0], (Constraint, FunctionConstraint, str)) and callable(restrictions[0]) and len(signature(restrictions[0]).parameters) == 1: + # differentiate between old style monolithic with single 'p' argument and newer *args style + if (len(restrictions) == 1 + and not isinstance(restrictions[0], (Constraint, FunctionConstraint, str)) + and callable(restrictions[0]) + and len(signature(restrictions[0]).parameters) == 1 + and len(self.param_names) > 1): restrictions = restrictions[0] if isinstance(restrictions, list): for restriction in restrictions: @@ -507,16 +512,7 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem: required_params = restriction[1] restriction = restriction[0] if callable(restriction) and not isinstance(restriction, Constraint): - # differentiate between old style monolithic with single 'p' argument and newer *args style - if len(signature(restriction).parameters) == 1 and len(self.param_names) != 1: - def restrictions_wrapper(*args): - # raise ValueError(self.param_names, args, restriction, signature(restriction).parameters) - # return restriction(dict(zip(self.param_names, args))) - return check_instance_restrictions(restriction, dict(zip(self.param_names, args)), False) - - restriction = FunctionConstraint(restrictions_wrapper) - else: - restriction = FunctionConstraint(restriction, required_params) + restriction = FunctionConstraint(restriction, required_params) # add as a Constraint all_params_required = all(param_name in required_params for param_name in self.param_names) @@ -537,6 +533,7 @@ def restrictions_wrapper(*args): elif callable(restrictions): def restrictions_wrapper(*args): + """Wrap old-style monolithic restrictions to work with multiple arguments.""" return check_instance_restrictions(restrictions, dict(zip(self.param_names, args)), False) parameter_space.addConstraint(FunctionConstraint(restrictions_wrapper), self.param_names) From 38700afd348e59f8ff9b7629612be979d32617bd Mon Sep 17 00:00:00 2001 From: fjwillemsen Date: Wed, 8 Oct 2025 15:01:00 +0200 Subject: [PATCH 3/3] Removed trailing whitespace --- kernel_tuner/searchspace.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kernel_tuner/searchspace.py b/kernel_tuner/searchspace.py index d5612f86..42d0be4d 100644 --- a/kernel_tuner/searchspace.py +++ b/kernel_tuner/searchspace.py @@ -497,10 +497,10 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem: """Add the user-specified restrictions as constraints on the parameter space.""" restrictions = deepcopy(self.restrictions) # differentiate between old style monolithic with single 'p' argument and newer *args style - if (len(restrictions) == 1 - and not isinstance(restrictions[0], (Constraint, FunctionConstraint, str)) - and callable(restrictions[0]) - and len(signature(restrictions[0]).parameters) == 1 + if (len(restrictions) == 1 + and not isinstance(restrictions[0], (Constraint, FunctionConstraint, str)) + and callable(restrictions[0]) + and len(signature(restrictions[0]).parameters) == 1 and len(self.param_names) > 1): restrictions = restrictions[0] if isinstance(restrictions, list):