- 
                Notifications
    You must be signed in to change notification settings 
- Fork 0
Bayesian models in numpyro
This page provides guidance for the models in bayesian-envhealth-models repo.
The user should be able to build up models based on these building blocks.
Binomial
latent_rate = # effects making up logit-transformed rate
with numpyro.plate("N", size=N):
    mu_logit = latent_rate[...]
    numpyro.sample(
        "deaths",
        # no need for a logit transform with numpyro
        # it can cope with inputting the logits directly
        dist.Binomial(total_count=population, logits=mu_logit),
        # this is where the model sees the data - the number of deaths
        obs=deaths,
    )Poisson
latent_rate = # effects making up log-transformed rate
with numpyro.plate("N", size=N):
    # offset for population
    # equivalent in other PPLs of:
    # log(mu[i]) <- log(n[i]) + lograte[...]
    mu = jnp.exp(jnp.log(population) + latent_rate[...])
    numpyro.sample(
        "deaths",
        dist.Poisson(rate=mu),
        obs=deaths,
    )Normal (not appropriate for count data, but maybe for continuous environmental variables like temperature
mean = # effects making up (identity-transformed (not)) mean
sigma = numpyro.sample("sigma", dist.HalfNormal(5.0))
with numpyro.plate("N", size=N):
    mu = mean[...]
    numpyro.sample("outcome", dist.Normal(loc=mu, scale=sigma), obs=outcome)We want to enforce similarity over adjacent age groups.
One way of doing this is with random walk priors 
# use a plate to create 1d array `age_drift` of size `N_age - 1`
with numpyro.plate("age_groups", size=(N_age-1)):
    age_drift = numpyro.sample("age_drift", dist.Normal(0, sigma_age))
    # a random walk is just a cumulative sum of random effects https://observablehq.com/@observablehq/plot-random-walk
    # we pad the array with a 0 for identifiability
    # so the array now has length `N_age` (`[0, ...]`)
    # and the effect for the first age group is 0
    age_effect = jnp.pad(jnp.cumsum(age_drift, -1), (1, 0))In small-area studies, it is common to smooth data using models with explicit spatial dependence, which are designed to give more weight to nearby areas than those further away. There are three main categories for modelling spatial effects. First, we can treat space as a continuous surface using Gaussian processes or splines. Second, we can use areal models, which make use of the spatial neighbourhood structure of the units. Third, we can build models that exploit a nested hierarchy of geographical units, for example between state, county and census tract in the US. Each of these methods rely on assumptions which may make them more or less appropriate in different applications.
Here, we will focus on areal models, which are the most common in disease mapping studies.
A more popular prior is the conditional autoregressive (CAR) prior, also known as a Gaussian Markov random field (GMRF).
These form a joint distribution where the covariance is usually defined instead in terms of the precision matrix
The ICAR prior is specified as
spatial_effect_raw = numpyro.sample(
    "spatial_effect_raw",
    dist.CAR(
        loc=0.0,
        # effectively ICAR – there are mathematical reasons it cannot be 1.0
        correlation=0.99,
        conditional_precision=1.0,
        # `adj` is a matrix of 1s and 0s specifying the neighbourhood adjacency
        adj_matrix=adj,
        is_sparse=True,
    ),
)
spatial_effect = spatial_scale * spatial_effect_rawFor annual data, we would consider linear slopes for trends (
Space-time interactions (for example, although could be age-time or race-time) could range from fully independent, to each spatial unit having independent temporal patterns, to inseparable space-time variation where interactions borrow strength across neighbouring spatial units and neighbouring time periods. The most common types are fully independent (Type I) and each spatial unit having independent temporal patterns (type II).
A type I interaction is specified as
with age_plate, space_plate:
    age_space_interaction = numpyro.sample("age_space_interaction", dist.Normal(0, sigma_age_space))A type II interaction where the temporal effects are separate random walks for each age group can be adapted from the random walk implementation above
with age_plate, time_plate:
    # two plates make `age_time_drift` a 2d array with size `(N_age, N_t - 1)`
    age_time_drift = numpyro.sample(
        "age_time_drift", dist.Normal(0, sigma_rw_age_time)
    )
    # we pad the array with a 0 for identifiability
    age_time_effect = jnp.pad(jnp.cumsum(age_time_drift, -1), [(0, 0), (1, 0)])There are two ways to encode categorical variables. Firstly, using random effects. This works the same as the age group above variable, but without the cumulative sum. i.e. just a mean-zero normal
with numpyro.plate("race", size=N_race):
    race_effect = numpyro.sample("race_effect", dist.Normal(0, sigma_race))Secondly, we can use fixed effects.
For a small number of categories, I would propose using fixed effects.
This can be incorporated into a model by one-hot encoding the variables and passing in a matrix X.
Note, fixed effects are relative, so you need a reference category (e.g. measuring the effect of Black relative to reference category White).
The matrix X can include other fixed effects, including continuous variables, where the effect is a linear slope.
In the model, it will look like
with numpyro.plate("covariates", size=N_covariates):
    beta_covariates = numpyro.sample("beta_covariates", dist.Normal(0, 1)) # prior for each independent covariate effect
covariate_effects = jnp.dot(X, beta_covariates)Here is an example of a full model with:
- intercept (first term of random walk over age effect)
- slope over time
- random walk over age effect
- age-specific random walk over time (type II interaction)
- Binomial likelihood
The code below is annotated to explain how each effect is built up.
def model_age_time_interaction(age_id, time_id, population, deaths):
    N = len(population)            # size of the data
    N_age = len(np.unique(age_id)) # number of age groups
    N_t = len(np.unique(time_id))  # number of time steps
    # plates control the `size` of the effect, and replace `for i in 1:N_age` in other PPLs
    # the argument `dim` helps us keep track of shapes and allows for clever broadcasting
    # the `dim=-2` argument means any `numpyro.sample` statement within `age_plate` will
    # have shape `(N_age, 1)` rather than just `(N_age,)` – there has been an extra dimension
    # created here meaning all age effects are in the 2nd rightmost dim and all time effects
    # are in the rightmost (`-1`) dim.
    # There is more information on shapes in the docs https://pyro.ai/examples/tensor_shapes.html
    # and in this post by Eric Ma https://ericmjl.github.io/blog/2019/5/29/reasoning-about-shapes-and-probability-distributions/
    age_plate = numpyro.plate("age_groups", size=N_age, dim=-2)
    time_plate = numpyro.plate("time", size=(N_t - 1), dim=-1)
    # hyperparameters
    slope = numpyro.sample("slope", dist.Normal(loc=0.0, scale=1.0))
    sigma_rw_age = numpyro.sample("sigma_rw_age", dist.HalfNormal(1.0)) # Half-Normal is a good prior for positive sd effects
    sigma_rw_age_time = numpyro.sample("sigma_rw_age_time", dist.HalfNormal(1.0))
    # slope over time is the same as adding slope at each timestep
    slope_cum = slope * jnp.arange(N_t) # jnp.arange is [0, 1, ..., N_t], so this becomes [0, 1 * slope, ..., N_t * slope] 
    # random walk over age
    with age_plate:
        age_drift_scale = jnp.pad(
            jnp.broadcast_to(sigma_rw_age, N_age - 1),
            (1, 0),
            # pad so first term is the intercept, prior N(0, 10)
            # this is a bit fancy really
            # we could also just have an effect of size `N_age - 1`
            # do the cumsum and then pad the first term with zero for identifiability (as above)
            constant_values=10.0,
        )[:, jnp.newaxis] # manually turn this from shape `(N_age,)` to `(N_age, 1)` using `jnp.newaxis`
        # `numpyro.sample` statement within `age_plate` will have shape `(N_age, 1)`
        age_drift = numpyro.sample("age_drift", dist.Normal(0, age_drift_scale))
        # needs to be applied over the age dimension, i.e. `dim=-2`
        age_effect = jnp.cumsum(age_drift, -2)
    # age-time random walk (type II) interaction
    with age_plate, time_plate:
        # random sample of shape `(N_age, N_t - 1)`
        age_time_drift = numpyro.sample(
            "age_time_drift", dist.Normal(0, sigma_rw_age_time)
        )
        # cumulative sum over the time dimension (`dim=-1`) and then pad the time dimension with 0 for identifiability
        # so shape of effect becomes `(N_age, N_t)`
        age_time_effect = jnp.pad(jnp.cumsum(age_time_drift, -1), [(0, 0), (1, 0)])
    # this is where the shape magic happens
    # `slope_cum` has shape `(N_t,)`
    # `age_effect` has shape `(N_age, 1)`
    # `age_time_effect` has shape `(N_age, N_t)`
    # we add these things with different shapes
    # the age effect with be "broadcasted" (repeated) over each time step
    # i.e. the age effect is the same in each time step
    # the same happens for the `slope_cum` over age groups
    # `latent_rate` has shape `(N_age, N_t)`
    latent_rate = slope_cum + age_effect + age_time_effect
    # likelihood
    with numpyro.plate("N", size=N):
        # this line plucks out the right `latent_rate` according to the dataset
        # for example, if the row in the data was age group 3 and time step 17
        # it would pick out `latent_rate[3, 17]`
        mu_logit = latent_rate[age_id, time_id]
        numpyro.sample(
            "deaths",
            # no need for a logit transform with numpyro
            # it can cope with inputting the logits directly
            dist.Binomial(total_count=population, logits=mu_logit),
            # this is where the model sees the data - the number of deaths
            obs=deaths,
        )- Binomial likelihood
- Global intercept and slope
- Age group-specific intercepts and slopes (random walk)
- Two-tier nested hierarchy of random effects over space, intercepts and slopes
- Temporal non-linear random walk
- Binomial likelihood
- Global intercept and slope
- Age group-specific intercepts and slopes (random walk)
- Spatial intercept and slopes, either three-tier nested hierarchy of random effects or ICAR
- Age-space IID (type I) interaction
- Age-time type II interaction
The default sampler is NUTS, which is a variant of HMC.
NUTS samplers work better in a non-centred parametrisation, i.e. they prefer 
numpyro lets us write models in the centred parametrisation, then add a decorator to the model to tell numpyro to evaluate the relevant parameters as non-centred.
In the case below, we are telling the model to use the non-centred parametrisations for the age_drift and age_time_drift parameters.
@numpyro.handlers.reparam(
    config={
        k: LocScaleReparam(0)
        for k in [
            "age_drift",
            "age_time_drift",
        ]
    }
)
def model_age_time_interaction(
    age_id: Int[Array, "data"],
    time_id: Int[Array, "data"],
    population: Int[Array, "data"],
    deaths: Optional[Int[Array, "data"]] = None,
) -> None:This is another way to make sure the users pass the correct type of data to the model and helps prevent errors.
Python is a dynamically typed language. Recent versions of python have allowed for static typing in the form of type hints.
In this way, we can convert the basic function definition
def model_age_time_interaction(age_id, time_id, population, deaths=None)to the equivalent, but strongly typed
@jaxtyped(typechecker=beartype)
def model_age_time_interaction(
    age_id: Int[Array, "data"],
    time_id: Int[Array, "data"],
    population: Int[Array, "data"],
    deaths: Optional[Int[Array, "data"]] = None,
) -> None:In the first basic model, there is nothing to stop the user passing the arguments age_id="r", population=True, deaths=12.34, which of course makes no sense.
Although these are just type hints, by using the decorator @jaxtyped(typechecker=beartype), these types are enforced and the model throws an error if the user passes something faulty.
Read more about the beartype project here.
Further, jaxtyping makes sure all the shape and dtype of jax arrays match (the dimension "data" in this case).
Running a model is great, but you need to check the inference has converged.
arviz has inbuilt methods for this.
The most important thing is to check the r_hat column values in the summary are all below 1.01 (although 1.05 is also a reasonable threshold).
import arviz as az
import xarray as xr
ds = xr.open_dataset("../../output/model_age_time_interaction_samples.nc")
az.summary(ds)Although csv files are human-readable, when the dataset gets large the most efficient way of holding data are in binary files.
Also, when dealing with multiple causes of death, the population column often contains repeat data, and we might only want to hold one version of this.
Below is a code snippet to convert the csv to the required npy binary files for loading into the modelling framework.
import pandas as pd
import numpy as np
# Read the CSV file into a pandas DataFrame
df = pd.read_csv('simulated_deaths.csv')
# convert columns into indicator variables for hierarchical model
df = df.assign(
    year_id = lambda x: x.year.astype("category").cat.codes,
    age_id = lambda x: x.age.astype("category").cat.codes
)
# Save each column as an array using np.save()
for column in df.columns:
    np.save(f'{column}.npy', df[column].values)