11"""Unit tests for the Lai procedure"""
22
3+ import os
4+ from pathlib import Path
5+
36import numpy as np
47import pytest
58
69from sequentialized_barnard_tests import Decision , Hypothesis
710from sequentialized_barnard_tests .lai import LaiTest , MirroredLaiTest
811
912##### Lai Test #####
13+ paper_data_path = str (
14+ Path (
15+ os .path .join (
16+ os .path .dirname (os .path .abspath (__file__ )),
17+ "../eval_data/" ,
18+ )
19+ ).resolve ()
20+ )
21+ eval_clean_up_spill = np .load (
22+ f"{ paper_data_path } /TRI_CLEAN_SPILL_v4.npy"
23+ ) # Must be flipped for standard form
24+ eval_fold_red_towel = np .load (
25+ f"{ paper_data_path } /TRI_FOLD_RED_TOWEL.npy"
26+ ) # ALREADY in standard form
27+ eval_sim_spoon_on_towel = np .load (
28+ f"{ paper_data_path } /TRI_SIM_SPOON_ON_TOWEL.npy"
29+ ) # Must be flipped for standard form
30+ eval_sim_eggplant_in_basket = np .load (
31+ f"{ paper_data_path } /TRI_SIM_EGGPLANT_IN_BASKET.npy"
32+ ) # Must be flipped for standard form
33+ eval_sim_stack_cube = np .load (
34+ f"{ paper_data_path } /TRI_SIM_STACK_CUBE.npy"
35+ ) # Must be flipped for standard form
1036
1137
1238@pytest .fixture (scope = "module" )
@@ -18,7 +44,7 @@ def lai(request):
1844 calibrate_regularizer = False ,
1945 use_offline_calibration = False ,
2046 )
21- test .set_c (4.3320915613895993e -05 )
47+ test .set_c (5.3077895340120925e -05 )
2248 return test
2349
2450
@@ -62,6 +88,70 @@ def test_lai(lai, sequence_0, sequence_1, expected):
6288 assert result .decision == expected
6389
6490
91+ @pytest .fixture (scope = "module" )
92+ def lai200 (request ):
93+ test = LaiTest (
94+ alternative = request .param ,
95+ n_max = 200 ,
96+ alpha = 0.05 ,
97+ )
98+ test .set_c (0.00014121395942619315 )
99+ return test
100+
101+
102+ @pytest .mark .parametrize (
103+ ("lai200" , "sequence_0" , "sequence_1" , "expected" ),
104+ [
105+ # fmt: off
106+ (Hypothesis .P0LessThanP1 , eval_clean_up_spill [:, 1 ], eval_clean_up_spill [:, 0 ], 13 ),
107+ (Hypothesis .P0MoreThanP1 , eval_clean_up_spill [:, 1 ], eval_clean_up_spill [:, 0 ], 50 ),
108+ (Hypothesis .P0LessThanP1 , eval_clean_up_spill [:, 0 ], eval_clean_up_spill [:, 1 ], 50 ),
109+ (Hypothesis .P0MoreThanP1 , eval_clean_up_spill [:, 0 ], eval_clean_up_spill [:, 1 ], 13 ),
110+ (Hypothesis .P0LessThanP1 , eval_fold_red_towel [:, 0 ], eval_fold_red_towel [:, 1 ], 21 ),
111+ (Hypothesis .P0MoreThanP1 , eval_fold_red_towel [:, 0 ], eval_fold_red_towel [:, 1 ], 50 ),
112+ (Hypothesis .P0LessThanP1 , eval_fold_red_towel [:, 1 ], eval_fold_red_towel [:, 0 ], 50 ),
113+ (Hypothesis .P0MoreThanP1 , eval_fold_red_towel [:, 1 ], eval_fold_red_towel [:, 0 ], 21 ),
114+ # fmt: on
115+ ],
116+ indirect = ["lai200" ],
117+ )
118+ def test_lai200_time (lai200 , sequence_0 , sequence_1 , expected ):
119+ result = lai200 .run_on_sequence (sequence_0 , sequence_1 )
120+ assert np .abs (result .info ["Time" ] - expected ) <= 0.6
121+
122+
123+ @pytest .fixture (scope = "module" )
124+ def lai50 (request ):
125+ test = LaiTest (
126+ alternative = request .param ,
127+ n_max = 50 ,
128+ alpha = 0.05 ,
129+ )
130+ test .set_c (0.000561395711114114 )
131+ return test
132+
133+
134+ @pytest .mark .parametrize (
135+ ("lai50" , "sequence_0" , "sequence_1" , "expected" ),
136+ [
137+ # fmt: off
138+ (Hypothesis .P0LessThanP1 , eval_clean_up_spill [:, 1 ], eval_clean_up_spill [:, 0 ], 8 ),
139+ (Hypothesis .P0MoreThanP1 , eval_clean_up_spill [:, 1 ], eval_clean_up_spill [:, 0 ], 50 ),
140+ (Hypothesis .P0LessThanP1 , eval_clean_up_spill [:, 0 ], eval_clean_up_spill [:, 1 ], 50 ),
141+ (Hypothesis .P0MoreThanP1 , eval_clean_up_spill [:, 0 ], eval_clean_up_spill [:, 1 ], 8 ),
142+ (Hypothesis .P0LessThanP1 , eval_fold_red_towel [:, 0 ], eval_fold_red_towel [:, 1 ], 17 ),
143+ (Hypothesis .P0MoreThanP1 , eval_fold_red_towel [:, 0 ], eval_fold_red_towel [:, 1 ], 50 ),
144+ (Hypothesis .P0LessThanP1 , eval_fold_red_towel [:, 1 ], eval_fold_red_towel [:, 0 ], 50 ),
145+ (Hypothesis .P0MoreThanP1 , eval_fold_red_towel [:, 1 ], eval_fold_red_towel [:, 0 ], 17 ),
146+ # fmt: on
147+ ],
148+ indirect = ["lai50" ],
149+ )
150+ def test_lai50_time (lai50 , sequence_0 , sequence_1 , expected ):
151+ result = lai50 .run_on_sequence (sequence_0 , sequence_1 )
152+ assert np .abs (result .info ["Time" ] - expected ) <= 0.6
153+
154+
65155##### Mirrored Lai Test #####
66156
67157
@@ -99,6 +189,108 @@ def test_mirrored_lai(mirrored_lai, sequence_0, sequence_1, expected):
99189 assert result .decision == expected
100190
101191
192+ @pytest .fixture (scope = "module" )
193+ def mirrored_lai200 (request ):
194+ test = MirroredLaiTest (
195+ alternative = request .param ,
196+ n_max = 200 ,
197+ alpha = 0.05 ,
198+ )
199+ test .set_c (0.00014121395942619315 )
200+ return test
201+
202+
203+ @pytest .mark .parametrize (
204+ ("mirrored_lai200" , "sequence_0" , "sequence_1" , "expected" ),
205+ [
206+ # fmt: off
207+ (Hypothesis .P0LessThanP1 , eval_clean_up_spill [:, 1 ], eval_clean_up_spill [:, 0 ], 13 ),
208+ (Hypothesis .P0MoreThanP1 , eval_clean_up_spill [:, 1 ], eval_clean_up_spill [:, 0 ], 13 ),
209+ (Hypothesis .P0LessThanP1 , eval_clean_up_spill [:, 0 ], eval_clean_up_spill [:, 1 ], 13 ),
210+ (Hypothesis .P0MoreThanP1 , eval_clean_up_spill [:, 0 ], eval_clean_up_spill [:, 1 ], 13 ),
211+ (Hypothesis .P0LessThanP1 , eval_fold_red_towel [:, 0 ], eval_fold_red_towel [:, 1 ], 21 ),
212+ (Hypothesis .P0MoreThanP1 , eval_fold_red_towel [:, 0 ], eval_fold_red_towel [:, 1 ], 21 ),
213+ (Hypothesis .P0LessThanP1 , eval_fold_red_towel [:, 1 ], eval_fold_red_towel [:, 0 ], 21 ),
214+ (Hypothesis .P0MoreThanP1 , eval_fold_red_towel [:, 1 ], eval_fold_red_towel [:, 0 ], 21 ),
215+ # fmt: on
216+ ],
217+ indirect = ["mirrored_lai200" ],
218+ )
219+ def test_mirrored_lai200_time (mirrored_lai200 , sequence_0 , sequence_1 , expected ):
220+ result = mirrored_lai200 .run_on_sequence (sequence_0 , sequence_1 )
221+ assert np .abs (result .info ["Time" ] - expected ) <= 0.6
222+
223+
224+ @pytest .fixture (scope = "module" )
225+ def mirrored_lai50 (request ):
226+ test = MirroredLaiTest (
227+ alternative = request .param ,
228+ n_max = 50 ,
229+ alpha = 0.05 ,
230+ )
231+ test .set_c (0.000561395711114114 )
232+ return test
233+
234+
235+ @pytest .mark .parametrize (
236+ ("mirrored_lai50" , "sequence_0" , "sequence_1" , "expected" ),
237+ [
238+ # fmt: off
239+ (Hypothesis .P0LessThanP1 , eval_clean_up_spill [:, 1 ], eval_clean_up_spill [:, 0 ], 8 ),
240+ (Hypothesis .P0MoreThanP1 , eval_clean_up_spill [:, 1 ], eval_clean_up_spill [:, 0 ], 8 ),
241+ (Hypothesis .P0LessThanP1 , eval_clean_up_spill [:, 0 ], eval_clean_up_spill [:, 1 ], 8 ),
242+ (Hypothesis .P0MoreThanP1 , eval_clean_up_spill [:, 0 ], eval_clean_up_spill [:, 1 ], 8 ),
243+ (Hypothesis .P0LessThanP1 , eval_fold_red_towel [:, 0 ], eval_fold_red_towel [:, 1 ], 17 ),
244+ (Hypothesis .P0MoreThanP1 , eval_fold_red_towel [:, 0 ], eval_fold_red_towel [:, 1 ], 17 ),
245+ (Hypothesis .P0LessThanP1 , eval_fold_red_towel [:, 1 ], eval_fold_red_towel [:, 0 ], 17 ),
246+ (Hypothesis .P0MoreThanP1 , eval_fold_red_towel [:, 1 ], eval_fold_red_towel [:, 0 ], 17 ),
247+ # fmt: on
248+ ],
249+ indirect = ["mirrored_lai50" ],
250+ )
251+ def test_mirrored_lai50_time (mirrored_lai50 , sequence_0 , sequence_1 , expected ):
252+ result = mirrored_lai50 .run_on_sequence (sequence_0 , sequence_1 )
253+ assert np .abs (result .info ["Time" ] - expected ) <= 0.6
254+
255+
256+ @pytest .fixture (scope = "module" )
257+ def mirrored_lai500 (request ):
258+ test = MirroredLaiTest (
259+ alternative = request .param ,
260+ n_max = 500 ,
261+ alpha = 0.01 ,
262+ calibrate_regularizer = False ,
263+ use_offline_calibration = False ,
264+ )
265+ test .set_c (1.013009359863071e-05 )
266+ return test
267+
268+
269+ @pytest .mark .parametrize (
270+ ("mirrored_lai500" , "sequence_0" , "sequence_1" , "expected" ),
271+ [
272+ # fmt: off
273+ (Hypothesis .P0LessThanP1 , eval_sim_spoon_on_towel [:, 1 ], eval_sim_spoon_on_towel [:, 0 ], 36 ),
274+ (Hypothesis .P0MoreThanP1 , eval_sim_spoon_on_towel [:, 1 ], eval_sim_spoon_on_towel [:, 0 ], 36 ),
275+ (Hypothesis .P0LessThanP1 , eval_sim_spoon_on_towel [:, 0 ], eval_sim_spoon_on_towel [:, 1 ], 36 ),
276+ (Hypothesis .P0MoreThanP1 , eval_sim_spoon_on_towel [:, 0 ], eval_sim_spoon_on_towel [:, 1 ], 36 ),
277+ (Hypothesis .P0LessThanP1 , eval_sim_eggplant_in_basket [:, 1 ], eval_sim_eggplant_in_basket [:, 0 ], 125 ),
278+ (Hypothesis .P0MoreThanP1 , eval_sim_eggplant_in_basket [:, 1 ], eval_sim_eggplant_in_basket [:, 0 ], 125 ),
279+ (Hypothesis .P0LessThanP1 , eval_sim_eggplant_in_basket [:, 0 ], eval_sim_eggplant_in_basket [:, 1 ], 125 ),
280+ (Hypothesis .P0MoreThanP1 , eval_sim_eggplant_in_basket [:, 0 ], eval_sim_eggplant_in_basket [:, 1 ], 125 ),
281+ (Hypothesis .P0LessThanP1 , eval_sim_stack_cube [:, 1 ], eval_sim_stack_cube [:, 0 ], 417 ),
282+ (Hypothesis .P0MoreThanP1 , eval_sim_stack_cube [:, 1 ], eval_sim_stack_cube [:, 0 ], 417 ),
283+ (Hypothesis .P0LessThanP1 , eval_sim_stack_cube [:, 0 ], eval_sim_stack_cube [:, 1 ], 417 ),
284+ (Hypothesis .P0MoreThanP1 , eval_sim_stack_cube [:, 0 ], eval_sim_stack_cube [:, 1 ], 417 ),
285+ # fmt: on
286+ ],
287+ indirect = ["mirrored_lai500" ],
288+ )
289+ def test_mirrored_lai500_time (mirrored_lai500 , sequence_0 , sequence_1 , expected ):
290+ result = mirrored_lai500 .run_on_sequence (sequence_0 , sequence_1 )
291+ assert np .abs (result .info ["Time" ] - expected ) <= 0.6
292+
293+
102294##### Offline Calibration Test #####
103295@pytest .mark .parametrize (
104296 ("alpha" , "n_max" ),
0 commit comments