diff --git a/pymc_marketing/pytensor_utils.py b/pymc_marketing/pytensor_utils.py index c527fae5..dad40ece 100644 --- a/pymc_marketing/pytensor_utils.py +++ b/pymc_marketing/pytensor_utils.py @@ -15,6 +15,7 @@ """PyTensor utility functions.""" import arviz as az +import pandas as pd import pytensor import pytensor.tensor as pt from arviz import InferenceData @@ -239,3 +240,271 @@ def extract_response_distribution( ) return response_distribution + + +class ModelSamplerEstimator: + """Estimate computational characteristics of a PyMC model using JAX/NumPyro. + + This utility measures the average evaluation time of the model's logp and gradients + and estimates the number of integrator steps taken by NUTS during warmup + sampling. + It then compiles the information into a single-row pandas DataFrame with helpful + metadata to guide planning and benchmarking. + + Parameters + ---------- + tune : int, default 1000 + Number of warmup iterations to use when estimating NUTS steps. + draws : int, default 1000 + Number of sampling iterations to use when estimating NUTS steps. + chains : int, default 1 + Intended number of chains (metadata only; not used in JAX runs here). + sequential_chains : int, default 1 + Number of chains expected to run sequentially on the target environment. + Used to scale the wall-clock time estimate. + seed : int | None, default None + Random seed used for the step estimation runs. + + Examples + -------- + .. code-block:: python + + est = ModelSamplerEstimator( + tune=1000, draws=1000, chains=4, sequential_chains=1, seed=1 + ) + df = est.run(model) + print(df) + """ + + def __init__( + self, + *, + tune: int = 1000, + draws: int = 1000, + chains: int = 1, + sequential_chains: int = 1, + seed: int | None = None, + ) -> None: + self.tune = int(tune) + self.draws = int(draws) + self.chains = int(chains) + self.sequential_chains = int(sequential_chains) + self.seed = seed + + @property + def default_nuts_kwargs(self) -> dict: + """Default keyword arguments for a NumPyro NUTS kernel. + + Mirrors the current hard-coded defaults used in this estimator. + """ + return { + "target_accept_prob": 0.8, + "adapt_step_size": True, + "adapt_mass_matrix": True, + "dense_mass": False, + } + + @property + def default_mcmc_kwargs(self) -> dict: + """Default keyword arguments for a NumPyro MCMC runner. + + Parameters that depend on the run size (``num_warmup`` and ``num_samples``) + are intentionally excluded and provided explicitly by the estimator. + """ + return { + "num_chains": 1, + "postprocess_fn": None, + "chain_method": "sequential", + "progress_bar": False, + } + + def estimate_model_eval_time( + self, model: Model, num_evaluations: int | None = None + ) -> float: + """Estimate average evaluation time (seconds) of logp+dlogp using JAX. + + Parameters + ---------- + model : Model + PyMC model whose logp and gradients are jitted and evaluated. + num_evaluations : int | None, optional + Number of repeated evaluations to average over. If ``None``, a value + is chosen to take roughly 5 seconds in total for a stable estimate. + + Returns + ------- + float + Average evaluation time in seconds. + """ + from time import perf_counter_ns + + import numpy as np + + try: + import jax + from pymc.sampling.jax import get_jaxified_logp + except Exception as err: # pragma: no cover - environment specific + raise RuntimeError( + "JAX backend is required for ModelSamplerEstimator." + ) from err + + initial_point = list(model.initial_point().values()) + logp_fn = get_jaxified_logp(model) + logp_dlogp_fn = jax.jit(jax.value_and_grad(logp_fn, argnums=0)) + logp_res, grad_res = logp_dlogp_fn(initial_point) + for val in (logp_res, *grad_res): + if not np.isfinite(val).all(): + raise RuntimeError( + "logp or gradients are not finite at the model initial point; the model may be misspecified" + ) + + if num_evaluations is None: + start = perf_counter_ns() + jax.block_until_ready(logp_dlogp_fn(initial_point)) + end = perf_counter_ns() + num_evaluations = max(5, int(5e9 / max(end - start, 1))) + + start = perf_counter_ns() + for _ in range(num_evaluations): + jax.block_until_ready(logp_dlogp_fn(initial_point)) + end = perf_counter_ns() + eval_time = (end - start) / num_evaluations * 1e-9 + return float(eval_time) + + def estimate_num_steps_sampling( + self, + model: Model, + *, + tune: int | None = None, + draws: int | None = None, + seed: int | None = None, + nuts_kwargs: dict | None = None, + mcmc_kwargs: dict | None = None, + ) -> int: + """Estimate total number of NUTS steps during warmup + sampling using NumPyro. + + Parameters + ---------- + model : Model + PyMC model to estimate steps for using a JAX/NumPyro NUTS kernel. + tune : int | None, optional + Warmup iterations. Defaults to the estimator setting if ``None``. + draws : int | None, optional + Sampling iterations. Defaults to the estimator setting if ``None``. + seed : int | None, optional + Random seed for the JAX run. Defaults to the estimator setting if ``None``. + nuts_kwargs : dict | None, optional + Additional keyword arguments passed to ``numpyro.infer.NUTS``. If not provided, + the estimator's ``default_nuts_kwargs`` are used. Provided values override + the defaults. + mcmc_kwargs : dict | None, optional + Additional keyword arguments passed to ``numpyro.infer.MCMC`` (excluding + ``num_warmup`` and ``num_samples``, which are set by ``tune``/``draws``). If not + provided, the estimator's ``default_mcmc_kwargs`` are used. Provided values + override the defaults. + + Returns + ------- + int + Total number of leapfrog steps across warmup + sampling. + """ + import numpy as np + + try: + import jax + from numpyro.infer import MCMC, NUTS + from pymc.sampling.jax import get_jaxified_logp + except Exception as err: # pragma: no cover - environment specific + raise RuntimeError( + "JAX and NumPyro are required for ModelSamplerEstimator." + ) from err + + num_warmup = int(self.tune if tune is None else tune) + num_samples = int(self.draws if draws is None else draws) + + initial_point = list(model.initial_point().values()) + logp_fn = get_jaxified_logp(model, negative_logp=False) + merged_nuts_kwargs = {**self.default_nuts_kwargs, **(nuts_kwargs or {})} + nuts_kernel = NUTS( + potential_fn=logp_fn, + **merged_nuts_kwargs, + ) + + merged_mcmc_kwargs = {**self.default_mcmc_kwargs, **(mcmc_kwargs or {})} + mcmc = MCMC( + nuts_kernel, + num_warmup=num_warmup, + num_samples=num_samples, + **merged_mcmc_kwargs, + ) + + if seed is None: + rng_seed = int(np.random.default_rng().integers(2**32)) + else: + rng_seed = int(seed) + + tune_rng, sample_rng = jax.random.split(jax.random.PRNGKey(int(rng_seed)), 2) + mcmc.warmup( + tune_rng, + init_params=initial_point, + extra_fields=("num_steps",), + collect_warmup=True, + ) + warmup_steps = int(mcmc.get_extra_fields()["num_steps"].sum()) + mcmc.run(sample_rng, extra_fields=("num_steps",)) + sample_steps = int(mcmc.get_extra_fields()["num_steps"].sum()) + return int(warmup_steps + sample_steps) + + def run(self, model: Model) -> pd.DataFrame: + """Execute the estimation pipeline and return a single-row DataFrame. + + Parameters + ---------- + model : Model + PyMC model to evaluate. + + Returns + ------- + pandas.DataFrame + Single-row DataFrame with columns including ``num_steps``, ``eval_time_seconds``, + ``sequential_chains``, and estimated sampling wall-clock time in seconds, + minutes, and hours, along with metadata such as ``tune``, ``draws``, ``chains``, + ``seed``, ``timestamp``, and ``model_name``. + + Examples + -------- + .. code-block:: python + + df = ModelSamplerEstimator().run(model) + df[ + [ + "num_steps", + "eval_time_seconds", + "estimated_sampling_time_minutes", + ] + ] + """ + import time + + steps = self.estimate_num_steps_sampling( + model, tune=self.tune, draws=self.draws, seed=self.seed + ) + eval_time_s = self.estimate_model_eval_time(model) + + sampling_time_seconds = float( + eval_time_s * steps * max(self.sequential_chains, 1) + ) + data = { + "model_name": getattr(model, "name", "PyMCModel"), + "num_steps": int(steps), + "eval_time_seconds": float(eval_time_s), + "sequential_chains": int(self.sequential_chains), + "estimated_sampling_time_seconds": sampling_time_seconds, + "estimated_sampling_time_minutes": sampling_time_seconds / 60.0, + "estimated_sampling_time_hours": sampling_time_seconds / 3600.0, + "tune": int(self.tune), + "draws": int(self.draws), + "chains": int(self.chains), + "seed": int(self.seed) if self.seed is not None else None, + "timestamp": pd.Timestamp.utcfromtimestamp(int(time.time())), + } + return pd.DataFrame([data]) diff --git a/tests/test_pytensor_utils.py b/tests/test_pytensor_utils.py index 7335bae2..4541d8b6 100644 --- a/tests/test_pytensor_utils.py +++ b/tests/test_pytensor_utils.py @@ -14,6 +14,9 @@ """Tests for pytensor_utils module.""" +import builtins as _builtins +import sys + import numpy as np import pandas as pd import pymc as pm @@ -28,7 +31,11 @@ MMM, MultiDimensionalBudgetOptimizerWrapper, ) -from pymc_marketing.pytensor_utils import _prefix_model, merge_models +from pymc_marketing.pytensor_utils import ( + ModelSamplerEstimator, + _prefix_model, + merge_models, +) @pytest.fixture @@ -321,6 +328,103 @@ def test_extract_response_distribution_vs_sample_response( print("\n✓ Both methods produce consistent results!") +@pytest.mark.parametrize("draws, tune", [(50, 50)]) +def test_model_sampler_estimator_with_simple_model(monkeypatch, draws, tune): + """Smoke test for ModelSamplerEstimator on a tiny PyMC model.""" + + with pm.Model() as model: + mu = pm.Normal("mu", 0.0, 1.0) + sigma = pm.HalfNormal("sigma", 1.0) + pm.Normal("y", mu=mu, sigma=sigma, observed=np.random.normal(0, 1, size=10)) + + est = ModelSamplerEstimator( + tune=tune, draws=draws, chains=2, sequential_chains=1, seed=123 + ) + df = est.run(model) + + # Check schema and basic values + expected_columns = { + "model_name", + "num_steps", + "eval_time_seconds", + "sequential_chains", + "estimated_sampling_time_seconds", + "estimated_sampling_time_minutes", + "estimated_sampling_time_hours", + "tune", + "draws", + "chains", + "seed", + "timestamp", + } + assert set(df.columns) >= expected_columns + + +def test_model_sampler_estimator_eval_time_multidim_model(fitted_multidim_mmm): + """Measure eval time for a fitted multidimensional MMM's PyMC model.""" + pm_model = fitted_multidim_mmm.model + + est = ModelSamplerEstimator( + tune=50, draws=50, chains=1, sequential_chains=1, seed=123 + ) + + steps = est.estimate_num_steps_sampling(pm_model) + assert steps > 0, f"Expected positive number of steps, got {steps}" + assert isinstance(steps, int), f"Expected steps to be an integer, got {type(steps)}" + + est_df = est.run(pm_model) + + # Check schema and basic values + expected_columns = { + "model_name", + "num_steps", + "eval_time_seconds", + "sequential_chains", + "estimated_sampling_time_seconds", + "estimated_sampling_time_minutes", + "estimated_sampling_time_hours", + "tune", + "draws", + "chains", + "seed", + "timestamp", + } + assert set(est_df.columns) >= expected_columns, ( + f"Expected columns {expected_columns} not found in {set(est_df.columns)}" + ) + + assert isinstance(est.default_mcmc_kwargs, dict), ( + f"Expected default_mcmc_kwargs to be a dict, got {type(est.default_mcmc_kwargs)}" + ) + assert isinstance(est.default_nuts_kwargs, dict), ( + f"Expected default_nuts_kwargs to be a dict, got {type(est.default_nuts_kwargs)}" + ) + + +def test_jax_numpyro_not_available(monkeypatch, fitted_multidim_mmm): + """Ensure estimate_model_eval_time raises when JAX is unavailable.""" + # Simulate that JAX (and pymc's jax helper) are not importable + monkeypatch.delitem(sys.modules, "jax", raising=False) + monkeypatch.delitem(sys.modules, "pymc.sampling.jax", raising=False) + + real_import = _builtins.__import__ + + def _blocked_import(name, globals=None, locals=None, fromlist=(), level=0): + if name == "jax" or name.startswith("pymc.sampling.jax"): + raise ImportError("blocked for test") + return real_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(_builtins, "__import__", _blocked_import) + + with pytest.raises( + RuntimeError, match="JAX backend is required for ModelSamplerEstimator" + ): + est = ModelSamplerEstimator( + tune=50, draws=50, chains=1, sequential_chains=1, seed=123 + ) + est.estimate_model_eval_time(fitted_multidim_mmm.model) + + def test_merge_models_prefix_and_merge_on_channel_data( fitted_multidim_mmm, sample_multidim_data ):