Skip to content

Commit 169f7f5

Browse files
committed
Rewrite apply_slice_along_axis using np.vectorize
1 parent 26f4f12 commit 169f7f5

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

tensorflow_probability/python/stats/sample_stats_test.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -699,24 +699,26 @@ def apply_slice_along_axis(self, func, arr, low, high, axis):
699699
specified through `low` and `high`. Support broadcasting.
700700
"""
701701
np.testing.assert_equal(low.shape, high.shape)
702-
ni, _, nk = arr.shape[:axis], arr.shape[axis], arr.shape[axis + 1:]
703-
si, j, sk = low.shape[:axis], low.shape[axis], low.shape[axis + 1:]
704-
mk = max(nk, sk)
705-
mi = max(ni, si)
706-
out = np.empty(mi + (j,) + mk)
707-
for ki in np.ndindex(ni):
708-
for kk in np.ndindex(mk):
709-
ak = tuple(np.mod(kk, nk))
710-
ik = tuple(np.mod(kk, sk))
711-
ai = tuple(np.mod(ki, ni))
712-
ii = tuple(np.mod(ki, si))
713-
a_1d = arr[ai + np.s_[:, ] + ak]
714-
out_1d = out[ki + np.s_[:, ] + kk]
715-
low_1d = low[ii + np.s_[:, ] + ik]
716-
high_1d = high[ii + np.s_[:, ] + ik]
717-
718-
for r in range(j):
719-
out_1d[r] = func(a_1d[low_1d[r]:high_1d[r]])
702+
703+
def apply_func(vector, l, h):
704+
return func(vector[l:h])
705+
706+
apply_func_1d = np.vectorize(apply_func, signature='(n), (), ()->()')
707+
vectorized_func = np.vectorize(apply_func_1d,
708+
signature='(n), (k), (k)->(m)')
709+
710+
# Put `axis` at the innermost dimension
711+
dims = list(range(arr.ndim))
712+
dims[-1] = axis
713+
dims[axis] = arr.ndim - 1
714+
t_arr = np.transpose(arr, axes=dims)
715+
t_low = np.transpose(low, axes=dims)
716+
t_high = np.transpose(high, axes=dims)
717+
718+
t_out = vectorized_func(t_arr, t_low, t_high)
719+
720+
# Replace `axis` at its place
721+
out = np.transpose(t_out, axes=dims)
720722
return out
721723

722724
def check_gaussian_windowed(self, shape, indice_shape, axis,
@@ -797,10 +799,6 @@ def test_windowed_mean_graph(self):
797799
def test_windowed_variance(self):
798800
self.check_windowed(func=sample_stats.windowed_variance, numpy_func=np.var)
799801

800-
def test_windowed_variance_graph(self):
801-
func = tf.function(sample_stats.windowed_variance)
802-
self.check_windowed(func=func, numpy_func=np.var)
803-
804802

805803
@test_util.test_all_tf_execution_regimes
806804
class LogAverageProbsTest(test_util.TestCase):

0 commit comments

Comments
 (0)