2121"""
2222
2323from collections .abc import Callable
24- from typing import Optional , TypeVar , Union
24+ from typing import TypeVar
2525
2626import jax
2727import jax .numpy as jnp
2828import jax .scipy as jsp
2929import numpy as np
30- from jaxtyping import Array , Float , Shaped
30+ from jaxtyping import Array , ArrayLike , Float , Scalar , Shaped
3131
3232from coreax import Coresubset , Data , SupervisedData
3333from coreax .kernels import SquaredExponentialKernel , SteinKernel , median_heuristic
@@ -61,12 +61,12 @@ class IterativeKernelHerding(KernelHerding[_Data]): # pylint: disable=too-many-
6161 """
6262
6363 num_iterations : int = 1
64- t_schedule : Optional [ Array ] = None
64+ t_schedule : Array | None = None
6565
6666 def reduce (
6767 self ,
6868 dataset : _Data ,
69- solver_state : Optional [ HerdingState ] = None ,
69+ solver_state : HerdingState | None = None ,
7070 ) -> tuple [Coresubset [_Data ], HerdingState ]:
7171 """
7272 Perform Kernel Herding reduction followed by additional refinement iterations.
@@ -118,7 +118,7 @@ def initialise_solvers( # noqa: C901
118118 train_data_umap : Data ,
119119 key : KeyArrayLike ,
120120 cpp_oversampling_factor : int ,
121- leaf_size : Optional [ int ] = None ,
121+ leaf_size : int | None = None ,
122122) -> dict [str , Callable [[int ], Solver ]]:
123123 """
124124 Initialise and return a list of solvers for various coreset algorithms.
@@ -147,7 +147,7 @@ def initialise_solvers( # noqa: C901
147147 kernel = SquaredExponentialKernel (length_scale = length_scale )
148148 sqrt_kernel = kernel .get_sqrt_kernel (16 )
149149
150- def _get_thinning_solver (_size : int ) -> Union [ KernelThinning , MapReduce ] :
150+ def _get_thinning_solver (_size : int ) -> KernelThinning | MapReduce :
151151 """
152152 Set up kernel thinning solver.
153153
@@ -169,7 +169,7 @@ def _get_thinning_solver(_size: int) -> Union[KernelThinning, MapReduce]:
169169 return thinning_solver
170170 return MapReduce (thinning_solver , leaf_size = leaf_size )
171171
172- def _get_herding_solver (_size : int ) -> Union [ KernelHerding , MapReduce ] :
172+ def _get_herding_solver (_size : int ) -> KernelHerding | MapReduce :
173173 """
174174 Set up kernel herding solver.
175175
@@ -185,7 +185,7 @@ def _get_herding_solver(_size: int) -> Union[KernelHerding, MapReduce]:
185185 return herding_solver
186186 return MapReduce (herding_solver , leaf_size = leaf_size )
187187
188- def _get_stein_solver (_size : int ) -> Union [ SteinThinning , MapReduce ] :
188+ def _get_stein_solver (_size : int ) -> SteinThinning | MapReduce :
189189 """
190190 Set up Stein thinning solver.
191191
@@ -199,20 +199,19 @@ def _get_stein_solver(_size: int) -> Union[SteinThinning, MapReduce]:
199199 kde = jsp .stats .gaussian_kde (train_data_umap .data [idx ].T )
200200
201201 # Define the score function as the gradient of log density given by the KDE
202- def score_function (
203- x : Union [Shaped [Array , " n d" ], Shaped [Array , "" ], float , int ],
204- ) -> Union [Shaped [Array , " n d" ], Shaped [Array , " 1 1" ]]:
202+ def score_function (x : Shaped [ArrayLike , " n d" ]) -> Shaped [Array , " n d" ]:
205203 """
206204 Compute the score function (gradient of log density) for a single point.
207205
208206 :param x: Input point represented as array.
209207 :return: Gradient of log probability density at the given point.
210208 """
211209
212- def logpdf_single (x : Shaped [Array , " d" ]) -> Shaped [ Array , "" ] :
210+ def logpdf_single (x : Shaped [Array , " d" ]) -> Scalar :
213211 return kde .logpdf (x .reshape (1 , - 1 ))[0 ]
214212
215- return jax .grad (logpdf_single )(x )
213+ _x = jnp .asarray (x )
214+ return jax .grad (logpdf_single )(_x )
216215
217216 stein_kernel = SteinKernel (
218217 base_kernel = kernel ,
@@ -264,7 +263,7 @@ def _get_compress_solver(_size: int) -> CompressPlusPlus:
264263
265264 def _get_probabilistic_herding_solver (
266265 _size : int ,
267- ) -> Union [ IterativeKernelHerding , MapReduce ] :
266+ ) -> IterativeKernelHerding | MapReduce :
268267 """
269268 Set up KernelHerding with probabilistic selection.
270269
@@ -289,7 +288,7 @@ def _get_probabilistic_herding_solver(
289288
290289 def _get_iterative_herding_solver (
291290 _size : int ,
292- ) -> Union [ IterativeKernelHerding , MapReduce ] :
291+ ) -> IterativeKernelHerding | MapReduce :
293292 """
294293 Set up KernelHerding with probabilistic selection.
295294
@@ -314,7 +313,7 @@ def _get_iterative_herding_solver(
314313
315314 def _get_cubic_iterative_herding_solver (
316315 _size : int ,
317- ) -> Union [ IterativeKernelHerding , MapReduce ] :
316+ ) -> IterativeKernelHerding | MapReduce :
318317 """
319318 Set up KernelHerding with probabilistic selection.
320319
0 commit comments