Skip to content

Commit 9600b80

Browse files
committed
Revert "Add numpy/jax rewrite for cumulative_logsumexp."
This reverts commit 82520ca for compatibility with TF 2.11.
1 parent 17ddf7c commit 9600b80

File tree

6 files changed

+4
-39
lines changed

6 files changed

+4
-39
lines changed

tensorflow_probability/python/experimental/mcmc/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,7 @@ multi_substrate_py_library(
704704
"//tensorflow_probability/python/internal:prefer_static",
705705
"//tensorflow_probability/python/internal:tensor_util",
706706
"//tensorflow_probability/python/internal:tensorshape_util",
707+
"//tensorflow_probability/python/math:generic",
707708
"//tensorflow_probability/python/math:gradient",
708709
"//tensorflow_probability/python/mcmc/internal:util",
709710
],

tensorflow_probability/python/experimental/mcmc/weighted_resampling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tensorflow_probability.python.distributions import uniform
2222
from tensorflow_probability.python.internal import distribution_util as dist_util
2323
from tensorflow_probability.python.internal import prefer_static as ps
24+
from tensorflow_probability.python.math.generic import log_cumsum_exp
2425
from tensorflow_probability.python.math.gradient import value_and_gradient
2526
from tensorflow_probability.python.mcmc.internal import util as mcmc_util
2627

@@ -134,7 +135,7 @@ def _resample_using_log_points(log_probs, sample_shape, log_points, name=None):
134135
tf.zeros(points_shape, dtype=tf.int32)],
135136
axis=-1)
136137
log_marker_positions = tf.broadcast_to(
137-
tf.math.cumulative_logsumexp(log_probs, axis=-1),
138+
log_cumsum_exp(log_probs, axis=-1),
138139
markers_shape)
139140
log_markers_and_points = ps.concat(
140141
[log_marker_positions, log_points], axis=-1)

tensorflow_probability/python/internal/backend/numpy/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ py_test(
469469
"--test_mode=xla",
470470
# TODO(b/168718272): reduce_*([nan, nan], axis=0) (GPU)
471471
# histogram_fixed_width_bins fails with f32([0.]), [0.0, 0.0], 2
472-
"--xla_disabled=math.cumulative_logsumexp,math.reduce_min,math.reduce_max,histogram_fixed_width_bins",
472+
"--xla_disabled=math.reduce_min,math.reduce_max,histogram_fixed_width_bins",
473473
],
474474
main = "numpy_test.py",
475475
shard_count = 11,

tensorflow_probability/python/internal/backend/numpy/numpy_math.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
'count_nonzero',
6262
'cumprod',
6363
'cumsum',
64-
'cumulative_logsumexp',
6564
'digamma',
6665
'divide',
6766
'divide_no_nan',
@@ -261,23 +260,6 @@ def _cumop(op, x, axis=0, exclusive=False, reverse=False, name=None,
261260
_cumsum = utils.partial(_cumop, np.cumsum, initial_value=0.)
262261

263262

264-
def _cumulative_logsumexp(x, axis=0, exclusive=False, reverse=False, name=None):
265-
del name
266-
axis = int(axis)
267-
if axis < 0:
268-
axis = axis + len(x.shape)
269-
if JAX_MODE:
270-
op = jax.lax.cumlogsumexp
271-
else:
272-
op = np.logaddexp.accumulate
273-
return _cumop(
274-
op, x,
275-
axis=axis,
276-
exclusive=exclusive,
277-
reverse=reverse,
278-
initial_value=-np.inf)
279-
280-
281263
def _equal(x, y, name=None):
282264
del name
283265
x = _convert_to_tensor(x)
@@ -578,11 +560,6 @@ def _unsorted_segment_sum(data, segment_ids, num_segments, name=None):
578560
'tf.math.cumsum',
579561
_cumsum)
580562

581-
cumulative_logsumexp = utils.copy_docstring(
582-
'tf.math.cumulative_logsumexp',
583-
_cumulative_logsumexp)
584-
585-
586563
digamma = utils.copy_docstring(
587564
'tf.math.digamma',
588565
lambda x, name=None: scipy_special.digamma(x))

tensorflow_probability/python/internal/backend/numpy/numpy_test.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,16 +1203,6 @@ def _not_implemented(*args, **kwargs):
12031203
hps.booleans()).map(lambda x: x[0] + (x[1], x[2]))
12041204
],
12051205
xla_const_args=(1, 2, 3)),
1206-
TestCase(
1207-
'math.cumulative_logsumexp', [
1208-
hps.tuples(
1209-
array_axis_tuples(
1210-
elements=floats(min_value=-1e12, max_value=1e12)),
1211-
hps.booleans(),
1212-
hps.booleans()).map(lambda x: x[0] + (x[1], x[2]))
1213-
],
1214-
rtol=6e-5,
1215-
xla_const_args=(1, 2, 3)),
12161206
]
12171207

12181208
NUMPY_TEST_CASES += [ # break the array for pylint to not timeout.

tensorflow_probability/python/math/generic.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from tensorflow_probability.python.internal import tensorshape_util
3131
from tensorflow_probability.python.internal import variadic_reduce
3232
from tensorflow_probability.python.math.scan_associative import scan_associative
33-
from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import
3433

3534

3635
__all__ = [
@@ -90,9 +89,6 @@ def log_combinations(n, counts, name='log_combinations'):
9089

9190
# TODO(b/154562929): Remove this once the built-in op supports XLA.
9291
# TODO(b/156297366): Derivatives of this function may not always be correct.
93-
@deprecation.deprecated('2023-03-01',
94-
'`log_cumsum_exp` is deprecated; '
95-
' Use `tf.math.cumulative_logsumexp` instead.')
9692
def log_cumsum_exp(x, axis=-1, name=None):
9793
"""Computes log(cumsum(exp(x))).
9894

0 commit comments

Comments
 (0)