Skip to content

Commit 0c282ef

Browse files
davmrejburnim
authored andcommitted
Avoid cache errors when user-written Bijectors don't pass parameters.
This addresses the issue highlighted in #1202 in which recent changes have broken some user-written Bijectors. PiperOrigin-RevId: 348849165
1 parent dcd59ed commit 0c282ef

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

tensorflow_probability/python/bijectors/bijector.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ def __init__(self,
511511
name = name_util.strip_invalid_chars(name)
512512
super(Bijector, self).__init__(name=name)
513513
self._name = name
514+
# TODO(b/176242804): Infer `parameters` if not specified by the child class.
514515
self._parameters = self._no_dependency(parameters)
515516

516517
self._graph_parents = self._no_dependency(graph_parents or [])
@@ -648,6 +649,8 @@ def parameters(self):
648649
# Remove "self", "__class__", or other special variables. These can appear
649650
# if the subclass used:
650651
# `parameters = dict(locals())`.
652+
if self._parameters is None:
653+
return None
651654
return {k: v for k, v in self._parameters.items()
652655
if not k.startswith('__') and k != 'self'}
653656

@@ -689,6 +692,11 @@ def __eq__(self, other):
689692
return True
690693

691694
def _get_parameterization(self):
695+
if self.parameters is None:
696+
# If a user-written bijector doesn't specify `parameters`, we must assume
697+
# that all instances are unique.
698+
# TODO(b/176242804): this can be removed if we always infer `parameters`.
699+
return id(self)
692700
return self.parameters
693701

694702
def __call__(self, value, name=None, **kwargs):

tensorflow_probability/python/bijectors/bijector_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,24 @@ def _get_parameterization(self):
280280
return id(self)
281281

282282

283+
class UnspecifiedParameters(tfb.Bijector):
284+
"""A bijector that fails to pass `parameters` to the base class."""
285+
286+
def __init__(self, loc):
287+
self._loc = loc
288+
super(UnspecifiedParameters, self).__init__(
289+
validate_args=False,
290+
is_constant_jacobian=True,
291+
forward_min_event_ndims=0,
292+
name='unspecified_parameters')
293+
294+
def _forward(self, x):
295+
return x + self._loc
296+
297+
def _forward_log_det_jacobian(self, x):
298+
return tf.constant(0., x.dtype)
299+
300+
283301
@test_util.test_all_tf_execution_regimes
284302
class BijectorTestEventNdims(test_util.TestCase):
285303

@@ -440,6 +458,18 @@ def testUniqueCacheKey(self):
440458
self.assertLen(bijector_1._cache.weak_keys(direction='forward'), 1)
441459
self.assertLen(bijector_2._cache.weak_keys(direction='forward'), 1)
442460

461+
def testBijectorsWithUnspecifiedParametersDoNotShareCache(self):
462+
bijector_1 = UnspecifiedParameters(loc=tf.constant(1., dtype=tf.float32))
463+
bijector_2 = UnspecifiedParameters(loc=tf.constant(2., dtype=tf.float32))
464+
465+
x = tf.constant(3., dtype=tf.float32)
466+
y_1 = bijector_1.forward(x)
467+
y_2 = bijector_2.forward(x)
468+
469+
self.assertIsNot(y_1, y_2)
470+
self.assertLen(bijector_1._cache.weak_keys(direction='forward'), 1)
471+
self.assertLen(bijector_2._cache.weak_keys(direction='forward'), 1)
472+
443473
def testInstanceCache(self):
444474
instance_cache_bijector = tfb.Exp()
445475
instance_cache_bijector._cache = cache_util.BijectorCache(

0 commit comments

Comments
 (0)