@@ -699,24 +699,26 @@ def apply_slice_along_axis(self, func, arr, low, high, axis):
699
699
specified through `low` and `high`. Support broadcasting.
700
700
"""
701
701
np .testing .assert_equal (low .shape , high .shape )
702
- ni , _ , nk = arr .shape [:axis ], arr .shape [axis ], arr .shape [axis + 1 :]
703
- si , j , sk = low .shape [:axis ], low .shape [axis ], low .shape [axis + 1 :]
704
- mk = max (nk , sk )
705
- mi = max (ni , si )
706
- out = np .empty (mi + (j ,) + mk )
707
- for ki in np .ndindex (ni ):
708
- for kk in np .ndindex (mk ):
709
- ak = tuple (np .mod (kk , nk ))
710
- ik = tuple (np .mod (kk , sk ))
711
- ai = tuple (np .mod (ki , ni ))
712
- ii = tuple (np .mod (ki , si ))
713
- a_1d = arr [ai + np .s_ [:, ] + ak ]
714
- out_1d = out [ki + np .s_ [:, ] + kk ]
715
- low_1d = low [ii + np .s_ [:, ] + ik ]
716
- high_1d = high [ii + np .s_ [:, ] + ik ]
717
-
718
- for r in range (j ):
719
- out_1d [r ] = func (a_1d [low_1d [r ]:high_1d [r ]])
702
+
703
+ def apply_func (vector , l , h ):
704
+ return func (vector [l :h ])
705
+
706
+ apply_func_1d = np .vectorize (apply_func , signature = '(n), (), ()->()' )
707
+ vectorized_func = np .vectorize (apply_func_1d ,
708
+ signature = '(n), (k), (k)->(m)' )
709
+
710
+ # Put `axis` at the innermost dimension
711
+ dims = list (range (arr .ndim ))
712
+ dims [- 1 ] = axis
713
+ dims [axis ] = arr .ndim - 1
714
+ t_arr = np .transpose (arr , axes = dims )
715
+ t_low = np .transpose (low , axes = dims )
716
+ t_high = np .transpose (high , axes = dims )
717
+
718
+ t_out = vectorized_func (t_arr , t_low , t_high )
719
+
720
+ # Replace `axis` at its place
721
+ out = np .transpose (t_out , axes = dims )
720
722
return out
721
723
722
724
def check_gaussian_windowed (self , shape , indice_shape , axis ,
@@ -797,10 +799,6 @@ def test_windowed_mean_graph(self):
797
799
def test_windowed_variance (self ):
798
800
self .check_windowed (func = sample_stats .windowed_variance , numpy_func = np .var )
799
801
800
- def test_windowed_variance_graph (self ):
801
- func = tf .function (sample_stats .windowed_variance )
802
- self .check_windowed (func = func , numpy_func = np .var )
803
-
804
802
805
803
@test_util .test_all_tf_execution_regimes
806
804
class LogAverageProbsTest (test_util .TestCase ):
0 commit comments