2828 Generic ,
2929 Literal ,
3030 NamedTuple ,
31- Optional ,
3231 TypeVar ,
33- Union ,
3432 cast ,
3533)
3634from unittest .mock import MagicMock , patch
9189_RefinementSolver = TypeVar ("_RefinementSolver" , bound = RefinementSolver )
9290
9391if TYPE_CHECKING :
94- # In Python 3.9-3. 10, this raises
92+ # In Python 3.10, this raises
9593 # `TypeError: Multiple inheritance with NamedTuple is not supported`.
9694 # Thus, we have to do the actual full typing here, and a non-generic one
9795 # below to be used at runtime.
9896 class _ReduceProblem (NamedTuple , Generic [_Data , _Solver ]):
9997 dataset : _Data
10098 solver : _Solver
101- expected_coreset : Optional [ AbstractCoreset ] = None
99+ expected_coreset : AbstractCoreset | None = None
102100
103101 class _RefineProblem (NamedTuple , Generic [_RefinementSolver ]):
104102 initial_coresubset : Coresubset
105103 solver : _RefinementSolver
106- expected_coresubset : Optional [ Coresubset ] = None
104+ expected_coresubset : Coresubset | None = None
107105else :
108106 # This is the implementation that's used at runtime.
109107 class _ReduceProblem (NamedTuple ):
110108 dataset : _Data
111109 solver : _Solver
112- expected_coreset : Optional [ AbstractCoreset ] = None
110+ expected_coreset : AbstractCoreset | None = None
113111
114112 class _RefineProblem (NamedTuple ):
115113 initial_coresubset : Coresubset
116114 solver : _RefinementSolver
117- expected_coresubset : Optional [ Coresubset ] = None
115+ expected_coresubset : Coresubset | None = None
118116
119117
120118class SolverTest :
@@ -151,7 +149,7 @@ def reduce_problem(
151149 return _ReduceProblem (Data (dataset ), solver , expected_coreset )
152150
153151 def check_solution_invariants (
154- self , coreset : AbstractCoreset , problem : Union [ _RefineProblem , _ReduceProblem ]
152+ self , coreset : AbstractCoreset , problem : _RefineProblem | _ReduceProblem
155153 ) -> None :
156154 """
157155 Check that a coreset obeys certain expected invariant properties.
@@ -796,7 +794,7 @@ def test_functions_impl(x):
796794
797795 @override
798796 def check_solution_invariants (
799- self , coreset : AbstractCoreset , problem : Union [ _RefineProblem , _ReduceProblem ]
797+ self , coreset : AbstractCoreset , problem : _RefineProblem | _ReduceProblem
800798 ) -> None :
801799 r"""
802800 Check that a coreset obeys certain expected invariant properties.
@@ -1006,7 +1004,7 @@ class ExplicitSizeSolverTest(SolverTest):
10061004
10071005 @override
10081006 def check_solution_invariants (
1009- self , coreset : AbstractCoreset , problem : Union [ _RefineProblem , _ReduceProblem ]
1007+ self , coreset : AbstractCoreset , problem : _RefineProblem | _ReduceProblem
10101008 ) -> None :
10111009 super ().check_solution_invariants (coreset , problem )
10121010 solver = problem .solver
@@ -1026,7 +1024,7 @@ def check_solution_invariants(
10261024 def test_check_init (
10271025 self ,
10281026 solver_factory : jtu .Partial ,
1029- coreset_size : Union [ int , float , str ] ,
1027+ coreset_size : int | float | str ,
10301028 context : AbstractContextManager ,
10311029 ) -> None :
10321030 """
@@ -1073,7 +1071,7 @@ def solver_factory(self, request: pytest.FixtureRequest) -> jtu.Partial:
10731071 def reduce_problem (
10741072 self ,
10751073 request : pytest .FixtureRequest ,
1076- solver_factory : Union [ type [Solver ], jtu .Partial ] ,
1074+ solver_factory : type [Solver ] | jtu .Partial ,
10771075 ) -> _ReduceProblem :
10781076 if request .param == "random" :
10791077 dataset = jr .uniform (self .random_key , self .shape )
@@ -1707,7 +1705,7 @@ class TestRandomSample(ExplicitSizeSolverTest):
17071705
17081706 @override
17091707 def check_solution_invariants (
1710- self , coreset : AbstractCoreset , problem : Union [ _RefineProblem , _ReduceProblem ]
1708+ self , coreset : AbstractCoreset , problem : _RefineProblem | _ReduceProblem
17111709 ) -> None :
17121710 super ().check_solution_invariants (coreset , problem )
17131711 solver = cast (RandomSample , problem .solver )
@@ -1730,7 +1728,7 @@ class TestRPCholesky(ExplicitSizeSolverTest):
17301728
17311729 @override
17321730 def check_solution_invariants (
1733- self , coreset : AbstractCoreset , problem : Union [ _RefineProblem , _ReduceProblem ]
1731+ self , coreset : AbstractCoreset , problem : _RefineProblem | _ReduceProblem
17341732 ) -> None :
17351733 """Check functionality of 'unique' in addition to the default checks."""
17361734 super ().check_solution_invariants (coreset , problem )
@@ -2044,7 +2042,7 @@ def solver_factory(self, request: pytest.FixtureRequest) -> jtu.Partial:
20442042 ],
20452043 )
20462044 def test_regulariser_lambda (
2047- self , test_lambda : Optional [ Union [ float , int ]] , reduce_problem : _ReduceProblem
2045+ self , test_lambda : float | int | None , reduce_problem : _ReduceProblem
20482046 ) -> None :
20492047 """Basic checks for the regularisation parameter, lambda."""
20502048 dataset , base_solver , _ = reduce_problem
@@ -2411,7 +2409,7 @@ def solver_factory(self, request) -> jtu.Partial:
24112409 def reduce_problem (
24122410 self ,
24132411 request : pytest .FixtureRequest ,
2414- solver_factory : Union [ type [Solver ], jtu .Partial ] ,
2412+ solver_factory : type [Solver ] | jtu .Partial ,
24152413 ) -> _ReduceProblem :
24162414 if request .param == "random" :
24172415 data_key , supervision_key = jr .split (self .random_key )
@@ -2427,7 +2425,7 @@ def reduce_problem(
24272425
24282426 @override
24292427 def check_solution_invariants (
2430- self , coreset : AbstractCoreset , problem : Union [ _RefineProblem , _ReduceProblem ]
2428+ self , coreset : AbstractCoreset , problem : _RefineProblem | _ReduceProblem
24312429 ) -> None :
24322430 """Check functionality of 'unique' in addition to the default checks."""
24332431 super ().check_solution_invariants (coreset , problem )
@@ -2796,7 +2794,7 @@ def __init__(self, _data: np.ndarray, **kwargs):
27962794 del kwargs
27972795 self .data = _data
27982796
2799- def get_arrays (self ) -> tuple [Union [ np .ndarray , None ] , ...]:
2797+ def get_arrays (self ) -> tuple [np .ndarray | None , ...]:
28002798 """Mock sklearn.neighbours.BinaryTree.get_arrays method."""
28012799 return None , np .arange (len (self .data )), None , None
28022800
0 commit comments