|
| 1 | +# Copyright 2022 - 2025 The PyMC Labs Developers |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +""" |
| 15 | +Specialized priors that behave like the Prior class. |
| 16 | +
|
| 17 | +The Prior class has certain design constraints that prevent it from |
| 18 | +covering all cases. So this module contains a collection of |
| 19 | +priors that do not inherit from the Prior class but have many |
| 20 | +of the same methods. |
| 21 | +""" |
| 22 | + |
| 23 | +import numpy as np |
| 24 | +import pymc as pm |
| 25 | +import pytensor.tensor as pt |
| 26 | +import xarray as xr |
| 27 | +from pymc_extras.deserialize import deserialize, register_deserialization |
| 28 | +from pymc_extras.prior import Prior, create_dim_handler, sample_prior |
| 29 | +from pytensor.tensor import TensorVariable |
| 30 | + |
| 31 | + |
| 32 | +class LogNormalPositiveParam: |
| 33 | + """ |
| 34 | + A specialized implementation of a log normal distribution. |
| 35 | +
|
| 36 | + Like the LogNormal distribution, this distribution has support over the positive numbers. |
| 37 | + However, unlike the lognormal, the parameters are also specified in the positive |
| 38 | + domain. |
| 39 | +
|
| 40 | + The other advantage of this prior is in constructing hierarchical models. It allows users to toggle |
| 41 | + between centered and non-centered parameterizations. This enables rapid iteration when searching |
| 42 | + for a parameterization that samples efficiently. |
| 43 | +
|
| 44 | + Parameters |
| 45 | + ---------- |
| 46 | + mu : Prior, float, int, array-like |
| 47 | + The mean of the distribution. |
| 48 | + sigma : Prior, float, int, array-like |
| 49 | + The standard deviation of the distribution. |
| 50 | + dims : tuple[str, ...], optional |
| 51 | + The dimensions of the distribution, by default None. |
| 52 | + centered : bool, optional |
| 53 | + Whether the distribution is centered, by default True. |
| 54 | +
|
| 55 | + Examples |
| 56 | + -------- |
| 57 | + Build a non-centered hierarchical model where information is shared across geos. |
| 58 | +
|
| 59 | + .. code-block:: python |
| 60 | + from pymc_marketing.special_priors import LogNormalPositiveParam |
| 61 | +
|
| 62 | + normal = LogNormalPositiveParam( |
| 63 | + mu=Prior("Gamma", mu=1.0, sigma=1.0), |
| 64 | + sigma=Prior("HalfNormal", sigma=1.0), |
| 65 | + dims=("geo",), |
| 66 | + centered=False, |
| 67 | + ) |
| 68 | + """ |
| 69 | + |
| 70 | + def __init__(self, dims: tuple | None = None, centered: bool = True, **parameters): |
| 71 | + self.parameters = parameters |
| 72 | + self.dims = dims |
| 73 | + self.centered = centered |
| 74 | + |
| 75 | + self._checks() |
| 76 | + |
| 77 | + def _checks(self) -> None: |
| 78 | + self._parameters_are_correct_set() |
| 79 | + |
| 80 | + def _parameters_are_correct_set(self) -> None: |
| 81 | + if set(self.parameters.keys()) != {"mu", "sigma"}: |
| 82 | + raise ValueError("Parameters must be mu and sigma") |
| 83 | + |
| 84 | + def _create_parameter(self, param, value, name): |
| 85 | + if not hasattr(value, "create_variable"): |
| 86 | + return value |
| 87 | + |
| 88 | + child_name = f"{name}_{param}" |
| 89 | + return self.dim_handler(value.create_variable(child_name), value.dims) |
| 90 | + |
| 91 | + def create_variable(self, name: str) -> TensorVariable: |
| 92 | + """Create a variable from the prior distribution.""" |
| 93 | + self.dim_handler = create_dim_handler(self.dims) |
| 94 | + parameters = { |
| 95 | + param: self._create_parameter(param, value, name) |
| 96 | + for param, value in self.parameters.items() |
| 97 | + } |
| 98 | + mu_log = pt.log( |
| 99 | + parameters["mu"] ** 2 |
| 100 | + / pt.sqrt(parameters["mu"] ** 2 + parameters["sigma"] ** 2) |
| 101 | + ) |
| 102 | + sigma_log = pt.sqrt( |
| 103 | + pt.log(1 + (parameters["sigma"] ** 2 / parameters["mu"] ** 2)) |
| 104 | + ) |
| 105 | + |
| 106 | + if self.centered: |
| 107 | + log_phi = pm.Normal( |
| 108 | + name + "_log", mu=mu_log, sigma=sigma_log, dims=self.dims |
| 109 | + ) |
| 110 | + |
| 111 | + else: |
| 112 | + log_phi_z = pm.Normal( |
| 113 | + name + "_log" + "_offset", mu=0, sigma=1, dims=self.dims |
| 114 | + ) |
| 115 | + log_phi = mu_log + log_phi_z * sigma_log |
| 116 | + |
| 117 | + phi = pm.math.exp(log_phi) |
| 118 | + phi = pm.Deterministic(name, phi, dims=self.dims) |
| 119 | + |
| 120 | + return phi |
| 121 | + |
| 122 | + def to_dict(self): |
| 123 | + """Convert the prior distribution to a dictionary.""" |
| 124 | + data = { |
| 125 | + "special_prior": "LogNormalPositiveParam", |
| 126 | + } |
| 127 | + if self.parameters: |
| 128 | + |
| 129 | + def handle_value(value): |
| 130 | + if isinstance(value, Prior): |
| 131 | + return value.to_dict() |
| 132 | + |
| 133 | + if isinstance(value, pt.TensorVariable): |
| 134 | + value = value.eval() |
| 135 | + |
| 136 | + if isinstance(value, np.ndarray): |
| 137 | + return value.tolist() |
| 138 | + |
| 139 | + if hasattr(value, "to_dict"): |
| 140 | + return value.to_dict() |
| 141 | + |
| 142 | + return value |
| 143 | + |
| 144 | + data["kwargs"] = { |
| 145 | + param: handle_value(value) for param, value in self.parameters.items() |
| 146 | + } |
| 147 | + if not self.centered: |
| 148 | + data["centered"] = False |
| 149 | + |
| 150 | + if self.dims: |
| 151 | + data["dims"] = self.dims |
| 152 | + |
| 153 | + return data |
| 154 | + |
| 155 | + @classmethod |
| 156 | + def from_dict(cls, data) -> Prior: |
| 157 | + """Create a LogNormalPositiveParam prior from a dictionary.""" |
| 158 | + if not isinstance(data, dict): |
| 159 | + msg = ( |
| 160 | + "Must be a dictionary representation of a prior distribution. " |
| 161 | + f"Not of type: {type(data)}" |
| 162 | + ) |
| 163 | + raise ValueError(msg) |
| 164 | + |
| 165 | + kwargs = data.get("kwargs", {}) |
| 166 | + |
| 167 | + def handle_value(value): |
| 168 | + if isinstance(value, dict): |
| 169 | + return deserialize(value) |
| 170 | + |
| 171 | + if isinstance(value, list): |
| 172 | + return np.array(value) |
| 173 | + |
| 174 | + return value |
| 175 | + |
| 176 | + kwargs = {param: handle_value(value) for param, value in kwargs.items()} |
| 177 | + centered = data.get("centered", True) |
| 178 | + dims = data.get("dims") |
| 179 | + if isinstance(dims, list): |
| 180 | + dims = tuple(dims) |
| 181 | + |
| 182 | + return cls(dims=dims, centered=centered, **kwargs) |
| 183 | + |
| 184 | + def sample_prior( |
| 185 | + self, |
| 186 | + coords=None, |
| 187 | + name: str = "variable", |
| 188 | + **sample_prior_predictive_kwargs, |
| 189 | + ) -> xr.Dataset: |
| 190 | + """Sample from the prior distribution.""" |
| 191 | + return sample_prior( |
| 192 | + factory=self, |
| 193 | + coords=coords, |
| 194 | + name=name, |
| 195 | + **sample_prior_predictive_kwargs, |
| 196 | + ) |
| 197 | + |
| 198 | + |
| 199 | +def _is_lognormalpositiveparam_type(data: dict) -> bool: |
| 200 | + if "special_prior" in data: |
| 201 | + return data["special_prior"] == "LogNormalPositiveParam" |
| 202 | + else: |
| 203 | + return False |
| 204 | + |
| 205 | + |
| 206 | +register_deserialization( |
| 207 | + is_type=_is_lognormalpositiveparam_type, |
| 208 | + deserialize=LogNormalPositiveParam.from_dict, |
| 209 | +) |
0 commit comments