Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 234 additions & 0 deletions pymc_marketing/pytensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""PyTensor utility functions."""

import arviz as az
import pandas as pd
import pytensor.tensor as pt
from arviz import InferenceData
from pymc import Model
Expand Down Expand Up @@ -106,3 +107,236 @@ def extract_response_distribution(
)

return response_distribution


class ModelSamplerEstimator:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is confusing because it is not an estimator right? Also, the name must contain something referring to JAX or numpyro?

"""Estimate computational characteristics of a PyMC model using JAX/NumPyro.
This utility measures the average evaluation time of the model's logp and gradients
and estimates the number of integrator steps taken by NUTS during warmup + sampling.
It then compiles the information into a single-row pandas DataFrame with helpful
metadata to guide planning and benchmarking.
Parameters
----------
tune : int, default 1000
Number of warmup iterations to use when estimating NUTS steps.
draws : int, default 1000
Number of sampling iterations to use when estimating NUTS steps.
chains : int, default 1
Intended number of chains (metadata only; not used in JAX runs here).
sequential_chains : int, default 1
Number of chains expected to run sequentially on the target environment.
Used to scale the wall-clock time estimate.
seed : int | None, default None
Random seed used for the step estimation runs.
Examples
--------
.. code-block:: python
est = ModelSamplerEstimator(
tune=1000, draws=1000, chains=4, sequential_chains=1, seed=1
)
df = est.run(model)
print(df)
"""

def __init__(
self,
*,
tune: int = 1000,
draws: int = 1000,
chains: int = 1,
sequential_chains: int = 1,
seed: int | None = None,
) -> None:
self.tune = int(tune)
self.draws = int(draws)
self.chains = int(chains)
self.sequential_chains = int(sequential_chains)
self.seed = seed

def estimate_model_eval_time(self, model: Model, n: int | None = None) -> float:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we give n a more descriptive namr?

"""Estimate average evaluation time (seconds) of logp+dlogp using JAX.
Parameters
----------
model : Model
PyMC model whose logp and gradients are jitted and evaluated.
n : int | None, optional
Number of repeated evaluations to average over. If ``None``, a value
is chosen to take roughly 5 seconds in total for a stable estimate.
Returns
-------
float
Average evaluation time in seconds.
"""
from time import perf_counter_ns

import numpy as np

try:
import jax
from pymc.sampling.jax import get_jaxified_logp
except Exception as err: # pragma: no cover - environment specific
raise RuntimeError(
"JAX backend is required for ModelSamplerEstimator."
) from err

initial_point = list(model.initial_point().values())
logp_fn = get_jaxified_logp(model)
logp_dlogp_fn = jax.jit(jax.value_and_grad(logp_fn, argnums=0))
logp_res, grad_res = logp_dlogp_fn(initial_point)
for val in (logp_res, *grad_res):
if not np.isfinite(val).all():
raise RuntimeError(
"logp or gradients are not finite at the model initial point; the model may be misspecified"
)

if n is None:
start = perf_counter_ns()
jax.block_until_ready(logp_dlogp_fn(initial_point))
end = perf_counter_ns()
n = max(5, int(5e9 / max(end - start, 1)))

start = perf_counter_ns()
for _ in range(n):
jax.block_until_ready(logp_dlogp_fn(initial_point))
end = perf_counter_ns()
eval_time = (end - start) / n * 1e-9
return float(eval_time)

def estimate_num_steps_sampling(
self,
model: Model,
*,
tune: int | None = None,
draws: int | None = None,
seed: int | None = None,
) -> int:
"""Estimate total number of NUTS steps during warmup + sampling using NumPyro.
Parameters
----------
model : Model
PyMC model to estimate steps for using a JAX/NumPyro NUTS kernel.
tune : int | None, optional
Warmup iterations. Defaults to the estimator setting if ``None``.
draws : int | None, optional
Sampling iterations. Defaults to the estimator setting if ``None``.
seed : int | None, optional
Random seed for the JAX run. Defaults to the estimator setting if ``None``.
Returns
-------
int
Total number of leapfrog steps across warmup + sampling.
"""
import numpy as np

try:
import jax
from numpyro.infer import MCMC, NUTS
from pymc.sampling.jax import get_jaxified_logp
except Exception as err: # pragma: no cover - environment specific
raise RuntimeError(
"JAX and NumPyro are required for ModelSamplerEstimator."
) from err

num_warmup = int(self.tune if tune is None else tune)
num_samples = int(self.draws if draws is None else draws)

initial_point = list(model.initial_point().values())
logp_fn = get_jaxified_logp(model, negative_logp=False)
nuts_kernel = NUTS(
potential_fn=logp_fn,
target_accept_prob=0.8,
adapt_step_size=True,
adapt_mass_matrix=True,
dense_mass=False,
Comment on lines +255 to +258
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we pass these as kwargs?

)

mcmc = MCMC(
nuts_kernel,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=1,
postprocess_fn=None,
chain_method="sequential",
progress_bar=False,
Comment on lines +266 to +268
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here (kwargs)

)

if seed is None:
rng_seed = int(np.random.default_rng().integers(2**32))
else:
rng_seed = int(seed)

tune_rng, sample_rng = jax.random.split(jax.random.PRNGKey(int(rng_seed)), 2)
mcmc.warmup(
tune_rng,
init_params=initial_point,
extra_fields=("num_steps",),
collect_warmup=True,
)
warmup_steps = int(mcmc.get_extra_fields()["num_steps"].sum())
mcmc.run(sample_rng, extra_fields=("num_steps",))
sample_steps = int(mcmc.get_extra_fields()["num_steps"].sum())
return int(warmup_steps + sample_steps)

def run(self, model: Model) -> pd.DataFrame:
"""Execute the estimation pipeline and return a single-row DataFrame.
Parameters
----------
model : Model
PyMC model to evaluate.
Returns
-------
pandas.DataFrame
Single-row DataFrame with columns including ``num_steps``, ``eval_time_seconds``,
``sequential_chains``, and estimated sampling wall-clock time in seconds,
minutes, and hours, along with metadata such as ``tune``, ``draws``, ``chains``,
``seed``, ``timestamp``, and ``model_name``.
Examples
--------
.. code-block:: python
df = ModelSamplerEstimator().run(model)
df[
[
"num_steps",
"eval_time_seconds",
"estimated_sampling_time_minutes",
]
]
"""
import time

steps = self.estimate_num_steps_sampling(
model, tune=self.tune, draws=self.draws, seed=self.seed
)
eval_time_s = self.estimate_model_eval_time(model)

sampling_time_seconds = float(
eval_time_s * steps * max(self.sequential_chains, 1)
)
data = {
"model_name": getattr(model, "name", "PyMCModel"),
"num_steps": int(steps),
"eval_time_seconds": float(eval_time_s),
"sequential_chains": int(self.sequential_chains),
"estimated_sampling_time_seconds": sampling_time_seconds,
"estimated_sampling_time_minutes": sampling_time_seconds / 60.0,
"estimated_sampling_time_hours": sampling_time_seconds / 3600.0,
"tune": int(self.tune),
"draws": int(self.draws),
"chains": int(self.chains),
"seed": int(self.seed) if self.seed is not None else None,
"timestamp": pd.Timestamp.utcfromtimestamp(int(time.time())),
}
df = pd.DataFrame([data])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think pd.DataFrame(data) still work as well (or data=data)?

return df
70 changes: 70 additions & 0 deletions tests/test_pytensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
import pandas as pd
import pymc as pm
import pytest
import xarray as xr
from pytensor import function
Expand All @@ -25,6 +26,7 @@
MMM,
MultiDimensionalBudgetOptimizerWrapper,
)
from pymc_marketing.pytensor_utils import ModelSamplerEstimator


@pytest.fixture
Expand Down Expand Up @@ -301,3 +303,71 @@ def test_extract_response_distribution_vs_sample_response(
)

print("\n✓ Both methods produce consistent results!")


@pytest.mark.parametrize("draws, tune", [(50, 50)])
def test_model_sampler_estimator_with_simple_model(monkeypatch, draws, tune):
"""Smoke test for ModelSamplerEstimator on a tiny PyMC model.

