-
-
Notifications
You must be signed in to change notification settings - Fork 205
Consolidate distribution classes #1592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
e3687f8
9fcd35a
abef403
16b40b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -219,9 +219,8 @@ class IndexDistribution(Distribution): | |
class (such as Bernoulli, LogNormal, etc.) with information | ||
about the conditions on the parameters of the distribution. | ||
|
||
For example, an IndexDistribution can be defined as | ||
a Bernoulli distribution whose parameter p is a function of | ||
a different input parameter. | ||
It can also wrap a list of pre-discretized distributions (previously | ||
provided by TimeVaryingDiscreteDistribution) and provide the same API. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -235,14 +234,19 @@ class (such as Bernoulli, LogNormal, etc.) with information | |
Keys should match the arguments to the engine class | ||
constructor. | ||
|
||
distributions: [DiscreteDistribution] | ||
Optional. A list of discrete distributions to wrap directly. | ||
|
||
seed : int | ||
Seed for random number generator. | ||
""" | ||
|
||
conditional = None | ||
engine = None | ||
|
||
def __init__(self, engine, conditional, RNG=None, seed=0): | ||
def __init__( | ||
self, engine=None, conditional=None, distributions=None, RNG=None, seed=0 | ||
): | ||
if RNG is None: | ||
# Set up the RNG | ||
super().__init__(seed) | ||
|
@@ -255,11 +259,24 @@ def __init__(self, engine, conditional, RNG=None, seed=0): | |
# and create a new one. | ||
self.seed = seed | ||
|
||
self.conditional = conditional | ||
# Mode 1: wrapping a list of discrete distributions | ||
if distributions is not None: | ||
self.distributions = distributions | ||
self.engine = None | ||
self.conditional = None | ||
self.dstns = [] | ||
return | ||
|
||
# Mode 2: engine + conditional parameters (original IndexDistribution) | ||
self.conditional = conditional if conditional is not None else {} | ||
self.engine = engine | ||
|
||
self.dstns = [] | ||
|
||
# If no engine/conditional were provided, remain empty (should not happen in normal use) | ||
if self.engine is None and not self.conditional: | ||
return | ||
|
||
# Test one item to determine case handling | ||
item0 = list(self.conditional.values())[0] | ||
|
||
|
@@ -273,7 +290,7 @@ def __init__(self, engine, conditional, RNG=None, seed=0): | |
|
||
elif type(item0) is float: | ||
self.dstns = [ | ||
self.engine(seed=self._rng.integers(0, 2**31 - 1), **conditional) | ||
self.engine(seed=self._rng.integers(0, 2**31 - 1), **self.conditional) | ||
] | ||
|
||
else: | ||
|
@@ -284,6 +301,9 @@ def __init__(self, engine, conditional, RNG=None, seed=0): | |
) | ||
|
||
def __getitem__(self, y): | ||
# Prefer discrete list mode if present | ||
if hasattr(self, "distributions") and self.distributions: | ||
return self.distributions[y] | ||
return self.dstns[y] | ||
|
||
def discretize(self, N, **kwds): | ||
|
@@ -302,16 +322,16 @@ def discretize(self, N, **kwds): | |
|
||
Returns: | ||
------------ | ||
dists : [DiscreteDistribution] | ||
A list of DiscreteDistributions that are the | ||
approximation of engine distribution under each condition. | ||
|
||
TODO: It would be better if there were a conditional discrete | ||
distribution representation. But that integrates with the | ||
solution code. This implementation will return the list of | ||
distributions representations expected by the solution code. | ||
dists : [DiscreteDistribution] or IndexDistribution | ||
If parameterization is constant, returns a single DiscreteDistribution. | ||
If parameterization varies with index, returns an IndexDistribution in | ||
discrete-list mode, wrapping the corresponding discrete distributions. | ||
""" | ||
|
||
# If already in discrete list mode, return self (already discretized) | ||
if hasattr(self, "distributions") and self.distributions: | ||
return self | ||
|
||
# test one item to determine case handling | ||
item0 = list(self.conditional.values())[0] | ||
|
||
|
@@ -320,8 +340,12 @@ def discretize(self, N, **kwds): | |
return self.dstns[0].discretize(N, **kwds) | ||
|
||
if type(item0) is list: | ||
return TimeVaryingDiscreteDistribution( | ||
[self[i].discretize(N, **kwds) for i, _ in enumerate(item0)] | ||
# Return an IndexDistribution wrapping a list of discrete distributions | ||
return IndexDistribution( | ||
distributions=[ | ||
self[i].discretize(N, **kwds) for i, _ in enumerate(item0) | ||
], | ||
seed=self.seed, | ||
) | ||
|
||
def draw(self, condition): | ||
|
@@ -345,6 +369,15 @@ def draw(self, condition): | |
# are of the same type. | ||
# this matches the HARK 'time-varying' model architecture. | ||
|
||
# If wrapping discrete distributions, draw from those | ||
if hasattr(self, "distributions") and self.distributions: | ||
draws = np.zeros(condition.size) | ||
for c in np.unique(condition): | ||
these = c == condition | ||
N = np.sum(these) | ||
draws[these] = self.distributions[c].draw(N) | ||
return draws | ||
|
||
# test one item to determine case handling | ||
item0 = list(self.conditional.values())[0] | ||
|
||
|
@@ -367,70 +400,6 @@ def draw(self, condition): | |
these = c == condition | ||
N = np.sum(these) | ||
|
||
cond = {key: val[c] for (key, val) in self.conditional.items()} | ||
draws[these] = self[c].draw(N) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The line Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
return draws | ||
|
||
|
||
class TimeVaryingDiscreteDistribution(Distribution): | ||
""" | ||
This class provides a way to define a discrete distribution that | ||
is conditional on an index. | ||
|
||
Wraps a list of discrete distributions. | ||
|
||
Parameters | ||
---------- | ||
|
||
distributions : [DiscreteDistribution] | ||
A list of discrete distributions | ||
|
||
seed : int | ||
Seed for random number generator. | ||
""" | ||
|
||
distributions = [] | ||
|
||
def __init__(self, distributions, seed=0): | ||
# Set up the RNG | ||
super().__init__(seed) | ||
|
||
self.distributions = distributions | ||
|
||
def __getitem__(self, y): | ||
return self.distributions[y] | ||
|
||
def draw(self, condition): | ||
""" | ||
Generate arrays of draws. | ||
The input is an array containing the conditions. | ||
The output is an array of the same length (axis 1 dimension) | ||
as the conditions containing random draws of the conditional | ||
distribution. | ||
|
||
Parameters | ||
---------- | ||
condition : np.array | ||
The input conditions to the distribution. | ||
|
||
Returns: | ||
------------ | ||
draws : np.array | ||
""" | ||
# for now, assume that all the conditionals | ||
# are of the same type. | ||
# this matches the HARK 'time-varying' model architecture. | ||
|
||
# conditions are indices into list | ||
# somewhat convoluted sampling strategy retained | ||
# for test backwards compatibility | ||
draws = np.zeros(condition.size) | ||
|
||
for c in np.unique(condition): | ||
these = c == condition | ||
N = np.sum(these) | ||
|
||
draws[these] = self.distributions[c].draw(N) | ||
|
||
return draws |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment suggests a condition that 'should not happen in normal use', but the code allows it. Consider either removing this case handling or throwing an exception if this truly represents an invalid state.
Copilot uses AI. Check for mistakes.