Skip to content

Commit 43a9d6c

Browse files
authored
Merge pull request #1205 from jburnim/r0.12
Prepare branch for TFP 0.12.1 release
2 parents dcd59ed + c1818a8 commit 43a9d6c

File tree

3 files changed

+39
-1
lines changed

3 files changed

+39
-1
lines changed

tensorflow_probability/python/bijectors/bijector.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ def __init__(self,
511511
name = name_util.strip_invalid_chars(name)
512512
super(Bijector, self).__init__(name=name)
513513
self._name = name
514+
# TODO(b/176242804): Infer `parameters` if not specified by the child class.
514515
self._parameters = self._no_dependency(parameters)
515516

516517
self._graph_parents = self._no_dependency(graph_parents or [])
@@ -648,6 +649,8 @@ def parameters(self):
648649
# Remove "self", "__class__", or other special variables. These can appear
649650
# if the subclass used:
650651
# `parameters = dict(locals())`.
652+
if self._parameters is None:
653+
return None
651654
return {k: v for k, v in self._parameters.items()
652655
if not k.startswith('__') and k != 'self'}
653656

@@ -689,6 +692,11 @@ def __eq__(self, other):
689692
return True
690693

691694
def _get_parameterization(self):
695+
if self.parameters is None:
696+
# If a user-written bijector doesn't specify `parameters`, we must assume
697+
# that all instances are unique.
698+
# TODO(b/176242804): this can be removed if we always infer `parameters`.
699+
return id(self)
692700
return self.parameters
693701

694702
def __call__(self, value, name=None, **kwargs):

tensorflow_probability/python/bijectors/bijector_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,24 @@ def _get_parameterization(self):
280280
return id(self)
281281

282282

283+
class UnspecifiedParameters(tfb.Bijector):
284+
"""A bijector that fails to pass `parameters` to the base class."""
285+
286+
def __init__(self, loc):
287+
self._loc = loc
288+
super(UnspecifiedParameters, self).__init__(
289+
validate_args=False,
290+
is_constant_jacobian=True,
291+
forward_min_event_ndims=0,
292+
name='unspecified_parameters')
293+
294+
def _forward(self, x):
295+
return x + self._loc
296+
297+
def _forward_log_det_jacobian(self, x):
298+
return tf.constant(0., x.dtype)
299+
300+
283301
@test_util.test_all_tf_execution_regimes
284302
class BijectorTestEventNdims(test_util.TestCase):
285303

@@ -440,6 +458,18 @@ def testUniqueCacheKey(self):
440458
self.assertLen(bijector_1._cache.weak_keys(direction='forward'), 1)
441459
self.assertLen(bijector_2._cache.weak_keys(direction='forward'), 1)
442460

461+
def testBijectorsWithUnspecifiedParametersDoNotShareCache(self):
462+
bijector_1 = UnspecifiedParameters(loc=tf.constant(1., dtype=tf.float32))
463+
bijector_2 = UnspecifiedParameters(loc=tf.constant(2., dtype=tf.float32))
464+
465+
x = tf.constant(3., dtype=tf.float32)
466+
y_1 = bijector_1.forward(x)
467+
y_2 = bijector_2.forward(x)
468+
469+
self.assertIsNot(y_1, y_2)
470+
self.assertLen(bijector_1._cache.weak_keys(direction='forward'), 1)
471+
self.assertLen(bijector_2._cache.weak_keys(direction='forward'), 1)
472+
443473
def testInstanceCache(self):
444474
instance_cache_bijector = tfb.Exp()
445475
instance_cache_bijector._cache = cache_util.BijectorCache(

tensorflow_probability/python/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# We follow Semantic Versioning (https://semver.org/)
1818
_MAJOR_VERSION = '0'
1919
_MINOR_VERSION = '12'
20-
_PATCH_VERSION = '0'
20+
_PATCH_VERSION = '1'
2121

2222
# When building releases, we can update this value on the release branch to
2323
# reflect the current release candidate ('rc0', 'rc1') or, finally, the official

0 commit comments

Comments
 (0)