- Builds a simple Normal model with known parameters.
- Monkeypatches heavy JAX/NumPyro calls to keep test fast and backend-agnostic.
- Verifies expected columns and basic invariants of the returned DataFrame.
"""

pytest.importorskip("jax")
pytest.importorskip("numpyro")
pytest.importorskip("pymc.sampling.jax")

with pm.Model() as model:
mu = pm.Normal("mu", 0.0, 1.0)
sigma = pm.HalfNormal("sigma", 1.0)
pm.Normal("y", mu=mu, sigma=sigma, observed=np.random.normal(0, 1, size=10))

est = ModelSamplerEstimator(
tune=tune, draws=draws, chains=2, sequential_chains=1, seed=123
)
df = est.run(model)

# Check schema and basic values
expected_columns = {
"model_name",
"num_steps",
"eval_time_seconds",
"sequential_chains",
"estimated_sampling_time_seconds",
"estimated_sampling_time_minutes",
"estimated_sampling_time_hours",
"tune",
"draws",
"chains",
"seed",
"timestamp",
}
assert set(df.columns) >= expected_columns


def test_model_sampler_estimator_eval_time_multidim_model(fitted_multidim_mmm):
"""Measure eval time for a fitted multidimensional MMM's PyMC model."""
pm_model = fitted_multidim_mmm.model

est = ModelSamplerEstimator(
tune=50, draws=50, chains=1, sequential_chains=1, seed=123
)
est_df = est.run(pm_model)

# Check schema and basic values
expected_columns = {
"model_name",
"num_steps",
"eval_time_seconds",
"sequential_chains",
"estimated_sampling_time_seconds",
"estimated_sampling_time_minutes",
"estimated_sampling_time_hours",
"tune",
"draws",
"chains",
"seed",
"timestamp",
}
assert set(est_df.columns) >= expected_columns