Skip to content

Commit 6b30c10

Browse files
Special priors (#1939)
* add special priors file which includes the lognormalexp prior * add tests * change name of class from lognormalexp -> lognormalpositiveparam ' * add test for no dims * expand tests to cover a few more cases. * add special_priors to the index of the api docs. * minor updates to doc string.
1 parent e2ca250 commit 6b30c10

File tree

3 files changed

+291
-0
lines changed

3 files changed

+291
-0
lines changed

docs/source/api/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@
2121
model_graph
2222
plot
2323
prior
24+
special_priors
2425
utils
2526
```

pymc_marketing/special_priors.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# Copyright 2022 - 2025 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Specialized priors that behave like the Prior class.
16+
17+
The Prior class has certain design constraints that prevent it from
18+
covering all cases. So this module contains a collection of
19+
priors that do not inherit from the Prior class but have many
20+
of the same methods.
21+
"""
22+
23+
import numpy as np
24+
import pymc as pm
25+
import pytensor.tensor as pt
26+
import xarray as xr
27+
from pymc_extras.deserialize import deserialize, register_deserialization
28+
from pymc_extras.prior import Prior, create_dim_handler, sample_prior
29+
from pytensor.tensor import TensorVariable
30+
31+
32+
class LogNormalPositiveParam:
33+
"""
34+
A specialized implementation of a log normal distribution.
35+
36+
Like the LogNormal distribution, this distribution has support over the positive numbers.
37+
However, unlike the lognormal, the parameters are also specified in the positive
38+
domain.
39+
40+
The other advantage of this prior is in constructing hierarchical models. It allows users to toggle
41+
between centered and non-centered parameterizations. This enables rapid iteration when searching
42+
for a parameterization that samples efficiently.
43+
44+
Parameters
45+
----------
46+
mu : Prior, float, int, array-like
47+
The mean of the distribution.
48+
sigma : Prior, float, int, array-like
49+
The standard deviation of the distribution.
50+
dims : tuple[str, ...], optional
51+
The dimensions of the distribution, by default None.
52+
centered : bool, optional
53+
Whether the distribution is centered, by default True.
54+
55+
Examples
56+
--------
57+
Build a non-centered hierarchical model where information is shared across geos.
58+
59+
.. code-block:: python
60+
from pymc_marketing.special_priors import LogNormalPositiveParam
61+
62+
normal = LogNormalPositiveParam(
63+
mu=Prior("Gamma", mu=1.0, sigma=1.0),
64+
sigma=Prior("HalfNormal", sigma=1.0),
65+
dims=("geo",),
66+
centered=False,
67+
)
68+
"""
69+
70+
def __init__(self, dims: tuple | None = None, centered: bool = True, **parameters):
71+
self.parameters = parameters
72+
self.dims = dims
73+
self.centered = centered
74+
75+
self._checks()
76+
77+
def _checks(self) -> None:
78+
self._parameters_are_correct_set()
79+
80+
def _parameters_are_correct_set(self) -> None:
81+
if set(self.parameters.keys()) != {"mu", "sigma"}:
82+
raise ValueError("Parameters must be mu and sigma")
83+
84+
def _create_parameter(self, param, value, name):
85+
if not hasattr(value, "create_variable"):
86+
return value
87+
88+
child_name = f"{name}_{param}"
89+
return self.dim_handler(value.create_variable(child_name), value.dims)
90+
91+
def create_variable(self, name: str) -> TensorVariable:
92+
"""Create a variable from the prior distribution."""
93+
self.dim_handler = create_dim_handler(self.dims)
94+
parameters = {
95+
param: self._create_parameter(param, value, name)
96+
for param, value in self.parameters.items()
97+
}
98+
mu_log = pt.log(
99+
parameters["mu"] ** 2
100+
/ pt.sqrt(parameters["mu"] ** 2 + parameters["sigma"] ** 2)
101+
)
102+
sigma_log = pt.sqrt(
103+
pt.log(1 + (parameters["sigma"] ** 2 / parameters["mu"] ** 2))
104+
)
105+
106+
if self.centered:
107+
log_phi = pm.Normal(
108+
name + "_log", mu=mu_log, sigma=sigma_log, dims=self.dims
109+
)
110+
111+
else:
112+
log_phi_z = pm.Normal(
113+
name + "_log" + "_offset", mu=0, sigma=1, dims=self.dims
114+
)
115+
log_phi = mu_log + log_phi_z * sigma_log
116+
117+
phi = pm.math.exp(log_phi)
118+
phi = pm.Deterministic(name, phi, dims=self.dims)
119+
120+
return phi
121+
122+
def to_dict(self):
123+
"""Convert the prior distribution to a dictionary."""
124+
data = {
125+
"special_prior": "LogNormalPositiveParam",
126+
}
127+
if self.parameters:
128+
129+
def handle_value(value):
130+
if isinstance(value, Prior):
131+
return value.to_dict()
132+
133+
if isinstance(value, pt.TensorVariable):
134+
value = value.eval()
135+
136+
if isinstance(value, np.ndarray):
137+
return value.tolist()
138+
139+
if hasattr(value, "to_dict"):
140+
return value.to_dict()
141+
142+
return value
143+
144+
data["kwargs"] = {
145+
param: handle_value(value) for param, value in self.parameters.items()
146+
}
147+
if not self.centered:
148+
data["centered"] = False
149+
150+
if self.dims:
151+
data["dims"] = self.dims
152+
153+
return data
154+
155+
@classmethod
156+
def from_dict(cls, data) -> Prior:
157+
"""Create a LogNormalPositiveParam prior from a dictionary."""
158+
if not isinstance(data, dict):
159+
msg = (
160+
"Must be a dictionary representation of a prior distribution. "
161+
f"Not of type: {type(data)}"
162+
)
163+
raise ValueError(msg)
164+
165+
kwargs = data.get("kwargs", {})
166+
167+
def handle_value(value):
168+
if isinstance(value, dict):
169+
return deserialize(value)
170+
171+
if isinstance(value, list):
172+
return np.array(value)
173+
174+
return value
175+
176+
kwargs = {param: handle_value(value) for param, value in kwargs.items()}
177+
centered = data.get("centered", True)
178+
dims = data.get("dims")
179+
if isinstance(dims, list):
180+
dims = tuple(dims)
181+
182+
return cls(dims=dims, centered=centered, **kwargs)
183+
184+
def sample_prior(
185+
self,
186+
coords=None,
187+
name: str = "variable",
188+
**sample_prior_predictive_kwargs,
189+
) -> xr.Dataset:
190+
"""Sample from the prior distribution."""
191+
return sample_prior(
192+
factory=self,
193+
coords=coords,
194+
name=name,
195+
**sample_prior_predictive_kwargs,
196+
)
197+
198+
199+
def _is_lognormalpositiveparam_type(data: dict) -> bool:
200+
if "special_prior" in data:
201+
return data["special_prior"] == "LogNormalPositiveParam"
202+
else:
203+
return False
204+
205+
206+
register_deserialization(
207+
is_type=_is_lognormalpositiveparam_type,
208+
deserialize=LogNormalPositiveParam.from_dict,
209+
)

tests/test_special_priors.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2022 - 2025 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pymc as pm
17+
import pytest
18+
import xarray as xr
19+
from pymc_extras.prior import Prior
20+
21+
from pymc_marketing.special_priors import (
22+
LogNormalPositiveParam,
23+
_is_lognormalpositiveparam_type,
24+
)
25+
26+
27+
@pytest.mark.parametrize(
28+
"mu, sigma, centered, dims",
29+
[
30+
(
31+
Prior("Gamma", mu=1.0, sigma=1.0),
32+
Prior("Gamma", mu=1.0, sigma=1.0),
33+
True,
34+
("channel",),
35+
),
36+
(1.0, 2.0, False, ("channel",)),
37+
(1.0, 2.0, True, ("channel",)),
38+
(np.array([1, 2, 3]), np.array([4, 5, 6]), True, ("channel",)),
39+
(np.array([1, 2, 3]), np.array([4, 5, 6]), False, ("channel",)),
40+
(1.0, 2.0, True, ()),
41+
],
42+
)
43+
def test_LogNormalPositiveParam_args(mu, sigma, centered, dims):
44+
"""
45+
Checks:
46+
- sample_prior runs
47+
- create_variable runs
48+
- round trip: dict to class to dict to class, doesn't lose any information
49+
"""
50+
rv = LogNormalPositiveParam(mu=mu, sigma=sigma, centered=centered, dims=dims)
51+
52+
coords = {"channel": ["C1", "C2", "C3"]}
53+
54+
if dims:
55+
prior = rv.sample_prior(coords=coords)
56+
assert prior.channel.shape == (len(coords["channel"]),)
57+
else:
58+
prior = rv.sample_prior()
59+
assert isinstance(prior, xr.Dataset)
60+
61+
if centered is False:
62+
assert "variable_log_offset" in prior.data_vars
63+
64+
with pm.Model(coords=coords):
65+
rv.create_variable("test")
66+
67+
assert rv.to_dict() == rv.from_dict(rv.to_dict()).to_dict()
68+
69+
70+
def test_LogNormalPositiveParam_args_invalid():
71+
with pytest.raises(ValueError):
72+
LogNormalPositiveParam(alpha=1.0, beta=1.0)
73+
74+
75+
def test_the_deserializer_can_distinguish_between_types_of_prior_classes():
76+
assert _is_lognormalpositiveparam_type(
77+
LogNormalPositiveParam(mu=1.0, sigma=1.0).to_dict()
78+
)
79+
assert not _is_lognormalpositiveparam_type(
80+
Prior("Normal", mu=1.0, sigma=1.0).to_dict()
81+
)

0 commit comments

Comments
 (0)