diff --git a/mitreattack/navlayers/core/exceptions.py b/mitreattack/navlayers/core/exceptions.py index c7125b1..d8deaad 100644 --- a/mitreattack/navlayers/core/exceptions.py +++ b/mitreattack/navlayers/core/exceptions.py @@ -55,13 +55,18 @@ def typeChecker(caller, testee, desired_type, field): :param caller: the entity that called this function (used for error messages) :param testee: the element to test - :param desired_type: the type the element should be + :param desired_type: the type the element should be or a list of + allowed types :param field: what the element is to be used as (used for error messages) :raises BadType: error denoting the testee element is not of the correct type """ - if not isinstance(testee, desired_type): + if isinstance(desired_type, list): + if not any(isinstance(testee, t) for t in desired_type): + handler(caller, f"{testee} [{field}] is not one of {str(desired_type)}") + raise BadType + elif not isinstance(testee, desired_type): handler(caller, f"{testee} [{field}] is not a {str(desired_type)}") raise BadType diff --git a/mitreattack/navlayers/core/technique.py b/mitreattack/navlayers/core/technique.py index 21d2a4a..c66d505 100644 --- a/mitreattack/navlayers/core/technique.py +++ b/mitreattack/navlayers/core/technique.py @@ -94,12 +94,8 @@ def score(self): @score.setter def score(self, score): """Setter for score.""" - try: - typeChecker(type(self).__name__, score, int, "score") - self.__score = score - except BadType: - typeChecker(type(self).__name__, score, float, "score") - self.__score = int(score) + typeChecker(type(self).__name__, score, [int, float], "score") + self.__score = score @property def color(self):