diff --git a/easybuild/framework/easyconfig/easyconfig.py b/easybuild/framework/easyconfig/easyconfig.py index dd022fc786..1164f959b9 100644 --- a/easybuild/framework/easyconfig/easyconfig.py +++ b/easybuild/framework/easyconfig/easyconfig.py @@ -65,6 +65,7 @@ from easybuild.framework.easyconfig.templates import ALTERNATIVE_EASYCONFIG_TEMPLATES, DEPRECATED_EASYCONFIG_TEMPLATES from easybuild.framework.easyconfig.templates import TEMPLATE_CONSTANTS, TEMPLATE_NAMES_DYNAMIC, template_constant_dict from easybuild.tools import LooseVersion +from easybuild.tools.entrypoints import EntrypointEasyblock from easybuild.tools.build_log import EasyBuildError, EasyBuildExit, print_warning, print_msg from easybuild.tools.config import GENERIC_EASYBLOCK_PKG, LOCAL_VAR_NAMING_CHECK_ERROR, LOCAL_VAR_NAMING_CHECK_LOG from easybuild.tools.config import LOCAL_VAR_NAMING_CHECK_WARN @@ -2016,9 +2017,15 @@ def get_easyblock_class(easyblock, name=None, error_on_failed_import=True, error class_name, modulepath) cls = get_class_for(modulepath, class_name) else: - modulepath = get_module_path(easyblock) - cls = get_class_for(modulepath, class_name) - _log.info("Derived full easyblock module path for %s: %s" % (class_name, modulepath)) + eb_from_eps = EntrypointEasyblock.get_loaded_entrypoints(name=easyblock) + if eb_from_eps: + ep = eb_from_eps[0] + cls = ep.wrapped + _log.info("Obtained easyblock class '%s' from entrypoint '%s'", easyblock, str(ep)) + else: + modulepath = get_module_path(easyblock) + cls = get_class_for(modulepath, class_name) + _log.info("Derived full easyblock module path for %s: %s" % (class_name, modulepath)) else: # if no easyblock specified, try to find if one exists if name is None: diff --git a/easybuild/framework/easyconfig/tools.py b/easybuild/framework/easyconfig/tools.py index 2e548298ff..d2b12ba357 100644 --- a/easybuild/framework/easyconfig/tools.py +++ b/easybuild/framework/easyconfig/tools.py @@ -54,6 +54,7 @@ from easybuild.framework.easyconfig.easyconfig import process_easyconfig from easybuild.framework.easyconfig.style import cmdline_easyconfigs_style_check from easybuild.tools import LooseVersion +from easybuild.tools.entrypoints import EntrypointEasyblock from easybuild.tools.build_log import EasyBuildError, EasyBuildExit, print_error, print_msg, print_warning from easybuild.tools.config import build_option from easybuild.tools.environment import restore_env @@ -799,6 +800,14 @@ def avail_easyblocks(): else: raise EasyBuildError("Failed to determine easyblock class name for %s", easyblock_loc) + ept_eb_lst = EntrypointEasyblock.get_loaded_entrypoints() + + for ept_eb in ept_eb_lst: + easyblocks[ept_eb.module] = { + 'class': ept_eb.name, + 'loc': ept_eb.file, + } + return easyblocks diff --git a/easybuild/tools/config.py b/easybuild/tools/config.py index cef93774a6..138e0ff153 100644 --- a/easybuild/tools/config.py +++ b/easybuild/tools/config.py @@ -351,6 +351,7 @@ def mk_full_default_path(name, prefix=DEFAULT_PREFIX): 'upload_test_report', 'update_modules_tool_cache', 'use_ccache', + 'use_entrypoints', 'use_existing_modules', 'use_f90cache', 'wait_on_lock_limit', diff --git a/easybuild/tools/entrypoints.py b/easybuild/tools/entrypoints.py new file mode 100644 index 0000000000..447e5eda28 --- /dev/null +++ b/easybuild/tools/entrypoints.py @@ -0,0 +1,210 @@ +"""Python module to manage entry points for EasyBuild. + +Authors: + +* Davide Grassano (CECAM) +""" +import sys +import importlib +from easybuild.tools.config import build_option + +from easybuild.base import fancylogger +from easybuild.tools.build_log import EasyBuildError +from typing import TypeVar, List, Set, Any + +_T = TypeVar('_T') + + +HAVE_ENTRY_POINTS = False +HAVE_ENTRY_POINTS_CLS = False +if sys.version_info >= (3, 8): + HAVE_ENTRY_POINTS = True + from importlib.metadata import entry_points, EntryPoint +else: + EntryPoint = Any + +if sys.version_info >= (3, 10): + # Python >= 3.10 uses importlib.metadata.EntryPoints as a type for entry_points() + HAVE_ENTRY_POINTS_CLS = True + + +_log = fancylogger.getLogger('entrypoints', fname=False) + + +class EasybuildEntrypoint: + group = None + expected_type = None + registered = {} + + def __init__(self): + if self.group is None: + raise EasyBuildError( + "Cannot use drirectly. Please use a subclass that defines `group`", + ) + + self.wrapped = None + self.module = None + self.name = None + self.file = None + + def __repr__(self): + return f"{self.__class__.__name__} <{self.module}:{self.name}>" + + def __call__(self, wrap: _T) -> _T: + """Use an instance of this class as a decorator to register an entrypoint.""" + if self.expected_type is not None: + check = False + try: + check = isinstance(wrap, self.expected_type) or issubclass(wrap, self.expected_type) + except Exception: + pass + if not check: + raise EasyBuildError( + "Entrypoint '%s' expected type '%s', got '%s'", + self.name, self.expected_type, type(wrap) + ) + self.wrapped = wrap + self.module = getattr(wrap, '__module__', None) + self.name = getattr(wrap, '__name__', None) + if self.module: + mod = importlib.import_module(self.module) + self.file = getattr(mod, '__file__', None) + + grp = self.registered.setdefault(self.group, set()) + + for ep in grp: + if ep.name == self.name and ep.module != self.module: + raise ValueError( + "Entrypoint '%s' already registered in group '%s' by module '%s' vs '%s'", + self.name, self.group, ep.module, self.module + ) + grp.add(self) + + self.validate() + + _log.debug("Registered entrypoint: %s", self) + + return wrap + + @classmethod + def retrieve_entrypoints(cls) -> Set[EntryPoint]: + """"Get all entrypoints in this group.""" + strict_python = True + use_eps = build_option('use_entrypoints', default=None) + if use_eps is None: + # Default True needed to work with commands like --list-toolchains that do not initialize the BuildOptions + use_eps = True + # Needed to work with older Python versions: do not raise errors when entry points are default enabled + strict_python = False + res = set() + if use_eps: + if not HAVE_ENTRY_POINTS: + if strict_python: + msg = "`--use-entrypoints` requires importlib.metadata (Python >= 3.8)" + _log.warning(msg) + raise EasyBuildError(msg) + else: + _log.debug("`get_group_entrypoints` called before BuildOptions initialized, with python < 3.8") + else: + if HAVE_ENTRY_POINTS_CLS: + res = set(entry_points(group=cls.group)) + else: + res = set(entry_points().get(cls.group, [])) + + return res + + @classmethod + def load_entrypoints(cls): + """Load all the entrypoints in this group. This is needed for the modules contining the entrypoints to be + actually imported in order to process the function decorators that will register them in the + `registered` dict.""" + for ep in cls.retrieve_entrypoints(): + try: + ep.load() + except Exception as e: + msg = f"Error loading entrypoint {ep}: {e}" + _log.warning(msg) + raise EasyBuildError(msg) from e + + @classmethod + def get_loaded_entrypoints(cls: _T, name: str = None, **filter_params) -> List[_T]: + """Get all entrypoints in this group.""" + cls.load_entrypoints() + + entrypoints = [] + for ep in cls.registered.get(cls.group, []): + cond = name is None or ep.name == name + for key, value in filter_params.items(): + cond = cond and getattr(ep, key, None) == value + if cond: + entrypoints.append(ep) + + return entrypoints + + @staticmethod + def clear(): + """Clear the registered entrypoints. Used for testing when the same entrypoint is loaded multiple times + from different temporary directories.""" + EasybuildEntrypoint.registered.clear() + + def validate(self): + """Validate the entrypoint.""" + if self.module is None or self.name is None: + raise EasyBuildError("Entrypoint `%s` has no module or name associated", self.wrapped) + + +class EntrypointHook(EasybuildEntrypoint): + """Class to represent a hook entrypoint.""" + group = 'easybuild.hooks' + + def __init__(self, step, pre_step=False, post_step=False, priority=0): + """Initialize the EntrypointHook.""" + super().__init__() + self.step = step + self.pre_step = pre_step + self.post_step = post_step + self.priority = priority + + def validate(self): + """Validate the hook entrypoint.""" + from easybuild.tools.hooks import KNOWN_HOOKS, HOOK_SUFF, PRE_PREF, POST_PREF + super().validate() + + if not callable(self.wrapped): + raise EasyBuildError("Hook entrypoint `%s` is not callable", self.wrapped) + + prefix = '' + if self.pre_step: + prefix = PRE_PREF + elif self.post_step: + prefix = POST_PREF + + hook_name = f'{prefix}{self.step}{HOOK_SUFF}' + + if hook_name not in KNOWN_HOOKS: + msg = f"Attempting to register unknown hook '{hook_name}'" + _log.warning(msg) + raise EasyBuildError(msg) + + +class EntrypointEasyblock(EasybuildEntrypoint): + """Class to represent an easyblock entrypoint.""" + group = 'easybuild.easyblock' + + def __init__(self): + super().__init__() + # Avoid circular imports by importing EasyBlock here + from easybuild.framework.easyblock import EasyBlock + self.expected_type = EasyBlock + + +class EntrypointToolchain(EasybuildEntrypoint): + """Class to represent a toolchain entrypoint.""" + group = 'easybuild.toolchain' + + def __init__(self, prepend=False): + super().__init__() + # Avoid circular imports by importing Toolchain here + from easybuild.tools.toolchain.toolchain import Toolchain + self.expected_type = Toolchain + self.prepend = prepend diff --git a/easybuild/tools/hooks.py b/easybuild/tools/hooks.py index 4451439856..8e95f23c01 100644 --- a/easybuild/tools/hooks.py +++ b/easybuild/tools/hooks.py @@ -32,6 +32,8 @@ import difflib import os +from easybuild.tools.entrypoints import EntrypointHook + from easybuild.base import fancylogger from easybuild.tools.build_log import EasyBuildError, print_msg from easybuild.tools.config import build_option @@ -233,12 +235,9 @@ def run_hook(label, hooks, pre_step_hook=False, post_step_hook=False, args=None, """ hook = find_hook(label, hooks, pre_step_hook=pre_step_hook, post_step_hook=post_step_hook) res = None + args = args or [] + kwargs = kwargs or {} if hook: - if args is None: - args = [] - if kwargs is None: - kwargs = {} - if pre_step_hook: label = 'pre-' + label elif post_step_hook: @@ -251,4 +250,26 @@ def run_hook(label, hooks, pre_step_hook=False, post_step_hook=False, args=None, _log.info("Running '%s' hook function (args: %s, keyword args: %s)...", hook.__name__, args, kwargs) res = hook(*args, **kwargs) + + entrypoint_hooks = EntrypointHook.get_loaded_entrypoints( + step=label, pre_step=pre_step_hook, post_step=post_step_hook + ) + if entrypoint_hooks: + msg = "Running entry point %s hook..." % label + if build_option('debug') and not build_option('silence_hook_trigger'): + print_msg(msg) + entrypoint_hooks.sort( + key=lambda x: (-x.priority, x.name), + ) + for hook in entrypoint_hooks: + _log.info( + "Running entry point '%s' hook function (args: %s, keyword args: %s)...", + hook.name, args, kwargs + ) + try: + res = hook.wrapped(*args, **kwargs) + except Exception as e: + _log.warning("Error running entry point '%s' hook: %s", hook.name, e) + raise EasyBuildError("Error running entry point '%s' hook: %s", hook.name, e) from e + return res diff --git a/easybuild/tools/options.py b/easybuild/tools/options.py index cd539cf547..aea28cc3c9 100644 --- a/easybuild/tools/options.py +++ b/easybuild/tools/options.py @@ -110,6 +110,7 @@ from easybuild.tools.systemtools import get_cpu_features, get_gpu_info, get_os_type, get_system_info from easybuild.tools.utilities import flatten from easybuild.tools.version import this_is_easybuild +from easybuild.tools.entrypoints import EntrypointHook, EntrypointEasyblock, EntrypointToolchain try: @@ -303,6 +304,9 @@ def basic_options(self): 'stop': ("Stop the installation after certain step", 'choice', 'store_or_None', EXTRACT_STEP, 's', all_stops), 'strict': ("Set strictness level", 'choice', 'store', WARN, strictness_options), + 'use-entrypoints': ( + "Use entry points for easyblocks, toolchains, and hooks", None, 'store_true', False, + ), }) self.log.debug("basic_options: descr %s opts %s" % (descr, opts)) @@ -1634,6 +1638,19 @@ def det_location(opt, prefix=''): pretty_print_opts(opts_dict) + if build_option('use_entrypoints', default=True): + for prefix, cls in [ + ('Hook', EntrypointHook), + ('Easyblock', EntrypointEasyblock), + ('Toolchain', EntrypointToolchain), + ]: + ept_list = cls.retrieve_entrypoints() + if ept_list: + print() + print("%ss from entrypoints (%d):" % (prefix, len(ept_list))) + for ept in ept_list: + print('-', ept) + def parse_options(args=None, with_include=True): """wrapper function for option parsing""" diff --git a/easybuild/tools/toolchain/toolchain.py b/easybuild/tools/toolchain/toolchain.py index d89c10b71a..312724767e 100644 --- a/easybuild/tools/toolchain/toolchain.py +++ b/easybuild/tools/toolchain/toolchain.py @@ -168,7 +168,7 @@ class Toolchain: CLASS_CONSTANTS_TO_RESTORE = None CLASS_CONSTANT_COPIES = {} - # class method + @classmethod def _is_toolchain_for(cls, name): """see if this class can provide support for toolchain named name""" # TODO report later in the initialization the found version @@ -181,8 +181,6 @@ def _is_toolchain_for(cls, name): # is no name is supplied, check whether class can be used as a toolchain return bool(getattr(cls, 'NAME', None)) - _is_toolchain_for = classmethod(_is_toolchain_for) - def __init__(self, name=None, version=None, mns=None, class_constants=None, tcdeps=None, modtool=None, hidden=False): """ diff --git a/easybuild/tools/toolchain/utilities.py b/easybuild/tools/toolchain/utilities.py index 90a0c99583..733df5ceec 100644 --- a/easybuild/tools/toolchain/utilities.py +++ b/easybuild/tools/toolchain/utilities.py @@ -40,6 +40,7 @@ import sys import easybuild.tools.toolchain +from easybuild.tools.entrypoints import EntrypointToolchain from easybuild.base import fancylogger from easybuild.tools.build_log import EasyBuildError from easybuild.tools.toolchain.toolchain import Toolchain @@ -77,6 +78,7 @@ def search_toolchain(name): # exclude the toolchain class defined in that module if not tc_mod.__file__ == sys.modules[elem.__module__].__file__: elem_name = getattr(elem, '__name__', elem) + # print(f" Adding {elem_name} to list of imported classes used for looking for constants") _log.debug("Adding %s to list of imported classes used for looking for constants", elem_name) mod_classes.append(elem) @@ -106,6 +108,15 @@ def search_toolchain(name): # obtain all subclasses of toolchain found_tcs = nub(get_subclasses(Toolchain)) + # Getting all subclasses will also include toolchains that are registered as entrypoints even if we are not + # using the `--use-entrypoints` option, so we filter them out here and re-add them later if needed. + all_eps_names = [ep.wrapped.NAME for ep in EntrypointToolchain.get_loaded_entrypoints()] + found_tcs = [x for x in found_tcs if x.NAME not in all_eps_names] + + prepend_eps = [_.wrapped for _ in EntrypointToolchain.get_loaded_entrypoints(prepend=True)] + append_eps = [_.wrapped for _ in EntrypointToolchain.get_loaded_entrypoints(prepend=False)] + found_tcs = prepend_eps + found_tcs + append_eps + # filter found toolchain subclasses based on whether they can be used a toolchains found_tcs = [tc for tc in found_tcs if tc._is_toolchain_for(None)] diff --git a/test/framework/entrypoints.py b/test/framework/entrypoints.py new file mode 100644 index 0000000000..ce925b9365 --- /dev/null +++ b/test/framework/entrypoints.py @@ -0,0 +1,488 @@ +# # +# Copyright 2013-2025 Ghent University +# +# This file is part of EasyBuild, +# originally created by the HPC team of Ghent University (http://ugent.be/hpc/en), +# with support of Ghent University (http://ugent.be/hpc), +# the Flemish Supercomputer Centre (VSC) (https://www.vscentrum.be), +# Flemish Research Foundation (FWO) (http://www.fwo.be/en) +# and the Department of Economy, Science and Innovation (EWI) (http://www.ewi-vlaanderen.be/en). +# +# https://github.com/easybuilders/easybuild +# +# EasyBuild is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation v2. +# +# EasyBuild is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with EasyBuild. If not, see . +# # +""" +Unit tests for EasyBuild configuration. + +@author: Davide Grassano (CECAM - EPFL) +""" + +import os +import shutil +import sys +import tempfile +from importlib import reload +from test.framework.utilities import EnhancedTestCase, TestLoaderFiltered, init_config +from unittest import TextTestRunner + +import easybuild.tools.options as eboptions +from easybuild.tools.build_log import EasyBuildError +from easybuild.tools.docs import list_easyblocks, list_toolchains +from easybuild.tools.entrypoints import ( + HAVE_ENTRY_POINTS, EntrypointHook, EntrypointEasyblock, EntrypointToolchain, EasybuildEntrypoint +) +from easybuild.tools.filetools import write_file +from easybuild.tools.hooks import run_hook, START, CONFIGURE_STEP +from easybuild.framework.easyconfig.easyconfig import get_easyblock_class + + +if HAVE_ENTRY_POINTS: + from importlib.metadata import DistributionFinder, Distribution +else: + DistributionFinder = object + Distribution = object + + +MOCK_HOOK_EP_NAME = "mock_hook" +MOCK_EASYBLOCK_EP_NAME = "mock_easyblock" +MOCK_TOOLCHAIN_EP_NAME = "mock_toolchain" + +MOCK_HOOK = "hello_world_12412412" +MOCK_EASYBLOCK = "TestEasyBlock_1212461" +MOCK_TOOLCHAIN = "MockTc_352124671346" + + +MOCK_EP_FILE = f""" +from easybuild.tools.entrypoints import EntrypointHook +from easybuild.tools.hooks import CONFIGURE_STEP, START + + +@EntrypointHook(START) +def {MOCK_HOOK}(): + print("Hello, World! ----------------------------------------") + +def {MOCK_HOOK}_invalid(): + print("This hook should not be registered, as it is invalid.") + +########################################################################## +from easybuild.framework.easyblock import EasyBlock +from easybuild.tools.entrypoints import EntrypointEasyblock + +@EntrypointEasyblock() +class {MOCK_EASYBLOCK}(EasyBlock): + def configure_step(self): + print("{MOCK_EASYBLOCK}: configure_step called.") + + def build_step(self): + print("{MOCK_EASYBLOCK}: build_step called.") + + def install_step(self): + print("{MOCK_EASYBLOCK}: install_step called.") + + def sanity_check_step(self): + print("{MOCK_EASYBLOCK}: sanity_check_step called.") + +class {MOCK_EASYBLOCK}_invalid(EasyBlock): + pass + +########################################################################## +from easybuild.tools.entrypoints import EntrypointToolchain +from easybuild.tools.toolchain.compiler import DEFAULT_OPT_LEVEL, Compiler +from easybuild.tools.toolchain.toolchain import SYSTEM_TOOLCHAIN_NAME + +TC_CONSTANT_MOCK = "Mock" + +class MockCompiler(Compiler): + COMPILER_FAMILY = TC_CONSTANT_MOCK + SUBTOOLCHAIN = SYSTEM_TOOLCHAIN_NAME + +@EntrypointToolchain() +class {MOCK_TOOLCHAIN}(MockCompiler): + NAME = '{MOCK_TOOLCHAIN}' # Using `...tc` to distinguish toolchain from package + COMPILER_MODULE_NAME = [NAME] + SUBTOOLCHAIN = [SYSTEM_TOOLCHAIN_NAME] + +class {MOCK_TOOLCHAIN}_invalid(MockCompiler): + pass +""" + + +MOCK_EP_META_FILE = f""" +[{EntrypointHook.group}] +{MOCK_HOOK_EP_NAME} = {{module}}:{MOCK_HOOK} +{{invalid_hook}} + +[{EntrypointEasyblock.group}] +{MOCK_EASYBLOCK_EP_NAME} = {{module}}:{MOCK_EASYBLOCK} +{{invalid_easyblock}} + +[{EntrypointToolchain.group}] +{MOCK_TOOLCHAIN_EP_NAME} = {{module}}:{MOCK_TOOLCHAIN} +{{invalid_toolchain}} +""" + +FORMAT_DCT = { + 'invalid_hook': '', + 'invalid_easyblock': '', + 'invalid_toolchain': '', +} + + +class MockDistribution(Distribution): + """Mock distribution for testing entry points.""" + def __init__(self, module): + self.module = module + + def read_text(self, filename): + if filename == "entry_points.txt": + return MOCK_EP_META_FILE.format(module=self.module, **FORMAT_DCT) + + if filename == "METADATA": + return "Name: mock_hook\nVersion: 0.1.0\n" + + +class MockDistributionFinder(DistributionFinder): + """Mock distribution finder for testing entry points.""" + def __init__(self, *args, module, **kwargs): + super().__init__(*args, **kwargs) + self.module = module + + def find_distributions(self, context=None): + yield MockDistribution(self.module) + + +class EasyBuildEntrypointsTest(EnhancedTestCase): + """Test cases for EasyBuild configuration.""" + + tmpdir = None + + def _run_mock_eb(self, args, strip=False, **kwargs): + """Helper function to mock easybuild runs + + Return (stdout, stderr) optionally stripped of whitespace at start/end + """ + with self.mocked_stdout_stderr() as (stdout, stderr): + self.eb_main(args, **kwargs) + stdout_txt = stdout.getvalue() + stderr_txt = stderr.getvalue() + if strip: + stdout_txt = stdout_txt.strip() + stderr_txt = stderr_txt.strip() + return stdout_txt, stderr_txt + + def setUp(self): + """Set up the test environment.""" + global FORMAT_DCT + + FORMAT_DCT = { + 'invalid_hook': '', + 'invalid_easyblock': '', + 'invalid_toolchain': '', + } + + reload(eboptions) + super().setUp() + self.tmpdir = tempfile.mkdtemp(prefix='easybuild_test_') + + if HAVE_ENTRY_POINTS: + filename_root = "mock" + dirname, dirpath = os.path.split(self.tmpdir) + + self.module = '.'.join([dirpath, filename_root]) + sys.path.insert(0, dirname) + sys.meta_path.insert(0, MockDistributionFinder(module=self.module)) + + # Create a mock entry point for testing + self.mock_hook_file = os.path.join(self.tmpdir, f'{filename_root}.py') + write_file(self.mock_hook_file, MOCK_EP_FILE) + else: + self.skipTest("Entry points not available in this Python version") + + def tearDown(self): + """Clean up the test environment.""" + super().tearDown() + + try: + shutil.rmtree(self.tmpdir) + except OSError: + pass + tempfile.tempdir = None + + if HAVE_ENTRY_POINTS: + # Remove the entry point from the working set + dirname, _ = os.path.split(self.tmpdir) + if dirname in sys.path: + sys.path.remove(dirname) + torm = [] + for idx, cls in enumerate(sys.meta_path): + if isinstance(cls, MockDistributionFinder): + torm.append(idx) + for idx in reversed(torm): + del sys.meta_path[idx] + + EntrypointHook.clear() + + def test_entrypoints_baseclass_raises(self): + """Test that attempting to register an entry point with the base class raises an error.""" + with self.assertRaises(EasyBuildError): + EasybuildEntrypoint()(lambda: None) + + def test_entrypoints_register_hook(self): + """Test registering entry point hooks with both valid and invalid hook names.""" + # Dummy function + def func(): + return + + # Invalid step name + with self.assertRaises(EasyBuildError): + EntrypointHook('123')(func) + + # Valid name but invalid combination of step and pre/post + with self.assertRaises(EasyBuildError): + EntrypointHook(START, pre_step=True)(func) + + # Valid hook registration + EntrypointHook(START)(func) + + def test_entrypoints_register_easyblock(self): + """Test registering entry point easyblocks with both valid and invalid easyblock names.""" + from easybuild.framework.easyblock import EasyBlock + decorator = EntrypointEasyblock() + + with self.assertRaises(EasyBuildError): + decorator(123) + + class MOCK(): + pass + with self.assertRaises(EasyBuildError): + decorator(MOCK) + + class MOCK(EasyBlock): + pass + decorator(MOCK) + + def test_entrypoints_register_toolchain(self): + """Test registering entry point toolchains with both valid and invalid toolchain names.""" + from easybuild.tools.toolchain.toolchain import Toolchain + decorator = EntrypointToolchain() + + with self.assertRaises(EasyBuildError): + decorator(123) + + class MOCK(): + pass + with self.assertRaises(EasyBuildError): + decorator(MOCK) + + class MOCK(Toolchain): + pass + decorator(MOCK) + + def test_entrypoints_get_group(self): + """Test retrieving entrypoints for a specific group.""" + expected = { + EntrypointHook: MOCK_HOOK_EP_NAME, + EntrypointEasyblock: MOCK_EASYBLOCK_EP_NAME, + EntrypointToolchain: MOCK_TOOLCHAIN_EP_NAME, + } + + for ep_type in [EntrypointHook, EntrypointEasyblock, EntrypointToolchain]: + group = ep_type.group + epts = ep_type.retrieve_entrypoints() + self.assertIsInstance(epts, set, f"Expected set for group {group}") + self.assertEqual(len(epts), 0, f"Expected non-empty set for group {group}") + + init_config(build_options={'use_entrypoints': True}) + for ep_type in [EntrypointHook, EntrypointEasyblock, EntrypointToolchain]: + group = ep_type.group + epts = ep_type.retrieve_entrypoints() + self.assertIsInstance(epts, set, f"Expected set for group {group}") + self.assertGreater(len(epts), 0, f"Expected non-empty set for group {group}") + + loaded_names = [ep.name for ep in epts] + expt = expected[ep_type] + self.assertIn(expt, loaded_names, f"Expected entry point {expt} in group {group}") + + def test_entrypoints_exclude_invalid(self): + """Check that invalid entry points are excluded from the get_entrypoints function.""" + init_config(build_options={'use_entrypoints': True}) + + # Check that the invalid hook is not registered + + FORMAT_DCT['invalid_hook'] = f"{MOCK_HOOK_EP_NAME}_invalid = {self.module}:{MOCK_HOOK}_invalid" + FORMAT_DCT['invalid_easyblock'] = f"{MOCK_EASYBLOCK_EP_NAME}_invalid = {self.module}:{MOCK_EASYBLOCK}_invalid" + FORMAT_DCT['invalid_toolchain'] = f"{MOCK_TOOLCHAIN_EP_NAME}_invalid = {self.module}:{MOCK_TOOLCHAIN}_invalid" + + hooks = EntrypointHook.get_loaded_entrypoints() + self.assertNotIn( + MOCK_HOOK + '_invalid', [ep.name for ep in hooks], "Invalid hook should not be registered" + ) + + # Check that the invalid easyblock is not registered + easyblocks = EntrypointEasyblock.get_loaded_entrypoints() + self.assertNotIn( + MOCK_EASYBLOCK + '_invalid', [ep.name for ep in easyblocks], "Invalid easyblock should not be registered" + ) + + # Check that the invalid toolchain is not registered + toolchains = EntrypointToolchain.get_loaded_entrypoints() + self.assertNotIn( + MOCK_TOOLCHAIN + '_invalid', [ep.name for ep in toolchains], "Invalid toolchain should not be registered" + ) + + def test_entrypoints_list_easyblocks(self): + """ + Tests for list_easyblocks function with entry points enabled. + """ + # Invalid EBs are still picked up as subclasses of EasyBlock, difficult to exclude them from this behavior + # txt = list_easyblocks() + # self.assertNotIn("TestEasyBlock", txt, "TestEasyBlock should not be listed without entry points enabled") + + init_config(build_options={'use_entrypoints': True}) + txt = list_easyblocks() + self.assertIn("TestEasyBlock", txt, "TestEasyBlock should be listed with entry points enabled") + + def test_entrypoints_list_toolchains(self): + """ + Tests for list_toolchains function with entry points enabled. + """ + # Invalid TCs are still picked up as subclasses of Toolchain, difficult to exclude them from this behavior + # txt = list_toolchains() + # self.assertNotIn(MOCK_TOOLCHAIN, txt, f"{MOCK_TOOLCHAIN} should not be listed without entry points enabled") + + init_config(build_options={'use_entrypoints': True}) + + txt = list_toolchains() + self.assertIn(MOCK_TOOLCHAIN, txt, f"{MOCK_TOOLCHAIN} should be listed with entry points enabled") + + def test_entrypoints_get_easyblock_class(self): + """ + Tests for get_easyblock_class function with entry points enabled. + """ + with self.assertRaises(EasyBuildError): + get_easyblock_class(MOCK_EASYBLOCK) + # self.assertIn('.generic.', module_path, "Module path should contain '.generic.'") + + init_config(build_options={'use_entrypoints': True}) + # Reload the EasyBlock module to ensure it is recognized + cls = get_easyblock_class(MOCK_EASYBLOCK) + self.assertEqual(cls.__module__, self.module, "Module path should match the mock module path") + + def test_entrypoints_show_config(self): + """Test that showing configuration includes entry points.""" + args = ['--show-config'] + stdout, stderr = self._run_mock_eb(args, strip=True) + + for name in ['Hooks', 'Easyblocks', 'Toolchains']: + pattern = f"{name} from entrypoints (" + self.assertIn(pattern, stdout, f"Expected {name} in configuration output") + + args = ['--show-full-config'] + stdout, stderr = self._run_mock_eb(args, strip=True) + + for name in ['Hooks', 'Easyblocks', 'Toolchains']: + pattern = f"{name} from entrypoints (" + self.assertIn(pattern, stdout, f"Expected {name} in configuration output") + + def test_entrypoints_register_invalid_hook(self): + """Test that registering an invalid hook steps raises an error.""" + # Invalid step name + with self.assertRaises(EasyBuildError): + EntrypointHook('invalid_hook_name')(lambda: None) + + # START does not have the pre/post prefixes + with self.assertRaises(EasyBuildError): + EntrypointHook(START, pre_step=True)(lambda: None) + + # CONFIGURE_STEP must have a pre/post prefix + with self.assertRaises(EasyBuildError): + EntrypointHook(CONFIGURE_STEP)(lambda: None) + + def test_entrypoints_run_hook(self): + """Ensure that entry point hooks are run in the correct order.""" + cnt = 0 + + @EntrypointHook(START, priority=50) + def func2_2(): + nonlocal cnt + self.assertEqual(cnt, 2, "This hook should be run third because of name ordering") + cnt += 1 + + @EntrypointHook(START, priority=50) + def func2_1(): + nonlocal cnt + self.assertEqual(cnt, 1, "This hook should be run second because of name ordering") + cnt += 1 + + @EntrypointHook(START, priority=10) + def func3(): + nonlocal cnt + self.assertEqual(cnt, 3, "This hook should be run last") + cnt += 1 + + @EntrypointHook(START, priority=100) + def func1(): + nonlocal cnt + self.assertEqual(cnt, 0, "This hook should be run first") + cnt += 1 + + @EntrypointHook(CONFIGURE_STEP, pre_step=True) + def func_configure_pre(): + nonlocal cnt + self.assertEqual(cnt, 4, "This hook should be run after all START hooks") + cnt += 1 + + run_hook(START, {}) + + self.assertEqual(cnt, 4, "All hooks should have been run in the correct order") + + def test_entrypoints_run_hook_onlyreq(self): + """Ensure that only the hooks required for a step are run.""" + tpl_flags = {'start': False, 'pre_cfg': False, 'post_cfg': False} + + @EntrypointHook(START) + def func_start(): + flags['start'] = True + + @EntrypointHook(CONFIGURE_STEP, pre_step=True) + def func_configure_pre(): + flags['pre_cfg'] = True + + @EntrypointHook(CONFIGURE_STEP, post_step=True) + def func_configure_post(): + flags['post_cfg'] = True + + flags = tpl_flags.copy() + run_hook(START, {}) + for key, val in flags.items(): + self.assertEqual(val, key == 'start', "Should only run START hooks") + + flags = tpl_flags.copy() + run_hook(CONFIGURE_STEP, {}, pre_step_hook=True) + for key, val in flags.items(): + self.assertEqual(val, key == 'pre_cfg', "Should only run pre-configure hooks") + + flags = tpl_flags.copy() + run_hook(CONFIGURE_STEP, {}, post_step_hook=True) + for key, val in flags.items(): + self.assertEqual(val, key == 'post_cfg', "Should only run post-configure hooks") + + +def suite(): + return TestLoaderFiltered().loadTestsFromTestCase(EasyBuildEntrypointsTest, sys.argv[1:]) + + +if __name__ == '__main__': + res = TextTestRunner(verbosity=1).run(suite()) + sys.exit(len(res.failures)) diff --git a/test/framework/suite.py b/test/framework/suite.py index afec127c83..28de8ee56e 100755 --- a/test/framework/suite.py +++ b/test/framework/suite.py @@ -52,6 +52,7 @@ import test.framework.easyconfigversion as ev import test.framework.easystack as es import test.framework.ebconfigobj as ebco +import test.framework.entrypoints as epts import test.framework.environment as env import test.framework.docs as d import test.framework.filetools as f @@ -119,7 +120,7 @@ # call suite() for each module and then run them all # note: make sure the options unit tests run first, to avoid running some of them with a readily initialized config -tests = [gen, d, bl, o, r, ef, ev, ebco, ep, e, mg, m, mt, f, run, a, robot, b, v, g, tcv, tc, t, c, s, lic, f_c, +tests = [gen, d, bl, o, r, ef, ev, ebco, ep, epts, e, mg, m, mt, f, run, a, robot, b, v, g, tcv, tc, t, c, s, lic, f_c, tw, p, i, pkg, env, et, st, h, ct, lib, u, es, ou] SUITE = unittest.TestSuite([x.suite() for x in tests])