@@ -280,6 +280,24 @@ def _get_parameterization(self):
280
280
return id (self )
281
281
282
282
283
+ class UnspecifiedParameters (tfb .Bijector ):
284
+ """A bijector that fails to pass `parameters` to the base class."""
285
+
286
+ def __init__ (self , loc ):
287
+ self ._loc = loc
288
+ super (UnspecifiedParameters , self ).__init__ (
289
+ validate_args = False ,
290
+ is_constant_jacobian = True ,
291
+ forward_min_event_ndims = 0 ,
292
+ name = 'unspecified_parameters' )
293
+
294
+ def _forward (self , x ):
295
+ return x + self ._loc
296
+
297
+ def _forward_log_det_jacobian (self , x ):
298
+ return tf .constant (0. , x .dtype )
299
+
300
+
283
301
@test_util .test_all_tf_execution_regimes
284
302
class BijectorTestEventNdims (test_util .TestCase ):
285
303
@@ -440,6 +458,18 @@ def testUniqueCacheKey(self):
440
458
self .assertLen (bijector_1 ._cache .weak_keys (direction = 'forward' ), 1 )
441
459
self .assertLen (bijector_2 ._cache .weak_keys (direction = 'forward' ), 1 )
442
460
461
+ def testBijectorsWithUnspecifiedParametersDoNotShareCache (self ):
462
+ bijector_1 = UnspecifiedParameters (loc = tf .constant (1. , dtype = tf .float32 ))
463
+ bijector_2 = UnspecifiedParameters (loc = tf .constant (2. , dtype = tf .float32 ))
464
+
465
+ x = tf .constant (3. , dtype = tf .float32 )
466
+ y_1 = bijector_1 .forward (x )
467
+ y_2 = bijector_2 .forward (x )
468
+
469
+ self .assertIsNot (y_1 , y_2 )
470
+ self .assertLen (bijector_1 ._cache .weak_keys (direction = 'forward' ), 1 )
471
+ self .assertLen (bijector_2 ._cache .weak_keys (direction = 'forward' ), 1 )
472
+
443
473
def testInstanceCache (self ):
444
474
instance_cache_bijector = tfb .Exp ()
445
475
instance_cache_bijector ._cache = cache_util .BijectorCache (
0 commit comments